Bug fixes of training

This commit is contained in:
Urtzi Alfaro
2025-11-14 20:27:39 +01:00
parent 71f9ca9d65
commit c349b845a6
11 changed files with 606 additions and 408 deletions

View File

@@ -40,9 +40,4 @@ nodes:
# Direct gateway access (backup)
- containerPort: 30800
hostPort: 8000
protocol: TCP
sysctls:
# Increase fs.inotify limits to prevent "too many open files" errors
fs.inotify.max_user_watches: 524288
fs.inotify.max_user_instances: 256
fs.inotify.max_queued_events: 32768
protocol: TCP

View File

@@ -100,14 +100,19 @@ The **External Service** integrates real-world data from Spanish sources to enha
## API Endpoints (Key Routes)
### POI Detection & Context
- `POST /poi-context/{tenant_id}/detect` - Detect POIs for tenant location (lat, long, force_refresh params)
- `GET /poi-context/{tenant_id}` - Get cached POI context for tenant
- `POST /poi-context/{tenant_id}/refresh` - Force refresh POI detection
- `DELETE /poi-context/{tenant_id}` - Delete POI context for tenant
- `GET /poi-context/{tenant_id}/feature-importance` - Get POI feature importance summary
- `GET /poi-context/{tenant_id}/competitor-analysis` - Get competitive analysis
- `GET /poi-context/health` - Check POI service and Overpass API health
- `GET /poi-context/cache/stats` - Get POI cache statistics
- `POST /api/v1/tenants/{tenant_id}/poi-context/detect` - Detect POIs for tenant location (lat, long, force_refresh params) - Direct access (bypasses gateway authentication)
- `GET /api/v1/tenants/{tenant_id}/poi-context` - Get cached POI context for tenant - Direct access (bypasses gateway authentication)
- `POST /api/v1/tenants/{tenant_id}/poi-context/refresh` - Force refresh POI detection - Direct access (bypasses gateway authentication)
- `DELETE /api/v1/tenants/{tenant_id}/poi-context` - Delete POI context for tenant - Direct access (bypasses gateway authentication)
- `GET /api/v1/tenants/{tenant_id}/poi-context/feature-importance` - Get POI feature importance summary - Direct access (bypasses gateway authentication)
- `GET /api/v1/tenants/{tenant_id}/poi-context/competitor-analysis` - Get competitive analysis - Direct access (bypasses gateway authentication)
- `GET /api/v1/tenants/poi-context/health` - Check POI service and Overpass API health - Direct access (bypasses gateway authentication)
- `GET /api/v1/tenants/poi-context/cache/stats` - Get POI cache statistics - Direct access (bypasses gateway authentication)
### Recommended Access Pattern:
- Services should use `/api/v1/tenants/{tenant_id}/external/poi-context` (detected POIs) and `/api/v1/tenants/{tenant_id}/external/poi-context/detect` (detection) via shared ExternalServiceClient through the API gateway for proper authentication and authorization.
**Note**: When using ExternalServiceClient through shared client, provide relative paths like `poi-context` and `poi-context/detect` - the client automatically constructs the full tenant-scoped path.
### Weather Data (AEMET)
- `GET /api/v1/external/weather/current` - Current weather for location

View File

@@ -18,13 +18,17 @@ from app.services.poi_refresh_service import POIRefreshService
from app.repositories.poi_context_repository import POIContextRepository
from app.cache.poi_cache_service import POICacheService
from app.core.redis_client import get_redis_client
from shared.routing.route_builder import RouteBuilder
logger = structlog.get_logger()
router = APIRouter(prefix="/tenants", tags=["POI Context"])
route_builder = RouteBuilder('external')
router = APIRouter(tags=["POI Context"])
@router.post("/{tenant_id}/poi-context/detect")
@router.post(
route_builder.build_base_route("poi-context/detect")
)
async def detect_pois_for_tenant(
tenant_id: str,
latitude: float = Query(..., description="Bakery latitude"),
@@ -297,7 +301,9 @@ async def detect_pois_for_tenant(
)
@router.get("/{tenant_id}/poi-context")
@router.get(
route_builder.build_base_route("poi-context")
)
async def get_poi_context(
tenant_id: str,
db: AsyncSession = Depends(get_db)
@@ -331,7 +337,9 @@ async def get_poi_context(
}
@router.post("/{tenant_id}/poi-context/refresh")
@router.post(
route_builder.build_base_route("poi-context/refresh")
)
async def refresh_poi_context(
tenant_id: str,
db: AsyncSession = Depends(get_db)
@@ -365,7 +373,9 @@ async def refresh_poi_context(
)
@router.delete("/{tenant_id}/poi-context")
@router.delete(
route_builder.build_base_route("poi-context")
)
async def delete_poi_context(
tenant_id: str,
db: AsyncSession = Depends(get_db)
@@ -393,7 +403,9 @@ async def delete_poi_context(
}
@router.get("/{tenant_id}/poi-context/feature-importance")
@router.get(
route_builder.build_base_route("poi-context/feature-importance")
)
async def get_feature_importance(
tenant_id: str,
db: AsyncSession = Depends(get_db)
@@ -430,7 +442,9 @@ async def get_feature_importance(
}
@router.get("/{tenant_id}/poi-context/competitor-analysis")
@router.get(
route_builder.build_base_route("poi-context/competitor-analysis")
)
async def get_competitor_analysis(
tenant_id: str,
db: AsyncSession = Depends(get_db)

View File

@@ -497,7 +497,7 @@ poi_features = await poi_service.fetch_poi_features(tenant_id)
- **Accuracy Improvement** - POI features contribute 5-10% accuracy improvement
**Endpoint Used:**
- `GET {EXTERNAL_SERVICE_URL}/poi-context/{tenant_id}` - Fetch POI features
- Via shared client: `/api/v1/tenants/{tenant_id}/external/poi-context` (routed through API Gateway)
## Integration Points

View File

@@ -5,10 +5,11 @@ Fetches POI features for use in demand forecasting predictions.
Ensures feature consistency between training and prediction.
"""
import httpx
from typing import Dict, Any, Optional
import structlog
from shared.clients.external_client import ExternalServiceClient
logger = structlog.get_logger()
@@ -20,15 +21,18 @@ class POIFeatureService:
prediction uses the same features as training.
"""
def __init__(self, external_service_url: str = "http://external-service:8000"):
def __init__(self, external_client: ExternalServiceClient = None):
"""
Initialize POI feature service.
Args:
external_service_url: Base URL for external service
external_client: External service client instance (optional)
"""
self.external_service_url = external_service_url.rstrip("/")
self.poi_context_endpoint = f"{self.external_service_url}/poi-context"
if external_client is None:
from app.core.config import settings
self.external_client = ExternalServiceClient(settings, "forecasting-service")
else:
self.external_client = external_client
async def get_poi_features(
self,
@@ -44,21 +48,10 @@ class POIFeatureService:
Dictionary with POI features or empty dict if not available
"""
try:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(
f"{self.poi_context_endpoint}/{tenant_id}"
)
result = await self.external_client.get_poi_context(tenant_id)
if response.status_code == 404:
logger.warning(
"No POI context found for tenant",
tenant_id=tenant_id
)
return {}
response.raise_for_status()
data = response.json()
poi_context = data.get("poi_context", {})
if result:
poi_context = result.get("poi_context", {})
ml_features = poi_context.get("ml_features", {})
logger.info(
@@ -68,17 +61,16 @@ class POIFeatureService:
)
return ml_features
else:
logger.warning(
"No POI context found for tenant",
tenant_id=tenant_id
)
return {}
except httpx.HTTPError as e:
logger.error(
"Failed to fetch POI features for forecasting",
tenant_id=tenant_id,
error=str(e)
)
return {}
except Exception as e:
logger.error(
"Unexpected error fetching POI features",
"Failed to fetch POI features for forecasting",
tenant_id=tenant_id,
error=str(e),
exc_info=True
@@ -87,17 +79,18 @@ class POIFeatureService:
async def check_poi_service_health(self) -> bool:
"""
Check if POI service is accessible.
Check if POI service is accessible through the external client.
Returns:
True if service is healthy, False otherwise
"""
try:
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.get(
f"{self.poi_context_endpoint}/health"
)
return response.status_code == 200
# Test the external service health by attempting to get POI context for a dummy tenant
# This will go through the proper authentication and routing
dummy_context = await self.external_client.get_poi_context("test-tenant")
# If we can successfully make a request (even if it returns None for missing tenant),
# it means the service is accessible
return True
except Exception as e:
logger.error(
"POI service health check failed",

View File

@@ -581,7 +581,7 @@ for feature_name in poi_features.keys():
```
**Endpoint Used:**
- `GET {EXTERNAL_SERVICE_URL}/poi-context/{tenant_id}` - Fetch POI features
- Via shared client: `/api/v1/tenants/{tenant_id}/external/poi-context` (routed through API Gateway)
## Integration Points

View File

@@ -5,11 +5,12 @@ Integrates POI features into ML training pipeline.
Fetches POI context from External service and merges features into training data.
"""
import httpx
from typing import Dict, Any, Optional, List
import structlog
import pandas as pd
from shared.clients.external_client import ExternalServiceClient
logger = structlog.get_logger()
@@ -21,15 +22,18 @@ class POIFeatureIntegrator:
to training dataframes for location-based demand forecasting.
"""
def __init__(self, external_service_url: str = "http://external-service:8000"):
def __init__(self, external_client: ExternalServiceClient = None):
"""
Initialize POI feature integrator.
Args:
external_service_url: Base URL for external service
external_client: External service client instance (optional)
"""
self.external_service_url = external_service_url.rstrip("/")
self.poi_context_endpoint = f"{self.external_service_url}/poi-context"
if external_client is None:
from app.core.config import settings
self.external_client = ExternalServiceClient(settings, "training-service")
else:
self.external_client = external_client
async def fetch_poi_features(
self,
@@ -53,57 +57,49 @@ class POIFeatureIntegrator:
Dictionary with POI features or None if detection fails
"""
try:
async with httpx.AsyncClient(timeout=60.0) as client:
# Try to get existing POI context first
if not force_refresh:
try:
response = await client.get(
f"{self.poi_context_endpoint}/{tenant_id}"
)
if response.status_code == 200:
data = response.json()
poi_context = data.get("poi_context", {})
# Try to get existing POI context first
if not force_refresh:
existing_context = await self.external_client.get_poi_context(tenant_id)
if existing_context:
poi_context = existing_context.get("poi_context", {})
ml_features = poi_context.get("ml_features", {})
# Check if stale
if not data.get("is_stale", False):
logger.info(
"Using existing POI context",
tenant_id=tenant_id
)
return poi_context.get("ml_features", {})
else:
logger.info(
"POI context is stale, refreshing",
tenant_id=tenant_id
)
force_refresh = True
except httpx.HTTPStatusError as e:
if e.response.status_code != 404:
raise
# Check if stale
is_stale = existing_context.get("is_stale", False)
if not is_stale:
logger.info(
"No existing POI context, will detect",
"Using existing POI context",
tenant_id=tenant_id
)
return ml_features
else:
logger.info(
"POI context is stale, refreshing",
tenant_id=tenant_id
)
force_refresh = True
else:
logger.info(
"No existing POI context, will detect",
tenant_id=tenant_id
)
# Detect or refresh POIs
logger.info(
"Detecting POIs for tenant",
tenant_id=tenant_id,
location=(latitude, longitude)
)
# Detect or refresh POIs
logger.info(
"Detecting POIs for tenant",
tenant_id=tenant_id,
location=(latitude, longitude)
)
response = await client.post(
f"{self.poi_context_endpoint}/{tenant_id}/detect",
params={
"latitude": latitude,
"longitude": longitude,
"force_refresh": force_refresh
}
)
response.raise_for_status()
detection_result = await self.external_client.detect_poi_for_tenant(
tenant_id=tenant_id,
latitude=latitude,
longitude=longitude,
force_refresh=force_refresh
)
result = response.json()
poi_context = result.get("poi_context", {})
if detection_result:
poi_context = detection_result.get("poi_context", {})
ml_features = poi_context.get("ml_features", {})
logger.info(
@@ -114,15 +110,13 @@ class POIFeatureIntegrator:
)
return ml_features
else:
logger.error(
"POI detection failed",
tenant_id=tenant_id
)
return None
except httpx.HTTPError as e:
logger.error(
"Failed to fetch POI features",
tenant_id=tenant_id,
error=str(e),
exc_info=True
)
return None
except Exception as e:
logger.error(
"Unexpected error fetching POI features",
@@ -185,17 +179,18 @@ class POIFeatureIntegrator:
async def check_poi_service_health(self) -> bool:
"""
Check if POI service is accessible.
Check if POI service is accessible through the external client.
Returns:
True if service is healthy, False otherwise
"""
try:
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.get(
f"{self.poi_context_endpoint}/health"
)
return response.status_code == 200
# We can test the external service health by attempting to get POI context for a dummy tenant
# This will go through the proper authentication and routing
dummy_context = await self.external_client.get_poi_context("test-tenant")
# If we can successfully make a request (even if it returns None for missing tenant),
# it means the service is accessible
return True
except Exception as e:
logger.error(
"POI service health check failed",

View File

@@ -375,40 +375,143 @@ class EnhancedBakeryMLTrainer:
try:
# Use provided session or create new one to prevent nested sessions and deadlocks
should_create_session = session is None
db_session = session if session is not None else None
if should_create_session:
# Only create a session if one wasn't provided
async with self.database_manager.get_session() as db_session:
repos = await self._get_repositories(db_session)
# Validate input data
if training_data.empty or len(training_data) < settings.MIN_TRAINING_DATA_DAYS:
raise ValueError(f"Insufficient training data: need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(training_data)}")
# Validate required columns
required_columns = ['ds', 'y']
missing_cols = [col for col in required_columns if col not in training_data.columns]
if missing_cols:
raise ValueError(f"Missing required columns in training data: {missing_cols}")
# Create a simple progress tracker for single product
from app.services.progress_tracker import ParallelProductProgressTracker
progress_tracker = ParallelProductProgressTracker(
job_id=job_id,
tenant_id=tenant_id,
total_products=1
)
# Ensure training data has proper data types before training
if 'ds' in training_data.columns:
training_data['ds'] = pd.to_datetime(training_data['ds'])
if 'y' in training_data.columns:
training_data['y'] = pd.to_numeric(training_data['y'], errors='coerce')
# Remove any rows with NaN values
training_data = training_data.dropna()
# Train the model using the existing _train_single_product method
product_id, result = await self._train_single_product(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
product_data=training_data,
job_id=job_id,
repos=repos,
progress_tracker=progress_tracker,
session=db_session # Pass the session to prevent nested sessions
)
logger.info("Single product training completed",
job_id=job_id,
inventory_product_id=inventory_product_id,
result_status=result.get('status'))
# Write training result to database (create model record)
if result.get('status') == 'success':
model_info = result.get('model_info')
product_data = result.get('product_data')
if model_info and product_data is not None:
# Create model record in database
model_record = await self._create_model_record(
repos, tenant_id, inventory_product_id, model_info, job_id, product_data
)
# Create performance metrics
if model_info.get('training_metrics') and model_record:
await self._create_performance_metrics(
repos, model_record.id,
tenant_id, inventory_product_id, model_info['training_metrics']
)
# Update result with model_record_id
result['model_record_id'] = str(model_record.id) if model_record else None
# Get training metrics and filter out non-numeric values
raw_metrics = result.get('model_info', {}).get('training_metrics', {})
# Filter metrics to only include numeric values (per Pydantic schema requirement)
filtered_metrics = {}
for key, value in raw_metrics.items():
if key == 'product_category':
# Skip product_category as it's a string value, not a numeric metric
continue
try:
# Try to convert to float for validation
filtered_metrics[key] = float(value) if value is not None else 0.0
except (ValueError, TypeError):
# Skip non-numeric values
continue
# Return appropriate result format
result_dict = {
"job_id": job_id,
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id,
"status": result.get('status', 'success'),
"model_id": str(result.get('model_record_id', '')) if result.get('model_record_id') else None,
"training_metrics": filtered_metrics,
"training_time": result.get('training_time_seconds', 0),
"data_points": result.get('data_points', 0),
"message": f"Single product model training {'completed' if result.get('status') != 'error' else 'failed'}"
}
# Only commit if this is our own session (not a parent session)
# Commit after we're done with all database operations
await db_session.commit()
logger.info("Committed single product model record to database",
inventory_product_id=inventory_product_id,
model_record_id=result.get('model_record_id'))
return result_dict
else:
# Use the provided session
repos = await self._get_repositories(session)
# Validate input data
if training_data.empty or len(training_data) < settings.MIN_TRAINING_DATA_DAYS:
raise ValueError(f"Insufficient training data: need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(training_data)}")
# Validate required columns
required_columns = ['ds', 'y']
missing_cols = [col for col in required_columns if col not in training_data.columns]
if missing_cols:
raise ValueError(f"Missing required columns in training data: {missing_cols}")
# Create a simple progress tracker for single product
from app.services.progress_tracker import ParallelProductProgressTracker
progress_tracker = ParallelProductProgressTracker(
job_id=job_id,
tenant_id=tenant_id,
job_id=job_id,
tenant_id=tenant_id,
total_products=1
)
# Ensure training data has proper data types before training
if 'ds' in training_data.columns:
training_data['ds'] = pd.to_datetime(training_data['ds'])
if 'y' in training_data.columns:
training_data['y'] = pd.to_numeric(training_data['y'], errors='coerce')
# Remove any rows with NaN values
training_data = training_data.dropna()
# Train the model using the existing _train_single_product method
product_id, result = await self._train_single_product(
tenant_id=tenant_id,
@@ -417,14 +520,35 @@ class EnhancedBakeryMLTrainer:
job_id=job_id,
repos=repos,
progress_tracker=progress_tracker,
session=db_session # Pass the session to prevent nested sessions
session=session # Pass the provided session
)
logger.info("Single product training completed",
job_id=job_id,
inventory_product_id=inventory_product_id,
result_status=result.get('status'))
# Write training result to database (create model record)
if result.get('status') == 'success':
model_info = result.get('model_info')
product_data = result.get('product_data')
if model_info and product_data is not None:
# Create model record in database
model_record = await self._create_model_record(
repos, tenant_id, inventory_product_id, model_info, job_id, product_data
)
# Create performance metrics
if model_info.get('training_metrics') and model_record:
await self._create_performance_metrics(
repos, model_record.id,
tenant_id, inventory_product_id, model_info['training_metrics']
)
# Update result with model_record_id
result['model_record_id'] = str(model_record.id) if model_record else None
# Get training metrics and filter out non-numeric values
raw_metrics = result.get('model_info', {}).get('training_metrics', {})
# Filter metrics to only include numeric values (per Pydantic schema requirement)
@@ -439,7 +563,7 @@ class EnhancedBakeryMLTrainer:
except (ValueError, TypeError):
# Skip non-numeric values
continue
# Return appropriate result format
result_dict = {
"job_id": job_id,
@@ -452,7 +576,13 @@ class EnhancedBakeryMLTrainer:
"data_points": result.get('data_points', 0),
"message": f"Single product model training {'completed' if result.get('status') != 'error' else 'failed'}"
}
# For provided sessions, do NOT commit here - let the calling method handle commits
# This prevents committing a parent transaction prematurely
logger.info("Single product model processed (commit handled by caller)",
inventory_product_id=inventory_product_id,
model_record_id=result.get('model_record_id'))
return result_dict
except Exception as e:

View File

@@ -186,6 +186,7 @@ class TrainedModel(Base):
"training_samples": self.training_samples,
"hyperparameters": self.hyperparameters,
"features_used": self.features_used,
"features": self.features_used, # Alias for frontend compatibility (ModelDetailsModal expects 'features')
"product_category": self.product_category,
"is_active": self.is_active,
"is_production": self.is_production,

View File

@@ -732,297 +732,293 @@ class EnhancedTrainingService:
current_step: str = None,
error_message: str = None,
results: Dict = None,
tenant_id: str = None):
"""Update job status using repository pattern"""
tenant_id: str = None,
session = None):
"""Update job status using repository pattern
Args:
session: Optional database session to reuse. If None, creates a new session.
"""
try:
async with self.database_manager.get_session() as session:
await self._init_repositories(session)
# Use provided session or create new one
should_create_session = session is None
# Check if log exists, create if not
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
if not existing_log:
# Create initial log entry
if not tenant_id:
# Extract tenant_id from job_id if not provided
# Format: enhanced_training_{tenant_id}_{job_suffix}
try:
parts = job_id.split('_')
if len(parts) >= 3 and parts[0] == 'enhanced' and parts[1] == 'training':
tenant_id = parts[2]
except Exception:
logger.warning(f"Could not extract tenant_id from job_id {job_id}")
if tenant_id:
log_data = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": status or "pending",
"progress": progress or 0,
"current_step": current_step or "initializing",
"start_time": datetime.now(timezone.utc)
}
if error_message:
log_data["error_message"] = error_message
if results:
# Ensure results are JSON-serializable before storing
log_data["results"] = make_json_serializable(results)
try:
await self.training_log_repo.create_training_log(log_data)
await session.commit() # Explicit commit so other sessions can see it
logger.info("Created initial training log", job_id=job_id, tenant_id=tenant_id)
except Exception as create_error:
# Handle race condition: another session may have created the log
if "unique constraint" in str(create_error).lower() or "duplicate" in str(create_error).lower():
logger.debug("Training log already exists (race condition), querying again", job_id=job_id)
await session.rollback()
# Query again to get the existing log
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
if existing_log:
# Update the existing log instead
await self.training_log_repo.update_log_progress(
job_id=job_id,
progress=progress,
current_step=current_step,
status=status
)
await session.commit()
else:
raise
else:
logger.error("Cannot create training log without tenant_id", job_id=job_id)
return
else:
# Update existing log
await self.training_log_repo.update_log_progress(
job_id=job_id,
progress=progress,
current_step=current_step,
status=status
if should_create_session:
async with self.database_manager.get_session() as session:
await self._init_repositories(session)
await self._update_job_status_impl(
session, job_id, status, progress, current_step,
error_message, results, tenant_id
)
# Update additional fields if provided
if error_message or results:
update_data = {}
if error_message:
update_data["error_message"] = error_message
if results:
# Ensure results are JSON-serializable before storing
update_data["results"] = make_json_serializable(results)
if status in ["completed", "failed"]:
update_data["end_time"] = datetime.now(timezone.utc)
if update_data:
await self.training_log_repo.update(existing_log.id, update_data)
await session.commit() # Explicit commit after updates
else:
# Reuse provided session (don't commit - let caller control transaction)
await self._init_repositories(session)
await self._update_job_status_impl(
session, job_id, status, progress, current_step,
error_message, results, tenant_id, auto_commit=False
)
except Exception as e:
logger.error("Failed to update job status using repository",
job_id=job_id,
error=str(e))
async def _update_job_status_impl(self,
session,
job_id: str,
status: str,
progress: int = None,
current_step: str = None,
error_message: str = None,
results: Dict = None,
tenant_id: str = None,
auto_commit: bool = True):
"""Implementation of job status update"""
# Check if log exists, create if not
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
if not existing_log:
# Create initial log entry
if not tenant_id:
# Extract tenant_id from job_id if not provided
# Format: enhanced_training_{tenant_id}_{job_suffix}
try:
parts = job_id.split('_')
if len(parts) >= 3 and parts[0] == 'enhanced' and parts[1] == 'training':
tenant_id = parts[2]
except Exception:
logger.warning(f"Could not extract tenant_id from job_id {job_id}")
if tenant_id:
log_data = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": status or "pending",
"progress": progress or 0,
"current_step": current_step or "initializing",
"start_time": datetime.now(timezone.utc)
}
if error_message:
log_data["error_message"] = error_message
if results:
# Ensure results are JSON-serializable before storing
log_data["results"] = make_json_serializable(results)
try:
await self.training_log_repo.create_training_log(log_data)
if auto_commit:
await session.commit() # Explicit commit so other sessions can see it
logger.info("Created initial training log", job_id=job_id, tenant_id=tenant_id)
except Exception as create_error:
# Handle race condition: another session may have created the log
if "unique constraint" in str(create_error).lower() or "duplicate" in str(create_error).lower():
logger.debug("Training log already exists (race condition), querying again", job_id=job_id)
await session.rollback()
# Query again to get the existing log
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
if existing_log:
# Update the existing log instead
await self.training_log_repo.update_log_progress(
job_id=job_id,
progress=progress,
current_step=current_step,
status=status
)
if auto_commit:
await session.commit()
else:
raise
else:
logger.error("Cannot create training log without tenant_id", job_id=job_id)
return
else:
# Update existing log
await self.training_log_repo.update_log_progress(
job_id=job_id,
progress=progress,
current_step=current_step,
status=status
)
# Update additional fields if provided
if error_message or results:
update_data = {}
if error_message:
update_data["error_message"] = error_message
if results:
# Ensure results are JSON-serializable before storing
update_data["results"] = make_json_serializable(results)
if status in ["completed", "failed"]:
update_data["end_time"] = datetime.now(timezone.utc)
if update_data:
await self.training_log_repo.update(existing_log.id, update_data)
if auto_commit:
await session.commit() # Explicit commit after updates
async def start_single_product_training(self,
tenant_id: str,
inventory_product_id: str,
job_id: str,
bakery_location: tuple = (40.4168, -3.7038)) -> Dict[str, Any]:
"""Start enhanced single product training using repository pattern"""
try:
logger.info("Starting enhanced single product training",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
job_id=job_id)
# Create initial training log
await self._update_job_status_repository(
job_id=job_id,
status="running",
progress=0,
current_step="Fetching training data",
tenant_id=tenant_id
)
# Prepare training data for all products to get weather/traffic data
# then filter down to the specific product
training_dataset = await self.orchestrator.prepare_training_data(
tenant_id=tenant_id,
bakery_location=bakery_location,
job_id=job_id + "_temp"
)
# Filter sales data to the specific product
sales_df = pd.DataFrame(training_dataset.sales_data)
product_sales_df = sales_df[sales_df['inventory_product_id'] == inventory_product_id]
if product_sales_df.empty:
raise ValueError(f"No sales data available for product {inventory_product_id}")
# Prepare the data in Prophet format (ds and y columns)
# Ensure proper column names and types for Prophet
product_data = product_sales_df.copy()
product_data = product_data.rename(columns={
'sale_date': 'ds', # Common sales date column
'sale_datetime': 'ds', # Alternative date column
'date': 'ds', # Alternative date column
'quantity': 'y', # Quantity sold
'total_amount': 'y', # Alternative for sales data
'sales_amount': 'y', # Alternative for sales data
'sale_amount': 'y' # Alternative for sales data
})
# If 'ds' and 'y' columns are not renamed properly, try to infer them
if 'ds' not in product_data.columns:
# Try to find date-like columns
date_cols = [col for col in product_data.columns if 'date' in col.lower() or 'time' in col.lower()]
if date_cols:
product_data = product_data.rename(columns={date_cols[0]: 'ds'})
if 'y' not in product_data.columns:
# Try to find sales/quantity-like columns
sales_cols = [col for col in product_data.columns if
any(word in col.lower() for word in ['amount', 'quantity', 'sales', 'total', 'count', 'value'])]
if sales_cols:
product_data = product_data.rename(columns={sales_cols[0]: 'y'})
# Ensure required columns exist
if 'ds' not in product_data.columns or 'y' not in product_data.columns:
raise ValueError(f"Sales data must contain 'date' and 'quantity/sales' columns. Available columns: {list(product_data.columns)}")
# Convert the date column to datetime if it's not already
product_data['ds'] = pd.to_datetime(product_data['ds'])
# Convert to numeric ensuring no pandas/numpy objects remain
product_data['y'] = pd.to_numeric(product_data['y'], errors='coerce')
# Sort by date to ensure proper chronological order
product_data = product_data.sort_values('ds').reset_index(drop=True)
# Drop any rows with NaN values
product_data = product_data.dropna(subset=['ds', 'y'])
# Ensure the data is in the right format for Prophet
product_data = product_data[['ds', 'y']].copy()
"""Start enhanced single product training using repository pattern with single session"""
# Create a single database session for all operations to avoid connection pool exhaustion
async with self.database_manager.get_session() as session:
await self._init_repositories(session)
# Convert to pandas datetime and float types (keep as pandas Series for proper min/max operations)
product_data['ds'] = pd.to_datetime(product_data['ds'])
product_data['y'] = product_data['y'].astype(float)
# DEBUG: Log data types to diagnose dict comparison error
logger.info(f"DEBUG: product_data dtypes after conversion: ds={product_data['ds'].dtype}, y={product_data['y'].dtype}")
logger.info(f"DEBUG: product_data['ds'] sample values: {product_data['ds'].head(3).tolist()}")
logger.info(f"DEBUG: Attempting to get min/max...")
try:
min_val = product_data['ds'].min()
max_val = product_data['ds'].max()
logger.info(f"DEBUG: min_val type={type(min_val)}, value={min_val}")
logger.info(f"DEBUG: max_val type={type(max_val)}, value={max_val}")
except Exception as debug_e:
logger.error(f"DEBUG: Failed to get min/max: {debug_e}")
import traceback
logger.error(f"DEBUG: Traceback: {traceback.format_exc()}")
logger.info("Starting enhanced single product training",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
job_id=job_id)
logger.info("Prepared training data for single product",
inventory_product_id=inventory_product_id,
data_points=len(product_data),
date_range=f"{product_data['ds'].min()} to {product_data['ds'].max()}")
# Update progress
await self._update_job_status_repository(
job_id=job_id,
status="running",
progress=30,
current_step="Training model",
tenant_id=tenant_id
)
# Train the model using the trainer
# Extract datetime values with proper pandas Timestamp wrapper for type safety
try:
training_start = pd.Timestamp(product_data['ds'].min()).to_pydatetime()
training_end = pd.Timestamp(product_data['ds'].max()).to_pydatetime()
except Exception as e:
import traceback
logger.error(f"Failed to extract training dates: {e}")
logger.error(f"Traceback: {traceback.format_exc()}")
logger.error(f"product_data['ds'] dtype: {product_data['ds'].dtype}")
logger.error(f"product_data['ds'] first 5 values: {product_data['ds'].head().tolist()}")
raise
# Run the actual training
try:
model_info = await self.trainer.train_single_product_model(
# Create initial training log (using shared session)
await self._update_job_status_repository(
job_id=job_id,
status="running",
progress=0,
current_step="Fetching training data",
tenant_id=tenant_id,
session=session
)
await session.commit() # Commit after initial log creation
# Prepare training data for all products to get weather/traffic data
# then filter down to the specific product
training_dataset = await self.orchestrator.prepare_training_data(
tenant_id=tenant_id,
bakery_location=bakery_location,
job_id=job_id + "_temp"
)
# Use the enhanced data processor to prepare training data with all features (POI, weather, traffic)
# Filter sales data to the specific product first
sales_df = pd.DataFrame(training_dataset.sales_data)
product_sales_df = sales_df[sales_df['inventory_product_id'] == inventory_product_id]
if product_sales_df.empty:
raise ValueError(f"No sales data available for product {inventory_product_id}")
# Get weather and traffic data as DataFrames
weather_df = pd.DataFrame(training_dataset.weather_data)
traffic_df = pd.DataFrame(training_dataset.traffic_data)
# Get POI features from the training dataset (already collected by orchestrator)
poi_features = training_dataset.poi_features if hasattr(training_dataset, 'poi_features') else None
# Use the enhanced data processor to merge all features properly
# This will include POI, weather, traffic features along with ds and y
from app.ml.data_processor import EnhancedBakeryDataProcessor
data_processor = EnhancedBakeryDataProcessor(self.database_manager)
product_data = await data_processor.prepare_training_data(
sales_data=product_sales_df,
weather_data=weather_df,
traffic_data=traffic_df,
inventory_product_id=inventory_product_id,
training_data=product_data,
poi_features=poi_features,
tenant_id=tenant_id,
job_id=job_id
)
except Exception as e:
import traceback
logger.error(f"Training failed with error: {e}")
logger.error(f"Full traceback: {traceback.format_exc()}")
raise
# Update progress
await self._update_job_status_repository(
job_id=job_id,
status="running",
progress=80,
current_step="Saving model",
tenant_id=tenant_id
)
# The model should already be saved by train_single_product_model
# Return appropriate response
return {
"job_id": job_id,
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id,
"status": "completed",
"message": "Enhanced single product training completed successfully",
"created_at": datetime.now(timezone.utc),
"estimated_duration_minutes": 15, # Default estimate for single product
"training_results": {
"total_products": 1,
"successful_trainings": 1,
"failed_trainings": 0,
"products": [{
"inventory_product_id": inventory_product_id,
"status": "completed",
"model_id": str(model_info.get('model_id', f"model_{inventory_product_id}_{job_id[:8]}")) if model_info.get('model_id') else None,
"data_points": len(product_data) if product_data is not None else 0,
# Filter metrics to ensure only numeric values are included
"metrics": {
k: float(v) if not isinstance(v, (int, float)) else v
for k, v in model_info.get('training_metrics', {"mape": 0.0, "mae": 0.0, "rmse": 0.0, "r2_score": 0.0}).items()
if k != 'product_category' and v is not None
}
}],
"overall_training_time_seconds": model_info.get('training_time', 45.2)
},
"enhanced_features": True,
"repository_integration": True,
"completed_at": datetime.now(timezone.utc).isoformat()
}
except Exception as e:
logger.error("Enhanced single product training failed",
if product_data.empty:
raise ValueError(f"Data processor returned empty data for product {inventory_product_id}")
logger.info("Prepared training data for single product",
inventory_product_id=inventory_product_id,
data_points=len(product_data),
features=list(product_data.columns),
date_range=f"{product_data['ds'].min()} to {product_data['ds'].max()}")
# Update progress (using shared session)
await self._update_job_status_repository(
job_id=job_id,
status="running",
progress=30,
current_step="Training model",
tenant_id=tenant_id,
session=session
)
await session.commit() # Commit progress update
# Run the actual training (passing the session to avoid nested session creation)
try:
model_info = await self.trainer.train_single_product_model(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
error=str(e))
# Update status to failed
await self._update_job_status_repository(
job_id=job_id,
status="failed",
progress=0,
current_step="Training failed",
error_message=str(e),
tenant_id=tenant_id
)
raise
training_data=product_data,
job_id=job_id,
session=session # ✅ CRITICAL FIX: Pass session to prevent deadlock
)
except Exception as e:
import traceback
logger.error(f"Training failed with error: {e}")
logger.error(f"Full traceback: {traceback.format_exc()}")
raise
# Update progress (using shared session)
await self._update_job_status_repository(
job_id=job_id,
status="running",
progress=80,
current_step="Saving model",
tenant_id=tenant_id,
session=session
)
await session.commit() # Commit progress update
# The model should already be saved by train_single_product_model
# Return appropriate response
return {
"job_id": job_id,
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id,
"status": "completed",
"message": "Enhanced single product training completed successfully",
"created_at": datetime.now(timezone.utc),
"estimated_duration_minutes": 15, # Default estimate for single product
"training_results": {
"total_products": 1,
"successful_trainings": 1,
"failed_trainings": 0,
"products": [{
"inventory_product_id": inventory_product_id,
"status": "completed",
"model_id": str(model_info.get('model_id', f"model_{inventory_product_id}_{job_id[:8]}")) if model_info.get('model_id') else None,
"data_points": len(product_data) if product_data is not None else 0,
# Filter metrics to ensure only numeric values are included
"metrics": {
k: float(v) if not isinstance(v, (int, float)) else v
for k, v in model_info.get('training_metrics', {"mape": 0.0, "mae": 0.0, "rmse": 0.0, "r2_score": 0.0}).items()
if k != 'product_category' and v is not None
}
}],
"overall_training_time_seconds": model_info.get('training_time', 45.2)
},
"enhanced_features": True,
"repository_integration": True,
"completed_at": datetime.now(timezone.utc).isoformat()
}
except Exception as e:
logger.error("Enhanced single product training failed",
inventory_product_id=inventory_product_id,
error=str(e))
# Update status to failed (using shared session)
await self._update_job_status_repository(
job_id=job_id,
status="failed",
progress=0,
current_step="Training failed",
error_message=str(e),
tenant_id=tenant_id,
session=session
)
await session.commit() # Commit failure status
raise
def _create_detailed_training_response(self, final_result: Dict[str, Any]) -> Dict[str, Any]:
"""Convert final result to detailed training response"""

View File

@@ -494,6 +494,75 @@ class ExternalServiceClient(BaseServiceClient):
# POI (POINT OF INTEREST) DATA
# ================================================================
async def detect_poi_for_tenant(
self,
tenant_id: str,
latitude: float,
longitude: float,
force_refresh: bool = False
) -> Optional[Dict[str, Any]]:
"""
Detect POIs for a tenant's location and generate ML features for forecasting.
With the new tenant-based architecture:
- Gateway receives at: /api/v1/tenants/{tenant_id}/external/poi-context/detect
- Gateway proxies to external service at: /api/v1/tenants/{tenant_id}/poi-context/detect
- This client calls: poi-context/detect (base client automatically constructs with tenant)
This triggers POI detection using Overpass API and calculates ML features
for demand forecasting.
Args:
tenant_id: Tenant ID
latitude: Latitude of the bakery location
longitude: Longitude of the bakery location
force_refresh: Whether to force refresh even if POI context exists
Returns:
Dict with POI detection results including:
- ml_features: Dict of POI features for ML models (e.g., poi_retail_total_count)
- poi_detection_results: Full detection results
- location: Latitude/longitude
- total_pois_detected: Count of POIs
"""
logger.info(
"Detecting POIs for tenant",
tenant_id=tenant_id,
location=(latitude, longitude),
force_refresh=force_refresh
)
params = {
"latitude": latitude,
"longitude": longitude,
"force_refresh": force_refresh
}
# Updated endpoint path to follow tenant-based pattern: external/poi-context/detect
result = await self._make_request(
"POST",
"external/poi-context/detect", # Path will become /api/v1/tenants/{tenant_id}/external/poi-context/detect by base client
tenant_id=tenant_id, # Pass tenant_id to include in headers and path construction
params=params,
timeout=60.0 # POI detection can take longer
)
if result:
poi_context = result.get("poi_context", {})
ml_features = poi_context.get("ml_features", {})
logger.info(
"POI detection completed successfully",
tenant_id=tenant_id,
total_pois=poi_context.get("total_pois_detected", 0),
ml_features_count=len(ml_features),
source=result.get("source", "unknown")
)
return result
else:
logger.warning("POI detection failed for tenant", tenant_id=tenant_id)
return None
async def get_poi_context(
self,
tenant_id: str
@@ -504,7 +573,7 @@ class ExternalServiceClient(BaseServiceClient):
With the new tenant-based architecture:
- Gateway receives at: /api/v1/tenants/{tenant_id}/external/poi-context
- Gateway proxies to external service at: /api/v1/tenants/{tenant_id}/poi-context
- This client calls: /tenants/{tenant_id}/poi-context
- This client calls: poi-context (base client automatically constructs with tenant)
This retrieves stored POI detection results and calculated ML features
that should be included in demand forecasting predictions.
@@ -521,11 +590,11 @@ class ExternalServiceClient(BaseServiceClient):
"""
logger.info("Fetching POI context for forecasting", tenant_id=tenant_id)
# Updated endpoint path to follow tenant-based pattern: /tenants/{tenant_id}/poi-context
# Updated endpoint path to follow tenant-based pattern: external/poi-context
result = await self._make_request(
"GET",
f"tenants/{tenant_id}/poi-context", # Updated path: /tenants/{tenant_id}/poi-context
tenant_id=tenant_id, # Pass tenant_id to include in headers for authentication
"external/poi-context", # Path will become /api/v1/tenants/{tenant_id}/external/poi-context by base client
tenant_id=tenant_id, # Pass tenant_id to include in headers and path construction
timeout=5.0
)