2025-07-17 14:34:24 +02:00
|
|
|
"""
|
2025-08-08 09:08:41 +02:00
|
|
|
Enhanced Training Service with Repository Pattern
|
|
|
|
|
Main training service that uses the repository pattern for data access
|
2025-07-17 14:34:24 +02:00
|
|
|
"""
|
2025-07-19 16:59:37 +02:00
|
|
|
|
|
|
|
|
from typing import Dict, List, Any, Optional
|
|
|
|
|
import uuid
|
2025-08-08 09:08:41 +02:00
|
|
|
import structlog
|
2025-08-17 13:35:05 +02:00
|
|
|
from datetime import datetime, date, timezone
|
|
|
|
|
from decimal import Decimal
|
2025-07-17 14:34:24 +02:00
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
2025-08-08 09:08:41 +02:00
|
|
|
import json
|
|
|
|
|
import numpy as np
|
|
|
|
|
import pandas as pd
|
2025-07-17 14:34:24 +02:00
|
|
|
|
2025-10-09 14:11:02 +02:00
|
|
|
from app.ml.trainer import EnhancedBakeryMLTrainer
|
2025-07-28 19:28:39 +02:00
|
|
|
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType
|
|
|
|
|
from app.services.training_orchestrator import TrainingDataOrchestrator
|
Fix multiple critical bugs in onboarding training step
This commit addresses all identified bugs and issues in the training code path:
## Critical Fixes:
- Add get_start_time() method to TrainingLogRepository and fix non-existent method call
- Remove duplicate training.started event from API endpoint (trainer publishes the accurate one)
- Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone"
## High Priority Fixes:
- Fix division by zero risk in time estimation with double-check and max() safety
- Remove unreachable exception handler in training_operations.py
- Simplify WebSocket token refresh logic to only reconnect on actual user session changes
## Medium Priority Fixes:
- Fix auto-start training effect with useRef to prevent duplicate starts
- Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket
- Extract all magic numbers to centralized constants files:
- Backend: services/training/app/core/training_constants.py
- Frontend: frontend/src/constants/training.ts
- Standardize error logging with exc_info=True on critical errors
## Code Quality Improvements:
- All progress percentages now use named constants
- All timeouts and intervals now use named constants
- Improved code maintainability and readability
- Better separation of concerns
## Files Changed:
- Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py
- Backend: training_operations.py, training_log_repository.py, training_constants.py (new)
- Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new)
All training progress events now properly flow from 0% to 100% with no gaps.
2025-11-05 13:02:39 +00:00
|
|
|
from app.core.training_constants import (
|
|
|
|
|
PROGRESS_DATA_VALIDATION,
|
|
|
|
|
PROGRESS_DATA_PREPARATION_COMPLETE,
|
|
|
|
|
PROGRESS_ML_TRAINING_START,
|
|
|
|
|
PROGRESS_TRAINING_COMPLETE,
|
|
|
|
|
PROGRESS_STORING_MODELS,
|
|
|
|
|
PROGRESS_STORING_METRICS,
|
|
|
|
|
MAX_ESTIMATED_TIME_REMAINING_SECONDS
|
|
|
|
|
)
|
2025-07-28 19:28:39 +02:00
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
# Import repositories
|
|
|
|
|
from app.repositories import (
|
|
|
|
|
ModelRepository,
|
|
|
|
|
TrainingLogRepository,
|
|
|
|
|
PerformanceRepository,
|
|
|
|
|
JobQueueRepository,
|
|
|
|
|
ArtifactRepository
|
|
|
|
|
)
|
2025-07-17 14:34:24 +02:00
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
# Import shared database components
|
|
|
|
|
from shared.database.unit_of_work import UnitOfWork
|
|
|
|
|
from shared.database.transactions import transactional
|
|
|
|
|
from shared.database.exceptions import DatabaseError
|
|
|
|
|
from app.core.database import database_manager
|
2025-08-03 14:55:13 +02:00
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
logger = structlog.get_logger()
|
2025-07-31 15:34:35 +02:00
|
|
|
|
2025-08-04 18:58:12 +02:00
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
def make_json_serializable(obj):
|
2025-08-16 20:13:40 +02:00
|
|
|
"""Convert numpy/pandas types, datetime, and UUID objects to JSON-serializable Python types"""
|
2025-08-08 09:08:41 +02:00
|
|
|
|
2025-08-17 13:35:05 +02:00
|
|
|
# Handle None values
|
|
|
|
|
if obj is None:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# Handle basic datetime types first (most common)
|
|
|
|
|
if isinstance(obj, datetime):
|
|
|
|
|
return obj.isoformat()
|
|
|
|
|
elif isinstance(obj, date):
|
|
|
|
|
return obj.isoformat()
|
|
|
|
|
|
|
|
|
|
# Handle pandas timestamp types
|
|
|
|
|
if hasattr(pd, 'Timestamp') and isinstance(obj, pd.Timestamp):
|
|
|
|
|
return obj.isoformat()
|
|
|
|
|
|
|
|
|
|
# Handle numpy datetime types
|
|
|
|
|
if hasattr(np, 'datetime64') and isinstance(obj, np.datetime64):
|
|
|
|
|
return pd.Timestamp(obj).isoformat()
|
|
|
|
|
|
|
|
|
|
# Handle numeric types
|
2025-08-08 09:08:41 +02:00
|
|
|
if isinstance(obj, (np.integer, pd.Int64Dtype)):
|
|
|
|
|
return int(obj)
|
|
|
|
|
elif isinstance(obj, (np.floating, pd.Float64Dtype)):
|
|
|
|
|
return float(obj)
|
|
|
|
|
elif isinstance(obj, np.ndarray):
|
|
|
|
|
return obj.tolist()
|
|
|
|
|
elif isinstance(obj, pd.Series):
|
|
|
|
|
return obj.tolist()
|
|
|
|
|
elif isinstance(obj, pd.DataFrame):
|
|
|
|
|
return obj.to_dict('records')
|
2025-08-17 13:35:05 +02:00
|
|
|
elif isinstance(obj, Decimal):
|
|
|
|
|
return float(obj)
|
|
|
|
|
|
|
|
|
|
# Handle UUID types
|
2025-08-08 09:08:41 +02:00
|
|
|
elif isinstance(obj, uuid.UUID):
|
|
|
|
|
return str(obj)
|
|
|
|
|
elif hasattr(obj, '__class__') and 'UUID' in str(obj.__class__):
|
|
|
|
|
# Handle any UUID-like objects (including asyncpg.pgproto.pgproto.UUID)
|
|
|
|
|
return str(obj)
|
2025-08-17 13:35:05 +02:00
|
|
|
|
|
|
|
|
# Handle collections recursively
|
2025-08-08 09:08:41 +02:00
|
|
|
elif isinstance(obj, dict):
|
|
|
|
|
return {k: make_json_serializable(v) for k, v in obj.items()}
|
2025-08-17 13:35:05 +02:00
|
|
|
elif isinstance(obj, (list, tuple)):
|
2025-08-08 09:08:41 +02:00
|
|
|
return [make_json_serializable(item) for item in obj]
|
2025-08-17 13:35:05 +02:00
|
|
|
elif isinstance(obj, set):
|
|
|
|
|
return [make_json_serializable(item) for item in obj]
|
|
|
|
|
|
|
|
|
|
# Handle other common types
|
|
|
|
|
elif isinstance(obj, (str, int, float, bool)):
|
2025-08-08 09:08:41 +02:00
|
|
|
return obj
|
2025-08-17 13:35:05 +02:00
|
|
|
|
|
|
|
|
# Last resort: try to convert to string
|
|
|
|
|
else:
|
|
|
|
|
try:
|
|
|
|
|
# For any other object, try to convert to string
|
|
|
|
|
return str(obj)
|
|
|
|
|
except Exception:
|
|
|
|
|
# If all else fails, return None
|
|
|
|
|
return None
|
2025-08-08 09:08:41 +02:00
|
|
|
|
2025-07-17 14:34:24 +02:00
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
class EnhancedTrainingService:
|
2025-07-19 16:59:37 +02:00
|
|
|
"""
|
2025-08-08 09:08:41 +02:00
|
|
|
Enhanced training service using repository pattern.
|
|
|
|
|
Coordinates the complete training pipeline with proper data abstraction.
|
2025-07-19 16:59:37 +02:00
|
|
|
"""
|
2025-07-17 14:34:24 +02:00
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
def __init__(self, session: AsyncSession = None):
|
|
|
|
|
self.session = session
|
|
|
|
|
self.database_manager = database_manager
|
|
|
|
|
|
|
|
|
|
# Initialize repositories
|
|
|
|
|
if session:
|
|
|
|
|
self.model_repo = ModelRepository(session)
|
|
|
|
|
self.training_log_repo = TrainingLogRepository(session)
|
|
|
|
|
self.performance_repo = PerformanceRepository(session)
|
|
|
|
|
self.queue_repo = JobQueueRepository(session)
|
|
|
|
|
self.artifact_repo = ArtifactRepository(session)
|
|
|
|
|
|
|
|
|
|
# Initialize training components
|
2025-10-09 14:11:02 +02:00
|
|
|
self.trainer = EnhancedBakeryMLTrainer(database_manager=self.database_manager)
|
2025-07-28 19:28:39 +02:00
|
|
|
self.date_alignment_service = DateAlignmentService()
|
|
|
|
|
self.orchestrator = TrainingDataOrchestrator(
|
|
|
|
|
date_alignment_service=self.date_alignment_service
|
|
|
|
|
)
|
|
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
async def _init_repositories(self, session: AsyncSession):
|
|
|
|
|
"""Initialize repositories with session"""
|
|
|
|
|
self.model_repo = ModelRepository(session)
|
|
|
|
|
self.training_log_repo = TrainingLogRepository(session)
|
|
|
|
|
self.performance_repo = PerformanceRepository(session)
|
|
|
|
|
self.queue_repo = JobQueueRepository(session)
|
|
|
|
|
self.artifact_repo = ArtifactRepository(session)
|
|
|
|
|
|
2025-07-28 19:28:39 +02:00
|
|
|
async def start_training_job(
|
|
|
|
|
self,
|
|
|
|
|
tenant_id: str,
|
2025-08-08 09:08:41 +02:00
|
|
|
bakery_location: tuple[float, float] = (40.4168, -3.7038),
|
2025-07-28 19:28:39 +02:00
|
|
|
requested_start: Optional[datetime] = None,
|
|
|
|
|
requested_end: Optional[datetime] = None,
|
|
|
|
|
job_id: Optional[str] = None
|
|
|
|
|
) -> Dict[str, Any]:
|
|
|
|
|
"""
|
2025-08-08 09:08:41 +02:00
|
|
|
Start a complete training job for a tenant using repository pattern.
|
2025-07-27 21:32:29 +02:00
|
|
|
|
2025-07-28 19:28:39 +02:00
|
|
|
Args:
|
|
|
|
|
tenant_id: Tenant identifier
|
|
|
|
|
bakery_location: Bakery coordinates (lat, lon)
|
|
|
|
|
requested_start: Optional explicit start date
|
|
|
|
|
requested_end: Optional explicit end date
|
|
|
|
|
job_id: Optional job identifier
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Training job results
|
|
|
|
|
"""
|
|
|
|
|
if not job_id:
|
|
|
|
|
job_id = f"training_{tenant_id}_{uuid.uuid4().hex[:8]}"
|
2025-07-27 21:32:29 +02:00
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
logger.info("Starting enhanced training job",
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
tenant_id=tenant_id)
|
2025-07-30 21:21:02 +02:00
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
# Get session and initialize repositories
|
2025-08-17 11:12:17 +02:00
|
|
|
async with self.database_manager.get_session() as session:
|
2025-08-08 09:08:41 +02:00
|
|
|
await self._init_repositories(session)
|
2025-10-09 14:11:02 +02:00
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
try:
|
2025-11-05 13:24:22 +00:00
|
|
|
# Check if training log already exists, update if found, create if not
|
2025-08-16 20:13:40 +02:00
|
|
|
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
|
2025-11-05 13:24:22 +00:00
|
|
|
|
2025-08-16 20:13:40 +02:00
|
|
|
if existing_log:
|
|
|
|
|
logger.info("Training log already exists, updating status", job_id=job_id)
|
|
|
|
|
training_log = await self.training_log_repo.update_log_progress(
|
|
|
|
|
job_id, 0, "initializing", "running"
|
|
|
|
|
)
|
2025-11-05 13:24:22 +00:00
|
|
|
await session.commit()
|
2025-08-16 20:13:40 +02:00
|
|
|
else:
|
|
|
|
|
# Create new training log entry
|
|
|
|
|
log_data = {
|
|
|
|
|
"job_id": job_id,
|
|
|
|
|
"tenant_id": tenant_id,
|
|
|
|
|
"status": "running",
|
|
|
|
|
"progress": 0,
|
|
|
|
|
"current_step": "initializing"
|
|
|
|
|
}
|
2025-11-05 13:24:22 +00:00
|
|
|
try:
|
|
|
|
|
training_log = await self.training_log_repo.create_training_log(log_data)
|
|
|
|
|
await session.commit() # Explicit commit so other sessions can see it
|
|
|
|
|
except Exception as create_error:
|
|
|
|
|
# Handle race condition: log may have been created by another session
|
|
|
|
|
if "unique constraint" in str(create_error).lower() or "duplicate" in str(create_error).lower():
|
|
|
|
|
logger.debug("Training log already exists (race condition), updating instead", job_id=job_id)
|
|
|
|
|
await session.rollback()
|
|
|
|
|
training_log = await self.training_log_repo.update_log_progress(
|
|
|
|
|
job_id, 0, "initializing", "running"
|
|
|
|
|
)
|
|
|
|
|
await session.commit()
|
|
|
|
|
else:
|
|
|
|
|
raise
|
|
|
|
|
|
2025-08-17 10:28:58 +02:00
|
|
|
# Step 1: Prepare training dataset (includes sales data validation)
|
|
|
|
|
logger.info("Step 1: Preparing and aligning training data (with validation)")
|
2025-08-08 09:08:41 +02:00
|
|
|
await self.training_log_repo.update_log_progress(
|
Fix multiple critical bugs in onboarding training step
This commit addresses all identified bugs and issues in the training code path:
## Critical Fixes:
- Add get_start_time() method to TrainingLogRepository and fix non-existent method call
- Remove duplicate training.started event from API endpoint (trainer publishes the accurate one)
- Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone"
## High Priority Fixes:
- Fix division by zero risk in time estimation with double-check and max() safety
- Remove unreachable exception handler in training_operations.py
- Simplify WebSocket token refresh logic to only reconnect on actual user session changes
## Medium Priority Fixes:
- Fix auto-start training effect with useRef to prevent duplicate starts
- Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket
- Extract all magic numbers to centralized constants files:
- Backend: services/training/app/core/training_constants.py
- Frontend: frontend/src/constants/training.ts
- Standardize error logging with exc_info=True on critical errors
## Code Quality Improvements:
- All progress percentages now use named constants
- All timeouts and intervals now use named constants
- Improved code maintainability and readability
- Better separation of concerns
## Files Changed:
- Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py
- Backend: training_operations.py, training_log_repository.py, training_constants.py (new)
- Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new)
All training progress events now properly flow from 0% to 100% with no gaps.
2025-11-05 13:02:39 +00:00
|
|
|
job_id, PROGRESS_DATA_VALIDATION, "data_validation", "running"
|
2025-08-08 09:08:41 +02:00
|
|
|
)
|
2025-10-09 14:11:02 +02:00
|
|
|
|
2025-08-17 10:28:58 +02:00
|
|
|
# Orchestrator now handles sales data validation to eliminate duplicate fetching
|
2025-08-08 09:08:41 +02:00
|
|
|
training_dataset = await self.orchestrator.prepare_training_data(
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
bakery_location=bakery_location,
|
|
|
|
|
requested_start=requested_start,
|
|
|
|
|
requested_end=requested_end,
|
|
|
|
|
job_id=job_id
|
|
|
|
|
)
|
2025-10-09 14:11:02 +02:00
|
|
|
|
2025-08-17 10:28:58 +02:00
|
|
|
# Log the results from orchestrator's unified sales data fetch
|
2025-10-09 14:11:02 +02:00
|
|
|
logger.info(f"Sales data validation completed: {len(training_dataset.sales_data)} records",
|
2025-08-17 10:28:58 +02:00
|
|
|
tenant_id=tenant_id, job_id=job_id)
|
2025-10-09 14:11:02 +02:00
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
await self.training_log_repo.update_log_progress(
|
Fix multiple critical bugs in onboarding training step
This commit addresses all identified bugs and issues in the training code path:
## Critical Fixes:
- Add get_start_time() method to TrainingLogRepository and fix non-existent method call
- Remove duplicate training.started event from API endpoint (trainer publishes the accurate one)
- Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone"
## High Priority Fixes:
- Fix division by zero risk in time estimation with double-check and max() safety
- Remove unreachable exception handler in training_operations.py
- Simplify WebSocket token refresh logic to only reconnect on actual user session changes
## Medium Priority Fixes:
- Fix auto-start training effect with useRef to prevent duplicate starts
- Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket
- Extract all magic numbers to centralized constants files:
- Backend: services/training/app/core/training_constants.py
- Frontend: frontend/src/constants/training.ts
- Standardize error logging with exc_info=True on critical errors
## Code Quality Improvements:
- All progress percentages now use named constants
- All timeouts and intervals now use named constants
- Improved code maintainability and readability
- Better separation of concerns
## Files Changed:
- Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py
- Backend: training_operations.py, training_log_repository.py, training_constants.py (new)
- Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new)
All training progress events now properly flow from 0% to 100% with no gaps.
2025-11-05 13:02:39 +00:00
|
|
|
job_id, PROGRESS_DATA_PREPARATION_COMPLETE, "data_preparation_complete", "running"
|
2025-08-08 09:08:41 +02:00
|
|
|
)
|
Fix multiple critical bugs in onboarding training step
This commit addresses all identified bugs and issues in the training code path:
## Critical Fixes:
- Add get_start_time() method to TrainingLogRepository and fix non-existent method call
- Remove duplicate training.started event from API endpoint (trainer publishes the accurate one)
- Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone"
## High Priority Fixes:
- Fix division by zero risk in time estimation with double-check and max() safety
- Remove unreachable exception handler in training_operations.py
- Simplify WebSocket token refresh logic to only reconnect on actual user session changes
## Medium Priority Fixes:
- Fix auto-start training effect with useRef to prevent duplicate starts
- Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket
- Extract all magic numbers to centralized constants files:
- Backend: services/training/app/core/training_constants.py
- Frontend: frontend/src/constants/training.ts
- Standardize error logging with exc_info=True on critical errors
## Code Quality Improvements:
- All progress percentages now use named constants
- All timeouts and intervals now use named constants
- Improved code maintainability and readability
- Better separation of concerns
## Files Changed:
- Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py
- Backend: training_operations.py, training_log_repository.py, training_constants.py (new)
- Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new)
All training progress events now properly flow from 0% to 100% with no gaps.
2025-11-05 13:02:39 +00:00
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
# Step 2: Execute ML training pipeline
|
|
|
|
|
logger.info("Step 2: Starting ML training pipeline")
|
|
|
|
|
await self.training_log_repo.update_log_progress(
|
Fix multiple critical bugs in onboarding training step
This commit addresses all identified bugs and issues in the training code path:
## Critical Fixes:
- Add get_start_time() method to TrainingLogRepository and fix non-existent method call
- Remove duplicate training.started event from API endpoint (trainer publishes the accurate one)
- Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone"
## High Priority Fixes:
- Fix division by zero risk in time estimation with double-check and max() safety
- Remove unreachable exception handler in training_operations.py
- Simplify WebSocket token refresh logic to only reconnect on actual user session changes
## Medium Priority Fixes:
- Fix auto-start training effect with useRef to prevent duplicate starts
- Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket
- Extract all magic numbers to centralized constants files:
- Backend: services/training/app/core/training_constants.py
- Frontend: frontend/src/constants/training.ts
- Standardize error logging with exc_info=True on critical errors
## Code Quality Improvements:
- All progress percentages now use named constants
- All timeouts and intervals now use named constants
- Improved code maintainability and readability
- Better separation of concerns
## Files Changed:
- Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py
- Backend: training_operations.py, training_log_repository.py, training_constants.py (new)
- Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new)
All training progress events now properly flow from 0% to 100% with no gaps.
2025-11-05 13:02:39 +00:00
|
|
|
job_id, PROGRESS_ML_TRAINING_START, "ml_training", "running"
|
2025-08-08 09:08:41 +02:00
|
|
|
)
|
2025-10-09 14:11:02 +02:00
|
|
|
|
2025-11-05 16:30:15 +01:00
|
|
|
# ✅ FIX: Commit the session to prevent deadlock with trainer's nested session
|
|
|
|
|
# The trainer creates its own session, so we need to ensure this update is committed
|
|
|
|
|
await session.commit()
|
|
|
|
|
logger.debug("Committed session after ml_training progress update")
|
|
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
training_results = await self.trainer.train_tenant_models(
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
training_dataset=training_dataset,
|
2025-11-05 18:47:20 +01:00
|
|
|
job_id=job_id,
|
|
|
|
|
session=session # Pass the main session to avoid nested sessions
|
2025-08-08 09:08:41 +02:00
|
|
|
)
|
2025-10-09 14:11:02 +02:00
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
await self.training_log_repo.update_log_progress(
|
Fix multiple critical bugs in onboarding training step
This commit addresses all identified bugs and issues in the training code path:
## Critical Fixes:
- Add get_start_time() method to TrainingLogRepository and fix non-existent method call
- Remove duplicate training.started event from API endpoint (trainer publishes the accurate one)
- Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone"
## High Priority Fixes:
- Fix division by zero risk in time estimation with double-check and max() safety
- Remove unreachable exception handler in training_operations.py
- Simplify WebSocket token refresh logic to only reconnect on actual user session changes
## Medium Priority Fixes:
- Fix auto-start training effect with useRef to prevent duplicate starts
- Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket
- Extract all magic numbers to centralized constants files:
- Backend: services/training/app/core/training_constants.py
- Frontend: frontend/src/constants/training.ts
- Standardize error logging with exc_info=True on critical errors
## Code Quality Improvements:
- All progress percentages now use named constants
- All timeouts and intervals now use named constants
- Improved code maintainability and readability
- Better separation of concerns
## Files Changed:
- Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py
- Backend: training_operations.py, training_log_repository.py, training_constants.py (new)
- Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new)
All training progress events now properly flow from 0% to 100% with no gaps.
2025-11-05 13:02:39 +00:00
|
|
|
job_id, PROGRESS_TRAINING_COMPLETE, "training_complete", "running"
|
2025-08-08 09:08:41 +02:00
|
|
|
)
|
Fix multiple critical bugs in onboarding training step
This commit addresses all identified bugs and issues in the training code path:
## Critical Fixes:
- Add get_start_time() method to TrainingLogRepository and fix non-existent method call
- Remove duplicate training.started event from API endpoint (trainer publishes the accurate one)
- Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone"
## High Priority Fixes:
- Fix division by zero risk in time estimation with double-check and max() safety
- Remove unreachable exception handler in training_operations.py
- Simplify WebSocket token refresh logic to only reconnect on actual user session changes
## Medium Priority Fixes:
- Fix auto-start training effect with useRef to prevent duplicate starts
- Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket
- Extract all magic numbers to centralized constants files:
- Backend: services/training/app/core/training_constants.py
- Frontend: frontend/src/constants/training.ts
- Standardize error logging with exc_info=True on critical errors
## Code Quality Improvements:
- All progress percentages now use named constants
- All timeouts and intervals now use named constants
- Improved code maintainability and readability
- Better separation of concerns
## Files Changed:
- Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py
- Backend: training_operations.py, training_log_repository.py, training_constants.py (new)
- Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new)
All training progress events now properly flow from 0% to 100% with no gaps.
2025-11-05 13:02:39 +00:00
|
|
|
|
|
|
|
|
# Publish progress event (85%)
|
|
|
|
|
from app.services.training_events import publish_training_progress
|
|
|
|
|
await publish_training_progress(
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
progress=PROGRESS_TRAINING_COMPLETE,
|
|
|
|
|
current_step="Training Complete",
|
|
|
|
|
step_details="All products trained successfully"
|
|
|
|
|
)
|
|
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
# Step 3: Store model records using repository
|
|
|
|
|
logger.info("Step 3: Storing model records")
|
|
|
|
|
logger.debug("Training results structure",
|
|
|
|
|
keys=list(training_results.keys()) if isinstance(training_results, dict) else "not_dict",
|
|
|
|
|
training_results_type=type(training_results).__name__)
|
2025-10-09 14:11:02 +02:00
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
stored_models = await self._store_trained_models(
|
|
|
|
|
tenant_id, job_id, training_results
|
|
|
|
|
)
|
2025-10-09 14:11:02 +02:00
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
await self.training_log_repo.update_log_progress(
|
Fix multiple critical bugs in onboarding training step
This commit addresses all identified bugs and issues in the training code path:
## Critical Fixes:
- Add get_start_time() method to TrainingLogRepository and fix non-existent method call
- Remove duplicate training.started event from API endpoint (trainer publishes the accurate one)
- Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone"
## High Priority Fixes:
- Fix division by zero risk in time estimation with double-check and max() safety
- Remove unreachable exception handler in training_operations.py
- Simplify WebSocket token refresh logic to only reconnect on actual user session changes
## Medium Priority Fixes:
- Fix auto-start training effect with useRef to prevent duplicate starts
- Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket
- Extract all magic numbers to centralized constants files:
- Backend: services/training/app/core/training_constants.py
- Frontend: frontend/src/constants/training.ts
- Standardize error logging with exc_info=True on critical errors
## Code Quality Improvements:
- All progress percentages now use named constants
- All timeouts and intervals now use named constants
- Improved code maintainability and readability
- Better separation of concerns
## Files Changed:
- Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py
- Backend: training_operations.py, training_log_repository.py, training_constants.py (new)
- Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new)
All training progress events now properly flow from 0% to 100% with no gaps.
2025-11-05 13:02:39 +00:00
|
|
|
job_id, PROGRESS_STORING_MODELS, "storing_models", "running"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Publish progress event (92%)
|
|
|
|
|
await publish_training_progress(
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
progress=PROGRESS_STORING_MODELS,
|
|
|
|
|
current_step="Storing Models",
|
|
|
|
|
step_details=f"Saved {len(stored_models)} trained models to database"
|
2025-08-08 09:08:41 +02:00
|
|
|
)
|
2025-10-09 14:11:02 +02:00
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
# Step 4: Create performance metrics
|
2025-10-15 21:09:42 +02:00
|
|
|
await self.training_log_repo.update_log_progress(
|
Fix multiple critical bugs in onboarding training step
This commit addresses all identified bugs and issues in the training code path:
## Critical Fixes:
- Add get_start_time() method to TrainingLogRepository and fix non-existent method call
- Remove duplicate training.started event from API endpoint (trainer publishes the accurate one)
- Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone"
## High Priority Fixes:
- Fix division by zero risk in time estimation with double-check and max() safety
- Remove unreachable exception handler in training_operations.py
- Simplify WebSocket token refresh logic to only reconnect on actual user session changes
## Medium Priority Fixes:
- Fix auto-start training effect with useRef to prevent duplicate starts
- Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket
- Extract all magic numbers to centralized constants files:
- Backend: services/training/app/core/training_constants.py
- Frontend: frontend/src/constants/training.ts
- Standardize error logging with exc_info=True on critical errors
## Code Quality Improvements:
- All progress percentages now use named constants
- All timeouts and intervals now use named constants
- Improved code maintainability and readability
- Better separation of concerns
## Files Changed:
- Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py
- Backend: training_operations.py, training_log_repository.py, training_constants.py (new)
- Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new)
All training progress events now properly flow from 0% to 100% with no gaps.
2025-11-05 13:02:39 +00:00
|
|
|
job_id, PROGRESS_STORING_METRICS, "storing_performance_metrics", "running"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Publish progress event (94%)
|
|
|
|
|
await publish_training_progress(
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
progress=PROGRESS_STORING_METRICS,
|
|
|
|
|
current_step="Storing Performance Metrics",
|
|
|
|
|
step_details="Saving model performance metrics"
|
2025-10-15 21:09:42 +02:00
|
|
|
)
|
2025-10-09 14:11:02 +02:00
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
await self._create_performance_metrics(
|
|
|
|
|
tenant_id, stored_models, training_results
|
|
|
|
|
)
|
2025-10-09 14:11:02 +02:00
|
|
|
|
2025-10-15 21:09:42 +02:00
|
|
|
# Step 4.5: Save training performance metrics for future estimations
|
|
|
|
|
await self._save_training_performance_metrics(
|
|
|
|
|
tenant_id, job_id, training_results, training_log
|
|
|
|
|
)
|
|
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
# Step 5: Complete training log
|
|
|
|
|
final_result = {
|
|
|
|
|
"job_id": job_id,
|
|
|
|
|
"tenant_id": tenant_id,
|
|
|
|
|
"status": "completed",
|
|
|
|
|
"training_results": training_results,
|
|
|
|
|
"stored_models": [{
|
|
|
|
|
"id": str(model.id),
|
2025-08-14 16:47:34 +02:00
|
|
|
"inventory_product_id": str(model.inventory_product_id),
|
2025-08-08 09:08:41 +02:00
|
|
|
"model_type": model.model_type,
|
|
|
|
|
"model_path": model.model_path,
|
|
|
|
|
"is_active": model.is_active,
|
|
|
|
|
"training_samples": model.training_samples
|
|
|
|
|
} for model in stored_models],
|
|
|
|
|
"data_summary": {
|
|
|
|
|
"sales_records": int(len(training_dataset.sales_data)),
|
|
|
|
|
"weather_records": int(len(training_dataset.weather_data)),
|
|
|
|
|
"traffic_records": int(len(training_dataset.traffic_data)),
|
|
|
|
|
"date_range": {
|
|
|
|
|
"start": training_dataset.date_range.start.isoformat(),
|
|
|
|
|
"end": training_dataset.date_range.end.isoformat()
|
|
|
|
|
},
|
|
|
|
|
"data_sources_used": [source.value for source in training_dataset.date_range.available_sources],
|
|
|
|
|
"constraints_applied": training_dataset.date_range.constraints
|
2025-07-28 21:30:49 +02:00
|
|
|
},
|
2025-08-08 09:08:41 +02:00
|
|
|
"completed_at": datetime.now().isoformat()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Make sure all data is JSON-serializable before saving to database
|
|
|
|
|
json_safe_result = make_json_serializable(final_result)
|
|
|
|
|
|
2025-08-17 10:28:58 +02:00
|
|
|
# Ensure results is a proper dict for database storage
|
|
|
|
|
if not isinstance(json_safe_result, dict):
|
|
|
|
|
logger.warning("JSON safe result is not a dict, wrapping it", result_type=type(json_safe_result))
|
|
|
|
|
json_safe_result = {"training_data": json_safe_result}
|
|
|
|
|
|
|
|
|
|
# Double-check JSON serialization by attempting to serialize
|
|
|
|
|
import json
|
|
|
|
|
try:
|
|
|
|
|
json.dumps(json_safe_result)
|
|
|
|
|
logger.debug("Results successfully JSON-serializable", job_id=job_id)
|
|
|
|
|
except (TypeError, ValueError) as e:
|
|
|
|
|
logger.error("Results still not JSON-serializable after cleaning",
|
|
|
|
|
job_id=job_id, error=str(e))
|
|
|
|
|
# Create a minimal safe result
|
|
|
|
|
json_safe_result = {
|
|
|
|
|
"status": "completed",
|
|
|
|
|
"job_id": job_id,
|
|
|
|
|
"models_created": final_result.get("products_trained", 0),
|
|
|
|
|
"error": "Result serialization failed"
|
|
|
|
|
}
|
|
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
await self.training_log_repo.complete_training_log(
|
|
|
|
|
job_id, results=json_safe_result
|
|
|
|
|
)
|
2025-10-09 14:11:02 +02:00
|
|
|
|
2025-12-30 14:40:20 +01:00
|
|
|
# CRITICAL: Commit the session to persist the completed status to database
|
|
|
|
|
# Without this commit, the status update is lost when the session closes
|
|
|
|
|
await session.commit()
|
|
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
logger.info("Enhanced training job completed successfully",
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
models_created=len(stored_models))
|
2025-10-09 14:11:02 +02:00
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
return self._create_detailed_training_response(final_result)
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error("Enhanced training job failed",
|
|
|
|
|
job_id=job_id,
|
Fix multiple critical bugs in onboarding training step
This commit addresses all identified bugs and issues in the training code path:
## Critical Fixes:
- Add get_start_time() method to TrainingLogRepository and fix non-existent method call
- Remove duplicate training.started event from API endpoint (trainer publishes the accurate one)
- Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone"
## High Priority Fixes:
- Fix division by zero risk in time estimation with double-check and max() safety
- Remove unreachable exception handler in training_operations.py
- Simplify WebSocket token refresh logic to only reconnect on actual user session changes
## Medium Priority Fixes:
- Fix auto-start training effect with useRef to prevent duplicate starts
- Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket
- Extract all magic numbers to centralized constants files:
- Backend: services/training/app/core/training_constants.py
- Frontend: frontend/src/constants/training.ts
- Standardize error logging with exc_info=True on critical errors
## Code Quality Improvements:
- All progress percentages now use named constants
- All timeouts and intervals now use named constants
- Improved code maintainability and readability
- Better separation of concerns
## Files Changed:
- Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py
- Backend: training_operations.py, training_log_repository.py, training_constants.py (new)
- Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new)
All training progress events now properly flow from 0% to 100% with no gaps.
2025-11-05 13:02:39 +00:00
|
|
|
error=str(e),
|
|
|
|
|
exc_info=True)
|
|
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
# Mark as failed in database
|
|
|
|
|
await self.training_log_repo.complete_training_log(
|
|
|
|
|
job_id, error_message=str(e)
|
|
|
|
|
)
|
2025-12-30 14:40:20 +01:00
|
|
|
|
|
|
|
|
# Commit the failure status to database
|
|
|
|
|
await session.commit()
|
|
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
error_result = {
|
|
|
|
|
"job_id": job_id,
|
|
|
|
|
"tenant_id": tenant_id,
|
|
|
|
|
"status": "failed",
|
|
|
|
|
"error_message": str(e),
|
|
|
|
|
"completed_at": datetime.now().isoformat()
|
|
|
|
|
}
|
|
|
|
|
|
2025-08-17 10:28:58 +02:00
|
|
|
# Ensure error result is JSON serializable
|
|
|
|
|
error_result = make_json_serializable(error_result)
|
|
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
return self._create_detailed_training_response(error_result)
|
2025-07-17 14:34:24 +02:00
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
async def _store_trained_models(
|
2025-07-28 19:28:39 +02:00
|
|
|
self,
|
|
|
|
|
tenant_id: str,
|
2025-08-08 09:08:41 +02:00
|
|
|
job_id: str,
|
|
|
|
|
training_results: Dict[str, Any]
|
|
|
|
|
) -> List:
|
2025-11-05 12:41:42 +00:00
|
|
|
"""
|
|
|
|
|
Retrieve or verify stored models from training results.
|
|
|
|
|
|
|
|
|
|
NOTE: Model records are now created by the trainer during parallel execution.
|
|
|
|
|
This method retrieves the already-created models instead of creating duplicates.
|
|
|
|
|
"""
|
2025-08-08 09:08:41 +02:00
|
|
|
stored_models = []
|
2025-11-05 12:41:42 +00:00
|
|
|
|
2025-07-19 16:59:37 +02:00
|
|
|
try:
|
2025-11-05 12:41:42 +00:00
|
|
|
# Check if models were already created by the trainer (new approach)
|
|
|
|
|
# The trainer now writes models sequentially after parallel training
|
|
|
|
|
training_results_dict = training_results.get("training_results", {})
|
|
|
|
|
|
|
|
|
|
# Get list of successfully trained products
|
|
|
|
|
successful_products = [
|
|
|
|
|
product_id for product_id, result in training_results_dict.items()
|
|
|
|
|
if result.get('status') == 'success' and result.get('model_record_id')
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
logger.info("Retrieving models created by trainer",
|
|
|
|
|
successful_products=len(successful_products),
|
|
|
|
|
job_id=job_id)
|
|
|
|
|
|
|
|
|
|
# Retrieve the models that were already created by the trainer
|
|
|
|
|
for product_id in successful_products:
|
|
|
|
|
result = training_results_dict[product_id]
|
|
|
|
|
model_record_id = result.get('model_record_id')
|
|
|
|
|
|
|
|
|
|
if model_record_id:
|
|
|
|
|
try:
|
|
|
|
|
# Get the model from the database using base repository method
|
|
|
|
|
model = await self.model_repo.get_by_id(model_record_id)
|
|
|
|
|
if model:
|
|
|
|
|
stored_models.append(model)
|
|
|
|
|
logger.debug("Retrieved model from database",
|
|
|
|
|
model_id=model_record_id,
|
|
|
|
|
inventory_product_id=product_id)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning("Could not retrieve model record",
|
|
|
|
|
model_id=model_record_id,
|
|
|
|
|
inventory_product_id=product_id,
|
|
|
|
|
error=str(e))
|
|
|
|
|
|
|
|
|
|
logger.info("Models retrieval complete",
|
|
|
|
|
models_retrieved=len(stored_models),
|
|
|
|
|
expected=len(successful_products))
|
|
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
return stored_models
|
2025-11-05 12:41:42 +00:00
|
|
|
|
2025-07-19 16:59:37 +02:00
|
|
|
except Exception as e:
|
2025-11-05 12:41:42 +00:00
|
|
|
logger.error("Failed to retrieve stored models",
|
2025-08-08 09:08:41 +02:00
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
error=str(e))
|
|
|
|
|
return stored_models
|
2025-07-19 16:59:37 +02:00
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
async def _create_performance_metrics(
|
2025-07-28 19:28:39 +02:00
|
|
|
self,
|
|
|
|
|
tenant_id: str,
|
2025-08-08 09:08:41 +02:00
|
|
|
stored_models: List,
|
|
|
|
|
training_results: Dict[str, Any]
|
|
|
|
|
):
|
2025-11-05 12:41:42 +00:00
|
|
|
"""
|
|
|
|
|
Verify performance metrics for stored models.
|
|
|
|
|
|
|
|
|
|
NOTE: Performance metrics are now created by the trainer during model creation.
|
|
|
|
|
This method now just verifies they exist rather than creating duplicates.
|
|
|
|
|
"""
|
2025-07-19 16:59:37 +02:00
|
|
|
try:
|
2025-11-05 12:41:42 +00:00
|
|
|
logger.info("Verifying performance metrics",
|
|
|
|
|
models_count=len(stored_models))
|
2025-10-15 21:09:42 +02:00
|
|
|
|
2025-11-05 12:41:42 +00:00
|
|
|
# Performance metrics are already created by the trainer
|
|
|
|
|
# This method is kept for compatibility but doesn't create duplicates
|
|
|
|
|
for model in stored_models:
|
|
|
|
|
logger.debug("Performance metrics already created for model",
|
|
|
|
|
model_id=str(model.id),
|
|
|
|
|
inventory_product_id=str(model.inventory_product_id))
|
2025-10-15 21:09:42 +02:00
|
|
|
|
2025-11-05 12:41:42 +00:00
|
|
|
logger.info("Performance metrics verification complete",
|
|
|
|
|
models_count=len(stored_models))
|
2025-10-15 21:09:42 +02:00
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
except Exception as e:
|
2025-11-05 12:41:42 +00:00
|
|
|
logger.error("Failed to verify performance metrics",
|
2025-08-08 09:08:41 +02:00
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
error=str(e))
|
2025-10-15 21:09:42 +02:00
|
|
|
|
|
|
|
|
async def _save_training_performance_metrics(
|
|
|
|
|
self,
|
|
|
|
|
tenant_id: str,
|
|
|
|
|
job_id: str,
|
|
|
|
|
training_results: Dict[str, Any],
|
|
|
|
|
training_log
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Save aggregated training performance metrics for time estimation.
|
|
|
|
|
This data is used to predict future training durations.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
from app.models.training import TrainingPerformanceMetrics
|
2025-11-05 13:34:56 +01:00
|
|
|
from shared.database.repository import BaseRepository
|
2025-10-15 21:09:42 +02:00
|
|
|
|
|
|
|
|
# Extract timing and success data
|
|
|
|
|
models_trained = training_results.get("models_trained", {})
|
|
|
|
|
total_products = len(models_trained)
|
|
|
|
|
successful_products = sum(1 for m in models_trained.values() if m.get("status") == "completed")
|
|
|
|
|
failed_products = total_products - successful_products
|
|
|
|
|
|
|
|
|
|
# Calculate total duration
|
|
|
|
|
if training_log.start_time and training_log.end_time:
|
|
|
|
|
total_duration_seconds = (training_log.end_time - training_log.start_time).total_seconds()
|
|
|
|
|
else:
|
|
|
|
|
# Fallback to elapsed time
|
|
|
|
|
total_duration_seconds = training_results.get("total_training_time", 0)
|
|
|
|
|
|
|
|
|
|
# Calculate average time per product
|
|
|
|
|
if successful_products > 0:
|
|
|
|
|
avg_time_per_product = total_duration_seconds / successful_products
|
|
|
|
|
else:
|
|
|
|
|
avg_time_per_product = 0
|
|
|
|
|
|
|
|
|
|
# Extract timing breakdown if available
|
|
|
|
|
data_analysis_time = training_results.get("data_analysis_time_seconds")
|
|
|
|
|
training_time = training_results.get("training_time_seconds")
|
|
|
|
|
finalization_time = training_results.get("finalization_time_seconds")
|
|
|
|
|
|
|
|
|
|
# Create performance metrics record
|
|
|
|
|
metric_data = {
|
|
|
|
|
"tenant_id": tenant_id,
|
|
|
|
|
"job_id": job_id,
|
|
|
|
|
"total_products": total_products,
|
|
|
|
|
"successful_products": successful_products,
|
|
|
|
|
"failed_products": failed_products,
|
|
|
|
|
"total_duration_seconds": total_duration_seconds,
|
|
|
|
|
"avg_time_per_product": avg_time_per_product,
|
|
|
|
|
"data_analysis_time_seconds": data_analysis_time,
|
|
|
|
|
"training_time_seconds": training_time,
|
|
|
|
|
"finalization_time_seconds": finalization_time,
|
|
|
|
|
"completed_at": datetime.now(timezone.utc)
|
|
|
|
|
}
|
|
|
|
|
|
2025-11-05 13:34:56 +01:00
|
|
|
# Create a temporary repository for the TrainingPerformanceMetrics model
|
|
|
|
|
# Use the session from one of the initialized repositories to ensure it's available
|
|
|
|
|
session = self.model_repo.session # This should be the same session used by all repositories
|
|
|
|
|
metrics_repo = BaseRepository(TrainingPerformanceMetrics, session)
|
|
|
|
|
|
2025-10-15 21:09:42 +02:00
|
|
|
# Use repository to create record
|
2025-11-05 13:34:56 +01:00
|
|
|
await metrics_repo.create(metric_data)
|
2025-10-15 21:09:42 +02:00
|
|
|
|
|
|
|
|
logger.info("Saved training performance metrics for future estimations",
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
avg_time_per_product=avg_time_per_product,
|
|
|
|
|
total_products=total_products,
|
|
|
|
|
successful_products=successful_products)
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error("Failed to save training performance metrics",
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
error=str(e))
|
2025-08-08 09:08:41 +02:00
|
|
|
|
|
|
|
|
async def get_training_status(self, job_id: str) -> Dict[str, Any]:
|
|
|
|
|
"""Get training job status using repository"""
|
|
|
|
|
try:
|
2025-10-09 14:11:02 +02:00
|
|
|
async with self.database_manager.get_session() as session:
|
2025-08-08 09:08:41 +02:00
|
|
|
await self._init_repositories(session)
|
2025-10-15 16:12:49 +02:00
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
log = await self.training_log_repo.get_log_by_job_id(job_id)
|
|
|
|
|
if not log:
|
|
|
|
|
return {"error": "Job not found"}
|
2025-10-15 16:12:49 +02:00
|
|
|
|
|
|
|
|
# Calculate estimated time remaining based on progress and elapsed time
|
|
|
|
|
estimated_time_remaining_seconds = None
|
|
|
|
|
if log.status == "running" and log.progress > 0 and log.start_time:
|
|
|
|
|
from datetime import datetime, timezone
|
|
|
|
|
elapsed_time = (datetime.now(timezone.utc) - log.start_time).total_seconds()
|
Fix multiple critical bugs in onboarding training step
This commit addresses all identified bugs and issues in the training code path:
## Critical Fixes:
- Add get_start_time() method to TrainingLogRepository and fix non-existent method call
- Remove duplicate training.started event from API endpoint (trainer publishes the accurate one)
- Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone"
## High Priority Fixes:
- Fix division by zero risk in time estimation with double-check and max() safety
- Remove unreachable exception handler in training_operations.py
- Simplify WebSocket token refresh logic to only reconnect on actual user session changes
## Medium Priority Fixes:
- Fix auto-start training effect with useRef to prevent duplicate starts
- Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket
- Extract all magic numbers to centralized constants files:
- Backend: services/training/app/core/training_constants.py
- Frontend: frontend/src/constants/training.ts
- Standardize error logging with exc_info=True on critical errors
## Code Quality Improvements:
- All progress percentages now use named constants
- All timeouts and intervals now use named constants
- Improved code maintainability and readability
- Better separation of concerns
## Files Changed:
- Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py
- Backend: training_operations.py, training_log_repository.py, training_constants.py (new)
- Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new)
All training progress events now properly flow from 0% to 100% with no gaps.
2025-11-05 13:02:39 +00:00
|
|
|
if elapsed_time > 0 and log.progress > 0: # Double-check progress is positive
|
2025-10-15 16:12:49 +02:00
|
|
|
# Calculate estimated total time based on progress
|
Fix multiple critical bugs in onboarding training step
This commit addresses all identified bugs and issues in the training code path:
## Critical Fixes:
- Add get_start_time() method to TrainingLogRepository and fix non-existent method call
- Remove duplicate training.started event from API endpoint (trainer publishes the accurate one)
- Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone"
## High Priority Fixes:
- Fix division by zero risk in time estimation with double-check and max() safety
- Remove unreachable exception handler in training_operations.py
- Simplify WebSocket token refresh logic to only reconnect on actual user session changes
## Medium Priority Fixes:
- Fix auto-start training effect with useRef to prevent duplicate starts
- Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket
- Extract all magic numbers to centralized constants files:
- Backend: services/training/app/core/training_constants.py
- Frontend: frontend/src/constants/training.ts
- Standardize error logging with exc_info=True on critical errors
## Code Quality Improvements:
- All progress percentages now use named constants
- All timeouts and intervals now use named constants
- Improved code maintainability and readability
- Better separation of concerns
## Files Changed:
- Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py
- Backend: training_operations.py, training_log_repository.py, training_constants.py (new)
- Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new)
All training progress events now properly flow from 0% to 100% with no gaps.
2025-11-05 13:02:39 +00:00
|
|
|
# Use max(log.progress, 1) as additional safety against division by zero
|
|
|
|
|
estimated_total_time = (elapsed_time / max(log.progress, 1)) * 100
|
2025-10-15 16:12:49 +02:00
|
|
|
estimated_time_remaining_seconds = int(estimated_total_time - elapsed_time)
|
Fix multiple critical bugs in onboarding training step
This commit addresses all identified bugs and issues in the training code path:
## Critical Fixes:
- Add get_start_time() method to TrainingLogRepository and fix non-existent method call
- Remove duplicate training.started event from API endpoint (trainer publishes the accurate one)
- Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone"
## High Priority Fixes:
- Fix division by zero risk in time estimation with double-check and max() safety
- Remove unreachable exception handler in training_operations.py
- Simplify WebSocket token refresh logic to only reconnect on actual user session changes
## Medium Priority Fixes:
- Fix auto-start training effect with useRef to prevent duplicate starts
- Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket
- Extract all magic numbers to centralized constants files:
- Backend: services/training/app/core/training_constants.py
- Frontend: frontend/src/constants/training.ts
- Standardize error logging with exc_info=True on critical errors
## Code Quality Improvements:
- All progress percentages now use named constants
- All timeouts and intervals now use named constants
- Improved code maintainability and readability
- Better separation of concerns
## Files Changed:
- Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py
- Backend: training_operations.py, training_log_repository.py, training_constants.py (new)
- Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new)
All training progress events now properly flow from 0% to 100% with no gaps.
2025-11-05 13:02:39 +00:00
|
|
|
# Cap at reasonable maximum (30 minutes)
|
|
|
|
|
estimated_time_remaining_seconds = max(0, min(estimated_time_remaining_seconds, MAX_ESTIMATED_TIME_REMAINING_SECONDS))
|
2025-10-15 16:12:49 +02:00
|
|
|
|
|
|
|
|
# Extract products info from results if available
|
|
|
|
|
products_total = 0
|
|
|
|
|
products_completed = 0
|
|
|
|
|
products_failed = 0
|
|
|
|
|
|
|
|
|
|
if log.results:
|
|
|
|
|
products_total = log.results.get("total_products", 0)
|
|
|
|
|
products_completed = log.results.get("successful_trainings", 0)
|
|
|
|
|
products_failed = log.results.get("failed_trainings", 0)
|
|
|
|
|
|
2025-07-19 16:59:37 +02:00
|
|
|
return {
|
2025-08-08 09:08:41 +02:00
|
|
|
"job_id": job_id,
|
|
|
|
|
"tenant_id": log.tenant_id,
|
|
|
|
|
"status": log.status,
|
|
|
|
|
"progress": log.progress,
|
|
|
|
|
"current_step": log.current_step,
|
2025-10-15 16:12:49 +02:00
|
|
|
"started_at": log.start_time.isoformat() if log.start_time else None,
|
|
|
|
|
"completed_at": log.end_time.isoformat() if log.end_time else None,
|
2025-08-08 09:08:41 +02:00
|
|
|
"error_message": log.error_message,
|
2025-10-15 16:12:49 +02:00
|
|
|
"results": log.results,
|
|
|
|
|
"products_total": products_total,
|
|
|
|
|
"products_completed": products_completed,
|
|
|
|
|
"products_failed": products_failed,
|
|
|
|
|
"estimated_time_remaining_seconds": estimated_time_remaining_seconds,
|
|
|
|
|
"message": log.current_step
|
2025-07-28 19:28:39 +02:00
|
|
|
}
|
2025-10-15 16:12:49 +02:00
|
|
|
|
2025-07-19 16:59:37 +02:00
|
|
|
except Exception as e:
|
2025-08-08 09:08:41 +02:00
|
|
|
logger.error("Failed to get training status",
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
error=str(e))
|
|
|
|
|
return {"error": f"Failed to get status: {str(e)}"}
|
2025-07-17 14:34:24 +02:00
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
async def get_tenant_models(
|
2025-07-28 19:28:39 +02:00
|
|
|
self,
|
|
|
|
|
tenant_id: str,
|
2025-08-08 09:08:41 +02:00
|
|
|
active_only: bool = True,
|
|
|
|
|
skip: int = 0,
|
|
|
|
|
limit: int = 100
|
|
|
|
|
) -> List[Dict[str, Any]]:
|
|
|
|
|
"""Get models for a tenant using repository"""
|
2025-07-19 16:59:37 +02:00
|
|
|
try:
|
2025-08-17 11:12:17 +02:00
|
|
|
async with self.database_manager.get_session() as session:
|
2025-08-08 09:08:41 +02:00
|
|
|
await self._init_repositories(session)
|
|
|
|
|
|
|
|
|
|
if active_only:
|
|
|
|
|
models = await self.model_repo.get_multi(
|
|
|
|
|
filters={"tenant_id": tenant_id, "is_active": True},
|
|
|
|
|
skip=skip,
|
|
|
|
|
limit=limit,
|
|
|
|
|
order_by="created_at",
|
|
|
|
|
order_desc=True
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
models = await self.model_repo.get_models_by_tenant(
|
|
|
|
|
tenant_id, skip=skip, limit=limit
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return [model.to_dict() for model in models]
|
|
|
|
|
|
2025-07-19 16:59:37 +02:00
|
|
|
except Exception as e:
|
2025-08-08 09:08:41 +02:00
|
|
|
logger.error("Failed to get tenant models",
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
error=str(e))
|
|
|
|
|
return []
|
2025-08-03 14:55:13 +02:00
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
async def get_model_performance(self, model_id: str) -> Dict[str, Any]:
|
|
|
|
|
"""Get model performance metrics using repository"""
|
|
|
|
|
try:
|
2025-08-17 11:12:17 +02:00
|
|
|
async with self.database_manager.get_session() as session:
|
2025-08-08 09:08:41 +02:00
|
|
|
await self._init_repositories(session)
|
|
|
|
|
|
|
|
|
|
# Get model summary
|
|
|
|
|
model_summary = await self.model_repo.get_model_performance_summary(model_id)
|
|
|
|
|
|
|
|
|
|
# Get latest performance metrics
|
|
|
|
|
latest_metric = await self.performance_repo.get_latest_metric_for_model(model_id)
|
|
|
|
|
|
|
|
|
|
if latest_metric:
|
|
|
|
|
model_summary["latest_metrics"] = {
|
|
|
|
|
"mae": latest_metric.mae,
|
|
|
|
|
"mse": latest_metric.mse,
|
|
|
|
|
"rmse": latest_metric.rmse,
|
|
|
|
|
"mape": latest_metric.mape,
|
|
|
|
|
"r2_score": latest_metric.r2_score,
|
|
|
|
|
"accuracy_percentage": latest_metric.accuracy_percentage,
|
|
|
|
|
"measured_at": latest_metric.measured_at.isoformat() if latest_metric.measured_at else None
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return model_summary
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error("Failed to get model performance",
|
|
|
|
|
model_id=model_id,
|
|
|
|
|
error=str(e))
|
|
|
|
|
return {}
|
2025-08-03 14:55:13 +02:00
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
async def get_tenant_statistics(self, tenant_id: str) -> Dict[str, Any]:
|
|
|
|
|
"""Get comprehensive tenant statistics using repositories"""
|
|
|
|
|
try:
|
2025-08-17 11:12:17 +02:00
|
|
|
async with self.database_manager.get_session() as session:
|
2025-08-08 09:08:41 +02:00
|
|
|
await self._init_repositories(session)
|
|
|
|
|
|
|
|
|
|
# Get model statistics
|
|
|
|
|
model_stats = await self.model_repo.get_model_statistics(tenant_id)
|
|
|
|
|
|
|
|
|
|
# Get job statistics
|
|
|
|
|
job_stats = await self.training_log_repo.get_job_statistics(tenant_id)
|
|
|
|
|
|
|
|
|
|
# Get performance trends
|
|
|
|
|
performance_trends = await self.performance_repo.get_performance_trends(tenant_id)
|
|
|
|
|
|
|
|
|
|
# Get queue status
|
|
|
|
|
queue_status = await self.queue_repo.get_queue_status(tenant_id)
|
|
|
|
|
|
|
|
|
|
# Get artifact statistics
|
|
|
|
|
artifact_stats = await self.artifact_repo.get_artifact_statistics(tenant_id)
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
"tenant_id": tenant_id,
|
|
|
|
|
"models": model_stats,
|
|
|
|
|
"training_jobs": job_stats,
|
|
|
|
|
"performance": performance_trends,
|
|
|
|
|
"queue": queue_status,
|
|
|
|
|
"artifacts": artifact_stats,
|
|
|
|
|
"summary": {
|
|
|
|
|
"total_active_models": model_stats.get("active_models", 0),
|
|
|
|
|
"total_training_jobs": job_stats.get("total_jobs", 0),
|
|
|
|
|
"success_rate": job_stats.get("success_rate", 0.0),
|
|
|
|
|
"products_with_models": len(model_stats.get("models_by_product", {})),
|
|
|
|
|
"total_storage_mb": artifact_stats.get("total_storage", {}).get("total_size_mb", 0.0)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error("Failed to get tenant statistics",
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
error=str(e))
|
|
|
|
|
return {"error": f"Failed to get statistics: {str(e)}"}
|
2025-08-03 14:55:13 +02:00
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
async def _update_job_status_repository(self,
|
|
|
|
|
job_id: str,
|
|
|
|
|
status: str,
|
|
|
|
|
progress: int = None,
|
|
|
|
|
current_step: str = None,
|
|
|
|
|
error_message: str = None,
|
2025-08-15 17:53:59 +02:00
|
|
|
results: Dict = None,
|
2025-11-14 20:27:39 +01:00
|
|
|
tenant_id: str = None,
|
|
|
|
|
session = None):
|
|
|
|
|
"""Update job status using repository pattern
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
session: Optional database session to reuse. If None, creates a new session.
|
|
|
|
|
"""
|
2025-08-08 09:08:41 +02:00
|
|
|
try:
|
2025-11-14 20:27:39 +01:00
|
|
|
# Use provided session or create new one
|
|
|
|
|
should_create_session = session is None
|
|
|
|
|
|
|
|
|
|
if should_create_session:
|
|
|
|
|
async with self.database_manager.get_session() as session:
|
|
|
|
|
await self._init_repositories(session)
|
|
|
|
|
await self._update_job_status_impl(
|
|
|
|
|
session, job_id, status, progress, current_step,
|
|
|
|
|
error_message, results, tenant_id
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
# Reuse provided session (don't commit - let caller control transaction)
|
2025-08-08 09:08:41 +02:00
|
|
|
await self._init_repositories(session)
|
2025-11-14 20:27:39 +01:00
|
|
|
await self._update_job_status_impl(
|
|
|
|
|
session, job_id, status, progress, current_step,
|
|
|
|
|
error_message, results, tenant_id, auto_commit=False
|
|
|
|
|
)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error("Failed to update job status using repository",
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
error=str(e))
|
2025-11-05 13:24:22 +00:00
|
|
|
|
2025-11-14 20:27:39 +01:00
|
|
|
async def _update_job_status_impl(self,
|
|
|
|
|
session,
|
|
|
|
|
job_id: str,
|
|
|
|
|
status: str,
|
|
|
|
|
progress: int = None,
|
|
|
|
|
current_step: str = None,
|
|
|
|
|
error_message: str = None,
|
|
|
|
|
results: Dict = None,
|
|
|
|
|
tenant_id: str = None,
|
|
|
|
|
auto_commit: bool = True):
|
|
|
|
|
"""Implementation of job status update"""
|
|
|
|
|
# Check if log exists, create if not
|
|
|
|
|
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
|
|
|
|
|
|
|
|
|
|
if not existing_log:
|
|
|
|
|
# Create initial log entry
|
|
|
|
|
if not tenant_id:
|
|
|
|
|
# Extract tenant_id from job_id if not provided
|
|
|
|
|
# Format: enhanced_training_{tenant_id}_{job_suffix}
|
|
|
|
|
try:
|
|
|
|
|
parts = job_id.split('_')
|
|
|
|
|
if len(parts) >= 3 and parts[0] == 'enhanced' and parts[1] == 'training':
|
|
|
|
|
tenant_id = parts[2]
|
|
|
|
|
except Exception:
|
|
|
|
|
logger.warning(f"Could not extract tenant_id from job_id {job_id}")
|
|
|
|
|
|
|
|
|
|
if tenant_id:
|
|
|
|
|
log_data = {
|
|
|
|
|
"job_id": job_id,
|
|
|
|
|
"tenant_id": tenant_id,
|
|
|
|
|
"status": status or "pending",
|
|
|
|
|
"progress": progress or 0,
|
|
|
|
|
"current_step": current_step or "initializing",
|
|
|
|
|
"start_time": datetime.now(timezone.utc)
|
|
|
|
|
}
|
2025-11-05 13:24:22 +00:00
|
|
|
|
2025-11-14 20:27:39 +01:00
|
|
|
if error_message:
|
|
|
|
|
log_data["error_message"] = error_message
|
|
|
|
|
if results:
|
|
|
|
|
# Ensure results are JSON-serializable before storing
|
|
|
|
|
log_data["results"] = make_json_serializable(results)
|
2025-11-05 13:24:22 +00:00
|
|
|
|
2025-11-14 20:27:39 +01:00
|
|
|
try:
|
|
|
|
|
await self.training_log_repo.create_training_log(log_data)
|
|
|
|
|
if auto_commit:
|
|
|
|
|
await session.commit() # Explicit commit so other sessions can see it
|
|
|
|
|
logger.info("Created initial training log", job_id=job_id, tenant_id=tenant_id)
|
|
|
|
|
except Exception as create_error:
|
|
|
|
|
# Handle race condition: another session may have created the log
|
|
|
|
|
if "unique constraint" in str(create_error).lower() or "duplicate" in str(create_error).lower():
|
|
|
|
|
logger.debug("Training log already exists (race condition), querying again", job_id=job_id)
|
|
|
|
|
await session.rollback()
|
|
|
|
|
# Query again to get the existing log
|
|
|
|
|
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
|
|
|
|
|
if existing_log:
|
|
|
|
|
# Update the existing log instead
|
|
|
|
|
await self.training_log_repo.update_log_progress(
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
progress=progress,
|
|
|
|
|
current_step=current_step,
|
|
|
|
|
status=status
|
|
|
|
|
)
|
|
|
|
|
if auto_commit:
|
|
|
|
|
await session.commit()
|
|
|
|
|
else:
|
|
|
|
|
raise
|
|
|
|
|
else:
|
|
|
|
|
logger.error("Cannot create training log without tenant_id", job_id=job_id)
|
|
|
|
|
return
|
|
|
|
|
else:
|
|
|
|
|
# Update existing log
|
|
|
|
|
await self.training_log_repo.update_log_progress(
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
progress=progress,
|
|
|
|
|
current_step=current_step,
|
|
|
|
|
status=status
|
|
|
|
|
)
|
2025-11-05 13:24:22 +00:00
|
|
|
|
2025-11-14 20:27:39 +01:00
|
|
|
# Update additional fields if provided
|
|
|
|
|
if error_message or results:
|
|
|
|
|
update_data = {}
|
|
|
|
|
if error_message:
|
|
|
|
|
update_data["error_message"] = error_message
|
|
|
|
|
if results:
|
|
|
|
|
# Ensure results are JSON-serializable before storing
|
|
|
|
|
update_data["results"] = make_json_serializable(results)
|
|
|
|
|
if status in ["completed", "failed"]:
|
|
|
|
|
update_data["end_time"] = datetime.now(timezone.utc)
|
2025-11-05 13:24:22 +00:00
|
|
|
|
2025-11-14 20:27:39 +01:00
|
|
|
if update_data:
|
|
|
|
|
await self.training_log_repo.update(existing_log.id, update_data)
|
2025-11-05 13:24:22 +00:00
|
|
|
|
2025-11-14 20:27:39 +01:00
|
|
|
if auto_commit:
|
|
|
|
|
await session.commit() # Explicit commit after updates
|
2025-08-08 09:08:41 +02:00
|
|
|
|
|
|
|
|
async def start_single_product_training(self,
|
|
|
|
|
tenant_id: str,
|
2025-08-14 16:47:34 +02:00
|
|
|
inventory_product_id: str,
|
2025-08-08 09:08:41 +02:00
|
|
|
job_id: str,
|
|
|
|
|
bakery_location: tuple = (40.4168, -3.7038)) -> Dict[str, Any]:
|
2025-11-14 20:27:39 +01:00
|
|
|
"""Start enhanced single product training using repository pattern with single session"""
|
|
|
|
|
# Create a single database session for all operations to avoid connection pool exhaustion
|
|
|
|
|
async with self.database_manager.get_session() as session:
|
|
|
|
|
await self._init_repositories(session)
|
2025-11-05 13:34:56 +01:00
|
|
|
|
|
|
|
|
try:
|
2025-11-14 20:27:39 +01:00
|
|
|
logger.info("Starting enhanced single product training",
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
inventory_product_id=inventory_product_id,
|
|
|
|
|
job_id=job_id)
|
2025-11-05 13:34:56 +01:00
|
|
|
|
2025-11-14 20:27:39 +01:00
|
|
|
# Create initial training log (using shared session)
|
|
|
|
|
await self._update_job_status_repository(
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
status="running",
|
|
|
|
|
progress=0,
|
|
|
|
|
current_step="Fetching training data",
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
session=session
|
|
|
|
|
)
|
|
|
|
|
await session.commit() # Commit after initial log creation
|
|
|
|
|
|
|
|
|
|
# Prepare training data for all products to get weather/traffic data
|
|
|
|
|
# then filter down to the specific product
|
|
|
|
|
training_dataset = await self.orchestrator.prepare_training_data(
|
2025-11-05 13:34:56 +01:00
|
|
|
tenant_id=tenant_id,
|
2025-11-14 20:27:39 +01:00
|
|
|
bakery_location=bakery_location,
|
|
|
|
|
job_id=job_id + "_temp"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Use the enhanced data processor to prepare training data with all features (POI, weather, traffic)
|
|
|
|
|
# Filter sales data to the specific product first
|
|
|
|
|
sales_df = pd.DataFrame(training_dataset.sales_data)
|
|
|
|
|
product_sales_df = sales_df[sales_df['inventory_product_id'] == inventory_product_id]
|
|
|
|
|
|
|
|
|
|
if product_sales_df.empty:
|
|
|
|
|
raise ValueError(f"No sales data available for product {inventory_product_id}")
|
|
|
|
|
|
|
|
|
|
# Get weather and traffic data as DataFrames
|
|
|
|
|
weather_df = pd.DataFrame(training_dataset.weather_data)
|
|
|
|
|
traffic_df = pd.DataFrame(training_dataset.traffic_data)
|
|
|
|
|
|
|
|
|
|
# Get POI features from the training dataset (already collected by orchestrator)
|
|
|
|
|
poi_features = training_dataset.poi_features if hasattr(training_dataset, 'poi_features') else None
|
|
|
|
|
|
|
|
|
|
# Use the enhanced data processor to merge all features properly
|
|
|
|
|
# This will include POI, weather, traffic features along with ds and y
|
|
|
|
|
from app.ml.data_processor import EnhancedBakeryDataProcessor
|
|
|
|
|
data_processor = EnhancedBakeryDataProcessor(self.database_manager)
|
|
|
|
|
|
|
|
|
|
product_data = await data_processor.prepare_training_data(
|
|
|
|
|
sales_data=product_sales_df,
|
|
|
|
|
weather_data=weather_df,
|
|
|
|
|
traffic_data=traffic_df,
|
2025-11-05 13:34:56 +01:00
|
|
|
inventory_product_id=inventory_product_id,
|
2025-11-14 20:27:39 +01:00
|
|
|
poi_features=poi_features,
|
|
|
|
|
tenant_id=tenant_id,
|
2025-11-05 13:34:56 +01:00
|
|
|
job_id=job_id
|
|
|
|
|
)
|
2025-11-14 20:27:39 +01:00
|
|
|
|
|
|
|
|
if product_data.empty:
|
|
|
|
|
raise ValueError(f"Data processor returned empty data for product {inventory_product_id}")
|
|
|
|
|
|
|
|
|
|
logger.info("Prepared training data for single product",
|
|
|
|
|
inventory_product_id=inventory_product_id,
|
|
|
|
|
data_points=len(product_data),
|
|
|
|
|
features=list(product_data.columns),
|
|
|
|
|
date_range=f"{product_data['ds'].min()} to {product_data['ds'].max()}")
|
|
|
|
|
|
|
|
|
|
# Update progress (using shared session)
|
|
|
|
|
await self._update_job_status_repository(
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
status="running",
|
|
|
|
|
progress=30,
|
|
|
|
|
current_step="Training model",
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
session=session
|
|
|
|
|
)
|
|
|
|
|
await session.commit() # Commit progress update
|
|
|
|
|
|
|
|
|
|
# Run the actual training (passing the session to avoid nested session creation)
|
|
|
|
|
try:
|
|
|
|
|
model_info = await self.trainer.train_single_product_model(
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
inventory_product_id=inventory_product_id,
|
|
|
|
|
training_data=product_data,
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
session=session # ✅ CRITICAL FIX: Pass session to prevent deadlock
|
|
|
|
|
)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
import traceback
|
|
|
|
|
logger.error(f"Training failed with error: {e}")
|
|
|
|
|
logger.error(f"Full traceback: {traceback.format_exc()}")
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
# Update progress (using shared session)
|
|
|
|
|
await self._update_job_status_repository(
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
status="running",
|
|
|
|
|
progress=80,
|
|
|
|
|
current_step="Saving model",
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
session=session
|
|
|
|
|
)
|
|
|
|
|
await session.commit() # Commit progress update
|
|
|
|
|
|
|
|
|
|
# The model should already be saved by train_single_product_model
|
|
|
|
|
# Return appropriate response
|
|
|
|
|
return {
|
|
|
|
|
"job_id": job_id,
|
|
|
|
|
"tenant_id": tenant_id,
|
|
|
|
|
"inventory_product_id": inventory_product_id,
|
|
|
|
|
"status": "completed",
|
|
|
|
|
"message": "Enhanced single product training completed successfully",
|
|
|
|
|
"created_at": datetime.now(timezone.utc),
|
|
|
|
|
"estimated_duration_minutes": 15, # Default estimate for single product
|
|
|
|
|
"training_results": {
|
|
|
|
|
"total_products": 1,
|
|
|
|
|
"successful_trainings": 1,
|
|
|
|
|
"failed_trainings": 0,
|
|
|
|
|
"products": [{
|
|
|
|
|
"inventory_product_id": inventory_product_id,
|
|
|
|
|
"status": "completed",
|
|
|
|
|
"model_id": str(model_info.get('model_id', f"model_{inventory_product_id}_{job_id[:8]}")) if model_info.get('model_id') else None,
|
|
|
|
|
"data_points": len(product_data) if product_data is not None else 0,
|
|
|
|
|
# Filter metrics to ensure only numeric values are included
|
|
|
|
|
"metrics": {
|
|
|
|
|
k: float(v) if not isinstance(v, (int, float)) else v
|
|
|
|
|
for k, v in model_info.get('training_metrics', {"mape": 0.0, "mae": 0.0, "rmse": 0.0, "r2_score": 0.0}).items()
|
|
|
|
|
if k != 'product_category' and v is not None
|
|
|
|
|
}
|
|
|
|
|
}],
|
|
|
|
|
"overall_training_time_seconds": model_info.get('training_time', 45.2)
|
|
|
|
|
},
|
|
|
|
|
"enhanced_features": True,
|
|
|
|
|
"repository_integration": True,
|
|
|
|
|
"completed_at": datetime.now(timezone.utc).isoformat()
|
|
|
|
|
}
|
|
|
|
|
|
2025-11-05 13:34:56 +01:00
|
|
|
except Exception as e:
|
2025-11-14 20:27:39 +01:00
|
|
|
logger.error("Enhanced single product training failed",
|
|
|
|
|
inventory_product_id=inventory_product_id,
|
|
|
|
|
error=str(e))
|
|
|
|
|
|
|
|
|
|
# Update status to failed (using shared session)
|
|
|
|
|
await self._update_job_status_repository(
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
status="failed",
|
|
|
|
|
progress=0,
|
|
|
|
|
current_step="Training failed",
|
|
|
|
|
error_message=str(e),
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
session=session
|
|
|
|
|
)
|
|
|
|
|
await session.commit() # Commit failure status
|
|
|
|
|
|
2025-11-05 13:34:56 +01:00
|
|
|
raise
|
2025-08-08 09:08:41 +02:00
|
|
|
|
|
|
|
|
def _create_detailed_training_response(self, final_result: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
|
"""Convert final result to detailed training response"""
|
|
|
|
|
try:
|
|
|
|
|
training_results_data = final_result.get("training_results", {})
|
|
|
|
|
stored_models = final_result.get("stored_models", [])
|
2025-08-03 14:55:13 +02:00
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
# Convert stored models to product results
|
|
|
|
|
products = []
|
|
|
|
|
for model in stored_models:
|
|
|
|
|
products.append({
|
2025-08-14 16:47:34 +02:00
|
|
|
"inventory_product_id": model.get("inventory_product_id"),
|
2025-08-08 09:08:41 +02:00
|
|
|
"status": "completed",
|
|
|
|
|
"model_id": model.get("id"),
|
|
|
|
|
"data_points": model.get("training_samples", 0),
|
|
|
|
|
"metrics": {
|
|
|
|
|
"mape": model.get("mape"),
|
|
|
|
|
"mae": model.get("mae"),
|
|
|
|
|
"rmse": model.get("rmse"),
|
|
|
|
|
"r2_score": model.get("r2_score")
|
|
|
|
|
},
|
|
|
|
|
"model_path": model.get("model_path")
|
|
|
|
|
})
|
2025-08-03 14:55:13 +02:00
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
# Build the response
|
|
|
|
|
response_data = {
|
|
|
|
|
"job_id": final_result["job_id"],
|
|
|
|
|
"tenant_id": final_result["tenant_id"],
|
|
|
|
|
"status": final_result["status"],
|
|
|
|
|
"message": f"Training {final_result['status']} successfully",
|
|
|
|
|
"created_at": datetime.now(),
|
2025-11-05 13:34:56 +01:00
|
|
|
"estimated_duration_minutes": final_result.get("estimated_duration_minutes", 15),
|
2025-08-08 09:08:41 +02:00
|
|
|
"training_results": {
|
|
|
|
|
"total_products": len(products),
|
|
|
|
|
"successful_trainings": len([p for p in products if p["status"] == "completed"]),
|
|
|
|
|
"failed_trainings": len([p for p in products if p["status"] == "failed"]),
|
|
|
|
|
"products": products,
|
|
|
|
|
"overall_training_time_seconds": training_results_data.get("total_training_time", 0)
|
|
|
|
|
},
|
|
|
|
|
"data_summary": final_result.get("data_summary", {}),
|
|
|
|
|
"completed_at": final_result.get("completed_at")
|
|
|
|
|
}
|
2025-08-03 14:55:13 +02:00
|
|
|
|
2025-08-08 09:08:41 +02:00
|
|
|
return response_data
|
2025-08-03 14:55:13 +02:00
|
|
|
|
|
|
|
|
except Exception as e:
|
2025-08-08 09:08:41 +02:00
|
|
|
logger.error("Failed to create detailed response", error=str(e))
|
2025-10-09 14:11:02 +02:00
|
|
|
return final_result
|