# Training Service (ML Model Management) ## Overview The **Training Service** is the machine learning pipeline engine of Bakery-IA, responsible for training, versioning, and managing Prophet forecasting models. It orchestrates the entire ML workflow from data collection to model deployment, providing real-time progress updates via WebSocket and ensuring bakeries always have the most accurate prediction models. This service enables continuous learning and model improvement without requiring data science expertise. ## Key Features ### Automated ML Pipeline - **One-Click Model Training** - Train models for all products with a single API call - **Background Job Processing** - Asynchronous training with job queue management - **Multi-Product Training** - Process multiple products in parallel - **Progress Tracking** - Real-time WebSocket updates on training status - **Automatic Model Versioning** - Track all model versions with performance metrics - **Model Artifact Storage** - Persist trained models for fast prediction loading ### Training Job Management - **Job Queue** - FIFO queue for training requests - **Job Status Tracking** - Monitor pending, running, completed, and failed jobs - **Concurrent Job Control** - Limit parallel training jobs to prevent resource exhaustion - **Timeout Handling** - Automatic job termination after maximum duration - **Error Recovery** - Detailed error messages and retry capabilities - **Job History** - Complete audit trail of all training executions ### Model Performance Tracking - **Accuracy Metrics** - MAE, RMSE, R², MAPE for each trained model - **Historical Comparison** - Compare current vs. previous model performance - **Per-Product Analytics** - Track which products have the best forecast accuracy - **Training Duration Tracking** - Monitor training performance and optimization - **Model Selection** - Automatically deploy best-performing models ### Real-Time Communication - **WebSocket Live Updates** - Real-time progress percentage and status messages - **Training Logs** - Detailed step-by-step execution logs - **Completion Notifications** - RabbitMQ events for training completion - **Error Alerts** - Immediate notification of training failures ### Feature Engineering - **Historical Data Aggregation** - Collect sales data for model training - **External Data Integration** - Fetch weather, traffic, holiday data - **POI Feature Integration** - Merge location-based POI features into training data - **Feature Extraction** - Generate 30+ temporal, contextual, and location-based features - **Data Validation** - Ensure minimum data requirements before training - **Outlier Detection** - Filter anomalous data points ## Technical Capabilities ### ML Training Pipeline ```python # Training workflow async def train_model_pipeline(tenant_id: str, product_id: str): """Complete ML training pipeline""" # Step 1: Data Collection sales_data = await fetch_historical_sales(tenant_id, product_id) if len(sales_data) < MIN_TRAINING_DAYS: raise InsufficientDataError(f"Need {MIN_TRAINING_DAYS}+ days of data") # Step 2: Feature Engineering features = engineer_features(sales_data) weather_data = await fetch_weather_data(tenant_id) traffic_data = await fetch_traffic_data(tenant_id) holiday_data = await fetch_holiday_calendar() poi_features = await fetch_poi_features(tenant_id) # NEW: Location context # Merge POI features into training dataframe features = merge_poi_features(features, poi_features) # Step 3: Prophet Model Training model = Prophet( seasonality_mode='additive', daily_seasonality=True, weekly_seasonality=True, yearly_seasonality=True, ) model.add_country_holidays(country_name='ES') model.fit(features) # Step 4: Model Validation metrics = calculate_performance_metrics(model, sales_data) # Step 5: Model Storage model_path = save_model_artifact(model, tenant_id, product_id) # Step 6: Model Registration await register_model_in_database(model_path, metrics) # Step 7: Notification await publish_training_complete_event(tenant_id, product_id, metrics) return model, metrics ``` ### WebSocket Progress Updates ```python # Real-time progress broadcasting async def broadcast_training_progress(job_id: str, progress: dict): """Send progress update to connected clients""" message = { "type": "training_progress", "job_id": job_id, "progress": { "percentage": progress["percentage"], # 0-100 "current_step": progress["step"], # Step description "products_completed": progress["completed"], "products_total": progress["total"], "estimated_time_remaining": progress["eta"], # Seconds "started_at": progress["start_time"] }, "timestamp": datetime.utcnow().isoformat() } await websocket_manager.broadcast(job_id, message) ``` ### Model Artifact Management ```python # Model storage and retrieval import joblib from pathlib import Path # Save trained model def save_model_artifact(model: Prophet, tenant_id: str, product_id: str) -> str: """Serialize and store model""" model_dir = Path(f"/models/{tenant_id}/{product_id}") model_dir.mkdir(parents=True, exist_ok=True) version = datetime.utcnow().strftime("%Y%m%d_%H%M%S") model_path = model_dir / f"model_v{version}.pkl" joblib.dump(model, model_path) return str(model_path) # Load trained model def load_model_artifact(model_path: str) -> Prophet: """Load serialized model""" return joblib.load(model_path) ``` ### Performance Metrics Calculation ```python from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score import numpy as np def calculate_performance_metrics(model: Prophet, actual_data: pd.DataFrame) -> dict: """Calculate comprehensive model performance metrics""" # Make predictions on validation set predictions = model.predict(actual_data) # Calculate metrics mae = mean_absolute_error(actual_data['y'], predictions['yhat']) rmse = np.sqrt(mean_squared_error(actual_data['y'], predictions['yhat'])) r2 = r2_score(actual_data['y'], predictions['yhat']) mape = np.mean(np.abs((actual_data['y'] - predictions['yhat']) / actual_data['y'])) * 100 return { "mae": float(mae), # Mean Absolute Error "rmse": float(rmse), # Root Mean Square Error "r2_score": float(r2), # R-squared "mape": float(mape), # Mean Absolute Percentage Error "accuracy": float(100 - mape) if mape < 100 else 0.0 } ``` ## Business Value ### For Bakery Owners - **Continuous Improvement** - Models automatically improve with more data - **No ML Expertise Required** - One-click training, no data science skills needed - **Always Up-to-Date** - Weekly automatic retraining keeps models accurate - **Transparent Performance** - Clear accuracy metrics show forecast reliability - **Cost Savings** - Automated ML pipeline eliminates need for data scientists ### For Operations Managers - **Model Version Control** - Track and compare model versions over time - **Performance Monitoring** - Identify products with poor forecast accuracy - **Training Scheduling** - Schedule retraining during low-traffic hours - **Resource Management** - Control concurrent training jobs to prevent overload ### For Platform Operations - **Scalable ML Pipeline** - Train models for thousands of products - **Background Processing** - Non-blocking training jobs - **Error Handling** - Robust error recovery and retry mechanisms - **Cost Optimization** - Efficient model storage and caching ## Technology Stack - **Framework**: FastAPI (Python 3.11+) - Async web framework with WebSocket support - **Database**: PostgreSQL 17 - Training logs, model metadata, job queue - **ML Library**: Prophet (fbprophet) - Time series forecasting - **Model Storage**: Joblib - Model serialization - **File System**: Persistent volumes - Model artifact storage - **WebSocket**: FastAPI WebSocket - Real-time progress updates - **Messaging**: RabbitMQ 4.1 - Training completion events - **ORM**: SQLAlchemy 2.0 (async) - Database abstraction - **Data Processing**: Pandas, NumPy - Data manipulation - **Logging**: Structlog - Structured JSON logging - **Metrics**: Prometheus Client - Custom metrics ## API Endpoints (Key Routes) ### Training Management - `POST /api/v1/training/start` - Start training job for tenant - `POST /api/v1/training/start/{product_id}` - Train specific product - `POST /api/v1/training/stop/{job_id}` - Stop running training job - `GET /api/v1/training/status/{job_id}` - Get job status and progress - `GET /api/v1/training/history` - Get training job history - `DELETE /api/v1/training/jobs/{job_id}` - Delete training job record ### Model Management - `GET /api/v1/training/models` - List all trained models - `GET /api/v1/training/models/{model_id}` - Get specific model details - `GET /api/v1/training/models/{model_id}/metrics` - Get model performance metrics - `GET /api/v1/training/models/latest/{product_id}` - Get latest model for product - `POST /api/v1/training/models/{model_id}/deploy` - Deploy specific model version - `DELETE /api/v1/training/models/{model_id}` - Delete model artifact ### WebSocket - `WS /api/v1/training/ws/{job_id}` - Connect to training progress stream ### Analytics - `GET /api/v1/training/analytics/performance` - Overall training performance - `GET /api/v1/training/analytics/accuracy` - Model accuracy distribution - `GET /api/v1/training/analytics/duration` - Training duration statistics ## Database Schema ### Main Tables **training_job_queue** ```sql CREATE TABLE training_job_queue ( id UUID PRIMARY KEY, tenant_id UUID NOT NULL, job_name VARCHAR(255), products_to_train TEXT[], -- Array of product IDs status VARCHAR(50) NOT NULL, -- pending, running, completed, failed priority INTEGER DEFAULT 0, progress_percentage INTEGER DEFAULT 0, current_step VARCHAR(255), products_completed INTEGER DEFAULT 0, products_total INTEGER, started_at TIMESTAMP, completed_at TIMESTAMP, estimated_completion TIMESTAMP, error_message TEXT, retry_count INTEGER DEFAULT 0, created_by UUID, created_at TIMESTAMP DEFAULT NOW(), updated_at TIMESTAMP DEFAULT NOW() ); ``` **trained_models** ```sql CREATE TABLE trained_models ( id UUID PRIMARY KEY, tenant_id UUID NOT NULL, product_id UUID NOT NULL, model_version VARCHAR(50) NOT NULL, model_path VARCHAR(500) NOT NULL, training_job_id UUID REFERENCES training_job_queue(id), algorithm VARCHAR(50) DEFAULT 'prophet', hyperparameters JSONB, training_duration_seconds INTEGER, training_data_points INTEGER, is_deployed BOOLEAN DEFAULT FALSE, deployed_at TIMESTAMP, created_at TIMESTAMP DEFAULT NOW(), UNIQUE(tenant_id, product_id, model_version) ); ``` **model_performance_metrics** ```sql CREATE TABLE model_performance_metrics ( id UUID PRIMARY KEY, model_id UUID REFERENCES trained_models(id), tenant_id UUID NOT NULL, product_id UUID NOT NULL, mae DECIMAL(10, 4), -- Mean Absolute Error rmse DECIMAL(10, 4), -- Root Mean Square Error r2_score DECIMAL(10, 6), -- R-squared mape DECIMAL(10, 4), -- Mean Absolute Percentage Error accuracy_percentage DECIMAL(5, 2), validation_data_points INTEGER, created_at TIMESTAMP DEFAULT NOW() ); ``` **model_training_logs** ```sql CREATE TABLE model_training_logs ( id UUID PRIMARY KEY, training_job_id UUID REFERENCES training_job_queue(id), tenant_id UUID NOT NULL, product_id UUID, log_level VARCHAR(20), -- DEBUG, INFO, WARNING, ERROR message TEXT, step_name VARCHAR(100), execution_time_ms INTEGER, metadata JSONB, created_at TIMESTAMP DEFAULT NOW() ); ``` **model_artifacts** (Metadata only, actual files on disk) ```sql CREATE TABLE model_artifacts ( id UUID PRIMARY KEY, model_id UUID REFERENCES trained_models(id), artifact_type VARCHAR(50), -- model_file, feature_list, scaler, etc. file_path VARCHAR(500), file_size_bytes BIGINT, checksum VARCHAR(64), -- SHA-256 hash created_at TIMESTAMP DEFAULT NOW() ); ``` ## Events & Messaging ### Published Events (RabbitMQ) **Exchange**: `training` **Routing Key**: `training.completed` **Training Completed Event** ```json { "event_type": "training_completed", "tenant_id": "uuid", "job_id": "uuid", "job_name": "Weekly retraining - All products", "status": "completed", "results": { "successful_trainings": 25, "failed_trainings": 2, "total_products": 27, "models_created": [ { "product_id": "uuid", "product_name": "Baguette", "model_version": "20251106_143022", "accuracy": 82.5, "mae": 12.3, "rmse": 18.7, "r2_score": 0.78 } ], "average_accuracy": 79.8, "training_duration_seconds": 342 }, "started_at": "2025-11-06T14:25:00Z", "completed_at": "2025-11-06T14:30:42Z", "timestamp": "2025-11-06T14:30:42Z" } ``` **Training Failed Event** ```json { "event_type": "training_failed", "tenant_id": "uuid", "job_id": "uuid", "product_id": "uuid", "product_name": "Croissant", "error_type": "InsufficientDataError", "error_message": "Product requires minimum 30 days of sales data. Currently: 15 days.", "recommended_action": "Collect more sales data before retraining", "severity": "medium", "timestamp": "2025-11-06T14:28:15Z" } ``` ### Consumed Events - **From Orchestrator**: Scheduled training triggers - **From Sales**: New sales data imported (triggers retraining) ## Custom Metrics (Prometheus) ```python # Training job metrics training_jobs_total = Counter( 'training_jobs_total', 'Total training jobs started', ['tenant_id', 'status'] # completed, failed, cancelled ) training_duration_seconds = Histogram( 'training_duration_seconds', 'Training job duration', ['tenant_id'], buckets=[10, 30, 60, 120, 300, 600, 1800, 3600] # seconds ) models_trained_total = Counter( 'models_trained_total', 'Total models successfully trained', ['tenant_id', 'product_category'] ) # Model performance metrics model_accuracy_distribution = Histogram( 'model_accuracy_percentage', 'Distribution of model accuracy scores', ['tenant_id'], buckets=[50, 60, 70, 75, 80, 85, 90, 95, 100] # percentage ) model_mae_distribution = Histogram( 'model_mae', 'Distribution of Mean Absolute Error', ['tenant_id'], buckets=[1, 5, 10, 20, 30, 50, 100] # units ) # WebSocket metrics websocket_connections_total = Gauge( 'training_websocket_connections', 'Active WebSocket connections', ['tenant_id'] ) websocket_messages_sent = Counter( 'training_websocket_messages_total', 'Total WebSocket messages sent', ['tenant_id', 'message_type'] ) ``` ## Configuration ### Environment Variables **Service Configuration:** - `PORT` - Service port (default: 8004) - `DATABASE_URL` - PostgreSQL connection string - `RABBITMQ_URL` - RabbitMQ connection string - `MODEL_STORAGE_PATH` - Path for model artifacts (default: /models) **Training Configuration:** - `MAX_CONCURRENT_JOBS` - Maximum parallel training jobs (default: 3) - `MAX_TRAINING_TIME_MINUTES` - Job timeout (default: 30) - `MIN_TRAINING_DATA_DAYS` - Minimum history required (default: 30) - `ENABLE_AUTO_DEPLOYMENT` - Auto-deploy after training (default: true) **Prophet Configuration:** - `PROPHET_DAILY_SEASONALITY` - Enable daily patterns (default: true) - `PROPHET_WEEKLY_SEASONALITY` - Enable weekly patterns (default: true) - `PROPHET_YEARLY_SEASONALITY` - Enable yearly patterns (default: true) - `PROPHET_INTERVAL_WIDTH` - Confidence interval (default: 0.95) - `PROPHET_CHANGEPOINT_PRIOR_SCALE` - Trend flexibility (default: 0.05) **WebSocket Configuration:** - `WEBSOCKET_HEARTBEAT_INTERVAL` - Ping interval seconds (default: 30) - `WEBSOCKET_MAX_CONNECTIONS` - Max connections per tenant (default: 10) - `WEBSOCKET_MESSAGE_QUEUE_SIZE` - Message buffer size (default: 100) **Storage Configuration:** - `MODEL_RETENTION_DAYS` - Days to keep old models (default: 90) - `MAX_MODEL_VERSIONS_PER_PRODUCT` - Version limit (default: 10) - `ENABLE_MODEL_COMPRESSION` - Compress model files (default: true) ## Development Setup ### Prerequisites - Python 3.11+ - PostgreSQL 17 - RabbitMQ 4.1 - Persistent storage for model artifacts ### Local Development ```bash # Create virtual environment cd services/training python -m venv venv source venv/bin/activate # Install dependencies pip install -r requirements.txt # Set environment variables export DATABASE_URL=postgresql://user:pass@localhost:5432/training export RABBITMQ_URL=amqp://guest:guest@localhost:5672/ export MODEL_STORAGE_PATH=/tmp/models # Create model storage directory mkdir -p /tmp/models # Run database migrations alembic upgrade head # Run the service python main.py ``` ### Testing ```bash # Unit tests pytest tests/unit/ -v # Integration tests (requires services) pytest tests/integration/ -v # WebSocket tests pytest tests/websocket/ -v # Test with coverage pytest --cov=app tests/ --cov-report=html ``` ### WebSocket Testing ```python # Test WebSocket connection import asyncio import websockets import json async def test_training_progress(): uri = "ws://localhost:8004/api/v1/training/ws/job-id-here" async with websockets.connect(uri) as websocket: while True: message = await websocket.recv() data = json.loads(message) print(f"Progress: {data['progress']['percentage']}%") print(f"Step: {data['progress']['current_step']}") if data['type'] == 'training_completed': print("Training finished!") break asyncio.run(test_training_progress()) ``` ## POI Feature Integration ### How POI Features Enhance Training The Training Service integrates location-based POI features from the External Service to improve forecast accuracy: **POI Features Included:** - `school_density` - Number of schools within 1km radius - `office_density` - Number of offices and business centers nearby - `residential_density` - Residential area proximity - `transport_hub_proximity` - Distance to metro, bus, train stations - `commercial_zone_score` - Commercial activity in the area - `restaurant_density` - Nearby restaurants and cafes - `competitor_proximity` - Distance to competing bakeries - And 11+ more location-based features **Integration Process:** 1. **Fetch POI Context** - Retrieve tenant's POI features from External Service (`/poi-context/{tenant_id}`) 2. **Extract ML Features** - Parse `ml_features` JSON object from POI context 3. **Merge with Training Data** - Add POI features as additional columns in training dataframe 4. **Prophet Training** - Include POI features as regressors in Prophet model 5. **Feature Importance** - Track which POI features most impact predictions **Example POI Feature Integration:** ```python from app.ml.poi_feature_integrator import POIFeatureIntegrator # Initialize POI integrator poi_integrator = POIFeatureIntegrator(external_service_url) # Fetch and merge POI features poi_features = await poi_integrator.fetch_poi_features(tenant_id) training_df = poi_integrator.merge_poi_features(training_df, poi_features) # POI features now available as columns: # training_df['school_density'], training_df['office_density'], etc. # Add POI features as Prophet regressors for feature_name in poi_features.keys(): prophet_model.add_regressor(feature_name) ``` **Endpoint Used:** - Via shared client: `/api/v1/tenants/{tenant_id}/external/poi-context` (routed through API Gateway) ## Integration Points ### Dependencies (Services Called) - **Sales Service** - Fetch historical sales data for training - **External Service** - Fetch weather, traffic, holiday, and POI feature data - **PostgreSQL** - Store job queue, models, metrics, logs - **RabbitMQ** - Publish training completion events - **File System** - Store model artifacts ### Dependents (Services That Call This) - **Forecasting Service** - Load trained models for predictions - **Orchestrator Service** - Trigger scheduled training jobs - **Frontend Dashboard** - Display training progress and model metrics - **AI Insights Service** - Analyze model performance patterns ## Security Measures ### Data Protection - **Tenant Isolation** - All training jobs scoped to tenant_id - **Model Access Control** - Only tenant can access their models - **Input Validation** - Validate all training parameters - **Rate Limiting** - Prevent training job spam ### Model Security - **Model Checksums** - SHA-256 hash verification for artifacts - **Version Control** - Track all model versions with audit trail - **Access Logging** - Log all model access and deployment - **Secure Storage** - Model files stored with restricted permissions ### WebSocket Security - **JWT Authentication** - Authenticate WebSocket connections - **Connection Limits** - Max connections per tenant - **Message Validation** - Validate all WebSocket messages - **Heartbeat Monitoring** - Detect and close stale connections ## Performance Optimization ### Training Performance 1. **Parallel Processing** - Train multiple products concurrently 2. **Data Caching** - Cache fetched external data across products 3. **Incremental Training** - Only retrain changed products 4. **Resource Limits** - CPU/memory limits per training job 5. **Priority Queue** - Prioritize important products first ### Storage Optimization 1. **Model Compression** - Compress model artifacts (gzip) 2. **Old Model Cleanup** - Automatic deletion after retention period 3. **Version Limits** - Keep only N most recent versions 4. **Deduplication** - Avoid storing identical models ### WebSocket Optimization 1. **Message Batching** - Batch progress updates (every 2 seconds) 2. **Connection Pooling** - Reuse WebSocket connections 3. **Compression** - Enable WebSocket message compression 4. **Heartbeat** - Keep connections alive efficiently ## Troubleshooting ### Common Issues **Issue**: Training jobs stuck in "pending" status - **Cause**: Max concurrent jobs reached or worker process crashed - **Solution**: Check `MAX_CONCURRENT_JOBS` setting, restart service **Issue**: WebSocket connection drops during training - **Cause**: Network timeout or client disconnection - **Solution**: Implement auto-reconnect logic in client **Issue**: "Insufficient data" errors for many products - **Cause**: Products need 30+ days of sales history - **Solution**: Import more historical sales data or reduce `MIN_TRAINING_DATA_DAYS` **Issue**: Low model accuracy (<70%) - **Cause**: Insufficient data, outliers, or changing business patterns - **Solution**: Clean outliers, add more features, or manually adjust Prophet params ### Debug Mode ```bash # Enable detailed logging export LOG_LEVEL=DEBUG export PROPHET_VERBOSE=1 # Enable training profiling export ENABLE_PROFILING=1 # Disable concurrent jobs for debugging export MAX_CONCURRENT_JOBS=1 ``` ## Competitive Advantages 1. **One-Click ML** - No data science expertise required 2. **Real-Time Visibility** - WebSocket progress updates unique in bakery software 3. **Continuous Learning** - Automatic weekly retraining 4. **Version Control** - Track and compare all model versions 5. **Production-Ready** - Robust error handling and retry mechanisms 6. **Scalable** - Train models for thousands of products 7. **Spanish Market** - Optimized for Spanish bakery patterns and holidays ## Future Enhancements - **Hyperparameter Tuning** - Automatic optimization of Prophet parameters - **A/B Testing** - Deploy multiple models and compare performance - **Distributed Training** - Scale across multiple machines - **GPU Acceleration** - Use GPUs for deep learning models - **AutoML** - Automatic algorithm selection (Prophet vs LSTM vs ARIMA) - **Model Explainability** - SHAP values to explain predictions - **Custom Algorithms** - Support for user-provided ML models - **Transfer Learning** - Use pre-trained models from similar bakeries --- **For VUE Madrid Business Plan**: The Training Service demonstrates advanced ML engineering capabilities with automated pipeline management and real-time monitoring. The ability to continuously improve forecast accuracy without manual intervention represents significant operational efficiency and competitive advantage. This self-learning system is a key differentiator in the bakery software market and showcases technical innovation suitable for EU technology grants and investor presentations.