410 lines
16 KiB
Python
410 lines
16 KiB
Python
# services/training/app/services/data_client.py
|
|
"""
|
|
Training Service Data Client
|
|
Migrated to use shared service clients with timeout configuration
|
|
"""
|
|
|
|
import structlog
|
|
from typing import Dict, Any, List, Optional
|
|
from datetime import datetime
|
|
import httpx
|
|
|
|
# Import the shared clients
|
|
from shared.clients import get_sales_client, get_external_client, get_service_clients
|
|
from app.core.config import settings
|
|
from app.core import constants as const
|
|
from app.utils.circuit_breaker import circuit_breaker_registry, CircuitBreakerError
|
|
from app.utils.retry import with_retry, HTTP_RETRY_STRATEGY, EXTERNAL_SERVICE_RETRY_STRATEGY
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
class DataClient:
|
|
"""
|
|
Data client for training service
|
|
Now uses the shared data service client under the hood
|
|
"""
|
|
|
|
def __init__(self):
|
|
# Get the new specialized clients with timeout configuration
|
|
self.sales_client = get_sales_client(settings, "training")
|
|
self.external_client = get_external_client(settings, "training")
|
|
|
|
# ExternalServiceClient always has get_stored_traffic_data_for_training method
|
|
self.supports_stored_traffic_data = True
|
|
|
|
# Configure timeouts for HTTP clients
|
|
self._configure_timeouts()
|
|
|
|
# Initialize circuit breakers for external services
|
|
self._init_circuit_breakers()
|
|
|
|
def _configure_timeouts(self):
|
|
"""Configure appropriate timeouts for HTTP clients"""
|
|
timeout = httpx.Timeout(
|
|
connect=const.HTTP_TIMEOUT_DEFAULT,
|
|
read=const.HTTP_TIMEOUT_LONG_RUNNING,
|
|
write=const.HTTP_TIMEOUT_DEFAULT,
|
|
pool=const.HTTP_TIMEOUT_DEFAULT
|
|
)
|
|
|
|
# Apply timeout to clients if they have httpx clients
|
|
# Note: BaseServiceClient manages its own HTTP client internally
|
|
if hasattr(self.sales_client, 'client') and isinstance(self.sales_client.client, httpx.AsyncClient):
|
|
self.sales_client.client.timeout = timeout
|
|
|
|
if hasattr(self.external_client, 'client') and isinstance(self.external_client.client, httpx.AsyncClient):
|
|
self.external_client.client.timeout = timeout
|
|
|
|
def _init_circuit_breakers(self):
|
|
"""Initialize circuit breakers for external service calls"""
|
|
# Sales service circuit breaker
|
|
self.sales_cb = circuit_breaker_registry.get_or_create(
|
|
name="sales_service",
|
|
failure_threshold=5,
|
|
recovery_timeout=60.0,
|
|
expected_exception=Exception
|
|
)
|
|
|
|
# Weather service circuit breaker
|
|
self.weather_cb = circuit_breaker_registry.get_or_create(
|
|
name="weather_service",
|
|
failure_threshold=3, # Weather is optional, fail faster
|
|
recovery_timeout=30.0,
|
|
expected_exception=Exception
|
|
)
|
|
|
|
# Traffic service circuit breaker
|
|
self.traffic_cb = circuit_breaker_registry.get_or_create(
|
|
name="traffic_service",
|
|
failure_threshold=3, # Traffic is optional, fail faster
|
|
recovery_timeout=30.0,
|
|
expected_exception=Exception
|
|
)
|
|
|
|
@with_retry(max_attempts=3, initial_delay=1.0, max_delay=10.0)
|
|
async def _fetch_sales_data_internal(
|
|
self,
|
|
tenant_id: str,
|
|
start_date: Optional[str] = None,
|
|
end_date: Optional[str] = None,
|
|
product_id: Optional[str] = None,
|
|
fetch_all: bool = True
|
|
) -> List[Dict[str, Any]]:
|
|
"""Internal method to fetch sales data with automatic retry"""
|
|
if fetch_all:
|
|
sales_data = await self.sales_client.get_all_sales_data(
|
|
tenant_id=tenant_id,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
product_id=product_id,
|
|
aggregation="daily",
|
|
page_size=1000,
|
|
max_pages=100
|
|
)
|
|
else:
|
|
sales_data = await self.sales_client.get_sales_data(
|
|
tenant_id=tenant_id,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
product_id=product_id,
|
|
aggregation="daily"
|
|
)
|
|
sales_data = sales_data or []
|
|
|
|
if sales_data:
|
|
logger.info(f"Fetched {len(sales_data)} sales records",
|
|
tenant_id=tenant_id, product_id=product_id, fetch_all=fetch_all)
|
|
return sales_data
|
|
else:
|
|
logger.error("No sales data returned", tenant_id=tenant_id)
|
|
raise ValueError(f"No sales data available for tenant {tenant_id}")
|
|
|
|
async def fetch_sales_data(
|
|
self,
|
|
tenant_id: str,
|
|
start_date: Optional[str] = None,
|
|
end_date: Optional[str] = None,
|
|
product_id: Optional[str] = None,
|
|
fetch_all: bool = True
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Fetch sales data for training with circuit breaker protection
|
|
"""
|
|
try:
|
|
return await self.sales_cb.call(
|
|
self._fetch_sales_data_internal,
|
|
tenant_id, start_date, end_date, product_id, fetch_all
|
|
)
|
|
except CircuitBreakerError as exc:
|
|
logger.error("Sales service circuit breaker open", error_message=str(exc))
|
|
raise RuntimeError(f"Sales service unavailable: {str(exc)}")
|
|
except ValueError:
|
|
raise
|
|
except Exception as exc:
|
|
logger.error("Error fetching sales data", tenant_id=tenant_id, error_message=str(exc))
|
|
raise RuntimeError(f"Failed to fetch sales data: {str(exc)}")
|
|
|
|
async def fetch_weather_data(
|
|
self,
|
|
tenant_id: str,
|
|
start_date: str,
|
|
end_date: str,
|
|
latitude: Optional[float] = None,
|
|
longitude: Optional[float] = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Fetch weather data for training
|
|
All the error handling and retry logic is now in the base client!
|
|
"""
|
|
try:
|
|
weather_data = await self.external_client.get_weather_historical(
|
|
tenant_id=tenant_id,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
latitude=latitude,
|
|
longitude=longitude
|
|
)
|
|
|
|
if weather_data:
|
|
logger.info(f"Fetched {len(weather_data)} weather records",
|
|
tenant_id=tenant_id)
|
|
return weather_data
|
|
else:
|
|
logger.warning("No weather data returned, will use synthetic data", tenant_id=tenant_id)
|
|
return []
|
|
|
|
except Exception as exc:
|
|
logger.warning("Error fetching weather data, will use synthetic data", tenant_id=tenant_id, error_message=str(exc))
|
|
return []
|
|
|
|
async def fetch_traffic_data_unified(
|
|
self,
|
|
tenant_id: str,
|
|
start_date: str,
|
|
end_date: str,
|
|
latitude: Optional[float] = None,
|
|
longitude: Optional[float] = None,
|
|
force_refresh: bool = False
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Unified traffic data fetching with intelligent cache-first strategy
|
|
|
|
Strategy:
|
|
1. Check if stored/cached traffic data exists for the date range
|
|
2. If exists and not force_refresh, return cached data
|
|
3. If not exists or force_refresh, fetch fresh data
|
|
4. Always return data without duplicate fetching
|
|
|
|
Args:
|
|
tenant_id: Tenant identifier
|
|
start_date: Start date string (ISO format)
|
|
end_date: End date string (ISO format)
|
|
latitude: Optional latitude for location-based data
|
|
longitude: Optional longitude for location-based data
|
|
force_refresh: If True, bypass cache and fetch fresh data
|
|
"""
|
|
cache_key = f"{tenant_id}_{start_date}_{end_date}_{latitude}_{longitude}"
|
|
|
|
try:
|
|
# Step 1: Try to get stored/cached data first (unless force_refresh)
|
|
if not force_refresh and self.supports_stored_traffic_data:
|
|
logger.info("Attempting to fetch cached traffic data",
|
|
tenant_id=tenant_id, cache_key=cache_key)
|
|
|
|
try:
|
|
cached_data = await self.external_client.get_stored_traffic_data_for_training(
|
|
tenant_id=tenant_id,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
latitude=latitude,
|
|
longitude=longitude
|
|
)
|
|
|
|
if cached_data and len(cached_data) > 0:
|
|
logger.info(f"✅ Using cached traffic data: {len(cached_data)} records",
|
|
tenant_id=tenant_id)
|
|
return cached_data
|
|
else:
|
|
logger.info("No cached traffic data found, fetching fresh data",
|
|
tenant_id=tenant_id)
|
|
except Exception as cache_error:
|
|
logger.warning(f"Cache fetch failed, falling back to fresh data: {cache_error}",
|
|
tenant_id=tenant_id)
|
|
|
|
# Step 2: Fetch fresh data if no cache or force_refresh
|
|
logger.info("Fetching fresh traffic data" + (" (force refresh)" if force_refresh else ""),
|
|
tenant_id=tenant_id)
|
|
|
|
fresh_data = await self.external_client.get_traffic_data(
|
|
tenant_id=tenant_id,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
latitude=latitude,
|
|
longitude=longitude
|
|
)
|
|
|
|
if fresh_data and len(fresh_data) > 0:
|
|
logger.info(f"✅ Fetched fresh traffic data: {len(fresh_data)} records",
|
|
tenant_id=tenant_id)
|
|
return fresh_data
|
|
else:
|
|
logger.warning("No fresh traffic data available", tenant_id=tenant_id)
|
|
return []
|
|
|
|
except Exception as exc:
|
|
logger.error("Error in unified traffic data fetch",
|
|
tenant_id=tenant_id, cache_key=cache_key, error_message=str(exc))
|
|
return []
|
|
|
|
# Legacy methods for backward compatibility - now delegate to unified method
|
|
async def fetch_traffic_data(
|
|
self,
|
|
tenant_id: str,
|
|
start_date: str,
|
|
end_date: str,
|
|
latitude: Optional[float] = None,
|
|
longitude: Optional[float] = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""Legacy method - delegates to unified fetcher with cache-first strategy"""
|
|
logger.info("Legacy fetch_traffic_data called - delegating to unified method", tenant_id=tenant_id)
|
|
return await self.fetch_traffic_data_unified(
|
|
tenant_id=tenant_id,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
latitude=latitude,
|
|
longitude=longitude,
|
|
force_refresh=False # Use cache-first for legacy calls
|
|
)
|
|
|
|
async def fetch_stored_traffic_data_for_training(
|
|
self,
|
|
tenant_id: str,
|
|
start_date: str,
|
|
end_date: str,
|
|
latitude: Optional[float] = None,
|
|
longitude: Optional[float] = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""Legacy method - delegates to unified fetcher with cache-first strategy"""
|
|
logger.info("Legacy fetch_stored_traffic_data_for_training called - delegating to unified method", tenant_id=tenant_id)
|
|
return await self.fetch_traffic_data_unified(
|
|
tenant_id=tenant_id,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
latitude=latitude,
|
|
longitude=longitude,
|
|
force_refresh=False # Use cache-first for training calls
|
|
)
|
|
|
|
async def refresh_traffic_data(
|
|
self,
|
|
tenant_id: str,
|
|
start_date: str,
|
|
end_date: str,
|
|
latitude: Optional[float] = None,
|
|
longitude: Optional[float] = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""Convenience method to force refresh traffic data"""
|
|
logger.info("Force refreshing traffic data (bypassing cache)", tenant_id=tenant_id)
|
|
return await self.fetch_traffic_data_unified(
|
|
tenant_id=tenant_id,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
latitude=latitude,
|
|
longitude=longitude,
|
|
force_refresh=True # Force fresh data
|
|
)
|
|
|
|
async def validate_data_quality(
|
|
self,
|
|
tenant_id: str,
|
|
start_date: str,
|
|
end_date: str,
|
|
sales_data: List[Dict[str, Any]] = None
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Validate data quality before training with comprehensive checks
|
|
"""
|
|
try:
|
|
errors = []
|
|
warnings = []
|
|
|
|
# If sales data provided, validate it directly
|
|
if sales_data is not None:
|
|
if not sales_data or len(sales_data) == 0:
|
|
errors.append("No sales data available for the specified period")
|
|
return {"is_valid": False, "errors": errors, "warnings": warnings}
|
|
|
|
# Check minimum data points
|
|
if len(sales_data) < 30:
|
|
errors.append(f"Insufficient data points: {len(sales_data)} (minimum 30 required)")
|
|
elif len(sales_data) < 90:
|
|
warnings.append(f"Limited data points: {len(sales_data)} (recommended 90+)")
|
|
|
|
# Check for required fields
|
|
required_fields = ['date', 'inventory_product_id']
|
|
for record in sales_data[:5]: # Sample check
|
|
missing = [f for f in required_fields if f not in record or record[f] is None]
|
|
if missing:
|
|
errors.append(f"Missing required fields: {missing}")
|
|
break
|
|
|
|
# Check for data quality issues
|
|
zero_count = sum(1 for r in sales_data if r.get('quantity', 0) == 0)
|
|
zero_ratio = zero_count / len(sales_data)
|
|
if zero_ratio > 0.9:
|
|
errors.append(f"Too many zero values: {zero_ratio:.1%} of records")
|
|
elif zero_ratio > 0.7:
|
|
warnings.append(f"High zero value ratio: {zero_ratio:.1%}")
|
|
|
|
# Check product diversity
|
|
unique_products = set(r.get('inventory_product_id') for r in sales_data if r.get('inventory_product_id'))
|
|
if len(unique_products) == 0:
|
|
errors.append("No valid product IDs found in sales data")
|
|
elif len(unique_products) == 1:
|
|
warnings.append("Only one product found - consider adding more products")
|
|
|
|
else:
|
|
# Fetch data for validation
|
|
sales_data = await self.fetch_sales_data(
|
|
tenant_id=tenant_id,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
fetch_all=False
|
|
)
|
|
|
|
if not sales_data:
|
|
errors.append("Unable to fetch sales data for validation")
|
|
return {"is_valid": False, "errors": errors, "warnings": warnings}
|
|
|
|
# Recursive call with fetched data
|
|
return await self.validate_data_quality(
|
|
tenant_id, start_date, end_date, sales_data
|
|
)
|
|
|
|
is_valid = len(errors) == 0
|
|
result = {
|
|
"is_valid": is_valid,
|
|
"errors": errors,
|
|
"warnings": warnings,
|
|
"data_points": len(sales_data) if sales_data else 0,
|
|
"unique_products": len(unique_products) if sales_data else 0
|
|
}
|
|
|
|
if is_valid:
|
|
logger.info("Data validation passed",
|
|
tenant_id=tenant_id,
|
|
data_points=result["data_points"],
|
|
warnings_count=len(warnings))
|
|
else:
|
|
logger.error("Data validation failed",
|
|
tenant_id=tenant_id,
|
|
errors=errors)
|
|
|
|
return result
|
|
|
|
except Exception as exc:
|
|
logger.error("Error validating data", tenant_id=tenant_id, error_message=str(exc))
|
|
raise ValueError(f"Data validation failed: {str(exc)}")
|
|
|
|
# Global instance - same as before, but much simpler implementation
|
|
data_client = DataClient() |