Files
bakery-ia/services/training/app/services/training_service.py
Urtzi Alfaro c07df124fb Improve UI
2025-12-30 14:40:20 +01:00

1076 lines
48 KiB
Python

"""
Enhanced Training Service with Repository Pattern
Main training service that uses the repository pattern for data access
"""
from typing import Dict, List, Any, Optional
import uuid
import structlog
from datetime import datetime, date, timezone
from decimal import Decimal
from sqlalchemy.ext.asyncio import AsyncSession
import json
import numpy as np
import pandas as pd
from app.ml.trainer import EnhancedBakeryMLTrainer
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType
from app.services.training_orchestrator import TrainingDataOrchestrator
from app.core.training_constants import (
PROGRESS_DATA_VALIDATION,
PROGRESS_DATA_PREPARATION_COMPLETE,
PROGRESS_ML_TRAINING_START,
PROGRESS_TRAINING_COMPLETE,
PROGRESS_STORING_MODELS,
PROGRESS_STORING_METRICS,
MAX_ESTIMATED_TIME_REMAINING_SECONDS
)
# Import repositories
from app.repositories import (
ModelRepository,
TrainingLogRepository,
PerformanceRepository,
JobQueueRepository,
ArtifactRepository
)
# Import shared database components
from shared.database.unit_of_work import UnitOfWork
from shared.database.transactions import transactional
from shared.database.exceptions import DatabaseError
from app.core.database import database_manager
logger = structlog.get_logger()
def make_json_serializable(obj):
"""Convert numpy/pandas types, datetime, and UUID objects to JSON-serializable Python types"""
# Handle None values
if obj is None:
return None
# Handle basic datetime types first (most common)
if isinstance(obj, datetime):
return obj.isoformat()
elif isinstance(obj, date):
return obj.isoformat()
# Handle pandas timestamp types
if hasattr(pd, 'Timestamp') and isinstance(obj, pd.Timestamp):
return obj.isoformat()
# Handle numpy datetime types
if hasattr(np, 'datetime64') and isinstance(obj, np.datetime64):
return pd.Timestamp(obj).isoformat()
# Handle numeric types
if isinstance(obj, (np.integer, pd.Int64Dtype)):
return int(obj)
elif isinstance(obj, (np.floating, pd.Float64Dtype)):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, pd.Series):
return obj.tolist()
elif isinstance(obj, pd.DataFrame):
return obj.to_dict('records')
elif isinstance(obj, Decimal):
return float(obj)
# Handle UUID types
elif isinstance(obj, uuid.UUID):
return str(obj)
elif hasattr(obj, '__class__') and 'UUID' in str(obj.__class__):
# Handle any UUID-like objects (including asyncpg.pgproto.pgproto.UUID)
return str(obj)
# Handle collections recursively
elif isinstance(obj, dict):
return {k: make_json_serializable(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)):
return [make_json_serializable(item) for item in obj]
elif isinstance(obj, set):
return [make_json_serializable(item) for item in obj]
# Handle other common types
elif isinstance(obj, (str, int, float, bool)):
return obj
# Last resort: try to convert to string
else:
try:
# For any other object, try to convert to string
return str(obj)
except Exception:
# If all else fails, return None
return None
class EnhancedTrainingService:
"""
Enhanced training service using repository pattern.
Coordinates the complete training pipeline with proper data abstraction.
"""
def __init__(self, session: AsyncSession = None):
self.session = session
self.database_manager = database_manager
# Initialize repositories
if session:
self.model_repo = ModelRepository(session)
self.training_log_repo = TrainingLogRepository(session)
self.performance_repo = PerformanceRepository(session)
self.queue_repo = JobQueueRepository(session)
self.artifact_repo = ArtifactRepository(session)
# Initialize training components
self.trainer = EnhancedBakeryMLTrainer(database_manager=self.database_manager)
self.date_alignment_service = DateAlignmentService()
self.orchestrator = TrainingDataOrchestrator(
date_alignment_service=self.date_alignment_service
)
async def _init_repositories(self, session: AsyncSession):
"""Initialize repositories with session"""
self.model_repo = ModelRepository(session)
self.training_log_repo = TrainingLogRepository(session)
self.performance_repo = PerformanceRepository(session)
self.queue_repo = JobQueueRepository(session)
self.artifact_repo = ArtifactRepository(session)
async def start_training_job(
self,
tenant_id: str,
bakery_location: tuple[float, float] = (40.4168, -3.7038),
requested_start: Optional[datetime] = None,
requested_end: Optional[datetime] = None,
job_id: Optional[str] = None
) -> Dict[str, Any]:
"""
Start a complete training job for a tenant using repository pattern.
Args:
tenant_id: Tenant identifier
bakery_location: Bakery coordinates (lat, lon)
requested_start: Optional explicit start date
requested_end: Optional explicit end date
job_id: Optional job identifier
Returns:
Training job results
"""
if not job_id:
job_id = f"training_{tenant_id}_{uuid.uuid4().hex[:8]}"
logger.info("Starting enhanced training job",
job_id=job_id,
tenant_id=tenant_id)
# Get session and initialize repositories
async with self.database_manager.get_session() as session:
await self._init_repositories(session)
try:
# Check if training log already exists, update if found, create if not
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
if existing_log:
logger.info("Training log already exists, updating status", job_id=job_id)
training_log = await self.training_log_repo.update_log_progress(
job_id, 0, "initializing", "running"
)
await session.commit()
else:
# Create new training log entry
log_data = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "running",
"progress": 0,
"current_step": "initializing"
}
try:
training_log = await self.training_log_repo.create_training_log(log_data)
await session.commit() # Explicit commit so other sessions can see it
except Exception as create_error:
# Handle race condition: log may have been created by another session
if "unique constraint" in str(create_error).lower() or "duplicate" in str(create_error).lower():
logger.debug("Training log already exists (race condition), updating instead", job_id=job_id)
await session.rollback()
training_log = await self.training_log_repo.update_log_progress(
job_id, 0, "initializing", "running"
)
await session.commit()
else:
raise
# Step 1: Prepare training dataset (includes sales data validation)
logger.info("Step 1: Preparing and aligning training data (with validation)")
await self.training_log_repo.update_log_progress(
job_id, PROGRESS_DATA_VALIDATION, "data_validation", "running"
)
# Orchestrator now handles sales data validation to eliminate duplicate fetching
training_dataset = await self.orchestrator.prepare_training_data(
tenant_id=tenant_id,
bakery_location=bakery_location,
requested_start=requested_start,
requested_end=requested_end,
job_id=job_id
)
# Log the results from orchestrator's unified sales data fetch
logger.info(f"Sales data validation completed: {len(training_dataset.sales_data)} records",
tenant_id=tenant_id, job_id=job_id)
await self.training_log_repo.update_log_progress(
job_id, PROGRESS_DATA_PREPARATION_COMPLETE, "data_preparation_complete", "running"
)
# Step 2: Execute ML training pipeline
logger.info("Step 2: Starting ML training pipeline")
await self.training_log_repo.update_log_progress(
job_id, PROGRESS_ML_TRAINING_START, "ml_training", "running"
)
# ✅ FIX: Commit the session to prevent deadlock with trainer's nested session
# The trainer creates its own session, so we need to ensure this update is committed
await session.commit()
logger.debug("Committed session after ml_training progress update")
training_results = await self.trainer.train_tenant_models(
tenant_id=tenant_id,
training_dataset=training_dataset,
job_id=job_id,
session=session # Pass the main session to avoid nested sessions
)
await self.training_log_repo.update_log_progress(
job_id, PROGRESS_TRAINING_COMPLETE, "training_complete", "running"
)
# Publish progress event (85%)
from app.services.training_events import publish_training_progress
await publish_training_progress(
job_id=job_id,
tenant_id=tenant_id,
progress=PROGRESS_TRAINING_COMPLETE,
current_step="Training Complete",
step_details="All products trained successfully"
)
# Step 3: Store model records using repository
logger.info("Step 3: Storing model records")
logger.debug("Training results structure",
keys=list(training_results.keys()) if isinstance(training_results, dict) else "not_dict",
training_results_type=type(training_results).__name__)
stored_models = await self._store_trained_models(
tenant_id, job_id, training_results
)
await self.training_log_repo.update_log_progress(
job_id, PROGRESS_STORING_MODELS, "storing_models", "running"
)
# Publish progress event (92%)
await publish_training_progress(
job_id=job_id,
tenant_id=tenant_id,
progress=PROGRESS_STORING_MODELS,
current_step="Storing Models",
step_details=f"Saved {len(stored_models)} trained models to database"
)
# Step 4: Create performance metrics
await self.training_log_repo.update_log_progress(
job_id, PROGRESS_STORING_METRICS, "storing_performance_metrics", "running"
)
# Publish progress event (94%)
await publish_training_progress(
job_id=job_id,
tenant_id=tenant_id,
progress=PROGRESS_STORING_METRICS,
current_step="Storing Performance Metrics",
step_details="Saving model performance metrics"
)
await self._create_performance_metrics(
tenant_id, stored_models, training_results
)
# Step 4.5: Save training performance metrics for future estimations
await self._save_training_performance_metrics(
tenant_id, job_id, training_results, training_log
)
# Step 5: Complete training log
final_result = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "completed",
"training_results": training_results,
"stored_models": [{
"id": str(model.id),
"inventory_product_id": str(model.inventory_product_id),
"model_type": model.model_type,
"model_path": model.model_path,
"is_active": model.is_active,
"training_samples": model.training_samples
} for model in stored_models],
"data_summary": {
"sales_records": int(len(training_dataset.sales_data)),
"weather_records": int(len(training_dataset.weather_data)),
"traffic_records": int(len(training_dataset.traffic_data)),
"date_range": {
"start": training_dataset.date_range.start.isoformat(),
"end": training_dataset.date_range.end.isoformat()
},
"data_sources_used": [source.value for source in training_dataset.date_range.available_sources],
"constraints_applied": training_dataset.date_range.constraints
},
"completed_at": datetime.now().isoformat()
}
# Make sure all data is JSON-serializable before saving to database
json_safe_result = make_json_serializable(final_result)
# Ensure results is a proper dict for database storage
if not isinstance(json_safe_result, dict):
logger.warning("JSON safe result is not a dict, wrapping it", result_type=type(json_safe_result))
json_safe_result = {"training_data": json_safe_result}
# Double-check JSON serialization by attempting to serialize
import json
try:
json.dumps(json_safe_result)
logger.debug("Results successfully JSON-serializable", job_id=job_id)
except (TypeError, ValueError) as e:
logger.error("Results still not JSON-serializable after cleaning",
job_id=job_id, error=str(e))
# Create a minimal safe result
json_safe_result = {
"status": "completed",
"job_id": job_id,
"models_created": final_result.get("products_trained", 0),
"error": "Result serialization failed"
}
await self.training_log_repo.complete_training_log(
job_id, results=json_safe_result
)
# CRITICAL: Commit the session to persist the completed status to database
# Without this commit, the status update is lost when the session closes
await session.commit()
logger.info("Enhanced training job completed successfully",
job_id=job_id,
models_created=len(stored_models))
return self._create_detailed_training_response(final_result)
except Exception as e:
logger.error("Enhanced training job failed",
job_id=job_id,
error=str(e),
exc_info=True)
# Mark as failed in database
await self.training_log_repo.complete_training_log(
job_id, error_message=str(e)
)
# Commit the failure status to database
await session.commit()
error_result = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "failed",
"error_message": str(e),
"completed_at": datetime.now().isoformat()
}
# Ensure error result is JSON serializable
error_result = make_json_serializable(error_result)
return self._create_detailed_training_response(error_result)
async def _store_trained_models(
self,
tenant_id: str,
job_id: str,
training_results: Dict[str, Any]
) -> List:
"""
Retrieve or verify stored models from training results.
NOTE: Model records are now created by the trainer during parallel execution.
This method retrieves the already-created models instead of creating duplicates.
"""
stored_models = []
try:
# Check if models were already created by the trainer (new approach)
# The trainer now writes models sequentially after parallel training
training_results_dict = training_results.get("training_results", {})
# Get list of successfully trained products
successful_products = [
product_id for product_id, result in training_results_dict.items()
if result.get('status') == 'success' and result.get('model_record_id')
]
logger.info("Retrieving models created by trainer",
successful_products=len(successful_products),
job_id=job_id)
# Retrieve the models that were already created by the trainer
for product_id in successful_products:
result = training_results_dict[product_id]
model_record_id = result.get('model_record_id')
if model_record_id:
try:
# Get the model from the database using base repository method
model = await self.model_repo.get_by_id(model_record_id)
if model:
stored_models.append(model)
logger.debug("Retrieved model from database",
model_id=model_record_id,
inventory_product_id=product_id)
except Exception as e:
logger.warning("Could not retrieve model record",
model_id=model_record_id,
inventory_product_id=product_id,
error=str(e))
logger.info("Models retrieval complete",
models_retrieved=len(stored_models),
expected=len(successful_products))
return stored_models
except Exception as e:
logger.error("Failed to retrieve stored models",
tenant_id=tenant_id,
job_id=job_id,
error=str(e))
return stored_models
async def _create_performance_metrics(
self,
tenant_id: str,
stored_models: List,
training_results: Dict[str, Any]
):
"""
Verify performance metrics for stored models.
NOTE: Performance metrics are now created by the trainer during model creation.
This method now just verifies they exist rather than creating duplicates.
"""
try:
logger.info("Verifying performance metrics",
models_count=len(stored_models))
# Performance metrics are already created by the trainer
# This method is kept for compatibility but doesn't create duplicates
for model in stored_models:
logger.debug("Performance metrics already created for model",
model_id=str(model.id),
inventory_product_id=str(model.inventory_product_id))
logger.info("Performance metrics verification complete",
models_count=len(stored_models))
except Exception as e:
logger.error("Failed to verify performance metrics",
tenant_id=tenant_id,
error=str(e))
async def _save_training_performance_metrics(
self,
tenant_id: str,
job_id: str,
training_results: Dict[str, Any],
training_log
):
"""
Save aggregated training performance metrics for time estimation.
This data is used to predict future training durations.
"""
try:
from app.models.training import TrainingPerformanceMetrics
from shared.database.repository import BaseRepository
# Extract timing and success data
models_trained = training_results.get("models_trained", {})
total_products = len(models_trained)
successful_products = sum(1 for m in models_trained.values() if m.get("status") == "completed")
failed_products = total_products - successful_products
# Calculate total duration
if training_log.start_time and training_log.end_time:
total_duration_seconds = (training_log.end_time - training_log.start_time).total_seconds()
else:
# Fallback to elapsed time
total_duration_seconds = training_results.get("total_training_time", 0)
# Calculate average time per product
if successful_products > 0:
avg_time_per_product = total_duration_seconds / successful_products
else:
avg_time_per_product = 0
# Extract timing breakdown if available
data_analysis_time = training_results.get("data_analysis_time_seconds")
training_time = training_results.get("training_time_seconds")
finalization_time = training_results.get("finalization_time_seconds")
# Create performance metrics record
metric_data = {
"tenant_id": tenant_id,
"job_id": job_id,
"total_products": total_products,
"successful_products": successful_products,
"failed_products": failed_products,
"total_duration_seconds": total_duration_seconds,
"avg_time_per_product": avg_time_per_product,
"data_analysis_time_seconds": data_analysis_time,
"training_time_seconds": training_time,
"finalization_time_seconds": finalization_time,
"completed_at": datetime.now(timezone.utc)
}
# Create a temporary repository for the TrainingPerformanceMetrics model
# Use the session from one of the initialized repositories to ensure it's available
session = self.model_repo.session # This should be the same session used by all repositories
metrics_repo = BaseRepository(TrainingPerformanceMetrics, session)
# Use repository to create record
await metrics_repo.create(metric_data)
logger.info("Saved training performance metrics for future estimations",
tenant_id=tenant_id,
job_id=job_id,
avg_time_per_product=avg_time_per_product,
total_products=total_products,
successful_products=successful_products)
except Exception as e:
logger.error("Failed to save training performance metrics",
tenant_id=tenant_id,
job_id=job_id,
error=str(e))
async def get_training_status(self, job_id: str) -> Dict[str, Any]:
"""Get training job status using repository"""
try:
async with self.database_manager.get_session() as session:
await self._init_repositories(session)
log = await self.training_log_repo.get_log_by_job_id(job_id)
if not log:
return {"error": "Job not found"}
# Calculate estimated time remaining based on progress and elapsed time
estimated_time_remaining_seconds = None
if log.status == "running" and log.progress > 0 and log.start_time:
from datetime import datetime, timezone
elapsed_time = (datetime.now(timezone.utc) - log.start_time).total_seconds()
if elapsed_time > 0 and log.progress > 0: # Double-check progress is positive
# Calculate estimated total time based on progress
# Use max(log.progress, 1) as additional safety against division by zero
estimated_total_time = (elapsed_time / max(log.progress, 1)) * 100
estimated_time_remaining_seconds = int(estimated_total_time - elapsed_time)
# Cap at reasonable maximum (30 minutes)
estimated_time_remaining_seconds = max(0, min(estimated_time_remaining_seconds, MAX_ESTIMATED_TIME_REMAINING_SECONDS))
# Extract products info from results if available
products_total = 0
products_completed = 0
products_failed = 0
if log.results:
products_total = log.results.get("total_products", 0)
products_completed = log.results.get("successful_trainings", 0)
products_failed = log.results.get("failed_trainings", 0)
return {
"job_id": job_id,
"tenant_id": log.tenant_id,
"status": log.status,
"progress": log.progress,
"current_step": log.current_step,
"started_at": log.start_time.isoformat() if log.start_time else None,
"completed_at": log.end_time.isoformat() if log.end_time else None,
"error_message": log.error_message,
"results": log.results,
"products_total": products_total,
"products_completed": products_completed,
"products_failed": products_failed,
"estimated_time_remaining_seconds": estimated_time_remaining_seconds,
"message": log.current_step
}
except Exception as e:
logger.error("Failed to get training status",
job_id=job_id,
error=str(e))
return {"error": f"Failed to get status: {str(e)}"}
async def get_tenant_models(
self,
tenant_id: str,
active_only: bool = True,
skip: int = 0,
limit: int = 100
) -> List[Dict[str, Any]]:
"""Get models for a tenant using repository"""
try:
async with self.database_manager.get_session() as session:
await self._init_repositories(session)
if active_only:
models = await self.model_repo.get_multi(
filters={"tenant_id": tenant_id, "is_active": True},
skip=skip,
limit=limit,
order_by="created_at",
order_desc=True
)
else:
models = await self.model_repo.get_models_by_tenant(
tenant_id, skip=skip, limit=limit
)
return [model.to_dict() for model in models]
except Exception as e:
logger.error("Failed to get tenant models",
tenant_id=tenant_id,
error=str(e))
return []
async def get_model_performance(self, model_id: str) -> Dict[str, Any]:
"""Get model performance metrics using repository"""
try:
async with self.database_manager.get_session() as session:
await self._init_repositories(session)
# Get model summary
model_summary = await self.model_repo.get_model_performance_summary(model_id)
# Get latest performance metrics
latest_metric = await self.performance_repo.get_latest_metric_for_model(model_id)
if latest_metric:
model_summary["latest_metrics"] = {
"mae": latest_metric.mae,
"mse": latest_metric.mse,
"rmse": latest_metric.rmse,
"mape": latest_metric.mape,
"r2_score": latest_metric.r2_score,
"accuracy_percentage": latest_metric.accuracy_percentage,
"measured_at": latest_metric.measured_at.isoformat() if latest_metric.measured_at else None
}
return model_summary
except Exception as e:
logger.error("Failed to get model performance",
model_id=model_id,
error=str(e))
return {}
async def get_tenant_statistics(self, tenant_id: str) -> Dict[str, Any]:
"""Get comprehensive tenant statistics using repositories"""
try:
async with self.database_manager.get_session() as session:
await self._init_repositories(session)
# Get model statistics
model_stats = await self.model_repo.get_model_statistics(tenant_id)
# Get job statistics
job_stats = await self.training_log_repo.get_job_statistics(tenant_id)
# Get performance trends
performance_trends = await self.performance_repo.get_performance_trends(tenant_id)
# Get queue status
queue_status = await self.queue_repo.get_queue_status(tenant_id)
# Get artifact statistics
artifact_stats = await self.artifact_repo.get_artifact_statistics(tenant_id)
return {
"tenant_id": tenant_id,
"models": model_stats,
"training_jobs": job_stats,
"performance": performance_trends,
"queue": queue_status,
"artifacts": artifact_stats,
"summary": {
"total_active_models": model_stats.get("active_models", 0),
"total_training_jobs": job_stats.get("total_jobs", 0),
"success_rate": job_stats.get("success_rate", 0.0),
"products_with_models": len(model_stats.get("models_by_product", {})),
"total_storage_mb": artifact_stats.get("total_storage", {}).get("total_size_mb", 0.0)
}
}
except Exception as e:
logger.error("Failed to get tenant statistics",
tenant_id=tenant_id,
error=str(e))
return {"error": f"Failed to get statistics: {str(e)}"}
async def _update_job_status_repository(self,
job_id: str,
status: str,
progress: int = None,
current_step: str = None,
error_message: str = None,
results: Dict = None,
tenant_id: str = None,
session = None):
"""Update job status using repository pattern
Args:
session: Optional database session to reuse. If None, creates a new session.
"""
try:
# Use provided session or create new one
should_create_session = session is None
if should_create_session:
async with self.database_manager.get_session() as session:
await self._init_repositories(session)
await self._update_job_status_impl(
session, job_id, status, progress, current_step,
error_message, results, tenant_id
)
else:
# Reuse provided session (don't commit - let caller control transaction)
await self._init_repositories(session)
await self._update_job_status_impl(
session, job_id, status, progress, current_step,
error_message, results, tenant_id, auto_commit=False
)
except Exception as e:
logger.error("Failed to update job status using repository",
job_id=job_id,
error=str(e))
async def _update_job_status_impl(self,
session,
job_id: str,
status: str,
progress: int = None,
current_step: str = None,
error_message: str = None,
results: Dict = None,
tenant_id: str = None,
auto_commit: bool = True):
"""Implementation of job status update"""
# Check if log exists, create if not
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
if not existing_log:
# Create initial log entry
if not tenant_id:
# Extract tenant_id from job_id if not provided
# Format: enhanced_training_{tenant_id}_{job_suffix}
try:
parts = job_id.split('_')
if len(parts) >= 3 and parts[0] == 'enhanced' and parts[1] == 'training':
tenant_id = parts[2]
except Exception:
logger.warning(f"Could not extract tenant_id from job_id {job_id}")
if tenant_id:
log_data = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": status or "pending",
"progress": progress or 0,
"current_step": current_step or "initializing",
"start_time": datetime.now(timezone.utc)
}
if error_message:
log_data["error_message"] = error_message
if results:
# Ensure results are JSON-serializable before storing
log_data["results"] = make_json_serializable(results)
try:
await self.training_log_repo.create_training_log(log_data)
if auto_commit:
await session.commit() # Explicit commit so other sessions can see it
logger.info("Created initial training log", job_id=job_id, tenant_id=tenant_id)
except Exception as create_error:
# Handle race condition: another session may have created the log
if "unique constraint" in str(create_error).lower() or "duplicate" in str(create_error).lower():
logger.debug("Training log already exists (race condition), querying again", job_id=job_id)
await session.rollback()
# Query again to get the existing log
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
if existing_log:
# Update the existing log instead
await self.training_log_repo.update_log_progress(
job_id=job_id,
progress=progress,
current_step=current_step,
status=status
)
if auto_commit:
await session.commit()
else:
raise
else:
logger.error("Cannot create training log without tenant_id", job_id=job_id)
return
else:
# Update existing log
await self.training_log_repo.update_log_progress(
job_id=job_id,
progress=progress,
current_step=current_step,
status=status
)
# Update additional fields if provided
if error_message or results:
update_data = {}
if error_message:
update_data["error_message"] = error_message
if results:
# Ensure results are JSON-serializable before storing
update_data["results"] = make_json_serializable(results)
if status in ["completed", "failed"]:
update_data["end_time"] = datetime.now(timezone.utc)
if update_data:
await self.training_log_repo.update(existing_log.id, update_data)
if auto_commit:
await session.commit() # Explicit commit after updates
async def start_single_product_training(self,
tenant_id: str,
inventory_product_id: str,
job_id: str,
bakery_location: tuple = (40.4168, -3.7038)) -> Dict[str, Any]:
"""Start enhanced single product training using repository pattern with single session"""
# Create a single database session for all operations to avoid connection pool exhaustion
async with self.database_manager.get_session() as session:
await self._init_repositories(session)
try:
logger.info("Starting enhanced single product training",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
job_id=job_id)
# Create initial training log (using shared session)
await self._update_job_status_repository(
job_id=job_id,
status="running",
progress=0,
current_step="Fetching training data",
tenant_id=tenant_id,
session=session
)
await session.commit() # Commit after initial log creation
# Prepare training data for all products to get weather/traffic data
# then filter down to the specific product
training_dataset = await self.orchestrator.prepare_training_data(
tenant_id=tenant_id,
bakery_location=bakery_location,
job_id=job_id + "_temp"
)
# Use the enhanced data processor to prepare training data with all features (POI, weather, traffic)
# Filter sales data to the specific product first
sales_df = pd.DataFrame(training_dataset.sales_data)
product_sales_df = sales_df[sales_df['inventory_product_id'] == inventory_product_id]
if product_sales_df.empty:
raise ValueError(f"No sales data available for product {inventory_product_id}")
# Get weather and traffic data as DataFrames
weather_df = pd.DataFrame(training_dataset.weather_data)
traffic_df = pd.DataFrame(training_dataset.traffic_data)
# Get POI features from the training dataset (already collected by orchestrator)
poi_features = training_dataset.poi_features if hasattr(training_dataset, 'poi_features') else None
# Use the enhanced data processor to merge all features properly
# This will include POI, weather, traffic features along with ds and y
from app.ml.data_processor import EnhancedBakeryDataProcessor
data_processor = EnhancedBakeryDataProcessor(self.database_manager)
product_data = await data_processor.prepare_training_data(
sales_data=product_sales_df,
weather_data=weather_df,
traffic_data=traffic_df,
inventory_product_id=inventory_product_id,
poi_features=poi_features,
tenant_id=tenant_id,
job_id=job_id
)
if product_data.empty:
raise ValueError(f"Data processor returned empty data for product {inventory_product_id}")
logger.info("Prepared training data for single product",
inventory_product_id=inventory_product_id,
data_points=len(product_data),
features=list(product_data.columns),
date_range=f"{product_data['ds'].min()} to {product_data['ds'].max()}")
# Update progress (using shared session)
await self._update_job_status_repository(
job_id=job_id,
status="running",
progress=30,
current_step="Training model",
tenant_id=tenant_id,
session=session
)
await session.commit() # Commit progress update
# Run the actual training (passing the session to avoid nested session creation)
try:
model_info = await self.trainer.train_single_product_model(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
training_data=product_data,
job_id=job_id,
session=session # ✅ CRITICAL FIX: Pass session to prevent deadlock
)
except Exception as e:
import traceback
logger.error(f"Training failed with error: {e}")
logger.error(f"Full traceback: {traceback.format_exc()}")
raise
# Update progress (using shared session)
await self._update_job_status_repository(
job_id=job_id,
status="running",
progress=80,
current_step="Saving model",
tenant_id=tenant_id,
session=session
)
await session.commit() # Commit progress update
# The model should already be saved by train_single_product_model
# Return appropriate response
return {
"job_id": job_id,
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id,
"status": "completed",
"message": "Enhanced single product training completed successfully",
"created_at": datetime.now(timezone.utc),
"estimated_duration_minutes": 15, # Default estimate for single product
"training_results": {
"total_products": 1,
"successful_trainings": 1,
"failed_trainings": 0,
"products": [{
"inventory_product_id": inventory_product_id,
"status": "completed",
"model_id": str(model_info.get('model_id', f"model_{inventory_product_id}_{job_id[:8]}")) if model_info.get('model_id') else None,
"data_points": len(product_data) if product_data is not None else 0,
# Filter metrics to ensure only numeric values are included
"metrics": {
k: float(v) if not isinstance(v, (int, float)) else v
for k, v in model_info.get('training_metrics', {"mape": 0.0, "mae": 0.0, "rmse": 0.0, "r2_score": 0.0}).items()
if k != 'product_category' and v is not None
}
}],
"overall_training_time_seconds": model_info.get('training_time', 45.2)
},
"enhanced_features": True,
"repository_integration": True,
"completed_at": datetime.now(timezone.utc).isoformat()
}
except Exception as e:
logger.error("Enhanced single product training failed",
inventory_product_id=inventory_product_id,
error=str(e))
# Update status to failed (using shared session)
await self._update_job_status_repository(
job_id=job_id,
status="failed",
progress=0,
current_step="Training failed",
error_message=str(e),
tenant_id=tenant_id,
session=session
)
await session.commit() # Commit failure status
raise
def _create_detailed_training_response(self, final_result: Dict[str, Any]) -> Dict[str, Any]:
"""Convert final result to detailed training response"""
try:
training_results_data = final_result.get("training_results", {})
stored_models = final_result.get("stored_models", [])
# Convert stored models to product results
products = []
for model in stored_models:
products.append({
"inventory_product_id": model.get("inventory_product_id"),
"status": "completed",
"model_id": model.get("id"),
"data_points": model.get("training_samples", 0),
"metrics": {
"mape": model.get("mape"),
"mae": model.get("mae"),
"rmse": model.get("rmse"),
"r2_score": model.get("r2_score")
},
"model_path": model.get("model_path")
})
# Build the response
response_data = {
"job_id": final_result["job_id"],
"tenant_id": final_result["tenant_id"],
"status": final_result["status"],
"message": f"Training {final_result['status']} successfully",
"created_at": datetime.now(),
"estimated_duration_minutes": final_result.get("estimated_duration_minutes", 15),
"training_results": {
"total_products": len(products),
"successful_trainings": len([p for p in products if p["status"] == "completed"]),
"failed_trainings": len([p for p in products if p["status"] == "failed"]),
"products": products,
"overall_training_time_seconds": training_results_data.get("total_training_time", 0)
},
"data_summary": final_result.get("data_summary", {}),
"completed_at": final_result.get("completed_at")
}
return response_data
except Exception as e:
logger.error("Failed to create detailed response", error=str(e))
return final_result