From c349b845a6ff98da992bcd96b7b5b996d6972c4d Mon Sep 17 00:00:00 2001 From: Urtzi Alfaro Date: Fri, 14 Nov 2025 20:27:39 +0100 Subject: [PATCH] Bug fixes of training --- kind-config.yaml | 7 +- services/external/README.md | 21 +- services/external/app/api/poi_context.py | 28 +- services/forecasting/README.md | 2 +- .../app/services/poi_feature_service.py | 59 +- services/training/README.md | 2 +- .../training/app/ml/poi_feature_integrator.py | 121 ++-- services/training/app/ml/trainer.py | 160 +++++- services/training/app/models/training.py | 1 + .../training/app/services/training_service.py | 536 +++++++++--------- shared/clients/external_client.py | 77 ++- 11 files changed, 606 insertions(+), 408 deletions(-) diff --git a/kind-config.yaml b/kind-config.yaml index eb4ac449..3a1d7fc8 100644 --- a/kind-config.yaml +++ b/kind-config.yaml @@ -40,9 +40,4 @@ nodes: # Direct gateway access (backup) - containerPort: 30800 hostPort: 8000 - protocol: TCP - sysctls: - # Increase fs.inotify limits to prevent "too many open files" errors - fs.inotify.max_user_watches: 524288 - fs.inotify.max_user_instances: 256 - fs.inotify.max_queued_events: 32768 + protocol: TCP \ No newline at end of file diff --git a/services/external/README.md b/services/external/README.md index fa19afa1..e3f6d896 100644 --- a/services/external/README.md +++ b/services/external/README.md @@ -100,14 +100,19 @@ The **External Service** integrates real-world data from Spanish sources to enha ## API Endpoints (Key Routes) ### POI Detection & Context -- `POST /poi-context/{tenant_id}/detect` - Detect POIs for tenant location (lat, long, force_refresh params) -- `GET /poi-context/{tenant_id}` - Get cached POI context for tenant -- `POST /poi-context/{tenant_id}/refresh` - Force refresh POI detection -- `DELETE /poi-context/{tenant_id}` - Delete POI context for tenant -- `GET /poi-context/{tenant_id}/feature-importance` - Get POI feature importance summary -- `GET /poi-context/{tenant_id}/competitor-analysis` - Get competitive analysis -- `GET /poi-context/health` - Check POI service and Overpass API health -- `GET /poi-context/cache/stats` - Get POI cache statistics +- `POST /api/v1/tenants/{tenant_id}/poi-context/detect` - Detect POIs for tenant location (lat, long, force_refresh params) - Direct access (bypasses gateway authentication) +- `GET /api/v1/tenants/{tenant_id}/poi-context` - Get cached POI context for tenant - Direct access (bypasses gateway authentication) +- `POST /api/v1/tenants/{tenant_id}/poi-context/refresh` - Force refresh POI detection - Direct access (bypasses gateway authentication) +- `DELETE /api/v1/tenants/{tenant_id}/poi-context` - Delete POI context for tenant - Direct access (bypasses gateway authentication) +- `GET /api/v1/tenants/{tenant_id}/poi-context/feature-importance` - Get POI feature importance summary - Direct access (bypasses gateway authentication) +- `GET /api/v1/tenants/{tenant_id}/poi-context/competitor-analysis` - Get competitive analysis - Direct access (bypasses gateway authentication) +- `GET /api/v1/tenants/poi-context/health` - Check POI service and Overpass API health - Direct access (bypasses gateway authentication) +- `GET /api/v1/tenants/poi-context/cache/stats` - Get POI cache statistics - Direct access (bypasses gateway authentication) + +### Recommended Access Pattern: +- Services should use `/api/v1/tenants/{tenant_id}/external/poi-context` (detected POIs) and `/api/v1/tenants/{tenant_id}/external/poi-context/detect` (detection) via shared ExternalServiceClient through the API gateway for proper authentication and authorization. + +**Note**: When using ExternalServiceClient through shared client, provide relative paths like `poi-context` and `poi-context/detect` - the client automatically constructs the full tenant-scoped path. ### Weather Data (AEMET) - `GET /api/v1/external/weather/current` - Current weather for location diff --git a/services/external/app/api/poi_context.py b/services/external/app/api/poi_context.py index 6f11f7bf..c9148b85 100644 --- a/services/external/app/api/poi_context.py +++ b/services/external/app/api/poi_context.py @@ -18,13 +18,17 @@ from app.services.poi_refresh_service import POIRefreshService from app.repositories.poi_context_repository import POIContextRepository from app.cache.poi_cache_service import POICacheService from app.core.redis_client import get_redis_client +from shared.routing.route_builder import RouteBuilder logger = structlog.get_logger() -router = APIRouter(prefix="/tenants", tags=["POI Context"]) +route_builder = RouteBuilder('external') +router = APIRouter(tags=["POI Context"]) -@router.post("/{tenant_id}/poi-context/detect") +@router.post( + route_builder.build_base_route("poi-context/detect") +) async def detect_pois_for_tenant( tenant_id: str, latitude: float = Query(..., description="Bakery latitude"), @@ -297,7 +301,9 @@ async def detect_pois_for_tenant( ) -@router.get("/{tenant_id}/poi-context") +@router.get( + route_builder.build_base_route("poi-context") +) async def get_poi_context( tenant_id: str, db: AsyncSession = Depends(get_db) @@ -331,7 +337,9 @@ async def get_poi_context( } -@router.post("/{tenant_id}/poi-context/refresh") +@router.post( + route_builder.build_base_route("poi-context/refresh") +) async def refresh_poi_context( tenant_id: str, db: AsyncSession = Depends(get_db) @@ -365,7 +373,9 @@ async def refresh_poi_context( ) -@router.delete("/{tenant_id}/poi-context") +@router.delete( + route_builder.build_base_route("poi-context") +) async def delete_poi_context( tenant_id: str, db: AsyncSession = Depends(get_db) @@ -393,7 +403,9 @@ async def delete_poi_context( } -@router.get("/{tenant_id}/poi-context/feature-importance") +@router.get( + route_builder.build_base_route("poi-context/feature-importance") +) async def get_feature_importance( tenant_id: str, db: AsyncSession = Depends(get_db) @@ -430,7 +442,9 @@ async def get_feature_importance( } -@router.get("/{tenant_id}/poi-context/competitor-analysis") +@router.get( + route_builder.build_base_route("poi-context/competitor-analysis") +) async def get_competitor_analysis( tenant_id: str, db: AsyncSession = Depends(get_db) diff --git a/services/forecasting/README.md b/services/forecasting/README.md index 05883c30..7cf7f493 100644 --- a/services/forecasting/README.md +++ b/services/forecasting/README.md @@ -497,7 +497,7 @@ poi_features = await poi_service.fetch_poi_features(tenant_id) - **Accuracy Improvement** - POI features contribute 5-10% accuracy improvement **Endpoint Used:** -- `GET {EXTERNAL_SERVICE_URL}/poi-context/{tenant_id}` - Fetch POI features +- Via shared client: `/api/v1/tenants/{tenant_id}/external/poi-context` (routed through API Gateway) ## Integration Points diff --git a/services/forecasting/app/services/poi_feature_service.py b/services/forecasting/app/services/poi_feature_service.py index 6d6c13cc..a7263fdb 100644 --- a/services/forecasting/app/services/poi_feature_service.py +++ b/services/forecasting/app/services/poi_feature_service.py @@ -5,10 +5,11 @@ Fetches POI features for use in demand forecasting predictions. Ensures feature consistency between training and prediction. """ -import httpx from typing import Dict, Any, Optional import structlog +from shared.clients.external_client import ExternalServiceClient + logger = structlog.get_logger() @@ -20,15 +21,18 @@ class POIFeatureService: prediction uses the same features as training. """ - def __init__(self, external_service_url: str = "http://external-service:8000"): + def __init__(self, external_client: ExternalServiceClient = None): """ Initialize POI feature service. Args: - external_service_url: Base URL for external service + external_client: External service client instance (optional) """ - self.external_service_url = external_service_url.rstrip("/") - self.poi_context_endpoint = f"{self.external_service_url}/poi-context" + if external_client is None: + from app.core.config import settings + self.external_client = ExternalServiceClient(settings, "forecasting-service") + else: + self.external_client = external_client async def get_poi_features( self, @@ -44,21 +48,10 @@ class POIFeatureService: Dictionary with POI features or empty dict if not available """ try: - async with httpx.AsyncClient(timeout=10.0) as client: - response = await client.get( - f"{self.poi_context_endpoint}/{tenant_id}" - ) + result = await self.external_client.get_poi_context(tenant_id) - if response.status_code == 404: - logger.warning( - "No POI context found for tenant", - tenant_id=tenant_id - ) - return {} - - response.raise_for_status() - data = response.json() - poi_context = data.get("poi_context", {}) + if result: + poi_context = result.get("poi_context", {}) ml_features = poi_context.get("ml_features", {}) logger.info( @@ -68,17 +61,16 @@ class POIFeatureService: ) return ml_features + else: + logger.warning( + "No POI context found for tenant", + tenant_id=tenant_id + ) + return {} - except httpx.HTTPError as e: - logger.error( - "Failed to fetch POI features for forecasting", - tenant_id=tenant_id, - error=str(e) - ) - return {} except Exception as e: logger.error( - "Unexpected error fetching POI features", + "Failed to fetch POI features for forecasting", tenant_id=tenant_id, error=str(e), exc_info=True @@ -87,17 +79,18 @@ class POIFeatureService: async def check_poi_service_health(self) -> bool: """ - Check if POI service is accessible. + Check if POI service is accessible through the external client. Returns: True if service is healthy, False otherwise """ try: - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.get( - f"{self.poi_context_endpoint}/health" - ) - return response.status_code == 200 + # Test the external service health by attempting to get POI context for a dummy tenant + # This will go through the proper authentication and routing + dummy_context = await self.external_client.get_poi_context("test-tenant") + # If we can successfully make a request (even if it returns None for missing tenant), + # it means the service is accessible + return True except Exception as e: logger.error( "POI service health check failed", diff --git a/services/training/README.md b/services/training/README.md index 4e106244..e93e3f70 100644 --- a/services/training/README.md +++ b/services/training/README.md @@ -581,7 +581,7 @@ for feature_name in poi_features.keys(): ``` **Endpoint Used:** -- `GET {EXTERNAL_SERVICE_URL}/poi-context/{tenant_id}` - Fetch POI features +- Via shared client: `/api/v1/tenants/{tenant_id}/external/poi-context` (routed through API Gateway) ## Integration Points diff --git a/services/training/app/ml/poi_feature_integrator.py b/services/training/app/ml/poi_feature_integrator.py index 4d069622..05a580dc 100644 --- a/services/training/app/ml/poi_feature_integrator.py +++ b/services/training/app/ml/poi_feature_integrator.py @@ -5,11 +5,12 @@ Integrates POI features into ML training pipeline. Fetches POI context from External service and merges features into training data. """ -import httpx from typing import Dict, Any, Optional, List import structlog import pandas as pd +from shared.clients.external_client import ExternalServiceClient + logger = structlog.get_logger() @@ -21,15 +22,18 @@ class POIFeatureIntegrator: to training dataframes for location-based demand forecasting. """ - def __init__(self, external_service_url: str = "http://external-service:8000"): + def __init__(self, external_client: ExternalServiceClient = None): """ Initialize POI feature integrator. Args: - external_service_url: Base URL for external service + external_client: External service client instance (optional) """ - self.external_service_url = external_service_url.rstrip("/") - self.poi_context_endpoint = f"{self.external_service_url}/poi-context" + if external_client is None: + from app.core.config import settings + self.external_client = ExternalServiceClient(settings, "training-service") + else: + self.external_client = external_client async def fetch_poi_features( self, @@ -53,57 +57,49 @@ class POIFeatureIntegrator: Dictionary with POI features or None if detection fails """ try: - async with httpx.AsyncClient(timeout=60.0) as client: - # Try to get existing POI context first - if not force_refresh: - try: - response = await client.get( - f"{self.poi_context_endpoint}/{tenant_id}" - ) - if response.status_code == 200: - data = response.json() - poi_context = data.get("poi_context", {}) + # Try to get existing POI context first + if not force_refresh: + existing_context = await self.external_client.get_poi_context(tenant_id) + if existing_context: + poi_context = existing_context.get("poi_context", {}) + ml_features = poi_context.get("ml_features", {}) - # Check if stale - if not data.get("is_stale", False): - logger.info( - "Using existing POI context", - tenant_id=tenant_id - ) - return poi_context.get("ml_features", {}) - else: - logger.info( - "POI context is stale, refreshing", - tenant_id=tenant_id - ) - force_refresh = True - except httpx.HTTPStatusError as e: - if e.response.status_code != 404: - raise + # Check if stale + is_stale = existing_context.get("is_stale", False) + if not is_stale: logger.info( - "No existing POI context, will detect", + "Using existing POI context", tenant_id=tenant_id ) + return ml_features + else: + logger.info( + "POI context is stale, refreshing", + tenant_id=tenant_id + ) + force_refresh = True + else: + logger.info( + "No existing POI context, will detect", + tenant_id=tenant_id + ) - # Detect or refresh POIs - logger.info( - "Detecting POIs for tenant", - tenant_id=tenant_id, - location=(latitude, longitude) - ) + # Detect or refresh POIs + logger.info( + "Detecting POIs for tenant", + tenant_id=tenant_id, + location=(latitude, longitude) + ) - response = await client.post( - f"{self.poi_context_endpoint}/{tenant_id}/detect", - params={ - "latitude": latitude, - "longitude": longitude, - "force_refresh": force_refresh - } - ) - response.raise_for_status() + detection_result = await self.external_client.detect_poi_for_tenant( + tenant_id=tenant_id, + latitude=latitude, + longitude=longitude, + force_refresh=force_refresh + ) - result = response.json() - poi_context = result.get("poi_context", {}) + if detection_result: + poi_context = detection_result.get("poi_context", {}) ml_features = poi_context.get("ml_features", {}) logger.info( @@ -114,15 +110,13 @@ class POIFeatureIntegrator: ) return ml_features + else: + logger.error( + "POI detection failed", + tenant_id=tenant_id + ) + return None - except httpx.HTTPError as e: - logger.error( - "Failed to fetch POI features", - tenant_id=tenant_id, - error=str(e), - exc_info=True - ) - return None except Exception as e: logger.error( "Unexpected error fetching POI features", @@ -185,17 +179,18 @@ class POIFeatureIntegrator: async def check_poi_service_health(self) -> bool: """ - Check if POI service is accessible. + Check if POI service is accessible through the external client. Returns: True if service is healthy, False otherwise """ try: - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.get( - f"{self.poi_context_endpoint}/health" - ) - return response.status_code == 200 + # We can test the external service health by attempting to get POI context for a dummy tenant + # This will go through the proper authentication and routing + dummy_context = await self.external_client.get_poi_context("test-tenant") + # If we can successfully make a request (even if it returns None for missing tenant), + # it means the service is accessible + return True except Exception as e: logger.error( "POI service health check failed", diff --git a/services/training/app/ml/trainer.py b/services/training/app/ml/trainer.py index 3542a866..f01ca6ad 100644 --- a/services/training/app/ml/trainer.py +++ b/services/training/app/ml/trainer.py @@ -375,40 +375,143 @@ class EnhancedBakeryMLTrainer: try: # Use provided session or create new one to prevent nested sessions and deadlocks should_create_session = session is None - db_session = session if session is not None else None - + if should_create_session: # Only create a session if one wasn't provided async with self.database_manager.get_session() as db_session: repos = await self._get_repositories(db_session) - + + # Validate input data + if training_data.empty or len(training_data) < settings.MIN_TRAINING_DATA_DAYS: + raise ValueError(f"Insufficient training data: need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(training_data)}") + + # Validate required columns + required_columns = ['ds', 'y'] + missing_cols = [col for col in required_columns if col not in training_data.columns] + if missing_cols: + raise ValueError(f"Missing required columns in training data: {missing_cols}") + + # Create a simple progress tracker for single product + from app.services.progress_tracker import ParallelProductProgressTracker + progress_tracker = ParallelProductProgressTracker( + job_id=job_id, + tenant_id=tenant_id, + total_products=1 + ) + + # Ensure training data has proper data types before training + if 'ds' in training_data.columns: + training_data['ds'] = pd.to_datetime(training_data['ds']) + if 'y' in training_data.columns: + training_data['y'] = pd.to_numeric(training_data['y'], errors='coerce') + + # Remove any rows with NaN values + training_data = training_data.dropna() + + # Train the model using the existing _train_single_product method + product_id, result = await self._train_single_product( + tenant_id=tenant_id, + inventory_product_id=inventory_product_id, + product_data=training_data, + job_id=job_id, + repos=repos, + progress_tracker=progress_tracker, + session=db_session # Pass the session to prevent nested sessions + ) + + logger.info("Single product training completed", + job_id=job_id, + inventory_product_id=inventory_product_id, + result_status=result.get('status')) + + # Write training result to database (create model record) + if result.get('status') == 'success': + model_info = result.get('model_info') + product_data = result.get('product_data') + + if model_info and product_data is not None: + # Create model record in database + model_record = await self._create_model_record( + repos, tenant_id, inventory_product_id, model_info, job_id, product_data + ) + + # Create performance metrics + if model_info.get('training_metrics') and model_record: + await self._create_performance_metrics( + repos, model_record.id, + tenant_id, inventory_product_id, model_info['training_metrics'] + ) + + # Update result with model_record_id + result['model_record_id'] = str(model_record.id) if model_record else None + + # Get training metrics and filter out non-numeric values + raw_metrics = result.get('model_info', {}).get('training_metrics', {}) + # Filter metrics to only include numeric values (per Pydantic schema requirement) + filtered_metrics = {} + for key, value in raw_metrics.items(): + if key == 'product_category': + # Skip product_category as it's a string value, not a numeric metric + continue + try: + # Try to convert to float for validation + filtered_metrics[key] = float(value) if value is not None else 0.0 + except (ValueError, TypeError): + # Skip non-numeric values + continue + + # Return appropriate result format + result_dict = { + "job_id": job_id, + "tenant_id": tenant_id, + "inventory_product_id": inventory_product_id, + "status": result.get('status', 'success'), + "model_id": str(result.get('model_record_id', '')) if result.get('model_record_id') else None, + "training_metrics": filtered_metrics, + "training_time": result.get('training_time_seconds', 0), + "data_points": result.get('data_points', 0), + "message": f"Single product model training {'completed' if result.get('status') != 'error' else 'failed'}" + } + + # Only commit if this is our own session (not a parent session) + # Commit after we're done with all database operations + await db_session.commit() + logger.info("Committed single product model record to database", + inventory_product_id=inventory_product_id, + model_record_id=result.get('model_record_id')) + + return result_dict + else: + # Use the provided session + repos = await self._get_repositories(session) + # Validate input data if training_data.empty or len(training_data) < settings.MIN_TRAINING_DATA_DAYS: raise ValueError(f"Insufficient training data: need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(training_data)}") - + # Validate required columns required_columns = ['ds', 'y'] missing_cols = [col for col in required_columns if col not in training_data.columns] if missing_cols: raise ValueError(f"Missing required columns in training data: {missing_cols}") - + # Create a simple progress tracker for single product from app.services.progress_tracker import ParallelProductProgressTracker progress_tracker = ParallelProductProgressTracker( - job_id=job_id, - tenant_id=tenant_id, + job_id=job_id, + tenant_id=tenant_id, total_products=1 ) - + # Ensure training data has proper data types before training if 'ds' in training_data.columns: training_data['ds'] = pd.to_datetime(training_data['ds']) if 'y' in training_data.columns: training_data['y'] = pd.to_numeric(training_data['y'], errors='coerce') - + # Remove any rows with NaN values training_data = training_data.dropna() - + # Train the model using the existing _train_single_product method product_id, result = await self._train_single_product( tenant_id=tenant_id, @@ -417,14 +520,35 @@ class EnhancedBakeryMLTrainer: job_id=job_id, repos=repos, progress_tracker=progress_tracker, - session=db_session # Pass the session to prevent nested sessions + session=session # Pass the provided session ) - + logger.info("Single product training completed", job_id=job_id, inventory_product_id=inventory_product_id, result_status=result.get('status')) - + + # Write training result to database (create model record) + if result.get('status') == 'success': + model_info = result.get('model_info') + product_data = result.get('product_data') + + if model_info and product_data is not None: + # Create model record in database + model_record = await self._create_model_record( + repos, tenant_id, inventory_product_id, model_info, job_id, product_data + ) + + # Create performance metrics + if model_info.get('training_metrics') and model_record: + await self._create_performance_metrics( + repos, model_record.id, + tenant_id, inventory_product_id, model_info['training_metrics'] + ) + + # Update result with model_record_id + result['model_record_id'] = str(model_record.id) if model_record else None + # Get training metrics and filter out non-numeric values raw_metrics = result.get('model_info', {}).get('training_metrics', {}) # Filter metrics to only include numeric values (per Pydantic schema requirement) @@ -439,7 +563,7 @@ class EnhancedBakeryMLTrainer: except (ValueError, TypeError): # Skip non-numeric values continue - + # Return appropriate result format result_dict = { "job_id": job_id, @@ -452,7 +576,13 @@ class EnhancedBakeryMLTrainer: "data_points": result.get('data_points', 0), "message": f"Single product model training {'completed' if result.get('status') != 'error' else 'failed'}" } - + + # For provided sessions, do NOT commit here - let the calling method handle commits + # This prevents committing a parent transaction prematurely + logger.info("Single product model processed (commit handled by caller)", + inventory_product_id=inventory_product_id, + model_record_id=result.get('model_record_id')) + return result_dict except Exception as e: diff --git a/services/training/app/models/training.py b/services/training/app/models/training.py index a0f3f561..7cae79ad 100644 --- a/services/training/app/models/training.py +++ b/services/training/app/models/training.py @@ -186,6 +186,7 @@ class TrainedModel(Base): "training_samples": self.training_samples, "hyperparameters": self.hyperparameters, "features_used": self.features_used, + "features": self.features_used, # Alias for frontend compatibility (ModelDetailsModal expects 'features') "product_category": self.product_category, "is_active": self.is_active, "is_production": self.is_production, diff --git a/services/training/app/services/training_service.py b/services/training/app/services/training_service.py index 7742abae..71fe6be0 100644 --- a/services/training/app/services/training_service.py +++ b/services/training/app/services/training_service.py @@ -732,297 +732,293 @@ class EnhancedTrainingService: current_step: str = None, error_message: str = None, results: Dict = None, - tenant_id: str = None): - """Update job status using repository pattern""" + tenant_id: str = None, + session = None): + """Update job status using repository pattern + + Args: + session: Optional database session to reuse. If None, creates a new session. + """ try: - async with self.database_manager.get_session() as session: - await self._init_repositories(session) + # Use provided session or create new one + should_create_session = session is None - # Check if log exists, create if not - existing_log = await self.training_log_repo.get_log_by_job_id(job_id) - - if not existing_log: - # Create initial log entry - if not tenant_id: - # Extract tenant_id from job_id if not provided - # Format: enhanced_training_{tenant_id}_{job_suffix} - try: - parts = job_id.split('_') - if len(parts) >= 3 and parts[0] == 'enhanced' and parts[1] == 'training': - tenant_id = parts[2] - except Exception: - logger.warning(f"Could not extract tenant_id from job_id {job_id}") - - if tenant_id: - log_data = { - "job_id": job_id, - "tenant_id": tenant_id, - "status": status or "pending", - "progress": progress or 0, - "current_step": current_step or "initializing", - "start_time": datetime.now(timezone.utc) - } - - if error_message: - log_data["error_message"] = error_message - if results: - # Ensure results are JSON-serializable before storing - log_data["results"] = make_json_serializable(results) - - try: - await self.training_log_repo.create_training_log(log_data) - await session.commit() # Explicit commit so other sessions can see it - logger.info("Created initial training log", job_id=job_id, tenant_id=tenant_id) - except Exception as create_error: - # Handle race condition: another session may have created the log - if "unique constraint" in str(create_error).lower() or "duplicate" in str(create_error).lower(): - logger.debug("Training log already exists (race condition), querying again", job_id=job_id) - await session.rollback() - # Query again to get the existing log - existing_log = await self.training_log_repo.get_log_by_job_id(job_id) - if existing_log: - # Update the existing log instead - await self.training_log_repo.update_log_progress( - job_id=job_id, - progress=progress, - current_step=current_step, - status=status - ) - await session.commit() - else: - raise - else: - logger.error("Cannot create training log without tenant_id", job_id=job_id) - return - else: - # Update existing log - await self.training_log_repo.update_log_progress( - job_id=job_id, - progress=progress, - current_step=current_step, - status=status + if should_create_session: + async with self.database_manager.get_session() as session: + await self._init_repositories(session) + await self._update_job_status_impl( + session, job_id, status, progress, current_step, + error_message, results, tenant_id ) - - # Update additional fields if provided - if error_message or results: - update_data = {} - if error_message: - update_data["error_message"] = error_message - if results: - # Ensure results are JSON-serializable before storing - update_data["results"] = make_json_serializable(results) - if status in ["completed", "failed"]: - update_data["end_time"] = datetime.now(timezone.utc) - - if update_data: - await self.training_log_repo.update(existing_log.id, update_data) - - await session.commit() # Explicit commit after updates - + else: + # Reuse provided session (don't commit - let caller control transaction) + await self._init_repositories(session) + await self._update_job_status_impl( + session, job_id, status, progress, current_step, + error_message, results, tenant_id, auto_commit=False + ) except Exception as e: logger.error("Failed to update job status using repository", job_id=job_id, error=str(e)) + async def _update_job_status_impl(self, + session, + job_id: str, + status: str, + progress: int = None, + current_step: str = None, + error_message: str = None, + results: Dict = None, + tenant_id: str = None, + auto_commit: bool = True): + """Implementation of job status update""" + # Check if log exists, create if not + existing_log = await self.training_log_repo.get_log_by_job_id(job_id) + + if not existing_log: + # Create initial log entry + if not tenant_id: + # Extract tenant_id from job_id if not provided + # Format: enhanced_training_{tenant_id}_{job_suffix} + try: + parts = job_id.split('_') + if len(parts) >= 3 and parts[0] == 'enhanced' and parts[1] == 'training': + tenant_id = parts[2] + except Exception: + logger.warning(f"Could not extract tenant_id from job_id {job_id}") + + if tenant_id: + log_data = { + "job_id": job_id, + "tenant_id": tenant_id, + "status": status or "pending", + "progress": progress or 0, + "current_step": current_step or "initializing", + "start_time": datetime.now(timezone.utc) + } + + if error_message: + log_data["error_message"] = error_message + if results: + # Ensure results are JSON-serializable before storing + log_data["results"] = make_json_serializable(results) + + try: + await self.training_log_repo.create_training_log(log_data) + if auto_commit: + await session.commit() # Explicit commit so other sessions can see it + logger.info("Created initial training log", job_id=job_id, tenant_id=tenant_id) + except Exception as create_error: + # Handle race condition: another session may have created the log + if "unique constraint" in str(create_error).lower() or "duplicate" in str(create_error).lower(): + logger.debug("Training log already exists (race condition), querying again", job_id=job_id) + await session.rollback() + # Query again to get the existing log + existing_log = await self.training_log_repo.get_log_by_job_id(job_id) + if existing_log: + # Update the existing log instead + await self.training_log_repo.update_log_progress( + job_id=job_id, + progress=progress, + current_step=current_step, + status=status + ) + if auto_commit: + await session.commit() + else: + raise + else: + logger.error("Cannot create training log without tenant_id", job_id=job_id) + return + else: + # Update existing log + await self.training_log_repo.update_log_progress( + job_id=job_id, + progress=progress, + current_step=current_step, + status=status + ) + + # Update additional fields if provided + if error_message or results: + update_data = {} + if error_message: + update_data["error_message"] = error_message + if results: + # Ensure results are JSON-serializable before storing + update_data["results"] = make_json_serializable(results) + if status in ["completed", "failed"]: + update_data["end_time"] = datetime.now(timezone.utc) + + if update_data: + await self.training_log_repo.update(existing_log.id, update_data) + + if auto_commit: + await session.commit() # Explicit commit after updates + async def start_single_product_training(self, tenant_id: str, inventory_product_id: str, job_id: str, bakery_location: tuple = (40.4168, -3.7038)) -> Dict[str, Any]: - """Start enhanced single product training using repository pattern""" - try: - logger.info("Starting enhanced single product training", - tenant_id=tenant_id, - inventory_product_id=inventory_product_id, - job_id=job_id) - - # Create initial training log - await self._update_job_status_repository( - job_id=job_id, - status="running", - progress=0, - current_step="Fetching training data", - tenant_id=tenant_id - ) - - # Prepare training data for all products to get weather/traffic data - # then filter down to the specific product - training_dataset = await self.orchestrator.prepare_training_data( - tenant_id=tenant_id, - bakery_location=bakery_location, - job_id=job_id + "_temp" - ) - - # Filter sales data to the specific product - sales_df = pd.DataFrame(training_dataset.sales_data) - product_sales_df = sales_df[sales_df['inventory_product_id'] == inventory_product_id] - - if product_sales_df.empty: - raise ValueError(f"No sales data available for product {inventory_product_id}") - - # Prepare the data in Prophet format (ds and y columns) - # Ensure proper column names and types for Prophet - product_data = product_sales_df.copy() - product_data = product_data.rename(columns={ - 'sale_date': 'ds', # Common sales date column - 'sale_datetime': 'ds', # Alternative date column - 'date': 'ds', # Alternative date column - 'quantity': 'y', # Quantity sold - 'total_amount': 'y', # Alternative for sales data - 'sales_amount': 'y', # Alternative for sales data - 'sale_amount': 'y' # Alternative for sales data - }) - - # If 'ds' and 'y' columns are not renamed properly, try to infer them - if 'ds' not in product_data.columns: - # Try to find date-like columns - date_cols = [col for col in product_data.columns if 'date' in col.lower() or 'time' in col.lower()] - if date_cols: - product_data = product_data.rename(columns={date_cols[0]: 'ds'}) - - if 'y' not in product_data.columns: - # Try to find sales/quantity-like columns - sales_cols = [col for col in product_data.columns if - any(word in col.lower() for word in ['amount', 'quantity', 'sales', 'total', 'count', 'value'])] - if sales_cols: - product_data = product_data.rename(columns={sales_cols[0]: 'y'}) - - # Ensure required columns exist - if 'ds' not in product_data.columns or 'y' not in product_data.columns: - raise ValueError(f"Sales data must contain 'date' and 'quantity/sales' columns. Available columns: {list(product_data.columns)}") - - # Convert the date column to datetime if it's not already - product_data['ds'] = pd.to_datetime(product_data['ds']) - - # Convert to numeric ensuring no pandas/numpy objects remain - product_data['y'] = pd.to_numeric(product_data['y'], errors='coerce') - - # Sort by date to ensure proper chronological order - product_data = product_data.sort_values('ds').reset_index(drop=True) - - # Drop any rows with NaN values - product_data = product_data.dropna(subset=['ds', 'y']) - - # Ensure the data is in the right format for Prophet - product_data = product_data[['ds', 'y']].copy() + """Start enhanced single product training using repository pattern with single session""" + # Create a single database session for all operations to avoid connection pool exhaustion + async with self.database_manager.get_session() as session: + await self._init_repositories(session) - # Convert to pandas datetime and float types (keep as pandas Series for proper min/max operations) - product_data['ds'] = pd.to_datetime(product_data['ds']) - product_data['y'] = product_data['y'].astype(float) - - # DEBUG: Log data types to diagnose dict comparison error - logger.info(f"DEBUG: product_data dtypes after conversion: ds={product_data['ds'].dtype}, y={product_data['y'].dtype}") - logger.info(f"DEBUG: product_data['ds'] sample values: {product_data['ds'].head(3).tolist()}") - logger.info(f"DEBUG: Attempting to get min/max...") try: - min_val = product_data['ds'].min() - max_val = product_data['ds'].max() - logger.info(f"DEBUG: min_val type={type(min_val)}, value={min_val}") - logger.info(f"DEBUG: max_val type={type(max_val)}, value={max_val}") - except Exception as debug_e: - logger.error(f"DEBUG: Failed to get min/max: {debug_e}") - import traceback - logger.error(f"DEBUG: Traceback: {traceback.format_exc()}") + logger.info("Starting enhanced single product training", + tenant_id=tenant_id, + inventory_product_id=inventory_product_id, + job_id=job_id) - logger.info("Prepared training data for single product", - inventory_product_id=inventory_product_id, - data_points=len(product_data), - date_range=f"{product_data['ds'].min()} to {product_data['ds'].max()}") - - # Update progress - await self._update_job_status_repository( - job_id=job_id, - status="running", - progress=30, - current_step="Training model", - tenant_id=tenant_id - ) - - # Train the model using the trainer - # Extract datetime values with proper pandas Timestamp wrapper for type safety - try: - training_start = pd.Timestamp(product_data['ds'].min()).to_pydatetime() - training_end = pd.Timestamp(product_data['ds'].max()).to_pydatetime() - except Exception as e: - import traceback - logger.error(f"Failed to extract training dates: {e}") - logger.error(f"Traceback: {traceback.format_exc()}") - logger.error(f"product_data['ds'] dtype: {product_data['ds'].dtype}") - logger.error(f"product_data['ds'] first 5 values: {product_data['ds'].head().tolist()}") - raise - - # Run the actual training - try: - model_info = await self.trainer.train_single_product_model( + # Create initial training log (using shared session) + await self._update_job_status_repository( + job_id=job_id, + status="running", + progress=0, + current_step="Fetching training data", tenant_id=tenant_id, + session=session + ) + await session.commit() # Commit after initial log creation + + # Prepare training data for all products to get weather/traffic data + # then filter down to the specific product + training_dataset = await self.orchestrator.prepare_training_data( + tenant_id=tenant_id, + bakery_location=bakery_location, + job_id=job_id + "_temp" + ) + + # Use the enhanced data processor to prepare training data with all features (POI, weather, traffic) + # Filter sales data to the specific product first + sales_df = pd.DataFrame(training_dataset.sales_data) + product_sales_df = sales_df[sales_df['inventory_product_id'] == inventory_product_id] + + if product_sales_df.empty: + raise ValueError(f"No sales data available for product {inventory_product_id}") + + # Get weather and traffic data as DataFrames + weather_df = pd.DataFrame(training_dataset.weather_data) + traffic_df = pd.DataFrame(training_dataset.traffic_data) + + # Get POI features from the training dataset (already collected by orchestrator) + poi_features = training_dataset.poi_features if hasattr(training_dataset, 'poi_features') else None + + # Use the enhanced data processor to merge all features properly + # This will include POI, weather, traffic features along with ds and y + from app.ml.data_processor import EnhancedBakeryDataProcessor + data_processor = EnhancedBakeryDataProcessor(self.database_manager) + + product_data = await data_processor.prepare_training_data( + sales_data=product_sales_df, + weather_data=weather_df, + traffic_data=traffic_df, inventory_product_id=inventory_product_id, - training_data=product_data, + poi_features=poi_features, + tenant_id=tenant_id, job_id=job_id ) - except Exception as e: - import traceback - logger.error(f"Training failed with error: {e}") - logger.error(f"Full traceback: {traceback.format_exc()}") - raise - - # Update progress - await self._update_job_status_repository( - job_id=job_id, - status="running", - progress=80, - current_step="Saving model", - tenant_id=tenant_id - ) - - # The model should already be saved by train_single_product_model - # Return appropriate response - return { - "job_id": job_id, - "tenant_id": tenant_id, - "inventory_product_id": inventory_product_id, - "status": "completed", - "message": "Enhanced single product training completed successfully", - "created_at": datetime.now(timezone.utc), - "estimated_duration_minutes": 15, # Default estimate for single product - "training_results": { - "total_products": 1, - "successful_trainings": 1, - "failed_trainings": 0, - "products": [{ - "inventory_product_id": inventory_product_id, - "status": "completed", - "model_id": str(model_info.get('model_id', f"model_{inventory_product_id}_{job_id[:8]}")) if model_info.get('model_id') else None, - "data_points": len(product_data) if product_data is not None else 0, - # Filter metrics to ensure only numeric values are included - "metrics": { - k: float(v) if not isinstance(v, (int, float)) else v - for k, v in model_info.get('training_metrics', {"mape": 0.0, "mae": 0.0, "rmse": 0.0, "r2_score": 0.0}).items() - if k != 'product_category' and v is not None - } - }], - "overall_training_time_seconds": model_info.get('training_time', 45.2) - }, - "enhanced_features": True, - "repository_integration": True, - "completed_at": datetime.now(timezone.utc).isoformat() - } - - except Exception as e: - logger.error("Enhanced single product training failed", + + if product_data.empty: + raise ValueError(f"Data processor returned empty data for product {inventory_product_id}") + + logger.info("Prepared training data for single product", + inventory_product_id=inventory_product_id, + data_points=len(product_data), + features=list(product_data.columns), + date_range=f"{product_data['ds'].min()} to {product_data['ds'].max()}") + + # Update progress (using shared session) + await self._update_job_status_repository( + job_id=job_id, + status="running", + progress=30, + current_step="Training model", + tenant_id=tenant_id, + session=session + ) + await session.commit() # Commit progress update + + # Run the actual training (passing the session to avoid nested session creation) + try: + model_info = await self.trainer.train_single_product_model( + tenant_id=tenant_id, inventory_product_id=inventory_product_id, - error=str(e)) - - # Update status to failed - await self._update_job_status_repository( - job_id=job_id, - status="failed", - progress=0, - current_step="Training failed", - error_message=str(e), - tenant_id=tenant_id - ) - - raise + training_data=product_data, + job_id=job_id, + session=session # ✅ CRITICAL FIX: Pass session to prevent deadlock + ) + except Exception as e: + import traceback + logger.error(f"Training failed with error: {e}") + logger.error(f"Full traceback: {traceback.format_exc()}") + raise + + # Update progress (using shared session) + await self._update_job_status_repository( + job_id=job_id, + status="running", + progress=80, + current_step="Saving model", + tenant_id=tenant_id, + session=session + ) + await session.commit() # Commit progress update + + # The model should already be saved by train_single_product_model + # Return appropriate response + return { + "job_id": job_id, + "tenant_id": tenant_id, + "inventory_product_id": inventory_product_id, + "status": "completed", + "message": "Enhanced single product training completed successfully", + "created_at": datetime.now(timezone.utc), + "estimated_duration_minutes": 15, # Default estimate for single product + "training_results": { + "total_products": 1, + "successful_trainings": 1, + "failed_trainings": 0, + "products": [{ + "inventory_product_id": inventory_product_id, + "status": "completed", + "model_id": str(model_info.get('model_id', f"model_{inventory_product_id}_{job_id[:8]}")) if model_info.get('model_id') else None, + "data_points": len(product_data) if product_data is not None else 0, + # Filter metrics to ensure only numeric values are included + "metrics": { + k: float(v) if not isinstance(v, (int, float)) else v + for k, v in model_info.get('training_metrics', {"mape": 0.0, "mae": 0.0, "rmse": 0.0, "r2_score": 0.0}).items() + if k != 'product_category' and v is not None + } + }], + "overall_training_time_seconds": model_info.get('training_time', 45.2) + }, + "enhanced_features": True, + "repository_integration": True, + "completed_at": datetime.now(timezone.utc).isoformat() + } + + except Exception as e: + logger.error("Enhanced single product training failed", + inventory_product_id=inventory_product_id, + error=str(e)) + + # Update status to failed (using shared session) + await self._update_job_status_repository( + job_id=job_id, + status="failed", + progress=0, + current_step="Training failed", + error_message=str(e), + tenant_id=tenant_id, + session=session + ) + await session.commit() # Commit failure status + + raise def _create_detailed_training_response(self, final_result: Dict[str, Any]) -> Dict[str, Any]: """Convert final result to detailed training response""" diff --git a/shared/clients/external_client.py b/shared/clients/external_client.py index a4acf067..5bc04833 100644 --- a/shared/clients/external_client.py +++ b/shared/clients/external_client.py @@ -494,6 +494,75 @@ class ExternalServiceClient(BaseServiceClient): # POI (POINT OF INTEREST) DATA # ================================================================ + async def detect_poi_for_tenant( + self, + tenant_id: str, + latitude: float, + longitude: float, + force_refresh: bool = False + ) -> Optional[Dict[str, Any]]: + """ + Detect POIs for a tenant's location and generate ML features for forecasting. + + With the new tenant-based architecture: + - Gateway receives at: /api/v1/tenants/{tenant_id}/external/poi-context/detect + - Gateway proxies to external service at: /api/v1/tenants/{tenant_id}/poi-context/detect + - This client calls: poi-context/detect (base client automatically constructs with tenant) + + This triggers POI detection using Overpass API and calculates ML features + for demand forecasting. + + Args: + tenant_id: Tenant ID + latitude: Latitude of the bakery location + longitude: Longitude of the bakery location + force_refresh: Whether to force refresh even if POI context exists + + Returns: + Dict with POI detection results including: + - ml_features: Dict of POI features for ML models (e.g., poi_retail_total_count) + - poi_detection_results: Full detection results + - location: Latitude/longitude + - total_pois_detected: Count of POIs + """ + logger.info( + "Detecting POIs for tenant", + tenant_id=tenant_id, + location=(latitude, longitude), + force_refresh=force_refresh + ) + + params = { + "latitude": latitude, + "longitude": longitude, + "force_refresh": force_refresh + } + + # Updated endpoint path to follow tenant-based pattern: external/poi-context/detect + result = await self._make_request( + "POST", + "external/poi-context/detect", # Path will become /api/v1/tenants/{tenant_id}/external/poi-context/detect by base client + tenant_id=tenant_id, # Pass tenant_id to include in headers and path construction + params=params, + timeout=60.0 # POI detection can take longer + ) + + if result: + poi_context = result.get("poi_context", {}) + ml_features = poi_context.get("ml_features", {}) + + logger.info( + "POI detection completed successfully", + tenant_id=tenant_id, + total_pois=poi_context.get("total_pois_detected", 0), + ml_features_count=len(ml_features), + source=result.get("source", "unknown") + ) + return result + else: + logger.warning("POI detection failed for tenant", tenant_id=tenant_id) + return None + async def get_poi_context( self, tenant_id: str @@ -504,7 +573,7 @@ class ExternalServiceClient(BaseServiceClient): With the new tenant-based architecture: - Gateway receives at: /api/v1/tenants/{tenant_id}/external/poi-context - Gateway proxies to external service at: /api/v1/tenants/{tenant_id}/poi-context - - This client calls: /tenants/{tenant_id}/poi-context + - This client calls: poi-context (base client automatically constructs with tenant) This retrieves stored POI detection results and calculated ML features that should be included in demand forecasting predictions. @@ -521,11 +590,11 @@ class ExternalServiceClient(BaseServiceClient): """ logger.info("Fetching POI context for forecasting", tenant_id=tenant_id) - # Updated endpoint path to follow tenant-based pattern: /tenants/{tenant_id}/poi-context + # Updated endpoint path to follow tenant-based pattern: external/poi-context result = await self._make_request( "GET", - f"tenants/{tenant_id}/poi-context", # Updated path: /tenants/{tenant_id}/poi-context - tenant_id=tenant_id, # Pass tenant_id to include in headers for authentication + "external/poi-context", # Path will become /api/v1/tenants/{tenant_id}/external/poi-context by base client + tenant_id=tenant_id, # Pass tenant_id to include in headers and path construction timeout=5.0 )