diff --git a/services/forecasting/app/api/forecasts.py b/services/forecasting/app/api/forecasts.py index 1a9b6ca9..81d63ab3 100644 --- a/services/forecasting/app/api/forecasts.py +++ b/services/forecasting/app/api/forecasts.py @@ -47,7 +47,7 @@ async def create_single_forecast( ) # Generate forecast - forecast = await forecasting_service.generate_forecast(request, db) + forecast = await forecasting_service.generate_forecast(tenant_id, request, db) # Convert to response model return ForecastResponse( diff --git a/services/forecasting/app/schemas/forecasts.py b/services/forecasting/app/schemas/forecasts.py index b9435ecf..ac274426 100644 --- a/services/forecasting/app/schemas/forecasts.py +++ b/services/forecasting/app/schemas/forecasts.py @@ -24,14 +24,7 @@ class ForecastRequest(BaseModel): """Request schema for generating forecasts""" tenant_id: str = Field(..., description="Tenant ID") product_name: str = Field(..., description="Product name") - location: str = Field(..., description="Location identifier") forecast_date: date = Field(..., description="Date for which to generate forecast") - business_type: BusinessType = Field(BusinessType.INDIVIDUAL, description="Business model type") - - # Optional context - include_weather: bool = Field(True, description="Include weather data in forecast") - include_traffic: bool = Field(True, description="Include traffic data in forecast") - confidence_level: float = Field(0.8, ge=0.5, le=0.95, description="Confidence level for intervals") @validator('forecast_date') def validate_forecast_date(cls, v): diff --git a/services/forecasting/app/services/data_client.py b/services/forecasting/app/services/data_client.py new file mode 100644 index 00000000..5dc23c96 --- /dev/null +++ b/services/forecasting/app/services/data_client.py @@ -0,0 +1,64 @@ +# services/training/app/services/data_client.py +""" +Training Service Data Client +Migrated to use shared service clients - much simpler now! +""" + +import structlog +from typing import Dict, Any, List, Optional +from datetime import datetime + +# Import the shared clients +from shared.clients import get_data_client, get_service_clients +from app.core.config import settings + +logger = structlog.get_logger() + +class DataClient: + """ + Data client for training service + Now uses the shared data service client under the hood + """ + + def __init__(self): + # Get the shared data client configured for this service + self.data_client = get_data_client(settings, "forecasting") + + # Or alternatively, get all clients at once: + # self.clients = get_service_clients(settings, "training") + # Then use: self.clients.data.get_sales_data(...) + + + async def fetch_weather_forecast( + self, + tenant_id: str, + days: str, + latitude: Optional[float] = None, + longitude: Optional[float] = None + ) -> List[Dict[str, Any]]: + """ + Fetch weather data for forecats + All the error handling and retry logic is now in the base client! + """ + try: + weather_data = await self.data_client.get_weather_forecast( + tenant_id=tenant_id, + days=days, + latitude=latitude, + longitude=longitude + ) + + if weather_data: + logger.info(f"Fetched {len(weather_data)} weather records", + tenant_id=tenant_id) + return weather_data + else: + logger.warning("No weather data returned", tenant_id=tenant_id) + return [] + + except Exception as e: + logger.error(f"Error fetching weather data: {e}", tenant_id=tenant_id) + return [] + +# Global instance - same as before, but much simpler implementation +data_client = DataClient() \ No newline at end of file diff --git a/services/forecasting/app/services/forecasting_service.py b/services/forecasting/app/services/forecasting_service.py index 1ce60dd2..eea7145b 100644 --- a/services/forecasting/app/services/forecasting_service.py +++ b/services/forecasting/app/services/forecasting_service.py @@ -21,6 +21,8 @@ from app.services.prediction_service import PredictionService from app.services.messaging import publish_forecast_completed, publish_alert_created from app.core.config import settings from shared.monitoring.metrics import MetricsCollector +from app.services.model_client import ModelClient +from app.services.data_client import DataClient logger = structlog.get_logger() metrics = MetricsCollector("forecasting-service") @@ -33,6 +35,8 @@ class ForecastingService: def __init__(self): self.prediction_service = PredictionService() + self.model_client = ModelClient() + self.data_client = DataClient() async def generate_forecast(self, request: ForecastRequest, db: AsyncSession) -> Forecast: """Generate a single forecast for a product""" @@ -47,8 +51,7 @@ class ForecastingService: # Get the latest trained model for this tenant/product model_info = await self._get_latest_model( request.tenant_id, - request.product_name, - request.location + request.product_name, ) if not model_info: @@ -66,10 +69,9 @@ class ForecastingService: # Create forecast record forecast = Forecast( - tenant_id=uuid.UUID(request.tenant_id), - product_name=request.product_name, - location=request.location, - forecast_date=datetime.combine(request.forecast_date, datetime.min.time()), + tenant_id=uuid.UUID(tenant_id), + product_name=product_name, + forecast_date=datetime.combine(forecast_date, datetime.min.time()), # Prediction results predicted_demand=prediction_result["demand"], @@ -243,27 +245,12 @@ class ForecastingService: logger.error("Error retrieving forecasts", error=str(e)) raise - async def _get_latest_model(self, tenant_id: str, product_name: str, location: str) -> Optional[Dict[str, Any]]: + async def _get_latest_model(self, tenant_id: str, product_name: str) -> Optional[Dict[str, Any]]: """Get the latest trained model for a tenant/product combination""" try: - # Call training service to get model information - async with httpx.AsyncClient() as client: - response = await client.get( - f"{settings.TRAINING_SERVICE_URL}/tenants/{tenant_id}/models/{product_name}/active", - params={}, - headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN} - ) - - if response.status_code == 200: - return response.json() - elif response.status_code == 404: - logger.warning("No model found", - tenant_id=tenant_id, - product=product_name) - return None - else: - response.raise_for_status() + model_data = await self.data_client.get_best_model_for_forecasting(tenant_id, product_name) + return model_data except Exception as e: logger.error("Error getting latest model", error=str(e)) @@ -275,22 +262,15 @@ class ForecastingService: features = { "date": request.forecast_date.isoformat(), "day_of_week": request.forecast_date.weekday(), - "is_weekend": request.forecast_date.weekday() >= 5, - "business_type": request.business_type.value + "is_weekend": request.forecast_date.weekday() >= 5 } # Add Spanish holidays features["is_holiday"] = await self._is_spanish_holiday(request.forecast_date) - # Add weather data if requested - if request.include_weather: - weather_data = await self._get_weather_forecast(request.forecast_date) - features.update(weather_data) - - # Add traffic data if requested - if request.include_traffic: - traffic_data = await self._get_traffic_forecast(request.forecast_date, request.location) - features.update(traffic_data) + + weather_data = await self._get_weather_forecast(request.tenant_id, 1) + features.update(weather_data) return features @@ -315,61 +295,16 @@ class ForecastingService: logger.warning("Error checking holiday status", error=str(e)) return False - async def _get_weather_forecast(self, forecast_date: date) -> Dict[str, Any]: + async def _get_weather_forecast(self, tenant_id: str, days: str) -> Dict[str, Any]: """Get weather forecast for the date""" try: - # Call data service for weather forecast - async with httpx.AsyncClient() as client: - response = await client.get( - f"{settings.DATA_SERVICE_URL}/api/v1/weather/forecast", - params={"date": forecast_date.isoformat()}, - headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN} - ) - - if response.status_code == 200: - weather = response.json() - return { - "temperature": weather.get("temperature"), - "precipitation": weather.get("precipitation"), - "humidity": weather.get("humidity"), - "weather_description": weather.get("description") - } - else: - return {} - + weather_data = await self.data_client.fetch_weather_forecast(tenant_id, days) + return weather_data except Exception as e: logger.warning("Error getting weather forecast", error=str(e)) return {} - async def _get_traffic_forecast(self, forecast_date: date, location: str) -> Dict[str, Any]: - """Get traffic forecast for the date and location""" - - try: - # Call data service for traffic forecast - async with httpx.AsyncClient() as client: - response = await client.get( - f"{settings.DATA_SERVICE_URL}/api/v1/traffic/forecast", - params={ - "date": forecast_date.isoformat(), - "location": location - }, - headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN} - ) - - if response.status_code == 200: - traffic = response.json() - return { - "traffic_volume": traffic.get("volume"), - "pedestrian_count": traffic.get("pedestrian_count") - } - else: - return {} - - except Exception as e: - logger.warning("Error getting traffic forecast", error=str(e)) - return {} - async def _check_and_create_alerts(self, forecast: Forecast, db: AsyncSession): """Check forecast and create alerts if needed""" diff --git a/services/forecasting/app/services/model_client.py b/services/forecasting/app/services/model_client.py new file mode 100644 index 00000000..08c4d0b7 --- /dev/null +++ b/services/forecasting/app/services/model_client.py @@ -0,0 +1,183 @@ +# services/forecasting/app/services/model_client.py +""" +Forecast Service Model Client +Demonstrates calling training service to get models +""" + +import structlog +from typing import Dict, Any, List, Optional + +# Import shared clients - no more code duplication! +from shared.clients import get_service_clients, get_training_client, get_data_client +from app.core.config import settings + +logger = structlog.get_logger() + +class ModelClient: + """ + Client for managing models in forecasting service + Shows how to call multiple services cleanly + """ + + def __init__(self): + # Option 1: Get all clients at once + self.clients = get_service_clients(settings, "forecasting") + + # Option 2: Get specific clients + # self.training_client = get_training_client(settings, "forecasting") + # self.data_client = get_data_client(settings, "forecasting") + + async def get_available_models( + self, + tenant_id: str, + model_type: Optional[str] = None + ) -> List[Dict[str, Any]]: + """ + Get available trained models from training service + """ + try: + models = await self.clients.training.list_models( + tenant_id=tenant_id, + status="deployed", # Only get deployed models + model_type=model_type + ) + + if models: + logger.info(f"Found {len(models)} available models", + tenant_id=tenant_id, model_type=model_type) + return models + else: + logger.warning("No available models found", tenant_id=tenant_id) + return [] + + except Exception as e: + logger.error(f"Error fetching available models: {e}", tenant_id=tenant_id) + return [] + + async def get_best_model_for_forecasting( + self, + tenant_id: str, + product_id: Optional[str] = None + ) -> Optional[Dict[str, Any]]: + """ + Get the best model for forecasting based on performance metrics + """ + try: + # Get latest model + latest_model = await self.clients.training.get_latest_model( + tenant_id=tenant_id, + model_type="forecasting" + ) + + if not latest_model: + logger.warning("No trained models found", tenant_id=tenant_id) + return None + + # Get model metrics to validate quality + metrics = await self.clients.training.get_model_metrics( + tenant_id=tenant_id, + model_id=latest_model["id"] + ) + + if metrics and metrics.get("accuracy", 0) > 0.7: # 70% accuracy threshold + logger.info(f"Selected model {latest_model['id']} with accuracy {metrics.get('accuracy')}", + tenant_id=tenant_id) + return latest_model + else: + logger.warning(f"Model accuracy too low: {metrics.get('accuracy', 'unknown')}", + tenant_id=tenant_id) + return None + + except Exception as e: + logger.error(f"Error selecting best model: {e}", tenant_id=tenant_id) + return None + + async def validate_model_data_compatibility( + self, + tenant_id: str, + model_id: str, + forecast_start_date: str, + forecast_end_date: str + ) -> Dict[str, Any]: + """ + Validate that we have sufficient data for the model to make forecasts + Demonstrates calling both training and data services + """ + try: + # Get model details from training service + model = await self.clients.training.get_model( + tenant_id=tenant_id, + model_id=model_id + ) + + if not model: + return {"is_valid": False, "error": "Model not found"} + + # Get data statistics from data service + data_stats = await self.clients.data.get_data_statistics( + tenant_id=tenant_id, + start_date=forecast_start_date, + end_date=forecast_end_date + ) + + if not data_stats: + return {"is_valid": False, "error": "Could not retrieve data statistics"} + + # Check if we have minimum required data points + min_required = model.get("metadata", {}).get("min_data_points", 30) + available_points = data_stats.get("total_records", 0) + + is_valid = available_points >= min_required + + result = { + "is_valid": is_valid, + "model_id": model_id, + "required_points": min_required, + "available_points": available_points, + "data_coverage": data_stats.get("coverage_percentage", 0) + } + + if not is_valid: + result["error"] = f"Insufficient data: need {min_required}, have {available_points}" + + logger.info("Model data compatibility check completed", + tenant_id=tenant_id, model_id=model_id, is_valid=is_valid) + + return result + + except Exception as e: + logger.error(f"Error validating model compatibility: {e}", + tenant_id=tenant_id, model_id=model_id) + return {"is_valid": False, "error": str(e)} + + async def trigger_model_retraining( + self, + tenant_id: str, + include_weather: bool = True, + include_traffic: bool = False + ) -> Optional[Dict[str, Any]]: + """ + Trigger a new training job if current model is outdated + """ + try: + # Create training job through training service + job = await self.clients.training.create_training_job( + tenant_id=tenant_id, + include_weather=include_weather, + include_traffic=include_traffic, + min_data_points=50 # Higher threshold for forecasting + ) + + if job: + logger.info(f"Training job created: {job['job_id']}", tenant_id=tenant_id) + return job + else: + logger.error("Failed to create training job", tenant_id=tenant_id) + return None + + except Exception as e: + logger.error(f"Error triggering model retraining: {e}", tenant_id=tenant_id) + return None + +# Global instance +model_client = ModelClient() \ No newline at end of file diff --git a/services/training/app/core/service_auth.py b/services/training/app/core/service_auth.py deleted file mode 100644 index e07ad6b4..00000000 --- a/services/training/app/core/service_auth.py +++ /dev/null @@ -1,81 +0,0 @@ -import time -import structlog -from typing import Dict, Any -from shared.auth.jwt_handler import JWTHandler -from app.core.config import settings - -logger = structlog.get_logger() - -class ServiceAuthenticator: - """Handles service-to-service authentication via gateway""" - - def __init__(self): - self.jwt_handler = JWTHandler(settings.JWT_SECRET_KEY) - self._cached_token = None - self._token_expires_at = 0 - - async def get_service_token(self) -> str: - """ - Get a valid service token, using cache when possible - Creates JWT tokens that the gateway will accept - """ - current_time = int(time.time()) - - # Return cached token if still valid (with 5 min buffer) - if (self._cached_token and - self._token_expires_at > current_time + 300): - return self._cached_token - - # Create new service token - token_expires_at = current_time + 3600 # 1 hour - - service_payload = { - # ✅ Required fields for gateway middleware - "sub": "training-service", - "user_id": "training-service", - "email": "training-service@internal", - "type": "access", # ✅ Must be "access" for gateway - - # ✅ Expiration and timing - "exp": token_expires_at, - "iat": current_time, - "iss": "training-service", - - # ✅ Service identification - "service": "training", - "full_name": "Training Service", - "is_verified": True, - "is_active": True, - - # ✅ Optional tenant context (can be overridden per request) - "tenant_id": None - } - - try: - token = self.jwt_handler.create_access_token_from_payload(service_payload) - - # Cache the token - self._cached_token = token - self._token_expires_at = token_expires_at - - logger.debug("Created new service token", expires_at=token_expires_at) - return token - - except Exception as e: - logger.error(f"Failed to create service token: {e}") - raise ValueError(f"Service token creation failed: {e}") - - def get_request_headers(self, tenant_id: str = None) -> Dict[str, str]: - """Get standard headers for service requests""" - headers = { - "X-Service": "training-service", - "User-Agent": "training-service/1.0.0" - } - - if tenant_id: - headers["X-Tenant-ID"] = str(tenant_id) - - return headers - -# Global authenticator instance -service_auth = ServiceAuthenticator() \ No newline at end of file diff --git a/services/training/app/services/data_client.py b/services/training/app/services/data_client.py index 1bf29abd..e1ac00af 100644 --- a/services/training/app/services/data_client.py +++ b/services/training/app/services/data_client.py @@ -1,219 +1,89 @@ -import httpx +# services/training/app/services/data_client.py +""" +Training Service Data Client +Migrated to use shared service clients - much simpler now! +""" + import structlog -from typing import List, Dict, Any, Optional +from typing import Dict, Any, List, Optional +from datetime import datetime + +# Import the shared clients +from shared.clients import get_data_client, get_service_clients from app.core.config import settings -from app.core.service_auth import service_auth logger = structlog.get_logger() -class DataServiceClient: - """Client for fetching data through the API Gateway""" +class DataClient: + """ + Data client for training service + Now uses the shared data service client under the hood + """ def __init__(self): - self.base_url = settings.API_GATEWAY_URL - self.timeout = 2000.0 - - async def fetch_sales_data(self, tenant_id: str) -> List[Dict[str, Any]]: - """ - Fetch all sales data for training (no pagination limits) - FIXED: Retrieves ALL records instead of being limited to 1000 - """ - try: - # Get service token - token = await service_auth.get_service_token() - - # Prepare headers - headers = service_auth.get_request_headers(tenant_id) - headers["Authorization"] = f"Bearer {token}" - - all_records = [] - page = 0 - page_size = 5000 # Use maximum allowed by API - - while True: - # Prepare query parameters for pagination - params = { - "limit": page_size, - "offset": page * page_size - } - - logger.info(f"Fetching sales data page {page + 1} (offset: {page * page_size})", - tenant_id=tenant_id) - - # Make GET request via gateway with pagination - async with httpx.AsyncClient(timeout=self.timeout) as client: - response = await client.get( - f"{self.base_url}/api/v1/tenants/{tenant_id}/sales", - headers=headers, - params=params - ) - - if response.status_code == 200: - page_data = response.json() - - # Handle different response formats - if isinstance(page_data, list): - # Direct list response (no pagination metadata) - records = page_data - logger.info(f"Retrieved {len(records)} records from page {page + 1} (direct list)") - - # For direct list responses, we need to check if we got the max possible - # If we got less than page_size, we're done - if len(records) == 0: - logger.info("No records in response, pagination complete") - break - elif len(records) < page_size: - # Got fewer than requested, this is the last page - all_records.extend(records) - logger.info(f"Final page: retrieved {len(records)} records, total: {len(all_records)}") - break - else: - # Got full page, there might be more - all_records.extend(records) - logger.info(f"Full page retrieved: {len(records)} records, continuing to next page") - - elif isinstance(page_data, dict): - # Paginated response format - records = page_data.get('records', page_data.get('data', [])) - total_available = page_data.get('total', 0) - - logger.info(f"Retrieved {len(records)} records from page {page + 1} (paginated response)") - - if not records: - logger.info("No more records found in paginated response") - break - - all_records.extend(records) - - # Check if we've got all available records - if len(all_records) >= total_available: - logger.info(f"Retrieved all available records: {len(all_records)}/{total_available}") - break - - else: - logger.warning(f"Unexpected response format: {type(page_data)}") - records = [] - break - - page += 1 - - # Safety break to prevent infinite loops - if page > 100: # Max 500,000 records (100 * 5000) - logger.warning("Reached maximum page limit, stopping pagination") - break - - elif response.status_code == 401: - logger.error("Authentication failed with gateway", - tenant_id=tenant_id, - response_text=response.text) - return [] - - elif response.status_code == 404: - logger.warning("Sales data endpoint not found", - tenant_id=tenant_id, - url=response.url) - return [] - - else: - logger.error(f"Gateway request failed: HTTP {response.status_code}", - tenant_id=tenant_id, - response_text=response.text) - return [] - - logger.info(f"Successfully fetched {len(all_records)} total sales records via gateway", - tenant_id=tenant_id) - return all_records - - except httpx.TimeoutException: - logger.error("Timeout when fetching sales data via gateway", - tenant_id=tenant_id) - return [] - - except Exception as e: - logger.error(f"Error fetching sales data via gateway: {e}", - tenant_id=tenant_id) - return [] + # Get the shared data client configured for this service + self.data_client = get_data_client(settings, "training") - - async def fetch_weather_data( - self, + # Or alternatively, get all clients at once: + # self.clients = get_service_clients(settings, "training") + # Then use: self.clients.data.get_sales_data(...) + + async def fetch_sales_data( + self, tenant_id: str, - start_date: str, - end_date: str, - latitude: Optional[float] = None, - longitude: Optional[float] = None + start_date: Optional[str] = None, + end_date: Optional[str] = None, + product_id: Optional[str] = None, + fetch_all: bool = True ) -> List[Dict[str, Any]]: """ - Fetch historical weather data for training via API Gateway using POST + Fetch sales data for training + + Args: + tenant_id: Tenant identifier + start_date: Start date in ISO format + end_date: End date in ISO format + product_id: Optional product filter + fetch_all: If True, fetches ALL records using pagination (original behavior) + If False, fetches limited records (standard API response) """ try: - # Get service token - token = await service_auth.get_service_token() - - # Prepare headers - headers = service_auth.get_request_headers(tenant_id) - headers["Authorization"] = f"Bearer {token}" - headers["Content-Type"] = "application/json" - - # Prepare request payload with proper date handling - payload = { - "start_date": start_date, # Already in ISO format from calling code - "end_date": end_date, # Already in ISO format from calling code - "latitude": latitude or 40.4168, # Default Madrid coordinates - "longitude": longitude or -3.7038 - } - - logger.info(f"Weather request payload: {payload}", tenant_id=tenant_id) - - # Make POST request via gateway - async with httpx.AsyncClient(timeout=self.timeout) as client: - response = await client.post( - f"{self.base_url}/api/v1/tenants/{tenant_id}/weather/historical", - headers=headers, - json=payload + if fetch_all: + # Use paginated method to get ALL records (original behavior) + sales_data = await self.data_client.get_all_sales_data( + tenant_id=tenant_id, + start_date=start_date, + end_date=end_date, + product_id=product_id, + aggregation="daily", + page_size=5000, # Match original page size + max_pages=100 # Safety limit (500k records max) ) + else: + # Use standard method for limited results + sales_data = await self.data_client.get_sales_data( + tenant_id=tenant_id, + start_date=start_date, + end_date=end_date, + product_id=product_id, + aggregation="daily" + ) + sales_data = sales_data or [] + + if sales_data: + logger.info(f"Fetched {len(sales_data)} sales records", + tenant_id=tenant_id, product_id=product_id, fetch_all=fetch_all) + return sales_data + else: + logger.warning("No sales data returned", tenant_id=tenant_id) + return [] - logger.info(f"Weather data request: {response.status_code}", - tenant_id=tenant_id, - url=response.url) - - if response.status_code == 200: - data = response.json() - logger.info(f"Successfully fetched {len(data)} weather records") - return data - elif response.status_code == 400: - error_details = response.text - logger.error(f"Weather API validation error (400): {error_details}") - - # Try to parse the error and provide helpful info - try: - error_json = response.json() - if 'detail' in error_json: - detail = error_json['detail'] - if 'End date must be after start date' in str(detail): - logger.error(f"Date range issue: start={start_date}, end={end_date}") - elif 'Date range cannot exceed 90 days' in str(detail): - logger.error(f"Date range too large: {start_date} to {end_date}") - except: - pass - - return [] - elif response.status_code == 401: - logger.error("Authentication failed for weather API") - return [] - else: - logger.error(f"Failed to fetch weather data: {response.status_code} - {response.text}") - return [] - - except httpx.TimeoutException: - logger.error("Timeout when fetching weather data") - return [] except Exception as e: - logger.error(f"Error fetching weather data: {str(e)}") + logger.error(f"Error fetching sales data: {e}", tenant_id=tenant_id) return [] - - async def fetch_traffic_data( - self, + + async def fetch_weather_data( + self, tenant_id: str, start_date: str, end_date: str, @@ -221,65 +91,90 @@ class DataServiceClient: longitude: Optional[float] = None ) -> List[Dict[str, Any]]: """ - Fetch historical traffic data for training via API Gateway using POST + Fetch weather data for training + All the error handling and retry logic is now in the base client! """ try: - # Get service token - token = await service_auth.get_service_token() - - # Prepare headers - headers = service_auth.get_request_headers(tenant_id) - headers["Authorization"] = f"Bearer {token}" - headers["Content-Type"] = "application/json" - - # Prepare request payload - payload = { - "start_date": start_date, # Already in ISO format from calling code - "end_date": end_date, # Already in ISO format from calling code - "latitude": latitude or 40.4168, # Default Madrid coordinates - "longitude": longitude or -3.7038 - } - - logger.info(f"Traffic request payload: {payload}", tenant_id=tenant_id) - - # Madrid traffic data can take 5-10 minutes to download and process - timeout_config = httpx.Timeout( - connect=30.0, # Connection timeout - read=600.0, # Read timeout: 10 minutes (was 30s) - write=30.0, # Write timeout - pool=30.0 # Pool timeout + weather_data = await self.data_client.get_weather_historical( + tenant_id=tenant_id, + start_date=start_date, + end_date=end_date, + latitude=latitude, + longitude=longitude ) - # Make POST request via gateway - async with httpx.AsyncClient(timeout=timeout_config) as client: - response = await client.post( - f"{self.base_url}/api/v1/tenants/{tenant_id}/traffic/historical", - headers=headers, - json=payload - ) + if weather_data: + logger.info(f"Fetched {len(weather_data)} weather records", + tenant_id=tenant_id) + return weather_data + else: + logger.warning("No weather data returned", tenant_id=tenant_id) + return [] - logger.info(f"Traffic data request: {response.status_code}", - tenant_id=tenant_id, - url=response.url) - - if response.status_code == 200: - data = response.json() - logger.info(f"Successfully fetched {len(data)} traffic records") - return data - elif response.status_code == 400: - error_details = response.text - logger.error(f"Traffic API validation error (400): {error_details}") - return [] - elif response.status_code == 401: - logger.error("Authentication failed for traffic API") - return [] - else: - logger.error(f"Failed to fetch traffic data: {response.status_code} - {response.text}") - return [] - - except httpx.TimeoutException: - logger.error("Timeout when fetching traffic data") - return [] except Exception as e: - logger.error(f"Error fetching traffic data: {str(e)}") - return [] \ No newline at end of file + logger.error(f"Error fetching weather data: {e}", tenant_id=tenant_id) + return [] + + async def fetch_traffic_data( + self, + tenant_id: str, + start_date: str, + end_date: str, + latitude: Optional[float] = None, + longitude: Optional[float] = None + ) -> List[Dict[str, Any]]: + """ + Fetch traffic data for training + """ + try: + traffic_data = await self.data_client.get_traffic_data( + tenant_id=tenant_id, + start_date=start_date, + end_date=end_date, + latitude=latitude, + longitude=longitude + ) + + if traffic_data: + logger.info(f"Fetched {len(traffic_data)} traffic records", + tenant_id=tenant_id) + return traffic_data + else: + logger.warning("No traffic data returned", tenant_id=tenant_id) + return [] + + except Exception as e: + logger.error(f"Error fetching traffic data: {e}", tenant_id=tenant_id) + return [] + + async def validate_data_quality( + self, + tenant_id: str, + start_date: str, + end_date: str + ) -> Dict[str, Any]: + """ + Validate data quality before training + """ + try: + validation_result = await self.data_client.validate_data_quality( + tenant_id=tenant_id, + start_date=start_date, + end_date=end_date + ) + + if validation_result: + logger.info("Data validation completed", + tenant_id=tenant_id, + is_valid=validation_result.get("is_valid", False)) + return validation_result + else: + logger.warning("Data validation failed", tenant_id=tenant_id) + return {"is_valid": False, "errors": ["Validation service unavailable"]} + + except Exception as e: + logger.error(f"Error validating data: {e}", tenant_id=tenant_id) + return {"is_valid": False, "errors": [str(e)]} + +# Global instance - same as before, but much simpler implementation +data_client = DataClient() \ No newline at end of file diff --git a/services/training/app/services/training_orchestrator.py b/services/training/app/services/training_orchestrator.py index cc907948..c417859b 100644 --- a/services/training/app/services/training_orchestrator.py +++ b/services/training/app/services/training_orchestrator.py @@ -13,7 +13,7 @@ from concurrent.futures import ThreadPoolExecutor from datetime import timezone import pandas as pd -from app.services.data_client import DataServiceClient +from app.services.data_client import DataClient from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType, AlignedDateRange logger = logging.getLogger(__name__) @@ -37,7 +37,7 @@ class TrainingDataOrchestrator: madrid_client=None, weather_client=None, date_alignment_service: DateAlignmentService = None): - self.data_client = DataServiceClient() + self.data_client = DataClient() self.date_alignment_service = date_alignment_service or DateAlignmentService() self.max_concurrent_requests = 3 diff --git a/shared/clients/__init__.py b/shared/clients/__init__.py new file mode 100644 index 00000000..03c75218 --- /dev/null +++ b/shared/clients/__init__.py @@ -0,0 +1,106 @@ +# shared/clients/__init__.py +""" +Service Client Factory and Convenient Imports +Provides easy access to all service clients +""" + +from .base_service_client import BaseServiceClient, ServiceAuthenticator +from .training_client import TrainingServiceClient +from .data_client import DataServiceClient +from .forecast_client import ForecastServiceClient + +# Import config +from shared.config.base import BaseServiceSettings + +# Cache clients to avoid recreating them +_client_cache = {} + +def get_training_client(config: BaseServiceSettings = None, service_name: str = "unknown") -> TrainingServiceClient: + """Get or create a training service client""" + if config is None: + from app.core.config import settings as config + + cache_key = f"training_{service_name}" + if cache_key not in _client_cache: + _client_cache[cache_key] = TrainingServiceClient(config, service_name) + return _client_cache[cache_key] + +def get_data_client(config: BaseServiceSettings = None, service_name: str = "unknown") -> DataServiceClient: + """Get or create a data service client""" + if config is None: + from app.core.config import settings as config + + cache_key = f"data_{service_name}" + if cache_key not in _client_cache: + _client_cache[cache_key] = DataServiceClient(config, service_name) + return _client_cache[cache_key] + +def get_forecast_client(config: BaseServiceSettings = None, service_name: str = "unknown") -> ForecastServiceClient: + """Get or create a forecast service client""" + if config is None: + from app.core.config import settings as config + + cache_key = f"forecast_{service_name}" + if cache_key not in _client_cache: + _client_cache[cache_key] = ForecastServiceClient(config, service_name) + return _client_cache[cache_key] + +class ServiceClients: + """Convenient wrapper for all service clients""" + + def __init__(self, config: BaseServiceSettings = None, service_name: str = "unknown"): + self.service_name = service_name + self.config = config or self._get_default_config() + + # Initialize clients lazily + self._training_client = None + self._data_client = None + self._forecast_client = None + + def _get_default_config(self): + """Get default config from app settings""" + try: + from app.core.config import settings + return settings + except ImportError: + raise ImportError("Could not import app config. Please provide config explicitly.") + + @property + def training(self) -> TrainingServiceClient: + """Get training service client""" + if self._training_client is None: + self._training_client = get_training_client(self.config, self.service_name) + return self._training_client + + @property + def data(self) -> DataServiceClient: + """Get data service client""" + if self._data_client is None: + self._data_client = get_data_client(self.config, self.service_name) + return self._data_client + + @property + def forecast(self) -> ForecastServiceClient: + """Get forecast service client""" + if self._forecast_client is None: + self._forecast_client = get_forecast_client(self.config, self.service_name) + return self._forecast_client + +# Convenience function to get all clients +def get_service_clients(config: BaseServiceSettings = None, service_name: str = "unknown") -> ServiceClients: + """Get a wrapper with all service clients""" + return ServiceClients(config, service_name) + +# Export all classes for direct import +__all__ = [ + 'BaseServiceClient', + 'ServiceAuthenticator', + 'TrainingServiceClient', + 'DataServiceClient', + 'ForecastServiceClient', + 'ServiceClients', + 'get_training_client', + 'get_data_client', + 'get_forecast_client', + 'get_service_clients' +] \ No newline at end of file diff --git a/shared/clients/base_service_client.py b/shared/clients/base_service_client.py new file mode 100644 index 00000000..43bcccce --- /dev/null +++ b/shared/clients/base_service_client.py @@ -0,0 +1,363 @@ +# shared/clients/base_service_client.py +""" +Base Service Client for Inter-Service Communication +Provides a reusable foundation for all service-to-service API calls +""" + +import time +import asyncio +import httpx +import structlog +from abc import ABC, abstractmethod +from typing import Dict, Any, Optional, List, Union +from urllib.parse import urljoin + +from shared.auth.jwt_handler import JWTHandler +from shared.config.base import BaseServiceSettings + +logger = structlog.get_logger() + +class ServiceAuthenticator: + """Handles service-to-service authentication via gateway""" + + def __init__(self, service_name: str, config: BaseServiceSettings): + self.service_name = service_name + self.config = config + self.jwt_handler = JWTHandler(config.JWT_SECRET_KEY) + self._cached_token = None + self._token_expires_at = 0 + + async def get_service_token(self) -> str: + """Get a valid service token, using cache when possible""" + current_time = int(time.time()) + + # Return cached token if still valid (with 5 min buffer) + if (self._cached_token and + self._token_expires_at > current_time + 300): + return self._cached_token + + # Create new service token + token_expires_at = current_time + 3600 # 1 hour + + service_payload = { + "sub": f"{self.service_name}-service", + "user_id": f"{self.service_name}-service", + "email": f"{self.service_name}-service@internal", + "type": "access", + "exp": token_expires_at, + "iat": current_time, + "iss": f"{self.service_name}-service", + "service": self.service_name, + "full_name": f"{self.service_name.title()} Service", + "is_verified": True, + "is_active": True, + "tenant_id": None + } + + try: + token = self.jwt_handler.create_access_token_from_payload(service_payload) + self._cached_token = token + self._token_expires_at = token_expires_at + + logger.debug("Created new service token", service=self.service_name, expires_at=token_expires_at) + return token + + except Exception as e: + logger.error(f"Failed to create service token: {e}", service=self.service_name) + raise ValueError(f"Service token creation failed: {e}") + + def get_request_headers(self, tenant_id: Optional[str] = None) -> Dict[str, str]: + """Get standard headers for service requests""" + headers = { + "X-Service": f"{self.service_name}-service", + "User-Agent": f"{self.service_name}-service/1.0.0" + } + + if tenant_id: + headers["X-Tenant-ID"] = str(tenant_id) + + return headers + + +class BaseServiceClient(ABC): + """ + Base class for all inter-service communication clients + Provides common functionality for API calls through the gateway + """ + + def __init__(self, service_name: str, config: BaseServiceSettings): + self.service_name = service_name + self.config = config + self.gateway_url = config.GATEWAY_URL + self.authenticator = ServiceAuthenticator(service_name, config) + + # HTTP client configuration + self.timeout = config.HTTP_TIMEOUT + self.retries = config.HTTP_RETRIES + self.retry_delay = config.HTTP_RETRY_DELAY + + @abstractmethod + def get_service_base_path(self) -> str: + """Return the base path for this service's APIs""" + pass + + async def _make_request( + self, + method: str, + endpoint: str, + tenant_id: Optional[str] = None, + data: Optional[Dict[str, Any]] = None, + params: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, str]] = None, + timeout: Optional[Union[int, httpx.Timeout]] = None + ) -> Optional[Union[Dict[str, Any], List[Dict[str, Any]]]]: + """ + Make an authenticated request to another service via gateway + + Args: + method: HTTP method (GET, POST, PUT, DELETE) + endpoint: API endpoint (will be prefixed with service base path) + tenant_id: Optional tenant ID for tenant-scoped requests + data: Request body data (for POST/PUT) + params: Query parameters + headers: Additional headers + timeout: Request timeout override + + Returns: + Response data or None if request failed + """ + try: + # Get service token + token = await self.authenticator.get_service_token() + + # Build headers + request_headers = self.authenticator.get_request_headers(tenant_id) + request_headers["Authorization"] = f"Bearer {token}" + request_headers["Content-Type"] = "application/json" + + if headers: + request_headers.update(headers) + + # Build URL + base_path = self.get_service_base_path() + if tenant_id: + # For tenant-scoped endpoints + full_endpoint = f"{base_path}/tenants/{tenant_id}/{endpoint.lstrip('/')}" + else: + # For non-tenant endpoints + full_endpoint = f"{base_path}/{endpoint.lstrip('/')}" + + url = urljoin(self.gateway_url, full_endpoint) + + # Make request with retries + for attempt in range(self.retries + 1): + try: + # Handle different timeout configurations + if isinstance(timeout, httpx.Timeout): + client_timeout = timeout + else: + client_timeout = timeout or self.timeout + + async with httpx.AsyncClient(timeout=client_timeout) as client: + response = await client.request( + method=method, + url=url, + json=data, + params=params, + headers=request_headers + ) + + if response.status_code == 200: + return response.json() + elif response.status_code == 201: + return response.json() + elif response.status_code == 204: + return {} # No content success + elif response.status_code == 401: + # Token might be expired, clear cache and retry once + if attempt == 0: + self.authenticator._cached_token = None + logger.warning("Token expired, retrying with new token") + continue + else: + logger.error("Authentication failed after retry") + return None + elif response.status_code == 404: + logger.warning(f"Endpoint not found: {url}") + return None + else: + error_detail = "Unknown error" + try: + error_json = response.json() + error_detail = error_json.get('detail', f"HTTP {response.status_code}") + except: + error_detail = f"HTTP {response.status_code}: {response.text}" + + logger.error(f"Request failed: {error_detail}", + url=url, status_code=response.status_code) + return None + + except httpx.TimeoutException: + if attempt < self.retries: + logger.warning(f"Request timeout, retrying ({attempt + 1}/{self.retries})") + import asyncio + await asyncio.sleep(self.retry_delay * (2 ** attempt)) # Exponential backoff + continue + else: + logger.error(f"Request timeout after {self.retries} retries", url=url) + return None + + except Exception as e: + if attempt < self.retries: + logger.warning(f"Request failed, retrying ({attempt + 1}/{self.retries}): {e}") + import asyncio + await asyncio.sleep(self.retry_delay * (2 ** attempt)) + continue + else: + logger.error(f"Request failed after {self.retries} retries: {e}", url=url) + return None + + except Exception as e: + logger.error(f"Unexpected error in _make_request: {e}") + return None + + async def _make_paginated_request( + self, + endpoint: str, + tenant_id: Optional[str] = None, + params: Optional[Dict[str, Any]] = None, + page_size: int = 5000, + max_pages: int = 100, + timeout: Optional[Union[int, httpx.Timeout]] = None + ) -> List[Dict[str, Any]]: + """ + Make paginated GET requests to fetch all records + Handles both direct list and paginated object responses + + Args: + endpoint: API endpoint + tenant_id: Optional tenant ID + params: Base query parameters + page_size: Records per page (default 5000) + max_pages: Maximum pages to fetch (safety limit) + timeout: Request timeout override + + Returns: + List of all records from all pages + """ + all_records = [] + page = 0 + base_params = params or {} + + logger.info(f"Starting paginated request to {endpoint}", + tenant_id=tenant_id, page_size=page_size) + + while page < max_pages: + # Prepare pagination parameters + page_params = base_params.copy() + page_params.update({ + "limit": page_size, + "offset": page * page_size + }) + + logger.debug(f"Fetching page {page + 1} (offset: {page * page_size})", + tenant_id=tenant_id) + + # Make request for this page + result = await self._make_request( + "GET", + endpoint, + tenant_id=tenant_id, + params=page_params, + timeout=timeout + ) + + if result is None: + logger.error(f"Failed to fetch page {page + 1}", tenant_id=tenant_id) + break + + # Handle different response formats + if isinstance(result, list): + # Direct list response (no pagination metadata) + records = result + logger.debug(f"Retrieved {len(records)} records from page {page + 1} (direct list)") + + if len(records) == 0: + logger.info("No records in response, pagination complete") + break + elif len(records) < page_size: + # Got fewer than requested, this is the last page + all_records.extend(records) + logger.info(f"Final page: retrieved {len(records)} records, total: {len(all_records)}") + break + else: + # Got full page, there might be more + all_records.extend(records) + logger.debug(f"Full page retrieved: {len(records)} records, continuing to next page") + + elif isinstance(result, dict): + # Paginated response format + records = result.get('records', result.get('data', [])) + total_available = result.get('total', 0) + + logger.debug(f"Retrieved {len(records)} records from page {page + 1} (paginated response)") + + if not records: + logger.info("No more records found in paginated response") + break + + all_records.extend(records) + + # Check if we've got all available records + if len(all_records) >= total_available: + logger.info(f"Retrieved all available records: {len(all_records)}/{total_available}") + break + + else: + logger.warning(f"Unexpected response format: {type(result)}") + break + + page += 1 + + if page >= max_pages: + logger.warning(f"Reached maximum page limit ({max_pages}), stopping pagination") + + logger.info(f"Pagination complete: fetched {len(all_records)} total records", + tenant_id=tenant_id, pages_fetched=page) + + return all_records + + async def get(self, endpoint: str, tenant_id: Optional[str] = None, params: Optional[Dict[str, Any]] = None) -> Optional[Dict[str, Any]]: + """Make a GET request""" + return await self._make_request("GET", endpoint, tenant_id=tenant_id, params=params) + + async def get_paginated( + self, + endpoint: str, + tenant_id: Optional[str] = None, + params: Optional[Dict[str, Any]] = None, + page_size: int = 5000, + max_pages: int = 100, + timeout: Optional[Union[int, httpx.Timeout]] = None + ) -> List[Dict[str, Any]]: + """Make a paginated GET request to fetch all records""" + return await self._make_paginated_request( + endpoint, + tenant_id=tenant_id, + params=params, + page_size=page_size, + max_pages=max_pages, + timeout=timeout + ) + + async def post(self, endpoint: str, data: Dict[str, Any], tenant_id: Optional[str] = None) -> Optional[Dict[str, Any]]: + """Make a POST request""" + return await self._make_request("POST", endpoint, tenant_id=tenant_id, data=data) + + async def put(self, endpoint: str, data: Dict[str, Any], tenant_id: Optional[str] = None) -> Optional[Dict[str, Any]]: + """Make a PUT request""" + return await self._make_request("PUT", endpoint, tenant_id=tenant_id, data=data) + + async def delete(self, endpoint: str, tenant_id: Optional[str] = None) -> Optional[Dict[str, Any]]: + """Make a DELETE request""" + return await self._make_request("DELETE", endpoint, tenant_id=tenant_id) \ No newline at end of file diff --git a/shared/clients/data_client.py b/shared/clients/data_client.py new file mode 100644 index 00000000..a73d8e11 --- /dev/null +++ b/shared/clients/data_client.py @@ -0,0 +1,399 @@ +# shared/clients/data_client.py +""" +Data Service Client +Handles all API calls to the data service +""" + +import httpx +import structlog +from typing import Dict, Any, Optional, List, Union +from .base_service_client import BaseServiceClient +from shared.config.base import BaseServiceSettings + +logger = structlog.get_logger() + + +class DataServiceClient(BaseServiceClient): + """Client for communicating with the data service""" + + def __init__(self, config: BaseServiceSettings, calling_service_name: str = "unknown"): + super().__init__(calling_service_name, config) + + def get_service_base_path(self) -> str: + return "/api/v1" + + # ================================================================ + # SALES DATA (with advanced pagination support) + # ================================================================ + + async def get_sales_data( + self, + tenant_id: str, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + product_id: Optional[str] = None, + aggregation: str = "daily" + ) -> Optional[List[Dict[str, Any]]]: + """Get sales data for a date range""" + params = {"aggregation": aggregation} + if start_date: + params["start_date"] = start_date + if end_date: + params["end_date"] = end_date + if product_id: + params["product_id"] = product_id + + result = await self.get("sales", tenant_id=tenant_id, params=params) + return result.get("sales", []) if result else None + + async def get_all_sales_data( + self, + tenant_id: str, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + product_id: Optional[str] = None, + aggregation: str = "daily", + page_size: int = 5000, + max_pages: int = 100 + ) -> List[Dict[str, Any]]: + """ + Get ALL sales data using pagination (equivalent to original fetch_sales_data) + Retrieves all records without pagination limits + """ + params = {"aggregation": aggregation} + if start_date: + params["start_date"] = start_date + if end_date: + params["end_date"] = end_date + if product_id: + params["product_id"] = product_id + + # Use the inherited paginated request method + try: + all_records = await self.get_paginated( + "sales", + tenant_id=tenant_id, + params=params, + page_size=page_size, + max_pages=max_pages, + timeout=2000.0 # Match original timeout + ) + + logger.info(f"Successfully fetched {len(all_records)} total sales records via gateway", + tenant_id=tenant_id) + return all_records + + except AttributeError as e: + # Fallback: implement pagination directly if inheritance isn't working + logger.warning(f"Using fallback pagination due to: {e}") + return await self._fallback_paginated_sales(tenant_id, params, page_size, max_pages) + + async def _fallback_paginated_sales( + self, + tenant_id: str, + base_params: Dict[str, Any], + page_size: int = 5000, + max_pages: int = 100 + ) -> List[Dict[str, Any]]: + """ + Fallback pagination implementation for sales data + This replicates your original pagination logic directly + """ + all_records = [] + page = 0 + + logger.info(f"Starting fallback paginated request for sales data", + tenant_id=tenant_id, page_size=page_size) + + while page < max_pages: + # Prepare pagination parameters + params = base_params.copy() + params.update({ + "limit": page_size, + "offset": page * page_size + }) + + logger.info(f"Fetching sales data page {page + 1} (offset: {page * page_size})", + tenant_id=tenant_id) + + # Make request using the base client's _make_request method + result = await self._make_request( + "GET", + "sales", + tenant_id=tenant_id, + params=params, + timeout=2000.0 + ) + + if result is None: + logger.error(f"Failed to fetch page {page + 1}", tenant_id=tenant_id) + break + + # Handle different response formats (from your original code) + if isinstance(result, list): + # Direct list response (no pagination metadata) + records = result + logger.info(f"Retrieved {len(records)} records from page {page + 1} (direct list)") + + if len(records) == 0: + logger.info("No records in response, pagination complete") + break + elif len(records) < page_size: + # Got fewer than requested, this is the last page + all_records.extend(records) + logger.info(f"Final page: retrieved {len(records)} records, total: {len(all_records)}") + break + else: + # Got full page, there might be more + all_records.extend(records) + logger.info(f"Full page retrieved: {len(records)} records, continuing to next page") + + elif isinstance(result, dict): + # Paginated response format + records = result.get('records', result.get('data', [])) + total_available = result.get('total', 0) + + logger.info(f"Retrieved {len(records)} records from page {page + 1} (paginated response)") + + if not records: + logger.info("No more records found in paginated response") + break + + all_records.extend(records) + + # Check if we've got all available records + if len(all_records) >= total_available: + logger.info(f"Retrieved all available records: {len(all_records)}/{total_available}") + break + + else: + logger.warning(f"Unexpected response format: {type(result)}") + break + + page += 1 + + logger.info(f"Fallback pagination complete: fetched {len(all_records)} total records", + tenant_id=tenant_id, pages_fetched=page) + + return all_records + + async def upload_sales_data( + self, + tenant_id: str, + sales_data: List[Dict[str, Any]] + ) -> Optional[Dict[str, Any]]: + """Upload sales data""" + data = {"sales": sales_data} + return await self.post("sales", data=data, tenant_id=tenant_id) + + # ================================================================ + # WEATHER DATA + # ================================================================ + + async def get_weather_historical( + self, + tenant_id: str, + start_date: str, + end_date: str, + latitude: Optional[float] = None, + longitude: Optional[float] = None + ) -> Optional[List[Dict[str, Any]]]: + """ + Get weather data for a date range and location + Uses POST request as per original implementation + """ + # Prepare request payload with proper date handling + payload = { + "start_date": start_date, # Already in ISO format from calling code + "end_date": end_date, # Already in ISO format from calling code + "latitude": latitude or 40.4168, # Default Madrid coordinates + "longitude": longitude or -3.7038 + } + + logger.info(f"Weather request payload: {payload}", tenant_id=tenant_id) + + # Use POST request with extended timeout + result = await self._make_request( + "POST", + "weather/historical", + tenant_id=tenant_id, + data=payload, + timeout=2000.0 # Match original timeout + ) + + if result: + logger.info(f"Successfully fetched {len(result)} weather records") + return result + else: + logger.error("Failed to fetch weather data") + return [] + + async def get_weather_forecast( + self, + tenant_id: str, + days: str, + latitude: Optional[float] = None, + longitude: Optional[float] = None + ) -> Optional[List[Dict[str, Any]]]: + """ + Get weather data for a date range and location + Uses POST request as per original implementation + """ + # Prepare request payload with proper date handling + payload = { + "days": days, # Already in ISO format from calling code + "latitude": latitude or 40.4168, # Default Madrid coordinates + "longitude": longitude or -3.7038 + } + + logger.info(f"Weather request payload: {payload}", tenant_id=tenant_id) + + # Use POST request with extended timeout + result = await self._make_request( + "POST", + "weather/historical", + tenant_id=tenant_id, + data=payload, + timeout=2000.0 # Match original timeout + ) + + if result: + logger.info(f"Successfully fetched {len(result)} weather forecast for {days}") + return result + else: + logger.error("Failed to fetch weather data") + return [] + + # ================================================================ + # TRAFFIC DATA + # ================================================================ + + async def get_traffic_data( + self, + tenant_id: str, + start_date: str, + end_date: str, + latitude: Optional[float] = None, + longitude: Optional[float] = None + ) -> Optional[List[Dict[str, Any]]]: + """ + Get traffic data for a date range and location + Uses POST request with extended timeout for Madrid traffic data processing + """ + # Prepare request payload + payload = { + "start_date": start_date, # Already in ISO format from calling code + "end_date": end_date, # Already in ISO format from calling code + "latitude": latitude or 40.4168, # Default Madrid coordinates + "longitude": longitude or -3.7038 + } + + logger.info(f"Traffic request payload: {payload}", tenant_id=tenant_id) + + # Madrid traffic data can take 5-10 minutes to download and process + traffic_timeout = httpx.Timeout( + connect=30.0, # Connection timeout + read=600.0, # Read timeout: 10 minutes (was 30s) + write=30.0, # Write timeout + pool=30.0 # Pool timeout + ) + + # Use POST request with extended timeout + result = await self._make_request( + "POST", + "traffic/historical", + tenant_id=tenant_id, + data=payload, + timeout=traffic_timeout + ) + + if result: + logger.info(f"Successfully fetched {len(result)} traffic records") + return result + else: + logger.error("Failed to fetch traffic data") + return [] + + # ================================================================ + # PRODUCTS + # ================================================================ + + async def get_products(self, tenant_id: str) -> Optional[List[Dict[str, Any]]]: + """Get all products for a tenant""" + result = await self.get("products", tenant_id=tenant_id) + return result.get("products", []) if result else None + + async def get_product(self, tenant_id: str, product_id: str) -> Optional[Dict[str, Any]]: + """Get a specific product""" + return await self.get(f"products/{product_id}", tenant_id=tenant_id) + + async def create_product( + self, + tenant_id: str, + name: str, + category: str, + price: float, + **kwargs + ) -> Optional[Dict[str, Any]]: + """Create a new product""" + data = { + "name": name, + "category": category, + "price": price, + **kwargs + } + return await self.post("products", data=data, tenant_id=tenant_id) + + async def update_product( + self, + tenant_id: str, + product_id: str, + **updates + ) -> Optional[Dict[str, Any]]: + """Update a product""" + return await self.put(f"products/{product_id}", data=updates, tenant_id=tenant_id) + + # ================================================================ + # STORES & LOCATIONS + # ================================================================ + + async def get_stores(self, tenant_id: str) -> Optional[List[Dict[str, Any]]]: + """Get all stores for a tenant""" + result = await self.get("stores", tenant_id=tenant_id) + return result.get("stores", []) if result else None + + async def get_store(self, tenant_id: str, store_id: str) -> Optional[Dict[str, Any]]: + """Get a specific store""" + return await self.get(f"stores/{store_id}", tenant_id=tenant_id) + + # ================================================================ + # DATA VALIDATION & HEALTH + # ================================================================ + + async def validate_data_quality( + self, + tenant_id: str, + start_date: str, + end_date: str + ) -> Optional[Dict[str, Any]]: + """Validate data quality for a date range""" + params = { + "start_date": start_date, + "end_date": end_date + } + return await self.get("validation", tenant_id=tenant_id, params=params) + + async def get_data_statistics( + self, + tenant_id: str, + start_date: Optional[str] = None, + end_date: Optional[str] = None + ) -> Optional[Dict[str, Any]]: + """Get data statistics for a tenant""" + params = {} + if start_date: + params["start_date"] = start_date + if end_date: + params["end_date"] = end_date + + return await self.get("statistics", tenant_id=tenant_id, params=params) \ No newline at end of file diff --git a/shared/clients/forecast_client.py b/shared/clients/forecast_client.py new file mode 100644 index 00000000..3289a3ea --- /dev/null +++ b/shared/clients/forecast_client.py @@ -0,0 +1,175 @@ +# shared/clients/forecast_client.py +""" +Forecast Service Client +Handles all API calls to the forecasting service +""" + +from typing import Dict, Any, Optional, List +from .base_service_client import BaseServiceClient +from shared.config.base import BaseServiceSettings + + +class ForecastServiceClient(BaseServiceClient): + """Client for communicating with the forecasting service""" + + def __init__(self, config: BaseServiceSettings, calling_service_name: str = "unknown"): + super().__init__(calling_service_name, config) + + def get_service_base_path(self) -> str: + return "/api/v1" + + # ================================================================ + # FORECASTS + # ================================================================ + + async def create_forecast( + self, + tenant_id: str, + model_id: str, + start_date: str, + end_date: str, + product_ids: Optional[List[str]] = None, + include_confidence_intervals: bool = True, + **kwargs + ) -> Optional[Dict[str, Any]]: + """Create a new forecast""" + data = { + "model_id": model_id, + "start_date": start_date, + "end_date": end_date, + "include_confidence_intervals": include_confidence_intervals, + **kwargs + } + if product_ids: + data["product_ids"] = product_ids + + return await self.post("forecasts", data=data, tenant_id=tenant_id) + + async def get_forecast(self, tenant_id: str, forecast_id: str) -> Optional[Dict[str, Any]]: + """Get forecast details""" + return await self.get(f"forecasts/{forecast_id}", tenant_id=tenant_id) + + async def list_forecasts( + self, + tenant_id: str, + status: Optional[str] = None, + model_id: Optional[str] = None, + limit: int = 50 + ) -> Optional[List[Dict[str, Any]]]: + """List forecasts for a tenant""" + params = {"limit": limit} + if status: + params["status"] = status + if model_id: + params["model_id"] = model_id + + result = await self.get("forecasts", tenant_id=tenant_id, params=params) + return result.get("forecasts", []) if result else None + + async def delete_forecast(self, tenant_id: str, forecast_id: str) -> Optional[Dict[str, Any]]: + """Delete a forecast""" + return await self.delete(f"forecasts/{forecast_id}", tenant_id=tenant_id) + + # ================================================================ + # PREDICTIONS + # ================================================================ + + async def get_predictions( + self, + tenant_id: str, + forecast_id: str, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + product_id: Optional[str] = None + ) -> Optional[List[Dict[str, Any]]]: + """Get predictions from a forecast""" + params = {} + if start_date: + params["start_date"] = start_date + if end_date: + params["end_date"] = end_date + if product_id: + params["product_id"] = product_id + + result = await self.get(f"forecasts/{forecast_id}/predictions", tenant_id=tenant_id, params=params) + return result.get("predictions", []) if result else None + + async def create_realtime_prediction( + self, + tenant_id: str, + model_id: str, + target_date: str, + features: Dict[str, Any], + **kwargs + ) -> Optional[Dict[str, Any]]: + """Create a real-time prediction""" + data = { + "model_id": model_id, + "target_date": target_date, + "features": features, + **kwargs + } + return await self.post("predictions", data=data, tenant_id=tenant_id) + + # ================================================================ + # FORECAST VALIDATION & METRICS + # ================================================================ + + async def get_forecast_accuracy( + self, + tenant_id: str, + forecast_id: str, + start_date: Optional[str] = None, + end_date: Optional[str] = None + ) -> Optional[Dict[str, Any]]: + """Get forecast accuracy metrics""" + params = {} + if start_date: + params["start_date"] = start_date + if end_date: + params["end_date"] = end_date + + return await self.get(f"forecasts/{forecast_id}/accuracy", tenant_id=tenant_id, params=params) + + async def compare_forecasts( + self, + tenant_id: str, + forecast_ids: List[str], + metric: str = "mape" + ) -> Optional[Dict[str, Any]]: + """Compare multiple forecasts""" + data = { + "forecast_ids": forecast_ids, + "metric": metric + } + return await self.post("forecasts/compare", data=data, tenant_id=tenant_id) + + # ================================================================ + # FORECAST SCENARIOS + # ================================================================ + + async def create_scenario_forecast( + self, + tenant_id: str, + model_id: str, + scenario_name: str, + scenario_data: Dict[str, Any], + start_date: str, + end_date: str, + **kwargs + ) -> Optional[Dict[str, Any]]: + """Create a scenario-based forecast""" + data = { + "model_id": model_id, + "scenario_name": scenario_name, + "scenario_data": scenario_data, + "start_date": start_date, + "end_date": end_date, + **kwargs + } + return await self.post("scenarios", data=data, tenant_id=tenant_id) + + async def list_scenarios(self, tenant_id: str) -> Optional[List[Dict[str, Any]]]: + """List forecast scenarios for a tenant""" + result = await self.get("scenarios", tenant_id=tenant_id) + return result.get("scenarios", []) if result else None \ No newline at end of file diff --git a/shared/clients/training_client.py b/shared/clients/training_client.py new file mode 100644 index 00000000..f098d5b4 --- /dev/null +++ b/shared/clients/training_client.py @@ -0,0 +1,134 @@ +# shared/clients/training_client.py +""" +Training Service Client +Handles all API calls to the training service +""" + +from typing import Dict, Any, Optional, List +from .base_service_client import BaseServiceClient +from shared.config.base import BaseServiceSettings + + +class TrainingServiceClient(BaseServiceClient): + """Client for communicating with the training service""" + + def __init__(self, config: BaseServiceSettings, calling_service_name: str = "unknown"): + super().__init__(calling_service_name, config) + + def get_service_base_path(self) -> str: + return "/api/v1" + + # ================================================================ + # TRAINING JOBS + # ================================================================ + + async def create_training_job( + self, + tenant_id: str, + include_weather: bool = True, + include_traffic: bool = False, + min_data_points: int = 30, + **kwargs + ) -> Optional[Dict[str, Any]]: + """Create a new training job""" + data = { + "include_weather": include_weather, + "include_traffic": include_traffic, + "min_data_points": min_data_points, + **kwargs + } + return await self.post("jobs", data=data, tenant_id=tenant_id) + + async def get_training_job(self, tenant_id: str, job_id: str) -> Optional[Dict[str, Any]]: + """Get training job details""" + return await self.get(f"jobs/{job_id}", tenant_id=tenant_id) + + async def list_training_jobs( + self, + tenant_id: str, + status: Optional[str] = None, + limit: int = 50 + ) -> Optional[List[Dict[str, Any]]]: + """List training jobs for a tenant""" + params = {"limit": limit} + if status: + params["status"] = status + + result = await self.get("jobs", tenant_id=tenant_id, params=params) + return result.get("jobs", []) if result else None + + async def cancel_training_job(self, tenant_id: str, job_id: str) -> Optional[Dict[str, Any]]: + """Cancel a training job""" + return await self.delete(f"jobs/{job_id}", tenant_id=tenant_id) + + # ================================================================ + # MODELS + # ================================================================ + + async def get_model(self, tenant_id: str, model_id: str) -> Optional[Dict[str, Any]]: + """Get model details""" + return await self.get(f"models/{model_id}", tenant_id=tenant_id) + + async def list_models( + self, + tenant_id: str, + status: Optional[str] = None, + model_type: Optional[str] = None, + limit: int = 50 + ) -> Optional[List[Dict[str, Any]]]: + """List models for a tenant""" + params = {"limit": limit} + if status: + params["status"] = status + if model_type: + params["model_type"] = model_type + + result = await self.get("models", tenant_id=tenant_id, params=params) + return result.get("models", []) if result else None + + async def get_latest_model( + self, + tenant_id: str, + model_type: Optional[str] = None + ) -> Optional[Dict[str, Any]]: + """Get the latest trained model for a tenant""" + params = {"latest": "true"} + if model_type: + params["model_type"] = model_type + + result = await self.get("models", tenant_id=tenant_id, params=params) + models = result.get("models", []) if result else [] + return models[0] if models else None + + async def deploy_model(self, tenant_id: str, model_id: str) -> Optional[Dict[str, Any]]: + """Deploy a trained model""" + return await self.post(f"models/{model_id}/deploy", data={}, tenant_id=tenant_id) + + async def delete_model(self, tenant_id: str, model_id: str) -> Optional[Dict[str, Any]]: + """Delete a model""" + return await self.delete(f"models/{model_id}", tenant_id=tenant_id) + + # ================================================================ + # MODEL METRICS & PERFORMANCE + # ================================================================ + + async def get_model_metrics(self, tenant_id: str, model_id: str) -> Optional[Dict[str, Any]]: + """Get model performance metrics""" + return await self.get(f"models/{model_id}/metrics", tenant_id=tenant_id) + + async def get_model_predictions( + self, + tenant_id: str, + model_id: str, + start_date: Optional[str] = None, + end_date: Optional[str] = None + ) -> Optional[List[Dict[str, Any]]]: + """Get model predictions for evaluation""" + params = {} + if start_date: + params["start_date"] = start_date + if end_date: + params["end_date"] = end_date + + result = await self.get(f"models/{model_id}/predictions", tenant_id=tenant_id, params=params) + return result.get("predictions", []) if result else None \ No newline at end of file diff --git a/tests/test_onboarding_flow.sh b/tests/test_onboarding_flow.sh index 959e52be..8dab4d67 100755 --- a/tests/test_onboarding_flow.sh +++ b/tests/test_onboarding_flow.sh @@ -569,28 +569,22 @@ echo "" log_step "5.1. Testing basic dashboard functionality" -# Test basic forecasting capability (if training completed) -if [ -n "$TRAINING_TASK_ID" ]; then - # Use a real product name from our CSV for forecasting - FIRST_PRODUCT=$(echo "$REAL_PRODUCTS" | sed 's/"//g' | cut -d',' -f1) +# Use a real product name from our CSV for forecasting +FIRST_PRODUCT=$(echo "$REAL_PRODUCTS" | sed 's/"//g' | cut -d',' -f1) - FORECAST_RESPONSE=$(curl -s -X POST "$API_BASE/api/v1/forecasting/predict" \ - -H "Content-Type: application/json" \ - -H "Authorization: Bearer $ACCESS_TOKEN" \ - -H "X-Tenant-ID: $TENANT_ID" \ - -d "{ - \"products\": [\"$FIRST_PRODUCT\"], - \"forecast_days\": 7, - \"date\": \"2024-01-15\" +FORECAST_RESPONSE=$(curl -s -X POST "$API_BASE/api/v1/tenants/$TENANT_ID/forecast/single" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $ACCESS_TOKEN" \ + -d "{ + \"products\": [\"$FIRST_PRODUCT\"], + \"forecast_days\": 7, + \"date\": \"2025-09-15\" }") - if echo "$FORECAST_RESPONSE" | grep -q '"predictions"\|"forecast"'; then - log_success "Forecasting service is accessible" - else - log_warning "Forecasting may not be ready yet (model training required)" - fi +if echo "$FORECAST_RESPONSE" | grep -q '"predictions"\|"forecast"'; then + log_success "Forecasting service is accessible" else - log_warning "Skipping forecast test - no training task ID available" + log_warning "Forecasting may not be ready yet (model training required)" fi echo ""