729 lines
26 KiB
Markdown
729 lines
26 KiB
Markdown
|
|
# Training Service (ML Model Management)
|
||
|
|
|
||
|
|
## Overview
|
||
|
|
|
||
|
|
The **Training Service** is the machine learning pipeline engine of Bakery-IA, responsible for training, versioning, and managing Prophet forecasting models. It orchestrates the entire ML workflow from data collection to model deployment, providing real-time progress updates via WebSocket and ensuring bakeries always have the most accurate prediction models. This service enables continuous learning and model improvement without requiring data science expertise.
|
||
|
|
|
||
|
|
## Key Features
|
||
|
|
|
||
|
|
### Automated ML Pipeline
|
||
|
|
- **One-Click Model Training** - Train models for all products with a single API call
|
||
|
|
- **Background Job Processing** - Asynchronous training with job queue management
|
||
|
|
- **Multi-Product Training** - Process multiple products in parallel
|
||
|
|
- **Progress Tracking** - Real-time WebSocket updates on training status
|
||
|
|
- **Automatic Model Versioning** - Track all model versions with performance metrics
|
||
|
|
- **Model Artifact Storage** - Persist trained models for fast prediction loading
|
||
|
|
|
||
|
|
### Training Job Management
|
||
|
|
- **Job Queue** - FIFO queue for training requests
|
||
|
|
- **Job Status Tracking** - Monitor pending, running, completed, and failed jobs
|
||
|
|
- **Concurrent Job Control** - Limit parallel training jobs to prevent resource exhaustion
|
||
|
|
- **Timeout Handling** - Automatic job termination after maximum duration
|
||
|
|
- **Error Recovery** - Detailed error messages and retry capabilities
|
||
|
|
- **Job History** - Complete audit trail of all training executions
|
||
|
|
|
||
|
|
### Model Performance Tracking
|
||
|
|
- **Accuracy Metrics** - MAE, RMSE, R², MAPE for each trained model
|
||
|
|
- **Historical Comparison** - Compare current vs. previous model performance
|
||
|
|
- **Per-Product Analytics** - Track which products have the best forecast accuracy
|
||
|
|
- **Training Duration Tracking** - Monitor training performance and optimization
|
||
|
|
- **Model Selection** - Automatically deploy best-performing models
|
||
|
|
|
||
|
|
### Real-Time Communication
|
||
|
|
- **WebSocket Live Updates** - Real-time progress percentage and status messages
|
||
|
|
- **Training Logs** - Detailed step-by-step execution logs
|
||
|
|
- **Completion Notifications** - RabbitMQ events for training completion
|
||
|
|
- **Error Alerts** - Immediate notification of training failures
|
||
|
|
|
||
|
|
### Feature Engineering
|
||
|
|
- **Historical Data Aggregation** - Collect sales data for model training
|
||
|
|
- **External Data Integration** - Fetch weather, traffic, holiday data
|
||
|
|
- **POI Feature Integration** - Merge location-based POI features into training data
|
||
|
|
- **Feature Extraction** - Generate 30+ temporal, contextual, and location-based features
|
||
|
|
- **Data Validation** - Ensure minimum data requirements before training
|
||
|
|
- **Outlier Detection** - Filter anomalous data points
|
||
|
|
|
||
|
|
## Technical Capabilities
|
||
|
|
|
||
|
|
### ML Training Pipeline
|
||
|
|
|
||
|
|
```python
|
||
|
|
# Training workflow
|
||
|
|
async def train_model_pipeline(tenant_id: str, product_id: str):
|
||
|
|
"""Complete ML training pipeline"""
|
||
|
|
|
||
|
|
# Step 1: Data Collection
|
||
|
|
sales_data = await fetch_historical_sales(tenant_id, product_id)
|
||
|
|
if len(sales_data) < MIN_TRAINING_DAYS:
|
||
|
|
raise InsufficientDataError(f"Need {MIN_TRAINING_DAYS}+ days of data")
|
||
|
|
|
||
|
|
# Step 2: Feature Engineering
|
||
|
|
features = engineer_features(sales_data)
|
||
|
|
weather_data = await fetch_weather_data(tenant_id)
|
||
|
|
traffic_data = await fetch_traffic_data(tenant_id)
|
||
|
|
holiday_data = await fetch_holiday_calendar()
|
||
|
|
poi_features = await fetch_poi_features(tenant_id) # NEW: Location context
|
||
|
|
|
||
|
|
# Merge POI features into training dataframe
|
||
|
|
features = merge_poi_features(features, poi_features)
|
||
|
|
|
||
|
|
# Step 3: Prophet Model Training
|
||
|
|
model = Prophet(
|
||
|
|
seasonality_mode='additive',
|
||
|
|
daily_seasonality=True,
|
||
|
|
weekly_seasonality=True,
|
||
|
|
yearly_seasonality=True,
|
||
|
|
)
|
||
|
|
model.add_country_holidays(country_name='ES')
|
||
|
|
model.fit(features)
|
||
|
|
|
||
|
|
# Step 4: Model Validation
|
||
|
|
metrics = calculate_performance_metrics(model, sales_data)
|
||
|
|
|
||
|
|
# Step 5: Model Storage
|
||
|
|
model_path = save_model_artifact(model, tenant_id, product_id)
|
||
|
|
|
||
|
|
# Step 6: Model Registration
|
||
|
|
await register_model_in_database(model_path, metrics)
|
||
|
|
|
||
|
|
# Step 7: Notification
|
||
|
|
await publish_training_complete_event(tenant_id, product_id, metrics)
|
||
|
|
|
||
|
|
return model, metrics
|
||
|
|
```
|
||
|
|
|
||
|
|
### WebSocket Progress Updates
|
||
|
|
|
||
|
|
```python
|
||
|
|
# Real-time progress broadcasting
|
||
|
|
async def broadcast_training_progress(job_id: str, progress: dict):
|
||
|
|
"""Send progress update to connected clients"""
|
||
|
|
|
||
|
|
message = {
|
||
|
|
"type": "training_progress",
|
||
|
|
"job_id": job_id,
|
||
|
|
"progress": {
|
||
|
|
"percentage": progress["percentage"], # 0-100
|
||
|
|
"current_step": progress["step"], # Step description
|
||
|
|
"products_completed": progress["completed"],
|
||
|
|
"products_total": progress["total"],
|
||
|
|
"estimated_time_remaining": progress["eta"], # Seconds
|
||
|
|
"started_at": progress["start_time"]
|
||
|
|
},
|
||
|
|
"timestamp": datetime.utcnow().isoformat()
|
||
|
|
}
|
||
|
|
|
||
|
|
await websocket_manager.broadcast(job_id, message)
|
||
|
|
```
|
||
|
|
|
||
|
|
### Model Artifact Management (MinIO Storage)
|
||
|
|
|
||
|
|
```python
|
||
|
|
# Model storage and retrieval using MinIO
|
||
|
|
import joblib
|
||
|
|
from shared.clients.minio_client import minio_client
|
||
|
|
|
||
|
|
# Save trained model to MinIO
|
||
|
|
def save_model_artifact(model: Prophet, tenant_id: str, product_id: str) -> str:
|
||
|
|
"""Serialize and store model in MinIO"""
|
||
|
|
import io
|
||
|
|
version = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||
|
|
model_id = str(uuid.uuid4())
|
||
|
|
object_name = f"models/{tenant_id}/{product_id}/{model_id}.pkl"
|
||
|
|
|
||
|
|
# Serialize model (joblib.dump writes to file-like objects)
|
||
|
|
buffer = io.BytesIO()
|
||
|
|
joblib.dump(model, buffer)
|
||
|
|
model_data = buffer.getvalue()
|
||
|
|
|
||
|
|
# Upload to MinIO
|
||
|
|
minio_client.put_object(
|
||
|
|
bucket_name="training-models",
|
||
|
|
object_name=object_name,
|
||
|
|
data=model_data,
|
||
|
|
content_type="application/octet-stream"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Return MinIO path
|
||
|
|
return f"minio://training-models/{object_name}"
|
||
|
|
|
||
|
|
# Load trained model from MinIO
|
||
|
|
def load_model_artifact(model_path: str) -> Prophet:
|
||
|
|
"""Load serialized model from MinIO"""
|
||
|
|
import io
|
||
|
|
# Parse MinIO path: minio://bucket_name/object_path
|
||
|
|
_, bucket_and_path = model_path.split("://", 1)
|
||
|
|
bucket_name, object_name = bucket_and_path.split("/", 1)
|
||
|
|
|
||
|
|
# Download from MinIO
|
||
|
|
model_data = minio_client.get_object(bucket_name, object_name)
|
||
|
|
|
||
|
|
# Deserialize (joblib.load reads from file-like objects)
|
||
|
|
buffer = io.BytesIO(model_data)
|
||
|
|
return joblib.load(buffer)
|
||
|
|
```
|
||
|
|
|
||
|
|
### Performance Metrics Calculation
|
||
|
|
|
||
|
|
```python
|
||
|
|
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
def calculate_performance_metrics(model: Prophet, actual_data: pd.DataFrame) -> dict:
|
||
|
|
"""Calculate comprehensive model performance metrics"""
|
||
|
|
|
||
|
|
# Make predictions on validation set
|
||
|
|
predictions = model.predict(actual_data)
|
||
|
|
|
||
|
|
# Calculate metrics
|
||
|
|
mae = mean_absolute_error(actual_data['y'], predictions['yhat'])
|
||
|
|
rmse = np.sqrt(mean_squared_error(actual_data['y'], predictions['yhat']))
|
||
|
|
r2 = r2_score(actual_data['y'], predictions['yhat'])
|
||
|
|
mape = np.mean(np.abs((actual_data['y'] - predictions['yhat']) / actual_data['y'])) * 100
|
||
|
|
|
||
|
|
return {
|
||
|
|
"mae": float(mae), # Mean Absolute Error
|
||
|
|
"rmse": float(rmse), # Root Mean Square Error
|
||
|
|
"r2_score": float(r2), # R-squared
|
||
|
|
"mape": float(mape), # Mean Absolute Percentage Error
|
||
|
|
"accuracy": float(100 - mape) if mape < 100 else 0.0
|
||
|
|
}
|
||
|
|
```
|
||
|
|
|
||
|
|
## Business Value
|
||
|
|
|
||
|
|
### For Bakery Owners
|
||
|
|
- **Continuous Improvement** - Models automatically improve with more data
|
||
|
|
- **No ML Expertise Required** - One-click training, no data science skills needed
|
||
|
|
- **Always Up-to-Date** - Weekly automatic retraining keeps models accurate
|
||
|
|
- **Transparent Performance** - Clear accuracy metrics show forecast reliability
|
||
|
|
- **Cost Savings** - Automated ML pipeline eliminates need for data scientists
|
||
|
|
|
||
|
|
### For Operations Managers
|
||
|
|
- **Model Version Control** - Track and compare model versions over time
|
||
|
|
- **Performance Monitoring** - Identify products with poor forecast accuracy
|
||
|
|
- **Training Scheduling** - Schedule retraining during low-traffic hours
|
||
|
|
- **Resource Management** - Control concurrent training jobs to prevent overload
|
||
|
|
|
||
|
|
### For Platform Operations
|
||
|
|
- **Scalable ML Pipeline** - Train models for thousands of products
|
||
|
|
- **Background Processing** - Non-blocking training jobs
|
||
|
|
- **Error Handling** - Robust error recovery and retry mechanisms
|
||
|
|
- **Cost Optimization** - Efficient model storage and caching
|
||
|
|
|
||
|
|
## Technology Stack
|
||
|
|
|
||
|
|
- **Framework**: FastAPI (Python 3.11+) - Async web framework with WebSocket support
|
||
|
|
- **Database**: PostgreSQL 17 - Training logs, model metadata, job queue
|
||
|
|
- **ML Library**: Prophet (fbprophet) - Time series forecasting
|
||
|
|
- **Model Storage**: MinIO (S3-compatible) - Distributed object storage with TLS
|
||
|
|
- **Serialization**: Joblib - Model serialization
|
||
|
|
- **WebSocket**: FastAPI WebSocket - Real-time progress updates
|
||
|
|
- **Messaging**: RabbitMQ 4.1 - Training completion events
|
||
|
|
- **ORM**: SQLAlchemy 2.0 (async) - Database abstraction
|
||
|
|
- **Data Processing**: Pandas, NumPy - Data manipulation
|
||
|
|
- **Logging**: Structlog - Structured JSON logging
|
||
|
|
- **Metrics**: Prometheus Client - Custom metrics
|
||
|
|
|
||
|
|
## API Endpoints (Key Routes)
|
||
|
|
|
||
|
|
### Training Management
|
||
|
|
- `POST /api/v1/training/start` - Start training job for tenant
|
||
|
|
- `POST /api/v1/training/start/{product_id}` - Train specific product
|
||
|
|
- `POST /api/v1/training/stop/{job_id}` - Stop running training job
|
||
|
|
- `GET /api/v1/training/status/{job_id}` - Get job status and progress
|
||
|
|
- `GET /api/v1/training/history` - Get training job history
|
||
|
|
- `DELETE /api/v1/training/jobs/{job_id}` - Delete training job record
|
||
|
|
|
||
|
|
### Model Management
|
||
|
|
- `GET /api/v1/training/models` - List all trained models
|
||
|
|
- `GET /api/v1/training/models/{model_id}` - Get specific model details
|
||
|
|
- `GET /api/v1/training/models/{model_id}/metrics` - Get model performance metrics
|
||
|
|
- `GET /api/v1/training/models/latest/{product_id}` - Get latest model for product
|
||
|
|
- `POST /api/v1/training/models/{model_id}/deploy` - Deploy specific model version
|
||
|
|
- `DELETE /api/v1/training/models/{model_id}` - Delete model artifact
|
||
|
|
|
||
|
|
### WebSocket
|
||
|
|
- `WS /api/v1/training/ws/{job_id}` - Connect to training progress stream
|
||
|
|
|
||
|
|
### Analytics
|
||
|
|
- `GET /api/v1/training/analytics/performance` - Overall training performance
|
||
|
|
- `GET /api/v1/training/analytics/accuracy` - Model accuracy distribution
|
||
|
|
- `GET /api/v1/training/analytics/duration` - Training duration statistics
|
||
|
|
|
||
|
|
## Database Schema
|
||
|
|
|
||
|
|
### Main Tables
|
||
|
|
|
||
|
|
**training_job_queue**
|
||
|
|
```sql
|
||
|
|
CREATE TABLE training_job_queue (
|
||
|
|
id UUID PRIMARY KEY,
|
||
|
|
tenant_id UUID NOT NULL,
|
||
|
|
job_name VARCHAR(255),
|
||
|
|
products_to_train TEXT[], -- Array of product IDs
|
||
|
|
status VARCHAR(50) NOT NULL, -- pending, running, completed, failed
|
||
|
|
priority INTEGER DEFAULT 0,
|
||
|
|
progress_percentage INTEGER DEFAULT 0,
|
||
|
|
current_step VARCHAR(255),
|
||
|
|
products_completed INTEGER DEFAULT 0,
|
||
|
|
products_total INTEGER,
|
||
|
|
started_at TIMESTAMP,
|
||
|
|
completed_at TIMESTAMP,
|
||
|
|
estimated_completion TIMESTAMP,
|
||
|
|
error_message TEXT,
|
||
|
|
retry_count INTEGER DEFAULT 0,
|
||
|
|
created_by UUID,
|
||
|
|
created_at TIMESTAMP DEFAULT NOW(),
|
||
|
|
updated_at TIMESTAMP DEFAULT NOW()
|
||
|
|
);
|
||
|
|
```
|
||
|
|
|
||
|
|
**trained_models**
|
||
|
|
```sql
|
||
|
|
CREATE TABLE trained_models (
|
||
|
|
id UUID PRIMARY KEY,
|
||
|
|
tenant_id UUID NOT NULL,
|
||
|
|
product_id UUID NOT NULL,
|
||
|
|
model_version VARCHAR(50) NOT NULL,
|
||
|
|
model_path VARCHAR(500) NOT NULL,
|
||
|
|
training_job_id UUID REFERENCES training_job_queue(id),
|
||
|
|
algorithm VARCHAR(50) DEFAULT 'prophet',
|
||
|
|
hyperparameters JSONB,
|
||
|
|
training_duration_seconds INTEGER,
|
||
|
|
training_data_points INTEGER,
|
||
|
|
is_deployed BOOLEAN DEFAULT FALSE,
|
||
|
|
deployed_at TIMESTAMP,
|
||
|
|
created_at TIMESTAMP DEFAULT NOW(),
|
||
|
|
UNIQUE(tenant_id, product_id, model_version)
|
||
|
|
);
|
||
|
|
```
|
||
|
|
|
||
|
|
**model_performance_metrics**
|
||
|
|
```sql
|
||
|
|
CREATE TABLE model_performance_metrics (
|
||
|
|
id UUID PRIMARY KEY,
|
||
|
|
model_id UUID REFERENCES trained_models(id),
|
||
|
|
tenant_id UUID NOT NULL,
|
||
|
|
product_id UUID NOT NULL,
|
||
|
|
mae DECIMAL(10, 4), -- Mean Absolute Error
|
||
|
|
rmse DECIMAL(10, 4), -- Root Mean Square Error
|
||
|
|
r2_score DECIMAL(10, 6), -- R-squared
|
||
|
|
mape DECIMAL(10, 4), -- Mean Absolute Percentage Error
|
||
|
|
accuracy_percentage DECIMAL(5, 2),
|
||
|
|
validation_data_points INTEGER,
|
||
|
|
created_at TIMESTAMP DEFAULT NOW()
|
||
|
|
);
|
||
|
|
```
|
||
|
|
|
||
|
|
**model_training_logs**
|
||
|
|
```sql
|
||
|
|
CREATE TABLE model_training_logs (
|
||
|
|
id UUID PRIMARY KEY,
|
||
|
|
training_job_id UUID REFERENCES training_job_queue(id),
|
||
|
|
tenant_id UUID NOT NULL,
|
||
|
|
product_id UUID,
|
||
|
|
log_level VARCHAR(20), -- DEBUG, INFO, WARNING, ERROR
|
||
|
|
message TEXT,
|
||
|
|
step_name VARCHAR(100),
|
||
|
|
execution_time_ms INTEGER,
|
||
|
|
metadata JSONB,
|
||
|
|
created_at TIMESTAMP DEFAULT NOW()
|
||
|
|
);
|
||
|
|
```
|
||
|
|
|
||
|
|
**model_artifacts** (Metadata only, actual files on disk)
|
||
|
|
```sql
|
||
|
|
CREATE TABLE model_artifacts (
|
||
|
|
id UUID PRIMARY KEY,
|
||
|
|
model_id UUID REFERENCES trained_models(id),
|
||
|
|
artifact_type VARCHAR(50), -- model_file, feature_list, scaler, etc.
|
||
|
|
file_path VARCHAR(500),
|
||
|
|
file_size_bytes BIGINT,
|
||
|
|
checksum VARCHAR(64), -- SHA-256 hash
|
||
|
|
created_at TIMESTAMP DEFAULT NOW()
|
||
|
|
);
|
||
|
|
```
|
||
|
|
|
||
|
|
## Events & Messaging
|
||
|
|
|
||
|
|
### Published Events (RabbitMQ)
|
||
|
|
|
||
|
|
**Exchange**: `training`
|
||
|
|
**Routing Key**: `training.completed`
|
||
|
|
|
||
|
|
**Training Completed Event**
|
||
|
|
```json
|
||
|
|
{
|
||
|
|
"event_type": "training_completed",
|
||
|
|
"tenant_id": "uuid",
|
||
|
|
"job_id": "uuid",
|
||
|
|
"job_name": "Weekly retraining - All products",
|
||
|
|
"status": "completed",
|
||
|
|
"results": {
|
||
|
|
"successful_trainings": 25,
|
||
|
|
"failed_trainings": 2,
|
||
|
|
"total_products": 27,
|
||
|
|
"models_created": [
|
||
|
|
{
|
||
|
|
"product_id": "uuid",
|
||
|
|
"product_name": "Baguette",
|
||
|
|
"model_version": "20251106_143022",
|
||
|
|
"accuracy": 82.5,
|
||
|
|
"mae": 12.3,
|
||
|
|
"rmse": 18.7,
|
||
|
|
"r2_score": 0.78
|
||
|
|
}
|
||
|
|
],
|
||
|
|
"average_accuracy": 79.8,
|
||
|
|
"training_duration_seconds": 342
|
||
|
|
},
|
||
|
|
"started_at": "2025-11-06T14:25:00Z",
|
||
|
|
"completed_at": "2025-11-06T14:30:42Z",
|
||
|
|
"timestamp": "2025-11-06T14:30:42Z"
|
||
|
|
}
|
||
|
|
```
|
||
|
|
|
||
|
|
**Training Failed Event**
|
||
|
|
```json
|
||
|
|
{
|
||
|
|
"event_type": "training_failed",
|
||
|
|
"tenant_id": "uuid",
|
||
|
|
"job_id": "uuid",
|
||
|
|
"product_id": "uuid",
|
||
|
|
"product_name": "Croissant",
|
||
|
|
"error_type": "InsufficientDataError",
|
||
|
|
"error_message": "Product requires minimum 30 days of sales data. Currently: 15 days.",
|
||
|
|
"recommended_action": "Collect more sales data before retraining",
|
||
|
|
"severity": "medium",
|
||
|
|
"timestamp": "2025-11-06T14:28:15Z"
|
||
|
|
}
|
||
|
|
```
|
||
|
|
|
||
|
|
### Consumed Events
|
||
|
|
- **From Orchestrator**: Scheduled training triggers
|
||
|
|
- **From Sales**: New sales data imported (triggers retraining)
|
||
|
|
|
||
|
|
## Custom Metrics (Prometheus)
|
||
|
|
|
||
|
|
```python
|
||
|
|
# Training job metrics
|
||
|
|
training_jobs_total = Counter(
|
||
|
|
'training_jobs_total',
|
||
|
|
'Total training jobs started',
|
||
|
|
['tenant_id', 'status'] # completed, failed, cancelled
|
||
|
|
)
|
||
|
|
|
||
|
|
training_duration_seconds = Histogram(
|
||
|
|
'training_duration_seconds',
|
||
|
|
'Training job duration',
|
||
|
|
['tenant_id'],
|
||
|
|
buckets=[10, 30, 60, 120, 300, 600, 1800, 3600] # seconds
|
||
|
|
)
|
||
|
|
|
||
|
|
models_trained_total = Counter(
|
||
|
|
'models_trained_total',
|
||
|
|
'Total models successfully trained',
|
||
|
|
['tenant_id', 'product_category']
|
||
|
|
)
|
||
|
|
|
||
|
|
# Model performance metrics
|
||
|
|
model_accuracy_distribution = Histogram(
|
||
|
|
'model_accuracy_percentage',
|
||
|
|
'Distribution of model accuracy scores',
|
||
|
|
['tenant_id'],
|
||
|
|
buckets=[50, 60, 70, 75, 80, 85, 90, 95, 100] # percentage
|
||
|
|
)
|
||
|
|
|
||
|
|
model_mae_distribution = Histogram(
|
||
|
|
'model_mae',
|
||
|
|
'Distribution of Mean Absolute Error',
|
||
|
|
['tenant_id'],
|
||
|
|
buckets=[1, 5, 10, 20, 30, 50, 100] # units
|
||
|
|
)
|
||
|
|
|
||
|
|
# WebSocket metrics
|
||
|
|
websocket_connections_total = Gauge(
|
||
|
|
'training_websocket_connections',
|
||
|
|
'Active WebSocket connections',
|
||
|
|
['tenant_id']
|
||
|
|
)
|
||
|
|
|
||
|
|
websocket_messages_sent = Counter(
|
||
|
|
'training_websocket_messages_total',
|
||
|
|
'Total WebSocket messages sent',
|
||
|
|
['tenant_id', 'message_type']
|
||
|
|
)
|
||
|
|
```
|
||
|
|
|
||
|
|
## Configuration
|
||
|
|
|
||
|
|
### Environment Variables
|
||
|
|
|
||
|
|
**Service Configuration:**
|
||
|
|
- `PORT` - Service port (default: 8004)
|
||
|
|
- `DATABASE_URL` - PostgreSQL connection string
|
||
|
|
- `RABBITMQ_URL` - RabbitMQ connection string
|
||
|
|
|
||
|
|
**MinIO Configuration:**
|
||
|
|
- `MINIO_ENDPOINT` - MinIO server endpoint (default: minio.bakery-ia.svc.cluster.local:9000)
|
||
|
|
- `MINIO_ACCESS_KEY` - MinIO access key
|
||
|
|
- `MINIO_SECRET_KEY` - MinIO secret key
|
||
|
|
- `MINIO_USE_SSL` - Enable TLS (default: true)
|
||
|
|
- `MINIO_MODEL_BUCKET` - Bucket for models (default: training-models)
|
||
|
|
|
||
|
|
**Training Configuration:**
|
||
|
|
- `MAX_CONCURRENT_JOBS` - Maximum parallel training jobs (default: 3)
|
||
|
|
- `MAX_TRAINING_TIME_MINUTES` - Job timeout (default: 30)
|
||
|
|
- `MIN_TRAINING_DATA_DAYS` - Minimum history required (default: 30)
|
||
|
|
- `ENABLE_AUTO_DEPLOYMENT` - Auto-deploy after training (default: true)
|
||
|
|
|
||
|
|
**Prophet Configuration:**
|
||
|
|
- `PROPHET_DAILY_SEASONALITY` - Enable daily patterns (default: true)
|
||
|
|
- `PROPHET_WEEKLY_SEASONALITY` - Enable weekly patterns (default: true)
|
||
|
|
- `PROPHET_YEARLY_SEASONALITY` - Enable yearly patterns (default: true)
|
||
|
|
- `PROPHET_INTERVAL_WIDTH` - Confidence interval (default: 0.95)
|
||
|
|
- `PROPHET_CHANGEPOINT_PRIOR_SCALE` - Trend flexibility (default: 0.05)
|
||
|
|
|
||
|
|
**WebSocket Configuration:**
|
||
|
|
- `WEBSOCKET_HEARTBEAT_INTERVAL` - Ping interval seconds (default: 30)
|
||
|
|
- `WEBSOCKET_MAX_CONNECTIONS` - Max connections per tenant (default: 10)
|
||
|
|
- `WEBSOCKET_MESSAGE_QUEUE_SIZE` - Message buffer size (default: 100)
|
||
|
|
|
||
|
|
**Storage Configuration (MinIO):**
|
||
|
|
- `MINIO_MODEL_LIFECYCLE_DAYS` - Days to keep old model versions (default: 90)
|
||
|
|
- `MINIO_CACHE_TTL_SECONDS` - Model cache TTL in seconds (default: 3600)
|
||
|
|
|
||
|
|
## Development Setup
|
||
|
|
|
||
|
|
### Prerequisites
|
||
|
|
- Python 3.11+
|
||
|
|
- PostgreSQL 17
|
||
|
|
- RabbitMQ 4.1
|
||
|
|
- MinIO (S3-compatible object storage)
|
||
|
|
|
||
|
|
### Local Development
|
||
|
|
```bash
|
||
|
|
# Create virtual environment
|
||
|
|
cd services/training
|
||
|
|
python -m venv venv
|
||
|
|
source venv/bin/activate
|
||
|
|
|
||
|
|
# Install dependencies
|
||
|
|
pip install -r requirements.txt
|
||
|
|
|
||
|
|
# Set environment variables
|
||
|
|
export DATABASE_URL=postgresql://user:pass@localhost:5432/training
|
||
|
|
export RABBITMQ_URL=amqp://guest:guest@localhost:5672/
|
||
|
|
export MINIO_ENDPOINT=localhost:9000
|
||
|
|
export MINIO_ACCESS_KEY=minioadmin
|
||
|
|
export MINIO_SECRET_KEY=minioadmin
|
||
|
|
export MINIO_USE_SSL=false # Use true in production
|
||
|
|
|
||
|
|
# Start MinIO locally (if not using K8s)
|
||
|
|
docker run -p 9000:9000 -p 9001:9001 minio/minio server /data --console-address ":9001"
|
||
|
|
|
||
|
|
# Run database migrations
|
||
|
|
alembic upgrade head
|
||
|
|
|
||
|
|
# Run the service
|
||
|
|
python main.py
|
||
|
|
```
|
||
|
|
|
||
|
|
### Testing
|
||
|
|
```bash
|
||
|
|
# Unit tests
|
||
|
|
pytest tests/unit/ -v
|
||
|
|
|
||
|
|
# Integration tests (requires services)
|
||
|
|
pytest tests/integration/ -v
|
||
|
|
|
||
|
|
# WebSocket tests
|
||
|
|
pytest tests/websocket/ -v
|
||
|
|
|
||
|
|
# Test with coverage
|
||
|
|
pytest --cov=app tests/ --cov-report=html
|
||
|
|
```
|
||
|
|
|
||
|
|
### WebSocket Testing
|
||
|
|
```python
|
||
|
|
# Test WebSocket connection
|
||
|
|
import asyncio
|
||
|
|
import websockets
|
||
|
|
import json
|
||
|
|
|
||
|
|
async def test_training_progress():
|
||
|
|
uri = "ws://localhost:8004/api/v1/training/ws/job-id-here"
|
||
|
|
async with websockets.connect(uri) as websocket:
|
||
|
|
while True:
|
||
|
|
message = await websocket.recv()
|
||
|
|
data = json.loads(message)
|
||
|
|
print(f"Progress: {data['progress']['percentage']}%")
|
||
|
|
print(f"Step: {data['progress']['current_step']}")
|
||
|
|
|
||
|
|
if data['type'] == 'training_completed':
|
||
|
|
print("Training finished!")
|
||
|
|
break
|
||
|
|
|
||
|
|
asyncio.run(test_training_progress())
|
||
|
|
```
|
||
|
|
|
||
|
|
## POI Feature Integration
|
||
|
|
|
||
|
|
### How POI Features Enhance Training
|
||
|
|
|
||
|
|
The Training Service integrates location-based POI features from the External Service to improve forecast accuracy:
|
||
|
|
|
||
|
|
**POI Features Included:**
|
||
|
|
- `school_density` - Number of schools within 1km radius
|
||
|
|
- `office_density` - Number of offices and business centers nearby
|
||
|
|
- `residential_density` - Residential area proximity
|
||
|
|
- `transport_hub_proximity` - Distance to metro, bus, train stations
|
||
|
|
- `commercial_zone_score` - Commercial activity in the area
|
||
|
|
- `restaurant_density` - Nearby restaurants and cafes
|
||
|
|
- `competitor_proximity` - Distance to competing bakeries
|
||
|
|
- And 11+ more location-based features
|
||
|
|
|
||
|
|
**Integration Process:**
|
||
|
|
1. **Fetch POI Context** - Retrieve tenant's POI features from External Service (`/poi-context/{tenant_id}`)
|
||
|
|
2. **Extract ML Features** - Parse `ml_features` JSON object from POI context
|
||
|
|
3. **Merge with Training Data** - Add POI features as additional columns in training dataframe
|
||
|
|
4. **Prophet Training** - Include POI features as regressors in Prophet model
|
||
|
|
5. **Feature Importance** - Track which POI features most impact predictions
|
||
|
|
|
||
|
|
**Example POI Feature Integration:**
|
||
|
|
```python
|
||
|
|
from app.ml.poi_feature_integrator import POIFeatureIntegrator
|
||
|
|
|
||
|
|
# Initialize POI integrator
|
||
|
|
poi_integrator = POIFeatureIntegrator(external_service_url)
|
||
|
|
|
||
|
|
# Fetch and merge POI features
|
||
|
|
poi_features = await poi_integrator.fetch_poi_features(tenant_id)
|
||
|
|
training_df = poi_integrator.merge_poi_features(training_df, poi_features)
|
||
|
|
|
||
|
|
# POI features now available as columns:
|
||
|
|
# training_df['school_density'], training_df['office_density'], etc.
|
||
|
|
|
||
|
|
# Add POI features as Prophet regressors
|
||
|
|
for feature_name in poi_features.keys():
|
||
|
|
prophet_model.add_regressor(feature_name)
|
||
|
|
```
|
||
|
|
|
||
|
|
**Endpoint Used:**
|
||
|
|
- Via shared client: `/api/v1/tenants/{tenant_id}/external/poi-context` (routed through API Gateway)
|
||
|
|
|
||
|
|
## Integration Points
|
||
|
|
|
||
|
|
### Dependencies (Services Called)
|
||
|
|
- **Sales Service** - Fetch historical sales data for training
|
||
|
|
- **External Service** - Fetch weather, traffic, holiday, and POI feature data
|
||
|
|
- **PostgreSQL** - Store job queue, models, metrics, logs
|
||
|
|
- **RabbitMQ** - Publish training completion events
|
||
|
|
- **MinIO** - Store model artifacts (S3-compatible object storage with TLS)
|
||
|
|
|
||
|
|
### Dependents (Services That Call This)
|
||
|
|
- **Forecasting Service** - Load trained models for predictions
|
||
|
|
- **Orchestrator Service** - Trigger scheduled training jobs
|
||
|
|
- **Frontend Dashboard** - Display training progress and model metrics
|
||
|
|
- **AI Insights Service** - Analyze model performance patterns
|
||
|
|
|
||
|
|
## Security Measures
|
||
|
|
|
||
|
|
### Data Protection
|
||
|
|
- **Tenant Isolation** - All training jobs scoped to tenant_id
|
||
|
|
- **Model Access Control** - Only tenant can access their models
|
||
|
|
- **Input Validation** - Validate all training parameters
|
||
|
|
- **Rate Limiting** - Prevent training job spam
|
||
|
|
|
||
|
|
### Model Security
|
||
|
|
- **Model Checksums** - SHA-256 hash verification for artifacts
|
||
|
|
- **Version Control** - Track all model versions with audit trail
|
||
|
|
- **Access Logging** - Log all model access and deployment
|
||
|
|
- **Secure Storage** - Model files stored with restricted permissions
|
||
|
|
|
||
|
|
### WebSocket Security
|
||
|
|
- **JWT Authentication** - Authenticate WebSocket connections
|
||
|
|
- **Connection Limits** - Max connections per tenant
|
||
|
|
- **Message Validation** - Validate all WebSocket messages
|
||
|
|
- **Heartbeat Monitoring** - Detect and close stale connections
|
||
|
|
|
||
|
|
## Performance Optimization
|
||
|
|
|
||
|
|
### Training Performance
|
||
|
|
1. **Parallel Processing** - Train multiple products concurrently
|
||
|
|
2. **Data Caching** - Cache fetched external data across products
|
||
|
|
3. **Incremental Training** - Only retrain changed products
|
||
|
|
4. **Resource Limits** - CPU/memory limits per training job
|
||
|
|
5. **Priority Queue** - Prioritize important products first
|
||
|
|
|
||
|
|
### Storage Optimization (MinIO)
|
||
|
|
1. **Object Versioning** - MinIO maintains version history automatically
|
||
|
|
2. **Lifecycle Policies** - Auto-cleanup old versions after 90 days
|
||
|
|
3. **TLS Encryption** - Secure communication with MinIO
|
||
|
|
4. **Distributed Storage** - MinIO handles replication and availability
|
||
|
|
|
||
|
|
### WebSocket Optimization
|
||
|
|
1. **Message Batching** - Batch progress updates (every 2 seconds)
|
||
|
|
2. **Connection Pooling** - Reuse WebSocket connections
|
||
|
|
3. **Compression** - Enable WebSocket message compression
|
||
|
|
4. **Heartbeat** - Keep connections alive efficiently
|
||
|
|
|
||
|
|
## Troubleshooting
|
||
|
|
|
||
|
|
### Common Issues
|
||
|
|
|
||
|
|
**Issue**: Training jobs stuck in "pending" status
|
||
|
|
- **Cause**: Max concurrent jobs reached or worker process crashed
|
||
|
|
- **Solution**: Check `MAX_CONCURRENT_JOBS` setting, restart service
|
||
|
|
|
||
|
|
**Issue**: WebSocket connection drops during training
|
||
|
|
- **Cause**: Network timeout or client disconnection
|
||
|
|
- **Solution**: Implement auto-reconnect logic in client
|
||
|
|
|
||
|
|
**Issue**: "Insufficient data" errors for many products
|
||
|
|
- **Cause**: Products need 30+ days of sales history
|
||
|
|
- **Solution**: Import more historical sales data or reduce `MIN_TRAINING_DATA_DAYS`
|
||
|
|
|
||
|
|
**Issue**: Low model accuracy (<70%)
|
||
|
|
- **Cause**: Insufficient data, outliers, or changing business patterns
|
||
|
|
- **Solution**: Clean outliers, add more features, or manually adjust Prophet params
|
||
|
|
|
||
|
|
### Debug Mode
|
||
|
|
```bash
|
||
|
|
# Enable detailed logging
|
||
|
|
export LOG_LEVEL=DEBUG
|
||
|
|
export PROPHET_VERBOSE=1
|
||
|
|
|
||
|
|
# Enable training profiling
|
||
|
|
export ENABLE_PROFILING=1
|
||
|
|
|
||
|
|
# Disable concurrent jobs for debugging
|
||
|
|
export MAX_CONCURRENT_JOBS=1
|
||
|
|
```
|
||
|
|
|
||
|
|
## Competitive Advantages
|
||
|
|
|
||
|
|
1. **One-Click ML** - No data science expertise required
|
||
|
|
2. **Real-Time Visibility** - WebSocket progress updates unique in bakery software
|
||
|
|
3. **Continuous Learning** - Automatic weekly retraining
|
||
|
|
4. **Version Control** - Track and compare all model versions
|
||
|
|
5. **Production-Ready** - Robust error handling and retry mechanisms
|
||
|
|
6. **Scalable** - Train models for thousands of products
|
||
|
|
7. **Spanish Market** - Optimized for Spanish bakery patterns and holidays
|
||
|
|
|
||
|
|
## Future Enhancements
|
||
|
|
|
||
|
|
- **Hyperparameter Tuning** - Automatic optimization of Prophet parameters
|
||
|
|
- **A/B Testing** - Deploy multiple models and compare performance
|
||
|
|
- **Distributed Training** - Scale across multiple machines
|
||
|
|
- **GPU Acceleration** - Use GPUs for deep learning models
|
||
|
|
- **AutoML** - Automatic algorithm selection (Prophet vs LSTM vs ARIMA)
|
||
|
|
- **Model Explainability** - SHAP values to explain predictions
|
||
|
|
- **Custom Algorithms** - Support for user-provided ML models
|
||
|
|
- **Transfer Learning** - Use pre-trained models from similar bakeries
|
||
|
|
|
||
|
|
---
|
||
|
|
|
||
|
|
**For VUE Madrid Business Plan**: The Training Service demonstrates advanced ML engineering capabilities with automated pipeline management and real-time monitoring. The ability to continuously improve forecast accuracy without manual intervention represents significant operational efficiency and competitive advantage. This self-learning system is a key differentiator in the bakery software market and showcases technical innovation suitable for EU technology grants and investor presentations.
|