Training Service (ML Model Management)
Overview
The Training Service is the machine learning pipeline engine of Bakery-IA, responsible for training, versioning, and managing Prophet forecasting models. It orchestrates the entire ML workflow from data collection to model deployment, providing real-time progress updates via WebSocket and ensuring bakeries always have the most accurate prediction models. This service enables continuous learning and model improvement without requiring data science expertise.
Key Features
Automated ML Pipeline
- One-Click Model Training - Train models for all products with a single API call
- Background Job Processing - Asynchronous training with job queue management
- Multi-Product Training - Process multiple products in parallel
- Progress Tracking - Real-time WebSocket updates on training status
- Automatic Model Versioning - Track all model versions with performance metrics
- Model Artifact Storage - Persist trained models for fast prediction loading
Training Job Management
- Job Queue - FIFO queue for training requests
- Job Status Tracking - Monitor pending, running, completed, and failed jobs
- Concurrent Job Control - Limit parallel training jobs to prevent resource exhaustion
- Timeout Handling - Automatic job termination after maximum duration
- Error Recovery - Detailed error messages and retry capabilities
- Job History - Complete audit trail of all training executions
Model Performance Tracking
- Accuracy Metrics - MAE, RMSE, R², MAPE for each trained model
- Historical Comparison - Compare current vs. previous model performance
- Per-Product Analytics - Track which products have the best forecast accuracy
- Training Duration Tracking - Monitor training performance and optimization
- Model Selection - Automatically deploy best-performing models
Real-Time Communication
- WebSocket Live Updates - Real-time progress percentage and status messages
- Training Logs - Detailed step-by-step execution logs
- Completion Notifications - RabbitMQ events for training completion
- Error Alerts - Immediate notification of training failures
Feature Engineering
- Historical Data Aggregation - Collect sales data for model training
- External Data Integration - Fetch weather, traffic, holiday data
- Feature Extraction - Generate 20+ temporal and contextual features
- Data Validation - Ensure minimum data requirements before training
- Outlier Detection - Filter anomalous data points
Technical Capabilities
ML Training Pipeline
# Training workflow
async def train_model_pipeline(tenant_id: str, product_id: str):
"""Complete ML training pipeline"""
# Step 1: Data Collection
sales_data = await fetch_historical_sales(tenant_id, product_id)
if len(sales_data) < MIN_TRAINING_DAYS:
raise InsufficientDataError(f"Need {MIN_TRAINING_DAYS}+ days of data")
# Step 2: Feature Engineering
features = engineer_features(sales_data)
weather_data = await fetch_weather_data(tenant_id)
traffic_data = await fetch_traffic_data(tenant_id)
holiday_data = await fetch_holiday_calendar()
# Step 3: Prophet Model Training
model = Prophet(
seasonality_mode='additive',
daily_seasonality=True,
weekly_seasonality=True,
yearly_seasonality=True,
)
model.add_country_holidays(country_name='ES')
model.fit(features)
# Step 4: Model Validation
metrics = calculate_performance_metrics(model, sales_data)
# Step 5: Model Storage
model_path = save_model_artifact(model, tenant_id, product_id)
# Step 6: Model Registration
await register_model_in_database(model_path, metrics)
# Step 7: Notification
await publish_training_complete_event(tenant_id, product_id, metrics)
return model, metrics
WebSocket Progress Updates
# Real-time progress broadcasting
async def broadcast_training_progress(job_id: str, progress: dict):
"""Send progress update to connected clients"""
message = {
"type": "training_progress",
"job_id": job_id,
"progress": {
"percentage": progress["percentage"], # 0-100
"current_step": progress["step"], # Step description
"products_completed": progress["completed"],
"products_total": progress["total"],
"estimated_time_remaining": progress["eta"], # Seconds
"started_at": progress["start_time"]
},
"timestamp": datetime.utcnow().isoformat()
}
await websocket_manager.broadcast(job_id, message)
Model Artifact Management
# Model storage and retrieval
import joblib
from pathlib import Path
# Save trained model
def save_model_artifact(model: Prophet, tenant_id: str, product_id: str) -> str:
"""Serialize and store model"""
model_dir = Path(f"/models/{tenant_id}/{product_id}")
model_dir.mkdir(parents=True, exist_ok=True)
version = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
model_path = model_dir / f"model_v{version}.pkl"
joblib.dump(model, model_path)
return str(model_path)
# Load trained model
def load_model_artifact(model_path: str) -> Prophet:
"""Load serialized model"""
return joblib.load(model_path)
Performance Metrics Calculation
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import numpy as np
def calculate_performance_metrics(model: Prophet, actual_data: pd.DataFrame) -> dict:
"""Calculate comprehensive model performance metrics"""
# Make predictions on validation set
predictions = model.predict(actual_data)
# Calculate metrics
mae = mean_absolute_error(actual_data['y'], predictions['yhat'])
rmse = np.sqrt(mean_squared_error(actual_data['y'], predictions['yhat']))
r2 = r2_score(actual_data['y'], predictions['yhat'])
mape = np.mean(np.abs((actual_data['y'] - predictions['yhat']) / actual_data['y'])) * 100
return {
"mae": float(mae), # Mean Absolute Error
"rmse": float(rmse), # Root Mean Square Error
"r2_score": float(r2), # R-squared
"mape": float(mape), # Mean Absolute Percentage Error
"accuracy": float(100 - mape) if mape < 100 else 0.0
}
Business Value
For Bakery Owners
- Continuous Improvement - Models automatically improve with more data
- No ML Expertise Required - One-click training, no data science skills needed
- Always Up-to-Date - Weekly automatic retraining keeps models accurate
- Transparent Performance - Clear accuracy metrics show forecast reliability
- Cost Savings - Automated ML pipeline eliminates need for data scientists
For Operations Managers
- Model Version Control - Track and compare model versions over time
- Performance Monitoring - Identify products with poor forecast accuracy
- Training Scheduling - Schedule retraining during low-traffic hours
- Resource Management - Control concurrent training jobs to prevent overload
For Platform Operations
- Scalable ML Pipeline - Train models for thousands of products
- Background Processing - Non-blocking training jobs
- Error Handling - Robust error recovery and retry mechanisms
- Cost Optimization - Efficient model storage and caching
Technology Stack
- Framework: FastAPI (Python 3.11+) - Async web framework with WebSocket support
- Database: PostgreSQL 17 - Training logs, model metadata, job queue
- ML Library: Prophet (fbprophet) - Time series forecasting
- Model Storage: Joblib - Model serialization
- File System: Persistent volumes - Model artifact storage
- WebSocket: FastAPI WebSocket - Real-time progress updates
- Messaging: RabbitMQ 4.1 - Training completion events
- ORM: SQLAlchemy 2.0 (async) - Database abstraction
- Data Processing: Pandas, NumPy - Data manipulation
- Logging: Structlog - Structured JSON logging
- Metrics: Prometheus Client - Custom metrics
API Endpoints (Key Routes)
Training Management
POST /api/v1/training/start- Start training job for tenantPOST /api/v1/training/start/{product_id}- Train specific productPOST /api/v1/training/stop/{job_id}- Stop running training jobGET /api/v1/training/status/{job_id}- Get job status and progressGET /api/v1/training/history- Get training job historyDELETE /api/v1/training/jobs/{job_id}- Delete training job record
Model Management
GET /api/v1/training/models- List all trained modelsGET /api/v1/training/models/{model_id}- Get specific model detailsGET /api/v1/training/models/{model_id}/metrics- Get model performance metricsGET /api/v1/training/models/latest/{product_id}- Get latest model for productPOST /api/v1/training/models/{model_id}/deploy- Deploy specific model versionDELETE /api/v1/training/models/{model_id}- Delete model artifact
WebSocket
WS /api/v1/training/ws/{job_id}- Connect to training progress stream
Analytics
GET /api/v1/training/analytics/performance- Overall training performanceGET /api/v1/training/analytics/accuracy- Model accuracy distributionGET /api/v1/training/analytics/duration- Training duration statistics
Database Schema
Main Tables
training_job_queue
CREATE TABLE training_job_queue (
id UUID PRIMARY KEY,
tenant_id UUID NOT NULL,
job_name VARCHAR(255),
products_to_train TEXT[], -- Array of product IDs
status VARCHAR(50) NOT NULL, -- pending, running, completed, failed
priority INTEGER DEFAULT 0,
progress_percentage INTEGER DEFAULT 0,
current_step VARCHAR(255),
products_completed INTEGER DEFAULT 0,
products_total INTEGER,
started_at TIMESTAMP,
completed_at TIMESTAMP,
estimated_completion TIMESTAMP,
error_message TEXT,
retry_count INTEGER DEFAULT 0,
created_by UUID,
created_at TIMESTAMP DEFAULT NOW(),
updated_at TIMESTAMP DEFAULT NOW()
);
trained_models
CREATE TABLE trained_models (
id UUID PRIMARY KEY,
tenant_id UUID NOT NULL,
product_id UUID NOT NULL,
model_version VARCHAR(50) NOT NULL,
model_path VARCHAR(500) NOT NULL,
training_job_id UUID REFERENCES training_job_queue(id),
algorithm VARCHAR(50) DEFAULT 'prophet',
hyperparameters JSONB,
training_duration_seconds INTEGER,
training_data_points INTEGER,
is_deployed BOOLEAN DEFAULT FALSE,
deployed_at TIMESTAMP,
created_at TIMESTAMP DEFAULT NOW(),
UNIQUE(tenant_id, product_id, model_version)
);
model_performance_metrics
CREATE TABLE model_performance_metrics (
id UUID PRIMARY KEY,
model_id UUID REFERENCES trained_models(id),
tenant_id UUID NOT NULL,
product_id UUID NOT NULL,
mae DECIMAL(10, 4), -- Mean Absolute Error
rmse DECIMAL(10, 4), -- Root Mean Square Error
r2_score DECIMAL(10, 6), -- R-squared
mape DECIMAL(10, 4), -- Mean Absolute Percentage Error
accuracy_percentage DECIMAL(5, 2),
validation_data_points INTEGER,
created_at TIMESTAMP DEFAULT NOW()
);
model_training_logs
CREATE TABLE model_training_logs (
id UUID PRIMARY KEY,
training_job_id UUID REFERENCES training_job_queue(id),
tenant_id UUID NOT NULL,
product_id UUID,
log_level VARCHAR(20), -- DEBUG, INFO, WARNING, ERROR
message TEXT,
step_name VARCHAR(100),
execution_time_ms INTEGER,
metadata JSONB,
created_at TIMESTAMP DEFAULT NOW()
);
model_artifacts (Metadata only, actual files on disk)
CREATE TABLE model_artifacts (
id UUID PRIMARY KEY,
model_id UUID REFERENCES trained_models(id),
artifact_type VARCHAR(50), -- model_file, feature_list, scaler, etc.
file_path VARCHAR(500),
file_size_bytes BIGINT,
checksum VARCHAR(64), -- SHA-256 hash
created_at TIMESTAMP DEFAULT NOW()
);
Events & Messaging
Published Events (RabbitMQ)
Exchange: training
Routing Key: training.completed
Training Completed Event
{
"event_type": "training_completed",
"tenant_id": "uuid",
"job_id": "uuid",
"job_name": "Weekly retraining - All products",
"status": "completed",
"results": {
"successful_trainings": 25,
"failed_trainings": 2,
"total_products": 27,
"models_created": [
{
"product_id": "uuid",
"product_name": "Baguette",
"model_version": "20251106_143022",
"accuracy": 82.5,
"mae": 12.3,
"rmse": 18.7,
"r2_score": 0.78
}
],
"average_accuracy": 79.8,
"training_duration_seconds": 342
},
"started_at": "2025-11-06T14:25:00Z",
"completed_at": "2025-11-06T14:30:42Z",
"timestamp": "2025-11-06T14:30:42Z"
}
Training Failed Event
{
"event_type": "training_failed",
"tenant_id": "uuid",
"job_id": "uuid",
"product_id": "uuid",
"product_name": "Croissant",
"error_type": "InsufficientDataError",
"error_message": "Product requires minimum 30 days of sales data. Currently: 15 days.",
"recommended_action": "Collect more sales data before retraining",
"severity": "medium",
"timestamp": "2025-11-06T14:28:15Z"
}
Consumed Events
- From Orchestrator: Scheduled training triggers
- From Sales: New sales data imported (triggers retraining)
Custom Metrics (Prometheus)
# Training job metrics
training_jobs_total = Counter(
'training_jobs_total',
'Total training jobs started',
['tenant_id', 'status'] # completed, failed, cancelled
)
training_duration_seconds = Histogram(
'training_duration_seconds',
'Training job duration',
['tenant_id'],
buckets=[10, 30, 60, 120, 300, 600, 1800, 3600] # seconds
)
models_trained_total = Counter(
'models_trained_total',
'Total models successfully trained',
['tenant_id', 'product_category']
)
# Model performance metrics
model_accuracy_distribution = Histogram(
'model_accuracy_percentage',
'Distribution of model accuracy scores',
['tenant_id'],
buckets=[50, 60, 70, 75, 80, 85, 90, 95, 100] # percentage
)
model_mae_distribution = Histogram(
'model_mae',
'Distribution of Mean Absolute Error',
['tenant_id'],
buckets=[1, 5, 10, 20, 30, 50, 100] # units
)
# WebSocket metrics
websocket_connections_total = Gauge(
'training_websocket_connections',
'Active WebSocket connections',
['tenant_id']
)
websocket_messages_sent = Counter(
'training_websocket_messages_total',
'Total WebSocket messages sent',
['tenant_id', 'message_type']
)
Configuration
Environment Variables
Service Configuration:
PORT- Service port (default: 8004)DATABASE_URL- PostgreSQL connection stringRABBITMQ_URL- RabbitMQ connection stringMODEL_STORAGE_PATH- Path for model artifacts (default: /models)
Training Configuration:
MAX_CONCURRENT_JOBS- Maximum parallel training jobs (default: 3)MAX_TRAINING_TIME_MINUTES- Job timeout (default: 30)MIN_TRAINING_DATA_DAYS- Minimum history required (default: 30)ENABLE_AUTO_DEPLOYMENT- Auto-deploy after training (default: true)
Prophet Configuration:
PROPHET_DAILY_SEASONALITY- Enable daily patterns (default: true)PROPHET_WEEKLY_SEASONALITY- Enable weekly patterns (default: true)PROPHET_YEARLY_SEASONALITY- Enable yearly patterns (default: true)PROPHET_INTERVAL_WIDTH- Confidence interval (default: 0.95)PROPHET_CHANGEPOINT_PRIOR_SCALE- Trend flexibility (default: 0.05)
WebSocket Configuration:
WEBSOCKET_HEARTBEAT_INTERVAL- Ping interval seconds (default: 30)WEBSOCKET_MAX_CONNECTIONS- Max connections per tenant (default: 10)WEBSOCKET_MESSAGE_QUEUE_SIZE- Message buffer size (default: 100)
Storage Configuration:
MODEL_RETENTION_DAYS- Days to keep old models (default: 90)MAX_MODEL_VERSIONS_PER_PRODUCT- Version limit (default: 10)ENABLE_MODEL_COMPRESSION- Compress model files (default: true)
Development Setup
Prerequisites
- Python 3.11+
- PostgreSQL 17
- RabbitMQ 4.1
- Persistent storage for model artifacts
Local Development
# Create virtual environment
cd services/training
python -m venv venv
source venv/bin/activate
# Install dependencies
pip install -r requirements.txt
# Set environment variables
export DATABASE_URL=postgresql://user:pass@localhost:5432/training
export RABBITMQ_URL=amqp://guest:guest@localhost:5672/
export MODEL_STORAGE_PATH=/tmp/models
# Create model storage directory
mkdir -p /tmp/models
# Run database migrations
alembic upgrade head
# Run the service
python main.py
Testing
# Unit tests
pytest tests/unit/ -v
# Integration tests (requires services)
pytest tests/integration/ -v
# WebSocket tests
pytest tests/websocket/ -v
# Test with coverage
pytest --cov=app tests/ --cov-report=html
WebSocket Testing
# Test WebSocket connection
import asyncio
import websockets
import json
async def test_training_progress():
uri = "ws://localhost:8004/api/v1/training/ws/job-id-here"
async with websockets.connect(uri) as websocket:
while True:
message = await websocket.recv()
data = json.loads(message)
print(f"Progress: {data['progress']['percentage']}%")
print(f"Step: {data['progress']['current_step']}")
if data['type'] == 'training_completed':
print("Training finished!")
break
asyncio.run(test_training_progress())
Integration Points
Dependencies (Services Called)
- Sales Service - Fetch historical sales data for training
- External Service - Fetch weather, traffic, holiday data
- PostgreSQL - Store job queue, models, metrics, logs
- RabbitMQ - Publish training completion events
- File System - Store model artifacts
Dependents (Services That Call This)
- Forecasting Service - Load trained models for predictions
- Orchestrator Service - Trigger scheduled training jobs
- Frontend Dashboard - Display training progress and model metrics
- AI Insights Service - Analyze model performance patterns
Security Measures
Data Protection
- Tenant Isolation - All training jobs scoped to tenant_id
- Model Access Control - Only tenant can access their models
- Input Validation - Validate all training parameters
- Rate Limiting - Prevent training job spam
Model Security
- Model Checksums - SHA-256 hash verification for artifacts
- Version Control - Track all model versions with audit trail
- Access Logging - Log all model access and deployment
- Secure Storage - Model files stored with restricted permissions
WebSocket Security
- JWT Authentication - Authenticate WebSocket connections
- Connection Limits - Max connections per tenant
- Message Validation - Validate all WebSocket messages
- Heartbeat Monitoring - Detect and close stale connections
Performance Optimization
Training Performance
- Parallel Processing - Train multiple products concurrently
- Data Caching - Cache fetched external data across products
- Incremental Training - Only retrain changed products
- Resource Limits - CPU/memory limits per training job
- Priority Queue - Prioritize important products first
Storage Optimization
- Model Compression - Compress model artifacts (gzip)
- Old Model Cleanup - Automatic deletion after retention period
- Version Limits - Keep only N most recent versions
- Deduplication - Avoid storing identical models
WebSocket Optimization
- Message Batching - Batch progress updates (every 2 seconds)
- Connection Pooling - Reuse WebSocket connections
- Compression - Enable WebSocket message compression
- Heartbeat - Keep connections alive efficiently
Troubleshooting
Common Issues
Issue: Training jobs stuck in "pending" status
- Cause: Max concurrent jobs reached or worker process crashed
- Solution: Check
MAX_CONCURRENT_JOBSsetting, restart service
Issue: WebSocket connection drops during training
- Cause: Network timeout or client disconnection
- Solution: Implement auto-reconnect logic in client
Issue: "Insufficient data" errors for many products
- Cause: Products need 30+ days of sales history
- Solution: Import more historical sales data or reduce
MIN_TRAINING_DATA_DAYS
Issue: Low model accuracy (<70%)
- Cause: Insufficient data, outliers, or changing business patterns
- Solution: Clean outliers, add more features, or manually adjust Prophet params
Debug Mode
# Enable detailed logging
export LOG_LEVEL=DEBUG
export PROPHET_VERBOSE=1
# Enable training profiling
export ENABLE_PROFILING=1
# Disable concurrent jobs for debugging
export MAX_CONCURRENT_JOBS=1
Competitive Advantages
- One-Click ML - No data science expertise required
- Real-Time Visibility - WebSocket progress updates unique in bakery software
- Continuous Learning - Automatic weekly retraining
- Version Control - Track and compare all model versions
- Production-Ready - Robust error handling and retry mechanisms
- Scalable - Train models for thousands of products
- Spanish Market - Optimized for Spanish bakery patterns and holidays
Future Enhancements
- Hyperparameter Tuning - Automatic optimization of Prophet parameters
- A/B Testing - Deploy multiple models and compare performance
- Distributed Training - Scale across multiple machines
- GPU Acceleration - Use GPUs for deep learning models
- AutoML - Automatic algorithm selection (Prophet vs LSTM vs ARIMA)
- Model Explainability - SHAP values to explain predictions
- Custom Algorithms - Support for user-provided ML models
- Transfer Learning - Use pre-trained models from similar bakeries
For VUE Madrid Business Plan: The Training Service demonstrates advanced ML engineering capabilities with automated pipeline management and real-time monitoring. The ability to continuously improve forecast accuracy without manual intervention represents significant operational efficiency and competitive advantage. This self-learning system is a key differentiator in the bakery software market and showcases technical innovation suitable for EU technology grants and investor presentations.