Bug fixes of training

This commit is contained in:
Urtzi Alfaro
2025-11-14 20:27:39 +01:00
parent 71f9ca9d65
commit c349b845a6
11 changed files with 606 additions and 408 deletions

View File

@@ -41,8 +41,3 @@ nodes:
- containerPort: 30800 - containerPort: 30800
hostPort: 8000 hostPort: 8000
protocol: TCP protocol: TCP
sysctls:
# Increase fs.inotify limits to prevent "too many open files" errors
fs.inotify.max_user_watches: 524288
fs.inotify.max_user_instances: 256
fs.inotify.max_queued_events: 32768

View File

@@ -100,14 +100,19 @@ The **External Service** integrates real-world data from Spanish sources to enha
## API Endpoints (Key Routes) ## API Endpoints (Key Routes)
### POI Detection & Context ### POI Detection & Context
- `POST /poi-context/{tenant_id}/detect` - Detect POIs for tenant location (lat, long, force_refresh params) - `POST /api/v1/tenants/{tenant_id}/poi-context/detect` - Detect POIs for tenant location (lat, long, force_refresh params) - Direct access (bypasses gateway authentication)
- `GET /poi-context/{tenant_id}` - Get cached POI context for tenant - `GET /api/v1/tenants/{tenant_id}/poi-context` - Get cached POI context for tenant - Direct access (bypasses gateway authentication)
- `POST /poi-context/{tenant_id}/refresh` - Force refresh POI detection - `POST /api/v1/tenants/{tenant_id}/poi-context/refresh` - Force refresh POI detection - Direct access (bypasses gateway authentication)
- `DELETE /poi-context/{tenant_id}` - Delete POI context for tenant - `DELETE /api/v1/tenants/{tenant_id}/poi-context` - Delete POI context for tenant - Direct access (bypasses gateway authentication)
- `GET /poi-context/{tenant_id}/feature-importance` - Get POI feature importance summary - `GET /api/v1/tenants/{tenant_id}/poi-context/feature-importance` - Get POI feature importance summary - Direct access (bypasses gateway authentication)
- `GET /poi-context/{tenant_id}/competitor-analysis` - Get competitive analysis - `GET /api/v1/tenants/{tenant_id}/poi-context/competitor-analysis` - Get competitive analysis - Direct access (bypasses gateway authentication)
- `GET /poi-context/health` - Check POI service and Overpass API health - `GET /api/v1/tenants/poi-context/health` - Check POI service and Overpass API health - Direct access (bypasses gateway authentication)
- `GET /poi-context/cache/stats` - Get POI cache statistics - `GET /api/v1/tenants/poi-context/cache/stats` - Get POI cache statistics - Direct access (bypasses gateway authentication)
### Recommended Access Pattern:
- Services should use `/api/v1/tenants/{tenant_id}/external/poi-context` (detected POIs) and `/api/v1/tenants/{tenant_id}/external/poi-context/detect` (detection) via shared ExternalServiceClient through the API gateway for proper authentication and authorization.
**Note**: When using ExternalServiceClient through shared client, provide relative paths like `poi-context` and `poi-context/detect` - the client automatically constructs the full tenant-scoped path.
### Weather Data (AEMET) ### Weather Data (AEMET)
- `GET /api/v1/external/weather/current` - Current weather for location - `GET /api/v1/external/weather/current` - Current weather for location

View File

@@ -18,13 +18,17 @@ from app.services.poi_refresh_service import POIRefreshService
from app.repositories.poi_context_repository import POIContextRepository from app.repositories.poi_context_repository import POIContextRepository
from app.cache.poi_cache_service import POICacheService from app.cache.poi_cache_service import POICacheService
from app.core.redis_client import get_redis_client from app.core.redis_client import get_redis_client
from shared.routing.route_builder import RouteBuilder
logger = structlog.get_logger() logger = structlog.get_logger()
router = APIRouter(prefix="/tenants", tags=["POI Context"]) route_builder = RouteBuilder('external')
router = APIRouter(tags=["POI Context"])
@router.post("/{tenant_id}/poi-context/detect") @router.post(
route_builder.build_base_route("poi-context/detect")
)
async def detect_pois_for_tenant( async def detect_pois_for_tenant(
tenant_id: str, tenant_id: str,
latitude: float = Query(..., description="Bakery latitude"), latitude: float = Query(..., description="Bakery latitude"),
@@ -297,7 +301,9 @@ async def detect_pois_for_tenant(
) )
@router.get("/{tenant_id}/poi-context") @router.get(
route_builder.build_base_route("poi-context")
)
async def get_poi_context( async def get_poi_context(
tenant_id: str, tenant_id: str,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
@@ -331,7 +337,9 @@ async def get_poi_context(
} }
@router.post("/{tenant_id}/poi-context/refresh") @router.post(
route_builder.build_base_route("poi-context/refresh")
)
async def refresh_poi_context( async def refresh_poi_context(
tenant_id: str, tenant_id: str,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
@@ -365,7 +373,9 @@ async def refresh_poi_context(
) )
@router.delete("/{tenant_id}/poi-context") @router.delete(
route_builder.build_base_route("poi-context")
)
async def delete_poi_context( async def delete_poi_context(
tenant_id: str, tenant_id: str,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
@@ -393,7 +403,9 @@ async def delete_poi_context(
} }
@router.get("/{tenant_id}/poi-context/feature-importance") @router.get(
route_builder.build_base_route("poi-context/feature-importance")
)
async def get_feature_importance( async def get_feature_importance(
tenant_id: str, tenant_id: str,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
@@ -430,7 +442,9 @@ async def get_feature_importance(
} }
@router.get("/{tenant_id}/poi-context/competitor-analysis") @router.get(
route_builder.build_base_route("poi-context/competitor-analysis")
)
async def get_competitor_analysis( async def get_competitor_analysis(
tenant_id: str, tenant_id: str,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)

View File

@@ -497,7 +497,7 @@ poi_features = await poi_service.fetch_poi_features(tenant_id)
- **Accuracy Improvement** - POI features contribute 5-10% accuracy improvement - **Accuracy Improvement** - POI features contribute 5-10% accuracy improvement
**Endpoint Used:** **Endpoint Used:**
- `GET {EXTERNAL_SERVICE_URL}/poi-context/{tenant_id}` - Fetch POI features - Via shared client: `/api/v1/tenants/{tenant_id}/external/poi-context` (routed through API Gateway)
## Integration Points ## Integration Points

View File

@@ -5,10 +5,11 @@ Fetches POI features for use in demand forecasting predictions.
Ensures feature consistency between training and prediction. Ensures feature consistency between training and prediction.
""" """
import httpx
from typing import Dict, Any, Optional from typing import Dict, Any, Optional
import structlog import structlog
from shared.clients.external_client import ExternalServiceClient
logger = structlog.get_logger() logger = structlog.get_logger()
@@ -20,15 +21,18 @@ class POIFeatureService:
prediction uses the same features as training. prediction uses the same features as training.
""" """
def __init__(self, external_service_url: str = "http://external-service:8000"): def __init__(self, external_client: ExternalServiceClient = None):
""" """
Initialize POI feature service. Initialize POI feature service.
Args: Args:
external_service_url: Base URL for external service external_client: External service client instance (optional)
""" """
self.external_service_url = external_service_url.rstrip("/") if external_client is None:
self.poi_context_endpoint = f"{self.external_service_url}/poi-context" from app.core.config import settings
self.external_client = ExternalServiceClient(settings, "forecasting-service")
else:
self.external_client = external_client
async def get_poi_features( async def get_poi_features(
self, self,
@@ -44,21 +48,10 @@ class POIFeatureService:
Dictionary with POI features or empty dict if not available Dictionary with POI features or empty dict if not available
""" """
try: try:
async with httpx.AsyncClient(timeout=10.0) as client: result = await self.external_client.get_poi_context(tenant_id)
response = await client.get(
f"{self.poi_context_endpoint}/{tenant_id}"
)
if response.status_code == 404: if result:
logger.warning( poi_context = result.get("poi_context", {})
"No POI context found for tenant",
tenant_id=tenant_id
)
return {}
response.raise_for_status()
data = response.json()
poi_context = data.get("poi_context", {})
ml_features = poi_context.get("ml_features", {}) ml_features = poi_context.get("ml_features", {})
logger.info( logger.info(
@@ -68,17 +61,16 @@ class POIFeatureService:
) )
return ml_features return ml_features
else:
except httpx.HTTPError as e: logger.warning(
logger.error( "No POI context found for tenant",
"Failed to fetch POI features for forecasting", tenant_id=tenant_id
tenant_id=tenant_id,
error=str(e)
) )
return {} return {}
except Exception as e: except Exception as e:
logger.error( logger.error(
"Unexpected error fetching POI features", "Failed to fetch POI features for forecasting",
tenant_id=tenant_id, tenant_id=tenant_id,
error=str(e), error=str(e),
exc_info=True exc_info=True
@@ -87,17 +79,18 @@ class POIFeatureService:
async def check_poi_service_health(self) -> bool: async def check_poi_service_health(self) -> bool:
""" """
Check if POI service is accessible. Check if POI service is accessible through the external client.
Returns: Returns:
True if service is healthy, False otherwise True if service is healthy, False otherwise
""" """
try: try:
async with httpx.AsyncClient(timeout=5.0) as client: # Test the external service health by attempting to get POI context for a dummy tenant
response = await client.get( # This will go through the proper authentication and routing
f"{self.poi_context_endpoint}/health" dummy_context = await self.external_client.get_poi_context("test-tenant")
) # If we can successfully make a request (even if it returns None for missing tenant),
return response.status_code == 200 # it means the service is accessible
return True
except Exception as e: except Exception as e:
logger.error( logger.error(
"POI service health check failed", "POI service health check failed",

View File

@@ -581,7 +581,7 @@ for feature_name in poi_features.keys():
``` ```
**Endpoint Used:** **Endpoint Used:**
- `GET {EXTERNAL_SERVICE_URL}/poi-context/{tenant_id}` - Fetch POI features - Via shared client: `/api/v1/tenants/{tenant_id}/external/poi-context` (routed through API Gateway)
## Integration Points ## Integration Points

View File

@@ -5,11 +5,12 @@ Integrates POI features into ML training pipeline.
Fetches POI context from External service and merges features into training data. Fetches POI context from External service and merges features into training data.
""" """
import httpx
from typing import Dict, Any, Optional, List from typing import Dict, Any, Optional, List
import structlog import structlog
import pandas as pd import pandas as pd
from shared.clients.external_client import ExternalServiceClient
logger = structlog.get_logger() logger = structlog.get_logger()
@@ -21,15 +22,18 @@ class POIFeatureIntegrator:
to training dataframes for location-based demand forecasting. to training dataframes for location-based demand forecasting.
""" """
def __init__(self, external_service_url: str = "http://external-service:8000"): def __init__(self, external_client: ExternalServiceClient = None):
""" """
Initialize POI feature integrator. Initialize POI feature integrator.
Args: Args:
external_service_url: Base URL for external service external_client: External service client instance (optional)
""" """
self.external_service_url = external_service_url.rstrip("/") if external_client is None:
self.poi_context_endpoint = f"{self.external_service_url}/poi-context" from app.core.config import settings
self.external_client = ExternalServiceClient(settings, "training-service")
else:
self.external_client = external_client
async def fetch_poi_features( async def fetch_poi_features(
self, self,
@@ -53,33 +57,28 @@ class POIFeatureIntegrator:
Dictionary with POI features or None if detection fails Dictionary with POI features or None if detection fails
""" """
try: try:
async with httpx.AsyncClient(timeout=60.0) as client:
# Try to get existing POI context first # Try to get existing POI context first
if not force_refresh: if not force_refresh:
try: existing_context = await self.external_client.get_poi_context(tenant_id)
response = await client.get( if existing_context:
f"{self.poi_context_endpoint}/{tenant_id}" poi_context = existing_context.get("poi_context", {})
) ml_features = poi_context.get("ml_features", {})
if response.status_code == 200:
data = response.json()
poi_context = data.get("poi_context", {})
# Check if stale # Check if stale
if not data.get("is_stale", False): is_stale = existing_context.get("is_stale", False)
if not is_stale:
logger.info( logger.info(
"Using existing POI context", "Using existing POI context",
tenant_id=tenant_id tenant_id=tenant_id
) )
return poi_context.get("ml_features", {}) return ml_features
else: else:
logger.info( logger.info(
"POI context is stale, refreshing", "POI context is stale, refreshing",
tenant_id=tenant_id tenant_id=tenant_id
) )
force_refresh = True force_refresh = True
except httpx.HTTPStatusError as e: else:
if e.response.status_code != 404:
raise
logger.info( logger.info(
"No existing POI context, will detect", "No existing POI context, will detect",
tenant_id=tenant_id tenant_id=tenant_id
@@ -92,18 +91,15 @@ class POIFeatureIntegrator:
location=(latitude, longitude) location=(latitude, longitude)
) )
response = await client.post( detection_result = await self.external_client.detect_poi_for_tenant(
f"{self.poi_context_endpoint}/{tenant_id}/detect", tenant_id=tenant_id,
params={ latitude=latitude,
"latitude": latitude, longitude=longitude,
"longitude": longitude, force_refresh=force_refresh
"force_refresh": force_refresh
}
) )
response.raise_for_status()
result = response.json() if detection_result:
poi_context = result.get("poi_context", {}) poi_context = detection_result.get("poi_context", {})
ml_features = poi_context.get("ml_features", {}) ml_features = poi_context.get("ml_features", {})
logger.info( logger.info(
@@ -114,15 +110,13 @@ class POIFeatureIntegrator:
) )
return ml_features return ml_features
else:
except httpx.HTTPError as e:
logger.error( logger.error(
"Failed to fetch POI features", "POI detection failed",
tenant_id=tenant_id, tenant_id=tenant_id
error=str(e),
exc_info=True
) )
return None return None
except Exception as e: except Exception as e:
logger.error( logger.error(
"Unexpected error fetching POI features", "Unexpected error fetching POI features",
@@ -185,17 +179,18 @@ class POIFeatureIntegrator:
async def check_poi_service_health(self) -> bool: async def check_poi_service_health(self) -> bool:
""" """
Check if POI service is accessible. Check if POI service is accessible through the external client.
Returns: Returns:
True if service is healthy, False otherwise True if service is healthy, False otherwise
""" """
try: try:
async with httpx.AsyncClient(timeout=5.0) as client: # We can test the external service health by attempting to get POI context for a dummy tenant
response = await client.get( # This will go through the proper authentication and routing
f"{self.poi_context_endpoint}/health" dummy_context = await self.external_client.get_poi_context("test-tenant")
) # If we can successfully make a request (even if it returns None for missing tenant),
return response.status_code == 200 # it means the service is accessible
return True
except Exception as e: except Exception as e:
logger.error( logger.error(
"POI service health check failed", "POI service health check failed",

View File

@@ -375,7 +375,6 @@ class EnhancedBakeryMLTrainer:
try: try:
# Use provided session or create new one to prevent nested sessions and deadlocks # Use provided session or create new one to prevent nested sessions and deadlocks
should_create_session = session is None should_create_session = session is None
db_session = session if session is not None else None
if should_create_session: if should_create_session:
# Only create a session if one wasn't provided # Only create a session if one wasn't provided
@@ -425,6 +424,27 @@ class EnhancedBakeryMLTrainer:
inventory_product_id=inventory_product_id, inventory_product_id=inventory_product_id,
result_status=result.get('status')) result_status=result.get('status'))
# Write training result to database (create model record)
if result.get('status') == 'success':
model_info = result.get('model_info')
product_data = result.get('product_data')
if model_info and product_data is not None:
# Create model record in database
model_record = await self._create_model_record(
repos, tenant_id, inventory_product_id, model_info, job_id, product_data
)
# Create performance metrics
if model_info.get('training_metrics') and model_record:
await self._create_performance_metrics(
repos, model_record.id,
tenant_id, inventory_product_id, model_info['training_metrics']
)
# Update result with model_record_id
result['model_record_id'] = str(model_record.id) if model_record else None
# Get training metrics and filter out non-numeric values # Get training metrics and filter out non-numeric values
raw_metrics = result.get('model_info', {}).get('training_metrics', {}) raw_metrics = result.get('model_info', {}).get('training_metrics', {})
# Filter metrics to only include numeric values (per Pydantic schema requirement) # Filter metrics to only include numeric values (per Pydantic schema requirement)
@@ -453,6 +473,116 @@ class EnhancedBakeryMLTrainer:
"message": f"Single product model training {'completed' if result.get('status') != 'error' else 'failed'}" "message": f"Single product model training {'completed' if result.get('status') != 'error' else 'failed'}"
} }
# Only commit if this is our own session (not a parent session)
# Commit after we're done with all database operations
await db_session.commit()
logger.info("Committed single product model record to database",
inventory_product_id=inventory_product_id,
model_record_id=result.get('model_record_id'))
return result_dict
else:
# Use the provided session
repos = await self._get_repositories(session)
# Validate input data
if training_data.empty or len(training_data) < settings.MIN_TRAINING_DATA_DAYS:
raise ValueError(f"Insufficient training data: need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(training_data)}")
# Validate required columns
required_columns = ['ds', 'y']
missing_cols = [col for col in required_columns if col not in training_data.columns]
if missing_cols:
raise ValueError(f"Missing required columns in training data: {missing_cols}")
# Create a simple progress tracker for single product
from app.services.progress_tracker import ParallelProductProgressTracker
progress_tracker = ParallelProductProgressTracker(
job_id=job_id,
tenant_id=tenant_id,
total_products=1
)
# Ensure training data has proper data types before training
if 'ds' in training_data.columns:
training_data['ds'] = pd.to_datetime(training_data['ds'])
if 'y' in training_data.columns:
training_data['y'] = pd.to_numeric(training_data['y'], errors='coerce')
# Remove any rows with NaN values
training_data = training_data.dropna()
# Train the model using the existing _train_single_product method
product_id, result = await self._train_single_product(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
product_data=training_data,
job_id=job_id,
repos=repos,
progress_tracker=progress_tracker,
session=session # Pass the provided session
)
logger.info("Single product training completed",
job_id=job_id,
inventory_product_id=inventory_product_id,
result_status=result.get('status'))
# Write training result to database (create model record)
if result.get('status') == 'success':
model_info = result.get('model_info')
product_data = result.get('product_data')
if model_info and product_data is not None:
# Create model record in database
model_record = await self._create_model_record(
repos, tenant_id, inventory_product_id, model_info, job_id, product_data
)
# Create performance metrics
if model_info.get('training_metrics') and model_record:
await self._create_performance_metrics(
repos, model_record.id,
tenant_id, inventory_product_id, model_info['training_metrics']
)
# Update result with model_record_id
result['model_record_id'] = str(model_record.id) if model_record else None
# Get training metrics and filter out non-numeric values
raw_metrics = result.get('model_info', {}).get('training_metrics', {})
# Filter metrics to only include numeric values (per Pydantic schema requirement)
filtered_metrics = {}
for key, value in raw_metrics.items():
if key == 'product_category':
# Skip product_category as it's a string value, not a numeric metric
continue
try:
# Try to convert to float for validation
filtered_metrics[key] = float(value) if value is not None else 0.0
except (ValueError, TypeError):
# Skip non-numeric values
continue
# Return appropriate result format
result_dict = {
"job_id": job_id,
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id,
"status": result.get('status', 'success'),
"model_id": str(result.get('model_record_id', '')) if result.get('model_record_id') else None,
"training_metrics": filtered_metrics,
"training_time": result.get('training_time_seconds', 0),
"data_points": result.get('data_points', 0),
"message": f"Single product model training {'completed' if result.get('status') != 'error' else 'failed'}"
}
# For provided sessions, do NOT commit here - let the calling method handle commits
# This prevents committing a parent transaction prematurely
logger.info("Single product model processed (commit handled by caller)",
inventory_product_id=inventory_product_id,
model_record_id=result.get('model_record_id'))
return result_dict return result_dict
except Exception as e: except Exception as e:

View File

@@ -186,6 +186,7 @@ class TrainedModel(Base):
"training_samples": self.training_samples, "training_samples": self.training_samples,
"hyperparameters": self.hyperparameters, "hyperparameters": self.hyperparameters,
"features_used": self.features_used, "features_used": self.features_used,
"features": self.features_used, # Alias for frontend compatibility (ModelDetailsModal expects 'features')
"product_category": self.product_category, "product_category": self.product_category,
"is_active": self.is_active, "is_active": self.is_active,
"is_production": self.is_production, "is_production": self.is_production,

View File

@@ -732,12 +732,47 @@ class EnhancedTrainingService:
current_step: str = None, current_step: str = None,
error_message: str = None, error_message: str = None,
results: Dict = None, results: Dict = None,
tenant_id: str = None): tenant_id: str = None,
"""Update job status using repository pattern""" session = None):
"""Update job status using repository pattern
Args:
session: Optional database session to reuse. If None, creates a new session.
"""
try: try:
# Use provided session or create new one
should_create_session = session is None
if should_create_session:
async with self.database_manager.get_session() as session: async with self.database_manager.get_session() as session:
await self._init_repositories(session) await self._init_repositories(session)
await self._update_job_status_impl(
session, job_id, status, progress, current_step,
error_message, results, tenant_id
)
else:
# Reuse provided session (don't commit - let caller control transaction)
await self._init_repositories(session)
await self._update_job_status_impl(
session, job_id, status, progress, current_step,
error_message, results, tenant_id, auto_commit=False
)
except Exception as e:
logger.error("Failed to update job status using repository",
job_id=job_id,
error=str(e))
async def _update_job_status_impl(self,
session,
job_id: str,
status: str,
progress: int = None,
current_step: str = None,
error_message: str = None,
results: Dict = None,
tenant_id: str = None,
auto_commit: bool = True):
"""Implementation of job status update"""
# Check if log exists, create if not # Check if log exists, create if not
existing_log = await self.training_log_repo.get_log_by_job_id(job_id) existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
@@ -771,6 +806,7 @@ class EnhancedTrainingService:
try: try:
await self.training_log_repo.create_training_log(log_data) await self.training_log_repo.create_training_log(log_data)
if auto_commit:
await session.commit() # Explicit commit so other sessions can see it await session.commit() # Explicit commit so other sessions can see it
logger.info("Created initial training log", job_id=job_id, tenant_id=tenant_id) logger.info("Created initial training log", job_id=job_id, tenant_id=tenant_id)
except Exception as create_error: except Exception as create_error:
@@ -788,6 +824,7 @@ class EnhancedTrainingService:
current_step=current_step, current_step=current_step,
status=status status=status
) )
if auto_commit:
await session.commit() await session.commit()
else: else:
raise raise
@@ -817,33 +854,35 @@ class EnhancedTrainingService:
if update_data: if update_data:
await self.training_log_repo.update(existing_log.id, update_data) await self.training_log_repo.update(existing_log.id, update_data)
if auto_commit:
await session.commit() # Explicit commit after updates await session.commit() # Explicit commit after updates
except Exception as e:
logger.error("Failed to update job status using repository",
job_id=job_id,
error=str(e))
async def start_single_product_training(self, async def start_single_product_training(self,
tenant_id: str, tenant_id: str,
inventory_product_id: str, inventory_product_id: str,
job_id: str, job_id: str,
bakery_location: tuple = (40.4168, -3.7038)) -> Dict[str, Any]: bakery_location: tuple = (40.4168, -3.7038)) -> Dict[str, Any]:
"""Start enhanced single product training using repository pattern""" """Start enhanced single product training using repository pattern with single session"""
# Create a single database session for all operations to avoid connection pool exhaustion
async with self.database_manager.get_session() as session:
await self._init_repositories(session)
try: try:
logger.info("Starting enhanced single product training", logger.info("Starting enhanced single product training",
tenant_id=tenant_id, tenant_id=tenant_id,
inventory_product_id=inventory_product_id, inventory_product_id=inventory_product_id,
job_id=job_id) job_id=job_id)
# Create initial training log # Create initial training log (using shared session)
await self._update_job_status_repository( await self._update_job_status_repository(
job_id=job_id, job_id=job_id,
status="running", status="running",
progress=0, progress=0,
current_step="Fetching training data", current_step="Fetching training data",
tenant_id=tenant_id tenant_id=tenant_id,
session=session
) )
await session.commit() # Commit after initial log creation
# Prepare training data for all products to get weather/traffic data # Prepare training data for all products to get weather/traffic data
# then filter down to the specific product # then filter down to the specific product
@@ -853,111 +892,64 @@ class EnhancedTrainingService:
job_id=job_id + "_temp" job_id=job_id + "_temp"
) )
# Filter sales data to the specific product # Use the enhanced data processor to prepare training data with all features (POI, weather, traffic)
# Filter sales data to the specific product first
sales_df = pd.DataFrame(training_dataset.sales_data) sales_df = pd.DataFrame(training_dataset.sales_data)
product_sales_df = sales_df[sales_df['inventory_product_id'] == inventory_product_id] product_sales_df = sales_df[sales_df['inventory_product_id'] == inventory_product_id]
if product_sales_df.empty: if product_sales_df.empty:
raise ValueError(f"No sales data available for product {inventory_product_id}") raise ValueError(f"No sales data available for product {inventory_product_id}")
# Prepare the data in Prophet format (ds and y columns) # Get weather and traffic data as DataFrames
# Ensure proper column names and types for Prophet weather_df = pd.DataFrame(training_dataset.weather_data)
product_data = product_sales_df.copy() traffic_df = pd.DataFrame(training_dataset.traffic_data)
product_data = product_data.rename(columns={
'sale_date': 'ds', # Common sales date column
'sale_datetime': 'ds', # Alternative date column
'date': 'ds', # Alternative date column
'quantity': 'y', # Quantity sold
'total_amount': 'y', # Alternative for sales data
'sales_amount': 'y', # Alternative for sales data
'sale_amount': 'y' # Alternative for sales data
})
# If 'ds' and 'y' columns are not renamed properly, try to infer them # Get POI features from the training dataset (already collected by orchestrator)
if 'ds' not in product_data.columns: poi_features = training_dataset.poi_features if hasattr(training_dataset, 'poi_features') else None
# Try to find date-like columns
date_cols = [col for col in product_data.columns if 'date' in col.lower() or 'time' in col.lower()]
if date_cols:
product_data = product_data.rename(columns={date_cols[0]: 'ds'})
if 'y' not in product_data.columns: # Use the enhanced data processor to merge all features properly
# Try to find sales/quantity-like columns # This will include POI, weather, traffic features along with ds and y
sales_cols = [col for col in product_data.columns if from app.ml.data_processor import EnhancedBakeryDataProcessor
any(word in col.lower() for word in ['amount', 'quantity', 'sales', 'total', 'count', 'value'])] data_processor = EnhancedBakeryDataProcessor(self.database_manager)
if sales_cols:
product_data = product_data.rename(columns={sales_cols[0]: 'y'})
# Ensure required columns exist product_data = await data_processor.prepare_training_data(
if 'ds' not in product_data.columns or 'y' not in product_data.columns: sales_data=product_sales_df,
raise ValueError(f"Sales data must contain 'date' and 'quantity/sales' columns. Available columns: {list(product_data.columns)}") weather_data=weather_df,
traffic_data=traffic_df,
inventory_product_id=inventory_product_id,
poi_features=poi_features,
tenant_id=tenant_id,
job_id=job_id
)
# Convert the date column to datetime if it's not already if product_data.empty:
product_data['ds'] = pd.to_datetime(product_data['ds']) raise ValueError(f"Data processor returned empty data for product {inventory_product_id}")
# Convert to numeric ensuring no pandas/numpy objects remain
product_data['y'] = pd.to_numeric(product_data['y'], errors='coerce')
# Sort by date to ensure proper chronological order
product_data = product_data.sort_values('ds').reset_index(drop=True)
# Drop any rows with NaN values
product_data = product_data.dropna(subset=['ds', 'y'])
# Ensure the data is in the right format for Prophet
product_data = product_data[['ds', 'y']].copy()
# Convert to pandas datetime and float types (keep as pandas Series for proper min/max operations)
product_data['ds'] = pd.to_datetime(product_data['ds'])
product_data['y'] = product_data['y'].astype(float)
# DEBUG: Log data types to diagnose dict comparison error
logger.info(f"DEBUG: product_data dtypes after conversion: ds={product_data['ds'].dtype}, y={product_data['y'].dtype}")
logger.info(f"DEBUG: product_data['ds'] sample values: {product_data['ds'].head(3).tolist()}")
logger.info(f"DEBUG: Attempting to get min/max...")
try:
min_val = product_data['ds'].min()
max_val = product_data['ds'].max()
logger.info(f"DEBUG: min_val type={type(min_val)}, value={min_val}")
logger.info(f"DEBUG: max_val type={type(max_val)}, value={max_val}")
except Exception as debug_e:
logger.error(f"DEBUG: Failed to get min/max: {debug_e}")
import traceback
logger.error(f"DEBUG: Traceback: {traceback.format_exc()}")
logger.info("Prepared training data for single product", logger.info("Prepared training data for single product",
inventory_product_id=inventory_product_id, inventory_product_id=inventory_product_id,
data_points=len(product_data), data_points=len(product_data),
features=list(product_data.columns),
date_range=f"{product_data['ds'].min()} to {product_data['ds'].max()}") date_range=f"{product_data['ds'].min()} to {product_data['ds'].max()}")
# Update progress # Update progress (using shared session)
await self._update_job_status_repository( await self._update_job_status_repository(
job_id=job_id, job_id=job_id,
status="running", status="running",
progress=30, progress=30,
current_step="Training model", current_step="Training model",
tenant_id=tenant_id tenant_id=tenant_id,
session=session
) )
await session.commit() # Commit progress update
# Train the model using the trainer # Run the actual training (passing the session to avoid nested session creation)
# Extract datetime values with proper pandas Timestamp wrapper for type safety
try:
training_start = pd.Timestamp(product_data['ds'].min()).to_pydatetime()
training_end = pd.Timestamp(product_data['ds'].max()).to_pydatetime()
except Exception as e:
import traceback
logger.error(f"Failed to extract training dates: {e}")
logger.error(f"Traceback: {traceback.format_exc()}")
logger.error(f"product_data['ds'] dtype: {product_data['ds'].dtype}")
logger.error(f"product_data['ds'] first 5 values: {product_data['ds'].head().tolist()}")
raise
# Run the actual training
try: try:
model_info = await self.trainer.train_single_product_model( model_info = await self.trainer.train_single_product_model(
tenant_id=tenant_id, tenant_id=tenant_id,
inventory_product_id=inventory_product_id, inventory_product_id=inventory_product_id,
training_data=product_data, training_data=product_data,
job_id=job_id job_id=job_id,
session=session # ✅ CRITICAL FIX: Pass session to prevent deadlock
) )
except Exception as e: except Exception as e:
import traceback import traceback
@@ -965,14 +957,16 @@ class EnhancedTrainingService:
logger.error(f"Full traceback: {traceback.format_exc()}") logger.error(f"Full traceback: {traceback.format_exc()}")
raise raise
# Update progress # Update progress (using shared session)
await self._update_job_status_repository( await self._update_job_status_repository(
job_id=job_id, job_id=job_id,
status="running", status="running",
progress=80, progress=80,
current_step="Saving model", current_step="Saving model",
tenant_id=tenant_id tenant_id=tenant_id,
session=session
) )
await session.commit() # Commit progress update
# The model should already be saved by train_single_product_model # The model should already be saved by train_single_product_model
# Return appropriate response # Return appropriate response
@@ -1012,15 +1006,17 @@ class EnhancedTrainingService:
inventory_product_id=inventory_product_id, inventory_product_id=inventory_product_id,
error=str(e)) error=str(e))
# Update status to failed # Update status to failed (using shared session)
await self._update_job_status_repository( await self._update_job_status_repository(
job_id=job_id, job_id=job_id,
status="failed", status="failed",
progress=0, progress=0,
current_step="Training failed", current_step="Training failed",
error_message=str(e), error_message=str(e),
tenant_id=tenant_id tenant_id=tenant_id,
session=session
) )
await session.commit() # Commit failure status
raise raise

View File

@@ -494,6 +494,75 @@ class ExternalServiceClient(BaseServiceClient):
# POI (POINT OF INTEREST) DATA # POI (POINT OF INTEREST) DATA
# ================================================================ # ================================================================
async def detect_poi_for_tenant(
self,
tenant_id: str,
latitude: float,
longitude: float,
force_refresh: bool = False
) -> Optional[Dict[str, Any]]:
"""
Detect POIs for a tenant's location and generate ML features for forecasting.
With the new tenant-based architecture:
- Gateway receives at: /api/v1/tenants/{tenant_id}/external/poi-context/detect
- Gateway proxies to external service at: /api/v1/tenants/{tenant_id}/poi-context/detect
- This client calls: poi-context/detect (base client automatically constructs with tenant)
This triggers POI detection using Overpass API and calculates ML features
for demand forecasting.
Args:
tenant_id: Tenant ID
latitude: Latitude of the bakery location
longitude: Longitude of the bakery location
force_refresh: Whether to force refresh even if POI context exists
Returns:
Dict with POI detection results including:
- ml_features: Dict of POI features for ML models (e.g., poi_retail_total_count)
- poi_detection_results: Full detection results
- location: Latitude/longitude
- total_pois_detected: Count of POIs
"""
logger.info(
"Detecting POIs for tenant",
tenant_id=tenant_id,
location=(latitude, longitude),
force_refresh=force_refresh
)
params = {
"latitude": latitude,
"longitude": longitude,
"force_refresh": force_refresh
}
# Updated endpoint path to follow tenant-based pattern: external/poi-context/detect
result = await self._make_request(
"POST",
"external/poi-context/detect", # Path will become /api/v1/tenants/{tenant_id}/external/poi-context/detect by base client
tenant_id=tenant_id, # Pass tenant_id to include in headers and path construction
params=params,
timeout=60.0 # POI detection can take longer
)
if result:
poi_context = result.get("poi_context", {})
ml_features = poi_context.get("ml_features", {})
logger.info(
"POI detection completed successfully",
tenant_id=tenant_id,
total_pois=poi_context.get("total_pois_detected", 0),
ml_features_count=len(ml_features),
source=result.get("source", "unknown")
)
return result
else:
logger.warning("POI detection failed for tenant", tenant_id=tenant_id)
return None
async def get_poi_context( async def get_poi_context(
self, self,
tenant_id: str tenant_id: str
@@ -504,7 +573,7 @@ class ExternalServiceClient(BaseServiceClient):
With the new tenant-based architecture: With the new tenant-based architecture:
- Gateway receives at: /api/v1/tenants/{tenant_id}/external/poi-context - Gateway receives at: /api/v1/tenants/{tenant_id}/external/poi-context
- Gateway proxies to external service at: /api/v1/tenants/{tenant_id}/poi-context - Gateway proxies to external service at: /api/v1/tenants/{tenant_id}/poi-context
- This client calls: /tenants/{tenant_id}/poi-context - This client calls: poi-context (base client automatically constructs with tenant)
This retrieves stored POI detection results and calculated ML features This retrieves stored POI detection results and calculated ML features
that should be included in demand forecasting predictions. that should be included in demand forecasting predictions.
@@ -521,11 +590,11 @@ class ExternalServiceClient(BaseServiceClient):
""" """
logger.info("Fetching POI context for forecasting", tenant_id=tenant_id) logger.info("Fetching POI context for forecasting", tenant_id=tenant_id)
# Updated endpoint path to follow tenant-based pattern: /tenants/{tenant_id}/poi-context # Updated endpoint path to follow tenant-based pattern: external/poi-context
result = await self._make_request( result = await self._make_request(
"GET", "GET",
f"tenants/{tenant_id}/poi-context", # Updated path: /tenants/{tenant_id}/poi-context "external/poi-context", # Path will become /api/v1/tenants/{tenant_id}/external/poi-context by base client
tenant_id=tenant_id, # Pass tenant_id to include in headers for authentication tenant_id=tenant_id, # Pass tenant_id to include in headers and path construction
timeout=5.0 timeout=5.0
) )