Files
bakery-ia/services/training/README.md
2025-11-14 20:27:39 +01:00

24 KiB

Training Service (ML Model Management)

Overview

The Training Service is the machine learning pipeline engine of Bakery-IA, responsible for training, versioning, and managing Prophet forecasting models. It orchestrates the entire ML workflow from data collection to model deployment, providing real-time progress updates via WebSocket and ensuring bakeries always have the most accurate prediction models. This service enables continuous learning and model improvement without requiring data science expertise.

Key Features

Automated ML Pipeline

  • One-Click Model Training - Train models for all products with a single API call
  • Background Job Processing - Asynchronous training with job queue management
  • Multi-Product Training - Process multiple products in parallel
  • Progress Tracking - Real-time WebSocket updates on training status
  • Automatic Model Versioning - Track all model versions with performance metrics
  • Model Artifact Storage - Persist trained models for fast prediction loading

Training Job Management

  • Job Queue - FIFO queue for training requests
  • Job Status Tracking - Monitor pending, running, completed, and failed jobs
  • Concurrent Job Control - Limit parallel training jobs to prevent resource exhaustion
  • Timeout Handling - Automatic job termination after maximum duration
  • Error Recovery - Detailed error messages and retry capabilities
  • Job History - Complete audit trail of all training executions

Model Performance Tracking

  • Accuracy Metrics - MAE, RMSE, R², MAPE for each trained model
  • Historical Comparison - Compare current vs. previous model performance
  • Per-Product Analytics - Track which products have the best forecast accuracy
  • Training Duration Tracking - Monitor training performance and optimization
  • Model Selection - Automatically deploy best-performing models

Real-Time Communication

  • WebSocket Live Updates - Real-time progress percentage and status messages
  • Training Logs - Detailed step-by-step execution logs
  • Completion Notifications - RabbitMQ events for training completion
  • Error Alerts - Immediate notification of training failures

Feature Engineering

  • Historical Data Aggregation - Collect sales data for model training
  • External Data Integration - Fetch weather, traffic, holiday data
  • POI Feature Integration - Merge location-based POI features into training data
  • Feature Extraction - Generate 30+ temporal, contextual, and location-based features
  • Data Validation - Ensure minimum data requirements before training
  • Outlier Detection - Filter anomalous data points

Technical Capabilities

ML Training Pipeline

# Training workflow
async def train_model_pipeline(tenant_id: str, product_id: str):
    """Complete ML training pipeline"""

    # Step 1: Data Collection
    sales_data = await fetch_historical_sales(tenant_id, product_id)
    if len(sales_data) < MIN_TRAINING_DAYS:
        raise InsufficientDataError(f"Need {MIN_TRAINING_DAYS}+ days of data")

    # Step 2: Feature Engineering
    features = engineer_features(sales_data)
    weather_data = await fetch_weather_data(tenant_id)
    traffic_data = await fetch_traffic_data(tenant_id)
    holiday_data = await fetch_holiday_calendar()
    poi_features = await fetch_poi_features(tenant_id)  # NEW: Location context

    # Merge POI features into training dataframe
    features = merge_poi_features(features, poi_features)

    # Step 3: Prophet Model Training
    model = Prophet(
        seasonality_mode='additive',
        daily_seasonality=True,
        weekly_seasonality=True,
        yearly_seasonality=True,
    )
    model.add_country_holidays(country_name='ES')
    model.fit(features)

    # Step 4: Model Validation
    metrics = calculate_performance_metrics(model, sales_data)

    # Step 5: Model Storage
    model_path = save_model_artifact(model, tenant_id, product_id)

    # Step 6: Model Registration
    await register_model_in_database(model_path, metrics)

    # Step 7: Notification
    await publish_training_complete_event(tenant_id, product_id, metrics)

    return model, metrics

WebSocket Progress Updates

# Real-time progress broadcasting
async def broadcast_training_progress(job_id: str, progress: dict):
    """Send progress update to connected clients"""

    message = {
        "type": "training_progress",
        "job_id": job_id,
        "progress": {
            "percentage": progress["percentage"],          # 0-100
            "current_step": progress["step"],              # Step description
            "products_completed": progress["completed"],
            "products_total": progress["total"],
            "estimated_time_remaining": progress["eta"],   # Seconds
            "started_at": progress["start_time"]
        },
        "timestamp": datetime.utcnow().isoformat()
    }

    await websocket_manager.broadcast(job_id, message)

Model Artifact Management

# Model storage and retrieval
import joblib
from pathlib import Path

# Save trained model
def save_model_artifact(model: Prophet, tenant_id: str, product_id: str) -> str:
    """Serialize and store model"""
    model_dir = Path(f"/models/{tenant_id}/{product_id}")
    model_dir.mkdir(parents=True, exist_ok=True)

    version = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
    model_path = model_dir / f"model_v{version}.pkl"

    joblib.dump(model, model_path)
    return str(model_path)

# Load trained model
def load_model_artifact(model_path: str) -> Prophet:
    """Load serialized model"""
    return joblib.load(model_path)

Performance Metrics Calculation

from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import numpy as np

def calculate_performance_metrics(model: Prophet, actual_data: pd.DataFrame) -> dict:
    """Calculate comprehensive model performance metrics"""

    # Make predictions on validation set
    predictions = model.predict(actual_data)

    # Calculate metrics
    mae = mean_absolute_error(actual_data['y'], predictions['yhat'])
    rmse = np.sqrt(mean_squared_error(actual_data['y'], predictions['yhat']))
    r2 = r2_score(actual_data['y'], predictions['yhat'])
    mape = np.mean(np.abs((actual_data['y'] - predictions['yhat']) / actual_data['y'])) * 100

    return {
        "mae": float(mae),           # Mean Absolute Error
        "rmse": float(rmse),         # Root Mean Square Error
        "r2_score": float(r2),       # R-squared
        "mape": float(mape),         # Mean Absolute Percentage Error
        "accuracy": float(100 - mape) if mape < 100 else 0.0
    }

Business Value

For Bakery Owners

  • Continuous Improvement - Models automatically improve with more data
  • No ML Expertise Required - One-click training, no data science skills needed
  • Always Up-to-Date - Weekly automatic retraining keeps models accurate
  • Transparent Performance - Clear accuracy metrics show forecast reliability
  • Cost Savings - Automated ML pipeline eliminates need for data scientists

For Operations Managers

  • Model Version Control - Track and compare model versions over time
  • Performance Monitoring - Identify products with poor forecast accuracy
  • Training Scheduling - Schedule retraining during low-traffic hours
  • Resource Management - Control concurrent training jobs to prevent overload

For Platform Operations

  • Scalable ML Pipeline - Train models for thousands of products
  • Background Processing - Non-blocking training jobs
  • Error Handling - Robust error recovery and retry mechanisms
  • Cost Optimization - Efficient model storage and caching

Technology Stack

  • Framework: FastAPI (Python 3.11+) - Async web framework with WebSocket support
  • Database: PostgreSQL 17 - Training logs, model metadata, job queue
  • ML Library: Prophet (fbprophet) - Time series forecasting
  • Model Storage: Joblib - Model serialization
  • File System: Persistent volumes - Model artifact storage
  • WebSocket: FastAPI WebSocket - Real-time progress updates
  • Messaging: RabbitMQ 4.1 - Training completion events
  • ORM: SQLAlchemy 2.0 (async) - Database abstraction
  • Data Processing: Pandas, NumPy - Data manipulation
  • Logging: Structlog - Structured JSON logging
  • Metrics: Prometheus Client - Custom metrics

API Endpoints (Key Routes)

Training Management

  • POST /api/v1/training/start - Start training job for tenant
  • POST /api/v1/training/start/{product_id} - Train specific product
  • POST /api/v1/training/stop/{job_id} - Stop running training job
  • GET /api/v1/training/status/{job_id} - Get job status and progress
  • GET /api/v1/training/history - Get training job history
  • DELETE /api/v1/training/jobs/{job_id} - Delete training job record

Model Management

  • GET /api/v1/training/models - List all trained models
  • GET /api/v1/training/models/{model_id} - Get specific model details
  • GET /api/v1/training/models/{model_id}/metrics - Get model performance metrics
  • GET /api/v1/training/models/latest/{product_id} - Get latest model for product
  • POST /api/v1/training/models/{model_id}/deploy - Deploy specific model version
  • DELETE /api/v1/training/models/{model_id} - Delete model artifact

WebSocket

  • WS /api/v1/training/ws/{job_id} - Connect to training progress stream

Analytics

  • GET /api/v1/training/analytics/performance - Overall training performance
  • GET /api/v1/training/analytics/accuracy - Model accuracy distribution
  • GET /api/v1/training/analytics/duration - Training duration statistics

Database Schema

Main Tables

training_job_queue

CREATE TABLE training_job_queue (
    id UUID PRIMARY KEY,
    tenant_id UUID NOT NULL,
    job_name VARCHAR(255),
    products_to_train TEXT[],          -- Array of product IDs
    status VARCHAR(50) NOT NULL,        -- pending, running, completed, failed
    priority INTEGER DEFAULT 0,
    progress_percentage INTEGER DEFAULT 0,
    current_step VARCHAR(255),
    products_completed INTEGER DEFAULT 0,
    products_total INTEGER,
    started_at TIMESTAMP,
    completed_at TIMESTAMP,
    estimated_completion TIMESTAMP,
    error_message TEXT,
    retry_count INTEGER DEFAULT 0,
    created_by UUID,
    created_at TIMESTAMP DEFAULT NOW(),
    updated_at TIMESTAMP DEFAULT NOW()
);

trained_models

CREATE TABLE trained_models (
    id UUID PRIMARY KEY,
    tenant_id UUID NOT NULL,
    product_id UUID NOT NULL,
    model_version VARCHAR(50) NOT NULL,
    model_path VARCHAR(500) NOT NULL,
    training_job_id UUID REFERENCES training_job_queue(id),
    algorithm VARCHAR(50) DEFAULT 'prophet',
    hyperparameters JSONB,
    training_duration_seconds INTEGER,
    training_data_points INTEGER,
    is_deployed BOOLEAN DEFAULT FALSE,
    deployed_at TIMESTAMP,
    created_at TIMESTAMP DEFAULT NOW(),
    UNIQUE(tenant_id, product_id, model_version)
);

model_performance_metrics

CREATE TABLE model_performance_metrics (
    id UUID PRIMARY KEY,
    model_id UUID REFERENCES trained_models(id),
    tenant_id UUID NOT NULL,
    product_id UUID NOT NULL,
    mae DECIMAL(10, 4),                -- Mean Absolute Error
    rmse DECIMAL(10, 4),               -- Root Mean Square Error
    r2_score DECIMAL(10, 6),           -- R-squared
    mape DECIMAL(10, 4),               -- Mean Absolute Percentage Error
    accuracy_percentage DECIMAL(5, 2),
    validation_data_points INTEGER,
    created_at TIMESTAMP DEFAULT NOW()
);

model_training_logs

CREATE TABLE model_training_logs (
    id UUID PRIMARY KEY,
    training_job_id UUID REFERENCES training_job_queue(id),
    tenant_id UUID NOT NULL,
    product_id UUID,
    log_level VARCHAR(20),              -- DEBUG, INFO, WARNING, ERROR
    message TEXT,
    step_name VARCHAR(100),
    execution_time_ms INTEGER,
    metadata JSONB,
    created_at TIMESTAMP DEFAULT NOW()
);

model_artifacts (Metadata only, actual files on disk)

CREATE TABLE model_artifacts (
    id UUID PRIMARY KEY,
    model_id UUID REFERENCES trained_models(id),
    artifact_type VARCHAR(50),          -- model_file, feature_list, scaler, etc.
    file_path VARCHAR(500),
    file_size_bytes BIGINT,
    checksum VARCHAR(64),               -- SHA-256 hash
    created_at TIMESTAMP DEFAULT NOW()
);

Events & Messaging

Published Events (RabbitMQ)

Exchange: training Routing Key: training.completed

Training Completed Event

{
    "event_type": "training_completed",
    "tenant_id": "uuid",
    "job_id": "uuid",
    "job_name": "Weekly retraining - All products",
    "status": "completed",
    "results": {
        "successful_trainings": 25,
        "failed_trainings": 2,
        "total_products": 27,
        "models_created": [
            {
                "product_id": "uuid",
                "product_name": "Baguette",
                "model_version": "20251106_143022",
                "accuracy": 82.5,
                "mae": 12.3,
                "rmse": 18.7,
                "r2_score": 0.78
            }
        ],
        "average_accuracy": 79.8,
        "training_duration_seconds": 342
    },
    "started_at": "2025-11-06T14:25:00Z",
    "completed_at": "2025-11-06T14:30:42Z",
    "timestamp": "2025-11-06T14:30:42Z"
}

Training Failed Event

{
    "event_type": "training_failed",
    "tenant_id": "uuid",
    "job_id": "uuid",
    "product_id": "uuid",
    "product_name": "Croissant",
    "error_type": "InsufficientDataError",
    "error_message": "Product requires minimum 30 days of sales data. Currently: 15 days.",
    "recommended_action": "Collect more sales data before retraining",
    "severity": "medium",
    "timestamp": "2025-11-06T14:28:15Z"
}

Consumed Events

  • From Orchestrator: Scheduled training triggers
  • From Sales: New sales data imported (triggers retraining)

Custom Metrics (Prometheus)

# Training job metrics
training_jobs_total = Counter(
    'training_jobs_total',
    'Total training jobs started',
    ['tenant_id', 'status']  # completed, failed, cancelled
)

training_duration_seconds = Histogram(
    'training_duration_seconds',
    'Training job duration',
    ['tenant_id'],
    buckets=[10, 30, 60, 120, 300, 600, 1800, 3600]  # seconds
)

models_trained_total = Counter(
    'models_trained_total',
    'Total models successfully trained',
    ['tenant_id', 'product_category']
)

# Model performance metrics
model_accuracy_distribution = Histogram(
    'model_accuracy_percentage',
    'Distribution of model accuracy scores',
    ['tenant_id'],
    buckets=[50, 60, 70, 75, 80, 85, 90, 95, 100]  # percentage
)

model_mae_distribution = Histogram(
    'model_mae',
    'Distribution of Mean Absolute Error',
    ['tenant_id'],
    buckets=[1, 5, 10, 20, 30, 50, 100]  # units
)

# WebSocket metrics
websocket_connections_total = Gauge(
    'training_websocket_connections',
    'Active WebSocket connections',
    ['tenant_id']
)

websocket_messages_sent = Counter(
    'training_websocket_messages_total',
    'Total WebSocket messages sent',
    ['tenant_id', 'message_type']
)

Configuration

Environment Variables

Service Configuration:

  • PORT - Service port (default: 8004)
  • DATABASE_URL - PostgreSQL connection string
  • RABBITMQ_URL - RabbitMQ connection string
  • MODEL_STORAGE_PATH - Path for model artifacts (default: /models)

Training Configuration:

  • MAX_CONCURRENT_JOBS - Maximum parallel training jobs (default: 3)
  • MAX_TRAINING_TIME_MINUTES - Job timeout (default: 30)
  • MIN_TRAINING_DATA_DAYS - Minimum history required (default: 30)
  • ENABLE_AUTO_DEPLOYMENT - Auto-deploy after training (default: true)

Prophet Configuration:

  • PROPHET_DAILY_SEASONALITY - Enable daily patterns (default: true)
  • PROPHET_WEEKLY_SEASONALITY - Enable weekly patterns (default: true)
  • PROPHET_YEARLY_SEASONALITY - Enable yearly patterns (default: true)
  • PROPHET_INTERVAL_WIDTH - Confidence interval (default: 0.95)
  • PROPHET_CHANGEPOINT_PRIOR_SCALE - Trend flexibility (default: 0.05)

WebSocket Configuration:

  • WEBSOCKET_HEARTBEAT_INTERVAL - Ping interval seconds (default: 30)
  • WEBSOCKET_MAX_CONNECTIONS - Max connections per tenant (default: 10)
  • WEBSOCKET_MESSAGE_QUEUE_SIZE - Message buffer size (default: 100)

Storage Configuration:

  • MODEL_RETENTION_DAYS - Days to keep old models (default: 90)
  • MAX_MODEL_VERSIONS_PER_PRODUCT - Version limit (default: 10)
  • ENABLE_MODEL_COMPRESSION - Compress model files (default: true)

Development Setup

Prerequisites

  • Python 3.11+
  • PostgreSQL 17
  • RabbitMQ 4.1
  • Persistent storage for model artifacts

Local Development

# Create virtual environment
cd services/training
python -m venv venv
source venv/bin/activate

# Install dependencies
pip install -r requirements.txt

# Set environment variables
export DATABASE_URL=postgresql://user:pass@localhost:5432/training
export RABBITMQ_URL=amqp://guest:guest@localhost:5672/
export MODEL_STORAGE_PATH=/tmp/models

# Create model storage directory
mkdir -p /tmp/models

# Run database migrations
alembic upgrade head

# Run the service
python main.py

Testing

# Unit tests
pytest tests/unit/ -v

# Integration tests (requires services)
pytest tests/integration/ -v

# WebSocket tests
pytest tests/websocket/ -v

# Test with coverage
pytest --cov=app tests/ --cov-report=html

WebSocket Testing

# Test WebSocket connection
import asyncio
import websockets
import json

async def test_training_progress():
    uri = "ws://localhost:8004/api/v1/training/ws/job-id-here"
    async with websockets.connect(uri) as websocket:
        while True:
            message = await websocket.recv()
            data = json.loads(message)
            print(f"Progress: {data['progress']['percentage']}%")
            print(f"Step: {data['progress']['current_step']}")

            if data['type'] == 'training_completed':
                print("Training finished!")
                break

asyncio.run(test_training_progress())

POI Feature Integration

How POI Features Enhance Training

The Training Service integrates location-based POI features from the External Service to improve forecast accuracy:

POI Features Included:

  • school_density - Number of schools within 1km radius
  • office_density - Number of offices and business centers nearby
  • residential_density - Residential area proximity
  • transport_hub_proximity - Distance to metro, bus, train stations
  • commercial_zone_score - Commercial activity in the area
  • restaurant_density - Nearby restaurants and cafes
  • competitor_proximity - Distance to competing bakeries
  • And 11+ more location-based features

Integration Process:

  1. Fetch POI Context - Retrieve tenant's POI features from External Service (/poi-context/{tenant_id})
  2. Extract ML Features - Parse ml_features JSON object from POI context
  3. Merge with Training Data - Add POI features as additional columns in training dataframe
  4. Prophet Training - Include POI features as regressors in Prophet model
  5. Feature Importance - Track which POI features most impact predictions

Example POI Feature Integration:

from app.ml.poi_feature_integrator import POIFeatureIntegrator

# Initialize POI integrator
poi_integrator = POIFeatureIntegrator(external_service_url)

# Fetch and merge POI features
poi_features = await poi_integrator.fetch_poi_features(tenant_id)
training_df = poi_integrator.merge_poi_features(training_df, poi_features)

# POI features now available as columns:
# training_df['school_density'], training_df['office_density'], etc.

# Add POI features as Prophet regressors
for feature_name in poi_features.keys():
    prophet_model.add_regressor(feature_name)

Endpoint Used:

  • Via shared client: /api/v1/tenants/{tenant_id}/external/poi-context (routed through API Gateway)

Integration Points

Dependencies (Services Called)

  • Sales Service - Fetch historical sales data for training
  • External Service - Fetch weather, traffic, holiday, and POI feature data
  • PostgreSQL - Store job queue, models, metrics, logs
  • RabbitMQ - Publish training completion events
  • File System - Store model artifacts

Dependents (Services That Call This)

  • Forecasting Service - Load trained models for predictions
  • Orchestrator Service - Trigger scheduled training jobs
  • Frontend Dashboard - Display training progress and model metrics
  • AI Insights Service - Analyze model performance patterns

Security Measures

Data Protection

  • Tenant Isolation - All training jobs scoped to tenant_id
  • Model Access Control - Only tenant can access their models
  • Input Validation - Validate all training parameters
  • Rate Limiting - Prevent training job spam

Model Security

  • Model Checksums - SHA-256 hash verification for artifacts
  • Version Control - Track all model versions with audit trail
  • Access Logging - Log all model access and deployment
  • Secure Storage - Model files stored with restricted permissions

WebSocket Security

  • JWT Authentication - Authenticate WebSocket connections
  • Connection Limits - Max connections per tenant
  • Message Validation - Validate all WebSocket messages
  • Heartbeat Monitoring - Detect and close stale connections

Performance Optimization

Training Performance

  1. Parallel Processing - Train multiple products concurrently
  2. Data Caching - Cache fetched external data across products
  3. Incremental Training - Only retrain changed products
  4. Resource Limits - CPU/memory limits per training job
  5. Priority Queue - Prioritize important products first

Storage Optimization

  1. Model Compression - Compress model artifacts (gzip)
  2. Old Model Cleanup - Automatic deletion after retention period
  3. Version Limits - Keep only N most recent versions
  4. Deduplication - Avoid storing identical models

WebSocket Optimization

  1. Message Batching - Batch progress updates (every 2 seconds)
  2. Connection Pooling - Reuse WebSocket connections
  3. Compression - Enable WebSocket message compression
  4. Heartbeat - Keep connections alive efficiently

Troubleshooting

Common Issues

Issue: Training jobs stuck in "pending" status

  • Cause: Max concurrent jobs reached or worker process crashed
  • Solution: Check MAX_CONCURRENT_JOBS setting, restart service

Issue: WebSocket connection drops during training

  • Cause: Network timeout or client disconnection
  • Solution: Implement auto-reconnect logic in client

Issue: "Insufficient data" errors for many products

  • Cause: Products need 30+ days of sales history
  • Solution: Import more historical sales data or reduce MIN_TRAINING_DATA_DAYS

Issue: Low model accuracy (<70%)

  • Cause: Insufficient data, outliers, or changing business patterns
  • Solution: Clean outliers, add more features, or manually adjust Prophet params

Debug Mode

# Enable detailed logging
export LOG_LEVEL=DEBUG
export PROPHET_VERBOSE=1

# Enable training profiling
export ENABLE_PROFILING=1

# Disable concurrent jobs for debugging
export MAX_CONCURRENT_JOBS=1

Competitive Advantages

  1. One-Click ML - No data science expertise required
  2. Real-Time Visibility - WebSocket progress updates unique in bakery software
  3. Continuous Learning - Automatic weekly retraining
  4. Version Control - Track and compare all model versions
  5. Production-Ready - Robust error handling and retry mechanisms
  6. Scalable - Train models for thousands of products
  7. Spanish Market - Optimized for Spanish bakery patterns and holidays

Future Enhancements

  • Hyperparameter Tuning - Automatic optimization of Prophet parameters
  • A/B Testing - Deploy multiple models and compare performance
  • Distributed Training - Scale across multiple machines
  • GPU Acceleration - Use GPUs for deep learning models
  • AutoML - Automatic algorithm selection (Prophet vs LSTM vs ARIMA)
  • Model Explainability - SHAP values to explain predictions
  • Custom Algorithms - Support for user-provided ML models
  • Transfer Learning - Use pre-trained models from similar bakeries

For VUE Madrid Business Plan: The Training Service demonstrates advanced ML engineering capabilities with automated pipeline management and real-time monitoring. The ability to continuously improve forecast accuracy without manual intervention represents significant operational efficiency and competitive advantage. This self-learning system is a key differentiator in the bakery software market and showcases technical innovation suitable for EU technology grants and investor presentations.