Files
bakery-ia/services/training/app/services/training_service.py

1076 lines
48 KiB
Python
Raw Normal View History

2025-07-17 14:34:24 +02:00
"""
2025-08-08 09:08:41 +02:00
Enhanced Training Service with Repository Pattern
Main training service that uses the repository pattern for data access
2025-07-17 14:34:24 +02:00
"""
2025-07-19 16:59:37 +02:00
from typing import Dict, List, Any, Optional
import uuid
2025-08-08 09:08:41 +02:00
import structlog
2025-08-17 13:35:05 +02:00
from datetime import datetime, date, timezone
from decimal import Decimal
2025-07-17 14:34:24 +02:00
from sqlalchemy.ext.asyncio import AsyncSession
2025-08-08 09:08:41 +02:00
import json
import numpy as np
import pandas as pd
2025-07-17 14:34:24 +02:00
from app.ml.trainer import EnhancedBakeryMLTrainer
2025-07-28 19:28:39 +02:00
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType
from app.services.training_orchestrator import TrainingDataOrchestrator
Fix multiple critical bugs in onboarding training step This commit addresses all identified bugs and issues in the training code path: ## Critical Fixes: - Add get_start_time() method to TrainingLogRepository and fix non-existent method call - Remove duplicate training.started event from API endpoint (trainer publishes the accurate one) - Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone" ## High Priority Fixes: - Fix division by zero risk in time estimation with double-check and max() safety - Remove unreachable exception handler in training_operations.py - Simplify WebSocket token refresh logic to only reconnect on actual user session changes ## Medium Priority Fixes: - Fix auto-start training effect with useRef to prevent duplicate starts - Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket - Extract all magic numbers to centralized constants files: - Backend: services/training/app/core/training_constants.py - Frontend: frontend/src/constants/training.ts - Standardize error logging with exc_info=True on critical errors ## Code Quality Improvements: - All progress percentages now use named constants - All timeouts and intervals now use named constants - Improved code maintainability and readability - Better separation of concerns ## Files Changed: - Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py - Backend: training_operations.py, training_log_repository.py, training_constants.py (new) - Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new) All training progress events now properly flow from 0% to 100% with no gaps.
2025-11-05 13:02:39 +00:00
from app.core.training_constants import (
PROGRESS_DATA_VALIDATION,
PROGRESS_DATA_PREPARATION_COMPLETE,
PROGRESS_ML_TRAINING_START,
PROGRESS_TRAINING_COMPLETE,
PROGRESS_STORING_MODELS,
PROGRESS_STORING_METRICS,
MAX_ESTIMATED_TIME_REMAINING_SECONDS
)
2025-07-28 19:28:39 +02:00
2025-08-08 09:08:41 +02:00
# Import repositories
from app.repositories import (
ModelRepository,
TrainingLogRepository,
PerformanceRepository,
JobQueueRepository,
ArtifactRepository
)
2025-07-17 14:34:24 +02:00
2025-08-08 09:08:41 +02:00
# Import shared database components
from shared.database.unit_of_work import UnitOfWork
from shared.database.transactions import transactional
from shared.database.exceptions import DatabaseError
from app.core.database import database_manager
2025-08-03 14:55:13 +02:00
2025-08-08 09:08:41 +02:00
logger = structlog.get_logger()
2025-08-04 18:58:12 +02:00
2025-08-08 09:08:41 +02:00
def make_json_serializable(obj):
2025-08-16 20:13:40 +02:00
"""Convert numpy/pandas types, datetime, and UUID objects to JSON-serializable Python types"""
2025-08-08 09:08:41 +02:00
2025-08-17 13:35:05 +02:00
# Handle None values
if obj is None:
return None
# Handle basic datetime types first (most common)
if isinstance(obj, datetime):
return obj.isoformat()
elif isinstance(obj, date):
return obj.isoformat()
# Handle pandas timestamp types
if hasattr(pd, 'Timestamp') and isinstance(obj, pd.Timestamp):
return obj.isoformat()
# Handle numpy datetime types
if hasattr(np, 'datetime64') and isinstance(obj, np.datetime64):
return pd.Timestamp(obj).isoformat()
# Handle numeric types
2025-08-08 09:08:41 +02:00
if isinstance(obj, (np.integer, pd.Int64Dtype)):
return int(obj)
elif isinstance(obj, (np.floating, pd.Float64Dtype)):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, pd.Series):
return obj.tolist()
elif isinstance(obj, pd.DataFrame):
return obj.to_dict('records')
2025-08-17 13:35:05 +02:00
elif isinstance(obj, Decimal):
return float(obj)
# Handle UUID types
2025-08-08 09:08:41 +02:00
elif isinstance(obj, uuid.UUID):
return str(obj)
elif hasattr(obj, '__class__') and 'UUID' in str(obj.__class__):
# Handle any UUID-like objects (including asyncpg.pgproto.pgproto.UUID)
return str(obj)
2025-08-17 13:35:05 +02:00
# Handle collections recursively
2025-08-08 09:08:41 +02:00
elif isinstance(obj, dict):
return {k: make_json_serializable(v) for k, v in obj.items()}
2025-08-17 13:35:05 +02:00
elif isinstance(obj, (list, tuple)):
2025-08-08 09:08:41 +02:00
return [make_json_serializable(item) for item in obj]
2025-08-17 13:35:05 +02:00
elif isinstance(obj, set):
return [make_json_serializable(item) for item in obj]
# Handle other common types
elif isinstance(obj, (str, int, float, bool)):
2025-08-08 09:08:41 +02:00
return obj
2025-08-17 13:35:05 +02:00
# Last resort: try to convert to string
else:
try:
# For any other object, try to convert to string
return str(obj)
except Exception:
# If all else fails, return None
return None
2025-08-08 09:08:41 +02:00
2025-07-17 14:34:24 +02:00
2025-08-08 09:08:41 +02:00
class EnhancedTrainingService:
2025-07-19 16:59:37 +02:00
"""
2025-08-08 09:08:41 +02:00
Enhanced training service using repository pattern.
Coordinates the complete training pipeline with proper data abstraction.
2025-07-19 16:59:37 +02:00
"""
2025-07-17 14:34:24 +02:00
2025-08-08 09:08:41 +02:00
def __init__(self, session: AsyncSession = None):
self.session = session
self.database_manager = database_manager
# Initialize repositories
if session:
self.model_repo = ModelRepository(session)
self.training_log_repo = TrainingLogRepository(session)
self.performance_repo = PerformanceRepository(session)
self.queue_repo = JobQueueRepository(session)
self.artifact_repo = ArtifactRepository(session)
# Initialize training components
self.trainer = EnhancedBakeryMLTrainer(database_manager=self.database_manager)
2025-07-28 19:28:39 +02:00
self.date_alignment_service = DateAlignmentService()
self.orchestrator = TrainingDataOrchestrator(
date_alignment_service=self.date_alignment_service
)
2025-08-08 09:08:41 +02:00
async def _init_repositories(self, session: AsyncSession):
"""Initialize repositories with session"""
self.model_repo = ModelRepository(session)
self.training_log_repo = TrainingLogRepository(session)
self.performance_repo = PerformanceRepository(session)
self.queue_repo = JobQueueRepository(session)
self.artifact_repo = ArtifactRepository(session)
2025-07-28 19:28:39 +02:00
async def start_training_job(
self,
tenant_id: str,
2025-08-08 09:08:41 +02:00
bakery_location: tuple[float, float] = (40.4168, -3.7038),
2025-07-28 19:28:39 +02:00
requested_start: Optional[datetime] = None,
requested_end: Optional[datetime] = None,
job_id: Optional[str] = None
) -> Dict[str, Any]:
"""
2025-08-08 09:08:41 +02:00
Start a complete training job for a tenant using repository pattern.
2025-07-27 21:32:29 +02:00
2025-07-28 19:28:39 +02:00
Args:
tenant_id: Tenant identifier
bakery_location: Bakery coordinates (lat, lon)
requested_start: Optional explicit start date
requested_end: Optional explicit end date
job_id: Optional job identifier
Returns:
Training job results
"""
if not job_id:
job_id = f"training_{tenant_id}_{uuid.uuid4().hex[:8]}"
2025-07-27 21:32:29 +02:00
2025-08-08 09:08:41 +02:00
logger.info("Starting enhanced training job",
job_id=job_id,
tenant_id=tenant_id)
2025-08-08 09:08:41 +02:00
# Get session and initialize repositories
2025-08-17 11:12:17 +02:00
async with self.database_manager.get_session() as session:
2025-08-08 09:08:41 +02:00
await self._init_repositories(session)
2025-08-08 09:08:41 +02:00
try:
# Check if training log already exists, update if found, create if not
2025-08-16 20:13:40 +02:00
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
2025-08-16 20:13:40 +02:00
if existing_log:
logger.info("Training log already exists, updating status", job_id=job_id)
training_log = await self.training_log_repo.update_log_progress(
job_id, 0, "initializing", "running"
)
await session.commit()
2025-08-16 20:13:40 +02:00
else:
# Create new training log entry
log_data = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "running",
"progress": 0,
"current_step": "initializing"
}
try:
training_log = await self.training_log_repo.create_training_log(log_data)
await session.commit() # Explicit commit so other sessions can see it
except Exception as create_error:
# Handle race condition: log may have been created by another session
if "unique constraint" in str(create_error).lower() or "duplicate" in str(create_error).lower():
logger.debug("Training log already exists (race condition), updating instead", job_id=job_id)
await session.rollback()
training_log = await self.training_log_repo.update_log_progress(
job_id, 0, "initializing", "running"
)
await session.commit()
else:
raise
2025-08-17 10:28:58 +02:00
# Step 1: Prepare training dataset (includes sales data validation)
logger.info("Step 1: Preparing and aligning training data (with validation)")
2025-08-08 09:08:41 +02:00
await self.training_log_repo.update_log_progress(
Fix multiple critical bugs in onboarding training step This commit addresses all identified bugs and issues in the training code path: ## Critical Fixes: - Add get_start_time() method to TrainingLogRepository and fix non-existent method call - Remove duplicate training.started event from API endpoint (trainer publishes the accurate one) - Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone" ## High Priority Fixes: - Fix division by zero risk in time estimation with double-check and max() safety - Remove unreachable exception handler in training_operations.py - Simplify WebSocket token refresh logic to only reconnect on actual user session changes ## Medium Priority Fixes: - Fix auto-start training effect with useRef to prevent duplicate starts - Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket - Extract all magic numbers to centralized constants files: - Backend: services/training/app/core/training_constants.py - Frontend: frontend/src/constants/training.ts - Standardize error logging with exc_info=True on critical errors ## Code Quality Improvements: - All progress percentages now use named constants - All timeouts and intervals now use named constants - Improved code maintainability and readability - Better separation of concerns ## Files Changed: - Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py - Backend: training_operations.py, training_log_repository.py, training_constants.py (new) - Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new) All training progress events now properly flow from 0% to 100% with no gaps.
2025-11-05 13:02:39 +00:00
job_id, PROGRESS_DATA_VALIDATION, "data_validation", "running"
2025-08-08 09:08:41 +02:00
)
2025-08-17 10:28:58 +02:00
# Orchestrator now handles sales data validation to eliminate duplicate fetching
2025-08-08 09:08:41 +02:00
training_dataset = await self.orchestrator.prepare_training_data(
tenant_id=tenant_id,
bakery_location=bakery_location,
requested_start=requested_start,
requested_end=requested_end,
job_id=job_id
)
2025-08-17 10:28:58 +02:00
# Log the results from orchestrator's unified sales data fetch
logger.info(f"Sales data validation completed: {len(training_dataset.sales_data)} records",
2025-08-17 10:28:58 +02:00
tenant_id=tenant_id, job_id=job_id)
2025-08-08 09:08:41 +02:00
await self.training_log_repo.update_log_progress(
Fix multiple critical bugs in onboarding training step This commit addresses all identified bugs and issues in the training code path: ## Critical Fixes: - Add get_start_time() method to TrainingLogRepository and fix non-existent method call - Remove duplicate training.started event from API endpoint (trainer publishes the accurate one) - Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone" ## High Priority Fixes: - Fix division by zero risk in time estimation with double-check and max() safety - Remove unreachable exception handler in training_operations.py - Simplify WebSocket token refresh logic to only reconnect on actual user session changes ## Medium Priority Fixes: - Fix auto-start training effect with useRef to prevent duplicate starts - Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket - Extract all magic numbers to centralized constants files: - Backend: services/training/app/core/training_constants.py - Frontend: frontend/src/constants/training.ts - Standardize error logging with exc_info=True on critical errors ## Code Quality Improvements: - All progress percentages now use named constants - All timeouts and intervals now use named constants - Improved code maintainability and readability - Better separation of concerns ## Files Changed: - Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py - Backend: training_operations.py, training_log_repository.py, training_constants.py (new) - Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new) All training progress events now properly flow from 0% to 100% with no gaps.
2025-11-05 13:02:39 +00:00
job_id, PROGRESS_DATA_PREPARATION_COMPLETE, "data_preparation_complete", "running"
2025-08-08 09:08:41 +02:00
)
Fix multiple critical bugs in onboarding training step This commit addresses all identified bugs and issues in the training code path: ## Critical Fixes: - Add get_start_time() method to TrainingLogRepository and fix non-existent method call - Remove duplicate training.started event from API endpoint (trainer publishes the accurate one) - Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone" ## High Priority Fixes: - Fix division by zero risk in time estimation with double-check and max() safety - Remove unreachable exception handler in training_operations.py - Simplify WebSocket token refresh logic to only reconnect on actual user session changes ## Medium Priority Fixes: - Fix auto-start training effect with useRef to prevent duplicate starts - Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket - Extract all magic numbers to centralized constants files: - Backend: services/training/app/core/training_constants.py - Frontend: frontend/src/constants/training.ts - Standardize error logging with exc_info=True on critical errors ## Code Quality Improvements: - All progress percentages now use named constants - All timeouts and intervals now use named constants - Improved code maintainability and readability - Better separation of concerns ## Files Changed: - Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py - Backend: training_operations.py, training_log_repository.py, training_constants.py (new) - Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new) All training progress events now properly flow from 0% to 100% with no gaps.
2025-11-05 13:02:39 +00:00
2025-08-08 09:08:41 +02:00
# Step 2: Execute ML training pipeline
logger.info("Step 2: Starting ML training pipeline")
await self.training_log_repo.update_log_progress(
Fix multiple critical bugs in onboarding training step This commit addresses all identified bugs and issues in the training code path: ## Critical Fixes: - Add get_start_time() method to TrainingLogRepository and fix non-existent method call - Remove duplicate training.started event from API endpoint (trainer publishes the accurate one) - Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone" ## High Priority Fixes: - Fix division by zero risk in time estimation with double-check and max() safety - Remove unreachable exception handler in training_operations.py - Simplify WebSocket token refresh logic to only reconnect on actual user session changes ## Medium Priority Fixes: - Fix auto-start training effect with useRef to prevent duplicate starts - Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket - Extract all magic numbers to centralized constants files: - Backend: services/training/app/core/training_constants.py - Frontend: frontend/src/constants/training.ts - Standardize error logging with exc_info=True on critical errors ## Code Quality Improvements: - All progress percentages now use named constants - All timeouts and intervals now use named constants - Improved code maintainability and readability - Better separation of concerns ## Files Changed: - Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py - Backend: training_operations.py, training_log_repository.py, training_constants.py (new) - Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new) All training progress events now properly flow from 0% to 100% with no gaps.
2025-11-05 13:02:39 +00:00
job_id, PROGRESS_ML_TRAINING_START, "ml_training", "running"
2025-08-08 09:08:41 +02:00
)
# ✅ FIX: Commit the session to prevent deadlock with trainer's nested session
# The trainer creates its own session, so we need to ensure this update is committed
await session.commit()
logger.debug("Committed session after ml_training progress update")
2025-08-08 09:08:41 +02:00
training_results = await self.trainer.train_tenant_models(
tenant_id=tenant_id,
training_dataset=training_dataset,
2025-11-05 18:47:20 +01:00
job_id=job_id,
session=session # Pass the main session to avoid nested sessions
2025-08-08 09:08:41 +02:00
)
2025-08-08 09:08:41 +02:00
await self.training_log_repo.update_log_progress(
Fix multiple critical bugs in onboarding training step This commit addresses all identified bugs and issues in the training code path: ## Critical Fixes: - Add get_start_time() method to TrainingLogRepository and fix non-existent method call - Remove duplicate training.started event from API endpoint (trainer publishes the accurate one) - Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone" ## High Priority Fixes: - Fix division by zero risk in time estimation with double-check and max() safety - Remove unreachable exception handler in training_operations.py - Simplify WebSocket token refresh logic to only reconnect on actual user session changes ## Medium Priority Fixes: - Fix auto-start training effect with useRef to prevent duplicate starts - Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket - Extract all magic numbers to centralized constants files: - Backend: services/training/app/core/training_constants.py - Frontend: frontend/src/constants/training.ts - Standardize error logging with exc_info=True on critical errors ## Code Quality Improvements: - All progress percentages now use named constants - All timeouts and intervals now use named constants - Improved code maintainability and readability - Better separation of concerns ## Files Changed: - Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py - Backend: training_operations.py, training_log_repository.py, training_constants.py (new) - Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new) All training progress events now properly flow from 0% to 100% with no gaps.
2025-11-05 13:02:39 +00:00
job_id, PROGRESS_TRAINING_COMPLETE, "training_complete", "running"
2025-08-08 09:08:41 +02:00
)
Fix multiple critical bugs in onboarding training step This commit addresses all identified bugs and issues in the training code path: ## Critical Fixes: - Add get_start_time() method to TrainingLogRepository and fix non-existent method call - Remove duplicate training.started event from API endpoint (trainer publishes the accurate one) - Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone" ## High Priority Fixes: - Fix division by zero risk in time estimation with double-check and max() safety - Remove unreachable exception handler in training_operations.py - Simplify WebSocket token refresh logic to only reconnect on actual user session changes ## Medium Priority Fixes: - Fix auto-start training effect with useRef to prevent duplicate starts - Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket - Extract all magic numbers to centralized constants files: - Backend: services/training/app/core/training_constants.py - Frontend: frontend/src/constants/training.ts - Standardize error logging with exc_info=True on critical errors ## Code Quality Improvements: - All progress percentages now use named constants - All timeouts and intervals now use named constants - Improved code maintainability and readability - Better separation of concerns ## Files Changed: - Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py - Backend: training_operations.py, training_log_repository.py, training_constants.py (new) - Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new) All training progress events now properly flow from 0% to 100% with no gaps.
2025-11-05 13:02:39 +00:00
# Publish progress event (85%)
from app.services.training_events import publish_training_progress
await publish_training_progress(
job_id=job_id,
tenant_id=tenant_id,
progress=PROGRESS_TRAINING_COMPLETE,
current_step="Training Complete",
step_details="All products trained successfully"
)
2025-08-08 09:08:41 +02:00
# Step 3: Store model records using repository
logger.info("Step 3: Storing model records")
logger.debug("Training results structure",
keys=list(training_results.keys()) if isinstance(training_results, dict) else "not_dict",
training_results_type=type(training_results).__name__)
2025-08-08 09:08:41 +02:00
stored_models = await self._store_trained_models(
tenant_id, job_id, training_results
)
2025-08-08 09:08:41 +02:00
await self.training_log_repo.update_log_progress(
Fix multiple critical bugs in onboarding training step This commit addresses all identified bugs and issues in the training code path: ## Critical Fixes: - Add get_start_time() method to TrainingLogRepository and fix non-existent method call - Remove duplicate training.started event from API endpoint (trainer publishes the accurate one) - Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone" ## High Priority Fixes: - Fix division by zero risk in time estimation with double-check and max() safety - Remove unreachable exception handler in training_operations.py - Simplify WebSocket token refresh logic to only reconnect on actual user session changes ## Medium Priority Fixes: - Fix auto-start training effect with useRef to prevent duplicate starts - Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket - Extract all magic numbers to centralized constants files: - Backend: services/training/app/core/training_constants.py - Frontend: frontend/src/constants/training.ts - Standardize error logging with exc_info=True on critical errors ## Code Quality Improvements: - All progress percentages now use named constants - All timeouts and intervals now use named constants - Improved code maintainability and readability - Better separation of concerns ## Files Changed: - Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py - Backend: training_operations.py, training_log_repository.py, training_constants.py (new) - Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new) All training progress events now properly flow from 0% to 100% with no gaps.
2025-11-05 13:02:39 +00:00
job_id, PROGRESS_STORING_MODELS, "storing_models", "running"
)
# Publish progress event (92%)
await publish_training_progress(
job_id=job_id,
tenant_id=tenant_id,
progress=PROGRESS_STORING_MODELS,
current_step="Storing Models",
step_details=f"Saved {len(stored_models)} trained models to database"
2025-08-08 09:08:41 +02:00
)
2025-08-08 09:08:41 +02:00
# Step 4: Create performance metrics
2025-10-15 21:09:42 +02:00
await self.training_log_repo.update_log_progress(
Fix multiple critical bugs in onboarding training step This commit addresses all identified bugs and issues in the training code path: ## Critical Fixes: - Add get_start_time() method to TrainingLogRepository and fix non-existent method call - Remove duplicate training.started event from API endpoint (trainer publishes the accurate one) - Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone" ## High Priority Fixes: - Fix division by zero risk in time estimation with double-check and max() safety - Remove unreachable exception handler in training_operations.py - Simplify WebSocket token refresh logic to only reconnect on actual user session changes ## Medium Priority Fixes: - Fix auto-start training effect with useRef to prevent duplicate starts - Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket - Extract all magic numbers to centralized constants files: - Backend: services/training/app/core/training_constants.py - Frontend: frontend/src/constants/training.ts - Standardize error logging with exc_info=True on critical errors ## Code Quality Improvements: - All progress percentages now use named constants - All timeouts and intervals now use named constants - Improved code maintainability and readability - Better separation of concerns ## Files Changed: - Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py - Backend: training_operations.py, training_log_repository.py, training_constants.py (new) - Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new) All training progress events now properly flow from 0% to 100% with no gaps.
2025-11-05 13:02:39 +00:00
job_id, PROGRESS_STORING_METRICS, "storing_performance_metrics", "running"
)
# Publish progress event (94%)
await publish_training_progress(
job_id=job_id,
tenant_id=tenant_id,
progress=PROGRESS_STORING_METRICS,
current_step="Storing Performance Metrics",
step_details="Saving model performance metrics"
2025-10-15 21:09:42 +02:00
)
2025-08-08 09:08:41 +02:00
await self._create_performance_metrics(
tenant_id, stored_models, training_results
)
2025-10-15 21:09:42 +02:00
# Step 4.5: Save training performance metrics for future estimations
await self._save_training_performance_metrics(
tenant_id, job_id, training_results, training_log
)
2025-08-08 09:08:41 +02:00
# Step 5: Complete training log
final_result = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "completed",
"training_results": training_results,
"stored_models": [{
"id": str(model.id),
2025-08-14 16:47:34 +02:00
"inventory_product_id": str(model.inventory_product_id),
2025-08-08 09:08:41 +02:00
"model_type": model.model_type,
"model_path": model.model_path,
"is_active": model.is_active,
"training_samples": model.training_samples
} for model in stored_models],
"data_summary": {
"sales_records": int(len(training_dataset.sales_data)),
"weather_records": int(len(training_dataset.weather_data)),
"traffic_records": int(len(training_dataset.traffic_data)),
"date_range": {
"start": training_dataset.date_range.start.isoformat(),
"end": training_dataset.date_range.end.isoformat()
},
"data_sources_used": [source.value for source in training_dataset.date_range.available_sources],
"constraints_applied": training_dataset.date_range.constraints
2025-07-28 21:30:49 +02:00
},
2025-08-08 09:08:41 +02:00
"completed_at": datetime.now().isoformat()
}
# Make sure all data is JSON-serializable before saving to database
json_safe_result = make_json_serializable(final_result)
2025-08-17 10:28:58 +02:00
# Ensure results is a proper dict for database storage
if not isinstance(json_safe_result, dict):
logger.warning("JSON safe result is not a dict, wrapping it", result_type=type(json_safe_result))
json_safe_result = {"training_data": json_safe_result}
# Double-check JSON serialization by attempting to serialize
import json
try:
json.dumps(json_safe_result)
logger.debug("Results successfully JSON-serializable", job_id=job_id)
except (TypeError, ValueError) as e:
logger.error("Results still not JSON-serializable after cleaning",
job_id=job_id, error=str(e))
# Create a minimal safe result
json_safe_result = {
"status": "completed",
"job_id": job_id,
"models_created": final_result.get("products_trained", 0),
"error": "Result serialization failed"
}
2025-08-08 09:08:41 +02:00
await self.training_log_repo.complete_training_log(
job_id, results=json_safe_result
)
2025-12-30 14:40:20 +01:00
# CRITICAL: Commit the session to persist the completed status to database
# Without this commit, the status update is lost when the session closes
await session.commit()
2025-08-08 09:08:41 +02:00
logger.info("Enhanced training job completed successfully",
job_id=job_id,
models_created=len(stored_models))
2025-08-08 09:08:41 +02:00
return self._create_detailed_training_response(final_result)
except Exception as e:
logger.error("Enhanced training job failed",
job_id=job_id,
Fix multiple critical bugs in onboarding training step This commit addresses all identified bugs and issues in the training code path: ## Critical Fixes: - Add get_start_time() method to TrainingLogRepository and fix non-existent method call - Remove duplicate training.started event from API endpoint (trainer publishes the accurate one) - Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone" ## High Priority Fixes: - Fix division by zero risk in time estimation with double-check and max() safety - Remove unreachable exception handler in training_operations.py - Simplify WebSocket token refresh logic to only reconnect on actual user session changes ## Medium Priority Fixes: - Fix auto-start training effect with useRef to prevent duplicate starts - Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket - Extract all magic numbers to centralized constants files: - Backend: services/training/app/core/training_constants.py - Frontend: frontend/src/constants/training.ts - Standardize error logging with exc_info=True on critical errors ## Code Quality Improvements: - All progress percentages now use named constants - All timeouts and intervals now use named constants - Improved code maintainability and readability - Better separation of concerns ## Files Changed: - Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py - Backend: training_operations.py, training_log_repository.py, training_constants.py (new) - Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new) All training progress events now properly flow from 0% to 100% with no gaps.
2025-11-05 13:02:39 +00:00
error=str(e),
exc_info=True)
2025-08-08 09:08:41 +02:00
# Mark as failed in database
await self.training_log_repo.complete_training_log(
job_id, error_message=str(e)
)
2025-12-30 14:40:20 +01:00
# Commit the failure status to database
await session.commit()
2025-08-08 09:08:41 +02:00
error_result = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "failed",
"error_message": str(e),
"completed_at": datetime.now().isoformat()
}
2025-08-17 10:28:58 +02:00
# Ensure error result is JSON serializable
error_result = make_json_serializable(error_result)
2025-08-08 09:08:41 +02:00
return self._create_detailed_training_response(error_result)
2025-07-17 14:34:24 +02:00
2025-08-08 09:08:41 +02:00
async def _store_trained_models(
2025-07-28 19:28:39 +02:00
self,
tenant_id: str,
2025-08-08 09:08:41 +02:00
job_id: str,
training_results: Dict[str, Any]
) -> List:
"""
Retrieve or verify stored models from training results.
NOTE: Model records are now created by the trainer during parallel execution.
This method retrieves the already-created models instead of creating duplicates.
"""
2025-08-08 09:08:41 +02:00
stored_models = []
2025-07-19 16:59:37 +02:00
try:
# Check if models were already created by the trainer (new approach)
# The trainer now writes models sequentially after parallel training
training_results_dict = training_results.get("training_results", {})
# Get list of successfully trained products
successful_products = [
product_id for product_id, result in training_results_dict.items()
if result.get('status') == 'success' and result.get('model_record_id')
]
logger.info("Retrieving models created by trainer",
successful_products=len(successful_products),
job_id=job_id)
# Retrieve the models that were already created by the trainer
for product_id in successful_products:
result = training_results_dict[product_id]
model_record_id = result.get('model_record_id')
if model_record_id:
try:
# Get the model from the database using base repository method
model = await self.model_repo.get_by_id(model_record_id)
if model:
stored_models.append(model)
logger.debug("Retrieved model from database",
model_id=model_record_id,
inventory_product_id=product_id)
except Exception as e:
logger.warning("Could not retrieve model record",
model_id=model_record_id,
inventory_product_id=product_id,
error=str(e))
logger.info("Models retrieval complete",
models_retrieved=len(stored_models),
expected=len(successful_products))
2025-08-08 09:08:41 +02:00
return stored_models
2025-07-19 16:59:37 +02:00
except Exception as e:
logger.error("Failed to retrieve stored models",
2025-08-08 09:08:41 +02:00
tenant_id=tenant_id,
job_id=job_id,
error=str(e))
return stored_models
2025-07-19 16:59:37 +02:00
2025-08-08 09:08:41 +02:00
async def _create_performance_metrics(
2025-07-28 19:28:39 +02:00
self,
tenant_id: str,
2025-08-08 09:08:41 +02:00
stored_models: List,
training_results: Dict[str, Any]
):
"""
Verify performance metrics for stored models.
NOTE: Performance metrics are now created by the trainer during model creation.
This method now just verifies they exist rather than creating duplicates.
"""
2025-07-19 16:59:37 +02:00
try:
logger.info("Verifying performance metrics",
models_count=len(stored_models))
2025-10-15 21:09:42 +02:00
# Performance metrics are already created by the trainer
# This method is kept for compatibility but doesn't create duplicates
for model in stored_models:
logger.debug("Performance metrics already created for model",
model_id=str(model.id),
inventory_product_id=str(model.inventory_product_id))
2025-10-15 21:09:42 +02:00
logger.info("Performance metrics verification complete",
models_count=len(stored_models))
2025-10-15 21:09:42 +02:00
2025-08-08 09:08:41 +02:00
except Exception as e:
logger.error("Failed to verify performance metrics",
2025-08-08 09:08:41 +02:00
tenant_id=tenant_id,
error=str(e))
2025-10-15 21:09:42 +02:00
async def _save_training_performance_metrics(
self,
tenant_id: str,
job_id: str,
training_results: Dict[str, Any],
training_log
):
"""
Save aggregated training performance metrics for time estimation.
This data is used to predict future training durations.
"""
try:
from app.models.training import TrainingPerformanceMetrics
2025-11-05 13:34:56 +01:00
from shared.database.repository import BaseRepository
2025-10-15 21:09:42 +02:00
# Extract timing and success data
models_trained = training_results.get("models_trained", {})
total_products = len(models_trained)
successful_products = sum(1 for m in models_trained.values() if m.get("status") == "completed")
failed_products = total_products - successful_products
# Calculate total duration
if training_log.start_time and training_log.end_time:
total_duration_seconds = (training_log.end_time - training_log.start_time).total_seconds()
else:
# Fallback to elapsed time
total_duration_seconds = training_results.get("total_training_time", 0)
# Calculate average time per product
if successful_products > 0:
avg_time_per_product = total_duration_seconds / successful_products
else:
avg_time_per_product = 0
# Extract timing breakdown if available
data_analysis_time = training_results.get("data_analysis_time_seconds")
training_time = training_results.get("training_time_seconds")
finalization_time = training_results.get("finalization_time_seconds")
# Create performance metrics record
metric_data = {
"tenant_id": tenant_id,
"job_id": job_id,
"total_products": total_products,
"successful_products": successful_products,
"failed_products": failed_products,
"total_duration_seconds": total_duration_seconds,
"avg_time_per_product": avg_time_per_product,
"data_analysis_time_seconds": data_analysis_time,
"training_time_seconds": training_time,
"finalization_time_seconds": finalization_time,
"completed_at": datetime.now(timezone.utc)
}
2025-11-05 13:34:56 +01:00
# Create a temporary repository for the TrainingPerformanceMetrics model
# Use the session from one of the initialized repositories to ensure it's available
session = self.model_repo.session # This should be the same session used by all repositories
metrics_repo = BaseRepository(TrainingPerformanceMetrics, session)
2025-10-15 21:09:42 +02:00
# Use repository to create record
2025-11-05 13:34:56 +01:00
await metrics_repo.create(metric_data)
2025-10-15 21:09:42 +02:00
logger.info("Saved training performance metrics for future estimations",
tenant_id=tenant_id,
job_id=job_id,
avg_time_per_product=avg_time_per_product,
total_products=total_products,
successful_products=successful_products)
except Exception as e:
logger.error("Failed to save training performance metrics",
tenant_id=tenant_id,
job_id=job_id,
error=str(e))
2025-08-08 09:08:41 +02:00
async def get_training_status(self, job_id: str) -> Dict[str, Any]:
"""Get training job status using repository"""
try:
async with self.database_manager.get_session() as session:
2025-08-08 09:08:41 +02:00
await self._init_repositories(session)
2025-08-08 09:08:41 +02:00
log = await self.training_log_repo.get_log_by_job_id(job_id)
if not log:
return {"error": "Job not found"}
# Calculate estimated time remaining based on progress and elapsed time
estimated_time_remaining_seconds = None
if log.status == "running" and log.progress > 0 and log.start_time:
from datetime import datetime, timezone
elapsed_time = (datetime.now(timezone.utc) - log.start_time).total_seconds()
Fix multiple critical bugs in onboarding training step This commit addresses all identified bugs and issues in the training code path: ## Critical Fixes: - Add get_start_time() method to TrainingLogRepository and fix non-existent method call - Remove duplicate training.started event from API endpoint (trainer publishes the accurate one) - Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone" ## High Priority Fixes: - Fix division by zero risk in time estimation with double-check and max() safety - Remove unreachable exception handler in training_operations.py - Simplify WebSocket token refresh logic to only reconnect on actual user session changes ## Medium Priority Fixes: - Fix auto-start training effect with useRef to prevent duplicate starts - Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket - Extract all magic numbers to centralized constants files: - Backend: services/training/app/core/training_constants.py - Frontend: frontend/src/constants/training.ts - Standardize error logging with exc_info=True on critical errors ## Code Quality Improvements: - All progress percentages now use named constants - All timeouts and intervals now use named constants - Improved code maintainability and readability - Better separation of concerns ## Files Changed: - Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py - Backend: training_operations.py, training_log_repository.py, training_constants.py (new) - Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new) All training progress events now properly flow from 0% to 100% with no gaps.
2025-11-05 13:02:39 +00:00
if elapsed_time > 0 and log.progress > 0: # Double-check progress is positive
# Calculate estimated total time based on progress
Fix multiple critical bugs in onboarding training step This commit addresses all identified bugs and issues in the training code path: ## Critical Fixes: - Add get_start_time() method to TrainingLogRepository and fix non-existent method call - Remove duplicate training.started event from API endpoint (trainer publishes the accurate one) - Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone" ## High Priority Fixes: - Fix division by zero risk in time estimation with double-check and max() safety - Remove unreachable exception handler in training_operations.py - Simplify WebSocket token refresh logic to only reconnect on actual user session changes ## Medium Priority Fixes: - Fix auto-start training effect with useRef to prevent duplicate starts - Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket - Extract all magic numbers to centralized constants files: - Backend: services/training/app/core/training_constants.py - Frontend: frontend/src/constants/training.ts - Standardize error logging with exc_info=True on critical errors ## Code Quality Improvements: - All progress percentages now use named constants - All timeouts and intervals now use named constants - Improved code maintainability and readability - Better separation of concerns ## Files Changed: - Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py - Backend: training_operations.py, training_log_repository.py, training_constants.py (new) - Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new) All training progress events now properly flow from 0% to 100% with no gaps.
2025-11-05 13:02:39 +00:00
# Use max(log.progress, 1) as additional safety against division by zero
estimated_total_time = (elapsed_time / max(log.progress, 1)) * 100
estimated_time_remaining_seconds = int(estimated_total_time - elapsed_time)
Fix multiple critical bugs in onboarding training step This commit addresses all identified bugs and issues in the training code path: ## Critical Fixes: - Add get_start_time() method to TrainingLogRepository and fix non-existent method call - Remove duplicate training.started event from API endpoint (trainer publishes the accurate one) - Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone" ## High Priority Fixes: - Fix division by zero risk in time estimation with double-check and max() safety - Remove unreachable exception handler in training_operations.py - Simplify WebSocket token refresh logic to only reconnect on actual user session changes ## Medium Priority Fixes: - Fix auto-start training effect with useRef to prevent duplicate starts - Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket - Extract all magic numbers to centralized constants files: - Backend: services/training/app/core/training_constants.py - Frontend: frontend/src/constants/training.ts - Standardize error logging with exc_info=True on critical errors ## Code Quality Improvements: - All progress percentages now use named constants - All timeouts and intervals now use named constants - Improved code maintainability and readability - Better separation of concerns ## Files Changed: - Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py - Backend: training_operations.py, training_log_repository.py, training_constants.py (new) - Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new) All training progress events now properly flow from 0% to 100% with no gaps.
2025-11-05 13:02:39 +00:00
# Cap at reasonable maximum (30 minutes)
estimated_time_remaining_seconds = max(0, min(estimated_time_remaining_seconds, MAX_ESTIMATED_TIME_REMAINING_SECONDS))
# Extract products info from results if available
products_total = 0
products_completed = 0
products_failed = 0
if log.results:
products_total = log.results.get("total_products", 0)
products_completed = log.results.get("successful_trainings", 0)
products_failed = log.results.get("failed_trainings", 0)
2025-07-19 16:59:37 +02:00
return {
2025-08-08 09:08:41 +02:00
"job_id": job_id,
"tenant_id": log.tenant_id,
"status": log.status,
"progress": log.progress,
"current_step": log.current_step,
"started_at": log.start_time.isoformat() if log.start_time else None,
"completed_at": log.end_time.isoformat() if log.end_time else None,
2025-08-08 09:08:41 +02:00
"error_message": log.error_message,
"results": log.results,
"products_total": products_total,
"products_completed": products_completed,
"products_failed": products_failed,
"estimated_time_remaining_seconds": estimated_time_remaining_seconds,
"message": log.current_step
2025-07-28 19:28:39 +02:00
}
2025-07-19 16:59:37 +02:00
except Exception as e:
2025-08-08 09:08:41 +02:00
logger.error("Failed to get training status",
job_id=job_id,
error=str(e))
return {"error": f"Failed to get status: {str(e)}"}
2025-07-17 14:34:24 +02:00
2025-08-08 09:08:41 +02:00
async def get_tenant_models(
2025-07-28 19:28:39 +02:00
self,
tenant_id: str,
2025-08-08 09:08:41 +02:00
active_only: bool = True,
skip: int = 0,
limit: int = 100
) -> List[Dict[str, Any]]:
"""Get models for a tenant using repository"""
2025-07-19 16:59:37 +02:00
try:
2025-08-17 11:12:17 +02:00
async with self.database_manager.get_session() as session:
2025-08-08 09:08:41 +02:00
await self._init_repositories(session)
if active_only:
models = await self.model_repo.get_multi(
filters={"tenant_id": tenant_id, "is_active": True},
skip=skip,
limit=limit,
order_by="created_at",
order_desc=True
)
else:
models = await self.model_repo.get_models_by_tenant(
tenant_id, skip=skip, limit=limit
)
return [model.to_dict() for model in models]
2025-07-19 16:59:37 +02:00
except Exception as e:
2025-08-08 09:08:41 +02:00
logger.error("Failed to get tenant models",
tenant_id=tenant_id,
error=str(e))
return []
2025-08-03 14:55:13 +02:00
2025-08-08 09:08:41 +02:00
async def get_model_performance(self, model_id: str) -> Dict[str, Any]:
"""Get model performance metrics using repository"""
try:
2025-08-17 11:12:17 +02:00
async with self.database_manager.get_session() as session:
2025-08-08 09:08:41 +02:00
await self._init_repositories(session)
# Get model summary
model_summary = await self.model_repo.get_model_performance_summary(model_id)
# Get latest performance metrics
latest_metric = await self.performance_repo.get_latest_metric_for_model(model_id)
if latest_metric:
model_summary["latest_metrics"] = {
"mae": latest_metric.mae,
"mse": latest_metric.mse,
"rmse": latest_metric.rmse,
"mape": latest_metric.mape,
"r2_score": latest_metric.r2_score,
"accuracy_percentage": latest_metric.accuracy_percentage,
"measured_at": latest_metric.measured_at.isoformat() if latest_metric.measured_at else None
}
return model_summary
except Exception as e:
logger.error("Failed to get model performance",
model_id=model_id,
error=str(e))
return {}
2025-08-03 14:55:13 +02:00
2025-08-08 09:08:41 +02:00
async def get_tenant_statistics(self, tenant_id: str) -> Dict[str, Any]:
"""Get comprehensive tenant statistics using repositories"""
try:
2025-08-17 11:12:17 +02:00
async with self.database_manager.get_session() as session:
2025-08-08 09:08:41 +02:00
await self._init_repositories(session)
# Get model statistics
model_stats = await self.model_repo.get_model_statistics(tenant_id)
# Get job statistics
job_stats = await self.training_log_repo.get_job_statistics(tenant_id)
# Get performance trends
performance_trends = await self.performance_repo.get_performance_trends(tenant_id)
# Get queue status
queue_status = await self.queue_repo.get_queue_status(tenant_id)
# Get artifact statistics
artifact_stats = await self.artifact_repo.get_artifact_statistics(tenant_id)
return {
"tenant_id": tenant_id,
"models": model_stats,
"training_jobs": job_stats,
"performance": performance_trends,
"queue": queue_status,
"artifacts": artifact_stats,
"summary": {
"total_active_models": model_stats.get("active_models", 0),
"total_training_jobs": job_stats.get("total_jobs", 0),
"success_rate": job_stats.get("success_rate", 0.0),
"products_with_models": len(model_stats.get("models_by_product", {})),
"total_storage_mb": artifact_stats.get("total_storage", {}).get("total_size_mb", 0.0)
}
}
except Exception as e:
logger.error("Failed to get tenant statistics",
tenant_id=tenant_id,
error=str(e))
return {"error": f"Failed to get statistics: {str(e)}"}
2025-08-03 14:55:13 +02:00
2025-08-08 09:08:41 +02:00
async def _update_job_status_repository(self,
job_id: str,
status: str,
progress: int = None,
current_step: str = None,
error_message: str = None,
2025-08-15 17:53:59 +02:00
results: Dict = None,
2025-11-14 20:27:39 +01:00
tenant_id: str = None,
session = None):
"""Update job status using repository pattern
Args:
session: Optional database session to reuse. If None, creates a new session.
"""
2025-08-08 09:08:41 +02:00
try:
2025-11-14 20:27:39 +01:00
# Use provided session or create new one
should_create_session = session is None
if should_create_session:
async with self.database_manager.get_session() as session:
await self._init_repositories(session)
await self._update_job_status_impl(
session, job_id, status, progress, current_step,
error_message, results, tenant_id
)
else:
# Reuse provided session (don't commit - let caller control transaction)
2025-08-08 09:08:41 +02:00
await self._init_repositories(session)
2025-11-14 20:27:39 +01:00
await self._update_job_status_impl(
session, job_id, status, progress, current_step,
error_message, results, tenant_id, auto_commit=False
)
except Exception as e:
logger.error("Failed to update job status using repository",
job_id=job_id,
error=str(e))
2025-11-14 20:27:39 +01:00
async def _update_job_status_impl(self,
session,
job_id: str,
status: str,
progress: int = None,
current_step: str = None,
error_message: str = None,
results: Dict = None,
tenant_id: str = None,
auto_commit: bool = True):
"""Implementation of job status update"""
# Check if log exists, create if not
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
if not existing_log:
# Create initial log entry
if not tenant_id:
# Extract tenant_id from job_id if not provided
# Format: enhanced_training_{tenant_id}_{job_suffix}
try:
parts = job_id.split('_')
if len(parts) >= 3 and parts[0] == 'enhanced' and parts[1] == 'training':
tenant_id = parts[2]
except Exception:
logger.warning(f"Could not extract tenant_id from job_id {job_id}")
if tenant_id:
log_data = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": status or "pending",
"progress": progress or 0,
"current_step": current_step or "initializing",
"start_time": datetime.now(timezone.utc)
}
2025-11-14 20:27:39 +01:00
if error_message:
log_data["error_message"] = error_message
if results:
# Ensure results are JSON-serializable before storing
log_data["results"] = make_json_serializable(results)
2025-11-14 20:27:39 +01:00
try:
await self.training_log_repo.create_training_log(log_data)
if auto_commit:
await session.commit() # Explicit commit so other sessions can see it
logger.info("Created initial training log", job_id=job_id, tenant_id=tenant_id)
except Exception as create_error:
# Handle race condition: another session may have created the log
if "unique constraint" in str(create_error).lower() or "duplicate" in str(create_error).lower():
logger.debug("Training log already exists (race condition), querying again", job_id=job_id)
await session.rollback()
# Query again to get the existing log
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
if existing_log:
# Update the existing log instead
await self.training_log_repo.update_log_progress(
job_id=job_id,
progress=progress,
current_step=current_step,
status=status
)
if auto_commit:
await session.commit()
else:
raise
else:
logger.error("Cannot create training log without tenant_id", job_id=job_id)
return
else:
# Update existing log
await self.training_log_repo.update_log_progress(
job_id=job_id,
progress=progress,
current_step=current_step,
status=status
)
2025-11-14 20:27:39 +01:00
# Update additional fields if provided
if error_message or results:
update_data = {}
if error_message:
update_data["error_message"] = error_message
if results:
# Ensure results are JSON-serializable before storing
update_data["results"] = make_json_serializable(results)
if status in ["completed", "failed"]:
update_data["end_time"] = datetime.now(timezone.utc)
2025-11-14 20:27:39 +01:00
if update_data:
await self.training_log_repo.update(existing_log.id, update_data)
2025-11-14 20:27:39 +01:00
if auto_commit:
await session.commit() # Explicit commit after updates
2025-08-08 09:08:41 +02:00
async def start_single_product_training(self,
tenant_id: str,
2025-08-14 16:47:34 +02:00
inventory_product_id: str,
2025-08-08 09:08:41 +02:00
job_id: str,
bakery_location: tuple = (40.4168, -3.7038)) -> Dict[str, Any]:
2025-11-14 20:27:39 +01:00
"""Start enhanced single product training using repository pattern with single session"""
# Create a single database session for all operations to avoid connection pool exhaustion
async with self.database_manager.get_session() as session:
await self._init_repositories(session)
2025-11-05 13:34:56 +01:00
try:
2025-11-14 20:27:39 +01:00
logger.info("Starting enhanced single product training",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
job_id=job_id)
2025-11-05 13:34:56 +01:00
2025-11-14 20:27:39 +01:00
# Create initial training log (using shared session)
await self._update_job_status_repository(
job_id=job_id,
status="running",
progress=0,
current_step="Fetching training data",
tenant_id=tenant_id,
session=session
)
await session.commit() # Commit after initial log creation
# Prepare training data for all products to get weather/traffic data
# then filter down to the specific product
training_dataset = await self.orchestrator.prepare_training_data(
2025-11-05 13:34:56 +01:00
tenant_id=tenant_id,
2025-11-14 20:27:39 +01:00
bakery_location=bakery_location,
job_id=job_id + "_temp"
)
# Use the enhanced data processor to prepare training data with all features (POI, weather, traffic)
# Filter sales data to the specific product first
sales_df = pd.DataFrame(training_dataset.sales_data)
product_sales_df = sales_df[sales_df['inventory_product_id'] == inventory_product_id]
if product_sales_df.empty:
raise ValueError(f"No sales data available for product {inventory_product_id}")
# Get weather and traffic data as DataFrames
weather_df = pd.DataFrame(training_dataset.weather_data)
traffic_df = pd.DataFrame(training_dataset.traffic_data)
# Get POI features from the training dataset (already collected by orchestrator)
poi_features = training_dataset.poi_features if hasattr(training_dataset, 'poi_features') else None
# Use the enhanced data processor to merge all features properly
# This will include POI, weather, traffic features along with ds and y
from app.ml.data_processor import EnhancedBakeryDataProcessor
data_processor = EnhancedBakeryDataProcessor(self.database_manager)
product_data = await data_processor.prepare_training_data(
sales_data=product_sales_df,
weather_data=weather_df,
traffic_data=traffic_df,
2025-11-05 13:34:56 +01:00
inventory_product_id=inventory_product_id,
2025-11-14 20:27:39 +01:00
poi_features=poi_features,
tenant_id=tenant_id,
2025-11-05 13:34:56 +01:00
job_id=job_id
)
2025-11-14 20:27:39 +01:00
if product_data.empty:
raise ValueError(f"Data processor returned empty data for product {inventory_product_id}")
logger.info("Prepared training data for single product",
inventory_product_id=inventory_product_id,
data_points=len(product_data),
features=list(product_data.columns),
date_range=f"{product_data['ds'].min()} to {product_data['ds'].max()}")
# Update progress (using shared session)
await self._update_job_status_repository(
job_id=job_id,
status="running",
progress=30,
current_step="Training model",
tenant_id=tenant_id,
session=session
)
await session.commit() # Commit progress update
# Run the actual training (passing the session to avoid nested session creation)
try:
model_info = await self.trainer.train_single_product_model(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
training_data=product_data,
job_id=job_id,
session=session # ✅ CRITICAL FIX: Pass session to prevent deadlock
)
except Exception as e:
import traceback
logger.error(f"Training failed with error: {e}")
logger.error(f"Full traceback: {traceback.format_exc()}")
raise
# Update progress (using shared session)
await self._update_job_status_repository(
job_id=job_id,
status="running",
progress=80,
current_step="Saving model",
tenant_id=tenant_id,
session=session
)
await session.commit() # Commit progress update
# The model should already be saved by train_single_product_model
# Return appropriate response
return {
"job_id": job_id,
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id,
"status": "completed",
"message": "Enhanced single product training completed successfully",
"created_at": datetime.now(timezone.utc),
"estimated_duration_minutes": 15, # Default estimate for single product
"training_results": {
"total_products": 1,
"successful_trainings": 1,
"failed_trainings": 0,
"products": [{
"inventory_product_id": inventory_product_id,
"status": "completed",
"model_id": str(model_info.get('model_id', f"model_{inventory_product_id}_{job_id[:8]}")) if model_info.get('model_id') else None,
"data_points": len(product_data) if product_data is not None else 0,
# Filter metrics to ensure only numeric values are included
"metrics": {
k: float(v) if not isinstance(v, (int, float)) else v
for k, v in model_info.get('training_metrics', {"mape": 0.0, "mae": 0.0, "rmse": 0.0, "r2_score": 0.0}).items()
if k != 'product_category' and v is not None
}
}],
"overall_training_time_seconds": model_info.get('training_time', 45.2)
},
"enhanced_features": True,
"repository_integration": True,
"completed_at": datetime.now(timezone.utc).isoformat()
}
2025-11-05 13:34:56 +01:00
except Exception as e:
2025-11-14 20:27:39 +01:00
logger.error("Enhanced single product training failed",
inventory_product_id=inventory_product_id,
error=str(e))
# Update status to failed (using shared session)
await self._update_job_status_repository(
job_id=job_id,
status="failed",
progress=0,
current_step="Training failed",
error_message=str(e),
tenant_id=tenant_id,
session=session
)
await session.commit() # Commit failure status
2025-11-05 13:34:56 +01:00
raise
2025-08-08 09:08:41 +02:00
def _create_detailed_training_response(self, final_result: Dict[str, Any]) -> Dict[str, Any]:
"""Convert final result to detailed training response"""
try:
training_results_data = final_result.get("training_results", {})
stored_models = final_result.get("stored_models", [])
2025-08-03 14:55:13 +02:00
2025-08-08 09:08:41 +02:00
# Convert stored models to product results
products = []
for model in stored_models:
products.append({
2025-08-14 16:47:34 +02:00
"inventory_product_id": model.get("inventory_product_id"),
2025-08-08 09:08:41 +02:00
"status": "completed",
"model_id": model.get("id"),
"data_points": model.get("training_samples", 0),
"metrics": {
"mape": model.get("mape"),
"mae": model.get("mae"),
"rmse": model.get("rmse"),
"r2_score": model.get("r2_score")
},
"model_path": model.get("model_path")
})
2025-08-03 14:55:13 +02:00
2025-08-08 09:08:41 +02:00
# Build the response
response_data = {
"job_id": final_result["job_id"],
"tenant_id": final_result["tenant_id"],
"status": final_result["status"],
"message": f"Training {final_result['status']} successfully",
"created_at": datetime.now(),
2025-11-05 13:34:56 +01:00
"estimated_duration_minutes": final_result.get("estimated_duration_minutes", 15),
2025-08-08 09:08:41 +02:00
"training_results": {
"total_products": len(products),
"successful_trainings": len([p for p in products if p["status"] == "completed"]),
"failed_trainings": len([p for p in products if p["status"] == "failed"]),
"products": products,
"overall_training_time_seconds": training_results_data.get("total_training_time", 0)
},
"data_summary": final_result.get("data_summary", {}),
"completed_at": final_result.get("completed_at")
}
2025-08-03 14:55:13 +02:00
2025-08-08 09:08:41 +02:00
return response_data
2025-08-03 14:55:13 +02:00
except Exception as e:
2025-08-08 09:08:41 +02:00
logger.error("Failed to create detailed response", error=str(e))
return final_result