Initial commit - production deployment
This commit is contained in:
70
services/training/Dockerfile
Normal file
70
services/training/Dockerfile
Normal file
@@ -0,0 +1,70 @@
|
||||
# =============================================================================
|
||||
# Training Service Dockerfile - Environment-Configurable Base Images
|
||||
# =============================================================================
|
||||
# Build arguments for registry configuration:
|
||||
# - BASE_REGISTRY: Registry URL (default: docker.io for Docker Hub)
|
||||
# - PYTHON_IMAGE: Python image name and tag (default: python:3.11-slim)
|
||||
# =============================================================================
|
||||
|
||||
ARG BASE_REGISTRY=docker.io
|
||||
ARG PYTHON_IMAGE=python:3.11-slim
|
||||
|
||||
FROM ${BASE_REGISTRY}/${PYTHON_IMAGE} AS shared
|
||||
WORKDIR /shared
|
||||
COPY shared/ /shared/
|
||||
|
||||
ARG BASE_REGISTRY=docker.io
|
||||
ARG PYTHON_IMAGE=python:3.11-slim
|
||||
FROM ${BASE_REGISTRY}/${PYTHON_IMAGE}
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies including cmdstan requirements
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
g++ \
|
||||
make \
|
||||
curl \
|
||||
build-essential \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements
|
||||
COPY shared/requirements-tracing.txt /tmp/
|
||||
|
||||
COPY services/training/requirements.txt .
|
||||
|
||||
# Install Python dependencies
|
||||
RUN pip install --no-cache-dir -r /tmp/requirements-tracing.txt
|
||||
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy shared libraries from the shared stage
|
||||
COPY --from=shared /shared /app/shared
|
||||
|
||||
# Copy application code
|
||||
COPY services/training/ .
|
||||
|
||||
|
||||
|
||||
# Add shared libraries to Python path
|
||||
ENV PYTHONPATH="/app:/app/shared:${PYTHONPATH:-}"
|
||||
|
||||
# Set TMPDIR for cmdstan (directory will be created at runtime)
|
||||
ENV TMPDIR=/tmp/cmdstan
|
||||
|
||||
# Install cmdstan for Prophet (required for model optimization)
|
||||
# Suppress verbose output to reduce log noise
|
||||
RUN python -m pip install --no-cache-dir cmdstanpy && \
|
||||
python -m cmdstanpy.install_cmdstan
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD curl -f http://localhost:8000/health || exit 1
|
||||
|
||||
# Run application with increased WebSocket ping timeout to handle long training operations
|
||||
# Default uvicorn ws-ping-timeout is 20s, increasing to 300s (5 minutes) to prevent
|
||||
# premature disconnections during CPU-intensive ML training (typically 2-3 minutes)
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--ws-ping-timeout", "300"]
|
||||
728
services/training/README.md
Normal file
728
services/training/README.md
Normal file
@@ -0,0 +1,728 @@
|
||||
# Training Service (ML Model Management)
|
||||
|
||||
## Overview
|
||||
|
||||
The **Training Service** is the machine learning pipeline engine of Bakery-IA, responsible for training, versioning, and managing Prophet forecasting models. It orchestrates the entire ML workflow from data collection to model deployment, providing real-time progress updates via WebSocket and ensuring bakeries always have the most accurate prediction models. This service enables continuous learning and model improvement without requiring data science expertise.
|
||||
|
||||
## Key Features
|
||||
|
||||
### Automated ML Pipeline
|
||||
- **One-Click Model Training** - Train models for all products with a single API call
|
||||
- **Background Job Processing** - Asynchronous training with job queue management
|
||||
- **Multi-Product Training** - Process multiple products in parallel
|
||||
- **Progress Tracking** - Real-time WebSocket updates on training status
|
||||
- **Automatic Model Versioning** - Track all model versions with performance metrics
|
||||
- **Model Artifact Storage** - Persist trained models for fast prediction loading
|
||||
|
||||
### Training Job Management
|
||||
- **Job Queue** - FIFO queue for training requests
|
||||
- **Job Status Tracking** - Monitor pending, running, completed, and failed jobs
|
||||
- **Concurrent Job Control** - Limit parallel training jobs to prevent resource exhaustion
|
||||
- **Timeout Handling** - Automatic job termination after maximum duration
|
||||
- **Error Recovery** - Detailed error messages and retry capabilities
|
||||
- **Job History** - Complete audit trail of all training executions
|
||||
|
||||
### Model Performance Tracking
|
||||
- **Accuracy Metrics** - MAE, RMSE, R², MAPE for each trained model
|
||||
- **Historical Comparison** - Compare current vs. previous model performance
|
||||
- **Per-Product Analytics** - Track which products have the best forecast accuracy
|
||||
- **Training Duration Tracking** - Monitor training performance and optimization
|
||||
- **Model Selection** - Automatically deploy best-performing models
|
||||
|
||||
### Real-Time Communication
|
||||
- **WebSocket Live Updates** - Real-time progress percentage and status messages
|
||||
- **Training Logs** - Detailed step-by-step execution logs
|
||||
- **Completion Notifications** - RabbitMQ events for training completion
|
||||
- **Error Alerts** - Immediate notification of training failures
|
||||
|
||||
### Feature Engineering
|
||||
- **Historical Data Aggregation** - Collect sales data for model training
|
||||
- **External Data Integration** - Fetch weather, traffic, holiday data
|
||||
- **POI Feature Integration** - Merge location-based POI features into training data
|
||||
- **Feature Extraction** - Generate 30+ temporal, contextual, and location-based features
|
||||
- **Data Validation** - Ensure minimum data requirements before training
|
||||
- **Outlier Detection** - Filter anomalous data points
|
||||
|
||||
## Technical Capabilities
|
||||
|
||||
### ML Training Pipeline
|
||||
|
||||
```python
|
||||
# Training workflow
|
||||
async def train_model_pipeline(tenant_id: str, product_id: str):
|
||||
"""Complete ML training pipeline"""
|
||||
|
||||
# Step 1: Data Collection
|
||||
sales_data = await fetch_historical_sales(tenant_id, product_id)
|
||||
if len(sales_data) < MIN_TRAINING_DAYS:
|
||||
raise InsufficientDataError(f"Need {MIN_TRAINING_DAYS}+ days of data")
|
||||
|
||||
# Step 2: Feature Engineering
|
||||
features = engineer_features(sales_data)
|
||||
weather_data = await fetch_weather_data(tenant_id)
|
||||
traffic_data = await fetch_traffic_data(tenant_id)
|
||||
holiday_data = await fetch_holiday_calendar()
|
||||
poi_features = await fetch_poi_features(tenant_id) # NEW: Location context
|
||||
|
||||
# Merge POI features into training dataframe
|
||||
features = merge_poi_features(features, poi_features)
|
||||
|
||||
# Step 3: Prophet Model Training
|
||||
model = Prophet(
|
||||
seasonality_mode='additive',
|
||||
daily_seasonality=True,
|
||||
weekly_seasonality=True,
|
||||
yearly_seasonality=True,
|
||||
)
|
||||
model.add_country_holidays(country_name='ES')
|
||||
model.fit(features)
|
||||
|
||||
# Step 4: Model Validation
|
||||
metrics = calculate_performance_metrics(model, sales_data)
|
||||
|
||||
# Step 5: Model Storage
|
||||
model_path = save_model_artifact(model, tenant_id, product_id)
|
||||
|
||||
# Step 6: Model Registration
|
||||
await register_model_in_database(model_path, metrics)
|
||||
|
||||
# Step 7: Notification
|
||||
await publish_training_complete_event(tenant_id, product_id, metrics)
|
||||
|
||||
return model, metrics
|
||||
```
|
||||
|
||||
### WebSocket Progress Updates
|
||||
|
||||
```python
|
||||
# Real-time progress broadcasting
|
||||
async def broadcast_training_progress(job_id: str, progress: dict):
|
||||
"""Send progress update to connected clients"""
|
||||
|
||||
message = {
|
||||
"type": "training_progress",
|
||||
"job_id": job_id,
|
||||
"progress": {
|
||||
"percentage": progress["percentage"], # 0-100
|
||||
"current_step": progress["step"], # Step description
|
||||
"products_completed": progress["completed"],
|
||||
"products_total": progress["total"],
|
||||
"estimated_time_remaining": progress["eta"], # Seconds
|
||||
"started_at": progress["start_time"]
|
||||
},
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
await websocket_manager.broadcast(job_id, message)
|
||||
```
|
||||
|
||||
### Model Artifact Management (MinIO Storage)
|
||||
|
||||
```python
|
||||
# Model storage and retrieval using MinIO
|
||||
import joblib
|
||||
from shared.clients.minio_client import minio_client
|
||||
|
||||
# Save trained model to MinIO
|
||||
def save_model_artifact(model: Prophet, tenant_id: str, product_id: str) -> str:
|
||||
"""Serialize and store model in MinIO"""
|
||||
import io
|
||||
version = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
model_id = str(uuid.uuid4())
|
||||
object_name = f"models/{tenant_id}/{product_id}/{model_id}.pkl"
|
||||
|
||||
# Serialize model (joblib.dump writes to file-like objects)
|
||||
buffer = io.BytesIO()
|
||||
joblib.dump(model, buffer)
|
||||
model_data = buffer.getvalue()
|
||||
|
||||
# Upload to MinIO
|
||||
minio_client.put_object(
|
||||
bucket_name="training-models",
|
||||
object_name=object_name,
|
||||
data=model_data,
|
||||
content_type="application/octet-stream"
|
||||
)
|
||||
|
||||
# Return MinIO path
|
||||
return f"minio://training-models/{object_name}"
|
||||
|
||||
# Load trained model from MinIO
|
||||
def load_model_artifact(model_path: str) -> Prophet:
|
||||
"""Load serialized model from MinIO"""
|
||||
import io
|
||||
# Parse MinIO path: minio://bucket_name/object_path
|
||||
_, bucket_and_path = model_path.split("://", 1)
|
||||
bucket_name, object_name = bucket_and_path.split("/", 1)
|
||||
|
||||
# Download from MinIO
|
||||
model_data = minio_client.get_object(bucket_name, object_name)
|
||||
|
||||
# Deserialize (joblib.load reads from file-like objects)
|
||||
buffer = io.BytesIO(model_data)
|
||||
return joblib.load(buffer)
|
||||
```
|
||||
|
||||
### Performance Metrics Calculation
|
||||
|
||||
```python
|
||||
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
|
||||
import numpy as np
|
||||
|
||||
def calculate_performance_metrics(model: Prophet, actual_data: pd.DataFrame) -> dict:
|
||||
"""Calculate comprehensive model performance metrics"""
|
||||
|
||||
# Make predictions on validation set
|
||||
predictions = model.predict(actual_data)
|
||||
|
||||
# Calculate metrics
|
||||
mae = mean_absolute_error(actual_data['y'], predictions['yhat'])
|
||||
rmse = np.sqrt(mean_squared_error(actual_data['y'], predictions['yhat']))
|
||||
r2 = r2_score(actual_data['y'], predictions['yhat'])
|
||||
mape = np.mean(np.abs((actual_data['y'] - predictions['yhat']) / actual_data['y'])) * 100
|
||||
|
||||
return {
|
||||
"mae": float(mae), # Mean Absolute Error
|
||||
"rmse": float(rmse), # Root Mean Square Error
|
||||
"r2_score": float(r2), # R-squared
|
||||
"mape": float(mape), # Mean Absolute Percentage Error
|
||||
"accuracy": float(100 - mape) if mape < 100 else 0.0
|
||||
}
|
||||
```
|
||||
|
||||
## Business Value
|
||||
|
||||
### For Bakery Owners
|
||||
- **Continuous Improvement** - Models automatically improve with more data
|
||||
- **No ML Expertise Required** - One-click training, no data science skills needed
|
||||
- **Always Up-to-Date** - Weekly automatic retraining keeps models accurate
|
||||
- **Transparent Performance** - Clear accuracy metrics show forecast reliability
|
||||
- **Cost Savings** - Automated ML pipeline eliminates need for data scientists
|
||||
|
||||
### For Operations Managers
|
||||
- **Model Version Control** - Track and compare model versions over time
|
||||
- **Performance Monitoring** - Identify products with poor forecast accuracy
|
||||
- **Training Scheduling** - Schedule retraining during low-traffic hours
|
||||
- **Resource Management** - Control concurrent training jobs to prevent overload
|
||||
|
||||
### For Platform Operations
|
||||
- **Scalable ML Pipeline** - Train models for thousands of products
|
||||
- **Background Processing** - Non-blocking training jobs
|
||||
- **Error Handling** - Robust error recovery and retry mechanisms
|
||||
- **Cost Optimization** - Efficient model storage and caching
|
||||
|
||||
## Technology Stack
|
||||
|
||||
- **Framework**: FastAPI (Python 3.11+) - Async web framework with WebSocket support
|
||||
- **Database**: PostgreSQL 17 - Training logs, model metadata, job queue
|
||||
- **ML Library**: Prophet (fbprophet) - Time series forecasting
|
||||
- **Model Storage**: MinIO (S3-compatible) - Distributed object storage with TLS
|
||||
- **Serialization**: Joblib - Model serialization
|
||||
- **WebSocket**: FastAPI WebSocket - Real-time progress updates
|
||||
- **Messaging**: RabbitMQ 4.1 - Training completion events
|
||||
- **ORM**: SQLAlchemy 2.0 (async) - Database abstraction
|
||||
- **Data Processing**: Pandas, NumPy - Data manipulation
|
||||
- **Logging**: Structlog - Structured JSON logging
|
||||
- **Metrics**: Prometheus Client - Custom metrics
|
||||
|
||||
## API Endpoints (Key Routes)
|
||||
|
||||
### Training Management
|
||||
- `POST /api/v1/training/start` - Start training job for tenant
|
||||
- `POST /api/v1/training/start/{product_id}` - Train specific product
|
||||
- `POST /api/v1/training/stop/{job_id}` - Stop running training job
|
||||
- `GET /api/v1/training/status/{job_id}` - Get job status and progress
|
||||
- `GET /api/v1/training/history` - Get training job history
|
||||
- `DELETE /api/v1/training/jobs/{job_id}` - Delete training job record
|
||||
|
||||
### Model Management
|
||||
- `GET /api/v1/training/models` - List all trained models
|
||||
- `GET /api/v1/training/models/{model_id}` - Get specific model details
|
||||
- `GET /api/v1/training/models/{model_id}/metrics` - Get model performance metrics
|
||||
- `GET /api/v1/training/models/latest/{product_id}` - Get latest model for product
|
||||
- `POST /api/v1/training/models/{model_id}/deploy` - Deploy specific model version
|
||||
- `DELETE /api/v1/training/models/{model_id}` - Delete model artifact
|
||||
|
||||
### WebSocket
|
||||
- `WS /api/v1/training/ws/{job_id}` - Connect to training progress stream
|
||||
|
||||
### Analytics
|
||||
- `GET /api/v1/training/analytics/performance` - Overall training performance
|
||||
- `GET /api/v1/training/analytics/accuracy` - Model accuracy distribution
|
||||
- `GET /api/v1/training/analytics/duration` - Training duration statistics
|
||||
|
||||
## Database Schema
|
||||
|
||||
### Main Tables
|
||||
|
||||
**training_job_queue**
|
||||
```sql
|
||||
CREATE TABLE training_job_queue (
|
||||
id UUID PRIMARY KEY,
|
||||
tenant_id UUID NOT NULL,
|
||||
job_name VARCHAR(255),
|
||||
products_to_train TEXT[], -- Array of product IDs
|
||||
status VARCHAR(50) NOT NULL, -- pending, running, completed, failed
|
||||
priority INTEGER DEFAULT 0,
|
||||
progress_percentage INTEGER DEFAULT 0,
|
||||
current_step VARCHAR(255),
|
||||
products_completed INTEGER DEFAULT 0,
|
||||
products_total INTEGER,
|
||||
started_at TIMESTAMP,
|
||||
completed_at TIMESTAMP,
|
||||
estimated_completion TIMESTAMP,
|
||||
error_message TEXT,
|
||||
retry_count INTEGER DEFAULT 0,
|
||||
created_by UUID,
|
||||
created_at TIMESTAMP DEFAULT NOW(),
|
||||
updated_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
```
|
||||
|
||||
**trained_models**
|
||||
```sql
|
||||
CREATE TABLE trained_models (
|
||||
id UUID PRIMARY KEY,
|
||||
tenant_id UUID NOT NULL,
|
||||
product_id UUID NOT NULL,
|
||||
model_version VARCHAR(50) NOT NULL,
|
||||
model_path VARCHAR(500) NOT NULL,
|
||||
training_job_id UUID REFERENCES training_job_queue(id),
|
||||
algorithm VARCHAR(50) DEFAULT 'prophet',
|
||||
hyperparameters JSONB,
|
||||
training_duration_seconds INTEGER,
|
||||
training_data_points INTEGER,
|
||||
is_deployed BOOLEAN DEFAULT FALSE,
|
||||
deployed_at TIMESTAMP,
|
||||
created_at TIMESTAMP DEFAULT NOW(),
|
||||
UNIQUE(tenant_id, product_id, model_version)
|
||||
);
|
||||
```
|
||||
|
||||
**model_performance_metrics**
|
||||
```sql
|
||||
CREATE TABLE model_performance_metrics (
|
||||
id UUID PRIMARY KEY,
|
||||
model_id UUID REFERENCES trained_models(id),
|
||||
tenant_id UUID NOT NULL,
|
||||
product_id UUID NOT NULL,
|
||||
mae DECIMAL(10, 4), -- Mean Absolute Error
|
||||
rmse DECIMAL(10, 4), -- Root Mean Square Error
|
||||
r2_score DECIMAL(10, 6), -- R-squared
|
||||
mape DECIMAL(10, 4), -- Mean Absolute Percentage Error
|
||||
accuracy_percentage DECIMAL(5, 2),
|
||||
validation_data_points INTEGER,
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
```
|
||||
|
||||
**model_training_logs**
|
||||
```sql
|
||||
CREATE TABLE model_training_logs (
|
||||
id UUID PRIMARY KEY,
|
||||
training_job_id UUID REFERENCES training_job_queue(id),
|
||||
tenant_id UUID NOT NULL,
|
||||
product_id UUID,
|
||||
log_level VARCHAR(20), -- DEBUG, INFO, WARNING, ERROR
|
||||
message TEXT,
|
||||
step_name VARCHAR(100),
|
||||
execution_time_ms INTEGER,
|
||||
metadata JSONB,
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
```
|
||||
|
||||
**model_artifacts** (Metadata only, actual files on disk)
|
||||
```sql
|
||||
CREATE TABLE model_artifacts (
|
||||
id UUID PRIMARY KEY,
|
||||
model_id UUID REFERENCES trained_models(id),
|
||||
artifact_type VARCHAR(50), -- model_file, feature_list, scaler, etc.
|
||||
file_path VARCHAR(500),
|
||||
file_size_bytes BIGINT,
|
||||
checksum VARCHAR(64), -- SHA-256 hash
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
```
|
||||
|
||||
## Events & Messaging
|
||||
|
||||
### Published Events (RabbitMQ)
|
||||
|
||||
**Exchange**: `training`
|
||||
**Routing Key**: `training.completed`
|
||||
|
||||
**Training Completed Event**
|
||||
```json
|
||||
{
|
||||
"event_type": "training_completed",
|
||||
"tenant_id": "uuid",
|
||||
"job_id": "uuid",
|
||||
"job_name": "Weekly retraining - All products",
|
||||
"status": "completed",
|
||||
"results": {
|
||||
"successful_trainings": 25,
|
||||
"failed_trainings": 2,
|
||||
"total_products": 27,
|
||||
"models_created": [
|
||||
{
|
||||
"product_id": "uuid",
|
||||
"product_name": "Baguette",
|
||||
"model_version": "20251106_143022",
|
||||
"accuracy": 82.5,
|
||||
"mae": 12.3,
|
||||
"rmse": 18.7,
|
||||
"r2_score": 0.78
|
||||
}
|
||||
],
|
||||
"average_accuracy": 79.8,
|
||||
"training_duration_seconds": 342
|
||||
},
|
||||
"started_at": "2025-11-06T14:25:00Z",
|
||||
"completed_at": "2025-11-06T14:30:42Z",
|
||||
"timestamp": "2025-11-06T14:30:42Z"
|
||||
}
|
||||
```
|
||||
|
||||
**Training Failed Event**
|
||||
```json
|
||||
{
|
||||
"event_type": "training_failed",
|
||||
"tenant_id": "uuid",
|
||||
"job_id": "uuid",
|
||||
"product_id": "uuid",
|
||||
"product_name": "Croissant",
|
||||
"error_type": "InsufficientDataError",
|
||||
"error_message": "Product requires minimum 30 days of sales data. Currently: 15 days.",
|
||||
"recommended_action": "Collect more sales data before retraining",
|
||||
"severity": "medium",
|
||||
"timestamp": "2025-11-06T14:28:15Z"
|
||||
}
|
||||
```
|
||||
|
||||
### Consumed Events
|
||||
- **From Orchestrator**: Scheduled training triggers
|
||||
- **From Sales**: New sales data imported (triggers retraining)
|
||||
|
||||
## Custom Metrics (Prometheus)
|
||||
|
||||
```python
|
||||
# Training job metrics
|
||||
training_jobs_total = Counter(
|
||||
'training_jobs_total',
|
||||
'Total training jobs started',
|
||||
['tenant_id', 'status'] # completed, failed, cancelled
|
||||
)
|
||||
|
||||
training_duration_seconds = Histogram(
|
||||
'training_duration_seconds',
|
||||
'Training job duration',
|
||||
['tenant_id'],
|
||||
buckets=[10, 30, 60, 120, 300, 600, 1800, 3600] # seconds
|
||||
)
|
||||
|
||||
models_trained_total = Counter(
|
||||
'models_trained_total',
|
||||
'Total models successfully trained',
|
||||
['tenant_id', 'product_category']
|
||||
)
|
||||
|
||||
# Model performance metrics
|
||||
model_accuracy_distribution = Histogram(
|
||||
'model_accuracy_percentage',
|
||||
'Distribution of model accuracy scores',
|
||||
['tenant_id'],
|
||||
buckets=[50, 60, 70, 75, 80, 85, 90, 95, 100] # percentage
|
||||
)
|
||||
|
||||
model_mae_distribution = Histogram(
|
||||
'model_mae',
|
||||
'Distribution of Mean Absolute Error',
|
||||
['tenant_id'],
|
||||
buckets=[1, 5, 10, 20, 30, 50, 100] # units
|
||||
)
|
||||
|
||||
# WebSocket metrics
|
||||
websocket_connections_total = Gauge(
|
||||
'training_websocket_connections',
|
||||
'Active WebSocket connections',
|
||||
['tenant_id']
|
||||
)
|
||||
|
||||
websocket_messages_sent = Counter(
|
||||
'training_websocket_messages_total',
|
||||
'Total WebSocket messages sent',
|
||||
['tenant_id', 'message_type']
|
||||
)
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables
|
||||
|
||||
**Service Configuration:**
|
||||
- `PORT` - Service port (default: 8004)
|
||||
- `DATABASE_URL` - PostgreSQL connection string
|
||||
- `RABBITMQ_URL` - RabbitMQ connection string
|
||||
|
||||
**MinIO Configuration:**
|
||||
- `MINIO_ENDPOINT` - MinIO server endpoint (default: minio.bakery-ia.svc.cluster.local:9000)
|
||||
- `MINIO_ACCESS_KEY` - MinIO access key
|
||||
- `MINIO_SECRET_KEY` - MinIO secret key
|
||||
- `MINIO_USE_SSL` - Enable TLS (default: true)
|
||||
- `MINIO_MODEL_BUCKET` - Bucket for models (default: training-models)
|
||||
|
||||
**Training Configuration:**
|
||||
- `MAX_CONCURRENT_JOBS` - Maximum parallel training jobs (default: 3)
|
||||
- `MAX_TRAINING_TIME_MINUTES` - Job timeout (default: 30)
|
||||
- `MIN_TRAINING_DATA_DAYS` - Minimum history required (default: 30)
|
||||
- `ENABLE_AUTO_DEPLOYMENT` - Auto-deploy after training (default: true)
|
||||
|
||||
**Prophet Configuration:**
|
||||
- `PROPHET_DAILY_SEASONALITY` - Enable daily patterns (default: true)
|
||||
- `PROPHET_WEEKLY_SEASONALITY` - Enable weekly patterns (default: true)
|
||||
- `PROPHET_YEARLY_SEASONALITY` - Enable yearly patterns (default: true)
|
||||
- `PROPHET_INTERVAL_WIDTH` - Confidence interval (default: 0.95)
|
||||
- `PROPHET_CHANGEPOINT_PRIOR_SCALE` - Trend flexibility (default: 0.05)
|
||||
|
||||
**WebSocket Configuration:**
|
||||
- `WEBSOCKET_HEARTBEAT_INTERVAL` - Ping interval seconds (default: 30)
|
||||
- `WEBSOCKET_MAX_CONNECTIONS` - Max connections per tenant (default: 10)
|
||||
- `WEBSOCKET_MESSAGE_QUEUE_SIZE` - Message buffer size (default: 100)
|
||||
|
||||
**Storage Configuration (MinIO):**
|
||||
- `MINIO_MODEL_LIFECYCLE_DAYS` - Days to keep old model versions (default: 90)
|
||||
- `MINIO_CACHE_TTL_SECONDS` - Model cache TTL in seconds (default: 3600)
|
||||
|
||||
## Development Setup
|
||||
|
||||
### Prerequisites
|
||||
- Python 3.11+
|
||||
- PostgreSQL 17
|
||||
- RabbitMQ 4.1
|
||||
- MinIO (S3-compatible object storage)
|
||||
|
||||
### Local Development
|
||||
```bash
|
||||
# Create virtual environment
|
||||
cd services/training
|
||||
python -m venv venv
|
||||
source venv/bin/activate
|
||||
|
||||
# Install dependencies
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Set environment variables
|
||||
export DATABASE_URL=postgresql://user:pass@localhost:5432/training
|
||||
export RABBITMQ_URL=amqp://guest:guest@localhost:5672/
|
||||
export MINIO_ENDPOINT=localhost:9000
|
||||
export MINIO_ACCESS_KEY=minioadmin
|
||||
export MINIO_SECRET_KEY=minioadmin
|
||||
export MINIO_USE_SSL=false # Use true in production
|
||||
|
||||
# Start MinIO locally (if not using K8s)
|
||||
docker run -p 9000:9000 -p 9001:9001 minio/minio server /data --console-address ":9001"
|
||||
|
||||
# Run database migrations
|
||||
alembic upgrade head
|
||||
|
||||
# Run the service
|
||||
python main.py
|
||||
```
|
||||
|
||||
### Testing
|
||||
```bash
|
||||
# Unit tests
|
||||
pytest tests/unit/ -v
|
||||
|
||||
# Integration tests (requires services)
|
||||
pytest tests/integration/ -v
|
||||
|
||||
# WebSocket tests
|
||||
pytest tests/websocket/ -v
|
||||
|
||||
# Test with coverage
|
||||
pytest --cov=app tests/ --cov-report=html
|
||||
```
|
||||
|
||||
### WebSocket Testing
|
||||
```python
|
||||
# Test WebSocket connection
|
||||
import asyncio
|
||||
import websockets
|
||||
import json
|
||||
|
||||
async def test_training_progress():
|
||||
uri = "ws://localhost:8004/api/v1/training/ws/job-id-here"
|
||||
async with websockets.connect(uri) as websocket:
|
||||
while True:
|
||||
message = await websocket.recv()
|
||||
data = json.loads(message)
|
||||
print(f"Progress: {data['progress']['percentage']}%")
|
||||
print(f"Step: {data['progress']['current_step']}")
|
||||
|
||||
if data['type'] == 'training_completed':
|
||||
print("Training finished!")
|
||||
break
|
||||
|
||||
asyncio.run(test_training_progress())
|
||||
```
|
||||
|
||||
## POI Feature Integration
|
||||
|
||||
### How POI Features Enhance Training
|
||||
|
||||
The Training Service integrates location-based POI features from the External Service to improve forecast accuracy:
|
||||
|
||||
**POI Features Included:**
|
||||
- `school_density` - Number of schools within 1km radius
|
||||
- `office_density` - Number of offices and business centers nearby
|
||||
- `residential_density` - Residential area proximity
|
||||
- `transport_hub_proximity` - Distance to metro, bus, train stations
|
||||
- `commercial_zone_score` - Commercial activity in the area
|
||||
- `restaurant_density` - Nearby restaurants and cafes
|
||||
- `competitor_proximity` - Distance to competing bakeries
|
||||
- And 11+ more location-based features
|
||||
|
||||
**Integration Process:**
|
||||
1. **Fetch POI Context** - Retrieve tenant's POI features from External Service (`/poi-context/{tenant_id}`)
|
||||
2. **Extract ML Features** - Parse `ml_features` JSON object from POI context
|
||||
3. **Merge with Training Data** - Add POI features as additional columns in training dataframe
|
||||
4. **Prophet Training** - Include POI features as regressors in Prophet model
|
||||
5. **Feature Importance** - Track which POI features most impact predictions
|
||||
|
||||
**Example POI Feature Integration:**
|
||||
```python
|
||||
from app.ml.poi_feature_integrator import POIFeatureIntegrator
|
||||
|
||||
# Initialize POI integrator
|
||||
poi_integrator = POIFeatureIntegrator(external_service_url)
|
||||
|
||||
# Fetch and merge POI features
|
||||
poi_features = await poi_integrator.fetch_poi_features(tenant_id)
|
||||
training_df = poi_integrator.merge_poi_features(training_df, poi_features)
|
||||
|
||||
# POI features now available as columns:
|
||||
# training_df['school_density'], training_df['office_density'], etc.
|
||||
|
||||
# Add POI features as Prophet regressors
|
||||
for feature_name in poi_features.keys():
|
||||
prophet_model.add_regressor(feature_name)
|
||||
```
|
||||
|
||||
**Endpoint Used:**
|
||||
- Via shared client: `/api/v1/tenants/{tenant_id}/external/poi-context` (routed through API Gateway)
|
||||
|
||||
## Integration Points
|
||||
|
||||
### Dependencies (Services Called)
|
||||
- **Sales Service** - Fetch historical sales data for training
|
||||
- **External Service** - Fetch weather, traffic, holiday, and POI feature data
|
||||
- **PostgreSQL** - Store job queue, models, metrics, logs
|
||||
- **RabbitMQ** - Publish training completion events
|
||||
- **MinIO** - Store model artifacts (S3-compatible object storage with TLS)
|
||||
|
||||
### Dependents (Services That Call This)
|
||||
- **Forecasting Service** - Load trained models for predictions
|
||||
- **Orchestrator Service** - Trigger scheduled training jobs
|
||||
- **Frontend Dashboard** - Display training progress and model metrics
|
||||
- **AI Insights Service** - Analyze model performance patterns
|
||||
|
||||
## Security Measures
|
||||
|
||||
### Data Protection
|
||||
- **Tenant Isolation** - All training jobs scoped to tenant_id
|
||||
- **Model Access Control** - Only tenant can access their models
|
||||
- **Input Validation** - Validate all training parameters
|
||||
- **Rate Limiting** - Prevent training job spam
|
||||
|
||||
### Model Security
|
||||
- **Model Checksums** - SHA-256 hash verification for artifacts
|
||||
- **Version Control** - Track all model versions with audit trail
|
||||
- **Access Logging** - Log all model access and deployment
|
||||
- **Secure Storage** - Model files stored with restricted permissions
|
||||
|
||||
### WebSocket Security
|
||||
- **JWT Authentication** - Authenticate WebSocket connections
|
||||
- **Connection Limits** - Max connections per tenant
|
||||
- **Message Validation** - Validate all WebSocket messages
|
||||
- **Heartbeat Monitoring** - Detect and close stale connections
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### Training Performance
|
||||
1. **Parallel Processing** - Train multiple products concurrently
|
||||
2. **Data Caching** - Cache fetched external data across products
|
||||
3. **Incremental Training** - Only retrain changed products
|
||||
4. **Resource Limits** - CPU/memory limits per training job
|
||||
5. **Priority Queue** - Prioritize important products first
|
||||
|
||||
### Storage Optimization (MinIO)
|
||||
1. **Object Versioning** - MinIO maintains version history automatically
|
||||
2. **Lifecycle Policies** - Auto-cleanup old versions after 90 days
|
||||
3. **TLS Encryption** - Secure communication with MinIO
|
||||
4. **Distributed Storage** - MinIO handles replication and availability
|
||||
|
||||
### WebSocket Optimization
|
||||
1. **Message Batching** - Batch progress updates (every 2 seconds)
|
||||
2. **Connection Pooling** - Reuse WebSocket connections
|
||||
3. **Compression** - Enable WebSocket message compression
|
||||
4. **Heartbeat** - Keep connections alive efficiently
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
**Issue**: Training jobs stuck in "pending" status
|
||||
- **Cause**: Max concurrent jobs reached or worker process crashed
|
||||
- **Solution**: Check `MAX_CONCURRENT_JOBS` setting, restart service
|
||||
|
||||
**Issue**: WebSocket connection drops during training
|
||||
- **Cause**: Network timeout or client disconnection
|
||||
- **Solution**: Implement auto-reconnect logic in client
|
||||
|
||||
**Issue**: "Insufficient data" errors for many products
|
||||
- **Cause**: Products need 30+ days of sales history
|
||||
- **Solution**: Import more historical sales data or reduce `MIN_TRAINING_DATA_DAYS`
|
||||
|
||||
**Issue**: Low model accuracy (<70%)
|
||||
- **Cause**: Insufficient data, outliers, or changing business patterns
|
||||
- **Solution**: Clean outliers, add more features, or manually adjust Prophet params
|
||||
|
||||
### Debug Mode
|
||||
```bash
|
||||
# Enable detailed logging
|
||||
export LOG_LEVEL=DEBUG
|
||||
export PROPHET_VERBOSE=1
|
||||
|
||||
# Enable training profiling
|
||||
export ENABLE_PROFILING=1
|
||||
|
||||
# Disable concurrent jobs for debugging
|
||||
export MAX_CONCURRENT_JOBS=1
|
||||
```
|
||||
|
||||
## Competitive Advantages
|
||||
|
||||
1. **One-Click ML** - No data science expertise required
|
||||
2. **Real-Time Visibility** - WebSocket progress updates unique in bakery software
|
||||
3. **Continuous Learning** - Automatic weekly retraining
|
||||
4. **Version Control** - Track and compare all model versions
|
||||
5. **Production-Ready** - Robust error handling and retry mechanisms
|
||||
6. **Scalable** - Train models for thousands of products
|
||||
7. **Spanish Market** - Optimized for Spanish bakery patterns and holidays
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
- **Hyperparameter Tuning** - Automatic optimization of Prophet parameters
|
||||
- **A/B Testing** - Deploy multiple models and compare performance
|
||||
- **Distributed Training** - Scale across multiple machines
|
||||
- **GPU Acceleration** - Use GPUs for deep learning models
|
||||
- **AutoML** - Automatic algorithm selection (Prophet vs LSTM vs ARIMA)
|
||||
- **Model Explainability** - SHAP values to explain predictions
|
||||
- **Custom Algorithms** - Support for user-provided ML models
|
||||
- **Transfer Learning** - Use pre-trained models from similar bakeries
|
||||
|
||||
---
|
||||
|
||||
**For VUE Madrid Business Plan**: The Training Service demonstrates advanced ML engineering capabilities with automated pipeline management and real-time monitoring. The ability to continuously improve forecast accuracy without manual intervention represents significant operational efficiency and competitive advantage. This self-learning system is a key differentiator in the bakery software market and showcases technical innovation suitable for EU technology grants and investor presentations.
|
||||
84
services/training/alembic.ini
Normal file
84
services/training/alembic.ini
Normal file
@@ -0,0 +1,84 @@
|
||||
# ================================================================
|
||||
# services/training/alembic.ini - Alembic Configuration
|
||||
# ================================================================
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
script_location = migrations
|
||||
|
||||
# template used to generate migration file names
|
||||
file_template = %%(year)d%%(month).2d%%(day).2d_%%(hour).2d%%(minute).2d_%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
prepend_sys_path = .
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
timezone = Europe/Madrid
|
||||
|
||||
# max length of characters to apply to the
|
||||
# "slug" field
|
||||
truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
sourceless = false
|
||||
|
||||
# version of a migration file's filename format
|
||||
version_num_format = %%s
|
||||
|
||||
# version path separator
|
||||
version_path_separator = os
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
output_encoding = utf-8
|
||||
|
||||
# Database URL - will be overridden by environment variable or settings
|
||||
sqlalchemy.url = postgresql+asyncpg://training_user:password@training-db-service:5432/training_db
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts.
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
0
services/training/app/__init__.py
Normal file
0
services/training/app/__init__.py
Normal file
16
services/training/app/api/__init__.py
Normal file
16
services/training/app/api/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
Training API Layer
|
||||
HTTP endpoints for ML training operations and WebSocket connections
|
||||
"""
|
||||
|
||||
from .training_jobs import router as training_jobs_router
|
||||
from .training_operations import router as training_operations_router
|
||||
from .models import router as models_router
|
||||
from .websocket_operations import router as websocket_operations_router
|
||||
|
||||
__all__ = [
|
||||
"training_jobs_router",
|
||||
"training_operations_router",
|
||||
"models_router",
|
||||
"websocket_operations_router"
|
||||
]
|
||||
237
services/training/app/api/audit.py
Normal file
237
services/training/app/api/audit.py
Normal file
@@ -0,0 +1,237 @@
|
||||
# services/training/app/api/audit.py
|
||||
"""
|
||||
Audit Logs API - Retrieve audit trail for training service
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Path, status
|
||||
from typing import Optional, Dict, Any
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
import structlog
|
||||
from sqlalchemy import select, func, and_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import AuditLog
|
||||
from shared.auth.decorators import get_current_user_dep
|
||||
from shared.auth.access_control import require_user_role
|
||||
from shared.routing import RouteBuilder
|
||||
from shared.models.audit_log_schemas import (
|
||||
AuditLogResponse,
|
||||
AuditLogListResponse,
|
||||
AuditLogStatsResponse
|
||||
)
|
||||
from app.core.database import database_manager
|
||||
|
||||
route_builder = RouteBuilder('training')
|
||||
router = APIRouter(tags=["audit-logs"])
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
async def get_db():
|
||||
"""Database session dependency"""
|
||||
async with database_manager.get_session() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("audit-logs"),
|
||||
response_model=AuditLogListResponse
|
||||
)
|
||||
@require_user_role(['admin', 'owner'])
|
||||
async def get_audit_logs(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
start_date: Optional[datetime] = Query(None, description="Filter logs from this date"),
|
||||
end_date: Optional[datetime] = Query(None, description="Filter logs until this date"),
|
||||
user_id: Optional[UUID] = Query(None, description="Filter by user ID"),
|
||||
action: Optional[str] = Query(None, description="Filter by action type"),
|
||||
resource_type: Optional[str] = Query(None, description="Filter by resource type"),
|
||||
severity: Optional[str] = Query(None, description="Filter by severity level"),
|
||||
search: Optional[str] = Query(None, description="Search in description field"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="Number of records to return"),
|
||||
offset: int = Query(0, ge=0, description="Number of records to skip"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get audit logs for training service.
|
||||
Requires admin or owner role.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Retrieving audit logs",
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user.get("user_id"),
|
||||
filters={
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"action": action,
|
||||
"resource_type": resource_type,
|
||||
"severity": severity
|
||||
}
|
||||
)
|
||||
|
||||
# Build query filters
|
||||
filters = [AuditLog.tenant_id == tenant_id]
|
||||
|
||||
if start_date:
|
||||
filters.append(AuditLog.created_at >= start_date)
|
||||
if end_date:
|
||||
filters.append(AuditLog.created_at <= end_date)
|
||||
if user_id:
|
||||
filters.append(AuditLog.user_id == user_id)
|
||||
if action:
|
||||
filters.append(AuditLog.action == action)
|
||||
if resource_type:
|
||||
filters.append(AuditLog.resource_type == resource_type)
|
||||
if severity:
|
||||
filters.append(AuditLog.severity == severity)
|
||||
if search:
|
||||
filters.append(AuditLog.description.ilike(f"%{search}%"))
|
||||
|
||||
# Count total matching records
|
||||
count_query = select(func.count()).select_from(AuditLog).where(and_(*filters))
|
||||
total_result = await db.execute(count_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
# Fetch paginated results
|
||||
query = (
|
||||
select(AuditLog)
|
||||
.where(and_(*filters))
|
||||
.order_by(AuditLog.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
|
||||
result = await db.execute(query)
|
||||
audit_logs = result.scalars().all()
|
||||
|
||||
# Convert to response models
|
||||
items = [AuditLogResponse.from_orm(log) for log in audit_logs]
|
||||
|
||||
logger.info(
|
||||
"Successfully retrieved audit logs",
|
||||
tenant_id=tenant_id,
|
||||
total=total,
|
||||
returned=len(items)
|
||||
)
|
||||
|
||||
return AuditLogListResponse(
|
||||
items=items,
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
has_more=(offset + len(items)) < total
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to retrieve audit logs",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to retrieve audit logs: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("audit-logs/stats"),
|
||||
response_model=AuditLogStatsResponse
|
||||
)
|
||||
@require_user_role(['admin', 'owner'])
|
||||
async def get_audit_log_stats(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
start_date: Optional[datetime] = Query(None, description="Filter logs from this date"),
|
||||
end_date: Optional[datetime] = Query(None, description="Filter logs until this date"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get audit log statistics for training service.
|
||||
Requires admin or owner role.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Retrieving audit log statistics",
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
# Build base filters
|
||||
filters = [AuditLog.tenant_id == tenant_id]
|
||||
if start_date:
|
||||
filters.append(AuditLog.created_at >= start_date)
|
||||
if end_date:
|
||||
filters.append(AuditLog.created_at <= end_date)
|
||||
|
||||
# Total events
|
||||
count_query = select(func.count()).select_from(AuditLog).where(and_(*filters))
|
||||
total_result = await db.execute(count_query)
|
||||
total_events = total_result.scalar() or 0
|
||||
|
||||
# Events by action
|
||||
action_query = (
|
||||
select(AuditLog.action, func.count().label('count'))
|
||||
.where(and_(*filters))
|
||||
.group_by(AuditLog.action)
|
||||
)
|
||||
action_result = await db.execute(action_query)
|
||||
events_by_action = {row.action: row.count for row in action_result}
|
||||
|
||||
# Events by severity
|
||||
severity_query = (
|
||||
select(AuditLog.severity, func.count().label('count'))
|
||||
.where(and_(*filters))
|
||||
.group_by(AuditLog.severity)
|
||||
)
|
||||
severity_result = await db.execute(severity_query)
|
||||
events_by_severity = {row.severity: row.count for row in severity_result}
|
||||
|
||||
# Events by resource type
|
||||
resource_query = (
|
||||
select(AuditLog.resource_type, func.count().label('count'))
|
||||
.where(and_(*filters))
|
||||
.group_by(AuditLog.resource_type)
|
||||
)
|
||||
resource_result = await db.execute(resource_query)
|
||||
events_by_resource_type = {row.resource_type: row.count for row in resource_result}
|
||||
|
||||
# Date range
|
||||
date_range_query = (
|
||||
select(
|
||||
func.min(AuditLog.created_at).label('min_date'),
|
||||
func.max(AuditLog.created_at).label('max_date')
|
||||
)
|
||||
.where(and_(*filters))
|
||||
)
|
||||
date_result = await db.execute(date_range_query)
|
||||
date_row = date_result.one()
|
||||
|
||||
logger.info(
|
||||
"Successfully retrieved audit log statistics",
|
||||
tenant_id=tenant_id,
|
||||
total_events=total_events
|
||||
)
|
||||
|
||||
return AuditLogStatsResponse(
|
||||
total_events=total_events,
|
||||
events_by_action=events_by_action,
|
||||
events_by_severity=events_by_severity,
|
||||
events_by_resource_type=events_by_resource_type,
|
||||
date_range={
|
||||
"min": date_row.min_date,
|
||||
"max": date_row.max_date
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to retrieve audit log statistics",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to retrieve audit log statistics: {str(e)}"
|
||||
)
|
||||
261
services/training/app/api/health.py
Normal file
261
services/training/app/api/health.py
Normal file
@@ -0,0 +1,261 @@
|
||||
"""
|
||||
Enhanced Health Check Endpoints
|
||||
Comprehensive service health monitoring
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import text
|
||||
from typing import Dict, Any
|
||||
import psutil
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
import logging
|
||||
|
||||
from app.core.database import database_manager
|
||||
from app.utils.circuit_breaker import circuit_breaker_registry
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def check_database_health() -> Dict[str, Any]:
|
||||
"""Check database connectivity and performance"""
|
||||
try:
|
||||
start_time = datetime.now(timezone.utc)
|
||||
|
||||
async with database_manager.async_engine.begin() as conn:
|
||||
# Simple connectivity check
|
||||
await conn.execute(text("SELECT 1"))
|
||||
|
||||
# Check if we can access training tables
|
||||
result = await conn.execute(
|
||||
text("SELECT COUNT(*) FROM trained_models")
|
||||
)
|
||||
model_count = result.scalar()
|
||||
|
||||
# Check connection pool stats
|
||||
pool = database_manager.async_engine.pool
|
||||
pool_size = pool.size()
|
||||
pool_checked_out = pool.checked_out_connections()
|
||||
|
||||
response_time = (datetime.now(timezone.utc) - start_time).total_seconds()
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"response_time_seconds": round(response_time, 3),
|
||||
"model_count": model_count,
|
||||
"connection_pool": {
|
||||
"size": pool_size,
|
||||
"checked_out": pool_checked_out,
|
||||
"available": pool_size - pool_checked_out
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Database health check failed: {e}")
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
def check_system_resources() -> Dict[str, Any]:
|
||||
"""Check system resource usage"""
|
||||
try:
|
||||
cpu_percent = psutil.cpu_percent(interval=0.1)
|
||||
memory = psutil.virtual_memory()
|
||||
disk = psutil.disk_usage('/')
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"cpu": {
|
||||
"usage_percent": cpu_percent,
|
||||
"count": psutil.cpu_count()
|
||||
},
|
||||
"memory": {
|
||||
"total_mb": round(memory.total / 1024 / 1024, 2),
|
||||
"used_mb": round(memory.used / 1024 / 1024, 2),
|
||||
"available_mb": round(memory.available / 1024 / 1024, 2),
|
||||
"usage_percent": memory.percent
|
||||
},
|
||||
"disk": {
|
||||
"total_gb": round(disk.total / 1024 / 1024 / 1024, 2),
|
||||
"used_gb": round(disk.used / 1024 / 1024 / 1024, 2),
|
||||
"free_gb": round(disk.free / 1024 / 1024 / 1024, 2),
|
||||
"usage_percent": disk.percent
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"System resource check failed: {e}")
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
def check_model_storage() -> Dict[str, Any]:
|
||||
"""Check MinIO model storage health"""
|
||||
try:
|
||||
from shared.clients.minio_client import minio_client
|
||||
|
||||
# Check MinIO connectivity
|
||||
if not minio_client.health_check():
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"message": "MinIO service is not reachable",
|
||||
"storage_type": "minio"
|
||||
}
|
||||
|
||||
bucket_name = settings.MINIO_MODEL_BUCKET
|
||||
|
||||
# Check if bucket exists
|
||||
bucket_exists = minio_client.bucket_exists(bucket_name)
|
||||
if not bucket_exists:
|
||||
return {
|
||||
"status": "warning",
|
||||
"message": f"MinIO bucket does not exist: {bucket_name}",
|
||||
"storage_type": "minio"
|
||||
}
|
||||
|
||||
# Count model files in MinIO
|
||||
model_objects = minio_client.list_objects(bucket_name, prefix="models/")
|
||||
model_files = [obj for obj in model_objects if obj.endswith('.pkl')]
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"storage_type": "minio",
|
||||
"endpoint": settings.MINIO_ENDPOINT,
|
||||
"bucket": bucket_name,
|
||||
"use_ssl": settings.MINIO_USE_SSL,
|
||||
"model_files": len(model_files),
|
||||
"bucket_exists": bucket_exists
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MinIO storage check failed: {e}")
|
||||
return {
|
||||
"status": "error",
|
||||
"storage_type": "minio",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check() -> Dict[str, Any]:
|
||||
"""
|
||||
Basic health check endpoint.
|
||||
Returns 200 if service is running.
|
||||
"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "training-service",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/health/detailed")
|
||||
async def detailed_health_check() -> Dict[str, Any]:
|
||||
"""
|
||||
Detailed health check with component status.
|
||||
Includes database, system resources, and dependencies.
|
||||
"""
|
||||
database_health = await check_database_health()
|
||||
system_health = check_system_resources()
|
||||
storage_health = check_model_storage()
|
||||
circuit_breakers = circuit_breaker_registry.get_all_states()
|
||||
|
||||
# Determine overall status
|
||||
component_statuses = [
|
||||
database_health.get("status"),
|
||||
system_health.get("status"),
|
||||
storage_health.get("status")
|
||||
]
|
||||
|
||||
if "unhealthy" in component_statuses or "error" in component_statuses:
|
||||
overall_status = "unhealthy"
|
||||
elif "degraded" in component_statuses or "warning" in component_statuses:
|
||||
overall_status = "degraded"
|
||||
else:
|
||||
overall_status = "healthy"
|
||||
|
||||
return {
|
||||
"status": overall_status,
|
||||
"service": "training-service",
|
||||
"version": "1.0.0",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"components": {
|
||||
"database": database_health,
|
||||
"system": system_health,
|
||||
"storage": storage_health
|
||||
},
|
||||
"circuit_breakers": circuit_breakers,
|
||||
"configuration": {
|
||||
"max_concurrent_jobs": settings.MAX_CONCURRENT_TRAINING_JOBS,
|
||||
"min_training_days": settings.MIN_TRAINING_DATA_DAYS,
|
||||
"pool_size": settings.DB_POOL_SIZE,
|
||||
"pool_max_overflow": settings.DB_MAX_OVERFLOW
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.get("/health/ready")
|
||||
async def readiness_check() -> Dict[str, Any]:
|
||||
"""
|
||||
Readiness check for Kubernetes.
|
||||
Returns 200 only if service is ready to accept traffic.
|
||||
"""
|
||||
database_health = await check_database_health()
|
||||
|
||||
if database_health.get("status") != "healthy":
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Service not ready: database unavailable"
|
||||
)
|
||||
|
||||
storage_health = check_model_storage()
|
||||
if storage_health.get("status") == "error":
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Service not ready: model storage unavailable"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "ready",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/health/live")
|
||||
async def liveness_check() -> Dict[str, Any]:
|
||||
"""
|
||||
Liveness check for Kubernetes.
|
||||
Returns 200 if service process is alive.
|
||||
"""
|
||||
return {
|
||||
"status": "alive",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"pid": os.getpid()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/metrics/system")
|
||||
async def system_metrics() -> Dict[str, Any]:
|
||||
"""
|
||||
Detailed system metrics for monitoring.
|
||||
"""
|
||||
process = psutil.Process(os.getpid())
|
||||
|
||||
return {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"process": {
|
||||
"pid": os.getpid(),
|
||||
"cpu_percent": process.cpu_percent(interval=0.1),
|
||||
"memory_mb": round(process.memory_info().rss / 1024 / 1024, 2),
|
||||
"threads": process.num_threads(),
|
||||
"open_files": len(process.open_files()),
|
||||
"connections": len(process.connections())
|
||||
},
|
||||
"system": check_system_resources()
|
||||
}
|
||||
464
services/training/app/api/models.py
Normal file
464
services/training/app/api/models.py
Normal file
@@ -0,0 +1,464 @@
|
||||
"""
|
||||
Models API endpoints
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Path, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import List, Optional
|
||||
import structlog
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.schemas.training import TrainedModelResponse, ModelMetricsResponse
|
||||
from app.services.training_service import EnhancedTrainingService
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy import select, delete, func
|
||||
import uuid
|
||||
|
||||
from shared.auth.decorators import (
|
||||
get_current_user_dep,
|
||||
require_admin_role
|
||||
)
|
||||
from shared.routing import RouteBuilder
|
||||
from shared.auth.access_control import (
|
||||
require_user_role,
|
||||
admin_role_required,
|
||||
owner_role_required,
|
||||
require_subscription_tier,
|
||||
analytics_tier_required,
|
||||
enterprise_tier_required
|
||||
)
|
||||
|
||||
# Create route builder for consistent URL structure
|
||||
route_builder = RouteBuilder('training')
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter()
|
||||
|
||||
training_service = EnhancedTrainingService()
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("models") + "/{inventory_product_id}/active",
|
||||
response_model=TrainedModelResponse
|
||||
)
|
||||
async def get_active_model(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
inventory_product_id: str = Path(..., description="Inventory product UUID"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get the active model for a product - used by forecasting service
|
||||
"""
|
||||
try:
|
||||
logger.debug("Getting active model", tenant_id=tenant_id, inventory_product_id=inventory_product_id)
|
||||
# ✅ FIX: Wrap SQL with text() for SQLAlchemy 2.0 and add case-insensitive product name matching
|
||||
query = text("""
|
||||
SELECT * FROM trained_models
|
||||
WHERE tenant_id = :tenant_id
|
||||
AND inventory_product_id = :inventory_product_id
|
||||
AND is_active = true
|
||||
AND is_production = true
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1
|
||||
""")
|
||||
|
||||
result = await db.execute(query, {
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id
|
||||
})
|
||||
|
||||
model_record = result.fetchone()
|
||||
|
||||
if not model_record:
|
||||
logger.info("No active model found", tenant_id=tenant_id, inventory_product_id=inventory_product_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No active model found for product {inventory_product_id}"
|
||||
)
|
||||
|
||||
# ✅ FIX: Wrap update query with text() too
|
||||
update_query = text("""
|
||||
UPDATE trained_models
|
||||
SET last_used_at = :now
|
||||
WHERE id = :model_id
|
||||
""")
|
||||
|
||||
await db.execute(update_query, {
|
||||
"now": datetime.now(timezone.utc),
|
||||
"model_id": model_record.id
|
||||
})
|
||||
await db.commit()
|
||||
|
||||
return {
|
||||
"model_id": str(model_record.id),
|
||||
"tenant_id": str(model_record.tenant_id),
|
||||
"inventory_product_id": str(model_record.inventory_product_id),
|
||||
"model_type": model_record.model_type,
|
||||
"model_path": model_record.model_path,
|
||||
"version": 1, # Default version
|
||||
"training_samples": model_record.training_samples or 0,
|
||||
"features": model_record.features_used or [],
|
||||
"hyperparameters": model_record.hyperparameters or {},
|
||||
"training_metrics": {
|
||||
"mape": model_record.mape or 0.0,
|
||||
"mae": model_record.mae or 0.0,
|
||||
"rmse": model_record.rmse or 0.0,
|
||||
"r2_score": model_record.r2_score or 0.0
|
||||
},
|
||||
"is_active": model_record.is_active,
|
||||
"created_at": model_record.created_at,
|
||||
"data_period_start": model_record.training_start_date,
|
||||
"data_period_end": model_record.training_end_date
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = str(e) if str(e) else f"{type(e).__name__}: {repr(e)}"
|
||||
logger.error(f"Failed to get active model: {error_msg}", tenant_id=tenant_id, inventory_product_id=inventory_product_id)
|
||||
|
||||
# Handle client disconnection gracefully
|
||||
if "EndOfStream" in str(type(e)) or "WouldBlock" in str(type(e)):
|
||||
logger.info("Client disconnected during model retrieval", tenant_id=tenant_id, inventory_product_id=inventory_product_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_408_REQUEST_TIMEOUT,
|
||||
detail="Request connection closed"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve model"
|
||||
)
|
||||
|
||||
@router.get(
|
||||
route_builder.build_nested_resource_route("models", "model_id", "metrics"),
|
||||
response_model=ModelMetricsResponse
|
||||
)
|
||||
async def get_model_metrics(
|
||||
model_id: str = Path(..., description="Model ID"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get performance metrics for a specific model - used by forecasting service
|
||||
"""
|
||||
try:
|
||||
# Query the model by ID
|
||||
query = text("""
|
||||
SELECT * FROM trained_models
|
||||
WHERE id = :model_id
|
||||
""")
|
||||
|
||||
result = await db.execute(query, {"model_id": model_id})
|
||||
model_record = result.fetchone()
|
||||
|
||||
if not model_record:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Model {model_id} not found"
|
||||
)
|
||||
|
||||
# Return metrics in the format expected by forecasting service
|
||||
metrics = {
|
||||
"model_id": str(model_record.id),
|
||||
"accuracy": model_record.r2_score or 0.0, # Use R2 as accuracy measure
|
||||
"mape": model_record.mape or 0.0,
|
||||
"mae": model_record.mae or 0.0,
|
||||
"rmse": model_record.rmse or 0.0,
|
||||
"r2_score": model_record.r2_score or 0.0,
|
||||
"training_samples": model_record.training_samples or 0,
|
||||
"features_used": model_record.features_used or [],
|
||||
"model_type": model_record.model_type,
|
||||
"created_at": model_record.created_at.isoformat() if model_record.created_at else None,
|
||||
"last_used_at": model_record.last_used_at.isoformat() if model_record.last_used_at else None
|
||||
}
|
||||
|
||||
logger.info(f"Retrieved metrics for model {model_id}",
|
||||
mape=metrics["mape"],
|
||||
accuracy=metrics["accuracy"])
|
||||
|
||||
return metrics
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get model metrics: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve model metrics"
|
||||
)
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("models"),
|
||||
response_model=List[TrainedModelResponse]
|
||||
)
|
||||
async def list_models(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
status: Optional[str] = Query(None, description="Filter by status (active/inactive)"),
|
||||
model_type: Optional[str] = Query(None, description="Filter by model type"),
|
||||
limit: int = Query(50, ge=1, le=100, description="Maximum number of models to return"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
List models for a tenant - used by forecasting service for model discovery
|
||||
"""
|
||||
try:
|
||||
# Build query with filters
|
||||
query_parts = ["SELECT * FROM trained_models WHERE tenant_id = :tenant_id"]
|
||||
params = {"tenant_id": tenant_id}
|
||||
|
||||
if status == "deployed" or status == "active":
|
||||
query_parts.append("AND is_active = true AND is_production = true")
|
||||
elif status == "inactive":
|
||||
query_parts.append("AND (is_active = false OR is_production = false)")
|
||||
|
||||
if model_type:
|
||||
query_parts.append("AND model_type = :model_type")
|
||||
params["model_type"] = model_type
|
||||
|
||||
query_parts.append("ORDER BY created_at DESC LIMIT :limit")
|
||||
params["limit"] = limit
|
||||
|
||||
query = text(" ".join(query_parts))
|
||||
result = await db.execute(query, params)
|
||||
model_records = result.fetchall()
|
||||
|
||||
models = []
|
||||
for record in model_records:
|
||||
models.append({
|
||||
"model_id": str(record.id),
|
||||
"tenant_id": str(record.tenant_id),
|
||||
"inventory_product_id": str(record.inventory_product_id),
|
||||
"model_type": record.model_type,
|
||||
"model_path": record.model_path,
|
||||
"version": 1, # Default version
|
||||
"training_samples": record.training_samples or 0,
|
||||
"features": record.features_used or [],
|
||||
"hyperparameters": record.hyperparameters or {},
|
||||
"training_metrics": {
|
||||
"mape": record.mape or 0.0,
|
||||
"mae": record.mae or 0.0,
|
||||
"rmse": record.rmse or 0.0,
|
||||
"r2_score": record.r2_score or 0.0
|
||||
},
|
||||
"is_active": record.is_active,
|
||||
"created_at": record.created_at,
|
||||
"data_period_start": record.training_start_date,
|
||||
"data_period_end": record.training_end_date
|
||||
})
|
||||
|
||||
logger.info(f"Retrieved {len(models)} models for tenant {tenant_id}")
|
||||
return models
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list models: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve models"
|
||||
)
|
||||
|
||||
@router.delete("/models/tenant/{tenant_id}")
|
||||
@require_user_role(['admin', 'owner'])
|
||||
async def delete_tenant_models_complete(
|
||||
tenant_id: str,
|
||||
current_user = Depends(get_current_user_dep),
|
||||
_admin_check = Depends(require_admin_role),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Delete all trained models and artifacts for a tenant.
|
||||
|
||||
**WARNING: This operation is irreversible!**
|
||||
|
||||
This endpoint:
|
||||
1. Cancels any active training jobs for the tenant
|
||||
2. Deletes all model artifacts (files) from storage
|
||||
3. Deletes model records from database
|
||||
4. Deletes training logs and performance metrics
|
||||
5. Publishes deletion event
|
||||
|
||||
Used by admin user deletion process to clean up all training data.
|
||||
"""
|
||||
|
||||
try:
|
||||
tenant_uuid = uuid.UUID(tenant_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid tenant ID format"
|
||||
)
|
||||
|
||||
try:
|
||||
from app.models.training import (
|
||||
ModelTrainingLog,
|
||||
TrainedModel,
|
||||
ModelArtifact,
|
||||
ModelPerformanceMetric,
|
||||
TrainingJobQueue
|
||||
)
|
||||
from app.core.config import settings
|
||||
|
||||
deletion_stats = {
|
||||
"tenant_id": tenant_id,
|
||||
"deleted_at": datetime.now(timezone.utc).isoformat(),
|
||||
"jobs_cancelled": 0,
|
||||
"models_deleted": 0,
|
||||
"artifacts_deleted": 0,
|
||||
"minio_objects_deleted": 0,
|
||||
"training_logs_deleted": 0,
|
||||
"performance_metrics_deleted": 0,
|
||||
"errors": []
|
||||
}
|
||||
|
||||
# Step 1: Cancel active training jobs
|
||||
try:
|
||||
active_jobs_query = select(TrainingJobQueue).where(
|
||||
TrainingJobQueue.tenant_id == tenant_uuid,
|
||||
TrainingJobQueue.status.in_(["queued", "running", "pending"])
|
||||
)
|
||||
active_jobs_result = await db.execute(active_jobs_query)
|
||||
active_jobs = active_jobs_result.scalars().all()
|
||||
|
||||
for job in active_jobs:
|
||||
job.status = "cancelled"
|
||||
job.updated_at = datetime.now(timezone.utc)
|
||||
deletion_stats["jobs_cancelled"] += 1
|
||||
|
||||
if active_jobs:
|
||||
await db.commit()
|
||||
logger.info("Cancelled active training jobs",
|
||||
tenant_id=tenant_id,
|
||||
count=len(active_jobs))
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error cancelling training jobs: {str(e)}"
|
||||
deletion_stats["errors"].append(error_msg)
|
||||
logger.error(error_msg)
|
||||
|
||||
# Step 2: Delete model artifact files from MinIO storage
|
||||
try:
|
||||
from shared.clients.minio_client import minio_client
|
||||
|
||||
bucket_name = settings.MINIO_MODEL_BUCKET
|
||||
prefix = f"models/{tenant_id}/"
|
||||
|
||||
# List all objects for this tenant
|
||||
objects_to_delete = minio_client.list_objects(bucket_name, prefix=prefix)
|
||||
|
||||
files_deleted = 0
|
||||
for obj_name in objects_to_delete:
|
||||
try:
|
||||
minio_client.delete_object(bucket_name, obj_name)
|
||||
files_deleted += 1
|
||||
logger.debug("Deleted MinIO object", object_name=obj_name)
|
||||
except Exception as e:
|
||||
error_msg = f"Error deleting MinIO object {obj_name}: {str(e)}"
|
||||
deletion_stats["errors"].append(error_msg)
|
||||
logger.warning(error_msg)
|
||||
|
||||
deletion_stats["minio_objects_deleted"] = files_deleted
|
||||
|
||||
logger.info("Deleted MinIO objects",
|
||||
tenant_id=tenant_id,
|
||||
files_deleted=files_deleted)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error processing MinIO objects: {str(e)}"
|
||||
deletion_stats["errors"].append(error_msg)
|
||||
logger.error(error_msg)
|
||||
|
||||
# Step 3: Delete database records
|
||||
try:
|
||||
# Delete model performance metrics
|
||||
metrics_count_query = select(func.count(ModelPerformanceMetric.id)).where(
|
||||
ModelPerformanceMetric.tenant_id == tenant_uuid
|
||||
)
|
||||
metrics_count_result = await db.execute(metrics_count_query)
|
||||
metrics_count = metrics_count_result.scalar()
|
||||
|
||||
metrics_delete_query = delete(ModelPerformanceMetric).where(
|
||||
ModelPerformanceMetric.tenant_id == tenant_uuid
|
||||
)
|
||||
await db.execute(metrics_delete_query)
|
||||
deletion_stats["performance_metrics_deleted"] = metrics_count
|
||||
|
||||
# Delete model artifacts records
|
||||
artifacts_count_query = select(func.count(ModelArtifact.id)).where(
|
||||
ModelArtifact.tenant_id == tenant_uuid
|
||||
)
|
||||
artifacts_count_result = await db.execute(artifacts_count_query)
|
||||
artifacts_count = artifacts_count_result.scalar()
|
||||
|
||||
artifacts_delete_query = delete(ModelArtifact).where(
|
||||
ModelArtifact.tenant_id == tenant_uuid
|
||||
)
|
||||
await db.execute(artifacts_delete_query)
|
||||
deletion_stats["artifacts_deleted"] = artifacts_count
|
||||
|
||||
# Delete trained models
|
||||
models_count_query = select(func.count(TrainedModel.id)).where(
|
||||
TrainedModel.tenant_id == tenant_uuid
|
||||
)
|
||||
models_count_result = await db.execute(models_count_query)
|
||||
models_count = models_count_result.scalar()
|
||||
|
||||
models_delete_query = delete(TrainedModel).where(
|
||||
TrainedModel.tenant_id == tenant_uuid
|
||||
)
|
||||
await db.execute(models_delete_query)
|
||||
deletion_stats["models_deleted"] = models_count
|
||||
|
||||
# Delete training logs
|
||||
logs_count_query = select(func.count(ModelTrainingLog.id)).where(
|
||||
ModelTrainingLog.tenant_id == tenant_uuid
|
||||
)
|
||||
logs_count_result = await db.execute(logs_count_query)
|
||||
logs_count = logs_count_result.scalar()
|
||||
|
||||
logs_delete_query = delete(ModelTrainingLog).where(
|
||||
ModelTrainingLog.tenant_id == tenant_uuid
|
||||
)
|
||||
await db.execute(logs_delete_query)
|
||||
deletion_stats["training_logs_deleted"] = logs_count
|
||||
|
||||
# Delete job queue entries
|
||||
queue_delete_query = delete(TrainingJobQueue).where(
|
||||
TrainingJobQueue.tenant_id == tenant_uuid
|
||||
)
|
||||
await db.execute(queue_delete_query)
|
||||
|
||||
await db.commit()
|
||||
|
||||
logger.info("Deleted training database records",
|
||||
tenant_id=tenant_id,
|
||||
models=models_count,
|
||||
artifacts=artifacts_count,
|
||||
logs=logs_count,
|
||||
metrics=metrics_count)
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
error_msg = f"Error deleting database records: {str(e)}"
|
||||
deletion_stats["errors"].append(error_msg)
|
||||
logger.error(error_msg)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=error_msg
|
||||
)
|
||||
|
||||
# Step 4: Models deleted successfully (MinIO cleanup already done in Step 2)
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"All training data for tenant {tenant_id} deleted successfully",
|
||||
"deletion_details": deletion_stats
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error deleting tenant models",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to delete tenant models: {str(e)}"
|
||||
)
|
||||
410
services/training/app/api/monitoring.py
Normal file
410
services/training/app/api/monitoring.py
Normal file
@@ -0,0 +1,410 @@
|
||||
"""
|
||||
Monitoring and Observability Endpoints
|
||||
Real-time service monitoring and diagnostics
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Query
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from sqlalchemy import text, func
|
||||
import logging
|
||||
|
||||
from app.core.database import database_manager
|
||||
from app.utils.circuit_breaker import circuit_breaker_registry
|
||||
from app.models.training import ModelTrainingLog, TrainingJobQueue, TrainedModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/monitoring/circuit-breakers")
|
||||
async def get_circuit_breaker_status() -> Dict[str, Any]:
|
||||
"""
|
||||
Get status of all circuit breakers.
|
||||
Useful for monitoring external service health.
|
||||
"""
|
||||
breakers = circuit_breaker_registry.get_all_states()
|
||||
|
||||
return {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"circuit_breakers": breakers,
|
||||
"summary": {
|
||||
"total": len(breakers),
|
||||
"open": sum(1 for b in breakers.values() if b["state"] == "open"),
|
||||
"half_open": sum(1 for b in breakers.values() if b["state"] == "half_open"),
|
||||
"closed": sum(1 for b in breakers.values() if b["state"] == "closed")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.post("/monitoring/circuit-breakers/{name}/reset")
|
||||
async def reset_circuit_breaker(name: str) -> Dict[str, str]:
|
||||
"""
|
||||
Manually reset a circuit breaker.
|
||||
Use with caution - only reset if you know the service has recovered.
|
||||
"""
|
||||
circuit_breaker_registry.reset(name)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Circuit breaker '{name}' has been reset",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/monitoring/training-jobs")
|
||||
async def get_training_job_stats(
|
||||
hours: int = Query(default=24, ge=1, le=168, description="Look back period in hours")
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get training job statistics for the specified period.
|
||||
"""
|
||||
try:
|
||||
since = datetime.now(timezone.utc) - timedelta(hours=hours)
|
||||
|
||||
async with database_manager.get_session() as session:
|
||||
# Get job counts by status
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT status, COUNT(*) as count
|
||||
FROM model_training_logs
|
||||
WHERE created_at >= :since
|
||||
GROUP BY status
|
||||
"""),
|
||||
{"since": since}
|
||||
)
|
||||
status_counts = dict(result.fetchall())
|
||||
|
||||
# Get average training time for completed jobs
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT AVG(EXTRACT(EPOCH FROM (end_time - start_time))) as avg_duration
|
||||
FROM model_training_logs
|
||||
WHERE status = 'completed'
|
||||
AND created_at >= :since
|
||||
AND end_time IS NOT NULL
|
||||
"""),
|
||||
{"since": since}
|
||||
)
|
||||
avg_duration = result.scalar()
|
||||
|
||||
# Get failure rate
|
||||
total = sum(status_counts.values())
|
||||
failed = status_counts.get('failed', 0)
|
||||
failure_rate = (failed / total * 100) if total > 0 else 0
|
||||
|
||||
# Get recent jobs
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT job_id, tenant_id, status, progress, start_time, end_time
|
||||
FROM model_training_logs
|
||||
WHERE created_at >= :since
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 10
|
||||
"""),
|
||||
{"since": since}
|
||||
)
|
||||
recent_jobs = [
|
||||
{
|
||||
"job_id": row.job_id,
|
||||
"tenant_id": str(row.tenant_id),
|
||||
"status": row.status,
|
||||
"progress": row.progress,
|
||||
"start_time": row.start_time.isoformat() if row.start_time else None,
|
||||
"end_time": row.end_time.isoformat() if row.end_time else None
|
||||
}
|
||||
for row in result.fetchall()
|
||||
]
|
||||
|
||||
return {
|
||||
"period_hours": hours,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"summary": {
|
||||
"total_jobs": total,
|
||||
"by_status": status_counts,
|
||||
"failure_rate_percent": round(failure_rate, 2),
|
||||
"avg_duration_seconds": round(avg_duration, 2) if avg_duration else None
|
||||
},
|
||||
"recent_jobs": recent_jobs
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get training job stats: {e}")
|
||||
return {
|
||||
"error": str(e),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/monitoring/models")
|
||||
async def get_model_stats() -> Dict[str, Any]:
|
||||
"""
|
||||
Get statistics about trained models.
|
||||
"""
|
||||
try:
|
||||
async with database_manager.get_session() as session:
|
||||
# Total models
|
||||
result = await session.execute(
|
||||
text("SELECT COUNT(*) FROM trained_models")
|
||||
)
|
||||
total_models = result.scalar()
|
||||
|
||||
# Active models
|
||||
result = await session.execute(
|
||||
text("SELECT COUNT(*) FROM trained_models WHERE is_active = true")
|
||||
)
|
||||
active_models = result.scalar()
|
||||
|
||||
# Production models
|
||||
result = await session.execute(
|
||||
text("SELECT COUNT(*) FROM trained_models WHERE is_production = true")
|
||||
)
|
||||
production_models = result.scalar()
|
||||
|
||||
# Models by type
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT model_type, COUNT(*) as count
|
||||
FROM trained_models
|
||||
GROUP BY model_type
|
||||
""")
|
||||
)
|
||||
models_by_type = dict(result.fetchall())
|
||||
|
||||
# Average model performance (MAPE)
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT AVG(mape) as avg_mape
|
||||
FROM trained_models
|
||||
WHERE mape IS NOT NULL
|
||||
AND is_active = true
|
||||
""")
|
||||
)
|
||||
avg_mape = result.scalar()
|
||||
|
||||
# Models created today
|
||||
today = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT COUNT(*) FROM trained_models
|
||||
WHERE created_at >= :today
|
||||
"""),
|
||||
{"today": today}
|
||||
)
|
||||
models_today = result.scalar()
|
||||
|
||||
return {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"summary": {
|
||||
"total_models": total_models,
|
||||
"active_models": active_models,
|
||||
"production_models": production_models,
|
||||
"models_created_today": models_today,
|
||||
"average_mape_percent": round(avg_mape, 2) if avg_mape else None
|
||||
},
|
||||
"by_type": models_by_type
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get model stats: {e}")
|
||||
return {
|
||||
"error": str(e),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/monitoring/queue")
|
||||
async def get_queue_status() -> Dict[str, Any]:
|
||||
"""
|
||||
Get training job queue status.
|
||||
"""
|
||||
try:
|
||||
async with database_manager.get_session() as session:
|
||||
# Queued jobs
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT COUNT(*) FROM training_job_queue
|
||||
WHERE status = 'queued'
|
||||
""")
|
||||
)
|
||||
queued = result.scalar()
|
||||
|
||||
# Running jobs
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT COUNT(*) FROM training_job_queue
|
||||
WHERE status = 'running'
|
||||
""")
|
||||
)
|
||||
running = result.scalar()
|
||||
|
||||
# Get oldest queued job
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT created_at FROM training_job_queue
|
||||
WHERE status = 'queued'
|
||||
ORDER BY created_at ASC
|
||||
LIMIT 1
|
||||
""")
|
||||
)
|
||||
oldest_queued = result.scalar()
|
||||
|
||||
# Calculate wait time
|
||||
if oldest_queued:
|
||||
wait_time_seconds = (datetime.now(timezone.utc) - oldest_queued).total_seconds()
|
||||
else:
|
||||
wait_time_seconds = 0
|
||||
|
||||
return {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"queue": {
|
||||
"queued": queued,
|
||||
"running": running,
|
||||
"oldest_wait_time_seconds": round(wait_time_seconds, 2) if oldest_queued else 0,
|
||||
"oldest_queued_at": oldest_queued.isoformat() if oldest_queued else None
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get queue status: {e}")
|
||||
return {
|
||||
"error": str(e),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/monitoring/performance")
|
||||
async def get_performance_metrics(
|
||||
tenant_id: Optional[str] = Query(None, description="Filter by tenant ID")
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get model performance metrics.
|
||||
"""
|
||||
try:
|
||||
async with database_manager.get_session() as session:
|
||||
query_params = {}
|
||||
where_clause = ""
|
||||
|
||||
if tenant_id:
|
||||
where_clause = "WHERE tenant_id = :tenant_id"
|
||||
query_params["tenant_id"] = tenant_id
|
||||
|
||||
# Get performance distribution
|
||||
result = await session.execute(
|
||||
text(f"""
|
||||
SELECT
|
||||
COUNT(*) as total,
|
||||
AVG(mape) as avg_mape,
|
||||
MIN(mape) as min_mape,
|
||||
MAX(mape) as max_mape,
|
||||
PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY mape) as median_mape,
|
||||
AVG(mae) as avg_mae,
|
||||
AVG(rmse) as avg_rmse
|
||||
FROM model_performance_metrics
|
||||
{where_clause}
|
||||
"""),
|
||||
query_params
|
||||
)
|
||||
stats = result.fetchone()
|
||||
|
||||
# Get accuracy distribution (buckets)
|
||||
result = await session.execute(
|
||||
text(f"""
|
||||
SELECT
|
||||
CASE
|
||||
WHEN mape <= 10 THEN 'excellent'
|
||||
WHEN mape <= 20 THEN 'good'
|
||||
WHEN mape <= 30 THEN 'acceptable'
|
||||
ELSE 'poor'
|
||||
END as accuracy_category,
|
||||
COUNT(*) as count
|
||||
FROM model_performance_metrics
|
||||
{where_clause}
|
||||
GROUP BY accuracy_category
|
||||
"""),
|
||||
query_params
|
||||
)
|
||||
distribution = dict(result.fetchall())
|
||||
|
||||
return {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"tenant_id": tenant_id,
|
||||
"statistics": {
|
||||
"total_metrics": stats.total if stats else 0,
|
||||
"avg_mape_percent": round(stats.avg_mape, 2) if stats and stats.avg_mape else None,
|
||||
"min_mape_percent": round(stats.min_mape, 2) if stats and stats.min_mape else None,
|
||||
"max_mape_percent": round(stats.max_mape, 2) if stats and stats.max_mape else None,
|
||||
"median_mape_percent": round(stats.median_mape, 2) if stats and stats.median_mape else None,
|
||||
"avg_mae": round(stats.avg_mae, 2) if stats and stats.avg_mae else None,
|
||||
"avg_rmse": round(stats.avg_rmse, 2) if stats and stats.avg_rmse else None
|
||||
},
|
||||
"distribution": distribution
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get performance metrics: {e}")
|
||||
return {
|
||||
"error": str(e),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/monitoring/alerts")
|
||||
async def get_alerts() -> Dict[str, Any]:
|
||||
"""
|
||||
Get active alerts and warnings based on system state.
|
||||
"""
|
||||
alerts = []
|
||||
warnings = []
|
||||
|
||||
try:
|
||||
# Check circuit breakers
|
||||
breakers = circuit_breaker_registry.get_all_states()
|
||||
for name, state in breakers.items():
|
||||
if state["state"] == "open":
|
||||
alerts.append({
|
||||
"type": "circuit_breaker_open",
|
||||
"severity": "high",
|
||||
"message": f"Circuit breaker '{name}' is OPEN - service unavailable",
|
||||
"details": state
|
||||
})
|
||||
elif state["state"] == "half_open":
|
||||
warnings.append({
|
||||
"type": "circuit_breaker_recovering",
|
||||
"severity": "medium",
|
||||
"message": f"Circuit breaker '{name}' is recovering",
|
||||
"details": state
|
||||
})
|
||||
|
||||
# Check queue backlog
|
||||
async with database_manager.get_session() as session:
|
||||
result = await session.execute(
|
||||
text("SELECT COUNT(*) FROM training_job_queue WHERE status = 'queued'")
|
||||
)
|
||||
queued = result.scalar()
|
||||
|
||||
if queued > 10:
|
||||
warnings.append({
|
||||
"type": "queue_backlog",
|
||||
"severity": "medium",
|
||||
"message": f"Training queue has {queued} pending jobs",
|
||||
"count": queued
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate alerts: {e}")
|
||||
alerts.append({
|
||||
"type": "monitoring_error",
|
||||
"severity": "high",
|
||||
"message": f"Failed to check system alerts: {str(e)}"
|
||||
})
|
||||
|
||||
return {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"summary": {
|
||||
"total_alerts": len(alerts),
|
||||
"total_warnings": len(warnings)
|
||||
},
|
||||
"alerts": alerts,
|
||||
"warnings": warnings
|
||||
}
|
||||
123
services/training/app/api/training_jobs.py
Normal file
123
services/training/app/api/training_jobs.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
Training Jobs API - ATOMIC CRUD operations
|
||||
Handles basic training job creation and retrieval
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Path, Query, Request
|
||||
from typing import List, Optional
|
||||
import structlog
|
||||
from shared.routing import RouteBuilder
|
||||
from shared.monitoring.decorators import track_execution_time
|
||||
from shared.monitoring.metrics import get_metrics_collector
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
from app.services.training_service import EnhancedTrainingService
|
||||
from app.schemas.training import TrainingJobResponse
|
||||
from shared.database.base import create_database_manager
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
route_builder = RouteBuilder('training')
|
||||
|
||||
router = APIRouter(tags=["training-jobs"])
|
||||
|
||||
def get_enhanced_training_service():
|
||||
"""Dependency injection for EnhancedTrainingService"""
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
|
||||
return EnhancedTrainingService(database_manager)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_nested_resource_route("jobs", "job_id", "status")
|
||||
)
|
||||
@track_execution_time("enhanced_job_status_duration_seconds", "training-service")
|
||||
async def get_training_job_status(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
job_id: str = Path(..., description="Job ID"),
|
||||
request_obj: Request = None,
|
||||
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
|
||||
):
|
||||
"""
|
||||
Get training job status using repository pattern.
|
||||
"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
# Get status using enhanced service
|
||||
status_info = await enhanced_training_service.get_training_status(job_id)
|
||||
|
||||
if not status_info or status_info.get("error"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Training job not found"
|
||||
)
|
||||
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_status_requests_total")
|
||||
|
||||
return {
|
||||
**status_info,
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_status_errors_total")
|
||||
logger.error("Failed to get training status",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get training status"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("statistics")
|
||||
)
|
||||
@track_execution_time("enhanced_tenant_statistics_duration_seconds", "training-service")
|
||||
async def get_tenant_statistics(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
request_obj: Request = None,
|
||||
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
|
||||
):
|
||||
"""
|
||||
Get comprehensive tenant statistics using repository pattern.
|
||||
"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
# Get statistics using enhanced service
|
||||
statistics = await enhanced_training_service.get_tenant_statistics(tenant_id)
|
||||
|
||||
if statistics.get("error"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=statistics["error"]
|
||||
)
|
||||
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_statistics_requests_total")
|
||||
|
||||
return {
|
||||
**statistics,
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_statistics_errors_total")
|
||||
logger.error("Failed to get tenant statistics",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get tenant statistics"
|
||||
)
|
||||
821
services/training/app/api/training_operations.py
Normal file
821
services/training/app/api/training_operations.py
Normal file
@@ -0,0 +1,821 @@
|
||||
"""
|
||||
Training Operations API - BUSINESS logic
|
||||
Handles training job execution and metrics
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks, Request, Path
|
||||
from typing import Optional, Dict, Any
|
||||
import structlog
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
import shared.redis_utils
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from shared.routing import RouteBuilder
|
||||
from shared.monitoring.decorators import track_execution_time
|
||||
from shared.monitoring.metrics import get_metrics_collector
|
||||
from shared.database.base import create_database_manager
|
||||
from shared.auth.decorators import get_current_user_dep
|
||||
from shared.auth.access_control import require_user_role, admin_role_required, service_only_access
|
||||
from shared.security import create_audit_logger, create_rate_limiter, AuditSeverity, AuditAction
|
||||
from shared.subscription.plans import (
|
||||
get_training_job_quota,
|
||||
get_dataset_size_limit
|
||||
)
|
||||
|
||||
from app.services.training_service import EnhancedTrainingService
|
||||
from app.schemas.training import (
|
||||
TrainingJobRequest,
|
||||
SingleProductTrainingRequest,
|
||||
TrainingJobResponse
|
||||
)
|
||||
from app.utils.time_estimation import (
|
||||
calculate_initial_estimate,
|
||||
calculate_estimated_completion_time,
|
||||
get_historical_average_estimate
|
||||
)
|
||||
from app.services.training_events import (
|
||||
publish_training_started,
|
||||
publish_training_completed,
|
||||
publish_training_failed
|
||||
)
|
||||
from app.core.config import settings
|
||||
from app.core.database import get_db
|
||||
from app.models import AuditLog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
route_builder = RouteBuilder('training')
|
||||
|
||||
router = APIRouter(tags=["training-operations"])
|
||||
|
||||
# Initialize audit logger
|
||||
audit_logger = create_audit_logger("training-service", AuditLog)
|
||||
|
||||
# Redis client for rate limiting
|
||||
_redis_client = None
|
||||
|
||||
async def get_training_redis_client():
|
||||
"""Get or create Redis client for rate limiting"""
|
||||
global _redis_client
|
||||
if _redis_client is None:
|
||||
# Initialize Redis if not already done
|
||||
try:
|
||||
from app.core.config import settings
|
||||
_redis_client = await shared.redis_utils.initialize_redis(settings.REDIS_URL)
|
||||
except:
|
||||
# Fallback to getting the client directly (if already initialized elsewhere)
|
||||
_redis_client = await shared.redis_utils.get_redis_client()
|
||||
return _redis_client
|
||||
|
||||
async def get_rate_limiter():
|
||||
"""Dependency for rate limiter"""
|
||||
redis_client = await get_training_redis_client()
|
||||
return create_rate_limiter(redis_client)
|
||||
|
||||
def get_enhanced_training_service():
|
||||
"""Dependency injection for EnhancedTrainingService"""
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
|
||||
return EnhancedTrainingService(database_manager)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("jobs"), response_model=TrainingJobResponse)
|
||||
@require_user_role(['admin', 'owner'])
|
||||
@track_execution_time("enhanced_training_job_duration_seconds", "training-service")
|
||||
async def start_training_job(
|
||||
request: TrainingJobRequest,
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
background_tasks: BackgroundTasks = BackgroundTasks(),
|
||||
request_obj: Request = None,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service),
|
||||
rate_limiter = Depends(get_rate_limiter),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Start a new training job for all tenant products (Admin+ only, quota enforced).
|
||||
|
||||
**RBAC:** Admin or Owner role required
|
||||
**Quotas:**
|
||||
- Starter: 1 training job/day, max 1,000 rows
|
||||
- Professional: 5 training jobs/day, max 10,000 rows
|
||||
- Enterprise: Unlimited jobs, unlimited rows
|
||||
|
||||
Enhanced immediate response pattern:
|
||||
1. Validate subscription tier and quotas
|
||||
2. Validate request with enhanced validation
|
||||
3. Create job record using repository pattern
|
||||
4. Return 200 with enhanced job details
|
||||
5. Execute enhanced training in background with repository tracking
|
||||
|
||||
Enhanced features:
|
||||
- Repository pattern for data access
|
||||
- Quota enforcement by subscription tier
|
||||
- Audit logging for all operations
|
||||
- Enhanced error handling and logging
|
||||
- Metrics tracking and monitoring
|
||||
- Transactional operations
|
||||
"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
# Get subscription tier and enforce quotas
|
||||
tier = current_user.get('subscription_tier', 'starter')
|
||||
|
||||
# Estimate dataset size (this should come from the request or be calculated)
|
||||
# For now, we'll assume a reasonable estimate
|
||||
estimated_dataset_size = request.estimated_rows if hasattr(request, 'estimated_rows') else 500
|
||||
|
||||
# Initialize variables for later use
|
||||
quota_result = None
|
||||
quota_limit = None
|
||||
|
||||
try:
|
||||
# Validate dataset size limits
|
||||
await rate_limiter.validate_dataset_size(
|
||||
tenant_id, estimated_dataset_size, tier
|
||||
)
|
||||
|
||||
# Check daily training job quota
|
||||
quota_limit = get_training_job_quota(tier)
|
||||
quota_result = await rate_limiter.check_and_increment_quota(
|
||||
tenant_id,
|
||||
"training_jobs",
|
||||
quota_limit,
|
||||
period=86400 # 24 hours
|
||||
)
|
||||
|
||||
logger.info("Training job quota check passed",
|
||||
tenant_id=tenant_id,
|
||||
tier=tier,
|
||||
current_usage=quota_result.get('current', 0) if quota_result else 0,
|
||||
limit=quota_limit)
|
||||
|
||||
except HTTPException:
|
||||
# Quota or validation error - re-raise
|
||||
raise
|
||||
except Exception as quota_error:
|
||||
logger.error("Quota validation failed", error=str(quota_error))
|
||||
# Continue with job creation but log the error
|
||||
|
||||
try:
|
||||
# CRITICAL FIX: Check for existing running jobs before starting new one
|
||||
# This prevents duplicate tenant-level training jobs
|
||||
async with enhanced_training_service.database_manager.get_session() as check_session:
|
||||
from app.repositories.training_log_repository import TrainingLogRepository
|
||||
log_repo = TrainingLogRepository(check_session)
|
||||
|
||||
# Check for active jobs (running or pending)
|
||||
active_jobs = await log_repo.get_active_jobs(tenant_id=tenant_id)
|
||||
pending_jobs = await log_repo.get_logs_by_tenant(
|
||||
tenant_id=tenant_id,
|
||||
status="pending",
|
||||
limit=10
|
||||
)
|
||||
|
||||
all_active = active_jobs + pending_jobs
|
||||
|
||||
if all_active:
|
||||
# Training job already in progress, return existing job info
|
||||
existing_job = all_active[0]
|
||||
logger.info("Training job already in progress, returning existing job",
|
||||
existing_job_id=existing_job.job_id,
|
||||
tenant_id=tenant_id,
|
||||
status=existing_job.status)
|
||||
|
||||
return TrainingJobResponse(
|
||||
job_id=existing_job.job_id,
|
||||
tenant_id=tenant_id,
|
||||
status=existing_job.status,
|
||||
message=f"Training job already in progress (started {existing_job.created_at.isoformat() if existing_job.created_at else 'recently'})",
|
||||
created_at=existing_job.created_at or datetime.now(timezone.utc),
|
||||
estimated_duration_minutes=existing_job.config.get("estimated_duration_minutes", 15) if existing_job.config else 15,
|
||||
training_results={
|
||||
"total_products": 0,
|
||||
"successful_trainings": 0,
|
||||
"failed_trainings": 0,
|
||||
"products": [],
|
||||
"overall_training_time_seconds": 0.0
|
||||
},
|
||||
data_summary=None,
|
||||
completed_at=None,
|
||||
error_details=None,
|
||||
processing_metadata={
|
||||
"background_task": True,
|
||||
"async_execution": True,
|
||||
"existing_job": True,
|
||||
"deduplication": True
|
||||
}
|
||||
)
|
||||
|
||||
# No existing job, proceed with creating new one
|
||||
# Generate enhanced job ID
|
||||
job_id = f"enhanced_training_{tenant_id}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
logger.info("Creating enhanced training job using repository pattern",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id)
|
||||
|
||||
# Record job creation metrics
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_training_jobs_created_total")
|
||||
|
||||
# Calculate intelligent time estimate
|
||||
# We don't know exact product count yet, so use historical average or estimate
|
||||
try:
|
||||
# Try to get historical average for this tenant
|
||||
historical_avg = await get_historical_average_estimate(db, tenant_id)
|
||||
|
||||
# If no historical data, estimate based on typical product count (10-20 products)
|
||||
estimated_products = 15 # Conservative estimate
|
||||
estimated_duration_minutes = calculate_initial_estimate(
|
||||
total_products=estimated_products,
|
||||
avg_training_time_per_product=historical_avg if historical_avg else 60.0
|
||||
)
|
||||
except Exception as est_error:
|
||||
logger.warning("Could not calculate intelligent estimate, using default",
|
||||
error=str(est_error))
|
||||
estimated_duration_minutes = 15 # Default fallback
|
||||
|
||||
# Calculate estimated completion time
|
||||
estimated_completion_time = calculate_estimated_completion_time(estimated_duration_minutes)
|
||||
|
||||
# Note: training.started event will be published by the trainer with accurate product count
|
||||
# We don't publish here to avoid duplicate events
|
||||
|
||||
# Add enhanced background task
|
||||
background_tasks.add_task(
|
||||
execute_training_job_background,
|
||||
tenant_id=tenant_id,
|
||||
job_id=job_id,
|
||||
bakery_location=(40.4168, -3.7038),
|
||||
requested_start=request.start_date,
|
||||
requested_end=request.end_date,
|
||||
estimated_duration_minutes=estimated_duration_minutes
|
||||
)
|
||||
|
||||
# Return enhanced immediate success response
|
||||
response_data = {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"status": "pending",
|
||||
"message": "Enhanced training job started successfully using repository pattern",
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"estimated_duration_minutes": estimated_duration_minutes,
|
||||
"training_results": {
|
||||
"total_products": 0,
|
||||
"successful_trainings": 0,
|
||||
"failed_trainings": 0,
|
||||
"products": [],
|
||||
"overall_training_time_seconds": 0.0
|
||||
},
|
||||
"data_summary": None,
|
||||
"completed_at": None,
|
||||
"error_details": None,
|
||||
"processing_metadata": {
|
||||
"background_task": True,
|
||||
"async_execution": True,
|
||||
"enhanced_features": True,
|
||||
"repository_pattern": True,
|
||||
"dependency_injection": True
|
||||
}
|
||||
}
|
||||
|
||||
logger.info("Enhanced training job queued successfully",
|
||||
job_id=job_id,
|
||||
features=["repository-pattern", "dependency-injection", "enhanced-tracking"])
|
||||
|
||||
# Log audit event for training job creation
|
||||
try:
|
||||
from app.core.database import database_manager
|
||||
async with database_manager.get_session() as db:
|
||||
await audit_logger.log_event(
|
||||
db_session=db,
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user["user_id"],
|
||||
action=AuditAction.CREATE.value,
|
||||
resource_type="training_job",
|
||||
resource_id=job_id,
|
||||
severity=AuditSeverity.MEDIUM.value,
|
||||
description=f"Started training job (tier: {tier})",
|
||||
audit_metadata={
|
||||
"job_id": job_id,
|
||||
"tier": tier,
|
||||
"estimated_dataset_size": estimated_dataset_size,
|
||||
"quota_usage": quota_result.get('current', 0) if quota_result else 0,
|
||||
"quota_limit": quota_limit if quota_limit else "unlimited"
|
||||
},
|
||||
endpoint="/jobs",
|
||||
method="POST"
|
||||
)
|
||||
except Exception as audit_error:
|
||||
logger.warning("Failed to log audit event", error=str(audit_error))
|
||||
|
||||
return TrainingJobResponse(**response_data)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_training_validation_errors_total")
|
||||
logger.error("Enhanced training job validation error",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_training_job_errors_total")
|
||||
logger.error("Failed to queue enhanced training job",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to start enhanced training job"
|
||||
)
|
||||
|
||||
|
||||
async def execute_training_job_background(
|
||||
tenant_id: str,
|
||||
job_id: str,
|
||||
bakery_location: tuple,
|
||||
requested_start: Optional[datetime] = None,
|
||||
requested_end: Optional[datetime] = None,
|
||||
estimated_duration_minutes: int = 15
|
||||
):
|
||||
"""
|
||||
Enhanced background task that executes the training job using repository pattern.
|
||||
|
||||
Enhanced features:
|
||||
- Repository pattern for all data operations
|
||||
- Enhanced error handling with structured logging
|
||||
- Transactional operations for data consistency
|
||||
- Comprehensive metrics tracking
|
||||
- Database connection pooling
|
||||
- Enhanced progress reporting
|
||||
"""
|
||||
|
||||
logger.info("Enhanced background training job started",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
features=["repository-pattern", "enhanced-tracking"])
|
||||
|
||||
# Get enhanced training service with dependency injection
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
|
||||
enhanced_training_service = EnhancedTrainingService(database_manager)
|
||||
|
||||
try:
|
||||
# Create initial training log entry first
|
||||
await enhanced_training_service._update_job_status_repository(
|
||||
job_id=job_id,
|
||||
status="pending",
|
||||
progress=0,
|
||||
current_step="Starting enhanced training job",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# This will be published by the training service itself
|
||||
# when it starts execution
|
||||
|
||||
training_config = {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"bakery_location": {
|
||||
"latitude": bakery_location[0],
|
||||
"longitude": bakery_location[1]
|
||||
},
|
||||
"requested_start": requested_start.isoformat() if requested_start else None,
|
||||
"requested_end": requested_end.isoformat() if requested_end else None,
|
||||
"estimated_duration_minutes": estimated_duration_minutes,
|
||||
"background_execution": True,
|
||||
"enhanced_features": True,
|
||||
"repository_pattern": True,
|
||||
"api_version": "enhanced_v1"
|
||||
}
|
||||
|
||||
# Update job status using repository pattern
|
||||
await enhanced_training_service._update_job_status_repository(
|
||||
job_id=job_id,
|
||||
status="running",
|
||||
progress=0,
|
||||
current_step="Initializing enhanced training pipeline",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# Execute the enhanced training pipeline with repository pattern
|
||||
result = await enhanced_training_service.start_training_job(
|
||||
tenant_id=tenant_id,
|
||||
job_id=job_id,
|
||||
bakery_location=bakery_location,
|
||||
requested_start=requested_start,
|
||||
requested_end=requested_end
|
||||
)
|
||||
|
||||
# Note: Final status is already updated by start_training_job() via complete_training_log()
|
||||
# No need for redundant update here - it was causing duplicate log entries
|
||||
|
||||
# Completion event is published by the training service
|
||||
|
||||
logger.info("Enhanced background training job completed successfully",
|
||||
job_id=job_id,
|
||||
models_created=result.get('products_trained', 0),
|
||||
features=["repository-pattern", "enhanced-tracking"])
|
||||
|
||||
except Exception as training_error:
|
||||
logger.error("Enhanced training pipeline failed",
|
||||
job_id=job_id,
|
||||
error=str(training_error))
|
||||
|
||||
try:
|
||||
await enhanced_training_service._update_job_status_repository(
|
||||
job_id=job_id,
|
||||
status="failed",
|
||||
progress=0,
|
||||
current_step="Enhanced training failed",
|
||||
error_message=str(training_error),
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
except Exception as status_error:
|
||||
logger.error("Failed to update job status after training error",
|
||||
job_id=job_id,
|
||||
status_error=str(status_error))
|
||||
|
||||
# Failure event is published by the training service
|
||||
await publish_training_failed(job_id, tenant_id, str(training_error))
|
||||
|
||||
finally:
|
||||
logger.info("Enhanced background training job cleanup completed",
|
||||
job_id=job_id)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_resource_detail_route("products", "inventory_product_id"), response_model=TrainingJobResponse)
|
||||
@require_user_role(['admin', 'owner'])
|
||||
@track_execution_time("enhanced_single_product_training_duration_seconds", "training-service")
|
||||
async def start_single_product_training(
|
||||
request: SingleProductTrainingRequest,
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
inventory_product_id: str = Path(..., description="Inventory product UUID"),
|
||||
background_tasks: BackgroundTasks = BackgroundTasks(),
|
||||
request_obj: Request = None,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
|
||||
):
|
||||
"""
|
||||
Start enhanced training for a single product (Admin+ only).
|
||||
|
||||
**RBAC:** Admin or Owner role required
|
||||
|
||||
Enhanced features:
|
||||
- Repository pattern for data access
|
||||
- Enhanced error handling and validation
|
||||
- Metrics tracking
|
||||
- Transactional operations
|
||||
- Background execution to prevent blocking
|
||||
"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
logger.info("Starting enhanced single product training",
|
||||
inventory_product_id=inventory_product_id,
|
||||
tenant_id=tenant_id)
|
||||
|
||||
# CRITICAL FIX: Check if this product is currently being trained
|
||||
# This prevents duplicate training from rapid-click scenarios
|
||||
async with enhanced_training_service.database_manager.get_session() as check_session:
|
||||
from app.repositories.training_log_repository import TrainingLogRepository
|
||||
log_repo = TrainingLogRepository(check_session)
|
||||
|
||||
# Check for active jobs for this specific product
|
||||
active_jobs = await log_repo.get_active_jobs(tenant_id=tenant_id)
|
||||
pending_jobs = await log_repo.get_logs_by_tenant(
|
||||
tenant_id=tenant_id,
|
||||
status="pending",
|
||||
limit=20
|
||||
)
|
||||
|
||||
all_active = active_jobs + pending_jobs
|
||||
|
||||
# Filter for jobs that include this specific product
|
||||
product_jobs = [
|
||||
job for job in all_active
|
||||
if job.config and (
|
||||
# Single product job for this product
|
||||
job.config.get("product_id") == inventory_product_id or
|
||||
# Tenant-wide job that would include this product
|
||||
job.config.get("job_type") == "tenant_training"
|
||||
)
|
||||
]
|
||||
|
||||
if product_jobs:
|
||||
existing_job = product_jobs[0]
|
||||
logger.warning("Product training already in progress, rejecting duplicate request",
|
||||
existing_job_id=existing_job.job_id,
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
status=existing_job.status)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail={
|
||||
"error": "Product training already in progress",
|
||||
"message": f"Product {inventory_product_id} is currently being trained in job {existing_job.job_id}",
|
||||
"existing_job_id": existing_job.job_id,
|
||||
"status": existing_job.status,
|
||||
"started_at": existing_job.created_at.isoformat() if existing_job.created_at else None
|
||||
}
|
||||
)
|
||||
|
||||
# No existing job, proceed with training
|
||||
# Record metrics
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_single_product_training_total")
|
||||
|
||||
# Generate enhanced job ID
|
||||
job_id = f"enhanced_single_{tenant_id}_{inventory_product_id}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# CRITICAL FIX: Add initial training log entry
|
||||
await enhanced_training_service._update_job_status_repository(
|
||||
job_id=job_id,
|
||||
status="pending",
|
||||
progress=0,
|
||||
current_step="Initializing single product training",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# Add enhanced background task for single product training
|
||||
background_tasks.add_task(
|
||||
execute_single_product_training_background,
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
job_id=job_id,
|
||||
bakery_location=request.bakery_location or (40.4168, -3.7038),
|
||||
database_manager=enhanced_training_service.database_manager
|
||||
)
|
||||
|
||||
# Return immediate response with job info
|
||||
response_data = {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"status": "pending",
|
||||
"message": "Enhanced single product training started successfully",
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"estimated_duration_minutes": 15, # Default estimate for single product
|
||||
"training_results": {
|
||||
"total_products": 1,
|
||||
"successful_trainings": 0,
|
||||
"failed_trainings": 0,
|
||||
"products": [],
|
||||
"overall_training_time_seconds": 0.0
|
||||
},
|
||||
"data_summary": None,
|
||||
"completed_at": None,
|
||||
"error_details": None,
|
||||
"processing_metadata": {
|
||||
"background_task": True,
|
||||
"async_execution": True,
|
||||
"enhanced_features": True,
|
||||
"repository_pattern": True,
|
||||
"dependency_injection": True
|
||||
}
|
||||
}
|
||||
|
||||
logger.info("Enhanced single product training queued successfully",
|
||||
inventory_product_id=inventory_product_id,
|
||||
job_id=job_id)
|
||||
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_single_product_training_queued_total")
|
||||
|
||||
return TrainingJobResponse(**response_data)
|
||||
|
||||
except ValueError as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_single_product_validation_errors_total")
|
||||
logger.error("Enhanced single product training validation error",
|
||||
error=str(e),
|
||||
inventory_product_id=inventory_product_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_single_product_training_errors_total")
|
||||
logger.error("Enhanced single product training failed",
|
||||
error=str(e),
|
||||
inventory_product_id=inventory_product_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Enhanced single product training failed"
|
||||
)
|
||||
|
||||
|
||||
async def execute_single_product_training_background(
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
job_id: str,
|
||||
bakery_location: tuple,
|
||||
database_manager
|
||||
):
|
||||
"""
|
||||
Enhanced background task that executes single product training using repository pattern.
|
||||
Uses a separate service instance to avoid session conflicts.
|
||||
"""
|
||||
logger.info("Enhanced background single product training started",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id)
|
||||
|
||||
# Create a new service instance with a fresh database session to avoid conflicts
|
||||
from app.services.training_service import EnhancedTrainingService
|
||||
fresh_training_service = EnhancedTrainingService(database_manager)
|
||||
|
||||
try:
|
||||
# Update job status to running
|
||||
await fresh_training_service._update_job_status_repository(
|
||||
job_id=job_id,
|
||||
status="running",
|
||||
progress=0,
|
||||
current_step="Starting single product training",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# Execute the enhanced single product training with repository pattern
|
||||
result = await fresh_training_service.start_single_product_training(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
job_id=job_id,
|
||||
bakery_location=bakery_location
|
||||
)
|
||||
|
||||
logger.info("Enhanced background single product training completed successfully",
|
||||
job_id=job_id,
|
||||
inventory_product_id=inventory_product_id)
|
||||
|
||||
except Exception as training_error:
|
||||
logger.error("Enhanced single product training failed",
|
||||
job_id=job_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(training_error))
|
||||
|
||||
try:
|
||||
await fresh_training_service._update_job_status_repository(
|
||||
job_id=job_id,
|
||||
status="failed",
|
||||
progress=0,
|
||||
current_step="Single product training failed",
|
||||
error_message=str(training_error),
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
except Exception as status_error:
|
||||
logger.error("Failed to update job status after training error",
|
||||
job_id=job_id,
|
||||
status_error=str(status_error))
|
||||
|
||||
finally:
|
||||
logger.info("Enhanced background single product training cleanup completed",
|
||||
job_id=job_id,
|
||||
inventory_product_id=inventory_product_id)
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint for the training operations"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "training-operations",
|
||||
"version": "3.0.0",
|
||||
"features": [
|
||||
"repository-pattern",
|
||||
"dependency-injection",
|
||||
"enhanced-error-handling",
|
||||
"metrics-tracking",
|
||||
"transactional-operations"
|
||||
],
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tenant Data Deletion Operations (Internal Service Only)
|
||||
# ============================================================================
|
||||
|
||||
@router.delete(
|
||||
route_builder.build_base_route("tenant/{tenant_id}", include_tenant_prefix=False),
|
||||
response_model=dict
|
||||
)
|
||||
@service_only_access
|
||||
async def delete_tenant_data(
|
||||
tenant_id: str = Path(..., description="Tenant ID to delete data for"),
|
||||
current_user: dict = Depends(get_current_user_dep)
|
||||
):
|
||||
"""
|
||||
Delete all training data for a tenant (Internal service only)
|
||||
|
||||
This endpoint is called by the orchestrator during tenant deletion.
|
||||
It permanently deletes all training-related data including:
|
||||
- Trained models (all versions)
|
||||
- Model artifacts (files and metadata)
|
||||
- Training logs and job history
|
||||
- Model performance metrics
|
||||
- Training job queue entries
|
||||
- Audit logs
|
||||
|
||||
**WARNING**: This operation is irreversible!
|
||||
**NOTE**: Physical model files (.pkl) should be cleaned up separately
|
||||
|
||||
Returns:
|
||||
Deletion summary with counts of deleted records
|
||||
"""
|
||||
from app.services.tenant_deletion_service import TrainingTenantDeletionService
|
||||
from app.core.config import settings
|
||||
|
||||
try:
|
||||
logger.info("training.tenant_deletion.api_called", tenant_id=tenant_id)
|
||||
|
||||
db_manager = create_database_manager(settings.DATABASE_URL, "training")
|
||||
|
||||
async with db_manager.get_session() as session:
|
||||
deletion_service = TrainingTenantDeletionService(session)
|
||||
result = await deletion_service.safe_delete_tenant_data(tenant_id)
|
||||
|
||||
if not result.success:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Tenant data deletion failed: {', '.join(result.errors)}"
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "Tenant data deletion completed successfully",
|
||||
"note": "Physical model files should be cleaned up separately from storage",
|
||||
"summary": result.to_dict()
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("training.tenant_deletion.api_error",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to delete tenant data: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("tenant/{tenant_id}/deletion-preview", include_tenant_prefix=False),
|
||||
response_model=dict
|
||||
)
|
||||
@service_only_access
|
||||
async def preview_tenant_data_deletion(
|
||||
tenant_id: str = Path(..., description="Tenant ID to preview deletion for"),
|
||||
current_user: dict = Depends(get_current_user_dep)
|
||||
):
|
||||
"""
|
||||
Preview what data would be deleted for a tenant (dry-run)
|
||||
|
||||
This endpoint shows counts of all data that would be deleted
|
||||
without actually deleting anything. Useful for:
|
||||
- Confirming deletion scope before execution
|
||||
- Auditing and compliance
|
||||
- Troubleshooting
|
||||
|
||||
Returns:
|
||||
Dictionary with entity names and their counts
|
||||
"""
|
||||
from app.services.tenant_deletion_service import TrainingTenantDeletionService
|
||||
from app.core.config import settings
|
||||
|
||||
try:
|
||||
logger.info("training.tenant_deletion.preview_called", tenant_id=tenant_id)
|
||||
|
||||
db_manager = create_database_manager(settings.DATABASE_URL, "training")
|
||||
|
||||
async with db_manager.get_session() as session:
|
||||
deletion_service = TrainingTenantDeletionService(session)
|
||||
preview = await deletion_service.get_tenant_data_preview(tenant_id)
|
||||
|
||||
total_records = sum(preview.values())
|
||||
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"service": "training",
|
||||
"preview": preview,
|
||||
"total_records": total_records,
|
||||
"note": "Physical model files (.pkl, metadata) are not counted here",
|
||||
"warning": "These records will be permanently deleted and cannot be recovered"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("training.tenant_deletion.preview_error",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to preview tenant data deletion: {str(e)}"
|
||||
)
|
||||
163
services/training/app/api/websocket_operations.py
Normal file
163
services/training/app/api/websocket_operations.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
WebSocket Operations for Training Service
|
||||
Simple WebSocket endpoint that connects clients and receives broadcasts from RabbitMQ
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Path, Query
|
||||
import structlog
|
||||
|
||||
from app.websocket.manager import websocket_manager
|
||||
from shared.auth.jwt_handler import JWTHandler
|
||||
from app.core.config import settings
|
||||
from app.services.training_service import EnhancedTrainingService
|
||||
from shared.database.base import create_database_manager
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
router = APIRouter(tags=["websocket"])
|
||||
|
||||
|
||||
def get_enhanced_training_service():
|
||||
"""Create EnhancedTrainingService instance"""
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
|
||||
return EnhancedTrainingService(database_manager)
|
||||
|
||||
|
||||
@router.websocket("/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live")
|
||||
async def training_progress_websocket(
|
||||
websocket: WebSocket,
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
job_id: str = Path(..., description="Job ID"),
|
||||
token: str = Query(..., description="Authentication token")
|
||||
):
|
||||
"""
|
||||
WebSocket endpoint for real-time training progress updates.
|
||||
|
||||
This endpoint:
|
||||
1. Validates the authentication token
|
||||
2. Accepts the WebSocket connection
|
||||
3. Keeps the connection alive
|
||||
4. Receives broadcasts from RabbitMQ (via WebSocket manager)
|
||||
"""
|
||||
|
||||
# Validate token
|
||||
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
|
||||
|
||||
try:
|
||||
payload = jwt_handler.verify_token(token)
|
||||
if not payload:
|
||||
await websocket.close(code=1008, reason="Invalid token")
|
||||
logger.warning("WebSocket connection rejected - invalid token",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id)
|
||||
return
|
||||
|
||||
user_id = payload.get('user_id')
|
||||
if not user_id:
|
||||
await websocket.close(code=1008, reason="Invalid token payload")
|
||||
logger.warning("WebSocket connection rejected - no user_id in token",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id)
|
||||
return
|
||||
|
||||
logger.info("WebSocket authentication successful",
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
job_id=job_id)
|
||||
|
||||
except Exception as e:
|
||||
await websocket.close(code=1008, reason="Authentication failed")
|
||||
logger.warning("WebSocket authentication failed",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return
|
||||
|
||||
# Connect to WebSocket manager
|
||||
await websocket_manager.connect(job_id, websocket)
|
||||
|
||||
# Helper function to send current job status
|
||||
async def send_current_status():
|
||||
"""Fetch and send the current job status to the client"""
|
||||
try:
|
||||
training_service = get_enhanced_training_service()
|
||||
status_info = await training_service.get_training_status(job_id)
|
||||
|
||||
if status_info and not status_info.get("error"):
|
||||
# Map status to WebSocket message type
|
||||
ws_type = "progress"
|
||||
if status_info.get("status") == "completed":
|
||||
ws_type = "completed"
|
||||
elif status_info.get("status") == "failed":
|
||||
ws_type = "failed"
|
||||
|
||||
await websocket.send_json({
|
||||
"type": ws_type,
|
||||
"job_id": job_id,
|
||||
"data": {
|
||||
"progress": status_info.get("progress", 0),
|
||||
"current_step": status_info.get("current_step"),
|
||||
"status": status_info.get("status"),
|
||||
"products_total": status_info.get("products_total", 0),
|
||||
"products_completed": status_info.get("products_completed", 0),
|
||||
"products_failed": status_info.get("products_failed", 0),
|
||||
"estimated_time_remaining_seconds": status_info.get("estimated_time_remaining_seconds"),
|
||||
"message": status_info.get("message")
|
||||
}
|
||||
})
|
||||
logger.info("Sent current job status to client",
|
||||
job_id=job_id,
|
||||
status=status_info.get("status"),
|
||||
progress=status_info.get("progress"))
|
||||
except Exception as e:
|
||||
logger.error("Failed to send current job status",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
|
||||
try:
|
||||
# Send connection confirmation
|
||||
await websocket.send_json({
|
||||
"type": "connected",
|
||||
"job_id": job_id,
|
||||
"message": "Connected to training progress stream"
|
||||
})
|
||||
|
||||
# Immediately send current job status after connection
|
||||
# This handles the race condition where training completes before WebSocket connects
|
||||
await send_current_status()
|
||||
|
||||
# Keep connection alive and handle client messages
|
||||
ping_count = 0
|
||||
while True:
|
||||
try:
|
||||
# Receive messages from client (ping, get_status, etc.)
|
||||
data = await websocket.receive_text()
|
||||
|
||||
# Handle ping/pong
|
||||
if data == "ping":
|
||||
await websocket.send_text("pong")
|
||||
ping_count += 1
|
||||
logger.debug("WebSocket ping/pong",
|
||||
job_id=job_id,
|
||||
ping_count=ping_count,
|
||||
connection_healthy=True)
|
||||
# Handle get_status request
|
||||
elif data == "get_status":
|
||||
await send_current_status()
|
||||
logger.info("Status requested by client", job_id=job_id)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info("Client disconnected", job_id=job_id)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error("Error in WebSocket message loop",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
break
|
||||
|
||||
finally:
|
||||
# Disconnect from manager
|
||||
await websocket_manager.disconnect(job_id, websocket)
|
||||
logger.info("WebSocket connection closed",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id)
|
||||
435
services/training/app/consumers/training_event_consumer.py
Normal file
435
services/training/app/consumers/training_event_consumer.py
Normal file
@@ -0,0 +1,435 @@
|
||||
"""
|
||||
Training Event Consumer
|
||||
Processes ML model retraining requests from RabbitMQ
|
||||
Queues training jobs and manages model lifecycle
|
||||
"""
|
||||
import json
|
||||
import structlog
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from shared.messaging import RabbitMQClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class TrainingEventConsumer:
|
||||
"""
|
||||
Consumes training retraining events and queues ML training jobs
|
||||
Ensures no duplicate training jobs and manages priorities
|
||||
"""
|
||||
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self.db_session = db_session
|
||||
|
||||
async def consume_training_events(
|
||||
self,
|
||||
rabbitmq_client: RabbitMQClient
|
||||
):
|
||||
"""
|
||||
Start consuming training events from RabbitMQ
|
||||
"""
|
||||
async def process_message(message):
|
||||
"""Process a single training event message"""
|
||||
try:
|
||||
async with message.process():
|
||||
# Parse event data
|
||||
event_data = json.loads(message.body.decode())
|
||||
logger.info(
|
||||
"Received training event",
|
||||
event_id=event_data.get('event_id'),
|
||||
event_type=event_data.get('event_type'),
|
||||
tenant_id=event_data.get('tenant_id')
|
||||
)
|
||||
|
||||
# Process the event
|
||||
await self.process_training_event(event_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error processing training event",
|
||||
error=str(e),
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
# Start consuming events
|
||||
await rabbitmq_client.consume_events(
|
||||
exchange_name="training.events",
|
||||
queue_name="training.retraining.queue",
|
||||
routing_key="training.retrain.*",
|
||||
callback=process_message
|
||||
)
|
||||
|
||||
logger.info("Started consuming training events")
|
||||
|
||||
async def process_training_event(self, event_data: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Process a training event based on type
|
||||
|
||||
Args:
|
||||
event_data: Full event payload from RabbitMQ
|
||||
|
||||
Returns:
|
||||
bool: True if processed successfully
|
||||
"""
|
||||
try:
|
||||
event_type = event_data.get('event_type')
|
||||
data = event_data.get('data', {})
|
||||
tenant_id = event_data.get('tenant_id')
|
||||
|
||||
if not tenant_id:
|
||||
logger.warning("Training event missing tenant_id", event_data=event_data)
|
||||
return False
|
||||
|
||||
# Route to appropriate handler
|
||||
if event_type == 'training.retrain.requested':
|
||||
success = await self._handle_retrain_requested(tenant_id, data, event_data)
|
||||
elif event_type == 'training.retrain.scheduled':
|
||||
success = await self._handle_retrain_scheduled(tenant_id, data)
|
||||
else:
|
||||
logger.warning("Unknown training event type", event_type=event_type)
|
||||
success = True # Mark as processed to avoid retry
|
||||
|
||||
if success:
|
||||
logger.info(
|
||||
"Training event processed successfully",
|
||||
event_type=event_type,
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
"Training event processing failed",
|
||||
event_type=event_type,
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error in process_training_event",
|
||||
error=str(e),
|
||||
event_id=event_data.get('event_id'),
|
||||
exc_info=True
|
||||
)
|
||||
return False
|
||||
|
||||
async def _handle_retrain_requested(
|
||||
self,
|
||||
tenant_id: str,
|
||||
data: Dict[str, Any],
|
||||
event_data: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""
|
||||
Handle retraining request event
|
||||
|
||||
Validates model, checks for existing jobs, queues training job
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
data: Retraining request data
|
||||
event_data: Full event payload
|
||||
|
||||
Returns:
|
||||
bool: True if handled successfully
|
||||
"""
|
||||
try:
|
||||
model_id = data.get('model_id')
|
||||
product_id = data.get('product_id')
|
||||
trigger_reason = data.get('trigger_reason', 'unknown')
|
||||
priority = data.get('priority', 'normal')
|
||||
event_id = event_data.get('event_id')
|
||||
|
||||
if not model_id:
|
||||
logger.warning("Retraining request missing model_id", data=data)
|
||||
return False
|
||||
|
||||
# Validate model exists
|
||||
from app.models import TrainedModel
|
||||
|
||||
stmt = select(TrainedModel).where(
|
||||
TrainedModel.id == UUID(model_id),
|
||||
TrainedModel.tenant_id == UUID(tenant_id)
|
||||
)
|
||||
result = await self.db_session.execute(stmt)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if not model:
|
||||
logger.error(
|
||||
"Model not found for retraining",
|
||||
model_id=model_id,
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
return False
|
||||
|
||||
# Check if model is already in training
|
||||
if model.status in ['training', 'retraining_queued']:
|
||||
logger.info(
|
||||
"Model already in training, skipping duplicate request",
|
||||
model_id=model_id,
|
||||
current_status=model.status
|
||||
)
|
||||
return True # Consider successful (idempotent)
|
||||
|
||||
# Check for existing job in queue
|
||||
from app.models import TrainingJobQueue
|
||||
|
||||
existing_job_stmt = select(TrainingJobQueue).where(
|
||||
TrainingJobQueue.model_id == UUID(model_id),
|
||||
TrainingJobQueue.status.in_(['pending', 'running'])
|
||||
)
|
||||
existing_job_result = await self.db_session.execute(existing_job_stmt)
|
||||
existing_job = existing_job_result.scalar_one_or_none()
|
||||
|
||||
if existing_job:
|
||||
logger.info(
|
||||
"Training job already queued, skipping duplicate",
|
||||
model_id=model_id,
|
||||
job_id=str(existing_job.id)
|
||||
)
|
||||
return True # Idempotent
|
||||
|
||||
# Queue training job
|
||||
job_id = await self._queue_training_job(
|
||||
tenant_id=tenant_id,
|
||||
model_id=model_id,
|
||||
product_id=product_id,
|
||||
trigger_reason=trigger_reason,
|
||||
priority=priority,
|
||||
event_id=event_id,
|
||||
metadata=data
|
||||
)
|
||||
|
||||
if not job_id:
|
||||
logger.error("Failed to queue training job", model_id=model_id)
|
||||
return False
|
||||
|
||||
# Update model status
|
||||
model.status = 'retraining_queued'
|
||||
model.updated_at = datetime.utcnow()
|
||||
await self.db_session.commit()
|
||||
|
||||
# Publish job queued event
|
||||
await self._publish_job_queued_event(
|
||||
tenant_id=tenant_id,
|
||||
model_id=model_id,
|
||||
job_id=job_id,
|
||||
priority=priority
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Retraining job queued successfully",
|
||||
model_id=model_id,
|
||||
job_id=job_id,
|
||||
trigger_reason=trigger_reason,
|
||||
priority=priority
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
await self.db_session.rollback()
|
||||
logger.error(
|
||||
"Error handling retrain requested",
|
||||
error=str(e),
|
||||
model_id=data.get('model_id'),
|
||||
exc_info=True
|
||||
)
|
||||
return False
|
||||
|
||||
async def _handle_retrain_scheduled(
|
||||
self,
|
||||
tenant_id: str,
|
||||
data: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""
|
||||
Handle scheduled retraining event
|
||||
|
||||
Similar to retrain_requested but for scheduled/batch retraining
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
data: Scheduled retraining data
|
||||
|
||||
Returns:
|
||||
bool: True if handled successfully
|
||||
"""
|
||||
try:
|
||||
# Similar logic to _handle_retrain_requested
|
||||
# but may have different priority or batching logic
|
||||
logger.info(
|
||||
"Handling scheduled retraining",
|
||||
tenant_id=tenant_id,
|
||||
model_count=len(data.get('models', []))
|
||||
)
|
||||
|
||||
# For now, redirect to retrain_requested handler
|
||||
success_count = 0
|
||||
for model_data in data.get('models', []):
|
||||
if await self._handle_retrain_requested(
|
||||
tenant_id,
|
||||
model_data,
|
||||
{'event_id': data.get('schedule_id'), 'tenant_id': tenant_id}
|
||||
):
|
||||
success_count += 1
|
||||
|
||||
logger.info(
|
||||
"Scheduled retraining processed",
|
||||
tenant_id=tenant_id,
|
||||
successful=success_count,
|
||||
total=len(data.get('models', []))
|
||||
)
|
||||
|
||||
return success_count > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error handling retrain scheduled",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id,
|
||||
exc_info=True
|
||||
)
|
||||
return False
|
||||
|
||||
async def _queue_training_job(
|
||||
self,
|
||||
tenant_id: str,
|
||||
model_id: str,
|
||||
product_id: str,
|
||||
trigger_reason: str,
|
||||
priority: str,
|
||||
event_id: str,
|
||||
metadata: Dict[str, Any]
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Queue a training job in the database
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
model_id: Model ID to retrain
|
||||
product_id: Product ID
|
||||
trigger_reason: Why retraining was triggered
|
||||
priority: Job priority (low, normal, high)
|
||||
event_id: Originating event ID
|
||||
metadata: Additional job metadata
|
||||
|
||||
Returns:
|
||||
Job ID if successful, None otherwise
|
||||
"""
|
||||
try:
|
||||
from app.models import TrainingJobQueue
|
||||
import uuid
|
||||
|
||||
# Map priority to numeric value for sorting
|
||||
priority_map = {
|
||||
'low': 1,
|
||||
'normal': 2,
|
||||
'high': 3,
|
||||
'critical': 4
|
||||
}
|
||||
|
||||
job = TrainingJobQueue(
|
||||
id=uuid.uuid4(),
|
||||
tenant_id=UUID(tenant_id),
|
||||
model_id=UUID(model_id),
|
||||
product_id=UUID(product_id) if product_id else None,
|
||||
job_type='retrain',
|
||||
status='pending',
|
||||
priority=priority,
|
||||
priority_score=priority_map.get(priority, 2),
|
||||
trigger_reason=trigger_reason,
|
||||
event_id=event_id,
|
||||
metadata=metadata,
|
||||
created_at=datetime.utcnow(),
|
||||
scheduled_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
self.db_session.add(job)
|
||||
await self.db_session.commit()
|
||||
|
||||
logger.info(
|
||||
"Training job created",
|
||||
job_id=str(job.id),
|
||||
model_id=model_id,
|
||||
priority=priority,
|
||||
trigger_reason=trigger_reason
|
||||
)
|
||||
|
||||
return str(job.id)
|
||||
|
||||
except Exception as e:
|
||||
await self.db_session.rollback()
|
||||
logger.error(
|
||||
"Failed to queue training job",
|
||||
model_id=model_id,
|
||||
error=str(e),
|
||||
exc_info=True
|
||||
)
|
||||
return None
|
||||
|
||||
async def _publish_job_queued_event(
|
||||
self,
|
||||
tenant_id: str,
|
||||
model_id: str,
|
||||
job_id: str,
|
||||
priority: str
|
||||
):
|
||||
"""
|
||||
Publish event that training job was queued
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
model_id: Model ID
|
||||
job_id: Training job ID
|
||||
priority: Job priority
|
||||
"""
|
||||
try:
|
||||
from shared.messaging import get_rabbitmq_client
|
||||
import uuid
|
||||
|
||||
rabbitmq_client = get_rabbitmq_client()
|
||||
if not rabbitmq_client:
|
||||
logger.warning("RabbitMQ client not available for event publishing")
|
||||
return
|
||||
|
||||
event_payload = {
|
||||
"event_id": str(uuid.uuid4()),
|
||||
"event_type": "training.retrain.queued",
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"tenant_id": tenant_id,
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"model_id": model_id,
|
||||
"priority": priority,
|
||||
"status": "queued"
|
||||
}
|
||||
}
|
||||
|
||||
await rabbitmq_client.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.retrain.queued",
|
||||
event_data=event_payload
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Published job queued event",
|
||||
job_id=job_id,
|
||||
event_id=event_payload["event_id"]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to publish job queued event",
|
||||
job_id=job_id,
|
||||
error=str(e)
|
||||
)
|
||||
# Don't fail the main operation if event publishing fails
|
||||
|
||||
|
||||
# Factory function for creating consumer instance
|
||||
def create_training_event_consumer(db_session: AsyncSession) -> TrainingEventConsumer:
|
||||
"""Create training event consumer instance"""
|
||||
return TrainingEventConsumer(db_session)
|
||||
0
services/training/app/core/__init__.py
Normal file
0
services/training/app/core/__init__.py
Normal file
89
services/training/app/core/config.py
Normal file
89
services/training/app/core/config.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# ================================================================
|
||||
# TRAINING SERVICE CONFIGURATION
|
||||
# services/training/app/core/config.py
|
||||
# ================================================================
|
||||
|
||||
"""
|
||||
Training service configuration
|
||||
ML model training and management
|
||||
"""
|
||||
|
||||
from shared.config.base import BaseServiceSettings
|
||||
import os
|
||||
|
||||
class TrainingSettings(BaseServiceSettings):
|
||||
"""Training service specific settings"""
|
||||
|
||||
# Service Identity
|
||||
APP_NAME: str = "Training Service"
|
||||
SERVICE_NAME: str = "training-service"
|
||||
DESCRIPTION: str = "Machine learning model training service"
|
||||
|
||||
# Database configuration (secure approach - build from components)
|
||||
@property
|
||||
def DATABASE_URL(self) -> str:
|
||||
"""Build database URL from secure components"""
|
||||
# Try complete URL first (for backward compatibility)
|
||||
complete_url = os.getenv("TRAINING_DATABASE_URL")
|
||||
if complete_url:
|
||||
return complete_url
|
||||
|
||||
# Build from components (secure approach)
|
||||
user = os.getenv("TRAINING_DB_USER", "training_user")
|
||||
password = os.getenv("TRAINING_DB_PASSWORD", "training_pass123")
|
||||
host = os.getenv("TRAINING_DB_HOST", "localhost")
|
||||
port = os.getenv("TRAINING_DB_PORT", "5432")
|
||||
name = os.getenv("TRAINING_DB_NAME", "training_db")
|
||||
|
||||
return f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{name}"
|
||||
|
||||
# Redis Database (dedicated for training cache)
|
||||
REDIS_DB: int = 1
|
||||
|
||||
# ML Model Storage
|
||||
MODEL_BACKUP_ENABLED: bool = os.getenv("MODEL_BACKUP_ENABLED", "true").lower() == "true"
|
||||
MODEL_VERSIONING_ENABLED: bool = os.getenv("MODEL_VERSIONING_ENABLED", "true").lower() == "true"
|
||||
|
||||
# MinIO Configuration
|
||||
MINIO_ENDPOINT: str = os.getenv("MINIO_ENDPOINT", "minio.bakery-ia.svc.cluster.local:9000")
|
||||
MINIO_ACCESS_KEY: str = os.getenv("MINIO_ACCESS_KEY", "training-service")
|
||||
MINIO_SECRET_KEY: str = os.getenv("MINIO_SECRET_KEY", "training-secret-key")
|
||||
MINIO_USE_SSL: bool = os.getenv("MINIO_USE_SSL", "true").lower() == "true"
|
||||
MINIO_MODEL_BUCKET: str = os.getenv("MINIO_MODEL_BUCKET", "training-models")
|
||||
MINIO_CONSOLE_PORT: str = os.getenv("MINIO_CONSOLE_PORT", "9001")
|
||||
MINIO_API_PORT: str = os.getenv("MINIO_API_PORT", "9000")
|
||||
MINIO_REGION: str = os.getenv("MINIO_REGION", "us-east-1")
|
||||
MINIO_MODEL_LIFECYCLE_DAYS: int = int(os.getenv("MINIO_MODEL_LIFECYCLE_DAYS", "90"))
|
||||
MINIO_CACHE_TTL_SECONDS: int = int(os.getenv("MINIO_CACHE_TTL_SECONDS", "3600"))
|
||||
|
||||
# Training Configuration
|
||||
MAX_CONCURRENT_TRAINING_JOBS: int = int(os.getenv("MAX_CONCURRENT_TRAINING_JOBS", "3"))
|
||||
|
||||
# Prophet Specific Configuration
|
||||
PROPHET_HOLIDAYS_PRIOR_SCALE: float = float(os.getenv("PROPHET_HOLIDAYS_PRIOR_SCALE", "10.0"))
|
||||
|
||||
# Spanish Holiday Integration
|
||||
ENABLE_CUSTOM_HOLIDAYS: bool = os.getenv("ENABLE_CUSTOM_HOLIDAYS", "true").lower() == "true"
|
||||
|
||||
# Data Processing
|
||||
DATA_PREPROCESSING_ENABLED: bool = True
|
||||
OUTLIER_DETECTION_ENABLED: bool = os.getenv("OUTLIER_DETECTION_ENABLED", "true").lower() == "true"
|
||||
SEASONAL_DECOMPOSITION_ENABLED: bool = os.getenv("SEASONAL_DECOMPOSITION_ENABLED", "true").lower() == "true"
|
||||
|
||||
# Model Validation
|
||||
CROSS_VALIDATION_ENABLED: bool = os.getenv("CROSS_VALIDATION_ENABLED", "true").lower() == "true"
|
||||
VALIDATION_SPLIT_RATIO: float = float(os.getenv("VALIDATION_SPLIT_RATIO", "0.2"))
|
||||
MIN_MODEL_ACCURACY: float = float(os.getenv("MIN_MODEL_ACCURACY", "0.7"))
|
||||
|
||||
# Distributed Training (for future scaling)
|
||||
DISTRIBUTED_TRAINING_ENABLED: bool = os.getenv("DISTRIBUTED_TRAINING_ENABLED", "false").lower() == "true"
|
||||
TRAINING_WORKER_COUNT: int = int(os.getenv("TRAINING_WORKER_COUNT", "1"))
|
||||
|
||||
PROPHET_DAILY_SEASONALITY: bool = True
|
||||
PROPHET_WEEKLY_SEASONALITY: bool = True
|
||||
PROPHET_YEARLY_SEASONALITY: bool = True
|
||||
|
||||
# Throttling settings for parallel training to prevent heartbeat blocking
|
||||
MAX_CONCURRENT_TRAININGS: int = int(os.getenv("MAX_CONCURRENT_TRAININGS", "3"))
|
||||
|
||||
settings = TrainingSettings()
|
||||
97
services/training/app/core/constants.py
Normal file
97
services/training/app/core/constants.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Training Service Constants
|
||||
Centralized constants to avoid magic numbers throughout the codebase
|
||||
"""
|
||||
|
||||
# Data Validation Thresholds
|
||||
MIN_DATA_POINTS_REQUIRED = 30
|
||||
RECOMMENDED_DATA_POINTS = 90
|
||||
MAX_ZERO_RATIO_ERROR = 0.9 # 90% zeros = error
|
||||
HIGH_ZERO_RATIO_WARNING = 0.7 # 70% zeros = warning
|
||||
MAX_ZERO_RATIO_INTERMITTENT = 0.8 # Products with >80% zeros are intermittent
|
||||
MODERATE_SPARSITY_THRESHOLD = 0.6 # 60% zeros = moderate sparsity
|
||||
|
||||
# Training Time Periods (in days)
|
||||
MIN_NON_ZERO_DAYS = 30 # Minimum days with non-zero sales
|
||||
DATA_QUALITY_DAY_THRESHOLD_LOW = 90
|
||||
DATA_QUALITY_DAY_THRESHOLD_HIGH = 365
|
||||
MAX_TRAINING_RANGE_DAYS = 730 # 2 years
|
||||
MIN_TRAINING_RANGE_DAYS = 30
|
||||
|
||||
# Product Classification Thresholds
|
||||
HIGH_VOLUME_MEAN_SALES = 10.0
|
||||
HIGH_VOLUME_ZERO_RATIO = 0.3
|
||||
MEDIUM_VOLUME_MEAN_SALES = 5.0
|
||||
MEDIUM_VOLUME_ZERO_RATIO = 0.5
|
||||
LOW_VOLUME_MEAN_SALES = 2.0
|
||||
LOW_VOLUME_ZERO_RATIO = 0.7
|
||||
|
||||
# Hyperparameter Optimization
|
||||
OPTUNA_TRIALS_HIGH_VOLUME = 30
|
||||
OPTUNA_TRIALS_MEDIUM_VOLUME = 25
|
||||
OPTUNA_TRIALS_LOW_VOLUME = 20
|
||||
OPTUNA_TRIALS_INTERMITTENT = 15
|
||||
OPTUNA_TIMEOUT_SECONDS = 600
|
||||
|
||||
# Prophet Uncertainty Sampling
|
||||
UNCERTAINTY_SAMPLES_SPARSE_MIN = 100
|
||||
UNCERTAINTY_SAMPLES_SPARSE_MAX = 200
|
||||
UNCERTAINTY_SAMPLES_LOW_MIN = 150
|
||||
UNCERTAINTY_SAMPLES_LOW_MAX = 300
|
||||
UNCERTAINTY_SAMPLES_MEDIUM_MIN = 200
|
||||
UNCERTAINTY_SAMPLES_MEDIUM_MAX = 500
|
||||
UNCERTAINTY_SAMPLES_HIGH_MIN = 300
|
||||
UNCERTAINTY_SAMPLES_HIGH_MAX = 800
|
||||
|
||||
# MAPE Calculation
|
||||
MAPE_LOW_VOLUME_THRESHOLD = 2.0
|
||||
MAPE_MEDIUM_VOLUME_THRESHOLD = 5.0
|
||||
MAPE_CALCULATION_MIN_THRESHOLD = 0.5
|
||||
MAPE_CALCULATION_MID_THRESHOLD = 1.0
|
||||
MAPE_MAX_CAP = 200.0 # Cap MAPE at 200%
|
||||
MAPE_MEDIUM_CAP = 150.0
|
||||
|
||||
# Baseline MAPE estimates for improvement calculation
|
||||
BASELINE_MAPE_VERY_SPARSE = 80.0
|
||||
BASELINE_MAPE_SPARSE = 60.0
|
||||
BASELINE_MAPE_HIGH_VOLUME = 25.0
|
||||
BASELINE_MAPE_MEDIUM_VOLUME = 35.0
|
||||
BASELINE_MAPE_LOW_VOLUME = 45.0
|
||||
IMPROVEMENT_SIGNIFICANCE_THRESHOLD = 0.8 # Only claim improvement if MAPE < 80% of baseline
|
||||
|
||||
# Cross-validation
|
||||
CV_N_SPLITS = 2
|
||||
CV_MIN_VALIDATION_DAYS = 7
|
||||
|
||||
# Progress tracking
|
||||
PROGRESS_DATA_PREPARATION_START = 0
|
||||
PROGRESS_DATA_PREPARATION_END = 45
|
||||
PROGRESS_MODEL_TRAINING_START = 45
|
||||
PROGRESS_MODEL_TRAINING_END = 85
|
||||
PROGRESS_FINALIZATION_START = 85
|
||||
PROGRESS_FINALIZATION_END = 100
|
||||
|
||||
# HTTP Client Configuration
|
||||
HTTP_TIMEOUT_DEFAULT = 30.0 # seconds
|
||||
HTTP_TIMEOUT_LONG_RUNNING = 60.0 # for training data fetches
|
||||
HTTP_MAX_RETRIES = 3
|
||||
HTTP_RETRY_BACKOFF_FACTOR = 2.0
|
||||
|
||||
# WebSocket Configuration
|
||||
WEBSOCKET_PING_TIMEOUT = 60.0 # seconds
|
||||
WEBSOCKET_ACTIVITY_WARNING_THRESHOLD = 90.0 # seconds
|
||||
WEBSOCKET_CONSUMER_HEARTBEAT_INTERVAL = 10.0 # seconds
|
||||
|
||||
# Synthetic Data Generation
|
||||
SYNTHETIC_TEMP_DEFAULT = 50.0
|
||||
SYNTHETIC_TEMP_VARIATION = 100.0
|
||||
SYNTHETIC_TRAFFIC_DEFAULT = 50.0
|
||||
SYNTHETIC_TRAFFIC_VARIATION = 100.0
|
||||
|
||||
# Model Storage
|
||||
MODEL_FILE_EXTENSION = ".pkl"
|
||||
METADATA_FILE_EXTENSION = ".json"
|
||||
|
||||
# Data Quality Scoring
|
||||
MIN_QUALITY_SCORE = 0.1
|
||||
MAX_QUALITY_SCORE = 1.0
|
||||
432
services/training/app/core/database.py
Normal file
432
services/training/app/core/database.py
Normal file
@@ -0,0 +1,432 @@
|
||||
# services/training/app/core/database.py
|
||||
"""
|
||||
Database configuration for training service
|
||||
Uses shared database infrastructure
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from typing import AsyncGenerator
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from contextlib import asynccontextmanager
|
||||
from sqlalchemy import text
|
||||
|
||||
from shared.database.base import DatabaseManager, Base
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Initialize database manager with connection pooling configuration
|
||||
database_manager = DatabaseManager(
|
||||
settings.DATABASE_URL,
|
||||
pool_size=settings.DB_POOL_SIZE,
|
||||
max_overflow=settings.DB_MAX_OVERFLOW,
|
||||
pool_timeout=settings.DB_POOL_TIMEOUT,
|
||||
pool_recycle=settings.DB_POOL_RECYCLE,
|
||||
pool_pre_ping=settings.DB_POOL_PRE_PING,
|
||||
echo=settings.DB_ECHO
|
||||
)
|
||||
|
||||
# Alias for convenience - matches the existing interface
|
||||
get_db = database_manager.get_db
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_background_db_session():
|
||||
async with database_manager.async_session_local() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
async def get_db_health() -> bool:
|
||||
"""
|
||||
Health check function for database connectivity
|
||||
Enhanced version of the shared functionality
|
||||
"""
|
||||
try:
|
||||
async with database_manager.async_engine.begin() as conn:
|
||||
await conn.execute(text("SELECT 1"))
|
||||
logger.debug("Database health check passed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Database health check failed", error=str(e))
|
||||
return False
|
||||
|
||||
async def get_comprehensive_db_health() -> dict:
|
||||
"""
|
||||
Comprehensive health check that verifies both connectivity and table existence
|
||||
"""
|
||||
health_status = {
|
||||
"status": "healthy",
|
||||
"connectivity": False,
|
||||
"tables_exist": False,
|
||||
"tables_verified": [],
|
||||
"missing_tables": [],
|
||||
"errors": []
|
||||
}
|
||||
|
||||
try:
|
||||
# Test basic connectivity
|
||||
health_status["connectivity"] = await get_db_health()
|
||||
|
||||
if not health_status["connectivity"]:
|
||||
health_status["status"] = "unhealthy"
|
||||
health_status["errors"].append("Database connectivity failed")
|
||||
return health_status
|
||||
|
||||
# Test table existence
|
||||
tables_verified = await _verify_tables_exist()
|
||||
health_status["tables_exist"] = tables_verified
|
||||
|
||||
if tables_verified:
|
||||
health_status["tables_verified"] = [
|
||||
'model_training_logs', 'trained_models', 'model_performance_metrics',
|
||||
'training_job_queue', 'model_artifacts'
|
||||
]
|
||||
else:
|
||||
health_status["status"] = "unhealthy"
|
||||
health_status["errors"].append("Required tables missing or inaccessible")
|
||||
|
||||
# Try to identify which specific tables are missing
|
||||
try:
|
||||
async with database_manager.get_session() as session:
|
||||
for table_name in ['model_training_logs', 'trained_models', 'model_performance_metrics',
|
||||
'training_job_queue', 'model_artifacts']:
|
||||
try:
|
||||
await session.execute(text(f"SELECT 1 FROM {table_name} LIMIT 1"))
|
||||
health_status["tables_verified"].append(table_name)
|
||||
except Exception:
|
||||
health_status["missing_tables"].append(table_name)
|
||||
except Exception as e:
|
||||
health_status["errors"].append(f"Error checking individual tables: {str(e)}")
|
||||
|
||||
logger.debug("Comprehensive database health check completed",
|
||||
status=health_status["status"],
|
||||
connectivity=health_status["connectivity"],
|
||||
tables_exist=health_status["tables_exist"])
|
||||
|
||||
except Exception as e:
|
||||
health_status["status"] = "unhealthy"
|
||||
health_status["errors"].append(f"Health check failed: {str(e)}")
|
||||
logger.error("Comprehensive database health check failed", error=str(e))
|
||||
|
||||
return health_status
|
||||
|
||||
# Training service specific database utilities
|
||||
class TrainingDatabaseUtils:
|
||||
"""Training service specific database utilities"""
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_old_training_logs(days_old: int = 90):
|
||||
"""Clean up old training logs"""
|
||||
try:
|
||||
async with database_manager.async_session_local() as session:
|
||||
if settings.DATABASE_URL.startswith("sqlite"):
|
||||
query = text(
|
||||
"DELETE FROM model_training_logs "
|
||||
"WHERE start_time < datetime('now', :days_param)"
|
||||
)
|
||||
params = {"days_param": f"-{days_old} days"}
|
||||
else:
|
||||
query = text(
|
||||
"DELETE FROM model_training_logs "
|
||||
"WHERE start_time < NOW() - INTERVAL :days_param"
|
||||
)
|
||||
params = {"days_param": f"{days_old} days"}
|
||||
|
||||
result = await session.execute(query, params)
|
||||
await session.commit()
|
||||
|
||||
deleted_count = result.rowcount
|
||||
logger.info("Cleaned up old training logs",
|
||||
deleted_count=deleted_count,
|
||||
days_old=days_old)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Training logs cleanup failed", error=str(e))
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_old_models(days_old: int = 365):
|
||||
"""Clean up old inactive models"""
|
||||
try:
|
||||
async with database_manager.async_session_local() as session:
|
||||
if settings.DATABASE_URL.startswith("sqlite"):
|
||||
query = text(
|
||||
"DELETE FROM trained_models "
|
||||
"WHERE is_active = 0 AND created_at < datetime('now', :days_param)"
|
||||
)
|
||||
params = {"days_param": f"-{days_old} days"}
|
||||
else:
|
||||
query = text(
|
||||
"DELETE FROM trained_models "
|
||||
"WHERE is_active = false AND created_at < NOW() - INTERVAL :days_param"
|
||||
)
|
||||
params = {"days_param": f"{days_old} days"}
|
||||
|
||||
result = await session.execute(query, params)
|
||||
await session.commit()
|
||||
|
||||
deleted_count = result.rowcount
|
||||
logger.info("Cleaned up old models",
|
||||
deleted_count=deleted_count,
|
||||
days_old=days_old)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Model cleanup failed", error=str(e))
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
async def get_training_statistics(tenant_id: str = None) -> dict:
|
||||
"""Get training statistics"""
|
||||
try:
|
||||
async with database_manager.async_session_local() as session:
|
||||
# Base query for training logs
|
||||
if tenant_id:
|
||||
logs_query = text(
|
||||
"SELECT status, COUNT(*) as count "
|
||||
"FROM model_training_logs "
|
||||
"WHERE tenant_id = :tenant_id "
|
||||
"GROUP BY status"
|
||||
)
|
||||
models_query = text(
|
||||
"SELECT COUNT(*) as count "
|
||||
"FROM trained_models "
|
||||
"WHERE tenant_id = :tenant_id AND is_active = :is_active"
|
||||
)
|
||||
params = {"tenant_id": tenant_id}
|
||||
else:
|
||||
logs_query = text(
|
||||
"SELECT status, COUNT(*) as count "
|
||||
"FROM model_training_logs "
|
||||
"GROUP BY status"
|
||||
)
|
||||
models_query = text(
|
||||
"SELECT COUNT(*) as count "
|
||||
"FROM trained_models "
|
||||
"WHERE is_active = :is_active"
|
||||
)
|
||||
params = {}
|
||||
|
||||
# Get training job statistics
|
||||
logs_result = await session.execute(logs_query, params)
|
||||
job_stats = {row.status: row.count for row in logs_result.fetchall()}
|
||||
|
||||
# Get active models count
|
||||
active_models_result = await session.execute(
|
||||
models_query,
|
||||
{**params, "is_active": True}
|
||||
)
|
||||
active_models = active_models_result.scalar() or 0
|
||||
|
||||
# Get inactive models count
|
||||
inactive_models_result = await session.execute(
|
||||
models_query,
|
||||
{**params, "is_active": False}
|
||||
)
|
||||
inactive_models = inactive_models_result.scalar() or 0
|
||||
|
||||
return {
|
||||
"training_jobs": job_stats,
|
||||
"active_models": active_models,
|
||||
"inactive_models": inactive_models,
|
||||
"total_models": active_models + inactive_models
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get training statistics", error=str(e))
|
||||
return {
|
||||
"training_jobs": {},
|
||||
"active_models": 0,
|
||||
"inactive_models": 0,
|
||||
"total_models": 0
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def check_tenant_data_exists(tenant_id: str) -> bool:
|
||||
"""Check if tenant has any training data"""
|
||||
try:
|
||||
async with database_manager.async_session_local() as session:
|
||||
query = text(
|
||||
"SELECT COUNT(*) as count "
|
||||
"FROM model_training_logs "
|
||||
"WHERE tenant_id = :tenant_id "
|
||||
"LIMIT 1"
|
||||
)
|
||||
|
||||
result = await session.execute(query, {"tenant_id": tenant_id})
|
||||
count = result.scalar() or 0
|
||||
|
||||
return count > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to check tenant data existence",
|
||||
tenant_id=tenant_id, error=str(e))
|
||||
return False
|
||||
|
||||
# Enhanced database session dependency with better error handling
|
||||
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
Enhanced database session dependency with better logging and error handling
|
||||
"""
|
||||
async with database_manager.async_session_local() as session:
|
||||
try:
|
||||
logger.debug("Database session created")
|
||||
yield session
|
||||
except Exception as e:
|
||||
logger.error("Database session error", error=str(e), exc_info=True)
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
logger.debug("Database session closed")
|
||||
|
||||
# Database initialization for training service
|
||||
async def initialize_training_database():
|
||||
"""Initialize database tables for training service with retry logic and verification"""
|
||||
import asyncio
|
||||
from sqlalchemy import text
|
||||
|
||||
max_retries = 5
|
||||
retry_delay = 2.0
|
||||
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
logger.info("Initializing training service database",
|
||||
attempt=attempt, max_retries=max_retries)
|
||||
|
||||
# Step 1: Test database connectivity first
|
||||
logger.info("Testing database connectivity...")
|
||||
connection_ok = await database_manager.test_connection()
|
||||
if not connection_ok:
|
||||
raise Exception("Database connection test failed")
|
||||
logger.info("Database connectivity verified")
|
||||
|
||||
# Step 2: Import models to ensure they're registered
|
||||
logger.info("Importing and registering database models...")
|
||||
from app.models.training import (
|
||||
ModelTrainingLog,
|
||||
TrainedModel,
|
||||
ModelPerformanceMetric,
|
||||
TrainingJobQueue,
|
||||
ModelArtifact
|
||||
)
|
||||
|
||||
# Verify models are registered in metadata
|
||||
expected_tables = {
|
||||
'model_training_logs', 'trained_models', 'model_performance_metrics',
|
||||
'training_job_queue', 'model_artifacts'
|
||||
}
|
||||
registered_tables = set(Base.metadata.tables.keys())
|
||||
missing_tables = expected_tables - registered_tables
|
||||
if missing_tables:
|
||||
raise Exception(f"Models not properly registered: {missing_tables}")
|
||||
|
||||
logger.info("Models registered successfully",
|
||||
tables=list(registered_tables))
|
||||
|
||||
# Step 3: Create tables using shared infrastructure with verification
|
||||
logger.info("Creating database tables...")
|
||||
await database_manager.create_tables()
|
||||
|
||||
# Step 4: Verify tables were actually created
|
||||
logger.info("Verifying table creation...")
|
||||
verification_successful = await _verify_tables_exist()
|
||||
|
||||
if not verification_successful:
|
||||
raise Exception("Table verification failed - tables were not created properly")
|
||||
|
||||
logger.info("Training service database initialized and verified successfully",
|
||||
attempt=attempt)
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Database initialization failed",
|
||||
attempt=attempt,
|
||||
max_retries=max_retries,
|
||||
error=str(e))
|
||||
|
||||
if attempt == max_retries:
|
||||
logger.error("All database initialization attempts failed - giving up")
|
||||
raise Exception(f"Failed to initialize training database after {max_retries} attempts: {str(e)}")
|
||||
|
||||
# Wait before retry with exponential backoff
|
||||
wait_time = retry_delay * (2 ** (attempt - 1))
|
||||
logger.info("Retrying database initialization",
|
||||
retry_in_seconds=wait_time,
|
||||
next_attempt=attempt + 1)
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
async def _verify_tables_exist() -> bool:
|
||||
"""Verify that all required tables exist in the database"""
|
||||
try:
|
||||
async with database_manager.get_session() as session:
|
||||
# Check each required table exists and is accessible
|
||||
required_tables = [
|
||||
'model_training_logs',
|
||||
'trained_models',
|
||||
'model_performance_metrics',
|
||||
'training_job_queue',
|
||||
'model_artifacts'
|
||||
]
|
||||
|
||||
for table_name in required_tables:
|
||||
try:
|
||||
# Try to query the table structure
|
||||
result = await session.execute(
|
||||
text(f"SELECT 1 FROM {table_name} LIMIT 1")
|
||||
)
|
||||
logger.debug(f"Table {table_name} exists and is accessible")
|
||||
except Exception as table_error:
|
||||
# If it's a "relation does not exist" error, table creation failed
|
||||
if "does not exist" in str(table_error).lower():
|
||||
logger.error(f"Table {table_name} does not exist", error=str(table_error))
|
||||
return False
|
||||
# If it's an empty table, that's fine - table exists
|
||||
elif "no data" in str(table_error).lower():
|
||||
logger.debug(f"Table {table_name} exists but is empty (normal)")
|
||||
else:
|
||||
logger.warning(f"Unexpected error querying {table_name}", error=str(table_error))
|
||||
|
||||
logger.info("All required tables verified successfully",
|
||||
tables=required_tables)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Table verification failed", error=str(e))
|
||||
return False
|
||||
|
||||
# Database cleanup for training service
|
||||
async def cleanup_training_database():
|
||||
"""Cleanup database connections for training service"""
|
||||
try:
|
||||
logger.info("Cleaning up training service database connections")
|
||||
|
||||
# Close engine connections
|
||||
if hasattr(database_manager, 'async_engine') and database_manager.async_engine:
|
||||
await database_manager.async_engine.dispose()
|
||||
|
||||
logger.info("Training service database cleanup completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup training service database", error=str(e))
|
||||
|
||||
# Export the commonly used items to maintain compatibility
|
||||
__all__ = [
|
||||
'Base',
|
||||
'database_manager',
|
||||
'get_db',
|
||||
'get_db_session',
|
||||
'get_db_health',
|
||||
'TrainingDatabaseUtils',
|
||||
'initialize_training_database',
|
||||
'cleanup_training_database'
|
||||
]
|
||||
35
services/training/app/core/training_constants.py
Normal file
35
services/training/app/core/training_constants.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
Training Progress Constants
|
||||
Centralized constants for training progress tracking and timing
|
||||
"""
|
||||
|
||||
# Progress Milestones (percentage)
|
||||
PROGRESS_STARTED = 0
|
||||
PROGRESS_DATA_VALIDATION = 10
|
||||
PROGRESS_DATA_ANALYSIS = 20
|
||||
PROGRESS_DATA_PREPARATION_COMPLETE = 30
|
||||
PROGRESS_ML_TRAINING_START = 40
|
||||
PROGRESS_TRAINING_COMPLETE = 85
|
||||
PROGRESS_STORING_MODELS = 92
|
||||
PROGRESS_STORING_METRICS = 94
|
||||
PROGRESS_COMPLETED = 100
|
||||
|
||||
# Progress Ranges
|
||||
PROGRESS_TRAINING_RANGE_START = 20 # After data analysis
|
||||
PROGRESS_TRAINING_RANGE_END = 80 # Before finalization
|
||||
PROGRESS_TRAINING_RANGE_WIDTH = PROGRESS_TRAINING_RANGE_END - PROGRESS_TRAINING_RANGE_START # 60%
|
||||
|
||||
# Time Limits and Intervals (seconds)
|
||||
MAX_ESTIMATED_TIME_REMAINING_SECONDS = 1800 # 30 minutes
|
||||
WEBSOCKET_HEARTBEAT_INTERVAL_SECONDS = 30
|
||||
WEBSOCKET_RECONNECT_MAX_ATTEMPTS = 3
|
||||
WEBSOCKET_RECONNECT_INITIAL_DELAY_SECONDS = 1
|
||||
WEBSOCKET_RECONNECT_MAX_DELAY_SECONDS = 10
|
||||
|
||||
# Training Timeouts (seconds)
|
||||
TRAINING_SKIP_OPTION_DELAY_SECONDS = 120 # 2 minutes
|
||||
HTTP_POLLING_INTERVAL_MS = 5000 # 5 seconds
|
||||
HTTP_POLLING_DEBOUNCE_MS = 5000 # 5 seconds before enabling after WebSocket disconnect
|
||||
|
||||
# Frontend Display
|
||||
TRAINING_COMPLETION_DELAY_MS = 2000 # Delay before navigating after completion
|
||||
265
services/training/app/main.py
Normal file
265
services/training/app/main.py
Normal file
@@ -0,0 +1,265 @@
|
||||
# ================================================================
|
||||
# services/training/app/main.py
|
||||
# ================================================================
|
||||
"""
|
||||
Training Service Main Application
|
||||
ML training service for bakery demand forecasting
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from fastapi import FastAPI, Request
|
||||
from sqlalchemy import text
|
||||
from app.core.config import settings
|
||||
from app.core.database import initialize_training_database, cleanup_training_database, database_manager
|
||||
from app.api import training_jobs, training_operations, models, health, monitoring, websocket_operations, audit
|
||||
from app.services.training_events import setup_messaging, cleanup_messaging
|
||||
from app.websocket.events import setup_websocket_event_consumer, cleanup_websocket_consumers
|
||||
from shared.service_base import StandardFastAPIService
|
||||
from shared.monitoring.system_metrics import SystemMetricsCollector
|
||||
|
||||
|
||||
class TrainingService(StandardFastAPIService):
|
||||
"""Training Service with standardized setup"""
|
||||
|
||||
def __init__(self):
|
||||
# Define expected database tables for health checks
|
||||
training_expected_tables = [
|
||||
'model_training_logs', 'trained_models', 'model_performance_metrics',
|
||||
'training_job_queue', 'model_artifacts'
|
||||
]
|
||||
|
||||
super().__init__(
|
||||
service_name="training-service",
|
||||
app_name="Bakery Training Service",
|
||||
description="ML training service for bakery demand forecasting",
|
||||
version="1.0.0",
|
||||
log_level=settings.LOG_LEVEL,
|
||||
cors_origins=settings.CORS_ORIGINS_LIST,
|
||||
api_prefix="",
|
||||
database_manager=database_manager,
|
||||
expected_tables=training_expected_tables,
|
||||
enable_messaging=True
|
||||
)
|
||||
|
||||
async def _setup_messaging(self):
|
||||
"""Setup messaging for training service"""
|
||||
await setup_messaging()
|
||||
self.logger.info("Messaging setup completed")
|
||||
|
||||
# Initialize Redis pub/sub for cross-pod WebSocket broadcasting
|
||||
await self._setup_websocket_redis()
|
||||
|
||||
# Set up WebSocket event consumer (listens to RabbitMQ and broadcasts to WebSockets)
|
||||
success = await setup_websocket_event_consumer()
|
||||
if success:
|
||||
self.logger.info("WebSocket event consumer setup completed")
|
||||
else:
|
||||
self.logger.warning("WebSocket event consumer setup failed")
|
||||
|
||||
async def _setup_websocket_redis(self):
|
||||
"""
|
||||
Initialize Redis pub/sub for WebSocket cross-pod broadcasting.
|
||||
|
||||
CRITICAL FOR HORIZONTAL SCALING:
|
||||
Without this, WebSocket clients on Pod A won't receive events
|
||||
from training jobs running on Pod B.
|
||||
"""
|
||||
try:
|
||||
from app.websocket.manager import websocket_manager
|
||||
from app.core.config import settings
|
||||
|
||||
redis_url = settings.REDIS_URL
|
||||
success = await websocket_manager.initialize_redis(redis_url)
|
||||
|
||||
if success:
|
||||
self.logger.info("WebSocket Redis pub/sub initialized for horizontal scaling")
|
||||
else:
|
||||
self.logger.warning(
|
||||
"WebSocket Redis pub/sub failed to initialize. "
|
||||
"WebSocket events will only be delivered to local connections."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Failed to setup WebSocket Redis pub/sub",
|
||||
error=str(e))
|
||||
# Don't fail startup - WebSockets will work locally without Redis
|
||||
|
||||
async def _cleanup_messaging(self):
|
||||
"""Cleanup messaging for training service"""
|
||||
# Shutdown WebSocket Redis pub/sub
|
||||
try:
|
||||
from app.websocket.manager import websocket_manager
|
||||
await websocket_manager.shutdown()
|
||||
self.logger.info("WebSocket Redis pub/sub shutdown completed")
|
||||
except Exception as e:
|
||||
self.logger.warning("Error shutting down WebSocket Redis", error=str(e))
|
||||
|
||||
await cleanup_websocket_consumers()
|
||||
await cleanup_messaging()
|
||||
|
||||
async def verify_migrations(self):
|
||||
"""Verify database schema matches the latest migrations dynamically."""
|
||||
try:
|
||||
async with self.database_manager.get_session() as session:
|
||||
result = await session.execute(text("SELECT version_num FROM alembic_version"))
|
||||
version = result.scalar()
|
||||
|
||||
if not version:
|
||||
self.logger.error("No migration version found in database")
|
||||
raise RuntimeError("Database not initialized - no alembic version found")
|
||||
|
||||
self.logger.info(f"Migration verification successful: {version}")
|
||||
return version
|
||||
except Exception as e:
|
||||
self.logger.error(f"Migration verification failed: {e}")
|
||||
raise
|
||||
|
||||
async def on_startup(self, app: FastAPI):
|
||||
"""Custom startup logic including migration verification"""
|
||||
await self.verify_migrations()
|
||||
|
||||
# Initialize system metrics collection
|
||||
system_metrics = SystemMetricsCollector("training")
|
||||
self.logger.info("System metrics collection started")
|
||||
|
||||
# Recover stale jobs from previous pod crashes
|
||||
# This is important for horizontal scaling - jobs may be left in 'running'
|
||||
# state if a pod crashes. We mark them as failed so they can be retried.
|
||||
await self._recover_stale_jobs()
|
||||
|
||||
self.logger.info("Training service startup completed")
|
||||
|
||||
async def _recover_stale_jobs(self):
|
||||
"""
|
||||
Recover stale training jobs on startup.
|
||||
|
||||
When a pod crashes mid-training, jobs are left in 'running' or 'pending' state.
|
||||
This method finds jobs that haven't been updated in a while and marks them
|
||||
as failed so users can retry them.
|
||||
"""
|
||||
try:
|
||||
from app.repositories.training_log_repository import TrainingLogRepository
|
||||
|
||||
async with self.database_manager.get_session() as session:
|
||||
log_repo = TrainingLogRepository(session)
|
||||
|
||||
# Recover jobs that haven't been updated in 60 minutes
|
||||
# This is conservative - most training jobs complete within 30 minutes
|
||||
recovered = await log_repo.recover_stale_jobs(stale_threshold_minutes=60)
|
||||
|
||||
if recovered:
|
||||
self.logger.warning(
|
||||
"Recovered stale training jobs on startup",
|
||||
recovered_count=len(recovered),
|
||||
job_ids=[j.job_id for j in recovered]
|
||||
)
|
||||
else:
|
||||
self.logger.info("No stale training jobs to recover")
|
||||
|
||||
except Exception as e:
|
||||
# Don't fail startup if recovery fails - just log the error
|
||||
self.logger.error("Failed to recover stale jobs on startup", error=str(e))
|
||||
|
||||
async def on_shutdown(self, app: FastAPI):
|
||||
"""Custom shutdown logic for training service"""
|
||||
await cleanup_training_database()
|
||||
self.logger.info("Training database cleanup completed")
|
||||
|
||||
def get_service_features(self):
|
||||
"""Return training-specific features"""
|
||||
return [
|
||||
"ml_model_training",
|
||||
"demand_forecasting",
|
||||
"model_performance_tracking",
|
||||
"training_job_queue",
|
||||
"model_artifacts_management",
|
||||
"websocket_support",
|
||||
"messaging_integration"
|
||||
]
|
||||
|
||||
def setup_custom_middleware(self):
|
||||
"""Setup custom middleware for training service"""
|
||||
# Request middleware for logging and metrics
|
||||
@self.app.middleware("http")
|
||||
async def process_request(request: Request, call_next):
|
||||
"""Process requests with logging and metrics"""
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
try:
|
||||
response = await call_next(request)
|
||||
duration = asyncio.get_event_loop().time() - start_time
|
||||
|
||||
self.logger.info(
|
||||
"Request completed",
|
||||
method=request.method,
|
||||
path=request.url.path,
|
||||
status_code=response.status_code,
|
||||
duration_ms=round(duration * 1000, 2)
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
duration = asyncio.get_event_loop().time() - start_time
|
||||
|
||||
self.logger.error(
|
||||
"Request failed",
|
||||
method=request.method,
|
||||
path=request.url.path,
|
||||
error=str(e),
|
||||
duration_ms=round(duration * 1000, 2)
|
||||
)
|
||||
raise
|
||||
|
||||
def setup_custom_endpoints(self):
|
||||
"""Setup custom endpoints for training service"""
|
||||
# Note: Metrics are exported via OpenTelemetry OTLP to SigNoz
|
||||
# The /metrics endpoint is not needed as metrics are pushed automatically
|
||||
# @self.app.get("/metrics")
|
||||
# async def get_metrics():
|
||||
# """Prometheus metrics endpoint"""
|
||||
# if self.metrics_collector:
|
||||
# return self.metrics_collector.get_metrics()
|
||||
# return {"status": "metrics not available"}
|
||||
|
||||
@self.app.get("/")
|
||||
async def root():
|
||||
return {"service": "training-service", "version": "1.0.0"}
|
||||
|
||||
|
||||
# Create service instance
|
||||
service = TrainingService()
|
||||
|
||||
# Create FastAPI app with standardized setup
|
||||
app = service.create_app(
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc"
|
||||
)
|
||||
|
||||
# Setup standard endpoints
|
||||
service.setup_standard_endpoints()
|
||||
|
||||
# Setup custom middleware
|
||||
service.setup_custom_middleware()
|
||||
|
||||
# Setup custom endpoints
|
||||
service.setup_custom_endpoints()
|
||||
|
||||
# Include API routers
|
||||
# IMPORTANT: Register audit router FIRST to avoid route matching conflicts
|
||||
service.add_router(audit.router)
|
||||
service.add_router(training_jobs.router, tags=["training-jobs"])
|
||||
service.add_router(training_operations.router, tags=["training-operations"])
|
||||
service.add_router(models.router, tags=["models"])
|
||||
service.add_router(health.router, tags=["health"])
|
||||
service.add_router(monitoring.router, tags=["monitoring"])
|
||||
service.add_router(websocket_operations.router, tags=["websocket"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host="0.0.0.0",
|
||||
port=settings.PORT,
|
||||
reload=settings.DEBUG,
|
||||
log_level=settings.LOG_LEVEL.lower()
|
||||
)
|
||||
14
services/training/app/ml/__init__.py
Normal file
14
services/training/app/ml/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
ML Pipeline Components
|
||||
Machine learning training and prediction components
|
||||
"""
|
||||
|
||||
from .trainer import EnhancedBakeryMLTrainer
|
||||
from .data_processor import EnhancedBakeryDataProcessor
|
||||
from .prophet_manager import BakeryProphetManager
|
||||
|
||||
__all__ = [
|
||||
"EnhancedBakeryMLTrainer",
|
||||
"EnhancedBakeryDataProcessor",
|
||||
"BakeryProphetManager"
|
||||
]
|
||||
307
services/training/app/ml/calendar_features.py
Normal file
307
services/training/app/ml/calendar_features.py
Normal file
@@ -0,0 +1,307 @@
|
||||
"""
|
||||
Calendar-based Feature Engineering
|
||||
Hyperlocal school calendar and event features for demand forecasting
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import structlog
|
||||
from typing import Dict, List, Any, Optional
|
||||
from datetime import datetime, date, time, timedelta
|
||||
from shared.clients.external_client import ExternalServiceClient
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class CalendarFeatureEngine:
|
||||
"""
|
||||
Generates features based on school calendars and local events
|
||||
for hyperlocal demand forecasting enhancement
|
||||
"""
|
||||
|
||||
def __init__(self, external_client: ExternalServiceClient):
|
||||
self.external_client = external_client
|
||||
self.calendar_cache = {} # Cache calendar data to avoid repeated API calls
|
||||
|
||||
async def get_calendar_for_tenant(
|
||||
self,
|
||||
tenant_id: str,
|
||||
city_id: Optional[str] = "madrid"
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get the assigned school calendar for a tenant
|
||||
If tenant has no assignment, returns None
|
||||
"""
|
||||
try:
|
||||
# Check cache first
|
||||
cache_key = f"tenant_{tenant_id}_calendar"
|
||||
if cache_key in self.calendar_cache:
|
||||
logger.debug("Using cached calendar", tenant_id=tenant_id)
|
||||
return self.calendar_cache[cache_key]
|
||||
|
||||
# Get tenant location context
|
||||
context = await self.external_client.get_tenant_location_context(tenant_id)
|
||||
|
||||
if not context or not context.get("calendar"):
|
||||
logger.info(
|
||||
"No calendar assigned to tenant, using default if available",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
return None
|
||||
|
||||
calendar = context["calendar"]
|
||||
self.calendar_cache[cache_key] = calendar
|
||||
|
||||
logger.info(
|
||||
"Retrieved calendar for tenant",
|
||||
tenant_id=tenant_id,
|
||||
calendar_name=calendar.get("calendar_name")
|
||||
)
|
||||
|
||||
return calendar
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error retrieving calendar for tenant",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
return None
|
||||
|
||||
def _is_date_in_holiday_period(
|
||||
self,
|
||||
check_date: date,
|
||||
holiday_periods: List[Dict[str, Any]]
|
||||
) -> tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Check if a date falls within any holiday period
|
||||
|
||||
Returns:
|
||||
(is_holiday, holiday_name)
|
||||
"""
|
||||
for period in holiday_periods:
|
||||
start = datetime.strptime(period["start_date"], "%Y-%m-%d").date()
|
||||
end = datetime.strptime(period["end_date"], "%Y-%m-%d").date()
|
||||
|
||||
if start <= check_date <= end:
|
||||
return True, period["name"]
|
||||
|
||||
return False, None
|
||||
|
||||
def _is_school_hours_active(
|
||||
self,
|
||||
check_datetime: datetime,
|
||||
school_hours: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""
|
||||
Check if datetime falls during school operating hours
|
||||
|
||||
Args:
|
||||
check_datetime: DateTime to check
|
||||
school_hours: School hours configuration dict
|
||||
|
||||
Returns:
|
||||
True if during school hours, False otherwise
|
||||
"""
|
||||
# Only check weekdays
|
||||
if check_datetime.weekday() >= 5: # Saturday=5, Sunday=6
|
||||
return False
|
||||
|
||||
check_time = check_datetime.time()
|
||||
|
||||
# Morning session
|
||||
morning_start = datetime.strptime(
|
||||
school_hours["morning_start"], "%H:%M"
|
||||
).time()
|
||||
morning_end = datetime.strptime(
|
||||
school_hours["morning_end"], "%H:%M"
|
||||
).time()
|
||||
|
||||
if morning_start <= check_time <= morning_end:
|
||||
return True
|
||||
|
||||
# Afternoon session (if applicable)
|
||||
if school_hours.get("has_afternoon_session", False):
|
||||
afternoon_start = datetime.strptime(
|
||||
school_hours["afternoon_start"], "%H:%M"
|
||||
).time()
|
||||
afternoon_end = datetime.strptime(
|
||||
school_hours["afternoon_end"], "%H:%M"
|
||||
).time()
|
||||
|
||||
if afternoon_start <= check_time <= afternoon_end:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _calculate_school_proximity_intensity(
|
||||
self,
|
||||
check_datetime: datetime,
|
||||
school_hours: Dict[str, Any]
|
||||
) -> float:
|
||||
"""
|
||||
Calculate intensity of school-related foot traffic
|
||||
Peaks during drop-off and pick-up times
|
||||
|
||||
Returns:
|
||||
Float between 0.0 (no impact) and 1.0 (peak impact)
|
||||
"""
|
||||
# Only weekdays
|
||||
if check_datetime.weekday() >= 5:
|
||||
return 0.0
|
||||
|
||||
check_time = check_datetime.time()
|
||||
|
||||
# Define peak windows (30 minutes before and after school start/end)
|
||||
morning_start = datetime.strptime(
|
||||
school_hours["morning_start"], "%H:%M"
|
||||
).time()
|
||||
morning_end = datetime.strptime(
|
||||
school_hours["morning_end"], "%H:%M"
|
||||
).time()
|
||||
|
||||
# Morning drop-off peak (30 min before to 15 min after start)
|
||||
drop_off_start = (
|
||||
datetime.combine(date.today(), morning_start) - timedelta(minutes=30)
|
||||
).time()
|
||||
drop_off_end = (
|
||||
datetime.combine(date.today(), morning_start) + timedelta(minutes=15)
|
||||
).time()
|
||||
|
||||
if drop_off_start <= check_time <= drop_off_end:
|
||||
return 1.0 # Peak morning traffic
|
||||
|
||||
# Morning pick-up peak (15 min before to 30 min after end)
|
||||
pickup_start = (
|
||||
datetime.combine(date.today(), morning_end) - timedelta(minutes=15)
|
||||
).time()
|
||||
pickup_end = (
|
||||
datetime.combine(date.today(), morning_end) + timedelta(minutes=30)
|
||||
).time()
|
||||
|
||||
if pickup_start <= check_time <= pickup_end:
|
||||
return 1.0 # Peak afternoon traffic
|
||||
|
||||
# During school hours (moderate impact)
|
||||
if morning_start <= check_time <= morning_end:
|
||||
return 0.3
|
||||
|
||||
# Afternoon session if applicable
|
||||
if school_hours.get("has_afternoon_session", False):
|
||||
afternoon_start = datetime.strptime(
|
||||
school_hours["afternoon_start"], "%H:%M"
|
||||
).time()
|
||||
afternoon_end = datetime.strptime(
|
||||
school_hours["afternoon_end"], "%H:%M"
|
||||
).time()
|
||||
|
||||
if afternoon_start <= check_time <= afternoon_end:
|
||||
return 0.3
|
||||
|
||||
return 0.0
|
||||
|
||||
async def add_calendar_features(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
tenant_id: str,
|
||||
date_column: str = "date"
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Add calendar-based features to dataframe
|
||||
|
||||
Features added:
|
||||
- is_school_holiday: Binary (1/0)
|
||||
- school_holiday_name: String (name of holiday or None)
|
||||
- school_hours_active: Binary (1/0) - if during school operating hours
|
||||
- school_proximity_intensity: Float (0.0-1.0) - peak during drop-off/pick-up
|
||||
|
||||
Args:
|
||||
df: DataFrame with date/datetime column
|
||||
tenant_id: Tenant ID to get calendar assignment
|
||||
date_column: Name of date column
|
||||
|
||||
Returns:
|
||||
DataFrame with added calendar features
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Adding calendar-based features",
|
||||
tenant_id=tenant_id,
|
||||
rows=len(df)
|
||||
)
|
||||
|
||||
# Get calendar for tenant
|
||||
calendar = await self.get_calendar_for_tenant(tenant_id)
|
||||
|
||||
if not calendar:
|
||||
logger.warning(
|
||||
"No calendar available, using fallback features",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
# Add default features (all zeros)
|
||||
df["is_school_holiday"] = 0
|
||||
df["school_holiday_name"] = None
|
||||
df["school_hours_active"] = 0
|
||||
df["school_proximity_intensity"] = 0.0
|
||||
return df
|
||||
|
||||
holiday_periods = calendar.get("holiday_periods", [])
|
||||
school_hours = calendar.get("school_hours", {})
|
||||
|
||||
# Initialize feature columns
|
||||
school_holidays = []
|
||||
holiday_names = []
|
||||
hours_active = []
|
||||
proximity_intensity = []
|
||||
|
||||
# Process each row
|
||||
for idx, row in df.iterrows():
|
||||
row_date = pd.to_datetime(row[date_column])
|
||||
|
||||
# Check if holiday
|
||||
is_holiday, holiday_name = self._is_date_in_holiday_period(
|
||||
row_date.date(),
|
||||
holiday_periods
|
||||
)
|
||||
school_holidays.append(1 if is_holiday else 0)
|
||||
holiday_names.append(holiday_name)
|
||||
|
||||
# Check if during school hours (requires time component)
|
||||
if hasattr(row_date, 'hour'): # Has time component
|
||||
hours_active.append(
|
||||
1 if self._is_school_hours_active(row_date, school_hours) else 0
|
||||
)
|
||||
proximity_intensity.append(
|
||||
self._calculate_school_proximity_intensity(row_date, school_hours)
|
||||
)
|
||||
else:
|
||||
# Date only, no time component
|
||||
hours_active.append(0)
|
||||
proximity_intensity.append(0.0)
|
||||
|
||||
# Add features to dataframe
|
||||
df["is_school_holiday"] = school_holidays
|
||||
df["school_holiday_name"] = holiday_names
|
||||
df["school_hours_active"] = hours_active
|
||||
df["school_proximity_intensity"] = proximity_intensity
|
||||
|
||||
logger.info(
|
||||
"Calendar features added successfully",
|
||||
tenant_id=tenant_id,
|
||||
holiday_periods_count=len(holiday_periods),
|
||||
holidays_found=sum(school_holidays)
|
||||
)
|
||||
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error adding calendar features",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
# Return df with default features on error
|
||||
df["is_school_holiday"] = 0
|
||||
df["school_holiday_name"] = None
|
||||
df["school_hours_active"] = 0
|
||||
df["school_proximity_intensity"] = 0.0
|
||||
return df
|
||||
1453
services/training/app/ml/data_processor.py
Normal file
1453
services/training/app/ml/data_processor.py
Normal file
File diff suppressed because it is too large
Load Diff
355
services/training/app/ml/enhanced_features.py
Normal file
355
services/training/app/ml/enhanced_features.py
Normal file
@@ -0,0 +1,355 @@
|
||||
"""
|
||||
Enhanced Feature Engineering for Hybrid Prophet + XGBoost Models
|
||||
Adds lagged features, rolling statistics, and advanced interactions
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Dict, List, Optional
|
||||
import structlog
|
||||
from shared.ml.feature_calculator import HistoricalFeatureCalculator
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class AdvancedFeatureEngineer:
|
||||
"""
|
||||
Advanced feature engineering for hybrid forecasting models.
|
||||
Adds lagged features, rolling statistics, and complex interactions.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.feature_columns = []
|
||||
self.feature_calculator = HistoricalFeatureCalculator()
|
||||
|
||||
def add_lagged_features(self, df: pd.DataFrame, lag_days: List[int] = None) -> pd.DataFrame:
|
||||
"""
|
||||
Add lagged demand features for capturing recent trends.
|
||||
Uses shared feature calculator for consistency with prediction service.
|
||||
|
||||
Args:
|
||||
df: DataFrame with 'quantity' column
|
||||
lag_days: List of lag periods (default: [1, 7, 14])
|
||||
|
||||
Returns:
|
||||
DataFrame with added lagged features
|
||||
"""
|
||||
if lag_days is None:
|
||||
lag_days = [1, 7, 14]
|
||||
|
||||
# Use shared calculator for consistent lag calculation
|
||||
df = self.feature_calculator.calculate_lag_features(
|
||||
df,
|
||||
lag_days=lag_days,
|
||||
mode='training'
|
||||
)
|
||||
|
||||
# Update feature columns list
|
||||
for lag in lag_days:
|
||||
col_name = f'lag_{lag}_day'
|
||||
if col_name not in self.feature_columns:
|
||||
self.feature_columns.append(col_name)
|
||||
|
||||
logger.info(f"Added {len(lag_days)} lagged features (using shared calculator)", lags=lag_days)
|
||||
return df
|
||||
|
||||
def add_rolling_features(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
windows: List[int] = None,
|
||||
features: List[str] = None
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Add rolling statistics (mean, std, max, min).
|
||||
Uses shared feature calculator for consistency with prediction service.
|
||||
|
||||
Args:
|
||||
df: DataFrame with 'quantity' column
|
||||
windows: List of window sizes (default: [7, 14, 30])
|
||||
features: List of statistics to calculate (default: ['mean', 'std', 'max', 'min'])
|
||||
|
||||
Returns:
|
||||
DataFrame with rolling features
|
||||
"""
|
||||
if windows is None:
|
||||
windows = [7, 14, 30]
|
||||
|
||||
if features is None:
|
||||
features = ['mean', 'std', 'max', 'min']
|
||||
|
||||
# Use shared calculator for consistent rolling calculation
|
||||
df = self.feature_calculator.calculate_rolling_features(
|
||||
df,
|
||||
windows=windows,
|
||||
statistics=features,
|
||||
mode='training'
|
||||
)
|
||||
|
||||
# Update feature columns list
|
||||
for window in windows:
|
||||
for feature in features:
|
||||
col_name = f'rolling_{feature}_{window}d'
|
||||
if col_name not in self.feature_columns:
|
||||
self.feature_columns.append(col_name)
|
||||
|
||||
logger.info(f"Added rolling features (using shared calculator)", windows=windows, features=features)
|
||||
return df
|
||||
|
||||
def add_day_of_week_features(self, df: pd.DataFrame, date_column: str = 'date') -> pd.DataFrame:
|
||||
"""
|
||||
Add enhanced day-of-week features.
|
||||
|
||||
Args:
|
||||
df: DataFrame with date column
|
||||
date_column: Name of date column
|
||||
|
||||
Returns:
|
||||
DataFrame with day-of-week features
|
||||
"""
|
||||
df = df.copy()
|
||||
|
||||
# Day of week (0=Monday, 6=Sunday)
|
||||
df['day_of_week'] = df[date_column].dt.dayofweek
|
||||
|
||||
# Is weekend
|
||||
df['is_weekend'] = (df['day_of_week'] >= 5).astype(int)
|
||||
|
||||
# Is Friday (often higher demand due to weekend prep)
|
||||
df['is_friday'] = (df['day_of_week'] == 4).astype(int)
|
||||
|
||||
# Is Monday (often lower demand after weekend)
|
||||
df['is_monday'] = (df['day_of_week'] == 0).astype(int)
|
||||
|
||||
# Add to feature list
|
||||
for col in ['day_of_week', 'is_weekend', 'is_friday', 'is_monday']:
|
||||
if col not in self.feature_columns:
|
||||
self.feature_columns.append(col)
|
||||
|
||||
return df
|
||||
|
||||
def add_calendar_enhanced_features(self, df: pd.DataFrame, date_column: str = 'date') -> pd.DataFrame:
|
||||
"""
|
||||
Add enhanced calendar features beyond basic temporal features.
|
||||
|
||||
Args:
|
||||
df: DataFrame with date column
|
||||
date_column: Name of date column
|
||||
|
||||
Returns:
|
||||
DataFrame with enhanced calendar features
|
||||
"""
|
||||
df = df.copy()
|
||||
|
||||
# Month and quarter (if not already present)
|
||||
if 'month' not in df.columns:
|
||||
df['month'] = df[date_column].dt.month
|
||||
|
||||
if 'quarter' not in df.columns:
|
||||
df['quarter'] = df[date_column].dt.quarter
|
||||
|
||||
# Day of month
|
||||
df['day_of_month'] = df[date_column].dt.day
|
||||
|
||||
# Is month start/end
|
||||
df['is_month_start'] = (df['day_of_month'] <= 3).astype(int)
|
||||
df['is_month_end'] = (df[date_column].dt.is_month_end).astype(int)
|
||||
|
||||
# Week of year
|
||||
df['week_of_year'] = df[date_column].dt.isocalendar().week
|
||||
|
||||
# Payday indicators for Spain (high bakery traffic)
|
||||
# Spain commonly pays on: 28th, 15th, or last day of month
|
||||
df['is_payday'] = (
|
||||
(df['day_of_month'] == 15) | # Mid-month payday
|
||||
(df['day_of_month'] == 28) | # Common Spanish payday (28th)
|
||||
df[date_column].dt.is_month_end # End of month
|
||||
).astype(int)
|
||||
|
||||
# Add to feature list
|
||||
for col in ['month', 'quarter', 'day_of_month', 'is_month_start', 'is_month_end',
|
||||
'week_of_year', 'is_payday']:
|
||||
if col not in self.feature_columns:
|
||||
self.feature_columns.append(col)
|
||||
|
||||
return df
|
||||
|
||||
def add_interaction_features(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Add interaction features between variables.
|
||||
|
||||
Args:
|
||||
df: DataFrame with base features
|
||||
|
||||
Returns:
|
||||
DataFrame with interaction features
|
||||
"""
|
||||
df = df.copy()
|
||||
|
||||
# Weekend × Temperature (people buy more cold drinks in hot weekends)
|
||||
if 'is_weekend' in df.columns and 'temperature' in df.columns:
|
||||
df['weekend_temp_interaction'] = df['is_weekend'] * df['temperature']
|
||||
self.feature_columns.append('weekend_temp_interaction')
|
||||
|
||||
# Rain × Weekend (bad weather reduces weekend traffic)
|
||||
if 'is_weekend' in df.columns and 'precipitation' in df.columns:
|
||||
df['rain_weekend_interaction'] = df['is_weekend'] * (df['precipitation'] > 0).astype(int)
|
||||
self.feature_columns.append('rain_weekend_interaction')
|
||||
|
||||
# Friday × Traffic (high Friday traffic means weekend prep buying)
|
||||
if 'is_friday' in df.columns and 'traffic_volume' in df.columns:
|
||||
df['friday_traffic_interaction'] = df['is_friday'] * df['traffic_volume']
|
||||
self.feature_columns.append('friday_traffic_interaction')
|
||||
|
||||
# Month × Temperature (seasonal temperature patterns)
|
||||
if 'month' in df.columns and 'temperature' in df.columns:
|
||||
df['month_temp_interaction'] = df['month'] * df['temperature']
|
||||
self.feature_columns.append('month_temp_interaction')
|
||||
|
||||
# Payday × Weekend (big shopping days)
|
||||
if 'is_payday' in df.columns and 'is_weekend' in df.columns:
|
||||
df['payday_weekend_interaction'] = df['is_payday'] * df['is_weekend']
|
||||
self.feature_columns.append('payday_weekend_interaction')
|
||||
|
||||
logger.info(f"Added {len([c for c in self.feature_columns if 'interaction' in c])} interaction features")
|
||||
return df
|
||||
|
||||
def add_trend_features(self, df: pd.DataFrame, date_column: str = 'date') -> pd.DataFrame:
|
||||
"""
|
||||
Add trend-based features.
|
||||
Uses shared feature calculator for consistency with prediction service.
|
||||
|
||||
Args:
|
||||
df: DataFrame with date and quantity
|
||||
date_column: Name of date column
|
||||
|
||||
Returns:
|
||||
DataFrame with trend features
|
||||
"""
|
||||
# Use shared calculator for consistent trend calculation
|
||||
df = self.feature_calculator.calculate_trend_features(
|
||||
df,
|
||||
mode='training'
|
||||
)
|
||||
|
||||
# Update feature columns list
|
||||
for feature_name in ['days_since_start', 'momentum_1_7', 'trend_7_30', 'velocity_week']:
|
||||
if feature_name in df.columns and feature_name not in self.feature_columns:
|
||||
self.feature_columns.append(feature_name)
|
||||
|
||||
logger.debug("Added trend features (using shared calculator)")
|
||||
return df
|
||||
|
||||
def add_cyclical_encoding(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Add cyclical encoding for periodic features (day_of_week, month).
|
||||
Helps models understand that Monday follows Sunday, December follows January.
|
||||
|
||||
Args:
|
||||
df: DataFrame with day_of_week and month columns
|
||||
|
||||
Returns:
|
||||
DataFrame with cyclical features
|
||||
"""
|
||||
df = df.copy()
|
||||
|
||||
# Day of week cyclical encoding
|
||||
if 'day_of_week' in df.columns:
|
||||
df['day_of_week_sin'] = np.sin(2 * np.pi * df['day_of_week'] / 7)
|
||||
df['day_of_week_cos'] = np.cos(2 * np.pi * df['day_of_week'] / 7)
|
||||
self.feature_columns.extend(['day_of_week_sin', 'day_of_week_cos'])
|
||||
|
||||
# Month cyclical encoding
|
||||
if 'month' in df.columns:
|
||||
df['month_sin'] = np.sin(2 * np.pi * df['month'] / 12)
|
||||
df['month_cos'] = np.cos(2 * np.pi * df['month'] / 12)
|
||||
self.feature_columns.extend(['month_sin', 'month_cos'])
|
||||
|
||||
logger.info("Added cyclical encoding for temporal features")
|
||||
return df
|
||||
|
||||
def create_all_features(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
date_column: str = 'date',
|
||||
include_lags: bool = True,
|
||||
include_rolling: bool = True,
|
||||
include_interactions: bool = True,
|
||||
include_cyclical: bool = True
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Create all enhanced features in one go.
|
||||
|
||||
Args:
|
||||
df: DataFrame with base data
|
||||
date_column: Name of date column
|
||||
include_lags: Whether to include lagged features
|
||||
include_rolling: Whether to include rolling statistics
|
||||
include_interactions: Whether to include interaction features
|
||||
include_cyclical: Whether to include cyclical encoding
|
||||
|
||||
Returns:
|
||||
DataFrame with all enhanced features
|
||||
"""
|
||||
logger.info("Creating comprehensive feature set for hybrid model")
|
||||
|
||||
# Reset feature list
|
||||
self.feature_columns = []
|
||||
|
||||
# Day of week and calendar features (always needed)
|
||||
df = self.add_day_of_week_features(df, date_column)
|
||||
df = self.add_calendar_enhanced_features(df, date_column)
|
||||
|
||||
# Optional features
|
||||
if include_lags:
|
||||
df = self.add_lagged_features(df)
|
||||
|
||||
if include_rolling:
|
||||
df = self.add_rolling_features(df)
|
||||
|
||||
if include_interactions:
|
||||
df = self.add_interaction_features(df)
|
||||
|
||||
if include_cyclical:
|
||||
df = self.add_cyclical_encoding(df)
|
||||
|
||||
# Trend features (depends on lags and rolling)
|
||||
if include_lags or include_rolling:
|
||||
df = self.add_trend_features(df, date_column)
|
||||
|
||||
logger.info(f"Created {len(self.feature_columns)} enhanced features for hybrid model")
|
||||
|
||||
return df
|
||||
|
||||
def get_feature_columns(self) -> List[str]:
|
||||
"""Get list of all created feature column names."""
|
||||
return self.feature_columns.copy()
|
||||
|
||||
def fill_na_values(self, df: pd.DataFrame, strategy: str = 'forward_mean') -> pd.DataFrame:
|
||||
"""
|
||||
Fill NA values in lagged and rolling features.
|
||||
|
||||
IMPORTANT: Never uses backward fill to prevent data leakage in time series training.
|
||||
|
||||
Args:
|
||||
df: DataFrame with potential NA values
|
||||
strategy: 'forward_mean', 'zero', 'mean'
|
||||
|
||||
Returns:
|
||||
DataFrame with filled NA values
|
||||
"""
|
||||
df = df.copy()
|
||||
|
||||
if strategy == 'forward_mean':
|
||||
# Forward fill first (use previous values)
|
||||
df = df.fillna(method='ffill')
|
||||
# Fill remaining with mean (typically at beginning of series)
|
||||
# NEVER use bfill as it leaks future information into training data
|
||||
df = df.fillna(df.mean())
|
||||
|
||||
elif strategy == 'zero':
|
||||
df = df.fillna(0)
|
||||
|
||||
elif strategy == 'mean':
|
||||
df = df.fillna(df.mean())
|
||||
|
||||
return df
|
||||
253
services/training/app/ml/event_feature_generator.py
Normal file
253
services/training/app/ml/event_feature_generator.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""
|
||||
Event Feature Generator
|
||||
Converts calendar events into features for demand forecasting
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import date, timedelta
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class EventFeatureGenerator:
|
||||
"""
|
||||
Generate event-related features for demand forecasting.
|
||||
|
||||
Features include:
|
||||
- Binary flags for event presence
|
||||
- Event impact multipliers
|
||||
- Event type indicators
|
||||
- Days until/since major events
|
||||
"""
|
||||
|
||||
# Event type impact weights (default multipliers)
|
||||
EVENT_IMPACT_WEIGHTS = {
|
||||
'promotion': 1.3,
|
||||
'festival': 1.8,
|
||||
'holiday': 0.7, # Bakeries often close or have reduced demand
|
||||
'weather_event': 0.8, # Bad weather reduces foot traffic
|
||||
'school_break': 1.2,
|
||||
'sport_event': 1.4,
|
||||
'market': 1.5,
|
||||
'concert': 1.3,
|
||||
'local_event': 1.2
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def generate_event_features(
|
||||
self,
|
||||
dates: pd.DatetimeIndex,
|
||||
events: List[Dict[str, Any]]
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Generate event features for given dates.
|
||||
|
||||
Args:
|
||||
dates: Dates to generate features for
|
||||
events: List of event dictionaries with keys:
|
||||
- event_date: date
|
||||
- event_type: str
|
||||
- impact_multiplier: float (optional)
|
||||
- event_name: str
|
||||
|
||||
Returns:
|
||||
DataFrame with event features
|
||||
"""
|
||||
df = pd.DataFrame({'date': dates})
|
||||
|
||||
# Initialize feature columns
|
||||
df['has_event'] = 0
|
||||
df['event_impact'] = 1.0 # Neutral impact
|
||||
df['is_promotion'] = 0
|
||||
df['is_festival'] = 0
|
||||
df['is_local_event'] = 0
|
||||
df['days_to_next_event'] = 365
|
||||
df['days_since_last_event'] = 365
|
||||
|
||||
if not events:
|
||||
logger.debug("No events provided, returning default features")
|
||||
return df
|
||||
|
||||
# Convert events to DataFrame for easier processing
|
||||
events_df = pd.DataFrame(events)
|
||||
events_df['event_date'] = pd.to_datetime(events_df['event_date'])
|
||||
|
||||
for idx, row in df.iterrows():
|
||||
current_date = pd.to_datetime(row['date'])
|
||||
|
||||
# Check if there's an event on this date
|
||||
day_events = events_df[events_df['event_date'] == current_date]
|
||||
|
||||
if not day_events.empty:
|
||||
df.at[idx, 'has_event'] = 1
|
||||
|
||||
# Use custom impact multiplier if provided, else use default
|
||||
if 'impact_multiplier' in day_events.columns and not day_events['impact_multiplier'].isna().all():
|
||||
impact = day_events['impact_multiplier'].max()
|
||||
else:
|
||||
# Use default impact based on event type
|
||||
event_types = day_events['event_type'].tolist()
|
||||
impacts = [self.EVENT_IMPACT_WEIGHTS.get(et, 1.0) for et in event_types]
|
||||
impact = max(impacts)
|
||||
|
||||
df.at[idx, 'event_impact'] = impact
|
||||
|
||||
# Set event type flags
|
||||
event_types = day_events['event_type'].tolist()
|
||||
if 'promotion' in event_types:
|
||||
df.at[idx, 'is_promotion'] = 1
|
||||
if 'festival' in event_types:
|
||||
df.at[idx, 'is_festival'] = 1
|
||||
if 'local_event' in event_types or 'market' in event_types:
|
||||
df.at[idx, 'is_local_event'] = 1
|
||||
|
||||
# Calculate days to/from nearest event
|
||||
future_events = events_df[events_df['event_date'] > current_date]
|
||||
if not future_events.empty:
|
||||
next_event_date = future_events['event_date'].min()
|
||||
df.at[idx, 'days_to_next_event'] = (next_event_date - current_date).days
|
||||
|
||||
past_events = events_df[events_df['event_date'] < current_date]
|
||||
if not past_events.empty:
|
||||
last_event_date = past_events['event_date'].max()
|
||||
df.at[idx, 'days_since_last_event'] = (current_date - last_event_date).days
|
||||
|
||||
# Cap days values at 365
|
||||
df['days_to_next_event'] = df['days_to_next_event'].clip(upper=365)
|
||||
df['days_since_last_event'] = df['days_since_last_event'].clip(upper=365)
|
||||
|
||||
logger.debug("Generated event features",
|
||||
total_days=len(df),
|
||||
days_with_events=df['has_event'].sum())
|
||||
|
||||
return df
|
||||
|
||||
def add_event_features_to_forecast_data(
|
||||
self,
|
||||
forecast_data: pd.DataFrame,
|
||||
event_features: pd.DataFrame
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Add event features to forecast input data.
|
||||
|
||||
Args:
|
||||
forecast_data: Existing forecast data with 'date' column
|
||||
event_features: Event features from generate_event_features()
|
||||
|
||||
Returns:
|
||||
Enhanced forecast data with event features
|
||||
"""
|
||||
forecast_data = forecast_data.copy()
|
||||
forecast_data['date'] = pd.to_datetime(forecast_data['date'])
|
||||
event_features['date'] = pd.to_datetime(event_features['date'])
|
||||
|
||||
# Merge event features
|
||||
enhanced_data = forecast_data.merge(
|
||||
event_features[[
|
||||
'date', 'has_event', 'event_impact', 'is_promotion',
|
||||
'is_festival', 'is_local_event', 'days_to_next_event',
|
||||
'days_since_last_event'
|
||||
]],
|
||||
on='date',
|
||||
how='left'
|
||||
)
|
||||
|
||||
# Fill missing with defaults
|
||||
enhanced_data['has_event'].fillna(0, inplace=True)
|
||||
enhanced_data['event_impact'].fillna(1.0, inplace=True)
|
||||
enhanced_data['is_promotion'].fillna(0, inplace=True)
|
||||
enhanced_data['is_festival'].fillna(0, inplace=True)
|
||||
enhanced_data['is_local_event'].fillna(0, inplace=True)
|
||||
enhanced_data['days_to_next_event'].fillna(365, inplace=True)
|
||||
enhanced_data['days_since_last_event'].fillna(365, inplace=True)
|
||||
|
||||
return enhanced_data
|
||||
|
||||
def get_event_summary(self, events: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
Get summary statistics about events.
|
||||
|
||||
Args:
|
||||
events: List of event dictionaries
|
||||
|
||||
Returns:
|
||||
Summary dict with counts by type, avg impact, etc.
|
||||
"""
|
||||
if not events:
|
||||
return {
|
||||
'total_events': 0,
|
||||
'events_by_type': {},
|
||||
'avg_impact': 1.0
|
||||
}
|
||||
|
||||
events_df = pd.DataFrame(events)
|
||||
|
||||
summary = {
|
||||
'total_events': len(events),
|
||||
'events_by_type': events_df['event_type'].value_counts().to_dict(),
|
||||
'date_range': {
|
||||
'start': events_df['event_date'].min().isoformat() if not events_df.empty else None,
|
||||
'end': events_df['event_date'].max().isoformat() if not events_df.empty else None
|
||||
}
|
||||
}
|
||||
|
||||
if 'impact_multiplier' in events_df.columns:
|
||||
summary['avg_impact'] = float(events_df['impact_multiplier'].mean())
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
def create_event_calendar_features(
|
||||
dates: pd.DatetimeIndex,
|
||||
tenant_id: str,
|
||||
event_repository = None
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Convenience function to fetch events from database and generate features.
|
||||
|
||||
Args:
|
||||
dates: Dates to generate features for
|
||||
tenant_id: Tenant UUID
|
||||
event_repository: EventRepository instance (optional)
|
||||
|
||||
Returns:
|
||||
DataFrame with event features
|
||||
"""
|
||||
if event_repository is None:
|
||||
logger.warning("No event repository provided, using empty events")
|
||||
events = []
|
||||
else:
|
||||
# Fetch events from database
|
||||
from datetime import date
|
||||
start_date = dates.min().date()
|
||||
end_date = dates.max().date()
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
from uuid import UUID
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
events_objects = loop.run_until_complete(
|
||||
event_repository.get_events_by_date_range(
|
||||
tenant_id=UUID(tenant_id),
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
confirmed_only=False
|
||||
)
|
||||
)
|
||||
|
||||
# Convert to dict format
|
||||
events = [event.to_dict() for event in events_objects]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch events from database: {e}")
|
||||
events = []
|
||||
|
||||
# Generate features
|
||||
generator = EventFeatureGenerator()
|
||||
return generator.generate_event_features(dates, events)
|
||||
463
services/training/app/ml/hybrid_trainer.py
Normal file
463
services/training/app/ml/hybrid_trainer.py
Normal file
@@ -0,0 +1,463 @@
|
||||
"""
|
||||
Hybrid Prophet + XGBoost Trainer
|
||||
Combines Prophet's seasonality modeling with XGBoost's pattern learning
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import io
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
import structlog
|
||||
from datetime import datetime, timezone
|
||||
import joblib
|
||||
from sklearn.metrics import mean_absolute_error, mean_squared_error, mean_absolute_percentage_error
|
||||
from sklearn.model_selection import TimeSeriesSplit
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
# Import XGBoost
|
||||
try:
|
||||
import xgboost as xgb
|
||||
except ImportError:
|
||||
raise ImportError("XGBoost not installed. Run: pip install xgboost")
|
||||
|
||||
from app.ml.prophet_manager import BakeryProphetManager
|
||||
from app.ml.enhanced_features import AdvancedFeatureEngineer
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class HybridProphetXGBoost:
|
||||
"""
|
||||
Hybrid forecasting model combining Prophet and XGBoost.
|
||||
|
||||
Approach:
|
||||
1. Train Prophet on historical data (captures trend, seasonality, holidays)
|
||||
2. Calculate residuals (actual - prophet_prediction)
|
||||
3. Train XGBoost on residuals using enhanced features
|
||||
4. Final prediction = prophet_prediction + xgboost_residual_prediction
|
||||
|
||||
Benefits:
|
||||
- Prophet handles seasonality, holidays, trends
|
||||
- XGBoost captures complex patterns Prophet misses
|
||||
- Maintains Prophet's interpretability
|
||||
- Improves accuracy by 10-25% over Prophet alone
|
||||
"""
|
||||
|
||||
def __init__(self, database_manager=None):
|
||||
self.prophet_manager = BakeryProphetManager(database_manager)
|
||||
self.feature_engineer = AdvancedFeatureEngineer()
|
||||
self.xgb_model = None
|
||||
self.feature_columns = []
|
||||
self.prophet_model_data = None
|
||||
|
||||
async def train_hybrid_model(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
df: pd.DataFrame,
|
||||
job_id: str,
|
||||
validation_split: float = 0.2,
|
||||
session = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Train hybrid Prophet + XGBoost model.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
inventory_product_id: Product identifier
|
||||
df: Training data (must have 'ds', 'y' and regressor columns)
|
||||
job_id: Training job identifier
|
||||
validation_split: Fraction of data for validation
|
||||
session: Optional database session (uses parent session if provided to avoid nested sessions)
|
||||
|
||||
Returns:
|
||||
Dictionary with model metadata and performance metrics
|
||||
"""
|
||||
logger.info(
|
||||
"Starting hybrid Prophet + XGBoost training",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
data_points=len(df)
|
||||
)
|
||||
|
||||
# Step 1: Train Prophet model (base forecaster)
|
||||
logger.info("Step 1: Training Prophet base model")
|
||||
# ✅ FIX: Pass session to prophet_manager to avoid nested session issues
|
||||
prophet_result = await self.prophet_manager.train_bakery_model(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
df=df.copy(),
|
||||
job_id=job_id,
|
||||
session=session
|
||||
)
|
||||
|
||||
self.prophet_model_data = prophet_result
|
||||
|
||||
# Step 2: Create enhanced features for XGBoost
|
||||
logger.info("Step 2: Engineering enhanced features for XGBoost")
|
||||
df_enhanced = self._prepare_xgboost_features(df)
|
||||
|
||||
# Step 3: Split into train/validation
|
||||
split_idx = int(len(df_enhanced) * (1 - validation_split))
|
||||
train_df = df_enhanced.iloc[:split_idx].copy()
|
||||
val_df = df_enhanced.iloc[split_idx:].copy()
|
||||
|
||||
logger.info(
|
||||
"Data split",
|
||||
train_samples=len(train_df),
|
||||
val_samples=len(val_df)
|
||||
)
|
||||
|
||||
# Step 4: Get Prophet predictions on training data
|
||||
logger.info("Step 3: Generating Prophet predictions for residual calculation")
|
||||
train_prophet_pred = await self._get_prophet_predictions(prophet_result, train_df)
|
||||
val_prophet_pred = await self._get_prophet_predictions(prophet_result, val_df)
|
||||
|
||||
# Step 5: Calculate residuals (actual - prophet_prediction)
|
||||
train_residuals = train_df['y'].values - train_prophet_pred
|
||||
val_residuals = val_df['y'].values - val_prophet_pred
|
||||
|
||||
logger.info(
|
||||
"Residuals calculated",
|
||||
train_residual_mean=float(np.mean(train_residuals)),
|
||||
train_residual_std=float(np.std(train_residuals))
|
||||
)
|
||||
|
||||
# Step 6: Prepare feature matrix for XGBoost
|
||||
X_train = train_df[self.feature_columns].values
|
||||
X_val = val_df[self.feature_columns].values
|
||||
|
||||
# Step 7: Train XGBoost on residuals
|
||||
logger.info("Step 4: Training XGBoost on residuals")
|
||||
self.xgb_model = await self._train_xgboost(
|
||||
X_train, train_residuals,
|
||||
X_val, val_residuals
|
||||
)
|
||||
|
||||
# Step 8: Evaluate hybrid model
|
||||
logger.info("Step 5: Evaluating hybrid model performance")
|
||||
metrics = await self._evaluate_hybrid_model(
|
||||
train_df, val_df,
|
||||
train_prophet_pred, val_prophet_pred,
|
||||
prophet_result
|
||||
)
|
||||
|
||||
# Step 9: Save hybrid model
|
||||
model_data = self._package_hybrid_model(
|
||||
prophet_result, metrics, tenant_id, inventory_product_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Hybrid model training complete",
|
||||
prophet_mape=metrics['prophet_val_mape'],
|
||||
hybrid_mape=metrics['hybrid_val_mape'],
|
||||
improvement_pct=metrics['improvement_percentage']
|
||||
)
|
||||
|
||||
return model_data
|
||||
|
||||
def _prepare_xgboost_features(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Prepare enhanced features for XGBoost.
|
||||
|
||||
Args:
|
||||
df: Base dataframe with 'ds', 'y' and regressor columns
|
||||
|
||||
Returns:
|
||||
DataFrame with all enhanced features
|
||||
"""
|
||||
# Rename 'ds' to 'date' for feature engineering
|
||||
df_prep = df.copy()
|
||||
if 'ds' in df_prep.columns:
|
||||
df_prep['date'] = df_prep['ds']
|
||||
|
||||
# Ensure 'quantity' column for feature engineering
|
||||
if 'y' in df_prep.columns:
|
||||
df_prep['quantity'] = df_prep['y']
|
||||
|
||||
# Create all enhanced features
|
||||
df_enhanced = self.feature_engineer.create_all_features(
|
||||
df_prep,
|
||||
date_column='date',
|
||||
include_lags=True,
|
||||
include_rolling=True,
|
||||
include_interactions=True,
|
||||
include_cyclical=True
|
||||
)
|
||||
|
||||
# Fill NA values (from lagged features at beginning)
|
||||
df_enhanced = self.feature_engineer.fill_na_values(df_enhanced)
|
||||
|
||||
# Get feature column list (excluding target and date columns)
|
||||
self.feature_columns = [
|
||||
col for col in self.feature_engineer.get_feature_columns()
|
||||
if col in df_enhanced.columns
|
||||
]
|
||||
|
||||
# Also include original regressor columns if present
|
||||
regressor_cols = [
|
||||
col for col in df.columns
|
||||
if col not in ['ds', 'y', 'date', 'quantity'] and col in df_enhanced.columns
|
||||
]
|
||||
|
||||
self.feature_columns.extend(regressor_cols)
|
||||
self.feature_columns = list(set(self.feature_columns)) # Remove duplicates
|
||||
|
||||
logger.info(f"Prepared {len(self.feature_columns)} features for XGBoost")
|
||||
|
||||
return df_enhanced
|
||||
|
||||
async def _get_prophet_predictions(
|
||||
self,
|
||||
prophet_result: Dict[str, Any],
|
||||
df: pd.DataFrame
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Get Prophet predictions for given dataframe.
|
||||
|
||||
Args:
|
||||
prophet_result: Prophet model result from training (contains model_path)
|
||||
df: DataFrame with 'ds' column
|
||||
|
||||
Returns:
|
||||
Array of predictions
|
||||
"""
|
||||
# Get the model path from result instead of expecting the model object directly
|
||||
model_path = prophet_result.get('model_path')
|
||||
|
||||
if model_path is None:
|
||||
raise ValueError("Prophet model path not found in result")
|
||||
|
||||
# Load the actual Prophet model from the stored path
|
||||
try:
|
||||
if model_path.startswith("minio://"):
|
||||
# Use prophet_manager to load from MinIO
|
||||
prophet_model = await self.prophet_manager._load_model_from_minio(model_path)
|
||||
else:
|
||||
# Fallback to direct loading for local paths
|
||||
import joblib
|
||||
prophet_model = joblib.load(model_path)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load Prophet model from path {model_path}: {str(e)}")
|
||||
|
||||
# Prepare dataframe for prediction
|
||||
pred_df = df[['ds']].copy()
|
||||
|
||||
# Add regressors if present
|
||||
regressor_cols = [col for col in df.columns if col not in ['ds', 'y', 'date', 'quantity']]
|
||||
for col in regressor_cols:
|
||||
if col in df.columns:
|
||||
pred_df[col] = df[col]
|
||||
|
||||
# Get predictions
|
||||
forecast = prophet_model.predict(pred_df)
|
||||
|
||||
return forecast['yhat'].values
|
||||
|
||||
async def _train_xgboost(
|
||||
self,
|
||||
X_train: np.ndarray,
|
||||
y_train: np.ndarray,
|
||||
X_val: np.ndarray,
|
||||
y_val: np.ndarray
|
||||
) -> xgb.XGBRegressor:
|
||||
"""
|
||||
Train XGBoost model on residuals.
|
||||
|
||||
Args:
|
||||
X_train: Training features
|
||||
y_train: Training residuals
|
||||
X_val: Validation features
|
||||
y_val: Validation residuals
|
||||
|
||||
Returns:
|
||||
Trained XGBoost model
|
||||
"""
|
||||
# XGBoost parameters optimized for residual learning
|
||||
params = {
|
||||
'n_estimators': 100,
|
||||
'max_depth': 3, # Shallow trees to prevent overfitting
|
||||
'learning_rate': 0.1,
|
||||
'subsample': 0.8,
|
||||
'colsample_bytree': 0.8,
|
||||
'min_child_weight': 3,
|
||||
'reg_alpha': 0.1, # L1 regularization
|
||||
'reg_lambda': 1.0, # L2 regularization
|
||||
'objective': 'reg:squarederror',
|
||||
'random_state': 42,
|
||||
'n_jobs': -1,
|
||||
'early_stopping_rounds': 10
|
||||
}
|
||||
|
||||
# Initialize model
|
||||
model = xgb.XGBRegressor(**params)
|
||||
|
||||
# ✅ FIX: Run blocking model.fit() in thread pool to avoid blocking event loop
|
||||
import asyncio
|
||||
await asyncio.to_thread(
|
||||
model.fit,
|
||||
X_train, y_train,
|
||||
eval_set=[(X_val, y_val)],
|
||||
verbose=False
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"XGBoost training complete",
|
||||
best_iteration=model.best_iteration if hasattr(model, 'best_iteration') else None
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
async def _evaluate_hybrid_model(
|
||||
self,
|
||||
train_df: pd.DataFrame,
|
||||
val_df: pd.DataFrame,
|
||||
train_prophet_pred: np.ndarray,
|
||||
val_prophet_pred: np.ndarray,
|
||||
prophet_result: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Evaluate the overall performance of the hybrid model using threading for metrics.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
# Get XGBoost predictions on training and validation
|
||||
X_train = train_df[self.feature_columns].values
|
||||
X_val = val_df[self.feature_columns].values
|
||||
|
||||
train_xgb_pred = await asyncio.to_thread(self.xgb_model.predict, X_train)
|
||||
val_xgb_pred = await asyncio.to_thread(self.xgb_model.predict, X_val)
|
||||
|
||||
# Hybrid prediction = Prophet prediction + XGBoost residual prediction
|
||||
train_hybrid_pred = train_prophet_pred + train_xgb_pred
|
||||
val_hybrid_pred = val_prophet_pred + val_xgb_pred
|
||||
|
||||
actual_train = train_df['y'].values
|
||||
actual_val = val_df['y'].values
|
||||
|
||||
# Basic RMSE calculation
|
||||
train_rmse = float(np.sqrt(np.mean((actual_train - train_hybrid_pred)**2)))
|
||||
val_rmse = float(np.sqrt(np.mean((actual_val - val_hybrid_pred)**2)))
|
||||
|
||||
# MAE
|
||||
train_mae = float(np.mean(np.abs(actual_train - train_hybrid_pred)))
|
||||
val_mae = float(np.mean(np.abs(actual_val - val_hybrid_pred)))
|
||||
|
||||
# MAPE (with safety for zero sales)
|
||||
train_mape = float(np.mean(np.abs((actual_train - train_hybrid_pred) / np.maximum(actual_train, 1))))
|
||||
val_mape = float(np.mean(np.abs((actual_val - val_hybrid_pred) / np.maximum(actual_val, 1))))
|
||||
|
||||
# Calculate improvement
|
||||
prophet_metrics = prophet_result.get("metrics", {})
|
||||
prophet_val_mae = prophet_metrics.get("val_mae", val_mae) # Fallback to hybrid if missing
|
||||
prophet_val_mape = prophet_metrics.get("val_mape", val_mape)
|
||||
|
||||
improvement_pct = 0.0
|
||||
if prophet_val_mape > 0:
|
||||
improvement_pct = ((prophet_val_mape - val_mape) / prophet_val_mape) * 100
|
||||
|
||||
metrics = {
|
||||
"train_rmse": train_rmse,
|
||||
"val_rmse": val_rmse,
|
||||
"train_mae": train_mae,
|
||||
"val_mae": val_mae,
|
||||
"train_mape": train_mape,
|
||||
"val_mape": val_mape,
|
||||
"prophet_val_mape": prophet_val_mape,
|
||||
"hybrid_val_mape": val_mape,
|
||||
"improvement_percentage": float(improvement_pct),
|
||||
"prophet_metrics": prophet_metrics
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"Hybrid model evaluation complete",
|
||||
val_rmse=val_rmse,
|
||||
val_mae=val_mae,
|
||||
val_mape=val_mape,
|
||||
improvement=improvement_pct
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
def _package_hybrid_model(
|
||||
self,
|
||||
prophet_result: Dict[str, Any],
|
||||
metrics: Dict[str, Any],
|
||||
tenant_id: str,
|
||||
inventory_product_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Package hybrid model for storage.
|
||||
"""
|
||||
return {
|
||||
'model_type': 'hybrid_prophet_xgboost',
|
||||
'prophet_model_path': prophet_result.get('model_path'),
|
||||
'xgboost_model': self.xgb_model,
|
||||
'feature_columns': self.feature_columns,
|
||||
'metrics': metrics,
|
||||
'tenant_id': tenant_id,
|
||||
'inventory_product_id': inventory_product_id,
|
||||
'trained_at': datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
async def predict(
|
||||
self,
|
||||
future_df: pd.DataFrame,
|
||||
model_data: Dict[str, Any]
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Make predictions using hybrid model.
|
||||
|
||||
Args:
|
||||
future_df: DataFrame with future dates and regressors
|
||||
model_data: Loaded hybrid model data
|
||||
|
||||
Returns:
|
||||
DataFrame with predictions
|
||||
"""
|
||||
# Step 1: Get Prophet model from path and make predictions
|
||||
prophet_model_path = model_data.get('prophet_model_path')
|
||||
if prophet_model_path is None:
|
||||
raise ValueError("Prophet model path not found in model data")
|
||||
|
||||
# Load the Prophet model from the stored path
|
||||
try:
|
||||
if prophet_model_path.startswith("minio://"):
|
||||
# Use prophet_manager to load from MinIO
|
||||
prophet_model = await self.prophet_manager._load_model_from_minio(prophet_model_path)
|
||||
else:
|
||||
# Fallback to direct loading for local paths
|
||||
import joblib
|
||||
prophet_model = joblib.load(prophet_model_path)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load Prophet model from path {prophet_model_path}: {str(e)}")
|
||||
|
||||
# ✅ FIX: Run blocking predict() in thread pool to avoid blocking event loop
|
||||
import asyncio
|
||||
prophet_forecast = await asyncio.to_thread(prophet_model.predict, future_df)
|
||||
|
||||
# Step 2: Prepare features for XGBoost
|
||||
future_enhanced = self._prepare_xgboost_features(future_df)
|
||||
|
||||
# Step 3: Get XGBoost predictions
|
||||
xgb_model = model_data['xgboost_model']
|
||||
feature_columns = model_data['feature_columns']
|
||||
X_future = future_enhanced[feature_columns].values
|
||||
# ✅ FIX: Run blocking predict() in thread pool to avoid blocking event loop
|
||||
xgb_pred = await asyncio.to_thread(xgb_model.predict, X_future)
|
||||
|
||||
# Step 4: Combine predictions
|
||||
hybrid_pred = prophet_forecast['yhat'].values + xgb_pred
|
||||
|
||||
# Step 5: Create result dataframe
|
||||
result = pd.DataFrame({
|
||||
'ds': future_df['ds'],
|
||||
'prophet_yhat': prophet_forecast['yhat'],
|
||||
'xgb_adjustment': xgb_pred,
|
||||
'yhat': hybrid_pred,
|
||||
'yhat_lower': prophet_forecast['yhat_lower'] + xgb_pred,
|
||||
'yhat_upper': prophet_forecast['yhat_upper'] + xgb_pred
|
||||
})
|
||||
|
||||
return result
|
||||
257
services/training/app/ml/model_selector.py
Normal file
257
services/training/app/ml/model_selector.py
Normal file
@@ -0,0 +1,257 @@
|
||||
"""
|
||||
Model Selection System
|
||||
Determines whether to use Prophet-only or Hybrid Prophet+XGBoost models
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Dict, Any, Optional
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class ModelSelector:
|
||||
"""
|
||||
Intelligent model selection based on data characteristics.
|
||||
|
||||
Decision Criteria:
|
||||
- Data size: Hybrid needs more data (min 90 days)
|
||||
- Complexity: High variance benefits from XGBoost
|
||||
- Seasonality strength: Weak seasonality benefits from XGBoost
|
||||
- Historical performance: Compare models on validation set
|
||||
"""
|
||||
|
||||
# Thresholds for model selection
|
||||
MIN_DATA_POINTS_HYBRID = 90 # Minimum data points for hybrid
|
||||
HIGH_VARIANCE_THRESHOLD = 0.5 # CV > 0.5 suggests complex patterns
|
||||
LOW_SEASONALITY_THRESHOLD = 0.3 # Weak seasonal patterns
|
||||
HYBRID_IMPROVEMENT_THRESHOLD = 0.05 # 5% MAPE improvement to justify hybrid
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def select_model_type(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
product_category: str = "unknown",
|
||||
force_prophet: bool = False,
|
||||
force_hybrid: bool = False
|
||||
) -> str:
|
||||
"""
|
||||
Select best model type based on data characteristics.
|
||||
|
||||
Args:
|
||||
df: Training data with 'y' column
|
||||
product_category: Product category (bread, pastries, etc.)
|
||||
force_prophet: Force Prophet-only model
|
||||
force_hybrid: Force hybrid model
|
||||
|
||||
Returns:
|
||||
"prophet" or "hybrid"
|
||||
"""
|
||||
# Honor forced selections
|
||||
if force_prophet:
|
||||
logger.info("Prophet-only model forced by configuration")
|
||||
return "prophet"
|
||||
|
||||
if force_hybrid:
|
||||
logger.info("Hybrid model forced by configuration")
|
||||
return "hybrid"
|
||||
|
||||
# Check minimum data requirements
|
||||
if len(df) < self.MIN_DATA_POINTS_HYBRID:
|
||||
logger.info(
|
||||
"Insufficient data for hybrid model, using Prophet",
|
||||
data_points=len(df),
|
||||
min_required=self.MIN_DATA_POINTS_HYBRID
|
||||
)
|
||||
return "prophet"
|
||||
|
||||
# Calculate data characteristics
|
||||
characteristics = self._analyze_data_characteristics(df)
|
||||
|
||||
# Decision logic
|
||||
score_hybrid = 0
|
||||
score_prophet = 0
|
||||
|
||||
# Factor 1: Data complexity (variance)
|
||||
if characteristics['coefficient_of_variation'] > self.HIGH_VARIANCE_THRESHOLD:
|
||||
score_hybrid += 2
|
||||
logger.debug("High variance detected, favoring hybrid", cv=characteristics['coefficient_of_variation'])
|
||||
else:
|
||||
score_prophet += 1
|
||||
|
||||
# Factor 2: Seasonality strength
|
||||
if characteristics['seasonality_strength'] < self.LOW_SEASONALITY_THRESHOLD:
|
||||
score_hybrid += 2
|
||||
logger.debug("Weak seasonality detected, favoring hybrid", strength=characteristics['seasonality_strength'])
|
||||
else:
|
||||
score_prophet += 1
|
||||
|
||||
# Factor 3: Data size (more data = better for hybrid)
|
||||
if len(df) > 180:
|
||||
score_hybrid += 1
|
||||
elif len(df) < 120:
|
||||
score_prophet += 1
|
||||
|
||||
# Factor 4: Product category considerations
|
||||
if product_category in ['seasonal', 'cakes']:
|
||||
# Event-driven products benefit from XGBoost pattern learning
|
||||
score_hybrid += 1
|
||||
elif product_category in ['bread', 'savory']:
|
||||
# Stable products work well with Prophet
|
||||
score_prophet += 1
|
||||
|
||||
# Factor 5: Zero ratio (sparse data)
|
||||
if characteristics['zero_ratio'] > 0.3:
|
||||
# High zero ratio suggests difficult forecasting, hybrid might help
|
||||
score_hybrid += 1
|
||||
|
||||
# Make decision
|
||||
selected_model = "hybrid" if score_hybrid > score_prophet else "prophet"
|
||||
|
||||
logger.info(
|
||||
"Model selection complete",
|
||||
selected_model=selected_model,
|
||||
score_hybrid=score_hybrid,
|
||||
score_prophet=score_prophet,
|
||||
data_points=len(df),
|
||||
cv=characteristics['coefficient_of_variation'],
|
||||
seasonality=characteristics['seasonality_strength'],
|
||||
category=product_category
|
||||
)
|
||||
|
||||
return selected_model
|
||||
|
||||
def _analyze_data_characteristics(self, df: pd.DataFrame) -> Dict[str, float]:
|
||||
"""
|
||||
Analyze time series characteristics.
|
||||
|
||||
Args:
|
||||
df: DataFrame with 'y' column (sales data)
|
||||
|
||||
Returns:
|
||||
Dictionary with data characteristics
|
||||
"""
|
||||
y = df['y'].values
|
||||
|
||||
# Coefficient of variation
|
||||
cv = np.std(y) / np.mean(y) if np.mean(y) > 0 else 0
|
||||
|
||||
# Zero ratio
|
||||
zero_ratio = (y == 0).sum() / len(y)
|
||||
|
||||
# Seasonality strength using autocorrelation at key lags (7 days, 30 days)
|
||||
# This better captures periodic patterns without using future data
|
||||
if len(df) >= 14:
|
||||
# Calculate autocorrelation at weekly lag (7 days)
|
||||
# Higher autocorrelation indicates stronger weekly patterns
|
||||
try:
|
||||
weekly_autocorr = pd.Series(y).autocorr(lag=7) if len(y) > 7 else 0
|
||||
|
||||
# Calculate autocorrelation at monthly lag if enough data
|
||||
monthly_autocorr = pd.Series(y).autocorr(lag=30) if len(y) > 30 else 0
|
||||
|
||||
# Combine autocorrelations (weekly weighted more for bakery data)
|
||||
seasonality_strength = abs(weekly_autocorr) * 0.7 + abs(monthly_autocorr) * 0.3
|
||||
|
||||
# Ensure in valid range [0, 1]
|
||||
seasonality_strength = max(0.0, min(1.0, seasonality_strength))
|
||||
except Exception:
|
||||
# Fallback to simpler calculation if autocorrelation fails
|
||||
seasonality_strength = 0.5
|
||||
else:
|
||||
seasonality_strength = 0.5 # Default
|
||||
|
||||
# Trend strength
|
||||
if len(df) >= 30:
|
||||
from scipy import stats
|
||||
x = np.arange(len(y))
|
||||
slope, _, r_value, _, _ = stats.linregress(x, y)
|
||||
trend_strength = abs(r_value)
|
||||
else:
|
||||
trend_strength = 0
|
||||
|
||||
return {
|
||||
'coefficient_of_variation': float(cv),
|
||||
'zero_ratio': float(zero_ratio),
|
||||
'seasonality_strength': float(seasonality_strength),
|
||||
'trend_strength': float(trend_strength),
|
||||
'mean': float(np.mean(y)),
|
||||
'std': float(np.std(y))
|
||||
}
|
||||
|
||||
def compare_models(
|
||||
self,
|
||||
prophet_metrics: Dict[str, float],
|
||||
hybrid_metrics: Dict[str, float]
|
||||
) -> str:
|
||||
"""
|
||||
Compare Prophet and Hybrid model performance.
|
||||
|
||||
Args:
|
||||
prophet_metrics: Prophet model metrics (with 'mape' key)
|
||||
hybrid_metrics: Hybrid model metrics (with 'mape' key)
|
||||
|
||||
Returns:
|
||||
"prophet" or "hybrid" based on better performance
|
||||
"""
|
||||
prophet_mape = prophet_metrics.get('mape', float('inf'))
|
||||
hybrid_mape = hybrid_metrics.get('mape', float('inf'))
|
||||
|
||||
# Calculate improvement
|
||||
if prophet_mape > 0:
|
||||
improvement = (prophet_mape - hybrid_mape) / prophet_mape
|
||||
else:
|
||||
improvement = 0
|
||||
|
||||
# Hybrid must improve by at least threshold to justify complexity
|
||||
if improvement >= self.HYBRID_IMPROVEMENT_THRESHOLD:
|
||||
logger.info(
|
||||
"Hybrid model selected based on performance",
|
||||
prophet_mape=prophet_mape,
|
||||
hybrid_mape=hybrid_mape,
|
||||
improvement=f"{improvement*100:.1f}%"
|
||||
)
|
||||
return "hybrid"
|
||||
else:
|
||||
logger.info(
|
||||
"Prophet model selected (hybrid improvement insufficient)",
|
||||
prophet_mape=prophet_mape,
|
||||
hybrid_mape=hybrid_mape,
|
||||
improvement=f"{improvement*100:.1f}%"
|
||||
)
|
||||
return "prophet"
|
||||
|
||||
|
||||
def should_use_hybrid_model(
|
||||
df: pd.DataFrame,
|
||||
product_category: str = "unknown",
|
||||
tenant_settings: Dict[str, Any] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Convenience function to determine if hybrid model should be used.
|
||||
|
||||
Args:
|
||||
df: Training data
|
||||
product_category: Product category
|
||||
tenant_settings: Optional tenant-specific settings
|
||||
|
||||
Returns:
|
||||
True if hybrid model should be used, False otherwise
|
||||
"""
|
||||
selector = ModelSelector()
|
||||
|
||||
# Check tenant settings
|
||||
force_prophet = tenant_settings.get('force_prophet_only', False) if tenant_settings else False
|
||||
force_hybrid = tenant_settings.get('force_hybrid', False) if tenant_settings else False
|
||||
|
||||
selected = selector.select_model_type(
|
||||
df=df,
|
||||
product_category=product_category,
|
||||
force_prophet=force_prophet,
|
||||
force_hybrid=force_hybrid
|
||||
)
|
||||
|
||||
return selected == "hybrid"
|
||||
192
services/training/app/ml/poi_feature_integrator.py
Normal file
192
services/training/app/ml/poi_feature_integrator.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""
|
||||
POI Feature Integrator
|
||||
|
||||
Integrates POI features into ML training pipeline.
|
||||
Fetches POI context from External service and merges features into training data.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List
|
||||
import structlog
|
||||
import pandas as pd
|
||||
|
||||
from shared.clients.external_client import ExternalServiceClient
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class POIFeatureIntegrator:
|
||||
"""
|
||||
POI feature integration for ML training.
|
||||
|
||||
Fetches POI context from External service and adds features
|
||||
to training dataframes for location-based demand forecasting.
|
||||
"""
|
||||
|
||||
def __init__(self, external_client: ExternalServiceClient = None):
|
||||
"""
|
||||
Initialize POI feature integrator.
|
||||
|
||||
Args:
|
||||
external_client: External service client instance (optional)
|
||||
"""
|
||||
if external_client is None:
|
||||
from app.core.config import settings
|
||||
self.external_client = ExternalServiceClient(settings, "training-service")
|
||||
else:
|
||||
self.external_client = external_client
|
||||
|
||||
async def fetch_poi_features(
|
||||
self,
|
||||
tenant_id: str,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
force_refresh: bool = False
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch POI features for tenant location (optimized for training).
|
||||
|
||||
First checks if POI context exists. If not, returns None without triggering detection.
|
||||
POI detection should be triggered during tenant registration, not during training.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID
|
||||
latitude: Bakery latitude
|
||||
longitude: Bakery longitude
|
||||
force_refresh: Force re-detection (only use if POI context already exists)
|
||||
|
||||
Returns:
|
||||
Dictionary with POI features or None if not available
|
||||
"""
|
||||
try:
|
||||
# Try to get existing POI context first
|
||||
existing_context = await self.external_client.get_poi_context(tenant_id)
|
||||
|
||||
if existing_context:
|
||||
poi_context = existing_context.get("poi_context", {})
|
||||
ml_features = poi_context.get("ml_features", {})
|
||||
|
||||
# Check if stale and force_refresh is requested
|
||||
is_stale = existing_context.get("is_stale", False)
|
||||
|
||||
if not is_stale or not force_refresh:
|
||||
logger.info(
|
||||
"Using existing POI context",
|
||||
tenant_id=tenant_id,
|
||||
is_stale=is_stale,
|
||||
feature_count=len(ml_features)
|
||||
)
|
||||
return ml_features
|
||||
else:
|
||||
logger.info(
|
||||
"POI context is stale and force_refresh=True, refreshing",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
# Only refresh if explicitly requested and context exists
|
||||
detection_result = await self.external_client.detect_poi_for_tenant(
|
||||
tenant_id=tenant_id,
|
||||
latitude=latitude,
|
||||
longitude=longitude,
|
||||
force_refresh=True
|
||||
)
|
||||
|
||||
if detection_result:
|
||||
poi_context = detection_result.get("poi_context", {})
|
||||
ml_features = poi_context.get("ml_features", {})
|
||||
logger.info(
|
||||
"POI refresh completed",
|
||||
tenant_id=tenant_id,
|
||||
feature_count=len(ml_features)
|
||||
)
|
||||
return ml_features
|
||||
else:
|
||||
logger.warning(
|
||||
"POI refresh failed, returning existing features",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
return ml_features
|
||||
else:
|
||||
logger.info(
|
||||
"No existing POI context found - POI detection should be triggered during tenant registration",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Error fetching POI features - returning None",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
return None
|
||||
|
||||
def add_poi_features_to_dataframe(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
poi_features: Dict[str, Any]
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Add POI features to training dataframe.
|
||||
|
||||
POI features are static (don't vary by date), so they're
|
||||
broadcast to all rows in the dataframe.
|
||||
|
||||
Args:
|
||||
df: Training dataframe
|
||||
poi_features: Dictionary of POI ML features
|
||||
|
||||
Returns:
|
||||
Dataframe with POI features added as columns
|
||||
"""
|
||||
if not poi_features:
|
||||
logger.warning("No POI features to add")
|
||||
return df
|
||||
|
||||
logger.info(
|
||||
"Adding POI features to dataframe",
|
||||
feature_count=len(poi_features),
|
||||
dataframe_rows=len(df)
|
||||
)
|
||||
|
||||
# Add each POI feature as a column with constant value
|
||||
for feature_name, feature_value in poi_features.items():
|
||||
df[feature_name] = feature_value
|
||||
|
||||
logger.info(
|
||||
"POI features added successfully",
|
||||
new_columns=list(poi_features.keys())
|
||||
)
|
||||
|
||||
return df
|
||||
|
||||
def get_poi_feature_names(self, poi_features: Dict[str, Any]) -> List[str]:
|
||||
"""
|
||||
Get list of POI feature names for model registration.
|
||||
|
||||
Args:
|
||||
poi_features: Dictionary of POI ML features
|
||||
|
||||
Returns:
|
||||
List of feature names
|
||||
"""
|
||||
return list(poi_features.keys()) if poi_features else []
|
||||
|
||||
async def check_poi_service_health(self) -> bool:
|
||||
"""
|
||||
Check if POI service is accessible through the external client.
|
||||
|
||||
Returns:
|
||||
True if service is healthy, False otherwise
|
||||
"""
|
||||
try:
|
||||
# We can test the external service health by attempting to get POI context for a dummy tenant
|
||||
# This will go through the proper authentication and routing
|
||||
dummy_context = await self.external_client.get_poi_context("test-tenant")
|
||||
# If we can successfully make a request (even if it returns None for missing tenant),
|
||||
# it means the service is accessible
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"POI service health check failed",
|
||||
error=str(e)
|
||||
)
|
||||
return False
|
||||
361
services/training/app/ml/product_categorizer.py
Normal file
361
services/training/app/ml/product_categorizer.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""
|
||||
Product Categorization System
|
||||
Classifies bakery products into categories for category-specific forecasting
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from enum import Enum
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class ProductCategory(str, Enum):
|
||||
"""Product categories for bakery items"""
|
||||
BREAD = "bread"
|
||||
PASTRIES = "pastries"
|
||||
CAKES = "cakes"
|
||||
DRINKS = "drinks"
|
||||
SEASONAL = "seasonal"
|
||||
SAVORY = "savory"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class ProductCategorizer:
|
||||
"""
|
||||
Automatic product categorization based on product name and sales patterns.
|
||||
|
||||
Categories have different characteristics:
|
||||
- BREAD: Daily staple, high volume, consistent demand, short shelf life (1 day)
|
||||
- PASTRIES: Morning peak, weekend boost, medium shelf life (2-3 days)
|
||||
- CAKES: Event-driven, weekends, advance orders, longer shelf life (3-5 days)
|
||||
- DRINKS: Weather-dependent, hot/cold seasonal patterns
|
||||
- SEASONAL: Holiday-specific (roscón, panettone, etc.)
|
||||
- SAVORY: Lunch peak, weekday focus
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Keywords for automatic classification
|
||||
self.category_keywords = {
|
||||
ProductCategory.BREAD: [
|
||||
'pan', 'baguette', 'hogaza', 'chapata', 'integral', 'centeno',
|
||||
'bread', 'loaf', 'barra', 'molde', 'candeal'
|
||||
],
|
||||
ProductCategory.PASTRIES: [
|
||||
'croissant', 'napolitana', 'palmera', 'ensaimada', 'magdalena',
|
||||
'bollo', 'brioche', 'suizo', 'caracola', 'donut', 'berlina'
|
||||
],
|
||||
ProductCategory.CAKES: [
|
||||
'tarta', 'pastel', 'bizcocho', 'cake', 'torta', 'milhojas',
|
||||
'saint honoré', 'selva negra', 'tres leches'
|
||||
],
|
||||
ProductCategory.DRINKS: [
|
||||
'café', 'coffee', 'té', 'tea', 'zumo', 'juice', 'batido',
|
||||
'smoothie', 'refresco', 'agua', 'water'
|
||||
],
|
||||
ProductCategory.SEASONAL: [
|
||||
'roscón', 'panettone', 'turrón', 'polvorón', 'mona de pascua',
|
||||
'huevo de pascua', 'buñuelo', 'torrija'
|
||||
],
|
||||
ProductCategory.SAVORY: [
|
||||
'empanada', 'quiche', 'pizza', 'focaccia', 'salado', 'bocadillo',
|
||||
'sandwich', 'croqueta', 'hojaldre salado'
|
||||
]
|
||||
}
|
||||
|
||||
def categorize_product(
|
||||
self,
|
||||
product_name: str,
|
||||
product_id: str = None,
|
||||
sales_data: pd.DataFrame = None
|
||||
) -> ProductCategory:
|
||||
"""
|
||||
Categorize a product based on name and optional sales patterns.
|
||||
|
||||
Args:
|
||||
product_name: Product name
|
||||
product_id: Optional product ID
|
||||
sales_data: Optional historical sales data for pattern analysis
|
||||
|
||||
Returns:
|
||||
ProductCategory enum
|
||||
"""
|
||||
# First try keyword matching
|
||||
category = self._categorize_by_keywords(product_name)
|
||||
|
||||
if category != ProductCategory.UNKNOWN:
|
||||
logger.info(f"Product categorized by keywords",
|
||||
product=product_name,
|
||||
category=category.value)
|
||||
return category
|
||||
|
||||
# If no keyword match and we have sales data, analyze patterns
|
||||
if sales_data is not None and len(sales_data) > 30:
|
||||
category = self._categorize_by_sales_pattern(product_name, sales_data)
|
||||
logger.info(f"Product categorized by sales pattern",
|
||||
product=product_name,
|
||||
category=category.value)
|
||||
return category
|
||||
|
||||
logger.warning(f"Could not categorize product, using UNKNOWN",
|
||||
product=product_name)
|
||||
return ProductCategory.UNKNOWN
|
||||
|
||||
def _categorize_by_keywords(self, product_name: str) -> ProductCategory:
|
||||
"""Categorize by matching keywords in product name"""
|
||||
product_name_lower = product_name.lower()
|
||||
|
||||
# Check each category's keywords
|
||||
for category, keywords in self.category_keywords.items():
|
||||
for keyword in keywords:
|
||||
if keyword in product_name_lower:
|
||||
return category
|
||||
|
||||
return ProductCategory.UNKNOWN
|
||||
|
||||
def _categorize_by_sales_pattern(
|
||||
self,
|
||||
product_name: str,
|
||||
sales_data: pd.DataFrame
|
||||
) -> ProductCategory:
|
||||
"""
|
||||
Categorize by analyzing sales patterns.
|
||||
|
||||
Patterns:
|
||||
- BREAD: Consistent daily sales, low variance
|
||||
- PASTRIES: Weekend boost, morning peak
|
||||
- CAKES: Weekend spike, event correlation
|
||||
- DRINKS: Temperature correlation
|
||||
- SEASONAL: Concentrated in specific months
|
||||
- SAVORY: Weekday focus, lunch peak
|
||||
"""
|
||||
try:
|
||||
# Ensure we have required columns
|
||||
if 'date' not in sales_data.columns or 'quantity' not in sales_data.columns:
|
||||
return ProductCategory.UNKNOWN
|
||||
|
||||
sales_data = sales_data.copy()
|
||||
sales_data['date'] = pd.to_datetime(sales_data['date'])
|
||||
sales_data['day_of_week'] = sales_data['date'].dt.dayofweek
|
||||
sales_data['month'] = sales_data['date'].dt.month
|
||||
sales_data['is_weekend'] = sales_data['day_of_week'].isin([5, 6])
|
||||
|
||||
# Calculate pattern metrics
|
||||
weekend_avg = sales_data[sales_data['is_weekend']]['quantity'].mean()
|
||||
weekday_avg = sales_data[~sales_data['is_weekend']]['quantity'].mean()
|
||||
overall_avg = sales_data['quantity'].mean()
|
||||
cv = sales_data['quantity'].std() / overall_avg if overall_avg > 0 else 0
|
||||
|
||||
# Weekend ratio
|
||||
weekend_ratio = weekend_avg / weekday_avg if weekday_avg > 0 else 1.0
|
||||
|
||||
# Seasonal concentration (Gini coefficient for months)
|
||||
monthly_sales = sales_data.groupby('month')['quantity'].sum()
|
||||
seasonal_concentration = self._gini_coefficient(monthly_sales.values)
|
||||
|
||||
# Decision rules based on patterns
|
||||
if seasonal_concentration > 0.6:
|
||||
# High concentration in specific months = seasonal
|
||||
return ProductCategory.SEASONAL
|
||||
|
||||
elif cv < 0.3 and weekend_ratio < 1.2:
|
||||
# Low variance, consistent daily = bread
|
||||
return ProductCategory.BREAD
|
||||
|
||||
elif weekend_ratio > 1.5:
|
||||
# Strong weekend boost = cakes
|
||||
return ProductCategory.CAKES
|
||||
|
||||
elif weekend_ratio > 1.2:
|
||||
# Moderate weekend boost = pastries
|
||||
return ProductCategory.PASTRIES
|
||||
|
||||
elif weekend_ratio < 0.9:
|
||||
# Weekday focus = savory
|
||||
return ProductCategory.SAVORY
|
||||
|
||||
else:
|
||||
return ProductCategory.UNKNOWN
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing sales pattern: {e}")
|
||||
return ProductCategory.UNKNOWN
|
||||
|
||||
def _gini_coefficient(self, values: np.ndarray) -> float:
|
||||
"""Calculate Gini coefficient for concentration measurement"""
|
||||
if len(values) == 0:
|
||||
return 0.0
|
||||
|
||||
sorted_values = np.sort(values)
|
||||
n = len(values)
|
||||
cumsum = np.cumsum(sorted_values)
|
||||
|
||||
# Gini coefficient formula
|
||||
return (2 * np.sum((np.arange(1, n + 1) * sorted_values))) / (n * cumsum[-1]) - (n + 1) / n
|
||||
|
||||
def get_category_characteristics(self, category: ProductCategory) -> Dict[str, any]:
|
||||
"""
|
||||
Get forecasting characteristics for a category.
|
||||
|
||||
Returns hyperparameters and settings specific to the category.
|
||||
"""
|
||||
characteristics = {
|
||||
ProductCategory.BREAD: {
|
||||
"shelf_life_days": 1,
|
||||
"demand_stability": "high",
|
||||
"seasonality_strength": "low",
|
||||
"weekend_factor": 0.95, # Slightly lower on weekends
|
||||
"holiday_factor": 0.7, # Much lower on holidays
|
||||
"weather_sensitivity": "low",
|
||||
"prophet_params": {
|
||||
"seasonality_mode": "additive",
|
||||
"yearly_seasonality": False,
|
||||
"weekly_seasonality": True,
|
||||
"daily_seasonality": False,
|
||||
"changepoint_prior_scale": 0.01, # Very stable
|
||||
"seasonality_prior_scale": 5.0
|
||||
}
|
||||
},
|
||||
ProductCategory.PASTRIES: {
|
||||
"shelf_life_days": 2,
|
||||
"demand_stability": "medium",
|
||||
"seasonality_strength": "medium",
|
||||
"weekend_factor": 1.3, # Boost on weekends
|
||||
"holiday_factor": 1.1, # Slight boost on holidays
|
||||
"weather_sensitivity": "medium",
|
||||
"prophet_params": {
|
||||
"seasonality_mode": "multiplicative",
|
||||
"yearly_seasonality": True,
|
||||
"weekly_seasonality": True,
|
||||
"daily_seasonality": False,
|
||||
"changepoint_prior_scale": 0.05,
|
||||
"seasonality_prior_scale": 10.0
|
||||
}
|
||||
},
|
||||
ProductCategory.CAKES: {
|
||||
"shelf_life_days": 4,
|
||||
"demand_stability": "low",
|
||||
"seasonality_strength": "high",
|
||||
"weekend_factor": 2.0, # Large weekend boost
|
||||
"holiday_factor": 1.5, # Holiday boost
|
||||
"weather_sensitivity": "low",
|
||||
"prophet_params": {
|
||||
"seasonality_mode": "multiplicative",
|
||||
"yearly_seasonality": True,
|
||||
"weekly_seasonality": True,
|
||||
"daily_seasonality": False,
|
||||
"changepoint_prior_scale": 0.1, # More flexible
|
||||
"seasonality_prior_scale": 15.0
|
||||
}
|
||||
},
|
||||
ProductCategory.DRINKS: {
|
||||
"shelf_life_days": 1,
|
||||
"demand_stability": "medium",
|
||||
"seasonality_strength": "high",
|
||||
"weekend_factor": 1.1,
|
||||
"holiday_factor": 1.2,
|
||||
"weather_sensitivity": "very_high",
|
||||
"prophet_params": {
|
||||
"seasonality_mode": "multiplicative",
|
||||
"yearly_seasonality": True,
|
||||
"weekly_seasonality": True,
|
||||
"daily_seasonality": False,
|
||||
"changepoint_prior_scale": 0.08,
|
||||
"seasonality_prior_scale": 12.0
|
||||
}
|
||||
},
|
||||
ProductCategory.SEASONAL: {
|
||||
"shelf_life_days": 7,
|
||||
"demand_stability": "very_low",
|
||||
"seasonality_strength": "very_high",
|
||||
"weekend_factor": 1.2,
|
||||
"holiday_factor": 3.0, # Massive holiday boost
|
||||
"weather_sensitivity": "low",
|
||||
"prophet_params": {
|
||||
"seasonality_mode": "multiplicative",
|
||||
"yearly_seasonality": True,
|
||||
"weekly_seasonality": False,
|
||||
"daily_seasonality": False,
|
||||
"changepoint_prior_scale": 0.2, # Very flexible
|
||||
"seasonality_prior_scale": 20.0
|
||||
}
|
||||
},
|
||||
ProductCategory.SAVORY: {
|
||||
"shelf_life_days": 1,
|
||||
"demand_stability": "medium",
|
||||
"seasonality_strength": "low",
|
||||
"weekend_factor": 0.8, # Lower on weekends
|
||||
"holiday_factor": 0.6, # Much lower on holidays
|
||||
"weather_sensitivity": "medium",
|
||||
"prophet_params": {
|
||||
"seasonality_mode": "additive",
|
||||
"yearly_seasonality": False,
|
||||
"weekly_seasonality": True,
|
||||
"daily_seasonality": False,
|
||||
"changepoint_prior_scale": 0.03,
|
||||
"seasonality_prior_scale": 7.0
|
||||
}
|
||||
},
|
||||
ProductCategory.UNKNOWN: {
|
||||
"shelf_life_days": 2,
|
||||
"demand_stability": "medium",
|
||||
"seasonality_strength": "medium",
|
||||
"weekend_factor": 1.0,
|
||||
"holiday_factor": 1.0,
|
||||
"weather_sensitivity": "medium",
|
||||
"prophet_params": {
|
||||
"seasonality_mode": "multiplicative",
|
||||
"yearly_seasonality": True,
|
||||
"weekly_seasonality": True,
|
||||
"daily_seasonality": False,
|
||||
"changepoint_prior_scale": 0.05,
|
||||
"seasonality_prior_scale": 10.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return characteristics.get(category, characteristics[ProductCategory.UNKNOWN])
|
||||
|
||||
def batch_categorize(
|
||||
self,
|
||||
products: List[Dict[str, any]],
|
||||
sales_data: pd.DataFrame = None
|
||||
) -> Dict[str, ProductCategory]:
|
||||
"""
|
||||
Categorize multiple products at once.
|
||||
|
||||
Args:
|
||||
products: List of dicts with 'id' and 'name' keys
|
||||
sales_data: Optional sales data with 'inventory_product_id' column
|
||||
|
||||
Returns:
|
||||
Dict mapping product_id to category
|
||||
"""
|
||||
results = {}
|
||||
|
||||
for product in products:
|
||||
product_id = product.get('id')
|
||||
product_name = product.get('name', '')
|
||||
|
||||
# Filter sales data for this product if available
|
||||
product_sales = None
|
||||
if sales_data is not None and 'inventory_product_id' in sales_data.columns:
|
||||
product_sales = sales_data[
|
||||
sales_data['inventory_product_id'] == product_id
|
||||
].copy()
|
||||
|
||||
category = self.categorize_product(
|
||||
product_name=product_name,
|
||||
product_id=product_id,
|
||||
sales_data=product_sales
|
||||
)
|
||||
|
||||
results[product_id] = category
|
||||
|
||||
logger.info(f"Batch categorization complete",
|
||||
total_products=len(products),
|
||||
categories=dict(pd.Series(list(results.values())).value_counts()))
|
||||
|
||||
return results
|
||||
1089
services/training/app/ml/prophet_manager.py
Normal file
1089
services/training/app/ml/prophet_manager.py
Normal file
File diff suppressed because it is too large
Load Diff
284
services/training/app/ml/traffic_forecaster.py
Normal file
284
services/training/app/ml/traffic_forecaster.py
Normal file
@@ -0,0 +1,284 @@
|
||||
"""
|
||||
Traffic Forecasting System
|
||||
Predicts bakery foot traffic using weather and temporal features
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Dict, List, Any, Optional
|
||||
from prophet import Prophet
|
||||
import structlog
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class TrafficForecaster:
|
||||
"""
|
||||
Forecast bakery foot traffic using Prophet with weather and temporal features.
|
||||
|
||||
Traffic patterns are influenced by:
|
||||
- Weather: Temperature, precipitation, conditions
|
||||
- Time: Day of week, holidays, season
|
||||
- Special events: Local events, promotions
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.is_trained = False
|
||||
|
||||
def train(
|
||||
self,
|
||||
historical_traffic: pd.DataFrame,
|
||||
weather_data: pd.DataFrame = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Train traffic forecasting model.
|
||||
|
||||
Args:
|
||||
historical_traffic: DataFrame with columns ['date', 'traffic_count']
|
||||
weather_data: Optional weather data with columns ['date', 'temperature', 'precipitation', 'condition']
|
||||
|
||||
Returns:
|
||||
Training metrics
|
||||
"""
|
||||
try:
|
||||
logger.info("Training traffic forecasting model",
|
||||
data_points=len(historical_traffic))
|
||||
|
||||
# Prepare Prophet format
|
||||
df = historical_traffic.copy()
|
||||
df = df.rename(columns={'date': 'ds', 'traffic_count': 'y'})
|
||||
df['ds'] = pd.to_datetime(df['ds'])
|
||||
df = df.sort_values('ds')
|
||||
|
||||
# Merge with weather data if available
|
||||
if weather_data is not None:
|
||||
weather_data = weather_data.copy()
|
||||
weather_data['date'] = pd.to_datetime(weather_data['date'])
|
||||
df = df.merge(weather_data, left_on='ds', right_on='date', how='left')
|
||||
|
||||
# Create Prophet model with custom settings for traffic
|
||||
self.model = Prophet(
|
||||
seasonality_mode='multiplicative',
|
||||
yearly_seasonality=True,
|
||||
weekly_seasonality=True,
|
||||
daily_seasonality=False,
|
||||
changepoint_prior_scale=0.05, # Moderate flexibility
|
||||
seasonality_prior_scale=10.0,
|
||||
holidays_prior_scale=10.0
|
||||
)
|
||||
|
||||
# Add weather regressors if available
|
||||
if 'temperature' in df.columns:
|
||||
self.model.add_regressor('temperature')
|
||||
if 'precipitation' in df.columns:
|
||||
self.model.add_regressor('precipitation')
|
||||
if 'is_rainy' in df.columns:
|
||||
self.model.add_regressor('is_rainy')
|
||||
|
||||
# Add custom holidays for Spain
|
||||
from app.ml.prophet_manager import BakeryProphetManager
|
||||
spanish_holidays = self._get_spanish_holidays(
|
||||
df['ds'].min().year,
|
||||
df['ds'].max().year + 1
|
||||
)
|
||||
self.model.add_country_holidays(country_name='ES')
|
||||
|
||||
# Fit model
|
||||
self.model.fit(df)
|
||||
self.is_trained = True
|
||||
|
||||
# Calculate training metrics
|
||||
predictions = self.model.predict(df)
|
||||
metrics = self._calculate_metrics(df['y'].values, predictions['yhat'].values)
|
||||
|
||||
logger.info("Traffic forecasting model trained successfully",
|
||||
mape=metrics['mape'],
|
||||
rmse=metrics['rmse'])
|
||||
|
||||
return metrics
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to train traffic forecasting model: {e}")
|
||||
raise
|
||||
|
||||
def predict(
|
||||
self,
|
||||
future_dates: pd.DatetimeIndex,
|
||||
weather_forecast: pd.DataFrame = None
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Predict traffic for future dates.
|
||||
|
||||
Args:
|
||||
future_dates: Dates to predict traffic for
|
||||
weather_forecast: Optional weather forecast data
|
||||
|
||||
Returns:
|
||||
DataFrame with columns ['date', 'predicted_traffic', 'yhat_lower', 'yhat_upper']
|
||||
"""
|
||||
if not self.is_trained:
|
||||
raise ValueError("Model not trained. Call train() first.")
|
||||
|
||||
try:
|
||||
# Create future dataframe
|
||||
future = pd.DataFrame({'ds': future_dates})
|
||||
|
||||
# Add weather features if available
|
||||
if weather_forecast is not None:
|
||||
weather_forecast = weather_forecast.copy()
|
||||
weather_forecast['date'] = pd.to_datetime(weather_forecast['date'])
|
||||
future = future.merge(weather_forecast, left_on='ds', right_on='date', how='left')
|
||||
|
||||
# Fill missing weather with defaults
|
||||
if 'temperature' in future.columns:
|
||||
future['temperature'].fillna(15.0, inplace=True)
|
||||
if 'precipitation' in future.columns:
|
||||
future['precipitation'].fillna(0.0, inplace=True)
|
||||
if 'is_rainy' in future.columns:
|
||||
future['is_rainy'].fillna(0, inplace=True)
|
||||
|
||||
# Predict
|
||||
forecast = self.model.predict(future)
|
||||
|
||||
# Format results
|
||||
results = pd.DataFrame({
|
||||
'date': forecast['ds'],
|
||||
'predicted_traffic': forecast['yhat'].clip(lower=0), # Traffic can't be negative
|
||||
'yhat_lower': forecast['yhat_lower'].clip(lower=0),
|
||||
'yhat_upper': forecast['yhat_upper'].clip(lower=0)
|
||||
})
|
||||
|
||||
logger.info("Traffic predictions generated",
|
||||
dates=len(results),
|
||||
avg_traffic=results['predicted_traffic'].mean())
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to predict traffic: {e}")
|
||||
raise
|
||||
|
||||
def _calculate_metrics(self, actual: np.ndarray, predicted: np.ndarray) -> Dict[str, float]:
|
||||
"""Calculate forecast accuracy metrics"""
|
||||
mae = np.mean(np.abs(actual - predicted))
|
||||
mse = np.mean((actual - predicted) ** 2)
|
||||
rmse = np.sqrt(mse)
|
||||
|
||||
# MAPE (handle zeros)
|
||||
mask = actual != 0
|
||||
mape = np.mean(np.abs((actual[mask] - predicted[mask]) / actual[mask])) * 100 if mask.any() else 0
|
||||
|
||||
return {
|
||||
'mae': float(mae),
|
||||
'mse': float(mse),
|
||||
'rmse': float(rmse),
|
||||
'mape': float(mape)
|
||||
}
|
||||
|
||||
def _get_spanish_holidays(self, start_year: int, end_year: int) -> pd.DataFrame:
|
||||
"""Get Spanish holidays for the date range"""
|
||||
try:
|
||||
import holidays
|
||||
|
||||
es_holidays = holidays.Spain(years=range(start_year, end_year + 1))
|
||||
|
||||
holiday_dates = []
|
||||
holiday_names = []
|
||||
|
||||
for date, name in es_holidays.items():
|
||||
holiday_dates.append(date)
|
||||
holiday_names.append(name)
|
||||
|
||||
return pd.DataFrame({
|
||||
'ds': pd.to_datetime(holiday_dates),
|
||||
'holiday': holiday_names
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load Spanish holidays: {e}")
|
||||
return pd.DataFrame(columns=['ds', 'holiday'])
|
||||
|
||||
|
||||
class TrafficFeatureGenerator:
|
||||
"""
|
||||
Generate traffic-related features for demand forecasting.
|
||||
Uses predicted traffic as a feature in product demand models.
|
||||
"""
|
||||
|
||||
def __init__(self, traffic_forecaster: TrafficForecaster = None):
|
||||
self.traffic_forecaster = traffic_forecaster or TrafficForecaster()
|
||||
|
||||
def generate_traffic_features(
|
||||
self,
|
||||
dates: pd.DatetimeIndex,
|
||||
weather_forecast: pd.DataFrame = None
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Generate traffic features for given dates.
|
||||
|
||||
Args:
|
||||
dates: Dates to generate features for
|
||||
weather_forecast: Optional weather forecast
|
||||
|
||||
Returns:
|
||||
DataFrame with traffic features
|
||||
"""
|
||||
if not self.traffic_forecaster.is_trained:
|
||||
logger.warning("Traffic forecaster not trained, using default traffic values")
|
||||
return pd.DataFrame({
|
||||
'date': dates,
|
||||
'predicted_traffic': 100.0, # Default baseline
|
||||
'traffic_normalized': 1.0
|
||||
})
|
||||
|
||||
# Predict traffic
|
||||
traffic_predictions = self.traffic_forecaster.predict(dates, weather_forecast)
|
||||
|
||||
# Normalize traffic (0-2 range, 1 = average)
|
||||
mean_traffic = traffic_predictions['predicted_traffic'].mean()
|
||||
traffic_predictions['traffic_normalized'] = (
|
||||
traffic_predictions['predicted_traffic'] / mean_traffic
|
||||
).clip(0, 2)
|
||||
|
||||
# Add traffic categories
|
||||
traffic_predictions['traffic_category'] = pd.cut(
|
||||
traffic_predictions['predicted_traffic'],
|
||||
bins=[0, 50, 100, 150, np.inf],
|
||||
labels=['low', 'medium', 'high', 'very_high']
|
||||
)
|
||||
|
||||
return traffic_predictions
|
||||
|
||||
def add_traffic_features_to_forecast_data(
|
||||
self,
|
||||
forecast_data: pd.DataFrame,
|
||||
traffic_predictions: pd.DataFrame
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Add traffic features to forecast input data.
|
||||
|
||||
Args:
|
||||
forecast_data: Existing forecast data with 'date' column
|
||||
traffic_predictions: Traffic predictions from generate_traffic_features()
|
||||
|
||||
Returns:
|
||||
Enhanced forecast data with traffic features
|
||||
"""
|
||||
forecast_data = forecast_data.copy()
|
||||
forecast_data['date'] = pd.to_datetime(forecast_data['date'])
|
||||
traffic_predictions['date'] = pd.to_datetime(traffic_predictions['date'])
|
||||
|
||||
# Merge traffic features
|
||||
enhanced_data = forecast_data.merge(
|
||||
traffic_predictions[['date', 'predicted_traffic', 'traffic_normalized']],
|
||||
on='date',
|
||||
how='left'
|
||||
)
|
||||
|
||||
# Fill missing with defaults
|
||||
enhanced_data['predicted_traffic'].fillna(100.0, inplace=True)
|
||||
enhanced_data['traffic_normalized'].fillna(1.0, inplace=True)
|
||||
|
||||
return enhanced_data
|
||||
1375
services/training/app/ml/trainer.py
Normal file
1375
services/training/app/ml/trainer.py
Normal file
File diff suppressed because it is too large
Load Diff
33
services/training/app/models/__init__.py
Normal file
33
services/training/app/models/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""
|
||||
Training Service Models Package
|
||||
|
||||
Import all models to ensure they are registered with SQLAlchemy Base.
|
||||
"""
|
||||
|
||||
# Import AuditLog model for this service
|
||||
from shared.security import create_audit_log_model
|
||||
from shared.database.base import Base
|
||||
|
||||
# Create audit log model for this service
|
||||
AuditLog = create_audit_log_model(Base)
|
||||
|
||||
# Import all models to register them with the Base metadata
|
||||
from .training import (
|
||||
TrainedModel,
|
||||
ModelTrainingLog,
|
||||
ModelPerformanceMetric,
|
||||
TrainingJobQueue,
|
||||
ModelArtifact,
|
||||
TrainingPerformanceMetrics,
|
||||
)
|
||||
|
||||
# List all models for easier access
|
||||
__all__ = [
|
||||
"TrainedModel",
|
||||
"ModelTrainingLog",
|
||||
"ModelPerformanceMetric",
|
||||
"TrainingJobQueue",
|
||||
"ModelArtifact",
|
||||
"TrainingPerformanceMetrics",
|
||||
"AuditLog",
|
||||
]
|
||||
254
services/training/app/models/training.py
Normal file
254
services/training/app/models/training.py
Normal file
@@ -0,0 +1,254 @@
|
||||
# services/training/app/models/training.py
|
||||
"""
|
||||
Database models for training service
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Text, Boolean, JSON, Float
|
||||
from sqlalchemy.dialects.postgresql import UUID, ARRAY
|
||||
from shared.database.base import Base
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
|
||||
|
||||
class ModelTrainingLog(Base):
|
||||
"""
|
||||
Table to track training job execution and status.
|
||||
Replaces the old Celery task tracking.
|
||||
"""
|
||||
__tablename__ = "model_training_logs"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
job_id = Column(String(255), unique=True, index=True, nullable=False)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
status = Column(String(50), nullable=False, default="pending") # pending, running, completed, failed, cancelled
|
||||
progress = Column(Integer, default=0) # 0-100 percentage
|
||||
current_step = Column(String(500), default="")
|
||||
|
||||
# Timestamps
|
||||
start_time = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
end_time = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Configuration and results
|
||||
config = Column(JSON, nullable=True) # Training job configuration
|
||||
results = Column(JSON, nullable=True) # Training results
|
||||
error_message = Column(Text, nullable=True)
|
||||
|
||||
# Metadata
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
||||
class ModelPerformanceMetric(Base):
|
||||
"""
|
||||
Table to track model performance over time.
|
||||
"""
|
||||
__tablename__ = "model_performance_metrics"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
model_id = Column(String(255), index=True, nullable=False)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
inventory_product_id = Column(UUID(as_uuid=True), index=True, nullable=False)
|
||||
|
||||
# Performance metrics
|
||||
mae = Column(Float, nullable=True) # Mean Absolute Error
|
||||
mse = Column(Float, nullable=True) # Mean Squared Error
|
||||
rmse = Column(Float, nullable=True) # Root Mean Squared Error
|
||||
mape = Column(Float, nullable=True) # Mean Absolute Percentage Error
|
||||
r2_score = Column(Float, nullable=True) # R-squared score
|
||||
|
||||
# Additional metrics
|
||||
accuracy_percentage = Column(Float, nullable=True)
|
||||
prediction_confidence = Column(Float, nullable=True)
|
||||
|
||||
# Evaluation information
|
||||
evaluation_period_start = Column(DateTime, nullable=True)
|
||||
evaluation_period_end = Column(DateTime, nullable=True)
|
||||
evaluation_samples = Column(Integer, nullable=True)
|
||||
|
||||
# Metadata
|
||||
measured_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
|
||||
class TrainingJobQueue(Base):
|
||||
"""
|
||||
Table to manage training job queue and scheduling.
|
||||
"""
|
||||
__tablename__ = "training_job_queue"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
job_id = Column(String(255), unique=True, index=True, nullable=False)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
|
||||
# Job configuration
|
||||
job_type = Column(String(50), nullable=False) # full_training, single_product, evaluation
|
||||
priority = Column(Integer, default=1) # Higher number = higher priority
|
||||
config = Column(JSON, nullable=True)
|
||||
|
||||
# Scheduling information
|
||||
scheduled_at = Column(DateTime, nullable=True)
|
||||
started_at = Column(DateTime, nullable=True)
|
||||
estimated_duration_minutes = Column(Integer, nullable=True)
|
||||
|
||||
# Status
|
||||
status = Column(String(50), nullable=False, default="queued") # queued, running, completed, failed
|
||||
retry_count = Column(Integer, default=0)
|
||||
max_retries = Column(Integer, default=3)
|
||||
|
||||
# Metadata
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
cancelled_by = Column(String, nullable=True)
|
||||
|
||||
class ModelArtifact(Base):
|
||||
"""
|
||||
Table to track model files and artifacts.
|
||||
"""
|
||||
__tablename__ = "model_artifacts"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
model_id = Column(String(255), index=True, nullable=False)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
|
||||
# Artifact information
|
||||
artifact_type = Column(String(50), nullable=False) # model_file, metadata, training_data, etc.
|
||||
file_path = Column(String(1000), nullable=False)
|
||||
file_size_bytes = Column(Integer, nullable=True)
|
||||
checksum = Column(String(255), nullable=True) # For file integrity
|
||||
|
||||
# Storage information
|
||||
storage_location = Column(String(100), nullable=False, default="local") # local, s3, gcs, etc.
|
||||
compression = Column(String(50), nullable=True) # gzip, lz4, etc.
|
||||
|
||||
# Metadata
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
expires_at = Column(DateTime(timezone=True), nullable=True) # For automatic cleanup
|
||||
|
||||
class TrainedModel(Base):
|
||||
__tablename__ = "trained_models"
|
||||
|
||||
# Primary identification - Updated to use UUID properly
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
inventory_product_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
|
||||
# Model information
|
||||
model_type = Column(String, default="prophet_optimized")
|
||||
model_version = Column(String, default="1.0")
|
||||
job_id = Column(String, nullable=False)
|
||||
|
||||
# File storage
|
||||
model_path = Column(String, nullable=False) # Path to the .pkl file
|
||||
metadata_path = Column(String) # Path to metadata JSON
|
||||
|
||||
# Training metrics
|
||||
mape = Column(Float)
|
||||
mae = Column(Float)
|
||||
rmse = Column(Float)
|
||||
r2_score = Column(Float)
|
||||
training_samples = Column(Integer)
|
||||
|
||||
# Hyperparameters and features
|
||||
hyperparameters = Column(JSON) # Store optimized parameters
|
||||
features_used = Column(JSON) # List of regressor columns
|
||||
normalization_params = Column(JSON) # Store feature normalization parameters for consistent predictions
|
||||
product_category = Column(String, nullable=True) # Product category for category-specific forecasting
|
||||
|
||||
# Model status
|
||||
is_active = Column(Boolean, default=True)
|
||||
is_production = Column(Boolean, default=False)
|
||||
|
||||
# Timestamps - Updated to be timezone-aware with proper defaults
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
last_used_at = Column(DateTime(timezone=True))
|
||||
|
||||
# Training data info
|
||||
training_start_date = Column(DateTime(timezone=True))
|
||||
training_end_date = Column(DateTime(timezone=True))
|
||||
data_quality_score = Column(Float)
|
||||
|
||||
# Additional metadata
|
||||
notes = Column(Text)
|
||||
created_by = Column(String) # User who triggered training
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"model_id": str(self.id),
|
||||
"tenant_id": str(self.tenant_id),
|
||||
"inventory_product_id": str(self.inventory_product_id),
|
||||
"model_type": self.model_type,
|
||||
"model_version": self.model_version,
|
||||
"model_path": self.model_path,
|
||||
"mape": self.mape,
|
||||
"mae": self.mae,
|
||||
"rmse": self.rmse,
|
||||
"r2_score": self.r2_score,
|
||||
"training_samples": self.training_samples,
|
||||
"hyperparameters": self.hyperparameters,
|
||||
"features_used": self.features_used,
|
||||
"features": self.features_used, # Alias for frontend compatibility (ModelDetailsModal expects 'features')
|
||||
"product_category": self.product_category,
|
||||
"is_active": self.is_active,
|
||||
"is_production": self.is_production,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
"last_used_at": self.last_used_at.isoformat() if self.last_used_at else None,
|
||||
"training_start_date": self.training_start_date.isoformat() if self.training_start_date else None,
|
||||
"training_end_date": self.training_end_date.isoformat() if self.training_end_date else None,
|
||||
"data_quality_score": self.data_quality_score
|
||||
}
|
||||
|
||||
|
||||
class TrainingPerformanceMetrics(Base):
|
||||
"""
|
||||
Table to track historical training performance for time estimation.
|
||||
Stores aggregated metrics from completed training jobs.
|
||||
"""
|
||||
__tablename__ = "training_performance_metrics"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
job_id = Column(String(255), nullable=False, index=True)
|
||||
|
||||
# Training job statistics
|
||||
total_products = Column(Integer, nullable=False)
|
||||
successful_products = Column(Integer, nullable=False)
|
||||
failed_products = Column(Integer, nullable=False)
|
||||
|
||||
# Time metrics
|
||||
total_duration_seconds = Column(Float, nullable=False)
|
||||
avg_time_per_product = Column(Float, nullable=False) # Key metric for estimation
|
||||
data_analysis_time_seconds = Column(Float, nullable=True)
|
||||
training_time_seconds = Column(Float, nullable=True)
|
||||
finalization_time_seconds = Column(Float, nullable=True)
|
||||
|
||||
# Job metadata
|
||||
completed_at = Column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc))
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<TrainingPerformanceMetrics("
|
||||
f"tenant_id={self.tenant_id}, "
|
||||
f"job_id={self.job_id}, "
|
||||
f"total_products={self.total_products}, "
|
||||
f"avg_time_per_product={self.avg_time_per_product:.2f}s"
|
||||
f")>"
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"tenant_id": str(self.tenant_id),
|
||||
"job_id": self.job_id,
|
||||
"total_products": self.total_products,
|
||||
"successful_products": self.successful_products,
|
||||
"failed_products": self.failed_products,
|
||||
"total_duration_seconds": self.total_duration_seconds,
|
||||
"avg_time_per_product": self.avg_time_per_product,
|
||||
"data_analysis_time_seconds": self.data_analysis_time_seconds,
|
||||
"training_time_seconds": self.training_time_seconds,
|
||||
"finalization_time_seconds": self.finalization_time_seconds,
|
||||
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None
|
||||
}
|
||||
11
services/training/app/models/training_models.py
Normal file
11
services/training/app/models/training_models.py
Normal file
@@ -0,0 +1,11 @@
|
||||
# services/training/app/models/training_models.py
|
||||
"""
|
||||
Legacy file - TrainedModel has been moved to training.py
|
||||
This file is deprecated and should be removed after migration.
|
||||
"""
|
||||
|
||||
# Import the actual model from the correct location
|
||||
from .training import TrainedModel
|
||||
|
||||
# For backward compatibility, re-export the model
|
||||
__all__ = ["TrainedModel"]
|
||||
20
services/training/app/repositories/__init__.py
Normal file
20
services/training/app/repositories/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
Training Service Repositories
|
||||
Repository implementations for training service
|
||||
"""
|
||||
|
||||
from .base import TrainingBaseRepository
|
||||
from .model_repository import ModelRepository
|
||||
from .training_log_repository import TrainingLogRepository
|
||||
from .performance_repository import PerformanceRepository
|
||||
from .job_queue_repository import JobQueueRepository
|
||||
from .artifact_repository import ArtifactRepository
|
||||
|
||||
__all__ = [
|
||||
"TrainingBaseRepository",
|
||||
"ModelRepository",
|
||||
"TrainingLogRepository",
|
||||
"PerformanceRepository",
|
||||
"JobQueueRepository",
|
||||
"ArtifactRepository"
|
||||
]
|
||||
560
services/training/app/repositories/artifact_repository.py
Normal file
560
services/training/app/repositories/artifact_repository.py
Normal file
@@ -0,0 +1,560 @@
|
||||
"""
|
||||
Artifact Repository
|
||||
Repository for model artifact operations
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, text, desc
|
||||
from datetime import datetime, timedelta
|
||||
import structlog
|
||||
|
||||
from .base import TrainingBaseRepository
|
||||
from app.models.training import ModelArtifact
|
||||
from shared.database.exceptions import DatabaseError, ValidationError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class ArtifactRepository(TrainingBaseRepository):
|
||||
"""Repository for model artifact operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 1800):
|
||||
# Artifacts are stable, longer cache time (30 minutes)
|
||||
super().__init__(ModelArtifact, session, cache_ttl)
|
||||
|
||||
async def create_artifact(self, artifact_data: Dict[str, Any]) -> ModelArtifact:
|
||||
"""Create a new model artifact record"""
|
||||
try:
|
||||
# Validate artifact data
|
||||
validation_result = self._validate_training_data(
|
||||
artifact_data,
|
||||
["model_id", "tenant_id", "artifact_type", "file_path"]
|
||||
)
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid artifact data: {validation_result['errors']}")
|
||||
|
||||
# Set default values
|
||||
if "storage_location" not in artifact_data:
|
||||
artifact_data["storage_location"] = "local"
|
||||
|
||||
# Create artifact record
|
||||
artifact = await self.create(artifact_data)
|
||||
|
||||
logger.info("Model artifact created",
|
||||
model_id=artifact.model_id,
|
||||
tenant_id=artifact.tenant_id,
|
||||
artifact_type=artifact.artifact_type,
|
||||
file_path=artifact.file_path)
|
||||
|
||||
return artifact
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to create model artifact",
|
||||
model_id=artifact_data.get("model_id"),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create artifact: {str(e)}")
|
||||
|
||||
async def get_artifacts_by_model(
|
||||
self,
|
||||
model_id: str,
|
||||
artifact_type: str = None
|
||||
) -> List[ModelArtifact]:
|
||||
"""Get all artifacts for a model"""
|
||||
try:
|
||||
filters = {"model_id": model_id}
|
||||
if artifact_type:
|
||||
filters["artifact_type"] = artifact_type
|
||||
|
||||
return await self.get_multi(
|
||||
filters=filters,
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get artifacts by model",
|
||||
model_id=model_id,
|
||||
artifact_type=artifact_type,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get artifacts: {str(e)}")
|
||||
|
||||
async def get_artifacts_by_tenant(
|
||||
self,
|
||||
tenant_id: str,
|
||||
artifact_type: str = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[ModelArtifact]:
|
||||
"""Get artifacts for a tenant"""
|
||||
try:
|
||||
filters = {"tenant_id": tenant_id}
|
||||
if artifact_type:
|
||||
filters["artifact_type"] = artifact_type
|
||||
|
||||
return await self.get_multi(
|
||||
filters=filters,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get artifacts by tenant",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get tenant artifacts: {str(e)}")
|
||||
|
||||
async def get_artifact_by_path(self, file_path: str) -> Optional[ModelArtifact]:
|
||||
"""Get artifact by file path"""
|
||||
try:
|
||||
return await self.get_by_field("file_path", file_path)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get artifact by path",
|
||||
file_path=file_path,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get artifact: {str(e)}")
|
||||
|
||||
async def update_artifact_size(self, artifact_id: int, file_size_bytes: int) -> Optional[ModelArtifact]:
|
||||
"""Update artifact file size"""
|
||||
try:
|
||||
return await self.update(artifact_id, {"file_size_bytes": file_size_bytes})
|
||||
except Exception as e:
|
||||
logger.error("Failed to update artifact size",
|
||||
artifact_id=artifact_id,
|
||||
error=str(e))
|
||||
return None
|
||||
|
||||
async def update_artifact_checksum(self, artifact_id: int, checksum: str) -> Optional[ModelArtifact]:
|
||||
"""Update artifact checksum for integrity verification"""
|
||||
try:
|
||||
return await self.update(artifact_id, {"checksum": checksum})
|
||||
except Exception as e:
|
||||
logger.error("Failed to update artifact checksum",
|
||||
artifact_id=artifact_id,
|
||||
error=str(e))
|
||||
return None
|
||||
|
||||
async def mark_artifact_expired(self, artifact_id: int, expires_at: datetime = None) -> Optional[ModelArtifact]:
|
||||
"""Mark artifact for expiration/cleanup"""
|
||||
try:
|
||||
if not expires_at:
|
||||
expires_at = datetime.now()
|
||||
|
||||
return await self.update(artifact_id, {"expires_at": expires_at})
|
||||
except Exception as e:
|
||||
logger.error("Failed to mark artifact as expired",
|
||||
artifact_id=artifact_id,
|
||||
error=str(e))
|
||||
return None
|
||||
|
||||
async def get_expired_artifacts(self, days_expired: int = 0) -> List[ModelArtifact]:
|
||||
"""Get artifacts that have expired"""
|
||||
try:
|
||||
cutoff_date = datetime.now() - timedelta(days=days_expired)
|
||||
|
||||
query_text = """
|
||||
SELECT * FROM model_artifacts
|
||||
WHERE expires_at IS NOT NULL
|
||||
AND expires_at <= :cutoff_date
|
||||
ORDER BY expires_at ASC
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
|
||||
|
||||
expired_artifacts = []
|
||||
for row in result.fetchall():
|
||||
record_dict = dict(row._mapping)
|
||||
artifact = self.model(**record_dict)
|
||||
expired_artifacts.append(artifact)
|
||||
|
||||
return expired_artifacts
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get expired artifacts",
|
||||
days_expired=days_expired,
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def cleanup_expired_artifacts(self, days_expired: int = 0) -> int:
|
||||
"""Clean up expired artifacts"""
|
||||
try:
|
||||
cutoff_date = datetime.now() - timedelta(days=days_expired)
|
||||
|
||||
query_text = """
|
||||
DELETE FROM model_artifacts
|
||||
WHERE expires_at IS NOT NULL
|
||||
AND expires_at <= :cutoff_date
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
|
||||
deleted_count = result.rowcount
|
||||
|
||||
logger.info("Cleaned up expired artifacts",
|
||||
deleted_count=deleted_count,
|
||||
days_expired=days_expired)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup expired artifacts",
|
||||
days_expired=days_expired,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Artifact cleanup failed: {str(e)}")
|
||||
|
||||
async def get_large_artifacts(self, min_size_mb: int = 100) -> List[ModelArtifact]:
|
||||
"""Get artifacts larger than specified size"""
|
||||
try:
|
||||
min_size_bytes = min_size_mb * 1024 * 1024 # Convert MB to bytes
|
||||
|
||||
query_text = """
|
||||
SELECT * FROM model_artifacts
|
||||
WHERE file_size_bytes >= :min_size_bytes
|
||||
ORDER BY file_size_bytes DESC
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {"min_size_bytes": min_size_bytes})
|
||||
|
||||
large_artifacts = []
|
||||
for row in result.fetchall():
|
||||
record_dict = dict(row._mapping)
|
||||
artifact = self.model(**record_dict)
|
||||
large_artifacts.append(artifact)
|
||||
|
||||
return large_artifacts
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get large artifacts",
|
||||
min_size_mb=min_size_mb,
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def get_artifacts_by_storage_location(
|
||||
self,
|
||||
storage_location: str,
|
||||
tenant_id: str = None
|
||||
) -> List[ModelArtifact]:
|
||||
"""Get artifacts by storage location"""
|
||||
try:
|
||||
filters = {"storage_location": storage_location}
|
||||
if tenant_id:
|
||||
filters["tenant_id"] = tenant_id
|
||||
|
||||
return await self.get_multi(
|
||||
filters=filters,
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get artifacts by storage location",
|
||||
storage_location=storage_location,
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get artifacts: {str(e)}")
|
||||
|
||||
async def get_artifact_statistics(self, tenant_id: str = None) -> Dict[str, Any]:
|
||||
"""Get artifact statistics"""
|
||||
try:
|
||||
base_filters = {}
|
||||
if tenant_id:
|
||||
base_filters["tenant_id"] = tenant_id
|
||||
|
||||
# Get basic counts
|
||||
total_artifacts = await self.count(filters=base_filters)
|
||||
|
||||
# Get artifacts by type
|
||||
type_query_params = {}
|
||||
type_query_filter = ""
|
||||
if tenant_id:
|
||||
type_query_filter = "WHERE tenant_id = :tenant_id"
|
||||
type_query_params["tenant_id"] = tenant_id
|
||||
|
||||
type_query = text(f"""
|
||||
SELECT artifact_type, COUNT(*) as count
|
||||
FROM model_artifacts
|
||||
{type_query_filter}
|
||||
GROUP BY artifact_type
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(type_query, type_query_params)
|
||||
artifacts_by_type = {row.artifact_type: row.count for row in result.fetchall()}
|
||||
|
||||
# Get storage location stats
|
||||
location_query = text(f"""
|
||||
SELECT
|
||||
storage_location,
|
||||
COUNT(*) as count,
|
||||
SUM(COALESCE(file_size_bytes, 0)) as total_size_bytes
|
||||
FROM model_artifacts
|
||||
{type_query_filter}
|
||||
GROUP BY storage_location
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
|
||||
location_result = await self.session.execute(location_query, type_query_params)
|
||||
storage_stats = {}
|
||||
total_size_bytes = 0
|
||||
|
||||
for row in location_result.fetchall():
|
||||
storage_stats[row.storage_location] = {
|
||||
"artifact_count": row.count,
|
||||
"total_size_bytes": int(row.total_size_bytes or 0),
|
||||
"total_size_mb": round((row.total_size_bytes or 0) / (1024 * 1024), 2)
|
||||
}
|
||||
total_size_bytes += row.total_size_bytes or 0
|
||||
|
||||
# Get expired artifacts count
|
||||
expired_artifacts = len(await self.get_expired_artifacts())
|
||||
|
||||
return {
|
||||
"total_artifacts": total_artifacts,
|
||||
"expired_artifacts": expired_artifacts,
|
||||
"active_artifacts": total_artifacts - expired_artifacts,
|
||||
"artifacts_by_type": artifacts_by_type,
|
||||
"storage_statistics": storage_stats,
|
||||
"total_storage": {
|
||||
"total_size_bytes": total_size_bytes,
|
||||
"total_size_mb": round(total_size_bytes / (1024 * 1024), 2),
|
||||
"total_size_gb": round(total_size_bytes / (1024 * 1024 * 1024), 2)
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get artifact statistics",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"total_artifacts": 0,
|
||||
"expired_artifacts": 0,
|
||||
"active_artifacts": 0,
|
||||
"artifacts_by_type": {},
|
||||
"storage_statistics": {},
|
||||
"total_storage": {
|
||||
"total_size_bytes": 0,
|
||||
"total_size_mb": 0.0,
|
||||
"total_size_gb": 0.0
|
||||
}
|
||||
}
|
||||
|
||||
async def verify_artifact_integrity(self, artifact_id: int) -> Dict[str, Any]:
|
||||
"""Verify artifact file integrity with actual file system checks"""
|
||||
try:
|
||||
import os
|
||||
import hashlib
|
||||
|
||||
artifact = await self.get_by_id(artifact_id)
|
||||
if not artifact:
|
||||
return {"exists": False, "error": "Artifact not found"}
|
||||
|
||||
# Check if file exists
|
||||
file_exists = os.path.exists(artifact.file_path)
|
||||
if not file_exists:
|
||||
return {
|
||||
"artifact_id": artifact_id,
|
||||
"file_path": artifact.file_path,
|
||||
"exists": False,
|
||||
"checksum_valid": False,
|
||||
"size_valid": False,
|
||||
"storage_location": artifact.storage_location,
|
||||
"last_verified": datetime.now().isoformat(),
|
||||
"error": "File does not exist on disk"
|
||||
}
|
||||
|
||||
# Verify file size
|
||||
actual_size = os.path.getsize(artifact.file_path)
|
||||
size_valid = True
|
||||
if artifact.file_size_bytes:
|
||||
size_valid = (actual_size == artifact.file_size_bytes)
|
||||
|
||||
# Verify checksum if stored
|
||||
checksum_valid = True
|
||||
actual_checksum = None
|
||||
if artifact.checksum:
|
||||
# Calculate checksum of actual file
|
||||
sha256_hash = hashlib.sha256()
|
||||
try:
|
||||
with open(artifact.file_path, "rb") as f:
|
||||
# Read file in chunks to handle large files
|
||||
for byte_block in iter(lambda: f.read(4096), b""):
|
||||
sha256_hash.update(byte_block)
|
||||
actual_checksum = sha256_hash.hexdigest()
|
||||
checksum_valid = (actual_checksum == artifact.checksum)
|
||||
except Exception as checksum_error:
|
||||
logger.error(f"Failed to calculate checksum: {checksum_error}")
|
||||
checksum_valid = False
|
||||
actual_checksum = None
|
||||
|
||||
# Overall integrity status
|
||||
integrity_valid = file_exists and size_valid and checksum_valid
|
||||
|
||||
result = {
|
||||
"artifact_id": artifact_id,
|
||||
"file_path": artifact.file_path,
|
||||
"exists": file_exists,
|
||||
"checksum_valid": checksum_valid,
|
||||
"size_valid": size_valid,
|
||||
"integrity_valid": integrity_valid,
|
||||
"storage_location": artifact.storage_location,
|
||||
"last_verified": datetime.now().isoformat(),
|
||||
"details": {
|
||||
"stored_size_bytes": artifact.file_size_bytes,
|
||||
"actual_size_bytes": actual_size if file_exists else None,
|
||||
"stored_checksum": artifact.checksum,
|
||||
"actual_checksum": actual_checksum
|
||||
}
|
||||
}
|
||||
|
||||
if not integrity_valid:
|
||||
issues = []
|
||||
if not file_exists:
|
||||
issues.append("file_missing")
|
||||
if not size_valid:
|
||||
issues.append("size_mismatch")
|
||||
if not checksum_valid:
|
||||
issues.append("checksum_mismatch")
|
||||
result["issues"] = issues
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to verify artifact integrity",
|
||||
artifact_id=artifact_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"exists": False,
|
||||
"error": f"Verification failed: {str(e)}"
|
||||
}
|
||||
|
||||
async def migrate_artifacts_to_storage(
|
||||
self,
|
||||
from_location: str,
|
||||
to_location: str,
|
||||
tenant_id: str = None,
|
||||
copy_only: bool = False,
|
||||
verify: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""Migrate artifacts from one storage location to another with actual file operations"""
|
||||
try:
|
||||
import os
|
||||
import shutil
|
||||
import hashlib
|
||||
|
||||
# Get artifacts to migrate
|
||||
artifacts = await self.get_artifacts_by_storage_location(from_location, tenant_id)
|
||||
|
||||
migrated_count = 0
|
||||
failed_count = 0
|
||||
failed_artifacts = []
|
||||
verified_count = 0
|
||||
|
||||
for artifact in artifacts:
|
||||
try:
|
||||
# Determine new file path
|
||||
new_file_path = artifact.file_path.replace(from_location, to_location, 1)
|
||||
|
||||
# Create destination directory if it doesn't exist
|
||||
dest_dir = os.path.dirname(new_file_path)
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
|
||||
# Check if source file exists
|
||||
if not os.path.exists(artifact.file_path):
|
||||
logger.warning(f"Source file not found: {artifact.file_path}")
|
||||
failed_count += 1
|
||||
failed_artifacts.append({
|
||||
"artifact_id": artifact.id,
|
||||
"file_path": artifact.file_path,
|
||||
"reason": "source_file_not_found"
|
||||
})
|
||||
continue
|
||||
|
||||
# Copy or move file
|
||||
if copy_only:
|
||||
shutil.copy2(artifact.file_path, new_file_path)
|
||||
logger.debug(f"Copied file from {artifact.file_path} to {new_file_path}")
|
||||
else:
|
||||
shutil.move(artifact.file_path, new_file_path)
|
||||
logger.debug(f"Moved file from {artifact.file_path} to {new_file_path}")
|
||||
|
||||
# Verify file was copied/moved successfully
|
||||
if verify and os.path.exists(new_file_path):
|
||||
# Verify file size
|
||||
new_size = os.path.getsize(new_file_path)
|
||||
if artifact.file_size_bytes and new_size != artifact.file_size_bytes:
|
||||
logger.warning(f"File size mismatch after migration: {new_file_path}")
|
||||
failed_count += 1
|
||||
failed_artifacts.append({
|
||||
"artifact_id": artifact.id,
|
||||
"file_path": new_file_path,
|
||||
"reason": "size_mismatch_after_migration"
|
||||
})
|
||||
continue
|
||||
|
||||
# Verify checksum if available
|
||||
if artifact.checksum:
|
||||
sha256_hash = hashlib.sha256()
|
||||
with open(new_file_path, "rb") as f:
|
||||
for byte_block in iter(lambda: f.read(4096), b""):
|
||||
sha256_hash.update(byte_block)
|
||||
new_checksum = sha256_hash.hexdigest()
|
||||
|
||||
if new_checksum != artifact.checksum:
|
||||
logger.warning(f"Checksum mismatch after migration: {new_file_path}")
|
||||
failed_count += 1
|
||||
failed_artifacts.append({
|
||||
"artifact_id": artifact.id,
|
||||
"file_path": new_file_path,
|
||||
"reason": "checksum_mismatch_after_migration"
|
||||
})
|
||||
continue
|
||||
|
||||
verified_count += 1
|
||||
|
||||
# Update database with new location
|
||||
await self.update(artifact.id, {
|
||||
"storage_location": to_location,
|
||||
"file_path": new_file_path
|
||||
})
|
||||
|
||||
migrated_count += 1
|
||||
|
||||
except Exception as migration_error:
|
||||
logger.error("Failed to migrate artifact",
|
||||
artifact_id=artifact.id,
|
||||
error=str(migration_error))
|
||||
failed_count += 1
|
||||
failed_artifacts.append({
|
||||
"artifact_id": artifact.id,
|
||||
"file_path": artifact.file_path,
|
||||
"reason": str(migration_error)
|
||||
})
|
||||
|
||||
logger.info("Artifact migration completed",
|
||||
from_location=from_location,
|
||||
to_location=to_location,
|
||||
migrated_count=migrated_count,
|
||||
failed_count=failed_count,
|
||||
verified_count=verified_count)
|
||||
|
||||
return {
|
||||
"from_location": from_location,
|
||||
"to_location": to_location,
|
||||
"total_artifacts": len(artifacts),
|
||||
"migrated_count": migrated_count,
|
||||
"failed_count": failed_count,
|
||||
"verified_count": verified_count if verify else None,
|
||||
"success_rate": round((migrated_count / len(artifacts)) * 100, 2) if artifacts else 100,
|
||||
"copy_only": copy_only,
|
||||
"failed_artifacts": failed_artifacts if failed_artifacts else None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to migrate artifacts",
|
||||
from_location=from_location,
|
||||
to_location=to_location,
|
||||
error=str(e))
|
||||
return {
|
||||
"error": f"Migration failed: {str(e)}"
|
||||
}
|
||||
179
services/training/app/repositories/base.py
Normal file
179
services/training/app/repositories/base.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""
|
||||
Base Repository for Training Service
|
||||
Service-specific repository base class with training service utilities
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any, Type
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import text
|
||||
from datetime import datetime, timezone, timedelta
|
||||
import structlog
|
||||
|
||||
from shared.database.repository import BaseRepository
|
||||
from shared.database.exceptions import DatabaseError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class TrainingBaseRepository(BaseRepository):
|
||||
"""Base repository for training service with common training operations"""
|
||||
|
||||
def __init__(self, model: Type, session: AsyncSession, cache_ttl: Optional[int] = 300):
|
||||
# Training data changes frequently, shorter cache time (5 minutes)
|
||||
super().__init__(model, session, cache_ttl)
|
||||
|
||||
async def get_by_tenant_id(self, tenant_id: str, skip: int = 0, limit: int = 100) -> List:
|
||||
"""Get records by tenant ID"""
|
||||
if hasattr(self.model, 'tenant_id'):
|
||||
return await self.get_multi(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
filters={"tenant_id": tenant_id},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
return await self.get_multi(skip=skip, limit=limit)
|
||||
|
||||
async def get_active_records(self, skip: int = 0, limit: int = 100) -> List:
|
||||
"""Get active records (if model has is_active field)"""
|
||||
if hasattr(self.model, 'is_active'):
|
||||
return await self.get_multi(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
filters={"is_active": True},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
return await self.get_multi(skip=skip, limit=limit)
|
||||
|
||||
async def get_by_job_id(self, job_id: str) -> Optional:
|
||||
"""Get record by job ID (if model has job_id field)"""
|
||||
if hasattr(self.model, 'job_id'):
|
||||
return await self.get_by_field("job_id", job_id)
|
||||
return None
|
||||
|
||||
async def get_by_model_id(self, model_id: str) -> Optional:
|
||||
"""Get record by model ID (if model has model_id field)"""
|
||||
if hasattr(self.model, 'model_id'):
|
||||
return await self.get_by_field("model_id", model_id)
|
||||
return None
|
||||
|
||||
async def deactivate_record(self, record_id: Any) -> Optional:
|
||||
"""Deactivate a record instead of deleting it"""
|
||||
if hasattr(self.model, 'is_active'):
|
||||
return await self.update(record_id, {"is_active": False})
|
||||
return await self.delete(record_id)
|
||||
|
||||
async def activate_record(self, record_id: Any) -> Optional:
|
||||
"""Activate a record"""
|
||||
if hasattr(self.model, 'is_active'):
|
||||
return await self.update(record_id, {"is_active": True})
|
||||
return await self.get_by_id(record_id)
|
||||
|
||||
async def cleanup_old_records(self, days_old: int = 90, status_filter: str = None) -> int:
|
||||
"""Clean up old training records"""
|
||||
try:
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_old)
|
||||
table_name = self.model.__tablename__
|
||||
|
||||
# Build query based on available fields
|
||||
conditions = [f"created_at < :cutoff_date"]
|
||||
params = {"cutoff_date": cutoff_date}
|
||||
|
||||
if status_filter and hasattr(self.model, 'status'):
|
||||
conditions.append(f"status = :status")
|
||||
params["status"] = status_filter
|
||||
|
||||
query_text = f"""
|
||||
DELETE FROM {table_name}
|
||||
WHERE {' AND '.join(conditions)}
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), params)
|
||||
deleted_count = result.rowcount
|
||||
|
||||
logger.info(f"Cleaned up old {self.model.__name__} records",
|
||||
deleted_count=deleted_count,
|
||||
days_old=days_old,
|
||||
status_filter=status_filter)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup old records",
|
||||
model=self.model.__name__,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Cleanup failed: {str(e)}")
|
||||
|
||||
async def get_records_by_date_range(
|
||||
self,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List:
|
||||
"""Get records within date range"""
|
||||
if not hasattr(self.model, 'created_at'):
|
||||
logger.warning(f"Model {self.model.__name__} has no created_at field")
|
||||
return []
|
||||
|
||||
try:
|
||||
table_name = self.model.__tablename__
|
||||
|
||||
query_text = f"""
|
||||
SELECT * FROM {table_name}
|
||||
WHERE created_at >= :start_date
|
||||
AND created_at <= :end_date
|
||||
ORDER BY created_at DESC
|
||||
LIMIT :limit OFFSET :skip
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"limit": limit,
|
||||
"skip": skip
|
||||
})
|
||||
|
||||
# Convert rows to model objects
|
||||
records = []
|
||||
for row in result.fetchall():
|
||||
# Create model instance from row data
|
||||
record_dict = dict(row._mapping)
|
||||
record = self.model(**record_dict)
|
||||
records.append(record)
|
||||
|
||||
return records
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get records by date range",
|
||||
model=self.model.__name__,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Date range query failed: {str(e)}")
|
||||
|
||||
def _validate_training_data(self, data: Dict[str, Any], required_fields: List[str]) -> Dict[str, Any]:
|
||||
"""Validate training-related data"""
|
||||
errors = []
|
||||
|
||||
for field in required_fields:
|
||||
if field not in data or not data[field]:
|
||||
errors.append(f"Missing required field: {field}")
|
||||
|
||||
# Validate tenant_id format if present
|
||||
if "tenant_id" in data and data["tenant_id"]:
|
||||
tenant_id = data["tenant_id"]
|
||||
if not isinstance(tenant_id, str) or len(tenant_id) < 1:
|
||||
errors.append("Invalid tenant_id format")
|
||||
|
||||
# Validate job_id format if present
|
||||
if "job_id" in data and data["job_id"]:
|
||||
job_id = data["job_id"]
|
||||
if not isinstance(job_id, str) or len(job_id) < 1:
|
||||
errors.append("Invalid job_id format")
|
||||
|
||||
return {
|
||||
"is_valid": len(errors) == 0,
|
||||
"errors": errors
|
||||
}
|
||||
445
services/training/app/repositories/job_queue_repository.py
Normal file
445
services/training/app/repositories/job_queue_repository.py
Normal file
@@ -0,0 +1,445 @@
|
||||
"""
|
||||
Job Queue Repository
|
||||
Repository for training job queue operations
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, text, desc
|
||||
from datetime import datetime, timedelta
|
||||
import structlog
|
||||
|
||||
from .base import TrainingBaseRepository
|
||||
from app.models.training import TrainingJobQueue
|
||||
from shared.database.exceptions import DatabaseError, ValidationError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class JobQueueRepository(TrainingBaseRepository):
|
||||
"""Repository for training job queue operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 60):
|
||||
# Job queue changes frequently, very short cache time (1 minute)
|
||||
super().__init__(TrainingJobQueue, session, cache_ttl)
|
||||
|
||||
async def enqueue_job(self, job_data: Dict[str, Any]) -> TrainingJobQueue:
|
||||
"""Add a job to the training queue"""
|
||||
try:
|
||||
# Validate job data
|
||||
validation_result = self._validate_training_data(
|
||||
job_data,
|
||||
["job_id", "tenant_id", "job_type"]
|
||||
)
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid job data: {validation_result['errors']}")
|
||||
|
||||
# Set default values
|
||||
if "priority" not in job_data:
|
||||
job_data["priority"] = 1
|
||||
if "status" not in job_data:
|
||||
job_data["status"] = "queued"
|
||||
if "max_retries" not in job_data:
|
||||
job_data["max_retries"] = 3
|
||||
|
||||
# Create queue entry
|
||||
queued_job = await self.create(job_data)
|
||||
|
||||
logger.info("Job enqueued",
|
||||
job_id=queued_job.job_id,
|
||||
tenant_id=queued_job.tenant_id,
|
||||
job_type=queued_job.job_type,
|
||||
priority=queued_job.priority)
|
||||
|
||||
return queued_job
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to enqueue job",
|
||||
job_id=job_data.get("job_id"),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to enqueue job: {str(e)}")
|
||||
|
||||
async def get_next_job(self, job_types: List[str] = None) -> Optional[TrainingJobQueue]:
|
||||
"""Get the next job to process from the queue"""
|
||||
try:
|
||||
# Build filters for job types if specified
|
||||
filters = {"status": "queued"}
|
||||
|
||||
if job_types:
|
||||
# For multiple job types, we need to use raw SQL
|
||||
job_types_str = "', '".join(job_types)
|
||||
query_text = f"""
|
||||
SELECT * FROM training_job_queue
|
||||
WHERE status = 'queued'
|
||||
AND job_type IN ('{job_types_str}')
|
||||
AND (scheduled_at IS NULL OR scheduled_at <= :now)
|
||||
ORDER BY priority DESC, created_at ASC
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {"now": datetime.now()})
|
||||
row = result.fetchone()
|
||||
|
||||
if row:
|
||||
record_dict = dict(row._mapping)
|
||||
return self.model(**record_dict)
|
||||
return None
|
||||
else:
|
||||
# Simple case - get any queued job
|
||||
jobs = await self.get_multi(
|
||||
filters=filters,
|
||||
limit=1,
|
||||
order_by="priority",
|
||||
order_desc=True
|
||||
)
|
||||
return jobs[0] if jobs else None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get next job from queue",
|
||||
job_types=job_types,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get next job: {str(e)}")
|
||||
|
||||
async def start_job(self, job_id: str) -> Optional[TrainingJobQueue]:
|
||||
"""Mark a job as started"""
|
||||
try:
|
||||
job = await self.get_by_job_id(job_id)
|
||||
if not job:
|
||||
logger.error(f"Job not found in queue: {job_id}")
|
||||
return None
|
||||
|
||||
if job.status != "queued":
|
||||
logger.warning(f"Job {job_id} is not queued (status: {job.status})")
|
||||
return job
|
||||
|
||||
updated_job = await self.update(job.id, {
|
||||
"status": "running",
|
||||
"started_at": datetime.now(),
|
||||
"updated_at": datetime.now()
|
||||
})
|
||||
|
||||
logger.info("Job started",
|
||||
job_id=job_id,
|
||||
job_type=job.job_type)
|
||||
|
||||
return updated_job
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to start job",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to start job: {str(e)}")
|
||||
|
||||
async def complete_job(self, job_id: str) -> Optional[TrainingJobQueue]:
|
||||
"""Mark a job as completed"""
|
||||
try:
|
||||
job = await self.get_by_job_id(job_id)
|
||||
if not job:
|
||||
logger.error(f"Job not found in queue: {job_id}")
|
||||
return None
|
||||
|
||||
updated_job = await self.update(job.id, {
|
||||
"status": "completed",
|
||||
"updated_at": datetime.now()
|
||||
})
|
||||
|
||||
logger.info("Job completed",
|
||||
job_id=job_id,
|
||||
job_type=job.job_type if job else "unknown")
|
||||
|
||||
return updated_job
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to complete job",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to complete job: {str(e)}")
|
||||
|
||||
async def fail_job(self, job_id: str, error_message: str = None) -> Optional[TrainingJobQueue]:
|
||||
"""Mark a job as failed and handle retries"""
|
||||
try:
|
||||
job = await self.get_by_job_id(job_id)
|
||||
if not job:
|
||||
logger.error(f"Job not found in queue: {job_id}")
|
||||
return None
|
||||
|
||||
# Increment retry count
|
||||
new_retry_count = job.retry_count + 1
|
||||
|
||||
# Check if we should retry
|
||||
if new_retry_count < job.max_retries:
|
||||
# Reset to queued for retry
|
||||
updated_job = await self.update(job.id, {
|
||||
"status": "queued",
|
||||
"retry_count": new_retry_count,
|
||||
"updated_at": datetime.now(),
|
||||
"started_at": None # Reset started_at for retry
|
||||
})
|
||||
|
||||
logger.info("Job failed, queued for retry",
|
||||
job_id=job_id,
|
||||
retry_count=new_retry_count,
|
||||
max_retries=job.max_retries)
|
||||
else:
|
||||
# Mark as permanently failed
|
||||
updated_job = await self.update(job.id, {
|
||||
"status": "failed",
|
||||
"retry_count": new_retry_count,
|
||||
"updated_at": datetime.now()
|
||||
})
|
||||
|
||||
logger.error("Job permanently failed",
|
||||
job_id=job_id,
|
||||
retry_count=new_retry_count,
|
||||
error_message=error_message)
|
||||
|
||||
return updated_job
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to handle job failure",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to handle job failure: {str(e)}")
|
||||
|
||||
async def cancel_job(self, job_id: str, cancelled_by: str = None) -> Optional[TrainingJobQueue]:
|
||||
"""Cancel a job"""
|
||||
try:
|
||||
job = await self.get_by_job_id(job_id)
|
||||
if not job:
|
||||
logger.error(f"Job not found in queue: {job_id}")
|
||||
return None
|
||||
|
||||
if job.status in ["completed", "failed"]:
|
||||
logger.warning(f"Cannot cancel job {job_id} with status {job.status}")
|
||||
return job
|
||||
|
||||
updated_job = await self.update(job.id, {
|
||||
"status": "cancelled",
|
||||
"cancelled_by": cancelled_by,
|
||||
"updated_at": datetime.now()
|
||||
})
|
||||
|
||||
logger.info("Job cancelled",
|
||||
job_id=job_id,
|
||||
cancelled_by=cancelled_by)
|
||||
|
||||
return updated_job
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cancel job",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to cancel job: {str(e)}")
|
||||
|
||||
async def get_queue_status(self, tenant_id: str = None) -> Dict[str, Any]:
|
||||
"""Get queue status and statistics"""
|
||||
try:
|
||||
base_filters = {}
|
||||
if tenant_id:
|
||||
base_filters["tenant_id"] = tenant_id
|
||||
|
||||
# Get counts by status
|
||||
queued_jobs = await self.count(filters={**base_filters, "status": "queued"})
|
||||
running_jobs = await self.count(filters={**base_filters, "status": "running"})
|
||||
completed_jobs = await self.count(filters={**base_filters, "status": "completed"})
|
||||
failed_jobs = await self.count(filters={**base_filters, "status": "failed"})
|
||||
cancelled_jobs = await self.count(filters={**base_filters, "status": "cancelled"})
|
||||
|
||||
# Get jobs by type
|
||||
type_query = text(f"""
|
||||
SELECT job_type, COUNT(*) as count
|
||||
FROM training_job_queue
|
||||
WHERE 1=1
|
||||
{' AND tenant_id = :tenant_id' if tenant_id else ''}
|
||||
GROUP BY job_type
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
|
||||
params = {"tenant_id": tenant_id} if tenant_id else {}
|
||||
result = await self.session.execute(type_query, params)
|
||||
jobs_by_type = {row.job_type: row.count for row in result.fetchall()}
|
||||
|
||||
# Get average wait time for completed jobs
|
||||
wait_time_query = text(f"""
|
||||
SELECT
|
||||
AVG(EXTRACT(EPOCH FROM (started_at - created_at))/60) as avg_wait_minutes
|
||||
FROM training_job_queue
|
||||
WHERE status = 'completed'
|
||||
AND started_at IS NOT NULL
|
||||
AND created_at IS NOT NULL
|
||||
{' AND tenant_id = :tenant_id' if tenant_id else ''}
|
||||
""")
|
||||
|
||||
wait_result = await self.session.execute(wait_time_query, params)
|
||||
wait_row = wait_result.fetchone()
|
||||
avg_wait_time = float(wait_row.avg_wait_minutes) if wait_row and wait_row.avg_wait_minutes else 0.0
|
||||
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"queue_counts": {
|
||||
"queued": queued_jobs,
|
||||
"running": running_jobs,
|
||||
"completed": completed_jobs,
|
||||
"failed": failed_jobs,
|
||||
"cancelled": cancelled_jobs,
|
||||
"total": queued_jobs + running_jobs + completed_jobs + failed_jobs + cancelled_jobs
|
||||
},
|
||||
"jobs_by_type": jobs_by_type,
|
||||
"avg_wait_time_minutes": round(avg_wait_time, 2),
|
||||
"queue_health": {
|
||||
"has_queued_jobs": queued_jobs > 0,
|
||||
"has_running_jobs": running_jobs > 0,
|
||||
"failure_rate": round((failed_jobs / max(completed_jobs + failed_jobs, 1)) * 100, 2)
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get queue status",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"queue_counts": {
|
||||
"queued": 0, "running": 0, "completed": 0,
|
||||
"failed": 0, "cancelled": 0, "total": 0
|
||||
},
|
||||
"jobs_by_type": {},
|
||||
"avg_wait_time_minutes": 0.0,
|
||||
"queue_health": {
|
||||
"has_queued_jobs": False,
|
||||
"has_running_jobs": False,
|
||||
"failure_rate": 0.0
|
||||
}
|
||||
}
|
||||
|
||||
async def get_jobs_by_tenant(
|
||||
self,
|
||||
tenant_id: str,
|
||||
status: str = None,
|
||||
job_type: str = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[TrainingJobQueue]:
|
||||
"""Get jobs for a tenant with optional filtering"""
|
||||
try:
|
||||
filters = {"tenant_id": tenant_id}
|
||||
if status:
|
||||
filters["status"] = status
|
||||
if job_type:
|
||||
filters["job_type"] = job_type
|
||||
|
||||
return await self.get_multi(
|
||||
filters=filters,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get jobs by tenant",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get tenant jobs: {str(e)}")
|
||||
|
||||
async def cleanup_old_jobs(self, days_old: int = 30, status_filter: str = None) -> int:
|
||||
"""Clean up old completed/failed/cancelled jobs"""
|
||||
try:
|
||||
cutoff_date = datetime.now() - timedelta(days=days_old)
|
||||
|
||||
# Only clean up finished jobs by default
|
||||
default_statuses = ["completed", "failed", "cancelled"]
|
||||
|
||||
if status_filter:
|
||||
status_condition = "status = :status"
|
||||
params = {"cutoff_date": cutoff_date, "status": status_filter}
|
||||
else:
|
||||
status_list = "', '".join(default_statuses)
|
||||
status_condition = f"status IN ('{status_list}')"
|
||||
params = {"cutoff_date": cutoff_date}
|
||||
|
||||
query_text = f"""
|
||||
DELETE FROM training_job_queue
|
||||
WHERE created_at < :cutoff_date
|
||||
AND {status_condition}
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), params)
|
||||
deleted_count = result.rowcount
|
||||
|
||||
logger.info("Cleaned up old queue jobs",
|
||||
deleted_count=deleted_count,
|
||||
days_old=days_old,
|
||||
status_filter=status_filter)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup old queue jobs",
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Queue cleanup failed: {str(e)}")
|
||||
|
||||
async def get_stuck_jobs(self, hours_stuck: int = 2) -> List[TrainingJobQueue]:
|
||||
"""Get jobs that have been running for too long"""
|
||||
try:
|
||||
cutoff_time = datetime.now() - timedelta(hours=hours_stuck)
|
||||
|
||||
query_text = """
|
||||
SELECT * FROM training_job_queue
|
||||
WHERE status = 'running'
|
||||
AND started_at IS NOT NULL
|
||||
AND started_at < :cutoff_time
|
||||
ORDER BY started_at ASC
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {"cutoff_time": cutoff_time})
|
||||
|
||||
stuck_jobs = []
|
||||
for row in result.fetchall():
|
||||
record_dict = dict(row._mapping)
|
||||
job = self.model(**record_dict)
|
||||
stuck_jobs.append(job)
|
||||
|
||||
if stuck_jobs:
|
||||
logger.warning("Found stuck jobs",
|
||||
count=len(stuck_jobs),
|
||||
hours_stuck=hours_stuck)
|
||||
|
||||
return stuck_jobs
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get stuck jobs",
|
||||
hours_stuck=hours_stuck,
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def reset_stuck_jobs(self, hours_stuck: int = 2) -> int:
|
||||
"""Reset stuck jobs back to queued status"""
|
||||
try:
|
||||
stuck_jobs = await self.get_stuck_jobs(hours_stuck)
|
||||
reset_count = 0
|
||||
|
||||
for job in stuck_jobs:
|
||||
# Reset job to queued status
|
||||
await self.update(job.id, {
|
||||
"status": "queued",
|
||||
"started_at": None,
|
||||
"updated_at": datetime.now()
|
||||
})
|
||||
reset_count += 1
|
||||
|
||||
if reset_count > 0:
|
||||
logger.info("Reset stuck jobs",
|
||||
reset_count=reset_count,
|
||||
hours_stuck=hours_stuck)
|
||||
|
||||
return reset_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to reset stuck jobs",
|
||||
hours_stuck=hours_stuck,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to reset stuck jobs: {str(e)}")
|
||||
375
services/training/app/repositories/model_repository.py
Normal file
375
services/training/app/repositories/model_repository.py
Normal file
@@ -0,0 +1,375 @@
|
||||
"""
|
||||
Model Repository
|
||||
Repository for trained model operations
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, text, desc
|
||||
from datetime import datetime, timezone, timedelta
|
||||
import structlog
|
||||
|
||||
from .base import TrainingBaseRepository
|
||||
from app.models.training import TrainedModel
|
||||
from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class ModelRepository(TrainingBaseRepository):
|
||||
"""Repository for trained model operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 600):
|
||||
# Models are relatively stable, longer cache time (10 minutes)
|
||||
super().__init__(TrainedModel, session, cache_ttl)
|
||||
|
||||
async def create_model(self, model_data: Dict[str, Any]) -> TrainedModel:
|
||||
"""Create a new trained model with validation"""
|
||||
try:
|
||||
# Validate model data
|
||||
validation_result = self._validate_training_data(
|
||||
model_data,
|
||||
["tenant_id", "inventory_product_id", "model_path", "job_id"]
|
||||
)
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid model data: {validation_result['errors']}")
|
||||
|
||||
# Check for duplicate active models for same tenant+product
|
||||
existing_model = await self.get_active_model_for_product(
|
||||
model_data["tenant_id"],
|
||||
model_data["inventory_product_id"]
|
||||
)
|
||||
|
||||
# If there's an existing active model, we may want to deactivate it
|
||||
if existing_model and model_data.get("is_production", False):
|
||||
logger.info("Deactivating previous production model",
|
||||
previous_model_id=existing_model.id,
|
||||
tenant_id=model_data["tenant_id"],
|
||||
inventory_product_id=model_data["inventory_product_id"])
|
||||
await self.update(existing_model.id, {"is_production": False})
|
||||
|
||||
# Create new model
|
||||
model = await self.create(model_data)
|
||||
|
||||
logger.info("Trained model created successfully",
|
||||
model_id=model.id,
|
||||
tenant_id=model.tenant_id,
|
||||
inventory_product_id=str(model.inventory_product_id),
|
||||
model_type=model.model_type)
|
||||
|
||||
return model
|
||||
|
||||
except (ValidationError, DuplicateRecordError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to create trained model",
|
||||
tenant_id=model_data.get("tenant_id"),
|
||||
inventory_product_id=model_data.get("inventory_product_id"),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create model: {str(e)}")
|
||||
|
||||
async def get_model_by_tenant_and_product(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str
|
||||
) -> List[TrainedModel]:
|
||||
"""Get all models for a tenant and product"""
|
||||
try:
|
||||
return await self.get_multi(
|
||||
filters={
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id
|
||||
},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get models by tenant and product",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get models: {str(e)}")
|
||||
|
||||
async def get_active_model_for_product(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str
|
||||
) -> Optional[TrainedModel]:
|
||||
"""Get the active production model for a product"""
|
||||
try:
|
||||
models = await self.get_multi(
|
||||
filters={
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"is_active": True,
|
||||
"is_production": True
|
||||
},
|
||||
order_by="created_at",
|
||||
order_desc=True,
|
||||
limit=1
|
||||
)
|
||||
return models[0] if models else None
|
||||
except Exception as e:
|
||||
logger.error("Failed to get active model for product",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get active model: {str(e)}")
|
||||
|
||||
async def get_models_by_tenant(
|
||||
self,
|
||||
tenant_id: str,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[TrainedModel]:
|
||||
"""Get all models for a tenant"""
|
||||
return await self.get_by_tenant_id(tenant_id, skip=skip, limit=limit)
|
||||
|
||||
async def promote_to_production(self, model_id: str) -> Optional[TrainedModel]:
|
||||
"""Promote a model to production"""
|
||||
try:
|
||||
# Get the model first
|
||||
model = await self.get_by_id(model_id)
|
||||
if not model:
|
||||
raise ValueError(f"Model {model_id} not found")
|
||||
|
||||
# Deactivate other production models for the same tenant+product
|
||||
await self._deactivate_other_production_models(
|
||||
model.tenant_id,
|
||||
str(model.inventory_product_id),
|
||||
model_id
|
||||
)
|
||||
|
||||
# Promote this model
|
||||
updated_model = await self.update(model_id, {
|
||||
"is_production": True,
|
||||
"last_used_at": datetime.now(timezone.utc)
|
||||
})
|
||||
|
||||
logger.info("Model promoted to production",
|
||||
model_id=model_id,
|
||||
tenant_id=model.tenant_id,
|
||||
inventory_product_id=str(model.inventory_product_id))
|
||||
|
||||
return updated_model
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to promote model to production",
|
||||
model_id=model_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to promote model: {str(e)}")
|
||||
|
||||
async def update_model_usage(self, model_id: str) -> Optional[TrainedModel]:
|
||||
"""Update model last used timestamp"""
|
||||
try:
|
||||
return await self.update(model_id, {
|
||||
"last_used_at": datetime.now(timezone.utc)
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error("Failed to update model usage",
|
||||
model_id=model_id,
|
||||
error=str(e))
|
||||
# Don't raise here - usage update is not critical
|
||||
return None
|
||||
|
||||
async def archive_old_models(self, tenant_id: str, days_old: int = 90) -> int:
|
||||
"""Archive old non-production models"""
|
||||
try:
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_old)
|
||||
|
||||
query = text("""
|
||||
UPDATE trained_models
|
||||
SET is_active = false
|
||||
WHERE tenant_id = :tenant_id
|
||||
AND is_production = false
|
||||
AND created_at < :cutoff_date
|
||||
AND is_active = true
|
||||
""")
|
||||
|
||||
result = await self.session.execute(query, {
|
||||
"tenant_id": tenant_id,
|
||||
"cutoff_date": cutoff_date
|
||||
})
|
||||
|
||||
archived_count = result.rowcount
|
||||
|
||||
logger.info("Archived old models",
|
||||
tenant_id=tenant_id,
|
||||
archived_count=archived_count,
|
||||
days_old=days_old)
|
||||
|
||||
return archived_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to archive old models",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Model archival failed: {str(e)}")
|
||||
|
||||
async def get_model_statistics(self, tenant_id: str) -> Dict[str, Any]:
|
||||
"""Get model statistics for a tenant"""
|
||||
try:
|
||||
# Get basic counts
|
||||
total_models = await self.count(filters={"tenant_id": tenant_id})
|
||||
active_models = await self.count(filters={
|
||||
"tenant_id": tenant_id,
|
||||
"is_active": True
|
||||
})
|
||||
production_models = await self.count(filters={
|
||||
"tenant_id": tenant_id,
|
||||
"is_production": True
|
||||
})
|
||||
|
||||
# Get models by product using raw query
|
||||
product_query = text("""
|
||||
SELECT inventory_product_id, COUNT(*) as count
|
||||
FROM trained_models
|
||||
WHERE tenant_id = :tenant_id
|
||||
AND is_active = true
|
||||
GROUP BY inventory_product_id
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(product_query, {"tenant_id": tenant_id})
|
||||
product_stats = {row.inventory_product_id: row.count for row in result.fetchall()}
|
||||
|
||||
# Recent activity (models created in last 30 days)
|
||||
thirty_days_ago = datetime.now(timezone.utc) - timedelta(days=30)
|
||||
recent_models_query = text("""
|
||||
SELECT COUNT(*) as count
|
||||
FROM trained_models
|
||||
WHERE tenant_id = :tenant_id
|
||||
AND created_at >= :thirty_days_ago
|
||||
""")
|
||||
|
||||
recent_result = await self.session.execute(
|
||||
recent_models_query,
|
||||
{"tenant_id": tenant_id, "thirty_days_ago": thirty_days_ago}
|
||||
)
|
||||
recent_models = recent_result.scalar() or 0
|
||||
|
||||
# Calculate average accuracy from model metrics
|
||||
accuracy_query = text("""
|
||||
SELECT AVG(mape) as average_mape, COUNT(*) as total_models_with_metrics
|
||||
FROM trained_models
|
||||
WHERE tenant_id = :tenant_id
|
||||
AND mape IS NOT NULL
|
||||
AND is_active = true
|
||||
""")
|
||||
|
||||
accuracy_result = await self.session.execute(accuracy_query, {"tenant_id": tenant_id})
|
||||
accuracy_row = accuracy_result.fetchone()
|
||||
|
||||
average_mape = accuracy_row.average_mape if accuracy_row and accuracy_row.average_mape else 0
|
||||
total_models_with_metrics = accuracy_row.total_models_with_metrics if accuracy_row else 0
|
||||
|
||||
# Convert MAPE to accuracy percentage (lower MAPE = higher accuracy)
|
||||
# Use 100 - MAPE as a simple conversion, but cap it at reasonable bounds
|
||||
# Return None if no models have metrics (no data), rather than 0
|
||||
if total_models_with_metrics == 0:
|
||||
average_accuracy = None
|
||||
else:
|
||||
average_accuracy = max(0, min(100, 100 - float(average_mape))) if average_mape > 0 else 0
|
||||
|
||||
return {
|
||||
"total_models": total_models,
|
||||
"active_models": active_models,
|
||||
"inactive_models": total_models - active_models,
|
||||
"production_models": production_models,
|
||||
"models_by_product": product_stats,
|
||||
"recent_models_30d": recent_models,
|
||||
"average_accuracy": average_accuracy,
|
||||
"total_models_with_metrics": total_models_with_metrics,
|
||||
"average_mape": float(average_mape) if average_mape > 0 else 0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get model statistics",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"total_models": 0,
|
||||
"active_models": 0,
|
||||
"inactive_models": 0,
|
||||
"production_models": 0,
|
||||
"models_by_product": {},
|
||||
"recent_models_30d": 0,
|
||||
"average_accuracy": 0,
|
||||
"total_models_with_metrics": 0,
|
||||
"average_mape": 0
|
||||
}
|
||||
|
||||
async def _deactivate_other_production_models(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
exclude_model_id: str
|
||||
) -> int:
|
||||
"""Deactivate other production models for the same tenant+product"""
|
||||
try:
|
||||
query = text("""
|
||||
UPDATE trained_models
|
||||
SET is_production = false
|
||||
WHERE tenant_id = :tenant_id
|
||||
AND inventory_product_id = :inventory_product_id
|
||||
AND id != :exclude_model_id
|
||||
AND is_production = true
|
||||
""")
|
||||
|
||||
result = await self.session.execute(query, {
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"exclude_model_id": exclude_model_id
|
||||
})
|
||||
|
||||
return result.rowcount
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to deactivate other production models",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to deactivate models: {str(e)}")
|
||||
|
||||
async def get_model_performance_summary(self, model_id: str) -> Dict[str, Any]:
|
||||
"""Get performance summary for a model"""
|
||||
try:
|
||||
model = await self.get_by_id(model_id)
|
||||
if not model:
|
||||
return {}
|
||||
|
||||
return {
|
||||
"model_id": model.id,
|
||||
"tenant_id": model.tenant_id,
|
||||
"inventory_product_id": str(model.inventory_product_id),
|
||||
"model_type": model.model_type,
|
||||
"metrics": {
|
||||
"mape": model.mape,
|
||||
"mae": model.mae,
|
||||
"rmse": model.rmse,
|
||||
"r2_score": model.r2_score
|
||||
},
|
||||
"training_info": {
|
||||
"training_samples": model.training_samples,
|
||||
"training_start_date": model.training_start_date.isoformat() if model.training_start_date else None,
|
||||
"training_end_date": model.training_end_date.isoformat() if model.training_end_date else None,
|
||||
"data_quality_score": model.data_quality_score
|
||||
},
|
||||
"status": {
|
||||
"is_active": model.is_active,
|
||||
"is_production": model.is_production,
|
||||
"created_at": model.created_at.isoformat() if model.created_at else None,
|
||||
"last_used_at": model.last_used_at.isoformat() if model.last_used_at else None
|
||||
},
|
||||
"features": {
|
||||
"hyperparameters": model.hyperparameters,
|
||||
"features_used": model.features_used
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get model performance summary",
|
||||
model_id=model_id,
|
||||
error=str(e))
|
||||
return {}
|
||||
433
services/training/app/repositories/performance_repository.py
Normal file
433
services/training/app/repositories/performance_repository.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""
|
||||
Performance Repository
|
||||
Repository for model performance metrics operations
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, text, desc
|
||||
from datetime import datetime, timedelta
|
||||
import structlog
|
||||
|
||||
from .base import TrainingBaseRepository
|
||||
from app.models.training import ModelPerformanceMetric
|
||||
from shared.database.exceptions import DatabaseError, ValidationError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class PerformanceRepository(TrainingBaseRepository):
|
||||
"""Repository for model performance metrics operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 900):
|
||||
# Performance metrics are relatively stable, longer cache time (15 minutes)
|
||||
super().__init__(ModelPerformanceMetric, session, cache_ttl)
|
||||
|
||||
async def create_performance_metric(self, metric_data: Dict[str, Any]) -> ModelPerformanceMetric:
|
||||
"""Create a new performance metric record"""
|
||||
try:
|
||||
# Validate metric data
|
||||
validation_result = self._validate_training_data(
|
||||
metric_data,
|
||||
["model_id", "tenant_id", "inventory_product_id"]
|
||||
)
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid metric data: {validation_result['errors']}")
|
||||
|
||||
# Set measurement timestamp if not provided
|
||||
if "measured_at" not in metric_data:
|
||||
metric_data["measured_at"] = datetime.now()
|
||||
|
||||
# Create metric record
|
||||
metric = await self.create(metric_data)
|
||||
|
||||
logger.info("Performance metric created",
|
||||
model_id=metric.model_id,
|
||||
tenant_id=metric.tenant_id,
|
||||
inventory_product_id=str(metric.inventory_product_id))
|
||||
|
||||
return metric
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to create performance metric",
|
||||
model_id=metric_data.get("model_id"),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create metric: {str(e)}")
|
||||
|
||||
async def get_metrics_by_model(
|
||||
self,
|
||||
model_id: str,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[ModelPerformanceMetric]:
|
||||
"""Get all performance metrics for a model"""
|
||||
try:
|
||||
return await self.get_multi(
|
||||
filters={"model_id": model_id},
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
order_by="measured_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get metrics by model",
|
||||
model_id=model_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get metrics: {str(e)}")
|
||||
|
||||
async def get_latest_metric_for_model(self, model_id: str) -> Optional[ModelPerformanceMetric]:
|
||||
"""Get the latest performance metric for a model"""
|
||||
try:
|
||||
metrics = await self.get_multi(
|
||||
filters={"model_id": model_id},
|
||||
limit=1,
|
||||
order_by="measured_at",
|
||||
order_desc=True
|
||||
)
|
||||
return metrics[0] if metrics else None
|
||||
except Exception as e:
|
||||
logger.error("Failed to get latest metric for model",
|
||||
model_id=model_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get latest metric: {str(e)}")
|
||||
|
||||
async def get_metrics_by_tenant_and_product(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[ModelPerformanceMetric]:
|
||||
"""Get performance metrics for a tenant's product"""
|
||||
try:
|
||||
return await self.get_multi(
|
||||
filters={
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id
|
||||
},
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
order_by="measured_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get metrics by tenant and product",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get metrics: {str(e)}")
|
||||
|
||||
async def get_metrics_in_date_range(
|
||||
self,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
tenant_id: str = None,
|
||||
model_id: str = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[ModelPerformanceMetric]:
|
||||
"""Get performance metrics within a date range"""
|
||||
try:
|
||||
# Build filters
|
||||
table_name = self.model.__tablename__
|
||||
conditions = ["measured_at >= :start_date", "measured_at <= :end_date"]
|
||||
params = {"start_date": start_date, "end_date": end_date, "limit": limit, "skip": skip}
|
||||
|
||||
if tenant_id:
|
||||
conditions.append("tenant_id = :tenant_id")
|
||||
params["tenant_id"] = tenant_id
|
||||
|
||||
if model_id:
|
||||
conditions.append("model_id = :model_id")
|
||||
params["model_id"] = model_id
|
||||
|
||||
query_text = f"""
|
||||
SELECT * FROM {table_name}
|
||||
WHERE {' AND '.join(conditions)}
|
||||
ORDER BY measured_at DESC
|
||||
LIMIT :limit OFFSET :skip
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), params)
|
||||
|
||||
# Convert rows to model objects
|
||||
metrics = []
|
||||
for row in result.fetchall():
|
||||
record_dict = dict(row._mapping)
|
||||
metric = self.model(**record_dict)
|
||||
metrics.append(metric)
|
||||
|
||||
return metrics
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get metrics in date range",
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Date range query failed: {str(e)}")
|
||||
|
||||
async def get_performance_trends(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str = None,
|
||||
days: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""Get performance trends for analysis"""
|
||||
try:
|
||||
start_date = datetime.now() - timedelta(days=days)
|
||||
end_date = datetime.now()
|
||||
|
||||
# Build query for performance trends
|
||||
conditions = ["tenant_id = :tenant_id", "measured_at >= :start_date"]
|
||||
params = {"tenant_id": tenant_id, "start_date": start_date}
|
||||
|
||||
if inventory_product_id:
|
||||
conditions.append("inventory_product_id = :inventory_product_id")
|
||||
params["inventory_product_id"] = inventory_product_id
|
||||
|
||||
query_text = f"""
|
||||
SELECT
|
||||
inventory_product_id,
|
||||
AVG(mae) as avg_mae,
|
||||
AVG(mse) as avg_mse,
|
||||
AVG(rmse) as avg_rmse,
|
||||
AVG(mape) as avg_mape,
|
||||
AVG(r2_score) as avg_r2_score,
|
||||
AVG(accuracy_percentage) as avg_accuracy,
|
||||
COUNT(*) as measurement_count,
|
||||
MIN(measured_at) as first_measurement,
|
||||
MAX(measured_at) as last_measurement
|
||||
FROM model_performance_metrics
|
||||
WHERE {' AND '.join(conditions)}
|
||||
GROUP BY inventory_product_id
|
||||
ORDER BY avg_accuracy DESC
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), params)
|
||||
|
||||
trends = []
|
||||
for row in result.fetchall():
|
||||
trends.append({
|
||||
"inventory_product_id": row.inventory_product_id,
|
||||
"metrics": {
|
||||
"avg_mae": float(row.avg_mae) if row.avg_mae else None,
|
||||
"avg_mse": float(row.avg_mse) if row.avg_mse else None,
|
||||
"avg_rmse": float(row.avg_rmse) if row.avg_rmse else None,
|
||||
"avg_mape": float(row.avg_mape) if row.avg_mape else None,
|
||||
"avg_r2_score": float(row.avg_r2_score) if row.avg_r2_score else None,
|
||||
"avg_accuracy": float(row.avg_accuracy) if row.avg_accuracy else None
|
||||
},
|
||||
"measurement_count": int(row.measurement_count),
|
||||
"period": {
|
||||
"start": row.first_measurement.isoformat() if row.first_measurement else None,
|
||||
"end": row.last_measurement.isoformat() if row.last_measurement else None,
|
||||
"days": days
|
||||
}
|
||||
})
|
||||
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"trends": trends,
|
||||
"period_days": days,
|
||||
"total_products": len(trends)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get performance trends",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"trends": [],
|
||||
"period_days": days,
|
||||
"total_products": 0
|
||||
}
|
||||
|
||||
async def get_best_performing_models(
|
||||
self,
|
||||
tenant_id: str,
|
||||
metric_type: str = "accuracy_percentage",
|
||||
limit: int = 10
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get best performing models based on a specific metric"""
|
||||
try:
|
||||
# Validate metric type
|
||||
valid_metrics = ["mae", "mse", "rmse", "mape", "r2_score", "accuracy_percentage"]
|
||||
if metric_type not in valid_metrics:
|
||||
metric_type = "accuracy_percentage"
|
||||
|
||||
# For error metrics (mae, mse, rmse, mape), lower is better
|
||||
# For performance metrics (r2_score, accuracy_percentage), higher is better
|
||||
order_desc = metric_type in ["r2_score", "accuracy_percentage"]
|
||||
order_direction = "DESC" if order_desc else "ASC"
|
||||
|
||||
query_text = f"""
|
||||
SELECT DISTINCT ON (inventory_product_id, model_id)
|
||||
model_id,
|
||||
inventory_product_id,
|
||||
{metric_type},
|
||||
measured_at,
|
||||
evaluation_samples
|
||||
FROM model_performance_metrics
|
||||
WHERE tenant_id = :tenant_id
|
||||
AND {metric_type} IS NOT NULL
|
||||
ORDER BY inventory_product_id, model_id, measured_at DESC, {metric_type} {order_direction}
|
||||
LIMIT :limit
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {
|
||||
"tenant_id": tenant_id,
|
||||
"limit": limit
|
||||
})
|
||||
|
||||
best_models = []
|
||||
for row in result.fetchall():
|
||||
best_models.append({
|
||||
"model_id": row.model_id,
|
||||
"inventory_product_id": row.inventory_product_id,
|
||||
"metric_value": float(getattr(row, metric_type)),
|
||||
"metric_type": metric_type,
|
||||
"measured_at": row.measured_at.isoformat() if row.measured_at else None,
|
||||
"evaluation_samples": int(row.evaluation_samples) if row.evaluation_samples else None
|
||||
})
|
||||
|
||||
return best_models
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get best performing models",
|
||||
tenant_id=tenant_id,
|
||||
metric_type=metric_type,
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def cleanup_old_metrics(self, days_old: int = 180) -> int:
|
||||
"""Clean up old performance metrics"""
|
||||
return await self.cleanup_old_records(days_old=days_old)
|
||||
|
||||
async def get_metric_statistics(self, tenant_id: str) -> Dict[str, Any]:
|
||||
"""Get performance metric statistics for a tenant"""
|
||||
try:
|
||||
# Get basic counts
|
||||
total_metrics = await self.count(filters={"tenant_id": tenant_id})
|
||||
|
||||
# Get metrics by product using raw query
|
||||
product_query = text("""
|
||||
SELECT
|
||||
inventory_product_id,
|
||||
COUNT(*) as metric_count,
|
||||
AVG(accuracy_percentage) as avg_accuracy
|
||||
FROM model_performance_metrics
|
||||
WHERE tenant_id = :tenant_id
|
||||
GROUP BY inventory_product_id
|
||||
ORDER BY avg_accuracy DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(product_query, {"tenant_id": tenant_id})
|
||||
product_stats = {}
|
||||
|
||||
for row in result.fetchall():
|
||||
product_stats[row.inventory_product_id] = {
|
||||
"metric_count": row.metric_count,
|
||||
"avg_accuracy": float(row.avg_accuracy) if row.avg_accuracy else None
|
||||
}
|
||||
|
||||
# Recent activity (metrics in last 7 days)
|
||||
seven_days_ago = datetime.now() - timedelta(days=7)
|
||||
recent_metrics = len(await self.get_records_by_date_range(
|
||||
seven_days_ago,
|
||||
datetime.now(),
|
||||
limit=1000 # High limit to get accurate count
|
||||
))
|
||||
|
||||
return {
|
||||
"total_metrics": total_metrics,
|
||||
"products_tracked": len(product_stats),
|
||||
"metrics_by_product": product_stats,
|
||||
"recent_metrics_7d": recent_metrics
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get metric statistics",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"total_metrics": 0,
|
||||
"products_tracked": 0,
|
||||
"metrics_by_product": {},
|
||||
"recent_metrics_7d": 0
|
||||
}
|
||||
|
||||
async def compare_model_performance(
|
||||
self,
|
||||
model_ids: List[str],
|
||||
metric_type: str = "accuracy_percentage"
|
||||
) -> Dict[str, Any]:
|
||||
"""Compare performance between multiple models"""
|
||||
try:
|
||||
if not model_ids or len(model_ids) < 2:
|
||||
return {"error": "At least 2 model IDs required for comparison"}
|
||||
|
||||
# Validate metric type
|
||||
valid_metrics = ["mae", "mse", "rmse", "mape", "r2_score", "accuracy_percentage"]
|
||||
if metric_type not in valid_metrics:
|
||||
metric_type = "accuracy_percentage"
|
||||
|
||||
model_ids_str = "', '".join(model_ids)
|
||||
|
||||
query_text = f"""
|
||||
SELECT
|
||||
model_id,
|
||||
inventory_product_id,
|
||||
AVG({metric_type}) as avg_metric,
|
||||
MIN({metric_type}) as min_metric,
|
||||
MAX({metric_type}) as max_metric,
|
||||
COUNT(*) as measurement_count,
|
||||
MAX(measured_at) as latest_measurement
|
||||
FROM model_performance_metrics
|
||||
WHERE model_id IN ('{model_ids_str}')
|
||||
AND {metric_type} IS NOT NULL
|
||||
GROUP BY model_id, inventory_product_id
|
||||
ORDER BY avg_metric DESC
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text))
|
||||
|
||||
comparisons = []
|
||||
for row in result.fetchall():
|
||||
comparisons.append({
|
||||
"model_id": row.model_id,
|
||||
"inventory_product_id": row.inventory_product_id,
|
||||
"avg_metric": float(row.avg_metric),
|
||||
"min_metric": float(row.min_metric),
|
||||
"max_metric": float(row.max_metric),
|
||||
"measurement_count": int(row.measurement_count),
|
||||
"latest_measurement": row.latest_measurement.isoformat() if row.latest_measurement else None
|
||||
})
|
||||
|
||||
# Find best and worst performing models
|
||||
if comparisons:
|
||||
best_model = max(comparisons, key=lambda x: x["avg_metric"])
|
||||
worst_model = min(comparisons, key=lambda x: x["avg_metric"])
|
||||
else:
|
||||
best_model = worst_model = None
|
||||
|
||||
return {
|
||||
"metric_type": metric_type,
|
||||
"models_compared": len(set(comp["model_id"] for comp in comparisons)),
|
||||
"comparisons": comparisons,
|
||||
"best_performing": best_model,
|
||||
"worst_performing": worst_model
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to compare model performance",
|
||||
model_ids=model_ids,
|
||||
metric_type=metric_type,
|
||||
error=str(e))
|
||||
return {"error": f"Comparison failed: {str(e)}"}
|
||||
507
services/training/app/repositories/training_log_repository.py
Normal file
507
services/training/app/repositories/training_log_repository.py
Normal file
@@ -0,0 +1,507 @@
|
||||
"""
|
||||
Training Log Repository
|
||||
Repository for model training log operations
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, text, desc
|
||||
from datetime import datetime, timedelta
|
||||
import structlog
|
||||
|
||||
from .base import TrainingBaseRepository
|
||||
from app.models.training import ModelTrainingLog
|
||||
from shared.database.exceptions import DatabaseError, ValidationError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class TrainingLogRepository(TrainingBaseRepository):
|
||||
"""Repository for training log operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 300):
|
||||
# Training logs change frequently, shorter cache time (5 minutes)
|
||||
super().__init__(ModelTrainingLog, session, cache_ttl)
|
||||
|
||||
async def create_training_log(self, log_data: Dict[str, Any]) -> ModelTrainingLog:
|
||||
"""Create a new training log entry"""
|
||||
try:
|
||||
# Validate log data
|
||||
validation_result = self._validate_training_data(
|
||||
log_data,
|
||||
["job_id", "tenant_id", "status"]
|
||||
)
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid training log data: {validation_result['errors']}")
|
||||
|
||||
# Set default values
|
||||
if "progress" not in log_data:
|
||||
log_data["progress"] = 0
|
||||
if "current_step" not in log_data:
|
||||
log_data["current_step"] = "initializing"
|
||||
|
||||
# Create log entry
|
||||
log_entry = await self.create(log_data)
|
||||
|
||||
logger.info("Training log created",
|
||||
job_id=log_entry.job_id,
|
||||
tenant_id=log_entry.tenant_id,
|
||||
status=log_entry.status)
|
||||
|
||||
return log_entry
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to create training log",
|
||||
job_id=log_data.get("job_id"),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create training log: {str(e)}")
|
||||
|
||||
async def get_log_by_job_id(self, job_id: str) -> Optional[ModelTrainingLog]:
|
||||
"""Get training log by job ID"""
|
||||
return await self.get_by_job_id(job_id)
|
||||
|
||||
async def update_log_progress(
|
||||
self,
|
||||
job_id: str,
|
||||
progress: int,
|
||||
current_step: str = None,
|
||||
status: str = None
|
||||
) -> Optional[ModelTrainingLog]:
|
||||
"""Update training log progress"""
|
||||
try:
|
||||
update_data = {"progress": progress, "updated_at": datetime.now()}
|
||||
|
||||
if current_step:
|
||||
update_data["current_step"] = current_step
|
||||
if status:
|
||||
update_data["status"] = status
|
||||
|
||||
log_entry = await self.get_by_job_id(job_id)
|
||||
if not log_entry:
|
||||
logger.error(f"Training log not found for job {job_id}")
|
||||
return None
|
||||
|
||||
updated_log = await self.update(log_entry.id, update_data)
|
||||
|
||||
logger.debug("Training log progress updated",
|
||||
job_id=job_id,
|
||||
progress=progress,
|
||||
step=current_step)
|
||||
|
||||
return updated_log
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to update training log progress",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to update progress: {str(e)}")
|
||||
|
||||
async def complete_training_log(
|
||||
self,
|
||||
job_id: str,
|
||||
results: Dict[str, Any] = None,
|
||||
error_message: str = None
|
||||
) -> Optional[ModelTrainingLog]:
|
||||
"""Mark training log as completed or failed"""
|
||||
try:
|
||||
status = "failed" if error_message else "completed"
|
||||
|
||||
update_data = {
|
||||
"status": status,
|
||||
"progress": 100 if status == "completed" else None,
|
||||
"end_time": datetime.now(),
|
||||
"updated_at": datetime.now()
|
||||
}
|
||||
|
||||
if results:
|
||||
update_data["results"] = results
|
||||
if error_message:
|
||||
update_data["error_message"] = error_message
|
||||
|
||||
log_entry = await self.get_by_job_id(job_id)
|
||||
if not log_entry:
|
||||
logger.error(f"Training log not found for job {job_id}")
|
||||
return None
|
||||
|
||||
updated_log = await self.update(log_entry.id, update_data)
|
||||
|
||||
logger.info("Training log completed",
|
||||
job_id=job_id,
|
||||
status=status,
|
||||
has_results=bool(results))
|
||||
|
||||
return updated_log
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to complete training log",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to complete training log: {str(e)}")
|
||||
|
||||
async def get_logs_by_tenant(
|
||||
self,
|
||||
tenant_id: str,
|
||||
status: str = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[ModelTrainingLog]:
|
||||
"""Get training logs for a tenant"""
|
||||
try:
|
||||
filters = {"tenant_id": tenant_id}
|
||||
if status:
|
||||
filters["status"] = status
|
||||
|
||||
return await self.get_multi(
|
||||
filters=filters,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get logs by tenant",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get training logs: {str(e)}")
|
||||
|
||||
async def get_active_jobs(self, tenant_id: str = None) -> List[ModelTrainingLog]:
|
||||
"""Get currently running training jobs"""
|
||||
try:
|
||||
filters = {"status": "running"}
|
||||
if tenant_id:
|
||||
filters["tenant_id"] = tenant_id
|
||||
|
||||
return await self.get_multi(
|
||||
filters=filters,
|
||||
order_by="start_time",
|
||||
order_desc=True
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get active jobs",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get active jobs: {str(e)}")
|
||||
|
||||
async def cancel_job(self, job_id: str, cancelled_by: str = None) -> Optional[ModelTrainingLog]:
|
||||
"""Cancel a training job"""
|
||||
try:
|
||||
update_data = {
|
||||
"status": "cancelled",
|
||||
"end_time": datetime.now(),
|
||||
"updated_at": datetime.now()
|
||||
}
|
||||
|
||||
if cancelled_by:
|
||||
update_data["error_message"] = f"Cancelled by {cancelled_by}"
|
||||
|
||||
log_entry = await self.get_by_job_id(job_id)
|
||||
if not log_entry:
|
||||
logger.error(f"Training log not found for job {job_id}")
|
||||
return None
|
||||
|
||||
# Only cancel if job is still running
|
||||
if log_entry.status not in ["pending", "running"]:
|
||||
logger.warning(f"Cannot cancel job {job_id} with status {log_entry.status}")
|
||||
return log_entry
|
||||
|
||||
updated_log = await self.update(log_entry.id, update_data)
|
||||
|
||||
logger.info("Training job cancelled",
|
||||
job_id=job_id,
|
||||
cancelled_by=cancelled_by)
|
||||
|
||||
return updated_log
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cancel training job",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to cancel job: {str(e)}")
|
||||
|
||||
async def get_job_statistics(self, tenant_id: str = None) -> Dict[str, Any]:
|
||||
"""Get training job statistics"""
|
||||
try:
|
||||
base_filters = {}
|
||||
if tenant_id:
|
||||
base_filters["tenant_id"] = tenant_id
|
||||
|
||||
# Get counts by status
|
||||
total_jobs = await self.count(filters=base_filters)
|
||||
completed_jobs = await self.count(filters={**base_filters, "status": "completed"})
|
||||
failed_jobs = await self.count(filters={**base_filters, "status": "failed"})
|
||||
running_jobs = await self.count(filters={**base_filters, "status": "running"})
|
||||
pending_jobs = await self.count(filters={**base_filters, "status": "pending"})
|
||||
|
||||
# Get recent activity (jobs in last 7 days)
|
||||
seven_days_ago = datetime.now() - timedelta(days=7)
|
||||
recent_jobs = len(await self.get_records_by_date_range(
|
||||
seven_days_ago,
|
||||
datetime.now(),
|
||||
limit=1000 # High limit to get accurate count
|
||||
))
|
||||
|
||||
# Calculate success rate
|
||||
finished_jobs = completed_jobs + failed_jobs
|
||||
success_rate = (completed_jobs / finished_jobs * 100) if finished_jobs > 0 else 0
|
||||
|
||||
return {
|
||||
"total_jobs": total_jobs,
|
||||
"completed_jobs": completed_jobs,
|
||||
"failed_jobs": failed_jobs,
|
||||
"running_jobs": running_jobs,
|
||||
"pending_jobs": pending_jobs,
|
||||
"cancelled_jobs": total_jobs - completed_jobs - failed_jobs - running_jobs - pending_jobs,
|
||||
"success_rate": round(success_rate, 2),
|
||||
"recent_jobs_7d": recent_jobs
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get job statistics",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"total_jobs": 0,
|
||||
"completed_jobs": 0,
|
||||
"failed_jobs": 0,
|
||||
"running_jobs": 0,
|
||||
"pending_jobs": 0,
|
||||
"cancelled_jobs": 0,
|
||||
"success_rate": 0.0,
|
||||
"recent_jobs_7d": 0
|
||||
}
|
||||
|
||||
async def cleanup_old_logs(self, days_old: int = 90) -> int:
|
||||
"""Clean up old completed/failed training logs"""
|
||||
return await self.cleanup_old_records(
|
||||
days_old=days_old,
|
||||
status_filter=None # Clean up all old records regardless of status
|
||||
)
|
||||
|
||||
async def get_job_duration_stats(self, tenant_id: str = None) -> Dict[str, Any]:
|
||||
"""Get job duration statistics"""
|
||||
try:
|
||||
# Use raw SQL for complex duration calculations
|
||||
tenant_filter = "AND tenant_id = :tenant_id" if tenant_id else ""
|
||||
params = {"tenant_id": tenant_id} if tenant_id else {}
|
||||
|
||||
query = text(f"""
|
||||
SELECT
|
||||
AVG(EXTRACT(EPOCH FROM (end_time - start_time))/60) as avg_duration_minutes,
|
||||
MIN(EXTRACT(EPOCH FROM (end_time - start_time))/60) as min_duration_minutes,
|
||||
MAX(EXTRACT(EPOCH FROM (end_time - start_time))/60) as max_duration_minutes,
|
||||
COUNT(*) as completed_jobs_with_duration
|
||||
FROM model_training_logs
|
||||
WHERE status = 'completed'
|
||||
AND start_time IS NOT NULL
|
||||
AND end_time IS NOT NULL
|
||||
{tenant_filter}
|
||||
""")
|
||||
|
||||
result = await self.session.execute(query, params)
|
||||
row = result.fetchone()
|
||||
|
||||
if row and row.completed_jobs_with_duration > 0:
|
||||
return {
|
||||
"avg_duration_minutes": round(float(row.avg_duration_minutes or 0), 2),
|
||||
"min_duration_minutes": round(float(row.min_duration_minutes or 0), 2),
|
||||
"max_duration_minutes": round(float(row.max_duration_minutes or 0), 2),
|
||||
"completed_jobs_with_duration": int(row.completed_jobs_with_duration)
|
||||
}
|
||||
|
||||
return {
|
||||
"avg_duration_minutes": 0.0,
|
||||
"min_duration_minutes": 0.0,
|
||||
"max_duration_minutes": 0.0,
|
||||
"completed_jobs_with_duration": 0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get job duration statistics",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"avg_duration_minutes": 0.0,
|
||||
"min_duration_minutes": 0.0,
|
||||
"max_duration_minutes": 0.0,
|
||||
"completed_jobs_with_duration": 0
|
||||
}
|
||||
|
||||
async def get_start_time(self, job_id: str) -> Optional[datetime]:
|
||||
"""Get the start time for a training job"""
|
||||
try:
|
||||
log_entry = await self.get_by_job_id(job_id)
|
||||
if log_entry and log_entry.start_time:
|
||||
return log_entry.start_time
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error("Failed to get start time",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
return None
|
||||
|
||||
async def create_job_atomic(
|
||||
self,
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
config: Dict[str, Any] = None
|
||||
) -> tuple[Optional[ModelTrainingLog], bool]:
|
||||
"""
|
||||
Atomically create a training job, respecting the unique constraint.
|
||||
|
||||
This method uses INSERT ... ON CONFLICT to handle race conditions
|
||||
when multiple pods try to create a job for the same tenant simultaneously.
|
||||
The database constraint (idx_unique_active_training_per_tenant) ensures
|
||||
only one active job per tenant can exist.
|
||||
|
||||
Args:
|
||||
job_id: Unique job identifier
|
||||
tenant_id: Tenant identifier
|
||||
config: Optional job configuration
|
||||
|
||||
Returns:
|
||||
Tuple of (job, created):
|
||||
- If created: (new_job, True)
|
||||
- If conflict (existing active job): (existing_job, False)
|
||||
- If error: raises DatabaseError
|
||||
"""
|
||||
try:
|
||||
# First, try to find an existing active job
|
||||
existing = await self.get_active_jobs(tenant_id=tenant_id)
|
||||
pending = await self.get_logs_by_tenant(tenant_id=tenant_id, status="pending", limit=1)
|
||||
|
||||
if existing or pending:
|
||||
# Return existing job
|
||||
active_job = existing[0] if existing else pending[0]
|
||||
logger.info("Found existing active job, skipping creation",
|
||||
existing_job_id=active_job.job_id,
|
||||
tenant_id=tenant_id,
|
||||
requested_job_id=job_id)
|
||||
return (active_job, False)
|
||||
|
||||
# Try to create the new job
|
||||
# If another pod created one in the meantime, the unique constraint will prevent this
|
||||
log_data = {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"status": "pending",
|
||||
"progress": 0,
|
||||
"current_step": "initializing",
|
||||
"config": config or {}
|
||||
}
|
||||
|
||||
try:
|
||||
new_job = await self.create_training_log(log_data)
|
||||
await self.session.commit()
|
||||
logger.info("Created new training job atomically",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id)
|
||||
return (new_job, True)
|
||||
except Exception as create_error:
|
||||
error_str = str(create_error).lower()
|
||||
# Check if this is a unique constraint violation
|
||||
if "unique" in error_str or "duplicate" in error_str or "constraint" in error_str:
|
||||
await self.session.rollback()
|
||||
# Another pod created a job, fetch it
|
||||
logger.info("Unique constraint hit, fetching existing job",
|
||||
tenant_id=tenant_id,
|
||||
requested_job_id=job_id)
|
||||
existing = await self.get_active_jobs(tenant_id=tenant_id)
|
||||
pending = await self.get_logs_by_tenant(tenant_id=tenant_id, status="pending", limit=1)
|
||||
if existing or pending:
|
||||
active_job = existing[0] if existing else pending[0]
|
||||
return (active_job, False)
|
||||
# If still no job found, something went wrong
|
||||
raise DatabaseError(f"Constraint violation but no active job found: {create_error}")
|
||||
else:
|
||||
raise
|
||||
|
||||
except DatabaseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to create job atomically",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create training job atomically: {str(e)}")
|
||||
|
||||
async def recover_stale_jobs(self, stale_threshold_minutes: int = 60) -> List[ModelTrainingLog]:
|
||||
"""
|
||||
Find and mark stale running jobs as failed.
|
||||
|
||||
This is used during service startup to clean up jobs that were
|
||||
running when a pod crashed. With multiple replicas, only stale
|
||||
jobs (not updated recently) should be marked as failed.
|
||||
|
||||
Args:
|
||||
stale_threshold_minutes: Jobs not updated for this long are considered stale
|
||||
|
||||
Returns:
|
||||
List of jobs that were marked as failed
|
||||
"""
|
||||
try:
|
||||
stale_cutoff = datetime.now() - timedelta(minutes=stale_threshold_minutes)
|
||||
|
||||
# Find running jobs that haven't been updated recently
|
||||
query = text("""
|
||||
SELECT id, job_id, tenant_id, status, updated_at
|
||||
FROM model_training_logs
|
||||
WHERE status IN ('running', 'pending')
|
||||
AND updated_at < :stale_cutoff
|
||||
""")
|
||||
|
||||
result = await self.session.execute(query, {"stale_cutoff": stale_cutoff})
|
||||
stale_jobs = result.fetchall()
|
||||
|
||||
recovered_jobs = []
|
||||
for row in stale_jobs:
|
||||
try:
|
||||
# Mark as failed
|
||||
update_query = text("""
|
||||
UPDATE model_training_logs
|
||||
SET status = 'failed',
|
||||
error_message = :error_msg,
|
||||
end_time = :end_time,
|
||||
updated_at = :updated_at
|
||||
WHERE id = :id AND status IN ('running', 'pending')
|
||||
""")
|
||||
|
||||
await self.session.execute(update_query, {
|
||||
"id": row.id,
|
||||
"error_msg": f"Job recovered as failed - not updated since {row.updated_at.isoformat()}. Pod may have crashed.",
|
||||
"end_time": datetime.now(),
|
||||
"updated_at": datetime.now()
|
||||
})
|
||||
|
||||
logger.warning("Recovered stale training job",
|
||||
job_id=row.job_id,
|
||||
tenant_id=str(row.tenant_id),
|
||||
last_updated=row.updated_at.isoformat() if row.updated_at else "unknown")
|
||||
|
||||
# Fetch the updated job to return
|
||||
job = await self.get_by_job_id(row.job_id)
|
||||
if job:
|
||||
recovered_jobs.append(job)
|
||||
|
||||
except Exception as job_error:
|
||||
logger.error("Failed to recover individual stale job",
|
||||
job_id=row.job_id,
|
||||
error=str(job_error))
|
||||
|
||||
if recovered_jobs:
|
||||
await self.session.commit()
|
||||
logger.info("Stale job recovery completed",
|
||||
recovered_count=len(recovered_jobs),
|
||||
stale_threshold_minutes=stale_threshold_minutes)
|
||||
|
||||
return recovered_jobs
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to recover stale jobs",
|
||||
error=str(e))
|
||||
await self.session.rollback()
|
||||
return []
|
||||
0
services/training/app/schemas/__init__.py
Normal file
0
services/training/app/schemas/__init__.py
Normal file
384
services/training/app/schemas/training.py
Normal file
384
services/training/app/schemas/training.py
Normal file
@@ -0,0 +1,384 @@
|
||||
# services/training/app/schemas/training.py
|
||||
"""
|
||||
Complete schema definitions for training service
|
||||
Includes all request/response schemas used by the API endpoints
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from typing import List, Optional, Dict, Any, Union, Tuple
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
class TrainingStatus(str, Enum):
|
||||
"""Training job status enumeration"""
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class TrainingJobRequest(BaseModel):
|
||||
"""Request schema for starting a training job"""
|
||||
products: Optional[List[str]] = Field(None, description="Specific products to train (if None, trains all)")
|
||||
start_date: Optional[datetime] = Field(None, description="Start date for training data")
|
||||
end_date: Optional[datetime] = Field(None, description="End date for training data")
|
||||
|
||||
|
||||
class SingleProductTrainingRequest(BaseModel):
|
||||
"""Request schema for training a single product"""
|
||||
start_date: Optional[datetime] = Field(None, description="Start date for training data")
|
||||
end_date: Optional[datetime] = Field(None, description="End date for training data")
|
||||
|
||||
# Prophet-specific parameters
|
||||
seasonality_mode: str = Field("additive", description="Prophet seasonality mode")
|
||||
daily_seasonality: bool = Field(True, description="Enable daily seasonality")
|
||||
weekly_seasonality: bool = Field(True, description="Enable weekly seasonality")
|
||||
yearly_seasonality: bool = Field(True, description="Enable yearly seasonality")
|
||||
|
||||
# Location parameters
|
||||
bakery_location: Optional[Tuple[float, float]] = Field(None, description="Bakery coordinates (latitude, longitude)")
|
||||
|
||||
class DateRangeInfo(BaseModel):
|
||||
"""Schema for date range information"""
|
||||
start: str = Field(..., description="Start date in ISO format")
|
||||
end: str = Field(..., description="End date in ISO format")
|
||||
|
||||
class DataSummary(BaseModel):
|
||||
"""Schema for training data summary"""
|
||||
sales_records: int = Field(..., description="Number of sales records used")
|
||||
weather_records: int = Field(..., description="Number of weather records used")
|
||||
traffic_records: int = Field(..., description="Number of traffic records used")
|
||||
date_range: DateRangeInfo = Field(..., description="Date range of training data")
|
||||
data_sources_used: List[str] = Field(..., description="List of data sources used")
|
||||
constraints_applied: Dict[str, str] = Field(default_factory=dict, description="Constraints applied during data collection")
|
||||
|
||||
class ProductTrainingResult(BaseModel):
|
||||
"""Schema for individual product training results"""
|
||||
inventory_product_id: UUID = Field(..., description="Inventory product UUID")
|
||||
status: str = Field(..., description="Training status for this product")
|
||||
model_id: Optional[str] = Field(None, description="Trained model identifier")
|
||||
data_points: int = Field(..., description="Number of data points used for training")
|
||||
metrics: Optional[Dict[str, float]] = Field(None, description="Training metrics (MAE, MAPE, etc.)")
|
||||
training_time_seconds: Optional[float] = Field(None, description="Time taken to train this model")
|
||||
error_message: Optional[str] = Field(None, description="Error message if training failed")
|
||||
|
||||
class TrainingResults(BaseModel):
|
||||
"""Schema for overall training results"""
|
||||
total_products: int = Field(..., description="Total number of products")
|
||||
successful_trainings: int = Field(..., description="Number of successfully trained models")
|
||||
failed_trainings: int = Field(..., description="Number of failed trainings")
|
||||
products: List[ProductTrainingResult] = Field(..., description="Results for each product")
|
||||
overall_training_time_seconds: float = Field(..., description="Total training time")
|
||||
|
||||
class TrainingJobResponse(BaseModel):
|
||||
"""Enhanced response schema for training job with detailed results"""
|
||||
job_id: str = Field(..., description="Unique training job identifier")
|
||||
tenant_id: str = Field(..., description="Tenant identifier")
|
||||
status: TrainingStatus = Field(..., description="Overall job status")
|
||||
|
||||
# Required fields for basic response (backwards compatibility)
|
||||
message: str = Field(..., description="Status message")
|
||||
created_at: datetime = Field(..., description="Job creation timestamp")
|
||||
estimated_duration_minutes: int = Field(..., description="Estimated completion time in minutes")
|
||||
|
||||
# New detailed fields (optional for backwards compatibility)
|
||||
training_results: Optional[TrainingResults] = Field(None, description="Detailed training results")
|
||||
data_summary: Optional[DataSummary] = Field(None, description="Summary of training data used")
|
||||
completed_at: Optional[str] = Field(None, description="Job completion timestamp in ISO format")
|
||||
|
||||
# Additional optional fields
|
||||
error_details: Optional[Dict[str, Any]] = Field(None, description="Detailed error information if failed")
|
||||
processing_metadata: Optional[Dict[str, Any]] = Field(None, description="Additional processing metadata")
|
||||
|
||||
@validator('tenant_id', 'job_id', pre=True)
|
||||
def convert_uuid_to_string(cls, v):
|
||||
"""Convert UUID objects to strings for JSON serialization"""
|
||||
if isinstance(v, UUID):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class TrainingJobStatus(BaseModel):
|
||||
"""Response schema for training job status checks"""
|
||||
job_id: str = Field(..., description="Training job identifier")
|
||||
status: TrainingStatus = Field(..., description="Current job status")
|
||||
progress: int = Field(0, description="Progress percentage (0-100)")
|
||||
current_step: str = Field("", description="Current processing step")
|
||||
started_at: datetime = Field(..., description="Job start timestamp")
|
||||
completed_at: Optional[datetime] = Field(None, description="Job completion timestamp")
|
||||
products_total: int = Field(0, description="Total number of products to train")
|
||||
products_completed: int = Field(0, description="Number of products completed")
|
||||
products_failed: int = Field(0, description="Number of products that failed")
|
||||
error_message: Optional[str] = Field(None, description="Error message if failed")
|
||||
estimated_time_remaining_seconds: Optional[int] = Field(None, description="Estimated time remaining in seconds")
|
||||
message: Optional[str] = Field(None, description="Optional status message")
|
||||
|
||||
@validator('job_id', pre=True)
|
||||
def convert_uuid_to_string(cls, v):
|
||||
"""Convert UUID objects to strings for JSON serialization"""
|
||||
if isinstance(v, UUID):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class TrainingJobProgress(BaseModel):
|
||||
"""Schema for real-time training job progress updates"""
|
||||
job_id: str = Field(..., description="Training job identifier")
|
||||
status: TrainingStatus = Field(..., description="Current job status")
|
||||
progress: int = Field(0, description="Progress percentage (0-100)", ge=0, le=100)
|
||||
current_step: str = Field(..., description="Current processing step")
|
||||
current_product: Optional[str] = Field(None, description="Currently training product")
|
||||
products_completed: int = Field(0, description="Number of products completed")
|
||||
products_total: int = Field(0, description="Total number of products")
|
||||
estimated_time_remaining_minutes: Optional[int] = Field(None, description="Estimated time remaining")
|
||||
timestamp: datetime = Field(default_factory=datetime.now, description="Progress update timestamp")
|
||||
|
||||
@validator('job_id', pre=True)
|
||||
def convert_uuid_to_string(cls, v):
|
||||
"""Convert UUID objects to strings for JSON serialization"""
|
||||
if isinstance(v, UUID):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class DataValidationRequest(BaseModel):
|
||||
"""Request schema for validating training data"""
|
||||
products: Optional[List[str]] = Field(None, description="Specific products to validate (if None, validates all)")
|
||||
min_data_points: int = Field(30, description="Minimum required data points per product", ge=10, le=1000)
|
||||
start_date: Optional[datetime] = Field(None, description="Start date for data validation")
|
||||
end_date: Optional[datetime] = Field(None, description="End date for data validation")
|
||||
|
||||
@validator('min_data_points')
|
||||
def validate_min_data_points(cls, v):
|
||||
if v < 10:
|
||||
raise ValueError('min_data_points must be at least 10')
|
||||
return v
|
||||
|
||||
|
||||
class DataValidationResponse(BaseModel):
|
||||
"""Response schema for data validation results"""
|
||||
is_valid: bool = Field(..., description="Whether the data is valid for training")
|
||||
issues: List[str] = Field(default_factory=list, description="List of data quality issues")
|
||||
recommendations: List[str] = Field(default_factory=list, description="Recommendations for improvement")
|
||||
estimated_time_minutes: int = Field(..., description="Estimated training time in minutes")
|
||||
products_analyzed: int = Field(..., description="Number of products analyzed")
|
||||
total_data_points: int = Field(..., description="Total data points available")
|
||||
products_with_insufficient_data: List[str] = Field(default_factory=list, description="Products with insufficient data")
|
||||
data_quality_score: float = Field(0.0, description="Overall data quality score (0-1)", ge=0.0, le=1.0)
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
"""Schema for trained model information"""
|
||||
model_id: str = Field(..., description="Unique model identifier")
|
||||
model_path: str = Field(..., description="Path to stored model")
|
||||
model_type: str = Field("prophet", description="Type of ML model")
|
||||
training_samples: int = Field(..., description="Number of training samples")
|
||||
features: List[str] = Field(..., description="List of features used")
|
||||
hyperparameters: Dict[str, Any] = Field(..., description="Model hyperparameters")
|
||||
training_metrics: Dict[str, float] = Field(..., description="Training performance metrics")
|
||||
trained_at: datetime = Field(..., description="Training completion timestamp")
|
||||
data_period: Dict[str, str] = Field(..., description="Training data period")
|
||||
|
||||
|
||||
class ProductTrainingResult(BaseModel):
|
||||
"""Schema for individual product training result"""
|
||||
inventory_product_id: UUID = Field(..., description="Inventory product UUID")
|
||||
status: str = Field(..., description="Training status for this product")
|
||||
model_info: Optional[ModelInfo] = Field(None, description="Model information if successful")
|
||||
data_points: int = Field(..., description="Number of data points used")
|
||||
error_message: Optional[str] = Field(None, description="Error message if failed")
|
||||
trained_at: datetime = Field(..., description="Training completion timestamp")
|
||||
training_duration_seconds: Optional[float] = Field(None, description="Training duration in seconds")
|
||||
|
||||
|
||||
class TrainingResultsResponse(BaseModel):
|
||||
"""Response schema for complete training results"""
|
||||
job_id: str = Field(..., description="Training job identifier")
|
||||
tenant_id: str = Field(..., description="Tenant identifier")
|
||||
status: TrainingStatus = Field(..., description="Overall job status")
|
||||
products_trained: int = Field(..., description="Number of products successfully trained")
|
||||
products_failed: int = Field(..., description="Number of products that failed training")
|
||||
total_products: int = Field(..., description="Total number of products processed")
|
||||
training_results: Dict[str, ProductTrainingResult] = Field(..., description="Per-product results")
|
||||
summary: Dict[str, Any] = Field(..., description="Training summary statistics")
|
||||
completed_at: datetime = Field(..., description="Job completion timestamp")
|
||||
|
||||
@validator('tenant_id', 'job_id', pre=True)
|
||||
def convert_uuid_to_string(cls, v):
|
||||
"""Convert UUID objects to strings for JSON serialization"""
|
||||
if isinstance(v, UUID):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class TrainingValidationResult(BaseModel):
|
||||
"""Schema for training data validation results"""
|
||||
is_valid: bool = Field(..., description="Whether the data is valid for training")
|
||||
issues: List[str] = Field(default_factory=list, description="List of data quality issues")
|
||||
recommendations: List[str] = Field(default_factory=list, description="Recommendations for improvement")
|
||||
estimated_time_minutes: int = Field(..., description="Estimated training time in minutes")
|
||||
products_analyzed: int = Field(..., description="Number of products analyzed")
|
||||
total_data_points: int = Field(..., description="Total data points available")
|
||||
|
||||
|
||||
class TrainingMetrics(BaseModel):
|
||||
"""Schema for training performance metrics"""
|
||||
mae: float = Field(..., description="Mean Absolute Error")
|
||||
mse: float = Field(..., description="Mean Squared Error")
|
||||
rmse: float = Field(..., description="Root Mean Squared Error")
|
||||
mape: float = Field(..., description="Mean Absolute Percentage Error")
|
||||
r2_score: float = Field(..., description="R-squared score")
|
||||
mean_actual: float = Field(..., description="Mean of actual values")
|
||||
mean_predicted: float = Field(..., description="Mean of predicted values")
|
||||
|
||||
|
||||
class ExternalDataConfig(BaseModel):
|
||||
"""Configuration for external data sources"""
|
||||
weather_enabled: bool = Field(True, description="Enable weather data")
|
||||
traffic_enabled: bool = Field(True, description="Enable traffic data")
|
||||
weather_features: List[str] = Field(
|
||||
default_factory=lambda: ["temperature", "precipitation", "humidity"],
|
||||
description="Weather features to include"
|
||||
)
|
||||
traffic_features: List[str] = Field(
|
||||
default_factory=lambda: ["traffic_volume"],
|
||||
description="Traffic features to include"
|
||||
)
|
||||
|
||||
|
||||
class TrainingJobConfig(BaseModel):
|
||||
"""Complete training job configuration"""
|
||||
external_data: ExternalDataConfig = Field(default_factory=ExternalDataConfig)
|
||||
prophet_params: Dict[str, Any] = Field(
|
||||
default_factory=lambda: {
|
||||
"seasonality_mode": "additive",
|
||||
"daily_seasonality": True,
|
||||
"weekly_seasonality": True,
|
||||
"yearly_seasonality": True
|
||||
},
|
||||
description="Prophet model parameters"
|
||||
)
|
||||
data_filters: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Data filtering parameters"
|
||||
)
|
||||
validation_params: Dict[str, Any] = Field(
|
||||
default_factory=lambda: {"min_data_points": 30},
|
||||
description="Data validation parameters"
|
||||
)
|
||||
|
||||
|
||||
class TrainedModelResponse(BaseModel):
|
||||
"""Response schema for trained model information"""
|
||||
model_id: str = Field(..., description="Unique model identifier")
|
||||
tenant_id: str = Field(..., description="Tenant identifier")
|
||||
inventory_product_id: UUID = Field(..., description="Inventory product UUID")
|
||||
model_type: str = Field(..., description="Type of ML model")
|
||||
model_path: str = Field(..., description="Path to stored model")
|
||||
version: int = Field(..., description="Model version")
|
||||
training_samples: int = Field(..., description="Number of training samples")
|
||||
features: List[str] = Field(..., description="List of features used")
|
||||
hyperparameters: Dict[str, Any] = Field(..., description="Model hyperparameters")
|
||||
training_metrics: Dict[str, float] = Field(..., description="Training performance metrics")
|
||||
is_active: bool = Field(..., description="Whether model is active")
|
||||
created_at: datetime = Field(..., description="Model creation timestamp")
|
||||
data_period_start: Optional[datetime] = Field(None, description="Training data start date")
|
||||
data_period_end: Optional[datetime] = Field(None, description="Training data end date")
|
||||
|
||||
@validator('tenant_id', 'model_id', pre=True)
|
||||
def convert_uuid_to_string(cls, v):
|
||||
"""Convert UUID objects to strings for JSON serialization"""
|
||||
if isinstance(v, UUID):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ModelTrainingStats(BaseModel):
|
||||
"""Schema for model training statistics"""
|
||||
total_models: int = Field(..., description="Total number of trained models")
|
||||
active_models: int = Field(..., description="Number of active models")
|
||||
last_training_date: Optional[datetime] = Field(None, description="Last training date")
|
||||
avg_training_time_minutes: float = Field(..., description="Average training time in minutes")
|
||||
success_rate: float = Field(..., description="Training success rate (0-1)")
|
||||
|
||||
|
||||
class BulkTrainingRequest(BaseModel):
|
||||
"""Request schema for bulk training operations"""
|
||||
tenant_ids: List[str] = Field(..., description="List of tenant IDs to train")
|
||||
config: TrainingJobConfig = Field(default_factory=TrainingJobConfig, description="Training configuration")
|
||||
priority: int = Field(1, description="Training priority (1-10)", ge=1, le=10)
|
||||
schedule_time: Optional[datetime] = Field(None, description="Schedule training for specific time")
|
||||
|
||||
|
||||
class TrainingScheduleResponse(BaseModel):
|
||||
"""Response schema for scheduled training jobs"""
|
||||
schedule_id: str = Field(..., description="Unique schedule identifier")
|
||||
tenant_ids: List[str] = Field(..., description="Scheduled tenant IDs")
|
||||
scheduled_time: datetime = Field(..., description="Scheduled execution time")
|
||||
status: str = Field(..., description="Schedule status")
|
||||
created_at: datetime = Field(..., description="Schedule creation timestamp")
|
||||
|
||||
|
||||
# WebSocket response schemas for real-time updates
|
||||
class TrainingProgressUpdate(BaseModel):
|
||||
"""WebSocket message for training progress updates"""
|
||||
type: str = Field("training_progress", description="Message type")
|
||||
job_id: str = Field(..., description="Training job identifier")
|
||||
progress: TrainingJobProgress = Field(..., description="Progress information")
|
||||
|
||||
|
||||
class TrainingCompletedUpdate(BaseModel):
|
||||
"""WebSocket message for training completion"""
|
||||
type: str = Field("training_completed", description="Message type")
|
||||
job_id: str = Field(..., description="Training job identifier")
|
||||
results: TrainingResultsResponse = Field(..., description="Training results")
|
||||
|
||||
|
||||
class TrainingErrorUpdate(BaseModel):
|
||||
"""WebSocket message for training errors"""
|
||||
type: str = Field("training_error", description="Message type")
|
||||
job_id: str = Field(..., description="Training job identifier")
|
||||
error: str = Field(..., description="Error message")
|
||||
timestamp: datetime = Field(default_factory=datetime.now, description="Error timestamp")
|
||||
|
||||
|
||||
class ModelMetricsResponse(BaseModel):
|
||||
"""Response schema for model performance metrics"""
|
||||
model_id: str = Field(..., description="Unique model identifier")
|
||||
accuracy: float = Field(..., description="Model accuracy (R2 score)")
|
||||
mape: float = Field(..., description="Mean Absolute Percentage Error")
|
||||
mae: float = Field(..., description="Mean Absolute Error")
|
||||
rmse: float = Field(..., description="Root Mean Square Error")
|
||||
r2_score: float = Field(..., description="R-squared score")
|
||||
training_samples: int = Field(..., description="Number of training samples used")
|
||||
features_used: List[str] = Field(..., description="List of features used in training")
|
||||
model_type: str = Field(..., description="Type of ML model")
|
||||
created_at: Optional[str] = Field(None, description="Model creation timestamp")
|
||||
last_used_at: Optional[str] = Field(None, description="Last time model was used")
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
# Union type for all WebSocket messages
|
||||
TrainingWebSocketMessage = Union[
|
||||
TrainingProgressUpdate,
|
||||
TrainingCompletedUpdate,
|
||||
TrainingErrorUpdate
|
||||
]
|
||||
317
services/training/app/schemas/validation.py
Normal file
317
services/training/app/schemas/validation.py
Normal file
@@ -0,0 +1,317 @@
|
||||
"""
|
||||
Comprehensive Input Validation Schemas
|
||||
Ensures all API inputs are properly validated before processing
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field, validator, root_validator
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
from uuid import UUID
|
||||
import re
|
||||
|
||||
|
||||
class TrainingJobCreateRequest(BaseModel):
|
||||
"""Schema for creating a new training job"""
|
||||
|
||||
tenant_id: UUID = Field(..., description="Tenant identifier")
|
||||
start_date: Optional[str] = Field(
|
||||
None,
|
||||
description="Training data start date (ISO format: YYYY-MM-DD)",
|
||||
example="2024-01-01"
|
||||
)
|
||||
end_date: Optional[str] = Field(
|
||||
None,
|
||||
description="Training data end date (ISO format: YYYY-MM-DD)",
|
||||
example="2024-12-31"
|
||||
)
|
||||
product_ids: Optional[List[UUID]] = Field(
|
||||
None,
|
||||
description="Specific products to train (optional, trains all if not provided)"
|
||||
)
|
||||
force_retrain: bool = Field(
|
||||
default=False,
|
||||
description="Force retraining even if recent models exist"
|
||||
)
|
||||
|
||||
@validator('start_date', 'end_date')
|
||||
def validate_date_format(cls, v):
|
||||
"""Validate date is in ISO format"""
|
||||
if v is not None:
|
||||
try:
|
||||
datetime.fromisoformat(v)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid date format: {v}. Use YYYY-MM-DD format")
|
||||
return v
|
||||
|
||||
@root_validator
|
||||
def validate_date_range(cls, values):
|
||||
"""Validate date range is logical"""
|
||||
start = values.get('start_date')
|
||||
end = values.get('end_date')
|
||||
|
||||
if start and end:
|
||||
start_dt = datetime.fromisoformat(start)
|
||||
end_dt = datetime.fromisoformat(end)
|
||||
|
||||
if end_dt <= start_dt:
|
||||
raise ValueError("end_date must be after start_date")
|
||||
|
||||
# Check reasonable range (max 3 years)
|
||||
if (end_dt - start_dt).days > 1095:
|
||||
raise ValueError("Date range cannot exceed 3 years (1095 days)")
|
||||
|
||||
# Check not in future
|
||||
if end_dt > datetime.now():
|
||||
raise ValueError("end_date cannot be in the future")
|
||||
|
||||
return values
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"tenant_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"start_date": "2024-01-01",
|
||||
"end_date": "2024-12-31",
|
||||
"product_ids": None,
|
||||
"force_retrain": False
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ForecastRequest(BaseModel):
|
||||
"""Schema for generating forecasts"""
|
||||
|
||||
tenant_id: UUID = Field(..., description="Tenant identifier")
|
||||
product_id: UUID = Field(..., description="Product identifier")
|
||||
forecast_days: int = Field(
|
||||
default=30,
|
||||
ge=1,
|
||||
le=365,
|
||||
description="Number of days to forecast (1-365)"
|
||||
)
|
||||
include_regressors: bool = Field(
|
||||
default=True,
|
||||
description="Include weather and traffic data in forecast"
|
||||
)
|
||||
confidence_level: float = Field(
|
||||
default=0.80,
|
||||
ge=0.5,
|
||||
le=0.99,
|
||||
description="Confidence interval (0.5-0.99)"
|
||||
)
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"tenant_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"product_id": "223e4567-e89b-12d3-a456-426614174000",
|
||||
"forecast_days": 30,
|
||||
"include_regressors": True,
|
||||
"confidence_level": 0.80
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ModelEvaluationRequest(BaseModel):
|
||||
"""Schema for model evaluation"""
|
||||
|
||||
tenant_id: UUID = Field(..., description="Tenant identifier")
|
||||
product_id: Optional[UUID] = Field(None, description="Specific product (optional)")
|
||||
evaluation_start_date: str = Field(..., description="Evaluation period start")
|
||||
evaluation_end_date: str = Field(..., description="Evaluation period end")
|
||||
|
||||
@validator('evaluation_start_date', 'evaluation_end_date')
|
||||
def validate_date_format(cls, v):
|
||||
try:
|
||||
datetime.fromisoformat(v)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid date format: {v}")
|
||||
return v
|
||||
|
||||
@root_validator
|
||||
def validate_evaluation_period(cls, values):
|
||||
start = values.get('evaluation_start_date')
|
||||
end = values.get('evaluation_end_date')
|
||||
|
||||
if start and end:
|
||||
start_dt = datetime.fromisoformat(start)
|
||||
end_dt = datetime.fromisoformat(end)
|
||||
|
||||
if end_dt <= start_dt:
|
||||
raise ValueError("evaluation_end_date must be after evaluation_start_date")
|
||||
|
||||
# Minimum 7 days for meaningful evaluation
|
||||
if (end_dt - start_dt).days < 7:
|
||||
raise ValueError("Evaluation period must be at least 7 days")
|
||||
|
||||
return values
|
||||
|
||||
|
||||
class BulkTrainingRequest(BaseModel):
|
||||
"""Schema for bulk training operations"""
|
||||
|
||||
tenant_ids: List[UUID] = Field(
|
||||
...,
|
||||
min_items=1,
|
||||
max_items=100,
|
||||
description="List of tenant IDs (max 100)"
|
||||
)
|
||||
start_date: Optional[str] = Field(None, description="Common start date")
|
||||
end_date: Optional[str] = Field(None, description="Common end date")
|
||||
parallel: bool = Field(
|
||||
default=True,
|
||||
description="Execute training jobs in parallel"
|
||||
)
|
||||
|
||||
@validator('tenant_ids')
|
||||
def validate_unique_tenants(cls, v):
|
||||
if len(v) != len(set(v)):
|
||||
raise ValueError("Duplicate tenant IDs not allowed")
|
||||
return v
|
||||
|
||||
|
||||
class HyperparameterOverride(BaseModel):
|
||||
"""Schema for manual hyperparameter override"""
|
||||
|
||||
changepoint_prior_scale: Optional[float] = Field(
|
||||
None, ge=0.001, le=0.5,
|
||||
description="Flexibility of trend changes"
|
||||
)
|
||||
seasonality_prior_scale: Optional[float] = Field(
|
||||
None, ge=0.01, le=10.0,
|
||||
description="Strength of seasonality"
|
||||
)
|
||||
holidays_prior_scale: Optional[float] = Field(
|
||||
None, ge=0.01, le=10.0,
|
||||
description="Strength of holiday effects"
|
||||
)
|
||||
seasonality_mode: Optional[str] = Field(
|
||||
None,
|
||||
description="Seasonality mode",
|
||||
regex="^(additive|multiplicative)$"
|
||||
)
|
||||
daily_seasonality: Optional[bool] = None
|
||||
weekly_seasonality: Optional[bool] = None
|
||||
yearly_seasonality: Optional[bool] = None
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"changepoint_prior_scale": 0.05,
|
||||
"seasonality_prior_scale": 10.0,
|
||||
"holidays_prior_scale": 10.0,
|
||||
"seasonality_mode": "additive",
|
||||
"daily_seasonality": False,
|
||||
"weekly_seasonality": True,
|
||||
"yearly_seasonality": True
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class AdvancedTrainingRequest(TrainingJobCreateRequest):
|
||||
"""Extended training request with advanced options"""
|
||||
|
||||
hyperparameter_override: Optional[HyperparameterOverride] = Field(
|
||||
None,
|
||||
description="Manual hyperparameter settings (skips optimization)"
|
||||
)
|
||||
enable_cross_validation: bool = Field(
|
||||
default=True,
|
||||
description="Enable cross-validation during training"
|
||||
)
|
||||
cv_folds: int = Field(
|
||||
default=3,
|
||||
ge=2,
|
||||
le=10,
|
||||
description="Number of cross-validation folds"
|
||||
)
|
||||
optimization_trials: Optional[int] = Field(
|
||||
None,
|
||||
ge=5,
|
||||
le=100,
|
||||
description="Number of hyperparameter optimization trials (overrides defaults)"
|
||||
)
|
||||
save_diagnostics: bool = Field(
|
||||
default=False,
|
||||
description="Save detailed diagnostic plots and metrics"
|
||||
)
|
||||
|
||||
|
||||
class DataQualityCheckRequest(BaseModel):
|
||||
"""Schema for data quality validation"""
|
||||
|
||||
tenant_id: UUID = Field(..., description="Tenant identifier")
|
||||
start_date: str = Field(..., description="Check period start")
|
||||
end_date: str = Field(..., description="Check period end")
|
||||
product_ids: Optional[List[UUID]] = Field(
|
||||
None,
|
||||
description="Specific products to check"
|
||||
)
|
||||
include_recommendations: bool = Field(
|
||||
default=True,
|
||||
description="Include improvement recommendations"
|
||||
)
|
||||
|
||||
@validator('start_date', 'end_date')
|
||||
def validate_date(cls, v):
|
||||
try:
|
||||
datetime.fromisoformat(v)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid date format: {v}")
|
||||
return v
|
||||
|
||||
|
||||
class ModelQueryParams(BaseModel):
|
||||
"""Query parameters for model listing"""
|
||||
|
||||
tenant_id: Optional[UUID] = None
|
||||
product_id: Optional[UUID] = None
|
||||
is_active: Optional[bool] = None
|
||||
is_production: Optional[bool] = None
|
||||
model_type: Optional[str] = Field(None, regex="^(prophet|prophet_optimized|lstm|arima)$")
|
||||
min_accuracy: Optional[float] = Field(None, ge=0.0, le=1.0)
|
||||
created_after: Optional[datetime] = None
|
||||
created_before: Optional[datetime] = None
|
||||
limit: int = Field(default=100, ge=1, le=1000)
|
||||
offset: int = Field(default=0, ge=0)
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"tenant_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"is_active": True,
|
||||
"is_production": True,
|
||||
"limit": 50,
|
||||
"offset": 0
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def validate_uuid(value: str) -> UUID:
|
||||
"""Validate and convert string to UUID"""
|
||||
try:
|
||||
return UUID(value)
|
||||
except (ValueError, AttributeError):
|
||||
raise ValueError(f"Invalid UUID format: {value}")
|
||||
|
||||
|
||||
def validate_date_string(value: str) -> datetime:
|
||||
"""Validate and convert date string to datetime"""
|
||||
try:
|
||||
return datetime.fromisoformat(value)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid date format: {value}. Use ISO format (YYYY-MM-DD)")
|
||||
|
||||
|
||||
def validate_positive_integer(value: int, field_name: str = "value") -> int:
|
||||
"""Validate positive integer"""
|
||||
if value <= 0:
|
||||
raise ValueError(f"{field_name} must be positive, got {value}")
|
||||
return value
|
||||
|
||||
|
||||
def validate_probability(value: float, field_name: str = "value") -> float:
|
||||
"""Validate probability value (0.0-1.0)"""
|
||||
if not 0.0 <= value <= 1.0:
|
||||
raise ValueError(f"{field_name} must be between 0.0 and 1.0, got {value}")
|
||||
return value
|
||||
16
services/training/app/services/__init__.py
Normal file
16
services/training/app/services/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
Training Service Layer
|
||||
Business logic services for ML training and model management
|
||||
"""
|
||||
|
||||
from .training_service import EnhancedTrainingService
|
||||
from .training_orchestrator import TrainingDataOrchestrator
|
||||
from .date_alignment_service import DateAlignmentService
|
||||
from .data_client import DataClient
|
||||
|
||||
__all__ = [
|
||||
"EnhancedTrainingService",
|
||||
"TrainingDataOrchestrator",
|
||||
"DateAlignmentService",
|
||||
"DataClient"
|
||||
]
|
||||
410
services/training/app/services/data_client.py
Normal file
410
services/training/app/services/data_client.py
Normal file
@@ -0,0 +1,410 @@
|
||||
# services/training/app/services/data_client.py
|
||||
"""
|
||||
Training Service Data Client
|
||||
Migrated to use shared service clients with timeout configuration
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
import httpx
|
||||
|
||||
# Import the shared clients
|
||||
from shared.clients import get_sales_client, get_external_client, get_service_clients
|
||||
from app.core.config import settings
|
||||
from app.core import constants as const
|
||||
from app.utils.circuit_breaker import circuit_breaker_registry, CircuitBreakerError
|
||||
from app.utils.retry import with_retry, HTTP_RETRY_STRATEGY, EXTERNAL_SERVICE_RETRY_STRATEGY
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
class DataClient:
|
||||
"""
|
||||
Data client for training service
|
||||
Now uses the shared data service client under the hood
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Get the new specialized clients with timeout configuration
|
||||
self.sales_client = get_sales_client(settings, "training")
|
||||
self.external_client = get_external_client(settings, "training")
|
||||
|
||||
# ExternalServiceClient always has get_stored_traffic_data_for_training method
|
||||
self.supports_stored_traffic_data = True
|
||||
|
||||
# Configure timeouts for HTTP clients
|
||||
self._configure_timeouts()
|
||||
|
||||
# Initialize circuit breakers for external services
|
||||
self._init_circuit_breakers()
|
||||
|
||||
def _configure_timeouts(self):
|
||||
"""Configure appropriate timeouts for HTTP clients"""
|
||||
timeout = httpx.Timeout(
|
||||
connect=const.HTTP_TIMEOUT_DEFAULT,
|
||||
read=const.HTTP_TIMEOUT_LONG_RUNNING,
|
||||
write=const.HTTP_TIMEOUT_DEFAULT,
|
||||
pool=const.HTTP_TIMEOUT_DEFAULT
|
||||
)
|
||||
|
||||
# Apply timeout to clients if they have httpx clients
|
||||
# Note: BaseServiceClient manages its own HTTP client internally
|
||||
if hasattr(self.sales_client, 'client') and isinstance(self.sales_client.client, httpx.AsyncClient):
|
||||
self.sales_client.client.timeout = timeout
|
||||
|
||||
if hasattr(self.external_client, 'client') and isinstance(self.external_client.client, httpx.AsyncClient):
|
||||
self.external_client.client.timeout = timeout
|
||||
|
||||
def _init_circuit_breakers(self):
|
||||
"""Initialize circuit breakers for external service calls"""
|
||||
# Sales service circuit breaker
|
||||
self.sales_cb = circuit_breaker_registry.get_or_create(
|
||||
name="sales_service",
|
||||
failure_threshold=5,
|
||||
recovery_timeout=60.0,
|
||||
expected_exception=Exception
|
||||
)
|
||||
|
||||
# Weather service circuit breaker
|
||||
self.weather_cb = circuit_breaker_registry.get_or_create(
|
||||
name="weather_service",
|
||||
failure_threshold=3, # Weather is optional, fail faster
|
||||
recovery_timeout=30.0,
|
||||
expected_exception=Exception
|
||||
)
|
||||
|
||||
# Traffic service circuit breaker
|
||||
self.traffic_cb = circuit_breaker_registry.get_or_create(
|
||||
name="traffic_service",
|
||||
failure_threshold=3, # Traffic is optional, fail faster
|
||||
recovery_timeout=30.0,
|
||||
expected_exception=Exception
|
||||
)
|
||||
|
||||
@with_retry(max_attempts=3, initial_delay=1.0, max_delay=10.0)
|
||||
async def _fetch_sales_data_internal(
|
||||
self,
|
||||
tenant_id: str,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
product_id: Optional[str] = None,
|
||||
fetch_all: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Internal method to fetch sales data with automatic retry"""
|
||||
if fetch_all:
|
||||
sales_data = await self.sales_client.get_all_sales_data(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
product_id=product_id,
|
||||
aggregation="daily",
|
||||
page_size=1000,
|
||||
max_pages=100
|
||||
)
|
||||
else:
|
||||
sales_data = await self.sales_client.get_sales_data(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
product_id=product_id,
|
||||
aggregation="daily"
|
||||
)
|
||||
sales_data = sales_data or []
|
||||
|
||||
if sales_data:
|
||||
logger.info(f"Fetched {len(sales_data)} sales records",
|
||||
tenant_id=tenant_id, product_id=product_id, fetch_all=fetch_all)
|
||||
return sales_data
|
||||
else:
|
||||
logger.error("No sales data returned", tenant_id=tenant_id)
|
||||
raise ValueError(f"No sales data available for tenant {tenant_id}")
|
||||
|
||||
async def fetch_sales_data(
|
||||
self,
|
||||
tenant_id: str,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
product_id: Optional[str] = None,
|
||||
fetch_all: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch sales data for training with circuit breaker protection
|
||||
"""
|
||||
try:
|
||||
return await self.sales_cb.call(
|
||||
self._fetch_sales_data_internal,
|
||||
tenant_id, start_date, end_date, product_id, fetch_all
|
||||
)
|
||||
except CircuitBreakerError as exc:
|
||||
logger.error("Sales service circuit breaker open", error_message=str(exc))
|
||||
raise RuntimeError(f"Sales service unavailable: {str(exc)}")
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error("Error fetching sales data", tenant_id=tenant_id, error_message=str(exc))
|
||||
raise RuntimeError(f"Failed to fetch sales data: {str(exc)}")
|
||||
|
||||
async def fetch_weather_data(
|
||||
self,
|
||||
tenant_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
latitude: Optional[float] = None,
|
||||
longitude: Optional[float] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch weather data for training
|
||||
All the error handling and retry logic is now in the base client!
|
||||
"""
|
||||
try:
|
||||
weather_data = await self.external_client.get_weather_historical(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
latitude=latitude,
|
||||
longitude=longitude
|
||||
)
|
||||
|
||||
if weather_data:
|
||||
logger.info(f"Fetched {len(weather_data)} weather records",
|
||||
tenant_id=tenant_id)
|
||||
return weather_data
|
||||
else:
|
||||
logger.warning("No weather data returned, will use synthetic data", tenant_id=tenant_id)
|
||||
return []
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("Error fetching weather data, will use synthetic data", tenant_id=tenant_id, error_message=str(exc))
|
||||
return []
|
||||
|
||||
async def fetch_traffic_data_unified(
|
||||
self,
|
||||
tenant_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
latitude: Optional[float] = None,
|
||||
longitude: Optional[float] = None,
|
||||
force_refresh: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Unified traffic data fetching with intelligent cache-first strategy
|
||||
|
||||
Strategy:
|
||||
1. Check if stored/cached traffic data exists for the date range
|
||||
2. If exists and not force_refresh, return cached data
|
||||
3. If not exists or force_refresh, fetch fresh data
|
||||
4. Always return data without duplicate fetching
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
start_date: Start date string (ISO format)
|
||||
end_date: End date string (ISO format)
|
||||
latitude: Optional latitude for location-based data
|
||||
longitude: Optional longitude for location-based data
|
||||
force_refresh: If True, bypass cache and fetch fresh data
|
||||
"""
|
||||
cache_key = f"{tenant_id}_{start_date}_{end_date}_{latitude}_{longitude}"
|
||||
|
||||
try:
|
||||
# Step 1: Try to get stored/cached data first (unless force_refresh)
|
||||
if not force_refresh and self.supports_stored_traffic_data:
|
||||
logger.info("Attempting to fetch cached traffic data",
|
||||
tenant_id=tenant_id, cache_key=cache_key)
|
||||
|
||||
try:
|
||||
cached_data = await self.external_client.get_stored_traffic_data_for_training(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
latitude=latitude,
|
||||
longitude=longitude
|
||||
)
|
||||
|
||||
if cached_data and len(cached_data) > 0:
|
||||
logger.info(f"✅ Using cached traffic data: {len(cached_data)} records",
|
||||
tenant_id=tenant_id)
|
||||
return cached_data
|
||||
else:
|
||||
logger.info("No cached traffic data found, fetching fresh data",
|
||||
tenant_id=tenant_id)
|
||||
except Exception as cache_error:
|
||||
logger.warning(f"Cache fetch failed, falling back to fresh data: {cache_error}",
|
||||
tenant_id=tenant_id)
|
||||
|
||||
# Step 2: Fetch fresh data if no cache or force_refresh
|
||||
logger.info("Fetching fresh traffic data" + (" (force refresh)" if force_refresh else ""),
|
||||
tenant_id=tenant_id)
|
||||
|
||||
fresh_data = await self.external_client.get_traffic_data(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
latitude=latitude,
|
||||
longitude=longitude
|
||||
)
|
||||
|
||||
if fresh_data and len(fresh_data) > 0:
|
||||
logger.info(f"✅ Fetched fresh traffic data: {len(fresh_data)} records",
|
||||
tenant_id=tenant_id)
|
||||
return fresh_data
|
||||
else:
|
||||
logger.warning("No fresh traffic data available", tenant_id=tenant_id)
|
||||
return []
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Error in unified traffic data fetch",
|
||||
tenant_id=tenant_id, cache_key=cache_key, error_message=str(exc))
|
||||
return []
|
||||
|
||||
# Legacy methods for backward compatibility - now delegate to unified method
|
||||
async def fetch_traffic_data(
|
||||
self,
|
||||
tenant_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
latitude: Optional[float] = None,
|
||||
longitude: Optional[float] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Legacy method - delegates to unified fetcher with cache-first strategy"""
|
||||
logger.info("Legacy fetch_traffic_data called - delegating to unified method", tenant_id=tenant_id)
|
||||
return await self.fetch_traffic_data_unified(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
latitude=latitude,
|
||||
longitude=longitude,
|
||||
force_refresh=False # Use cache-first for legacy calls
|
||||
)
|
||||
|
||||
async def fetch_stored_traffic_data_for_training(
|
||||
self,
|
||||
tenant_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
latitude: Optional[float] = None,
|
||||
longitude: Optional[float] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Legacy method - delegates to unified fetcher with cache-first strategy"""
|
||||
logger.info("Legacy fetch_stored_traffic_data_for_training called - delegating to unified method", tenant_id=tenant_id)
|
||||
return await self.fetch_traffic_data_unified(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
latitude=latitude,
|
||||
longitude=longitude,
|
||||
force_refresh=False # Use cache-first for training calls
|
||||
)
|
||||
|
||||
async def refresh_traffic_data(
|
||||
self,
|
||||
tenant_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
latitude: Optional[float] = None,
|
||||
longitude: Optional[float] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Convenience method to force refresh traffic data"""
|
||||
logger.info("Force refreshing traffic data (bypassing cache)", tenant_id=tenant_id)
|
||||
return await self.fetch_traffic_data_unified(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
latitude=latitude,
|
||||
longitude=longitude,
|
||||
force_refresh=True # Force fresh data
|
||||
)
|
||||
|
||||
async def validate_data_quality(
|
||||
self,
|
||||
tenant_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
sales_data: List[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate data quality before training with comprehensive checks
|
||||
"""
|
||||
try:
|
||||
errors = []
|
||||
warnings = []
|
||||
|
||||
# If sales data provided, validate it directly
|
||||
if sales_data is not None:
|
||||
if not sales_data or len(sales_data) == 0:
|
||||
errors.append("No sales data available for the specified period")
|
||||
return {"is_valid": False, "errors": errors, "warnings": warnings}
|
||||
|
||||
# Check minimum data points
|
||||
if len(sales_data) < 30:
|
||||
errors.append(f"Insufficient data points: {len(sales_data)} (minimum 30 required)")
|
||||
elif len(sales_data) < 90:
|
||||
warnings.append(f"Limited data points: {len(sales_data)} (recommended 90+)")
|
||||
|
||||
# Check for required fields
|
||||
required_fields = ['date', 'inventory_product_id']
|
||||
for record in sales_data[:5]: # Sample check
|
||||
missing = [f for f in required_fields if f not in record or record[f] is None]
|
||||
if missing:
|
||||
errors.append(f"Missing required fields: {missing}")
|
||||
break
|
||||
|
||||
# Check for data quality issues
|
||||
zero_count = sum(1 for r in sales_data if r.get('quantity', 0) == 0)
|
||||
zero_ratio = zero_count / len(sales_data)
|
||||
if zero_ratio > 0.9:
|
||||
errors.append(f"Too many zero values: {zero_ratio:.1%} of records")
|
||||
elif zero_ratio > 0.7:
|
||||
warnings.append(f"High zero value ratio: {zero_ratio:.1%}")
|
||||
|
||||
# Check product diversity
|
||||
unique_products = set(r.get('inventory_product_id') for r in sales_data if r.get('inventory_product_id'))
|
||||
if len(unique_products) == 0:
|
||||
errors.append("No valid product IDs found in sales data")
|
||||
elif len(unique_products) == 1:
|
||||
warnings.append("Only one product found - consider adding more products")
|
||||
|
||||
else:
|
||||
# Fetch data for validation
|
||||
sales_data = await self.fetch_sales_data(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
fetch_all=False
|
||||
)
|
||||
|
||||
if not sales_data:
|
||||
errors.append("Unable to fetch sales data for validation")
|
||||
return {"is_valid": False, "errors": errors, "warnings": warnings}
|
||||
|
||||
# Recursive call with fetched data
|
||||
return await self.validate_data_quality(
|
||||
tenant_id, start_date, end_date, sales_data
|
||||
)
|
||||
|
||||
is_valid = len(errors) == 0
|
||||
result = {
|
||||
"is_valid": is_valid,
|
||||
"errors": errors,
|
||||
"warnings": warnings,
|
||||
"data_points": len(sales_data) if sales_data else 0,
|
||||
"unique_products": len(unique_products) if sales_data else 0
|
||||
}
|
||||
|
||||
if is_valid:
|
||||
logger.info("Data validation passed",
|
||||
tenant_id=tenant_id,
|
||||
data_points=result["data_points"],
|
||||
warnings_count=len(warnings))
|
||||
else:
|
||||
logger.error("Data validation failed",
|
||||
tenant_id=tenant_id,
|
||||
errors=errors)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Error validating data", tenant_id=tenant_id, error_message=str(exc))
|
||||
raise ValueError(f"Data validation failed: {str(exc)}")
|
||||
|
||||
# Global instance - same as before, but much simpler implementation
|
||||
data_client = DataClient()
|
||||
239
services/training/app/services/date_alignment_service.py
Normal file
239
services/training/app/services/date_alignment_service.py
Normal file
@@ -0,0 +1,239 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import logging
|
||||
from app.utils.ml_datetime import ensure_timezone_aware
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DataSourceType(Enum):
|
||||
BAKERY_SALES = "bakery_sales"
|
||||
MADRID_TRAFFIC = "madrid_traffic"
|
||||
WEATHER_FORECAST = "weather_forecast"
|
||||
|
||||
@dataclass
|
||||
class DateRange:
|
||||
start: datetime
|
||||
end: datetime
|
||||
source: DataSourceType
|
||||
|
||||
def duration_days(self) -> int:
|
||||
return (self.end - self.start).days
|
||||
|
||||
def overlaps_with(self, other: 'DateRange') -> bool:
|
||||
return self.start <= other.end and other.start <= self.end
|
||||
|
||||
@dataclass
|
||||
class AlignedDateRange:
|
||||
start: datetime
|
||||
end: datetime
|
||||
available_sources: List[DataSourceType]
|
||||
constraints: Dict[str, str]
|
||||
|
||||
class DateAlignmentService:
|
||||
"""
|
||||
Central service for managing and aligning dates across multiple data sources
|
||||
for the bakery sales prediction model.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.MAX_TRAINING_RANGE_DAYS = 730 # Maximum training data range
|
||||
self.MIN_TRAINING_RANGE_DAYS = 30 # Minimum viable training data
|
||||
|
||||
def validate_and_align_dates(
|
||||
self,
|
||||
user_sales_range: DateRange,
|
||||
requested_start: Optional[datetime] = None,
|
||||
requested_end: Optional[datetime] = None
|
||||
) -> AlignedDateRange:
|
||||
"""
|
||||
Main method to validate and align dates across all data sources.
|
||||
|
||||
Args:
|
||||
user_sales_range: Date range of user-provided sales data
|
||||
requested_start: Optional explicit start date for training
|
||||
requested_end: Optional explicit end date for training
|
||||
|
||||
Returns:
|
||||
AlignedDateRange with validated start/end dates and available sources
|
||||
"""
|
||||
try:
|
||||
# Step 1: Determine the base date range
|
||||
base_range = self._determine_base_range(
|
||||
user_sales_range, requested_start, requested_end
|
||||
)
|
||||
|
||||
# Step 2: Apply data source constraints
|
||||
aligned_range = self._apply_data_source_constraints(base_range)
|
||||
|
||||
# Step 3: Validate final range
|
||||
self._validate_final_range(aligned_range)
|
||||
|
||||
logger.info(f"Date alignment completed: {aligned_range.start} to {aligned_range.end}")
|
||||
return aligned_range
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Date alignment failed: {str(e)}")
|
||||
raise ValueError(f"Unable to align dates: {str(e)}")
|
||||
|
||||
def _determine_base_range(
|
||||
self,
|
||||
user_sales_range: DateRange,
|
||||
requested_start: Optional[datetime],
|
||||
requested_end: Optional[datetime]
|
||||
) -> DateRange:
|
||||
"""Determine the base date range for training."""
|
||||
|
||||
# Use explicit dates if provided
|
||||
if requested_start and requested_end:
|
||||
requested_start = ensure_timezone_aware(requested_start)
|
||||
requested_end = ensure_timezone_aware(requested_end)
|
||||
|
||||
if requested_end <= requested_start:
|
||||
raise ValueError("End date must be after start date")
|
||||
return DateRange(requested_start, requested_end, DataSourceType.BAKERY_SALES)
|
||||
|
||||
# Otherwise, use the user's sales data range as the foundation
|
||||
start_date = ensure_timezone_aware(requested_start or user_sales_range.start)
|
||||
end_date = ensure_timezone_aware(requested_end or user_sales_range.end)
|
||||
|
||||
# Ensure we don't exceed maximum training range
|
||||
if (end_date - start_date).days > self.MAX_TRAINING_RANGE_DAYS:
|
||||
start_date = end_date - timedelta(days=self.MAX_TRAINING_RANGE_DAYS)
|
||||
logger.warning(f"Limiting training range to {self.MAX_TRAINING_RANGE_DAYS} days")
|
||||
|
||||
return DateRange(start_date, end_date, DataSourceType.BAKERY_SALES)
|
||||
|
||||
def _apply_data_source_constraints(self, base_range: DateRange) -> AlignedDateRange:
|
||||
"""Apply constraints from each data source and determine final aligned range."""
|
||||
|
||||
current_month = datetime.now(timezone.utc).replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
available_sources = [DataSourceType.BAKERY_SALES] # Always have sales data
|
||||
constraints = {}
|
||||
|
||||
# Madrid Traffic Data Constraint
|
||||
madrid_end_date = self._get_madrid_traffic_end_date()
|
||||
if base_range.end > madrid_end_date:
|
||||
# If requested end date is in current month, adjust it
|
||||
new_end = madrid_end_date
|
||||
constraints["madrid_traffic"] = f"Adjusted end date to {new_end.date()} (latest available traffic data)"
|
||||
logger.info(f"Madrid traffic constraint: end date adjusted to {new_end.date()}")
|
||||
else:
|
||||
new_end = base_range.end
|
||||
available_sources.append(DataSourceType.MADRID_TRAFFIC)
|
||||
|
||||
# Weather Forecast Constraint
|
||||
# Weather data available from yesterday backward
|
||||
weather_end_date = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0) - timedelta(days=1)
|
||||
if base_range.end > weather_end_date:
|
||||
if new_end > weather_end_date:
|
||||
new_end = weather_end_date
|
||||
constraints["weather"] = f"Adjusted end date to {new_end.date()} (latest available weather data)"
|
||||
logger.info(f"Weather constraint: end date adjusted to {new_end.date()}")
|
||||
|
||||
if new_end >= base_range.start:
|
||||
available_sources.append(DataSourceType.WEATHER_FORECAST)
|
||||
|
||||
# Ensure minimum training period
|
||||
final_start = base_range.start
|
||||
if (new_end - final_start).days < self.MIN_TRAINING_RANGE_DAYS:
|
||||
final_start = new_end - timedelta(days=self.MIN_TRAINING_RANGE_DAYS)
|
||||
constraints["minimum_period"] = f"Adjusted start date to ensure {self.MIN_TRAINING_RANGE_DAYS} day minimum training period"
|
||||
logger.info(f"Minimum period constraint: start date adjusted to {final_start.date()}")
|
||||
|
||||
return AlignedDateRange(
|
||||
start=final_start,
|
||||
end=new_end,
|
||||
available_sources=available_sources,
|
||||
constraints=constraints
|
||||
)
|
||||
|
||||
def _get_madrid_traffic_end_date(self) -> datetime:
|
||||
"""
|
||||
Get the latest available date for Madrid traffic data.
|
||||
Data for current month is not available until the following month.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Data up to the previous month is available
|
||||
# Go to first day of current month, then subtract 1 day to get last day of previous month
|
||||
last_day_of_previous_month = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0) - timedelta(days=1)
|
||||
|
||||
return last_day_of_previous_month
|
||||
|
||||
def _validate_final_range(self, aligned_range: AlignedDateRange) -> None:
|
||||
"""Validate the final aligned date range."""
|
||||
|
||||
if aligned_range.start >= aligned_range.end:
|
||||
raise ValueError("Invalid date range: start date must be before end date")
|
||||
|
||||
duration = (aligned_range.end - aligned_range.start).days
|
||||
|
||||
if duration < self.MIN_TRAINING_RANGE_DAYS:
|
||||
raise ValueError(f"Insufficient training data: {duration} days (minimum: {self.MIN_TRAINING_RANGE_DAYS})")
|
||||
|
||||
if duration > self.MAX_TRAINING_RANGE_DAYS:
|
||||
raise ValueError(f"Training period too long: {duration} days (maximum: {self.MAX_TRAINING_RANGE_DAYS})")
|
||||
|
||||
# Ensure we have at least sales data
|
||||
if DataSourceType.BAKERY_SALES not in aligned_range.available_sources:
|
||||
raise ValueError("No sales data available for the aligned date range")
|
||||
|
||||
def get_data_collection_plan(self, aligned_range: AlignedDateRange) -> Dict[str, Dict]:
|
||||
"""
|
||||
Generate a data collection plan based on the aligned date range.
|
||||
|
||||
Returns:
|
||||
Dictionary with collection plans for each data source
|
||||
"""
|
||||
plan = {}
|
||||
|
||||
# Bakery Sales Data
|
||||
if DataSourceType.BAKERY_SALES in aligned_range.available_sources:
|
||||
plan["sales_data"] = {
|
||||
"start_date": aligned_range.start,
|
||||
"end_date": aligned_range.end,
|
||||
"source": "user_upload",
|
||||
"required": True
|
||||
}
|
||||
|
||||
# Madrid Traffic Data
|
||||
if DataSourceType.MADRID_TRAFFIC in aligned_range.available_sources:
|
||||
plan["traffic_data"] = {
|
||||
"start_date": aligned_range.start,
|
||||
"end_date": aligned_range.end,
|
||||
"source": "madrid_opendata",
|
||||
"required": False,
|
||||
"constraint": "Cannot request current month data"
|
||||
}
|
||||
|
||||
# Weather Data
|
||||
if DataSourceType.WEATHER_FORECAST in aligned_range.available_sources:
|
||||
plan["weather_data"] = {
|
||||
"start_date": aligned_range.start,
|
||||
"end_date": aligned_range.end,
|
||||
"source": "aemet_api",
|
||||
"required": False,
|
||||
"constraint": "Available from yesterday backward"
|
||||
}
|
||||
|
||||
return plan
|
||||
|
||||
def check_madrid_current_month_constraint(self, end_date: datetime) -> bool:
|
||||
"""
|
||||
Check if the end date violates the Madrid Open Data current month constraint.
|
||||
|
||||
Args:
|
||||
end_date: The requested end date
|
||||
|
||||
Returns:
|
||||
True if the constraint is violated (end date is in current month)
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
current_month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
# Debug logging
|
||||
logger.info(f"🔍 Madrid constraint check: end_date={end_date}, current_month_start={current_month_start}, violation={end_date >= current_month_start}")
|
||||
|
||||
return end_date >= current_month_start
|
||||
120
services/training/app/services/progress_tracker.py
Normal file
120
services/training/app/services/progress_tracker.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""
|
||||
Training Progress Tracker
|
||||
Manages progress calculation for parallel product training (20-80% range)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import structlog
|
||||
from typing import Optional
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from app.services.training_events import publish_product_training_completed
|
||||
from app.utils.time_estimation import calculate_estimated_completion_time
|
||||
from app.core.training_constants import (
|
||||
PROGRESS_TRAINING_RANGE_START,
|
||||
PROGRESS_TRAINING_RANGE_END,
|
||||
PROGRESS_TRAINING_RANGE_WIDTH
|
||||
)
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class ParallelProductProgressTracker:
|
||||
"""
|
||||
Tracks parallel product training progress and emits events.
|
||||
|
||||
For N products training in parallel:
|
||||
- Each product completion contributes 60/N% to overall progress
|
||||
- Progress range: 20% (after data analysis) to 80% (before completion)
|
||||
- Thread-safe for concurrent product trainings
|
||||
- Calculates time estimates based on elapsed time and progress
|
||||
"""
|
||||
|
||||
def __init__(self, job_id: str, tenant_id: str, total_products: int):
|
||||
self.job_id = job_id
|
||||
self.tenant_id = tenant_id
|
||||
self.total_products = max(total_products, 1) # Ensure at least 1 to avoid division by zero
|
||||
self.products_completed = 0
|
||||
self._lock = asyncio.Lock()
|
||||
self.start_time = datetime.now(timezone.utc)
|
||||
|
||||
# Calculate progress increment per product
|
||||
# Training range (from PROGRESS_TRAINING_RANGE_START to PROGRESS_TRAINING_RANGE_END) divided by number of products
|
||||
self.progress_per_product = PROGRESS_TRAINING_RANGE_WIDTH / self.total_products if self.total_products > 0 else 0
|
||||
|
||||
if total_products == 0:
|
||||
logger.warning("ParallelProductProgressTracker initialized with zero products",
|
||||
job_id=job_id)
|
||||
|
||||
logger.info("ParallelProductProgressTracker initialized",
|
||||
job_id=job_id,
|
||||
total_products=self.total_products,
|
||||
progress_per_product=f"{self.progress_per_product:.2f}%")
|
||||
|
||||
async def mark_product_completed(self, product_name: str) -> int:
|
||||
"""
|
||||
Mark a product as completed and publish event with time estimates.
|
||||
Returns the current overall progress percentage.
|
||||
"""
|
||||
async with self._lock:
|
||||
self.products_completed += 1
|
||||
current_progress = self.products_completed
|
||||
|
||||
# Calculate time estimates based on elapsed time and progress
|
||||
elapsed_seconds = (datetime.now(timezone.utc) - self.start_time).total_seconds()
|
||||
products_remaining = self.total_products - current_progress
|
||||
|
||||
# Calculate estimated time remaining
|
||||
# Avg time per product * remaining products
|
||||
estimated_time_remaining_seconds = None
|
||||
estimated_completion_time = None
|
||||
|
||||
if current_progress > 0 and products_remaining > 0:
|
||||
avg_time_per_product = elapsed_seconds / current_progress
|
||||
estimated_time_remaining_seconds = int(avg_time_per_product * products_remaining)
|
||||
|
||||
# Calculate estimated completion time
|
||||
estimated_duration_minutes = estimated_time_remaining_seconds / 60
|
||||
completion_datetime = calculate_estimated_completion_time(estimated_duration_minutes)
|
||||
estimated_completion_time = completion_datetime.isoformat()
|
||||
|
||||
# Publish product completion event with time estimates
|
||||
await publish_product_training_completed(
|
||||
job_id=self.job_id,
|
||||
tenant_id=self.tenant_id,
|
||||
product_name=product_name,
|
||||
products_completed=current_progress,
|
||||
total_products=self.total_products,
|
||||
estimated_time_remaining_seconds=estimated_time_remaining_seconds,
|
||||
estimated_completion_time=estimated_completion_time
|
||||
)
|
||||
|
||||
# Calculate overall progress (PROGRESS_TRAINING_RANGE_START% base + progress from completed products)
|
||||
# This calculation is done on the frontend/consumer side based on the event data
|
||||
if self.total_products > 0:
|
||||
overall_progress = PROGRESS_TRAINING_RANGE_START + int((current_progress / self.total_products) * PROGRESS_TRAINING_RANGE_WIDTH)
|
||||
else:
|
||||
overall_progress = PROGRESS_TRAINING_RANGE_START
|
||||
|
||||
logger.info("Product training completed",
|
||||
job_id=self.job_id,
|
||||
product_name=product_name,
|
||||
products_completed=current_progress,
|
||||
total_products=self.total_products,
|
||||
overall_progress=overall_progress,
|
||||
estimated_time_remaining_seconds=estimated_time_remaining_seconds)
|
||||
|
||||
return overall_progress
|
||||
|
||||
def get_progress(self) -> dict:
|
||||
"""Get current progress summary"""
|
||||
if self.total_products > 0:
|
||||
progress_percentage = PROGRESS_TRAINING_RANGE_START + int((self.products_completed / self.total_products) * PROGRESS_TRAINING_RANGE_WIDTH)
|
||||
else:
|
||||
progress_percentage = PROGRESS_TRAINING_RANGE_START
|
||||
|
||||
return {
|
||||
"products_completed": self.products_completed,
|
||||
"total_products": self.total_products,
|
||||
"progress_percentage": progress_percentage
|
||||
}
|
||||
339
services/training/app/services/tenant_deletion_service.py
Normal file
339
services/training/app/services/tenant_deletion_service.py
Normal file
@@ -0,0 +1,339 @@
|
||||
# services/training/app/services/tenant_deletion_service.py
|
||||
"""
|
||||
Tenant Data Deletion Service for Training Service
|
||||
Handles deletion of all training-related data for a tenant
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
from sqlalchemy import select, func, delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
import structlog
|
||||
|
||||
from shared.services.tenant_deletion import (
|
||||
BaseTenantDataDeletionService,
|
||||
TenantDataDeletionResult
|
||||
)
|
||||
from app.models import (
|
||||
TrainedModel,
|
||||
ModelTrainingLog,
|
||||
ModelPerformanceMetric,
|
||||
TrainingJobQueue,
|
||||
ModelArtifact,
|
||||
AuditLog
|
||||
)
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class TrainingTenantDeletionService(BaseTenantDataDeletionService):
|
||||
"""Service for deleting all training-related data for a tenant"""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
self.service_name = "training"
|
||||
|
||||
async def get_tenant_data_preview(self, tenant_id: str) -> Dict[str, int]:
|
||||
"""
|
||||
Get counts of what would be deleted for a tenant (dry-run)
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID to preview deletion for
|
||||
|
||||
Returns:
|
||||
Dictionary with entity names and their counts
|
||||
"""
|
||||
logger.info("training.tenant_deletion.preview", tenant_id=tenant_id)
|
||||
preview = {}
|
||||
|
||||
try:
|
||||
# Count trained models
|
||||
model_count = await self.db.scalar(
|
||||
select(func.count(TrainedModel.id)).where(
|
||||
TrainedModel.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
preview["trained_models"] = model_count or 0
|
||||
|
||||
# Count model artifacts
|
||||
artifact_count = await self.db.scalar(
|
||||
select(func.count(ModelArtifact.id)).where(
|
||||
ModelArtifact.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
preview["model_artifacts"] = artifact_count or 0
|
||||
|
||||
# Count training logs
|
||||
log_count = await self.db.scalar(
|
||||
select(func.count(ModelTrainingLog.id)).where(
|
||||
ModelTrainingLog.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
preview["model_training_logs"] = log_count or 0
|
||||
|
||||
# Count performance metrics
|
||||
metric_count = await self.db.scalar(
|
||||
select(func.count(ModelPerformanceMetric.id)).where(
|
||||
ModelPerformanceMetric.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
preview["model_performance_metrics"] = metric_count or 0
|
||||
|
||||
# Count training job queue entries
|
||||
queue_count = await self.db.scalar(
|
||||
select(func.count(TrainingJobQueue.id)).where(
|
||||
TrainingJobQueue.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
preview["training_job_queue"] = queue_count or 0
|
||||
|
||||
# Count audit logs
|
||||
audit_count = await self.db.scalar(
|
||||
select(func.count(AuditLog.id)).where(
|
||||
AuditLog.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
preview["audit_logs"] = audit_count or 0
|
||||
|
||||
logger.info(
|
||||
"training.tenant_deletion.preview_complete",
|
||||
tenant_id=tenant_id,
|
||||
preview=preview
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"training.tenant_deletion.preview_error",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e),
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
return preview
|
||||
|
||||
async def delete_tenant_data(self, tenant_id: str) -> TenantDataDeletionResult:
|
||||
"""
|
||||
Permanently delete all training data for a tenant
|
||||
|
||||
Deletion order:
|
||||
1. ModelArtifact (references models)
|
||||
2. ModelPerformanceMetric (references models)
|
||||
3. ModelTrainingLog (independent job logs)
|
||||
4. TrainingJobQueue (independent queue entries)
|
||||
5. TrainedModel (parent model records)
|
||||
6. AuditLog (independent)
|
||||
|
||||
Note: This also deletes physical model files from disk/storage
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID to delete data for
|
||||
|
||||
Returns:
|
||||
TenantDataDeletionResult with deletion counts and any errors
|
||||
"""
|
||||
logger.info("training.tenant_deletion.started", tenant_id=tenant_id)
|
||||
result = TenantDataDeletionResult(tenant_id=tenant_id, service_name=self.service_name)
|
||||
|
||||
try:
|
||||
import os
|
||||
|
||||
# Step 1: Delete model artifacts (references models)
|
||||
logger.info("training.tenant_deletion.deleting_artifacts", tenant_id=tenant_id)
|
||||
|
||||
# Delete physical files from storage before deleting DB records
|
||||
artifacts = await self.db.execute(
|
||||
select(ModelArtifact).where(ModelArtifact.tenant_id == tenant_id)
|
||||
)
|
||||
deleted_files = 0
|
||||
failed_files = 0
|
||||
for artifact in artifacts.scalars():
|
||||
try:
|
||||
if artifact.file_path and os.path.exists(artifact.file_path):
|
||||
os.remove(artifact.file_path)
|
||||
deleted_files += 1
|
||||
logger.info("Deleted artifact file",
|
||||
path=artifact.file_path,
|
||||
artifact_id=artifact.id)
|
||||
except Exception as e:
|
||||
failed_files += 1
|
||||
logger.warning("Failed to delete artifact file",
|
||||
path=artifact.file_path,
|
||||
artifact_id=artifact.id if hasattr(artifact, 'id') else 'unknown',
|
||||
error=str(e))
|
||||
|
||||
logger.info("Artifact files deletion complete",
|
||||
deleted_files=deleted_files,
|
||||
failed_files=failed_files)
|
||||
|
||||
# Now delete DB records
|
||||
artifacts_result = await self.db.execute(
|
||||
delete(ModelArtifact).where(
|
||||
ModelArtifact.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
result.deleted_counts["model_artifacts"] = artifacts_result.rowcount
|
||||
result.deleted_counts["artifact_files_deleted"] = deleted_files
|
||||
result.deleted_counts["artifact_files_failed"] = failed_files
|
||||
logger.info(
|
||||
"training.tenant_deletion.artifacts_deleted",
|
||||
tenant_id=tenant_id,
|
||||
count=artifacts_result.rowcount
|
||||
)
|
||||
|
||||
# Step 2: Delete model performance metrics
|
||||
logger.info("training.tenant_deletion.deleting_metrics", tenant_id=tenant_id)
|
||||
metrics_result = await self.db.execute(
|
||||
delete(ModelPerformanceMetric).where(
|
||||
ModelPerformanceMetric.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
result.deleted_counts["model_performance_metrics"] = metrics_result.rowcount
|
||||
logger.info(
|
||||
"training.tenant_deletion.metrics_deleted",
|
||||
tenant_id=tenant_id,
|
||||
count=metrics_result.rowcount
|
||||
)
|
||||
|
||||
# Step 3: Delete training logs
|
||||
logger.info("training.tenant_deletion.deleting_logs", tenant_id=tenant_id)
|
||||
logs_result = await self.db.execute(
|
||||
delete(ModelTrainingLog).where(
|
||||
ModelTrainingLog.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
result.deleted_counts["model_training_logs"] = logs_result.rowcount
|
||||
logger.info(
|
||||
"training.tenant_deletion.logs_deleted",
|
||||
tenant_id=tenant_id,
|
||||
count=logs_result.rowcount
|
||||
)
|
||||
|
||||
# Step 4: Delete training job queue entries
|
||||
logger.info("training.tenant_deletion.deleting_queue", tenant_id=tenant_id)
|
||||
queue_result = await self.db.execute(
|
||||
delete(TrainingJobQueue).where(
|
||||
TrainingJobQueue.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
result.deleted_counts["training_job_queue"] = queue_result.rowcount
|
||||
logger.info(
|
||||
"training.tenant_deletion.queue_deleted",
|
||||
tenant_id=tenant_id,
|
||||
count=queue_result.rowcount
|
||||
)
|
||||
|
||||
# Step 5: Delete trained models (parent records)
|
||||
logger.info("training.tenant_deletion.deleting_models", tenant_id=tenant_id)
|
||||
|
||||
# Delete physical model files (.pkl) before deleting DB records
|
||||
models = await self.db.execute(
|
||||
select(TrainedModel).where(TrainedModel.tenant_id == tenant_id)
|
||||
)
|
||||
deleted_model_files = 0
|
||||
failed_model_files = 0
|
||||
for model in models.scalars():
|
||||
try:
|
||||
# Delete .pkl file
|
||||
if hasattr(model, 'model_path') and model.model_path and os.path.exists(model.model_path):
|
||||
os.remove(model.model_path)
|
||||
deleted_model_files += 1
|
||||
logger.info("Deleted model file",
|
||||
path=model.model_path,
|
||||
model_id=model.id)
|
||||
# Delete model_file_path if it exists
|
||||
if hasattr(model, 'model_file_path') and model.model_file_path and os.path.exists(model.model_file_path):
|
||||
os.remove(model.model_file_path)
|
||||
deleted_model_files += 1
|
||||
logger.info("Deleted model file",
|
||||
path=model.model_file_path,
|
||||
model_id=model.id)
|
||||
# Delete metadata file if exists
|
||||
if hasattr(model, 'metadata_path') and model.metadata_path and os.path.exists(model.metadata_path):
|
||||
os.remove(model.metadata_path)
|
||||
logger.info("Deleted metadata file",
|
||||
path=model.metadata_path,
|
||||
model_id=model.id)
|
||||
except Exception as e:
|
||||
failed_model_files += 1
|
||||
logger.warning("Failed to delete model file",
|
||||
path=getattr(model, 'model_path', getattr(model, 'model_file_path', 'unknown')),
|
||||
model_id=model.id if hasattr(model, 'id') else 'unknown',
|
||||
error=str(e))
|
||||
|
||||
logger.info("Model files deletion complete",
|
||||
deleted_files=deleted_model_files,
|
||||
failed_files=failed_model_files)
|
||||
|
||||
# Now delete DB records
|
||||
models_result = await self.db.execute(
|
||||
delete(TrainedModel).where(
|
||||
TrainedModel.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
result.deleted_counts["trained_models"] = models_result.rowcount
|
||||
result.deleted_counts["model_files_deleted"] = deleted_model_files
|
||||
result.deleted_counts["model_files_failed"] = failed_model_files
|
||||
logger.info(
|
||||
"training.tenant_deletion.models_deleted",
|
||||
tenant_id=tenant_id,
|
||||
count=models_result.rowcount
|
||||
)
|
||||
|
||||
# Step 6: Delete audit logs
|
||||
logger.info("training.tenant_deletion.deleting_audit_logs", tenant_id=tenant_id)
|
||||
audit_result = await self.db.execute(
|
||||
delete(AuditLog).where(
|
||||
AuditLog.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
result.deleted_counts["audit_logs"] = audit_result.rowcount
|
||||
logger.info(
|
||||
"training.tenant_deletion.audit_logs_deleted",
|
||||
tenant_id=tenant_id,
|
||||
count=audit_result.rowcount
|
||||
)
|
||||
|
||||
# Commit the transaction
|
||||
await self.db.commit()
|
||||
|
||||
# Calculate total deleted
|
||||
total_deleted = sum(result.deleted_counts.values())
|
||||
|
||||
logger.info(
|
||||
"training.tenant_deletion.completed",
|
||||
tenant_id=tenant_id,
|
||||
total_deleted=total_deleted,
|
||||
breakdown=result.deleted_counts,
|
||||
note="Physical model files should be cleaned up separately"
|
||||
)
|
||||
|
||||
result.success = True
|
||||
|
||||
except Exception as e:
|
||||
await self.db.rollback()
|
||||
error_msg = f"Failed to delete training data for tenant {tenant_id}: {str(e)}"
|
||||
logger.error(
|
||||
"training.tenant_deletion.failed",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e),
|
||||
exc_info=True
|
||||
)
|
||||
result.errors.append(error_msg)
|
||||
result.success = False
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_training_tenant_deletion_service(
|
||||
db: AsyncSession
|
||||
) -> TrainingTenantDeletionService:
|
||||
"""
|
||||
Factory function to create TrainingTenantDeletionService instance
|
||||
|
||||
Args:
|
||||
db: AsyncSession database session
|
||||
|
||||
Returns:
|
||||
TrainingTenantDeletionService instance
|
||||
"""
|
||||
return TrainingTenantDeletionService(db)
|
||||
330
services/training/app/services/training_events.py
Normal file
330
services/training/app/services/training_events.py
Normal file
@@ -0,0 +1,330 @@
|
||||
"""
|
||||
Training Progress Events Publisher
|
||||
Simple, clean event publisher for the 4 main training steps
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
from shared.messaging import RabbitMQClient
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Single global publisher instance
|
||||
training_publisher = RabbitMQClient(settings.RABBITMQ_URL, "training-service")
|
||||
|
||||
|
||||
async def setup_messaging():
|
||||
"""Initialize messaging"""
|
||||
success = await training_publisher.connect()
|
||||
if success:
|
||||
logger.info("Training messaging initialized")
|
||||
else:
|
||||
logger.warning("Training messaging failed to initialize")
|
||||
return success
|
||||
|
||||
|
||||
async def cleanup_messaging():
|
||||
"""Cleanup messaging"""
|
||||
await training_publisher.disconnect()
|
||||
logger.info("Training messaging cleaned up")
|
||||
|
||||
|
||||
# ==========================================
|
||||
# 4 MAIN TRAINING PROGRESS EVENTS
|
||||
# ==========================================
|
||||
|
||||
async def publish_training_started(
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
total_products: int,
|
||||
estimated_duration_minutes: Optional[int] = None,
|
||||
estimated_completion_time: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Event 1: Training Started (0% progress)
|
||||
|
||||
Args:
|
||||
job_id: Training job identifier
|
||||
tenant_id: Tenant identifier
|
||||
total_products: Number of products to train
|
||||
estimated_duration_minutes: Estimated time to completion in minutes
|
||||
estimated_completion_time: ISO timestamp of estimated completion
|
||||
"""
|
||||
event_data = {
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.started",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"progress": 0,
|
||||
"current_step": "Training Started",
|
||||
"step_details": f"Starting training for {total_products} products",
|
||||
"total_products": total_products,
|
||||
"estimated_duration_minutes": estimated_duration_minutes,
|
||||
"estimated_completion_time": estimated_completion_time,
|
||||
"estimated_time_remaining_seconds": estimated_duration_minutes * 60 if estimated_duration_minutes else None
|
||||
}
|
||||
}
|
||||
|
||||
success = await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.started",
|
||||
event_data=event_data
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("Published training started event",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
total_products=total_products,
|
||||
estimated_duration_minutes=estimated_duration_minutes)
|
||||
else:
|
||||
logger.error("Failed to publish training started event", job_id=job_id)
|
||||
|
||||
return success
|
||||
|
||||
|
||||
async def publish_data_analysis(
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
analysis_details: Optional[str] = None,
|
||||
estimated_time_remaining_seconds: Optional[int] = None,
|
||||
estimated_completion_time: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Event 2: Data Analysis (20% progress)
|
||||
|
||||
Args:
|
||||
job_id: Training job identifier
|
||||
tenant_id: Tenant identifier
|
||||
analysis_details: Details about the analysis
|
||||
estimated_time_remaining_seconds: Estimated time remaining in seconds
|
||||
estimated_completion_time: ISO timestamp of estimated completion
|
||||
"""
|
||||
event_data = {
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.progress",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"progress": 20,
|
||||
"current_step": "Data Analysis",
|
||||
"step_details": analysis_details or "Analyzing sales, weather, and traffic data",
|
||||
"estimated_time_remaining_seconds": estimated_time_remaining_seconds,
|
||||
"estimated_completion_time": estimated_completion_time
|
||||
}
|
||||
}
|
||||
|
||||
success = await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.progress",
|
||||
event_data=event_data
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("Published data analysis event",
|
||||
job_id=job_id,
|
||||
progress=20)
|
||||
else:
|
||||
logger.error("Failed to publish data analysis event", job_id=job_id)
|
||||
|
||||
return success
|
||||
|
||||
|
||||
async def publish_training_progress(
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
progress: int,
|
||||
current_step: str,
|
||||
step_details: Optional[str] = None,
|
||||
estimated_time_remaining_seconds: Optional[int] = None,
|
||||
estimated_completion_time: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Generic Training Progress Event (for any progress percentage)
|
||||
|
||||
Args:
|
||||
job_id: Training job identifier
|
||||
tenant_id: Tenant identifier
|
||||
progress: Progress percentage (0-100)
|
||||
current_step: Current step name
|
||||
step_details: Details about the current step
|
||||
estimated_time_remaining_seconds: Estimated time remaining in seconds
|
||||
estimated_completion_time: ISO timestamp of estimated completion
|
||||
"""
|
||||
event_data = {
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.progress",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"progress": progress,
|
||||
"current_step": current_step,
|
||||
"step_details": step_details or current_step,
|
||||
"estimated_time_remaining_seconds": estimated_time_remaining_seconds,
|
||||
"estimated_completion_time": estimated_completion_time
|
||||
}
|
||||
}
|
||||
|
||||
success = await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.progress",
|
||||
event_data=event_data
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("Published training progress event",
|
||||
job_id=job_id,
|
||||
progress=progress,
|
||||
current_step=current_step)
|
||||
else:
|
||||
logger.error("Failed to publish training progress event",
|
||||
job_id=job_id,
|
||||
progress=progress)
|
||||
|
||||
return success
|
||||
|
||||
|
||||
async def publish_product_training_completed(
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
product_name: str,
|
||||
products_completed: int,
|
||||
total_products: int,
|
||||
estimated_time_remaining_seconds: Optional[int] = None,
|
||||
estimated_completion_time: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Event 3: Product Training Completed (contributes to 20-80% progress)
|
||||
|
||||
This event is published each time a product training completes.
|
||||
The frontend/consumer will calculate the progress as:
|
||||
progress = 20 + (products_completed / total_products) * 60
|
||||
|
||||
Args:
|
||||
job_id: Training job identifier
|
||||
tenant_id: Tenant identifier
|
||||
product_name: Name of the product that was trained
|
||||
products_completed: Number of products completed so far
|
||||
total_products: Total number of products
|
||||
estimated_time_remaining_seconds: Estimated time remaining in seconds
|
||||
estimated_completion_time: ISO timestamp of estimated completion
|
||||
"""
|
||||
event_data = {
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.product.completed",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name,
|
||||
"products_completed": products_completed,
|
||||
"total_products": total_products,
|
||||
"current_step": "Model Training",
|
||||
"step_details": f"Completed training for {product_name} ({products_completed}/{total_products})",
|
||||
"estimated_time_remaining_seconds": estimated_time_remaining_seconds,
|
||||
"estimated_completion_time": estimated_completion_time
|
||||
}
|
||||
}
|
||||
|
||||
success = await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.product.completed",
|
||||
event_data=event_data
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("Published product training completed event",
|
||||
job_id=job_id,
|
||||
product_name=product_name,
|
||||
products_completed=products_completed,
|
||||
total_products=total_products)
|
||||
else:
|
||||
logger.error("Failed to publish product training completed event",
|
||||
job_id=job_id)
|
||||
|
||||
return success
|
||||
|
||||
|
||||
async def publish_training_completed(
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
successful_trainings: int,
|
||||
failed_trainings: int,
|
||||
total_duration_seconds: float
|
||||
) -> bool:
|
||||
"""
|
||||
Event 4: Training Completed (100% progress)
|
||||
"""
|
||||
event_data = {
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.completed",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"progress": 100,
|
||||
"current_step": "Training Completed",
|
||||
"step_details": f"Training completed: {successful_trainings} successful, {failed_trainings} failed",
|
||||
"successful_trainings": successful_trainings,
|
||||
"failed_trainings": failed_trainings,
|
||||
"total_duration_seconds": total_duration_seconds
|
||||
}
|
||||
}
|
||||
|
||||
success = await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.completed",
|
||||
event_data=event_data
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("Published training completed event",
|
||||
job_id=job_id,
|
||||
successful_trainings=successful_trainings,
|
||||
failed_trainings=failed_trainings)
|
||||
else:
|
||||
logger.error("Failed to publish training completed event", job_id=job_id)
|
||||
|
||||
return success
|
||||
|
||||
|
||||
async def publish_training_failed(
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
error_message: str
|
||||
) -> bool:
|
||||
"""
|
||||
Event: Training Failed
|
||||
"""
|
||||
event_data = {
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.failed",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"current_step": "Training Failed",
|
||||
"error_message": error_message
|
||||
}
|
||||
}
|
||||
|
||||
success = await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.failed",
|
||||
event_data=event_data
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("Published training failed event",
|
||||
job_id=job_id,
|
||||
error=error_message)
|
||||
else:
|
||||
logger.error("Failed to publish training failed event", job_id=job_id)
|
||||
|
||||
return success
|
||||
971
services/training/app/services/training_orchestrator.py
Normal file
971
services/training/app/services/training_orchestrator.py
Normal file
@@ -0,0 +1,971 @@
|
||||
# services/training/app/services/training_orchestrator.py
|
||||
"""
|
||||
Training Data Orchestrator - Enhanced Integration Layer
|
||||
Orchestrates data collection, date alignment, and preparation for ML training
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from dataclasses import dataclass
|
||||
import asyncio
|
||||
import structlog
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import timezone
|
||||
import pandas as pd
|
||||
|
||||
from app.services.data_client import DataClient
|
||||
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType, AlignedDateRange
|
||||
from app.ml.poi_feature_integrator import POIFeatureIntegrator
|
||||
from app.services.training_events import publish_training_failed
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
@dataclass
|
||||
class TrainingDataSet:
|
||||
"""Container for all training data with metadata"""
|
||||
sales_data: List[Dict[str, Any]]
|
||||
weather_data: List[Dict[str, Any]]
|
||||
traffic_data: List[Dict[str, Any]]
|
||||
poi_features: Dict[str, Any] # POI features for location-based forecasting
|
||||
date_range: AlignedDateRange
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
class TrainingDataOrchestrator:
|
||||
"""
|
||||
Enhanced orchestrator for data collection from multiple sources.
|
||||
Ensures date alignment, handles data source constraints, and prepares data for ML training.
|
||||
Uses the new abstracted traffic service layer for multi-city support.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
date_alignment_service: DateAlignmentService = None,
|
||||
poi_feature_integrator: POIFeatureIntegrator = None):
|
||||
self.data_client = DataClient()
|
||||
self.date_alignment_service = date_alignment_service or DateAlignmentService()
|
||||
self.poi_feature_integrator = poi_feature_integrator or POIFeatureIntegrator()
|
||||
self.max_concurrent_requests = 5 # Increased for better performance
|
||||
|
||||
async def prepare_training_data(
|
||||
self,
|
||||
tenant_id: str,
|
||||
bakery_location: Tuple[float, float], # (lat, lon)
|
||||
requested_start: Optional[datetime] = None,
|
||||
requested_end: Optional[datetime] = None,
|
||||
job_id: Optional[str] = None
|
||||
) -> TrainingDataSet:
|
||||
"""
|
||||
Main method to prepare all training data with comprehensive date alignment.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
sales_data: User-provided sales data
|
||||
bakery_location: Bakery coordinates (lat, lon)
|
||||
requested_start: Optional explicit start date
|
||||
requested_end: Optional explicit end date
|
||||
job_id: Training job identifier for logging
|
||||
|
||||
Returns:
|
||||
TrainingDataSet with all aligned and validated data
|
||||
"""
|
||||
logger.info(f"Starting comprehensive training data preparation for tenant {tenant_id}, job {job_id}")
|
||||
|
||||
try:
|
||||
# Step 1: Fetch and validate sales data (unified approach)
|
||||
sales_data = await self.data_client.fetch_sales_data(tenant_id, fetch_all=True)
|
||||
|
||||
if not sales_data or len(sales_data) == 0:
|
||||
error_msg = f"No sales data available for tenant {tenant_id}. Please import sales data before starting training."
|
||||
logger.error("Training aborted - no sales data", tenant_id=tenant_id, job_id=job_id)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Debug: Analyze the sales data structure to understand product distribution
|
||||
sales_df_debug = pd.DataFrame(sales_data)
|
||||
if 'inventory_product_id' in sales_df_debug.columns:
|
||||
unique_products_found = sales_df_debug['inventory_product_id'].unique()
|
||||
product_counts = sales_df_debug['inventory_product_id'].value_counts().to_dict()
|
||||
|
||||
logger.info("Sales data analysis (moved from pre-flight)",
|
||||
tenant_id=tenant_id,
|
||||
job_id=job_id,
|
||||
total_sales_records=len(sales_data),
|
||||
unique_products_count=len(unique_products_found),
|
||||
unique_products=unique_products_found.tolist(),
|
||||
records_per_product=product_counts)
|
||||
|
||||
if len(unique_products_found) == 1:
|
||||
logger.warning("POTENTIAL ISSUE: Only ONE unique product found in all sales data",
|
||||
tenant_id=tenant_id,
|
||||
single_product=unique_products_found[0],
|
||||
record_count=len(sales_data))
|
||||
else:
|
||||
logger.warning("No 'inventory_product_id' column found in sales data",
|
||||
tenant_id=tenant_id,
|
||||
columns=list(sales_df_debug.columns))
|
||||
|
||||
logger.info(f"Sales data validation passed: {len(sales_data)} sales records found",
|
||||
tenant_id=tenant_id, job_id=job_id)
|
||||
|
||||
# Step 2: Extract and validate sales data date range
|
||||
sales_date_range = self._extract_sales_date_range(sales_data)
|
||||
logger.info(f"Sales data range detected: {sales_date_range.start} to {sales_date_range.end}")
|
||||
|
||||
# Step 3: Apply date alignment across all data sources
|
||||
aligned_range = self.date_alignment_service.validate_and_align_dates(
|
||||
user_sales_range=sales_date_range,
|
||||
requested_start=requested_start,
|
||||
requested_end=requested_end
|
||||
)
|
||||
|
||||
logger.info(f"Date alignment completed: {aligned_range.start} to {aligned_range.end}")
|
||||
if aligned_range.constraints:
|
||||
logger.info(f"Applied constraints: {aligned_range.constraints}")
|
||||
|
||||
# Step 4: Filter sales data to aligned date range
|
||||
filtered_sales = self._filter_sales_data(sales_data, aligned_range)
|
||||
|
||||
# Step 5: Collect external data sources concurrently
|
||||
logger.info("Collecting external data sources...")
|
||||
weather_data, traffic_data, poi_features = await self._collect_external_data(
|
||||
aligned_range, bakery_location, tenant_id
|
||||
)
|
||||
|
||||
# Step 6: Validate data quality
|
||||
data_quality_results = self._validate_data_sources(
|
||||
filtered_sales, weather_data, traffic_data, aligned_range
|
||||
)
|
||||
|
||||
# Step 7: Create comprehensive training dataset
|
||||
training_dataset = TrainingDataSet(
|
||||
sales_data=filtered_sales,
|
||||
weather_data=weather_data,
|
||||
traffic_data=traffic_data,
|
||||
poi_features=poi_features or {}, # POI features (static, location-based)
|
||||
date_range=aligned_range,
|
||||
metadata={
|
||||
"tenant_id": tenant_id,
|
||||
"job_id": job_id,
|
||||
"bakery_location": bakery_location,
|
||||
"data_sources_used": aligned_range.available_sources,
|
||||
"constraints_applied": aligned_range.constraints,
|
||||
"data_quality": data_quality_results,
|
||||
"preparation_timestamp": datetime.now().isoformat(),
|
||||
"original_sales_range": {
|
||||
"start": sales_date_range.start.isoformat(),
|
||||
"end": sales_date_range.end.isoformat()
|
||||
},
|
||||
"poi_features_count": len(poi_features) if poi_features else 0
|
||||
}
|
||||
)
|
||||
|
||||
# Step 8: Final validation
|
||||
final_validation = self.validate_training_data_quality(training_dataset)
|
||||
training_dataset.metadata["final_validation"] = final_validation
|
||||
|
||||
logger.info(f"Training data preparation completed successfully:")
|
||||
logger.info(f" - Sales records: {len(filtered_sales)}")
|
||||
logger.info(f" - Weather records: {len(weather_data)}")
|
||||
logger.info(f" - Traffic records: {len(traffic_data)}")
|
||||
logger.info(f" - POI features: {len(poi_features) if poi_features else 0}")
|
||||
logger.info(f" - Data quality score: {final_validation.get('data_quality_score', 'N/A')}")
|
||||
|
||||
return training_dataset
|
||||
|
||||
except Exception as e:
|
||||
if job_id and tenant_id:
|
||||
await publish_training_failed(job_id, tenant_id, str(e))
|
||||
logger.error(f"Training data preparation failed: {str(e)}")
|
||||
raise ValueError(f"Failed to prepare training data: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def extract_sales_date_range_utc_localize(sales_data_df: pd.DataFrame):
|
||||
"""
|
||||
Extracts the UTC-aware date range from a sales DataFrame using tz_localize.
|
||||
|
||||
Args:
|
||||
sales_data_df: A pandas DataFrame containing a 'date' column.
|
||||
|
||||
Returns:
|
||||
A tuple of timezone-aware start and end dates in UTC.
|
||||
"""
|
||||
if 'date' not in sales_data_df.columns:
|
||||
raise ValueError("DataFrame does not contain a 'date' column.")
|
||||
|
||||
# Convert the 'date' column to datetime objects
|
||||
sales_data_df['date'] = pd.to_datetime(sales_data_df['date'])
|
||||
|
||||
# Localize the naive datetime objects to UTC
|
||||
sales_data_df['date'] = sales_data_df['date'].tz_localize('UTC')
|
||||
|
||||
# Find the minimum and maximum dates
|
||||
start_date = sales_data_df['date'].min()
|
||||
end_date = sales_data_df['date'].max()
|
||||
|
||||
return DateRange(start_date, end_date, DataSourceType.BAKERY_SALES)
|
||||
|
||||
def _extract_sales_date_range(self, sales_data: List[Dict[str, Any]]) -> 'DateRange':
|
||||
"""
|
||||
Extract date range from sales data with proper date parsing
|
||||
|
||||
Args:
|
||||
sales_data: List of sales records
|
||||
|
||||
Returns:
|
||||
DateRange object with timezone-aware start and end dates
|
||||
"""
|
||||
if not sales_data:
|
||||
raise ValueError("No sales data provided for date range extraction")
|
||||
|
||||
# Convert to DataFrame for easier processing
|
||||
sales_df = pd.DataFrame(sales_data)
|
||||
|
||||
if 'date' not in sales_df.columns:
|
||||
raise ValueError("Sales data does not contain a 'date' column")
|
||||
|
||||
# Convert dates to datetime with proper parsing
|
||||
# This will use the improved date parsing from the data import service
|
||||
sales_df['date'] = pd.to_datetime(sales_df['date'], utc=True, errors='coerce')
|
||||
|
||||
# Remove any rows with invalid dates
|
||||
sales_df = sales_df.dropna(subset=['date'])
|
||||
|
||||
if len(sales_df) == 0:
|
||||
raise ValueError("No valid dates found in sales data")
|
||||
|
||||
# Find the minimum and maximum dates
|
||||
start_date = sales_df['date'].min()
|
||||
end_date = sales_df['date'].max()
|
||||
|
||||
logger.info(f"Extracted sales date range: {start_date} to {end_date}")
|
||||
|
||||
return DateRange(start_date, end_date, DataSourceType.BAKERY_SALES)
|
||||
|
||||
def _filter_sales_data(
|
||||
self,
|
||||
sales_data: List[Dict[str, Any]],
|
||||
aligned_range: AlignedDateRange
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Filter sales data to the aligned date range with enhanced validation"""
|
||||
filtered_data = []
|
||||
filtered_count = 0
|
||||
|
||||
for record in sales_data:
|
||||
try:
|
||||
if 'date' in record:
|
||||
record_date = record['date']
|
||||
|
||||
# ✅ FIX: Proper timezone handling for date parsing - FIXED THE TRUNCATION ISSUE
|
||||
if isinstance(record_date, str):
|
||||
# Parse complete ISO datetime string with timezone info intact
|
||||
# DO NOT truncate to date part only - this was causing the filtering issue
|
||||
if 'T' in record_date:
|
||||
record_date = record_date.replace('Z', '+00:00')
|
||||
# Parse with FULL datetime info, not just date part
|
||||
parsed_date = datetime.fromisoformat(record_date)
|
||||
# Ensure timezone-aware
|
||||
if parsed_date.tzinfo is None:
|
||||
parsed_date = parsed_date.replace(tzinfo=timezone.utc)
|
||||
record_date = parsed_date
|
||||
elif isinstance(record_date, datetime):
|
||||
# Ensure timezone-aware
|
||||
if record_date.tzinfo is None:
|
||||
record_date = record_date.replace(tzinfo=timezone.utc)
|
||||
# DO NOT normalize to start of day - keep actual datetime for proper filtering
|
||||
# Only normalize if needed for daily aggregation, but preserve original for filtering
|
||||
|
||||
# ✅ FIX: Ensure aligned_range dates are also timezone-aware for comparison
|
||||
aligned_start = aligned_range.start
|
||||
aligned_end = aligned_range.end
|
||||
|
||||
if aligned_start.tzinfo is None:
|
||||
aligned_start = aligned_start.replace(tzinfo=timezone.utc)
|
||||
if aligned_end.tzinfo is None:
|
||||
aligned_end = aligned_end.replace(tzinfo=timezone.utc)
|
||||
|
||||
# Check if date falls within aligned range (now both are timezone-aware)
|
||||
if aligned_start <= record_date <= aligned_end:
|
||||
# Validate that record has required fields
|
||||
if self._validate_sales_record(record):
|
||||
filtered_data.append(record)
|
||||
else:
|
||||
filtered_count += 1
|
||||
else:
|
||||
# Record outside date range
|
||||
filtered_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing sales record: {str(e)}")
|
||||
filtered_count += 1
|
||||
continue
|
||||
|
||||
logger.info(f"Filtered sales data: {len(filtered_data)} records in aligned range")
|
||||
if filtered_count > 0:
|
||||
logger.warning(f"Filtered out {filtered_count} invalid records")
|
||||
|
||||
return filtered_data
|
||||
|
||||
def _validate_sales_record(self, record: Dict[str, Any]) -> bool:
|
||||
"""Validate individual sales record"""
|
||||
required_fields = ['date', 'inventory_product_id']
|
||||
quantity_fields = ['quantity', 'quantity_sold', 'sales', 'units_sold']
|
||||
|
||||
# Check required fields
|
||||
for field in required_fields:
|
||||
if field not in record or record[field] is None:
|
||||
return False
|
||||
|
||||
# Check at least one quantity field exists
|
||||
has_quantity = any(field in record and record[field] is not None for field in quantity_fields)
|
||||
if not has_quantity:
|
||||
return False
|
||||
|
||||
# Validate quantity is numeric and non-negative
|
||||
for field in quantity_fields:
|
||||
if field in record and record[field] is not None:
|
||||
try:
|
||||
quantity = float(record[field])
|
||||
if quantity < 0:
|
||||
return False
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
break
|
||||
|
||||
return True
|
||||
|
||||
async def _collect_external_data(
|
||||
self,
|
||||
aligned_range: AlignedDateRange,
|
||||
bakery_location: Tuple[float, float],
|
||||
tenant_id: str
|
||||
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], Dict[str, Any]]:
|
||||
"""Collect weather, traffic, and POI data concurrently with enhanced error handling"""
|
||||
|
||||
lat, lon = bakery_location
|
||||
|
||||
# Create collection tasks with timeout
|
||||
tasks = []
|
||||
|
||||
# Weather data collection
|
||||
if DataSourceType.WEATHER_FORECAST in aligned_range.available_sources:
|
||||
weather_task = asyncio.create_task(
|
||||
self._collect_weather_data_with_timeout(lat, lon, aligned_range, tenant_id)
|
||||
)
|
||||
tasks.append(("weather", weather_task))
|
||||
|
||||
# Enhanced Traffic data collection (supports multiple cities)
|
||||
if DataSourceType.MADRID_TRAFFIC in aligned_range.available_sources:
|
||||
logger.info(f"🚛 Traffic data source available for multiple cities, creating collection task for date range: {aligned_range.start} to {aligned_range.end}")
|
||||
traffic_task = asyncio.create_task(
|
||||
self._collect_traffic_data_with_timeout_enhanced(lat, lon, aligned_range, tenant_id)
|
||||
)
|
||||
tasks.append(("traffic", traffic_task))
|
||||
else:
|
||||
logger.warning(f"🚫 Traffic data source NOT available in sources: {[s.value for s in aligned_range.available_sources]}")
|
||||
|
||||
# POI features collection (static, location-based)
|
||||
poi_task = asyncio.create_task(
|
||||
self._collect_poi_features(lat, lon, tenant_id)
|
||||
)
|
||||
tasks.append(("poi", poi_task))
|
||||
|
||||
# Execute tasks concurrently with proper error handling
|
||||
results = {}
|
||||
if tasks:
|
||||
try:
|
||||
completed_tasks = await asyncio.gather(
|
||||
*[task for _, task in tasks],
|
||||
return_exceptions=True
|
||||
)
|
||||
|
||||
for i, (task_name, _) in enumerate(tasks):
|
||||
result = completed_tasks[i]
|
||||
if isinstance(result, Exception):
|
||||
logger.warning(f"{task_name} data collection failed: {result}")
|
||||
results[task_name] = [] if task_name != "poi" else {}
|
||||
else:
|
||||
results[task_name] = result
|
||||
if task_name == "poi":
|
||||
logger.info(f"{task_name} features collected: {len(result) if result else 0} features")
|
||||
else:
|
||||
logger.info(f"{task_name} data collection completed: {len(result)} records")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in concurrent data collection: {str(e)}")
|
||||
results = {"weather": [], "traffic": [], "poi": {}}
|
||||
|
||||
weather_data = results.get("weather", [])
|
||||
traffic_data = results.get("traffic", [])
|
||||
poi_features = results.get("poi", {})
|
||||
|
||||
return weather_data, traffic_data, poi_features
|
||||
|
||||
async def _collect_poi_features(
|
||||
self,
|
||||
lat: float,
|
||||
lon: float,
|
||||
tenant_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Collect POI features for bakery location (non-blocking).
|
||||
|
||||
POI features are static (location-based, not time-varying).
|
||||
This method is non-blocking with a short timeout to prevent training delays.
|
||||
If POI detection hasn't been run yet, training continues without POI features.
|
||||
|
||||
Returns:
|
||||
Dictionary with POI features or empty dict if unavailable
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Collecting POI features (non-blocking)",
|
||||
tenant_id=tenant_id,
|
||||
location=(lat, lon)
|
||||
)
|
||||
|
||||
# Set a short timeout to prevent blocking training
|
||||
# POI detection should have been triggered during tenant registration
|
||||
poi_features = await asyncio.wait_for(
|
||||
self.poi_feature_integrator.fetch_poi_features(
|
||||
tenant_id=tenant_id,
|
||||
latitude=lat,
|
||||
longitude=lon,
|
||||
force_refresh=False
|
||||
),
|
||||
timeout=15.0 # 15 second timeout - POI should be cached from registration
|
||||
)
|
||||
|
||||
if poi_features:
|
||||
logger.info(
|
||||
"POI features collected successfully",
|
||||
tenant_id=tenant_id,
|
||||
feature_count=len(poi_features)
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"No POI features collected (service may be unavailable or not yet detected)",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
return poi_features or {}
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"POI collection timeout (15s) - continuing training without POI features. "
|
||||
"POI detection should be triggered during tenant registration for best results.",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to collect POI features (non-blocking) - continuing training without them",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
return {}
|
||||
|
||||
async def _collect_weather_data_with_timeout(
|
||||
self,
|
||||
lat: float,
|
||||
lon: float,
|
||||
aligned_range: AlignedDateRange,
|
||||
tenant_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Collect weather data with timeout and fallback"""
|
||||
try:
|
||||
|
||||
start_date_str = aligned_range.start.isoformat()
|
||||
end_date_str = aligned_range.end.isoformat()
|
||||
|
||||
weather_data = await self.data_client.fetch_weather_data(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date_str,
|
||||
end_date=end_date_str,
|
||||
latitude=lat,
|
||||
longitude=lon)
|
||||
|
||||
# Validate weather data
|
||||
if self._validate_weather_data(weather_data):
|
||||
logger.info(f"Collected {len(weather_data)} valid weather records")
|
||||
return weather_data
|
||||
else:
|
||||
logger.warning("Invalid weather data received, using synthetic data")
|
||||
return self._generate_synthetic_weather_data(aligned_range)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Weather data collection timed out, using synthetic data")
|
||||
return self._generate_synthetic_weather_data(aligned_range)
|
||||
except Exception as e:
|
||||
logger.warning(f"Weather data collection failed: {e}, using synthetic data")
|
||||
return self._generate_synthetic_weather_data(aligned_range)
|
||||
|
||||
async def _collect_traffic_data_with_timeout_enhanced(
|
||||
self,
|
||||
lat: float,
|
||||
lon: float,
|
||||
aligned_range: AlignedDateRange,
|
||||
tenant_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Enhanced traffic data collection with multi-city support and improved storage
|
||||
Uses the new abstracted traffic service layer
|
||||
"""
|
||||
try:
|
||||
# Double-check constraints before making request
|
||||
constraint_violated = self.date_alignment_service.check_madrid_current_month_constraint(aligned_range.end)
|
||||
if constraint_violated:
|
||||
logger.warning(f"🚫 Current month constraint violation: end_date={aligned_range.end}, no traffic data available")
|
||||
return []
|
||||
else:
|
||||
logger.info(f"✅ Date constraints passed: end_date={aligned_range.end}, proceeding with traffic data request")
|
||||
|
||||
start_date_str = aligned_range.start.isoformat()
|
||||
end_date_str = aligned_range.end.isoformat()
|
||||
|
||||
# Enhanced: Fetch traffic data using unified cache-first method
|
||||
# This automatically detects the appropriate city and uses the right client
|
||||
traffic_data = await self.data_client.fetch_traffic_data_unified(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date_str,
|
||||
end_date=end_date_str,
|
||||
latitude=lat,
|
||||
longitude=lon,
|
||||
force_refresh=False # Use cache-first strategy
|
||||
)
|
||||
|
||||
# Enhanced validation including pedestrian inference data
|
||||
if self._validate_traffic_data_enhanced(traffic_data):
|
||||
logger.info(f"Collected and stored {len(traffic_data)} valid enhanced traffic records for re-training")
|
||||
|
||||
# Log storage success with enhanced metadata
|
||||
self._log_enhanced_traffic_data_storage(lat, lon, aligned_range, len(traffic_data), traffic_data)
|
||||
|
||||
return traffic_data
|
||||
else:
|
||||
logger.warning("Invalid enhanced traffic data received")
|
||||
return []
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Enhanced traffic data collection timed out")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.warning(f"Enhanced traffic data collection failed: {e}")
|
||||
return []
|
||||
|
||||
def _log_enhanced_traffic_data_storage(self,
|
||||
lat: float,
|
||||
lon: float,
|
||||
aligned_range: AlignedDateRange,
|
||||
record_count: int,
|
||||
traffic_data: List[Dict[str, Any]]):
|
||||
"""Enhanced logging for traffic data storage with detailed metadata"""
|
||||
cities_detected = set()
|
||||
has_pedestrian_data = 0
|
||||
data_sources = set()
|
||||
districts_covered = set()
|
||||
|
||||
for record in traffic_data:
|
||||
if 'city' in record and record['city']:
|
||||
cities_detected.add(record['city'])
|
||||
if 'pedestrian_count' in record and record['pedestrian_count'] is not None:
|
||||
has_pedestrian_data += 1
|
||||
if 'source' in record and record['source']:
|
||||
data_sources.add(record['source'])
|
||||
if 'district' in record and record['district']:
|
||||
districts_covered.add(record['district'])
|
||||
|
||||
logger.info(
|
||||
"Enhanced traffic data stored for re-training",
|
||||
location=f"{lat:.4f},{lon:.4f}",
|
||||
date_range=f"{aligned_range.start.isoformat()} to {aligned_range.end.isoformat()}",
|
||||
records_stored=record_count,
|
||||
cities_detected=list(cities_detected),
|
||||
pedestrian_inference_coverage=f"{has_pedestrian_data}/{record_count}",
|
||||
data_sources=list(data_sources),
|
||||
districts_covered=list(districts_covered),
|
||||
storage_timestamp=datetime.now().isoformat(),
|
||||
purpose="model_training_and_retraining"
|
||||
)
|
||||
|
||||
def _validate_weather_data(self, weather_data: List[Dict[str, Any]]) -> bool:
|
||||
"""Validate weather data quality"""
|
||||
if not weather_data:
|
||||
return False
|
||||
|
||||
required_fields = ['date']
|
||||
weather_fields = ['temperature', 'temp', 'temperatura', 'precipitation', 'rain', 'lluvia']
|
||||
|
||||
valid_records = 0
|
||||
for record in weather_data:
|
||||
# Check required fields
|
||||
if not all(field in record for field in required_fields):
|
||||
continue
|
||||
|
||||
# Check at least one weather field exists
|
||||
if any(field in record and record[field] is not None for field in weather_fields):
|
||||
valid_records += 1
|
||||
|
||||
# Consider valid if at least 50% of records are valid
|
||||
validity_threshold = 0.5
|
||||
is_valid = (valid_records / len(weather_data)) >= validity_threshold
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(f"Weather data validation failed: {valid_records}/{len(weather_data)} valid records")
|
||||
|
||||
return is_valid
|
||||
|
||||
def _validate_traffic_data_enhanced(self, traffic_data: List[Dict[str, Any]]) -> bool:
|
||||
"""Enhanced validation for traffic data including pedestrian inference and city-specific fields"""
|
||||
if not traffic_data:
|
||||
return False
|
||||
|
||||
required_fields = ['date']
|
||||
traffic_fields = ['traffic_volume', 'traffic_intensity', 'intensidad', 'trafico']
|
||||
enhanced_fields = ['pedestrian_count', 'congestion_level', 'source']
|
||||
city_specific_fields = ['city', 'measurement_point_id', 'district']
|
||||
|
||||
valid_records = 0
|
||||
enhanced_records = 0
|
||||
city_aware_records = 0
|
||||
|
||||
for record in traffic_data:
|
||||
record_score = 0
|
||||
|
||||
# Check required fields
|
||||
if all(field in record and record[field] is not None for field in required_fields):
|
||||
record_score += 1
|
||||
|
||||
# Check traffic data fields
|
||||
if any(field in record and record[field] is not None for field in traffic_fields):
|
||||
record_score += 1
|
||||
|
||||
# Check enhanced fields (pedestrian inference, etc.)
|
||||
enhanced_count = sum(1 for field in enhanced_fields
|
||||
if field in record and record[field] is not None)
|
||||
if enhanced_count >= 2: # At least 2 enhanced fields
|
||||
enhanced_records += 1
|
||||
record_score += 1
|
||||
|
||||
# Check city-specific awareness
|
||||
city_count = sum(1 for field in city_specific_fields
|
||||
if field in record and record[field] is not None)
|
||||
if city_count >= 1: # At least some city awareness
|
||||
city_aware_records += 1
|
||||
|
||||
# Record is valid if it has basic requirements (date + any traffic field)
|
||||
# Lowered requirement from >= 2 to >= 1 to accept records with just date or traffic data
|
||||
if record_score >= 1:
|
||||
valid_records += 1
|
||||
|
||||
total_records = len(traffic_data)
|
||||
validity_threshold = 0.1 # Reduced from 0.3 to 0.1 - accept if 10% of records are valid
|
||||
enhancement_threshold = 0.1 # Reduced threshold for enhanced features
|
||||
|
||||
basic_validity = (valid_records / total_records) >= validity_threshold
|
||||
has_enhancements = (enhanced_records / total_records) >= enhancement_threshold
|
||||
has_city_awareness = (city_aware_records / total_records) >= enhancement_threshold
|
||||
|
||||
logger.info("Enhanced traffic data validation results",
|
||||
total_records=total_records,
|
||||
valid_records=valid_records,
|
||||
enhanced_records=enhanced_records,
|
||||
city_aware_records=city_aware_records,
|
||||
basic_validity=basic_validity,
|
||||
has_enhancements=has_enhancements,
|
||||
has_city_awareness=has_city_awareness)
|
||||
|
||||
if not basic_validity:
|
||||
logger.warning(f"Traffic data basic validation failed: {valid_records}/{total_records} valid records")
|
||||
|
||||
return basic_validity
|
||||
|
||||
def _validate_traffic_data(self, traffic_data: List[Dict[str, Any]]) -> bool:
|
||||
"""Legacy validation method - redirects to enhanced version"""
|
||||
return self._validate_traffic_data_enhanced(traffic_data)
|
||||
|
||||
def _validate_data_sources(
|
||||
self,
|
||||
sales_data: List[Dict[str, Any]],
|
||||
weather_data: List[Dict[str, Any]],
|
||||
traffic_data: List[Dict[str, Any]],
|
||||
aligned_range: AlignedDateRange
|
||||
) -> Dict[str, Any]:
|
||||
"""Validate all data sources and provide quality metrics"""
|
||||
|
||||
validation_results = {
|
||||
"sales_data": {
|
||||
"record_count": len(sales_data),
|
||||
"is_valid": len(sales_data) > 0,
|
||||
"coverage_days": (aligned_range.end - aligned_range.start).days,
|
||||
"quality_score": 0.0
|
||||
},
|
||||
"weather_data": {
|
||||
"record_count": len(weather_data),
|
||||
"is_valid": self._validate_weather_data(weather_data) if weather_data else False,
|
||||
"quality_score": 0.0
|
||||
},
|
||||
"traffic_data": {
|
||||
"record_count": len(traffic_data),
|
||||
"is_valid": self._validate_traffic_data(traffic_data) if traffic_data else False,
|
||||
"quality_score": 0.0
|
||||
},
|
||||
"overall_quality_score": 0.0
|
||||
}
|
||||
|
||||
# Calculate quality scores
|
||||
# Sales data quality (most important)
|
||||
if validation_results["sales_data"]["record_count"] > 0:
|
||||
coverage_ratio = min(1.0, validation_results["sales_data"]["record_count"] / validation_results["sales_data"]["coverage_days"])
|
||||
validation_results["sales_data"]["quality_score"] = coverage_ratio * 100
|
||||
|
||||
# Weather data quality
|
||||
if validation_results["weather_data"]["record_count"] > 0:
|
||||
expected_weather_records = (aligned_range.end - aligned_range.start).days
|
||||
coverage_ratio = min(1.0, validation_results["weather_data"]["record_count"] / expected_weather_records)
|
||||
validation_results["weather_data"]["quality_score"] = coverage_ratio * 100
|
||||
|
||||
# Traffic data quality
|
||||
if validation_results["traffic_data"]["record_count"] > 0:
|
||||
expected_traffic_records = (aligned_range.end - aligned_range.start).days
|
||||
coverage_ratio = min(1.0, validation_results["traffic_data"]["record_count"] / expected_traffic_records)
|
||||
validation_results["traffic_data"]["quality_score"] = coverage_ratio * 100
|
||||
|
||||
# Overall quality score (weighted by importance)
|
||||
weights = {"sales_data": 0.7, "weather_data": 0.2, "traffic_data": 0.1}
|
||||
overall_score = sum(
|
||||
validation_results[source]["quality_score"] * weight
|
||||
for source, weight in weights.items()
|
||||
)
|
||||
validation_results["overall_quality_score"] = round(overall_score, 2)
|
||||
|
||||
return validation_results
|
||||
|
||||
def _generate_synthetic_weather_data(
|
||||
self,
|
||||
aligned_range: AlignedDateRange
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Generate realistic synthetic weather data for Madrid"""
|
||||
synthetic_data = []
|
||||
current_date = aligned_range.start
|
||||
|
||||
# Madrid seasonal temperature patterns
|
||||
seasonal_temps = {
|
||||
1: 9, 2: 11, 3: 15, 4: 17, 5: 21, 6: 26,
|
||||
7: 29, 8: 28, 9: 24, 10: 18, 11: 12, 12: 9
|
||||
}
|
||||
|
||||
while current_date <= aligned_range.end:
|
||||
month = current_date.month
|
||||
base_temp = seasonal_temps.get(month, 15)
|
||||
|
||||
# Add some realistic variation
|
||||
import random
|
||||
temp_variation = random.gauss(0, 3) # ±3°C variation
|
||||
temperature = max(0, base_temp + temp_variation)
|
||||
|
||||
# Precipitation patterns (Madrid is relatively dry)
|
||||
precipitation = 0.0
|
||||
if random.random() < 0.15: # 15% chance of rain
|
||||
precipitation = random.uniform(0.1, 15.0)
|
||||
|
||||
synthetic_data.append({
|
||||
"date": current_date,
|
||||
"temperature": round(temperature, 1),
|
||||
"precipitation": round(precipitation, 1),
|
||||
"humidity": round(random.uniform(40, 80), 1),
|
||||
"wind_speed": round(random.uniform(2, 15), 1),
|
||||
"pressure": round(random.uniform(1005, 1025), 1),
|
||||
"source": "synthetic_madrid_pattern"
|
||||
})
|
||||
|
||||
current_date = current_date + timedelta(days=1)
|
||||
|
||||
logger.info(f"Generated {len(synthetic_data)} synthetic weather records with Madrid patterns")
|
||||
return synthetic_data
|
||||
|
||||
def validate_training_data_quality(self, dataset: TrainingDataSet) -> Dict[str, Any]:
|
||||
|
||||
"""Enhanced validation of training data quality"""
|
||||
validation_results = {
|
||||
"is_valid": True,
|
||||
"warnings": [],
|
||||
"errors": [],
|
||||
"data_quality_score": 100.0,
|
||||
"recommendations": []
|
||||
}
|
||||
|
||||
# Check sales data completeness
|
||||
sales_count = len(dataset.sales_data)
|
||||
if sales_count < 30:
|
||||
validation_results["warnings"].append(
|
||||
f"Limited sales data: {sales_count} records (recommended: 30+)"
|
||||
)
|
||||
validation_results["data_quality_score"] -= 20
|
||||
validation_results["recommendations"].append("Consider collecting more historical sales data")
|
||||
elif sales_count < 90:
|
||||
validation_results["warnings"].append(
|
||||
f"Moderate sales data: {sales_count} records (optimal: 90+)"
|
||||
)
|
||||
validation_results["data_quality_score"] -= 10
|
||||
|
||||
# Check date coverage
|
||||
date_coverage = (dataset.date_range.end - dataset.date_range.start).days
|
||||
if date_coverage < 90:
|
||||
validation_results["warnings"].append(
|
||||
f"Limited date coverage: {date_coverage} days (recommended: 90+)"
|
||||
)
|
||||
validation_results["data_quality_score"] -= 15
|
||||
validation_results["recommendations"].append("Extend date range for better seasonality detection")
|
||||
|
||||
# Check external data availability
|
||||
if not dataset.weather_data:
|
||||
validation_results["warnings"].append("No weather data available")
|
||||
validation_results["data_quality_score"] -= 10
|
||||
validation_results["recommendations"].append("Weather data improves forecast accuracy")
|
||||
elif len(dataset.weather_data) < date_coverage * 0.5:
|
||||
validation_results["warnings"].append("Sparse weather data coverage")
|
||||
validation_results["data_quality_score"] -= 5
|
||||
|
||||
if not dataset.traffic_data:
|
||||
validation_results["warnings"].append("No traffic data available")
|
||||
validation_results["data_quality_score"] -= 5
|
||||
validation_results["recommendations"].append("Traffic data can help with location-based patterns")
|
||||
|
||||
# Check data consistency
|
||||
unique_products = set()
|
||||
for record in dataset.sales_data:
|
||||
if 'inventory_product_id' in record:
|
||||
unique_products.add(record['inventory_product_id'])
|
||||
|
||||
if len(unique_products) == 0:
|
||||
validation_results["errors"].append("No product names found in sales data")
|
||||
validation_results["is_valid"] = False
|
||||
elif len(unique_products) > 50:
|
||||
validation_results["warnings"].append(
|
||||
f"Many products detected ({len(unique_products)}). Consider training models in batches."
|
||||
)
|
||||
validation_results["recommendations"].append("Group similar products for better training efficiency")
|
||||
|
||||
# Check for data source constraints
|
||||
if dataset.date_range.constraints:
|
||||
constraint_info = []
|
||||
for constraint_type, message in dataset.date_range.constraints.items():
|
||||
constraint_info.append(f"{constraint_type}: {message}")
|
||||
|
||||
validation_results["warnings"].append(
|
||||
f"Data source constraints applied: {'; '.join(constraint_info)}"
|
||||
)
|
||||
|
||||
# Final validation
|
||||
if validation_results["errors"]:
|
||||
validation_results["is_valid"] = False
|
||||
validation_results["data_quality_score"] = 0.0
|
||||
|
||||
# Ensure score doesn't go below 0
|
||||
validation_results["data_quality_score"] = max(0.0, validation_results["data_quality_score"])
|
||||
|
||||
# Add quality assessment
|
||||
score = validation_results["data_quality_score"]
|
||||
if score >= 80:
|
||||
validation_results["quality_assessment"] = "Excellent"
|
||||
elif score >= 60:
|
||||
validation_results["quality_assessment"] = "Good"
|
||||
elif score >= 40:
|
||||
validation_results["quality_assessment"] = "Fair"
|
||||
else:
|
||||
validation_results["quality_assessment"] = "Poor"
|
||||
|
||||
return validation_results
|
||||
|
||||
def get_data_collection_plan(self, aligned_range: AlignedDateRange) -> Dict[str, Dict]:
|
||||
"""
|
||||
Generate an enhanced data collection plan based on the aligned date range.
|
||||
"""
|
||||
plan = {
|
||||
"collection_summary": {
|
||||
"start_date": aligned_range.start.isoformat(),
|
||||
"end_date": aligned_range.end.isoformat(),
|
||||
"duration_days": (aligned_range.end - aligned_range.start).days,
|
||||
"available_sources": [source.value for source in aligned_range.available_sources],
|
||||
"constraints": aligned_range.constraints
|
||||
},
|
||||
"data_sources": {}
|
||||
}
|
||||
|
||||
# Bakery Sales Data
|
||||
if DataSourceType.BAKERY_SALES in aligned_range.available_sources:
|
||||
plan["data_sources"]["sales_data"] = {
|
||||
"start_date": aligned_range.start.isoformat(),
|
||||
"end_date": aligned_range.end.isoformat(),
|
||||
"source": "user_upload",
|
||||
"required": True,
|
||||
"priority": "high",
|
||||
"expected_records": "variable",
|
||||
"data_points": ["date", "inventory_product_id", "quantity"],
|
||||
"validation": "required_fields_check"
|
||||
}
|
||||
|
||||
# Madrid Traffic Data
|
||||
if DataSourceType.MADRID_TRAFFIC in aligned_range.available_sources:
|
||||
plan["data_sources"]["traffic_data"] = {
|
||||
"start_date": aligned_range.start.isoformat(),
|
||||
"end_date": aligned_range.end.isoformat(),
|
||||
"source": "madrid_opendata",
|
||||
"required": False,
|
||||
"priority": "medium",
|
||||
"expected_records": (aligned_range.end - aligned_range.start).days,
|
||||
"constraint": "Cannot request current month data",
|
||||
"data_points": ["date", "traffic_volume", "congestion_level"],
|
||||
"validation": "date_constraint_check"
|
||||
}
|
||||
|
||||
# Weather Data
|
||||
if DataSourceType.WEATHER_FORECAST in aligned_range.available_sources:
|
||||
plan["data_sources"]["weather_data"] = {
|
||||
"start_date": aligned_range.start.isoformat(),
|
||||
"end_date": aligned_range.end.isoformat(),
|
||||
"source": "aemet_api",
|
||||
"required": False,
|
||||
"priority": "high",
|
||||
"expected_records": (aligned_range.end - aligned_range.start).days,
|
||||
"constraint": "Available from yesterday backward",
|
||||
"data_points": ["date", "temperature", "precipitation", "humidity"],
|
||||
"validation": "temporal_constraint_check",
|
||||
"fallback": "synthetic_madrid_weather"
|
||||
}
|
||||
|
||||
return plan
|
||||
|
||||
def get_orchestration_summary(self, dataset: TrainingDataSet) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate a comprehensive summary of the orchestration process.
|
||||
"""
|
||||
return {
|
||||
"tenant_id": dataset.metadata.get("tenant_id"),
|
||||
"job_id": dataset.metadata.get("job_id"),
|
||||
"orchestration_completed_at": dataset.metadata.get("preparation_timestamp"),
|
||||
"data_alignment": {
|
||||
"original_range": dataset.metadata.get("original_sales_range"),
|
||||
"aligned_range": {
|
||||
"start": dataset.date_range.start.isoformat(),
|
||||
"end": dataset.date_range.end.isoformat(),
|
||||
"duration_days": (dataset.date_range.end - dataset.date_range.start).days
|
||||
},
|
||||
"constraints_applied": dataset.date_range.constraints,
|
||||
"available_sources": [source.value for source in dataset.date_range.available_sources]
|
||||
},
|
||||
"data_collection_results": {
|
||||
"sales_records": len(dataset.sales_data),
|
||||
"weather_records": len(dataset.weather_data),
|
||||
"traffic_records": len(dataset.traffic_data),
|
||||
"total_records": len(dataset.sales_data) + len(dataset.weather_data) + len(dataset.traffic_data)
|
||||
},
|
||||
"data_quality": dataset.metadata.get("data_quality", {}),
|
||||
"validation_results": dataset.metadata.get("final_validation", {}),
|
||||
"processing_metadata": {
|
||||
"bakery_location": dataset.metadata.get("bakery_location"),
|
||||
"data_sources_requested": len(dataset.date_range.available_sources),
|
||||
"data_sources_successful": sum([
|
||||
1 if len(dataset.sales_data) > 0 else 0,
|
||||
1 if len(dataset.weather_data) > 0 else 0,
|
||||
1 if len(dataset.traffic_data) > 0 else 0
|
||||
])
|
||||
}
|
||||
}
|
||||
1076
services/training/app/services/training_service.py
Normal file
1076
services/training/app/services/training_service.py
Normal file
File diff suppressed because it is too large
Load Diff
92
services/training/app/utils/__init__.py
Normal file
92
services/training/app/utils/__init__.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""
|
||||
Training Service Utilities
|
||||
"""
|
||||
|
||||
from .ml_datetime import (
|
||||
ensure_timezone_aware,
|
||||
ensure_timezone_naive,
|
||||
normalize_datetime_to_utc,
|
||||
normalize_dataframe_datetime_column,
|
||||
prepare_prophet_datetime,
|
||||
safe_datetime_comparison,
|
||||
get_current_utc,
|
||||
convert_timestamp_to_datetime
|
||||
)
|
||||
|
||||
from .circuit_breaker import (
|
||||
CircuitBreaker,
|
||||
CircuitBreakerError,
|
||||
CircuitState,
|
||||
circuit_breaker_registry
|
||||
)
|
||||
|
||||
from .file_utils import (
|
||||
calculate_file_checksum,
|
||||
verify_file_checksum,
|
||||
get_file_size,
|
||||
ensure_directory_exists,
|
||||
safe_file_delete,
|
||||
get_file_metadata,
|
||||
ChecksummedFile
|
||||
)
|
||||
|
||||
from .distributed_lock import (
|
||||
DatabaseLock,
|
||||
SimpleDatabaseLock,
|
||||
LockAcquisitionError,
|
||||
get_training_lock
|
||||
)
|
||||
|
||||
from .retry import (
|
||||
RetryStrategy,
|
||||
RetryError,
|
||||
retry_async,
|
||||
with_retry,
|
||||
retry_with_timeout,
|
||||
AdaptiveRetryStrategy,
|
||||
TimeoutRetryStrategy,
|
||||
HTTP_RETRY_STRATEGY,
|
||||
DATABASE_RETRY_STRATEGY,
|
||||
EXTERNAL_SERVICE_RETRY_STRATEGY
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Timezone utilities
|
||||
'ensure_timezone_aware',
|
||||
'ensure_timezone_naive',
|
||||
'normalize_datetime_to_utc',
|
||||
'normalize_dataframe_datetime_column',
|
||||
'prepare_prophet_datetime',
|
||||
'safe_datetime_comparison',
|
||||
'get_current_utc',
|
||||
'convert_timestamp_to_datetime',
|
||||
# Circuit breaker
|
||||
'CircuitBreaker',
|
||||
'CircuitBreakerError',
|
||||
'CircuitState',
|
||||
'circuit_breaker_registry',
|
||||
# File utilities
|
||||
'calculate_file_checksum',
|
||||
'verify_file_checksum',
|
||||
'get_file_size',
|
||||
'ensure_directory_exists',
|
||||
'safe_file_delete',
|
||||
'get_file_metadata',
|
||||
'ChecksummedFile',
|
||||
# Distributed locking
|
||||
'DatabaseLock',
|
||||
'SimpleDatabaseLock',
|
||||
'LockAcquisitionError',
|
||||
'get_training_lock',
|
||||
# Retry mechanisms
|
||||
'RetryStrategy',
|
||||
'RetryError',
|
||||
'retry_async',
|
||||
'with_retry',
|
||||
'retry_with_timeout',
|
||||
'AdaptiveRetryStrategy',
|
||||
'TimeoutRetryStrategy',
|
||||
'HTTP_RETRY_STRATEGY',
|
||||
'DATABASE_RETRY_STRATEGY',
|
||||
'EXTERNAL_SERVICE_RETRY_STRATEGY'
|
||||
]
|
||||
198
services/training/app/utils/circuit_breaker.py
Normal file
198
services/training/app/utils/circuit_breaker.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""
|
||||
Circuit Breaker Pattern Implementation
|
||||
Protects against cascading failures from external service calls
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Callable, Any, Optional
|
||||
import logging
|
||||
from functools import wraps
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CircuitState(Enum):
|
||||
"""Circuit breaker states"""
|
||||
CLOSED = "closed" # Normal operation
|
||||
OPEN = "open" # Circuit is open, rejecting requests
|
||||
HALF_OPEN = "half_open" # Testing if service recovered
|
||||
|
||||
|
||||
class CircuitBreakerError(Exception):
|
||||
"""Raised when circuit breaker is open"""
|
||||
pass
|
||||
|
||||
|
||||
class CircuitBreaker:
|
||||
"""
|
||||
Circuit breaker to prevent cascading failures.
|
||||
|
||||
States:
|
||||
- CLOSED: Normal operation, requests pass through
|
||||
- OPEN: Too many failures, rejecting all requests
|
||||
- HALF_OPEN: Testing if service recovered, allowing limited requests
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
failure_threshold: int = 5,
|
||||
recovery_timeout: float = 60.0,
|
||||
expected_exception: type = Exception,
|
||||
name: str = "circuit_breaker"
|
||||
):
|
||||
"""
|
||||
Initialize circuit breaker.
|
||||
|
||||
Args:
|
||||
failure_threshold: Number of failures before opening circuit
|
||||
recovery_timeout: Seconds to wait before attempting recovery
|
||||
expected_exception: Exception type to catch (others will pass through)
|
||||
name: Name for logging purposes
|
||||
"""
|
||||
self.failure_threshold = failure_threshold
|
||||
self.recovery_timeout = recovery_timeout
|
||||
self.expected_exception = expected_exception
|
||||
self.name = name
|
||||
|
||||
self.failure_count = 0
|
||||
self.last_failure_time: Optional[float] = None
|
||||
self.state = CircuitState.CLOSED
|
||||
|
||||
def _record_success(self):
|
||||
"""Record successful call"""
|
||||
self.failure_count = 0
|
||||
self.last_failure_time = None
|
||||
if self.state == CircuitState.HALF_OPEN:
|
||||
logger.info(f"Circuit breaker '{self.name}' recovered, closing circuit")
|
||||
self.state = CircuitState.CLOSED
|
||||
|
||||
def _record_failure(self):
|
||||
"""Record failed call"""
|
||||
self.failure_count += 1
|
||||
self.last_failure_time = time.time()
|
||||
|
||||
if self.failure_count >= self.failure_threshold:
|
||||
if self.state != CircuitState.OPEN:
|
||||
logger.warning(
|
||||
f"Circuit breaker '{self.name}' opened after {self.failure_count} failures"
|
||||
)
|
||||
self.state = CircuitState.OPEN
|
||||
|
||||
def _should_attempt_reset(self) -> bool:
|
||||
"""Check if we should attempt to reset circuit"""
|
||||
return (
|
||||
self.state == CircuitState.OPEN
|
||||
and self.last_failure_time is not None
|
||||
and time.time() - self.last_failure_time >= self.recovery_timeout
|
||||
)
|
||||
|
||||
async def call(self, func: Callable, *args, **kwargs) -> Any:
|
||||
"""
|
||||
Execute function with circuit breaker protection.
|
||||
|
||||
Args:
|
||||
func: Async function to execute
|
||||
*args: Positional arguments for func
|
||||
**kwargs: Keyword arguments for func
|
||||
|
||||
Returns:
|
||||
Result from func
|
||||
|
||||
Raises:
|
||||
CircuitBreakerError: If circuit is open
|
||||
Exception: Original exception if not expected_exception type
|
||||
"""
|
||||
# Check if circuit is open
|
||||
if self.state == CircuitState.OPEN:
|
||||
if self._should_attempt_reset():
|
||||
logger.info(f"Circuit breaker '{self.name}' attempting recovery (half-open)")
|
||||
self.state = CircuitState.HALF_OPEN
|
||||
else:
|
||||
raise CircuitBreakerError(
|
||||
f"Circuit breaker '{self.name}' is open. "
|
||||
f"Service unavailable for {self.recovery_timeout}s after {self.failure_count} failures."
|
||||
)
|
||||
|
||||
try:
|
||||
# Execute the function
|
||||
result = await func(*args, **kwargs)
|
||||
self._record_success()
|
||||
return result
|
||||
|
||||
except self.expected_exception as e:
|
||||
self._record_failure()
|
||||
logger.error(
|
||||
f"Circuit breaker '{self.name}' caught failure",
|
||||
error=str(e),
|
||||
failure_count=self.failure_count,
|
||||
state=self.state.value
|
||||
)
|
||||
raise
|
||||
|
||||
def __call__(self, func: Callable) -> Callable:
|
||||
"""Decorator interface for circuit breaker"""
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
return await self.call(func, *args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
def get_state(self) -> dict:
|
||||
"""Get current circuit breaker state for monitoring"""
|
||||
return {
|
||||
"name": self.name,
|
||||
"state": self.state.value,
|
||||
"failure_count": self.failure_count,
|
||||
"failure_threshold": self.failure_threshold,
|
||||
"last_failure_time": self.last_failure_time,
|
||||
"recovery_timeout": self.recovery_timeout
|
||||
}
|
||||
|
||||
|
||||
class CircuitBreakerRegistry:
|
||||
"""Registry to manage multiple circuit breakers"""
|
||||
|
||||
def __init__(self):
|
||||
self._breakers: dict[str, CircuitBreaker] = {}
|
||||
|
||||
def get_or_create(
|
||||
self,
|
||||
name: str,
|
||||
failure_threshold: int = 5,
|
||||
recovery_timeout: float = 60.0,
|
||||
expected_exception: type = Exception
|
||||
) -> CircuitBreaker:
|
||||
"""Get existing circuit breaker or create new one"""
|
||||
if name not in self._breakers:
|
||||
self._breakers[name] = CircuitBreaker(
|
||||
failure_threshold=failure_threshold,
|
||||
recovery_timeout=recovery_timeout,
|
||||
expected_exception=expected_exception,
|
||||
name=name
|
||||
)
|
||||
return self._breakers[name]
|
||||
|
||||
def get(self, name: str) -> Optional[CircuitBreaker]:
|
||||
"""Get circuit breaker by name"""
|
||||
return self._breakers.get(name)
|
||||
|
||||
def get_all_states(self) -> dict:
|
||||
"""Get states of all circuit breakers"""
|
||||
return {
|
||||
name: breaker.get_state()
|
||||
for name, breaker in self._breakers.items()
|
||||
}
|
||||
|
||||
def reset(self, name: str):
|
||||
"""Manually reset a circuit breaker"""
|
||||
if name in self._breakers:
|
||||
breaker = self._breakers[name]
|
||||
breaker.failure_count = 0
|
||||
breaker.last_failure_time = None
|
||||
breaker.state = CircuitState.CLOSED
|
||||
logger.info(f"Circuit breaker '{name}' manually reset")
|
||||
|
||||
|
||||
# Global registry instance
|
||||
circuit_breaker_registry = CircuitBreakerRegistry()
|
||||
250
services/training/app/utils/distributed_lock.py
Normal file
250
services/training/app/utils/distributed_lock.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""
|
||||
Distributed Locking Mechanisms
|
||||
Prevents concurrent training jobs for the same product
|
||||
|
||||
HORIZONTAL SCALING FIX:
|
||||
- Uses SHA256 for stable hash across all Python processes/pods
|
||||
- Python's built-in hash() varies between processes due to hash randomization (Python 3.3+)
|
||||
- This ensures all pods compute the same lock ID for the same lock name
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import hashlib
|
||||
from typing import Optional
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import text
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LockAcquisitionError(Exception):
|
||||
"""Raised when lock cannot be acquired"""
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseLock:
|
||||
"""
|
||||
Database-based distributed lock using PostgreSQL advisory locks.
|
||||
Works across multiple service instances.
|
||||
"""
|
||||
|
||||
def __init__(self, lock_name: str, timeout: float = 30.0):
|
||||
"""
|
||||
Initialize database lock.
|
||||
|
||||
Args:
|
||||
lock_name: Unique identifier for the lock
|
||||
timeout: Maximum seconds to wait for lock acquisition
|
||||
"""
|
||||
self.lock_name = lock_name
|
||||
self.timeout = timeout
|
||||
self.lock_id = self._hash_lock_name(lock_name)
|
||||
|
||||
def _hash_lock_name(self, name: str) -> int:
|
||||
"""
|
||||
Convert lock name to integer ID for PostgreSQL advisory lock.
|
||||
|
||||
CRITICAL: Uses SHA256 for stable hash across all Python processes/pods.
|
||||
Python's built-in hash() varies between processes due to hash randomization
|
||||
(PYTHONHASHSEED, enabled by default since Python 3.3), which would cause
|
||||
different pods to compute different lock IDs for the same lock name,
|
||||
defeating the purpose of distributed locking.
|
||||
"""
|
||||
# Use SHA256 for stable, cross-process hash
|
||||
hash_bytes = hashlib.sha256(name.encode('utf-8')).digest()
|
||||
# Take first 4 bytes and convert to positive 31-bit integer
|
||||
# (PostgreSQL advisory locks use bigint, but we use 31-bit for safety)
|
||||
return int.from_bytes(hash_bytes[:4], 'big') % (2**31)
|
||||
|
||||
@asynccontextmanager
|
||||
async def acquire(self, session: AsyncSession):
|
||||
"""
|
||||
Acquire distributed lock as async context manager.
|
||||
|
||||
Args:
|
||||
session: Database session for lock operations
|
||||
|
||||
Raises:
|
||||
LockAcquisitionError: If lock cannot be acquired within timeout
|
||||
"""
|
||||
acquired = False
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Try to acquire lock with timeout
|
||||
while time.time() - start_time < self.timeout:
|
||||
# Try non-blocking lock acquisition
|
||||
result = await session.execute(
|
||||
text("SELECT pg_try_advisory_lock(:lock_id)"),
|
||||
{"lock_id": self.lock_id}
|
||||
)
|
||||
acquired = result.scalar()
|
||||
|
||||
if acquired:
|
||||
logger.info(f"Acquired lock: {self.lock_name} (id={self.lock_id})")
|
||||
break
|
||||
|
||||
# Wait a bit before retrying
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
if not acquired:
|
||||
raise LockAcquisitionError(
|
||||
f"Could not acquire lock '{self.lock_name}' within {self.timeout}s"
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
if acquired:
|
||||
# Release lock
|
||||
await session.execute(
|
||||
text("SELECT pg_advisory_unlock(:lock_id)"),
|
||||
{"lock_id": self.lock_id}
|
||||
)
|
||||
logger.info(f"Released lock: {self.lock_name} (id={self.lock_id})")
|
||||
|
||||
|
||||
class SimpleDatabaseLock:
|
||||
"""
|
||||
Simple table-based distributed lock.
|
||||
Alternative to advisory locks, uses a dedicated locks table.
|
||||
"""
|
||||
|
||||
def __init__(self, lock_name: str, timeout: float = 30.0, ttl: float = 300.0):
|
||||
"""
|
||||
Initialize simple database lock.
|
||||
|
||||
Args:
|
||||
lock_name: Unique identifier for the lock
|
||||
timeout: Maximum seconds to wait for lock acquisition
|
||||
ttl: Time-to-live for stale lock cleanup (seconds)
|
||||
"""
|
||||
self.lock_name = lock_name
|
||||
self.timeout = timeout
|
||||
self.ttl = ttl
|
||||
|
||||
async def _ensure_lock_table(self, session: AsyncSession):
|
||||
"""Ensure locks table exists"""
|
||||
create_table_sql = """
|
||||
CREATE TABLE IF NOT EXISTS distributed_locks (
|
||||
lock_name VARCHAR(255) PRIMARY KEY,
|
||||
acquired_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
||||
acquired_by VARCHAR(255),
|
||||
expires_at TIMESTAMP WITH TIME ZONE NOT NULL
|
||||
)
|
||||
"""
|
||||
await session.execute(text(create_table_sql))
|
||||
await session.commit()
|
||||
|
||||
async def _cleanup_stale_locks(self, session: AsyncSession):
|
||||
"""Remove expired locks"""
|
||||
cleanup_sql = """
|
||||
DELETE FROM distributed_locks
|
||||
WHERE expires_at < :now
|
||||
"""
|
||||
await session.execute(
|
||||
text(cleanup_sql),
|
||||
{"now": datetime.now(timezone.utc)}
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
@asynccontextmanager
|
||||
async def acquire(self, session: AsyncSession, owner: str = "training-service"):
|
||||
"""
|
||||
Acquire simple database lock.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
owner: Identifier for lock owner
|
||||
|
||||
Raises:
|
||||
LockAcquisitionError: If lock cannot be acquired
|
||||
"""
|
||||
await self._ensure_lock_table(session)
|
||||
await self._cleanup_stale_locks(session)
|
||||
|
||||
acquired = False
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Try to acquire lock
|
||||
while time.time() - start_time < self.timeout:
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now + timedelta(seconds=self.ttl)
|
||||
|
||||
try:
|
||||
# Try to insert lock record
|
||||
insert_sql = """
|
||||
INSERT INTO distributed_locks (lock_name, acquired_at, acquired_by, expires_at)
|
||||
VALUES (:lock_name, :acquired_at, :acquired_by, :expires_at)
|
||||
ON CONFLICT (lock_name) DO NOTHING
|
||||
RETURNING lock_name
|
||||
"""
|
||||
|
||||
result = await session.execute(
|
||||
text(insert_sql),
|
||||
{
|
||||
"lock_name": self.lock_name,
|
||||
"acquired_at": now,
|
||||
"acquired_by": owner,
|
||||
"expires_at": expires_at
|
||||
}
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
if result.rowcount > 0:
|
||||
acquired = True
|
||||
logger.info(f"Acquired simple lock: {self.lock_name}")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Lock acquisition attempt failed: {e}")
|
||||
await session.rollback()
|
||||
|
||||
# Wait before retrying
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
if not acquired:
|
||||
raise LockAcquisitionError(
|
||||
f"Could not acquire lock '{self.lock_name}' within {self.timeout}s"
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
if acquired:
|
||||
# Release lock
|
||||
delete_sql = """
|
||||
DELETE FROM distributed_locks
|
||||
WHERE lock_name = :lock_name
|
||||
"""
|
||||
await session.execute(
|
||||
text(delete_sql),
|
||||
{"lock_name": self.lock_name}
|
||||
)
|
||||
await session.commit()
|
||||
logger.info(f"Released simple lock: {self.lock_name}")
|
||||
|
||||
|
||||
def get_training_lock(tenant_id: str, product_id: str, use_advisory: bool = True) -> DatabaseLock:
|
||||
"""
|
||||
Get distributed lock for training a specific product.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
product_id: Product identifier
|
||||
use_advisory: Use PostgreSQL advisory locks (True) or table-based (False)
|
||||
|
||||
Returns:
|
||||
Lock instance
|
||||
"""
|
||||
lock_name = f"training:{tenant_id}:{product_id}"
|
||||
|
||||
if use_advisory:
|
||||
return DatabaseLock(lock_name, timeout=60.0)
|
||||
else:
|
||||
return SimpleDatabaseLock(lock_name, timeout=60.0, ttl=600.0)
|
||||
216
services/training/app/utils/file_utils.py
Normal file
216
services/training/app/utils/file_utils.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""
|
||||
File Utility Functions
|
||||
Utilities for secure file operations including checksum verification
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def calculate_file_checksum(file_path: str, algorithm: str = "sha256") -> str:
|
||||
"""
|
||||
Calculate checksum of a file.
|
||||
|
||||
Args:
|
||||
file_path: Path to file
|
||||
algorithm: Hash algorithm (sha256, md5, etc.)
|
||||
|
||||
Returns:
|
||||
Hexadecimal checksum string
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
ValueError: If algorithm not supported
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
try:
|
||||
hash_func = hashlib.new(algorithm)
|
||||
except ValueError:
|
||||
raise ValueError(f"Unsupported hash algorithm: {algorithm}")
|
||||
|
||||
# Read file in chunks to handle large files efficiently
|
||||
with open(file_path, 'rb') as f:
|
||||
while chunk := f.read(8192):
|
||||
hash_func.update(chunk)
|
||||
|
||||
return hash_func.hexdigest()
|
||||
|
||||
|
||||
def verify_file_checksum(file_path: str, expected_checksum: str, algorithm: str = "sha256") -> bool:
|
||||
"""
|
||||
Verify file matches expected checksum.
|
||||
|
||||
Args:
|
||||
file_path: Path to file
|
||||
expected_checksum: Expected checksum value
|
||||
algorithm: Hash algorithm used
|
||||
|
||||
Returns:
|
||||
True if checksum matches, False otherwise
|
||||
"""
|
||||
try:
|
||||
actual_checksum = calculate_file_checksum(file_path, algorithm)
|
||||
matches = actual_checksum == expected_checksum
|
||||
|
||||
if matches:
|
||||
logger.debug(f"Checksum verified for {file_path}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Checksum mismatch for {file_path}",
|
||||
expected=expected_checksum,
|
||||
actual=actual_checksum
|
||||
)
|
||||
|
||||
return matches
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error verifying checksum for {file_path}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_file_size(file_path: str) -> int:
|
||||
"""
|
||||
Get file size in bytes.
|
||||
|
||||
Args:
|
||||
file_path: Path to file
|
||||
|
||||
Returns:
|
||||
File size in bytes
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
return os.path.getsize(file_path)
|
||||
|
||||
|
||||
def ensure_directory_exists(directory: str) -> Path:
|
||||
"""
|
||||
Ensure directory exists, create if necessary.
|
||||
|
||||
Args:
|
||||
directory: Directory path
|
||||
|
||||
Returns:
|
||||
Path object for directory
|
||||
"""
|
||||
path = Path(directory)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
def safe_file_delete(file_path: str) -> bool:
|
||||
"""
|
||||
Safely delete a file, logging any errors.
|
||||
|
||||
Args:
|
||||
file_path: Path to file
|
||||
|
||||
Returns:
|
||||
True if deleted successfully, False otherwise
|
||||
"""
|
||||
try:
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
logger.info(f"Deleted file: {file_path}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"File not found for deletion: {file_path}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting file {file_path}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_file_metadata(file_path: str) -> dict:
|
||||
"""
|
||||
Get comprehensive file metadata.
|
||||
|
||||
Args:
|
||||
file_path: Path to file
|
||||
|
||||
Returns:
|
||||
Dictionary with file metadata
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
stat = os.stat(file_path)
|
||||
|
||||
return {
|
||||
"path": file_path,
|
||||
"size_bytes": stat.st_size,
|
||||
"created_at": stat.st_ctime,
|
||||
"modified_at": stat.st_mtime,
|
||||
"accessed_at": stat.st_atime,
|
||||
"is_file": os.path.isfile(file_path),
|
||||
"is_dir": os.path.isdir(file_path),
|
||||
"exists": True
|
||||
}
|
||||
|
||||
|
||||
class ChecksummedFile:
|
||||
"""
|
||||
Context manager for working with checksummed files.
|
||||
Automatically calculates and stores checksum when file is written.
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str, checksum_path: Optional[str] = None, algorithm: str = "sha256"):
|
||||
"""
|
||||
Initialize checksummed file handler.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
checksum_path: Path to store checksum (default: file_path + '.checksum')
|
||||
algorithm: Hash algorithm to use
|
||||
"""
|
||||
self.file_path = file_path
|
||||
self.checksum_path = checksum_path or f"{file_path}.checksum"
|
||||
self.algorithm = algorithm
|
||||
self.checksum: Optional[str] = None
|
||||
|
||||
def calculate_and_save_checksum(self) -> str:
|
||||
"""Calculate checksum and save to file"""
|
||||
self.checksum = calculate_file_checksum(self.file_path, self.algorithm)
|
||||
|
||||
with open(self.checksum_path, 'w') as f:
|
||||
f.write(f"{self.checksum} {os.path.basename(self.file_path)}\n")
|
||||
|
||||
logger.info(f"Saved checksum for {self.file_path}: {self.checksum}")
|
||||
return self.checksum
|
||||
|
||||
def load_and_verify_checksum(self) -> bool:
|
||||
"""Load expected checksum and verify file"""
|
||||
try:
|
||||
with open(self.checksum_path, 'r') as f:
|
||||
expected_checksum = f.read().strip().split()[0]
|
||||
|
||||
return verify_file_checksum(self.file_path, expected_checksum, self.algorithm)
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.warning(f"Checksum file not found: {self.checksum_path}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading checksum: {e}")
|
||||
return False
|
||||
|
||||
def get_stored_checksum(self) -> Optional[str]:
|
||||
"""Get checksum from stored file"""
|
||||
try:
|
||||
with open(self.checksum_path, 'r') as f:
|
||||
return f.read().strip().split()[0]
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
270
services/training/app/utils/ml_datetime.py
Normal file
270
services/training/app/utils/ml_datetime.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""
|
||||
ML-Specific DateTime Utilities
|
||||
|
||||
DateTime utilities for machine learning operations, specifically for:
|
||||
- Prophet forecasting model (requires timezone-naive datetimes)
|
||||
- Pandas DataFrame datetime operations
|
||||
- Time series data processing
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Union
|
||||
import pandas as pd
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def ensure_timezone_aware(dt: datetime, default_tz=timezone.utc) -> datetime:
|
||||
"""
|
||||
Ensure a datetime is timezone-aware.
|
||||
|
||||
Args:
|
||||
dt: Datetime to check
|
||||
default_tz: Timezone to apply if datetime is naive (default: UTC)
|
||||
|
||||
Returns:
|
||||
Timezone-aware datetime
|
||||
"""
|
||||
if dt is None:
|
||||
return None
|
||||
|
||||
if dt.tzinfo is None:
|
||||
return dt.replace(tzinfo=default_tz)
|
||||
return dt
|
||||
|
||||
|
||||
def ensure_timezone_naive(dt: datetime) -> datetime:
|
||||
"""
|
||||
Remove timezone information from a datetime.
|
||||
|
||||
Args:
|
||||
dt: Datetime to process
|
||||
|
||||
Returns:
|
||||
Timezone-naive datetime
|
||||
"""
|
||||
if dt is None:
|
||||
return None
|
||||
|
||||
if dt.tzinfo is not None:
|
||||
return dt.replace(tzinfo=None)
|
||||
return dt
|
||||
|
||||
|
||||
def normalize_datetime_to_utc(dt: Union[datetime, pd.Timestamp]) -> datetime:
|
||||
"""
|
||||
Normalize any datetime to UTC timezone-aware datetime.
|
||||
|
||||
Args:
|
||||
dt: Datetime or pandas Timestamp to normalize
|
||||
|
||||
Returns:
|
||||
UTC timezone-aware datetime
|
||||
"""
|
||||
if dt is None:
|
||||
return None
|
||||
|
||||
if isinstance(dt, pd.Timestamp):
|
||||
dt = dt.to_pydatetime()
|
||||
|
||||
if dt.tzinfo is None:
|
||||
return dt.replace(tzinfo=timezone.utc)
|
||||
|
||||
return dt.astimezone(timezone.utc)
|
||||
|
||||
|
||||
def normalize_dataframe_datetime_column(
|
||||
df: pd.DataFrame,
|
||||
column: str,
|
||||
target_format: str = 'naive'
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Normalize a datetime column in a dataframe to consistent format.
|
||||
|
||||
Args:
|
||||
df: DataFrame to process
|
||||
column: Name of datetime column
|
||||
target_format: 'naive' or 'aware' (UTC)
|
||||
|
||||
Returns:
|
||||
DataFrame with normalized datetime column
|
||||
"""
|
||||
if column not in df.columns:
|
||||
logger.warning(f"Column {column} not found in dataframe")
|
||||
return df
|
||||
|
||||
df[column] = pd.to_datetime(df[column])
|
||||
|
||||
if target_format == 'naive':
|
||||
if df[column].dt.tz is not None:
|
||||
df[column] = df[column].dt.tz_localize(None)
|
||||
elif target_format == 'aware':
|
||||
if df[column].dt.tz is None:
|
||||
df[column] = df[column].dt.tz_localize(timezone.utc)
|
||||
else:
|
||||
df[column] = df[column].dt.tz_convert(timezone.utc)
|
||||
else:
|
||||
raise ValueError(f"Invalid target_format: {target_format}. Must be 'naive' or 'aware'")
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def prepare_prophet_datetime(df: pd.DataFrame, datetime_col: str = 'ds') -> pd.DataFrame:
|
||||
"""
|
||||
Prepare datetime column for Prophet (requires timezone-naive datetimes).
|
||||
|
||||
Args:
|
||||
df: DataFrame with datetime column
|
||||
datetime_col: Name of datetime column (default: 'ds')
|
||||
|
||||
Returns:
|
||||
DataFrame with Prophet-compatible datetime column
|
||||
"""
|
||||
df = df.copy()
|
||||
df = normalize_dataframe_datetime_column(df, datetime_col, target_format='naive')
|
||||
return df
|
||||
|
||||
|
||||
def safe_datetime_comparison(dt1: datetime, dt2: datetime) -> int:
|
||||
"""
|
||||
Safely compare two datetimes, handling timezone mismatches.
|
||||
|
||||
Args:
|
||||
dt1: First datetime
|
||||
dt2: Second datetime
|
||||
|
||||
Returns:
|
||||
-1 if dt1 < dt2, 0 if equal, 1 if dt1 > dt2
|
||||
"""
|
||||
dt1_utc = normalize_datetime_to_utc(dt1)
|
||||
dt2_utc = normalize_datetime_to_utc(dt2)
|
||||
|
||||
if dt1_utc < dt2_utc:
|
||||
return -1
|
||||
elif dt1_utc > dt2_utc:
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def get_current_utc() -> datetime:
|
||||
"""
|
||||
Get current datetime in UTC with timezone awareness.
|
||||
|
||||
Returns:
|
||||
Current UTC datetime
|
||||
"""
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def convert_timestamp_to_datetime(timestamp: Union[int, float, str]) -> datetime:
|
||||
"""
|
||||
Convert various timestamp formats to datetime.
|
||||
|
||||
Args:
|
||||
timestamp: Unix timestamp (seconds or milliseconds) or ISO string
|
||||
|
||||
Returns:
|
||||
UTC timezone-aware datetime
|
||||
"""
|
||||
if isinstance(timestamp, str):
|
||||
dt = pd.to_datetime(timestamp)
|
||||
return normalize_datetime_to_utc(dt)
|
||||
|
||||
if timestamp > 1e10:
|
||||
timestamp = timestamp / 1000
|
||||
|
||||
dt = datetime.fromtimestamp(timestamp, tz=timezone.utc)
|
||||
return dt
|
||||
|
||||
|
||||
def align_dataframe_dates(
|
||||
dfs: list[pd.DataFrame],
|
||||
date_column: str = 'ds',
|
||||
method: str = 'inner'
|
||||
) -> list[pd.DataFrame]:
|
||||
"""
|
||||
Align multiple dataframes to have the same date range.
|
||||
|
||||
Args:
|
||||
dfs: List of DataFrames to align
|
||||
date_column: Name of the date column
|
||||
method: 'inner' (intersection) or 'outer' (union)
|
||||
|
||||
Returns:
|
||||
List of aligned DataFrames
|
||||
"""
|
||||
if not dfs:
|
||||
return []
|
||||
|
||||
if len(dfs) == 1:
|
||||
return dfs
|
||||
|
||||
all_dates = None
|
||||
|
||||
for df in dfs:
|
||||
if date_column not in df.columns:
|
||||
continue
|
||||
|
||||
dates = set(pd.to_datetime(df[date_column]).dt.date)
|
||||
|
||||
if all_dates is None:
|
||||
all_dates = dates
|
||||
else:
|
||||
if method == 'inner':
|
||||
all_dates = all_dates.intersection(dates)
|
||||
elif method == 'outer':
|
||||
all_dates = all_dates.union(dates)
|
||||
|
||||
aligned_dfs = []
|
||||
for df in dfs:
|
||||
if date_column not in df.columns:
|
||||
aligned_dfs.append(df)
|
||||
continue
|
||||
|
||||
df = df.copy()
|
||||
df[date_column] = pd.to_datetime(df[date_column])
|
||||
df['_date_only'] = df[date_column].dt.date
|
||||
df = df[df['_date_only'].isin(all_dates)]
|
||||
df = df.drop('_date_only', axis=1)
|
||||
aligned_dfs.append(df)
|
||||
|
||||
return aligned_dfs
|
||||
|
||||
|
||||
def fill_missing_dates(
|
||||
df: pd.DataFrame,
|
||||
date_column: str = 'ds',
|
||||
freq: str = 'D',
|
||||
fill_value: float = 0.0
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Fill missing dates in a DataFrame with a specified frequency.
|
||||
|
||||
Args:
|
||||
df: DataFrame with date column
|
||||
date_column: Name of the date column
|
||||
freq: Pandas frequency string ('D' for daily, 'H' for hourly, etc.)
|
||||
fill_value: Value to fill for missing dates
|
||||
|
||||
Returns:
|
||||
DataFrame with filled dates
|
||||
"""
|
||||
df = df.copy()
|
||||
df[date_column] = pd.to_datetime(df[date_column])
|
||||
|
||||
df = df.set_index(date_column)
|
||||
|
||||
full_range = pd.date_range(
|
||||
start=df.index.min(),
|
||||
end=df.index.max(),
|
||||
freq=freq
|
||||
)
|
||||
|
||||
df = df.reindex(full_range, fill_value=fill_value)
|
||||
|
||||
df = df.reset_index()
|
||||
df = df.rename(columns={'index': date_column})
|
||||
|
||||
return df
|
||||
316
services/training/app/utils/retry.py
Normal file
316
services/training/app/utils/retry.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""
|
||||
Retry Mechanism with Exponential Backoff
|
||||
Handles transient failures with intelligent retry strategies
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import random
|
||||
from typing import Callable, Any, Optional, Type, Tuple
|
||||
from functools import wraps
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RetryError(Exception):
|
||||
"""Raised when all retry attempts are exhausted"""
|
||||
def __init__(self, message: str, attempts: int, last_exception: Exception):
|
||||
super().__init__(message)
|
||||
self.attempts = attempts
|
||||
self.last_exception = last_exception
|
||||
|
||||
|
||||
class RetryStrategy:
|
||||
"""Base retry strategy"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_attempts: int = 3,
|
||||
initial_delay: float = 1.0,
|
||||
max_delay: float = 60.0,
|
||||
exponential_base: float = 2.0,
|
||||
jitter: bool = True,
|
||||
retriable_exceptions: Tuple[Type[Exception], ...] = (Exception,)
|
||||
):
|
||||
"""
|
||||
Initialize retry strategy.
|
||||
|
||||
Args:
|
||||
max_attempts: Maximum number of retry attempts
|
||||
initial_delay: Initial delay in seconds
|
||||
max_delay: Maximum delay between retries
|
||||
exponential_base: Base for exponential backoff
|
||||
jitter: Add random jitter to prevent thundering herd
|
||||
retriable_exceptions: Tuple of exception types to retry
|
||||
"""
|
||||
self.max_attempts = max_attempts
|
||||
self.initial_delay = initial_delay
|
||||
self.max_delay = max_delay
|
||||
self.exponential_base = exponential_base
|
||||
self.jitter = jitter
|
||||
self.retriable_exceptions = retriable_exceptions
|
||||
|
||||
def calculate_delay(self, attempt: int) -> float:
|
||||
"""Calculate delay for given attempt using exponential backoff"""
|
||||
delay = min(
|
||||
self.initial_delay * (self.exponential_base ** attempt),
|
||||
self.max_delay
|
||||
)
|
||||
|
||||
if self.jitter:
|
||||
# Add random jitter (0-100% of delay)
|
||||
delay = delay * (0.5 + random.random() * 0.5)
|
||||
|
||||
return delay
|
||||
|
||||
def is_retriable(self, exception: Exception) -> bool:
|
||||
"""Check if exception should trigger retry"""
|
||||
return isinstance(exception, self.retriable_exceptions)
|
||||
|
||||
|
||||
async def retry_async(
|
||||
func: Callable,
|
||||
*args,
|
||||
strategy: Optional[RetryStrategy] = None,
|
||||
**kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
Retry async function with exponential backoff.
|
||||
|
||||
Args:
|
||||
func: Async function to retry
|
||||
*args: Positional arguments for func
|
||||
strategy: Retry strategy (uses default if None)
|
||||
**kwargs: Keyword arguments for func
|
||||
|
||||
Returns:
|
||||
Result from func
|
||||
|
||||
Raises:
|
||||
RetryError: When all attempts exhausted
|
||||
"""
|
||||
if strategy is None:
|
||||
strategy = RetryStrategy()
|
||||
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(strategy.max_attempts):
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
if attempt > 0:
|
||||
logger.info(
|
||||
f"Retry succeeded on attempt {attempt + 1}",
|
||||
function=func.__name__,
|
||||
attempt=attempt + 1
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
|
||||
if not strategy.is_retriable(e):
|
||||
logger.error(
|
||||
f"Non-retriable exception occurred",
|
||||
function=func.__name__,
|
||||
exception=str(e)
|
||||
)
|
||||
raise
|
||||
|
||||
if attempt < strategy.max_attempts - 1:
|
||||
delay = strategy.calculate_delay(attempt)
|
||||
logger.warning(
|
||||
f"Attempt {attempt + 1} failed, retrying in {delay:.2f}s",
|
||||
function=func.__name__,
|
||||
attempt=attempt + 1,
|
||||
max_attempts=strategy.max_attempts,
|
||||
exception=str(e)
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
else:
|
||||
logger.error(
|
||||
f"All {strategy.max_attempts} retry attempts exhausted",
|
||||
function=func.__name__,
|
||||
exception=str(e)
|
||||
)
|
||||
|
||||
raise RetryError(
|
||||
f"Failed after {strategy.max_attempts} attempts: {str(last_exception)}",
|
||||
attempts=strategy.max_attempts,
|
||||
last_exception=last_exception
|
||||
)
|
||||
|
||||
|
||||
def with_retry(
|
||||
max_attempts: int = 3,
|
||||
initial_delay: float = 1.0,
|
||||
max_delay: float = 60.0,
|
||||
exponential_base: float = 2.0,
|
||||
jitter: bool = True,
|
||||
retriable_exceptions: Tuple[Type[Exception], ...] = (Exception,)
|
||||
):
|
||||
"""
|
||||
Decorator to add retry logic to async functions.
|
||||
|
||||
Example:
|
||||
@with_retry(max_attempts=5, initial_delay=2.0)
|
||||
async def fetch_data():
|
||||
# Your code here
|
||||
pass
|
||||
"""
|
||||
strategy = RetryStrategy(
|
||||
max_attempts=max_attempts,
|
||||
initial_delay=initial_delay,
|
||||
max_delay=max_delay,
|
||||
exponential_base=exponential_base,
|
||||
jitter=jitter,
|
||||
retriable_exceptions=retriable_exceptions
|
||||
)
|
||||
|
||||
def decorator(func: Callable):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
return await retry_async(func, *args, strategy=strategy, **kwargs)
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class AdaptiveRetryStrategy(RetryStrategy):
|
||||
"""
|
||||
Adaptive retry strategy that adjusts based on success/failure patterns.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.success_count = 0
|
||||
self.failure_count = 0
|
||||
self.consecutive_failures = 0
|
||||
|
||||
def calculate_delay(self, attempt: int) -> float:
|
||||
"""Calculate delay with adaptation based on recent history"""
|
||||
base_delay = super().calculate_delay(attempt)
|
||||
|
||||
# Increase delay if seeing consecutive failures
|
||||
if self.consecutive_failures > 5:
|
||||
multiplier = min(2.0, 1.0 + (self.consecutive_failures - 5) * 0.2)
|
||||
base_delay *= multiplier
|
||||
|
||||
return min(base_delay, self.max_delay)
|
||||
|
||||
def record_success(self):
|
||||
"""Record successful attempt"""
|
||||
self.success_count += 1
|
||||
self.consecutive_failures = 0
|
||||
|
||||
def record_failure(self):
|
||||
"""Record failed attempt"""
|
||||
self.failure_count += 1
|
||||
self.consecutive_failures += 1
|
||||
|
||||
|
||||
class TimeoutRetryStrategy(RetryStrategy):
|
||||
"""
|
||||
Retry strategy with overall timeout across all attempts.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, timeout: float = 300.0, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
timeout: Total timeout in seconds for all attempts
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.timeout = timeout
|
||||
self.start_time: Optional[float] = None
|
||||
|
||||
def should_retry(self, attempt: int) -> bool:
|
||||
"""Check if should attempt another retry"""
|
||||
if self.start_time is None:
|
||||
self.start_time = time.time()
|
||||
return True
|
||||
|
||||
elapsed = time.time() - self.start_time
|
||||
return elapsed < self.timeout and attempt < self.max_attempts
|
||||
|
||||
|
||||
async def retry_with_timeout(
|
||||
func: Callable,
|
||||
*args,
|
||||
max_attempts: int = 3,
|
||||
timeout: float = 300.0,
|
||||
**kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
Retry with overall timeout.
|
||||
|
||||
Args:
|
||||
func: Function to retry
|
||||
max_attempts: Maximum attempts
|
||||
timeout: Overall timeout in seconds
|
||||
|
||||
Returns:
|
||||
Result from func
|
||||
"""
|
||||
strategy = TimeoutRetryStrategy(
|
||||
max_attempts=max_attempts,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
strategy.start_time = start_time
|
||||
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(strategy.max_attempts):
|
||||
if time.time() - start_time >= timeout:
|
||||
raise RetryError(
|
||||
f"Timeout of {timeout}s exceeded",
|
||||
attempts=attempt + 1,
|
||||
last_exception=last_exception
|
||||
)
|
||||
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
|
||||
if not strategy.is_retriable(e):
|
||||
raise
|
||||
|
||||
if attempt < strategy.max_attempts - 1:
|
||||
delay = strategy.calculate_delay(attempt)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
raise RetryError(
|
||||
f"Failed after {strategy.max_attempts} attempts",
|
||||
attempts=strategy.max_attempts,
|
||||
last_exception=last_exception
|
||||
)
|
||||
|
||||
|
||||
# Pre-configured strategies for common use cases
|
||||
HTTP_RETRY_STRATEGY = RetryStrategy(
|
||||
max_attempts=3,
|
||||
initial_delay=1.0,
|
||||
max_delay=10.0,
|
||||
exponential_base=2.0,
|
||||
jitter=True
|
||||
)
|
||||
|
||||
DATABASE_RETRY_STRATEGY = RetryStrategy(
|
||||
max_attempts=5,
|
||||
initial_delay=0.5,
|
||||
max_delay=5.0,
|
||||
exponential_base=1.5,
|
||||
jitter=True
|
||||
)
|
||||
|
||||
EXTERNAL_SERVICE_RETRY_STRATEGY = RetryStrategy(
|
||||
max_attempts=4,
|
||||
initial_delay=2.0,
|
||||
max_delay=30.0,
|
||||
exponential_base=2.5,
|
||||
jitter=True
|
||||
)
|
||||
340
services/training/app/utils/time_estimation.py
Normal file
340
services/training/app/utils/time_estimation.py
Normal file
@@ -0,0 +1,340 @@
|
||||
"""
|
||||
Training Time Estimation Utilities
|
||||
Provides intelligent time estimation for training jobs based on:
|
||||
- Product count
|
||||
- Historical performance data
|
||||
- Current progress and throughput
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import structlog
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
def calculate_initial_estimate(
|
||||
total_products: int,
|
||||
avg_training_time_per_product: float = 60.0, # seconds, default 1 min/product
|
||||
data_analysis_overhead: float = 120.0, # seconds, data loading & analysis
|
||||
finalization_overhead: float = 60.0, # seconds, saving models & cleanup
|
||||
min_estimate_minutes: int = 5,
|
||||
max_estimate_minutes: int = 60
|
||||
) -> int:
|
||||
"""
|
||||
Calculate realistic initial time estimate for training job.
|
||||
|
||||
Formula:
|
||||
total_time = data_analysis + (products * avg_time_per_product) + finalization
|
||||
|
||||
Args:
|
||||
total_products: Number of products to train
|
||||
avg_training_time_per_product: Average time per product in seconds
|
||||
data_analysis_overhead: Time for data loading and analysis in seconds
|
||||
finalization_overhead: Time for saving models and cleanup in seconds
|
||||
min_estimate_minutes: Minimum estimate (prevents unrealistic low values)
|
||||
max_estimate_minutes: Maximum estimate (prevents unrealistic high values)
|
||||
|
||||
Returns:
|
||||
Estimated duration in minutes
|
||||
|
||||
Examples:
|
||||
>>> calculate_initial_estimate(1)
|
||||
4 # 120 + 60 + 60 = 240s = 4min
|
||||
|
||||
>>> calculate_initial_estimate(5)
|
||||
8 # 120 + 300 + 60 = 480s = 8min
|
||||
|
||||
>>> calculate_initial_estimate(10)
|
||||
13 # 120 + 600 + 60 = 780s = 13min
|
||||
|
||||
>>> calculate_initial_estimate(20)
|
||||
23 # 120 + 1200 + 60 = 1380s = 23min
|
||||
|
||||
>>> calculate_initial_estimate(100)
|
||||
60 # Capped at max (would be 103 min)
|
||||
"""
|
||||
# Calculate total estimated time in seconds
|
||||
estimated_seconds = (
|
||||
data_analysis_overhead +
|
||||
(total_products * avg_training_time_per_product) +
|
||||
finalization_overhead
|
||||
)
|
||||
|
||||
# Convert to minutes, round up
|
||||
estimated_minutes = int((estimated_seconds / 60) + 0.5)
|
||||
|
||||
# Apply min/max bounds
|
||||
estimated_minutes = max(min_estimate_minutes, min(max_estimate_minutes, estimated_minutes))
|
||||
|
||||
logger.info(
|
||||
"Calculated initial time estimate",
|
||||
total_products=total_products,
|
||||
estimated_seconds=estimated_seconds,
|
||||
estimated_minutes=estimated_minutes,
|
||||
avg_time_per_product=avg_training_time_per_product
|
||||
)
|
||||
|
||||
return estimated_minutes
|
||||
|
||||
|
||||
def calculate_estimated_completion_time(
|
||||
estimated_duration_minutes: int,
|
||||
start_time: Optional[datetime] = None
|
||||
) -> datetime:
|
||||
"""
|
||||
Calculate estimated completion timestamp.
|
||||
|
||||
Args:
|
||||
estimated_duration_minutes: Estimated duration in minutes
|
||||
start_time: Job start time (defaults to now)
|
||||
|
||||
Returns:
|
||||
Estimated completion datetime (timezone-aware UTC)
|
||||
"""
|
||||
if start_time is None:
|
||||
start_time = datetime.now(timezone.utc)
|
||||
|
||||
completion_time = start_time + timedelta(minutes=estimated_duration_minutes)
|
||||
|
||||
return completion_time
|
||||
|
||||
|
||||
def calculate_remaining_time_smart(
|
||||
progress: int,
|
||||
elapsed_time: float,
|
||||
products_completed: int,
|
||||
total_products: int,
|
||||
recent_product_times: Optional[List[float]] = None,
|
||||
max_remaining_seconds: int = 1800 # 30 minutes
|
||||
) -> Optional[int]:
|
||||
"""
|
||||
Calculate remaining time using smart algorithm that considers:
|
||||
- Current progress percentage
|
||||
- Actual throughput (products completed / elapsed time)
|
||||
- Recent performance (weighted moving average)
|
||||
|
||||
Args:
|
||||
progress: Current progress percentage (0-100)
|
||||
elapsed_time: Time elapsed since job start (seconds)
|
||||
products_completed: Number of products completed
|
||||
total_products: Total number of products
|
||||
recent_product_times: List of recent product training times (seconds)
|
||||
max_remaining_seconds: Maximum remaining time (safety cap)
|
||||
|
||||
Returns:
|
||||
Estimated remaining time in seconds, or None if can't calculate
|
||||
"""
|
||||
# Job completed or not started
|
||||
if progress >= 100 or progress <= 0:
|
||||
return None
|
||||
|
||||
# Early stage (0-20%): Use weighted estimate
|
||||
if progress <= 20:
|
||||
# In data analysis phase - estimate based on remaining products
|
||||
remaining_products = total_products - products_completed
|
||||
|
||||
if recent_product_times and len(recent_product_times) > 0:
|
||||
# Use recent performance if available
|
||||
avg_time_per_product = sum(recent_product_times) / len(recent_product_times)
|
||||
else:
|
||||
# Fallback to default
|
||||
avg_time_per_product = 60.0 # 1 minute per product
|
||||
|
||||
# Estimate: remaining products * avg time + overhead
|
||||
estimated_remaining = (remaining_products * avg_time_per_product) + 60.0 # +1 min overhead
|
||||
|
||||
logger.debug(
|
||||
"Early stage estimation",
|
||||
progress=progress,
|
||||
remaining_products=remaining_products,
|
||||
avg_time_per_product=avg_time_per_product,
|
||||
estimated_remaining=estimated_remaining
|
||||
)
|
||||
|
||||
# Mid/late stage (21-99%): Use actual throughput
|
||||
else:
|
||||
if products_completed > 0:
|
||||
# Calculate actual time per product from current run
|
||||
actual_time_per_product = elapsed_time / products_completed
|
||||
remaining_products = total_products - products_completed
|
||||
estimated_remaining = remaining_products * actual_time_per_product
|
||||
|
||||
logger.debug(
|
||||
"Mid/late stage estimation",
|
||||
progress=progress,
|
||||
products_completed=products_completed,
|
||||
total_products=total_products,
|
||||
actual_time_per_product=actual_time_per_product,
|
||||
estimated_remaining=estimated_remaining
|
||||
)
|
||||
else:
|
||||
# Fallback to linear extrapolation
|
||||
estimated_total = (elapsed_time / progress) * 100
|
||||
estimated_remaining = estimated_total - elapsed_time
|
||||
|
||||
logger.debug(
|
||||
"Fallback linear estimation",
|
||||
progress=progress,
|
||||
elapsed_time=elapsed_time,
|
||||
estimated_remaining=estimated_remaining
|
||||
)
|
||||
|
||||
# Apply safety cap
|
||||
estimated_remaining = min(estimated_remaining, max_remaining_seconds)
|
||||
|
||||
return int(estimated_remaining)
|
||||
|
||||
|
||||
def calculate_average_product_time(
|
||||
products_completed: int,
|
||||
elapsed_time: float,
|
||||
min_products_threshold: int = 3
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
Calculate average time per product from current job progress.
|
||||
|
||||
Args:
|
||||
products_completed: Number of products completed
|
||||
elapsed_time: Time elapsed since job start (seconds)
|
||||
min_products_threshold: Minimum products needed for reliable calculation
|
||||
|
||||
Returns:
|
||||
Average time per product in seconds, or None if insufficient data
|
||||
"""
|
||||
if products_completed < min_products_threshold:
|
||||
return None
|
||||
|
||||
avg_time = elapsed_time / products_completed
|
||||
|
||||
logger.debug(
|
||||
"Calculated average product time",
|
||||
products_completed=products_completed,
|
||||
elapsed_time=elapsed_time,
|
||||
avg_time=avg_time
|
||||
)
|
||||
|
||||
return avg_time
|
||||
|
||||
|
||||
def format_time_remaining(seconds: int) -> str:
|
||||
"""
|
||||
Format remaining time in human-readable format.
|
||||
|
||||
Args:
|
||||
seconds: Time in seconds
|
||||
|
||||
Returns:
|
||||
Formatted string (e.g., "5 minutes", "1 hour 23 minutes")
|
||||
|
||||
Examples:
|
||||
>>> format_time_remaining(45)
|
||||
"45 seconds"
|
||||
|
||||
>>> format_time_remaining(180)
|
||||
"3 minutes"
|
||||
|
||||
>>> format_time_remaining(5400)
|
||||
"1 hour 30 minutes"
|
||||
"""
|
||||
if seconds < 60:
|
||||
return f"{seconds} seconds"
|
||||
|
||||
minutes = seconds // 60
|
||||
remaining_seconds = seconds % 60
|
||||
|
||||
if minutes < 60:
|
||||
if remaining_seconds > 0:
|
||||
return f"{minutes} minutes {remaining_seconds} seconds"
|
||||
return f"{minutes} minutes"
|
||||
|
||||
hours = minutes // 60
|
||||
remaining_minutes = minutes % 60
|
||||
|
||||
if remaining_minutes > 0:
|
||||
return f"{hours} hour{'s' if hours > 1 else ''} {remaining_minutes} minutes"
|
||||
return f"{hours} hour{'s' if hours > 1 else ''}"
|
||||
|
||||
|
||||
async def get_historical_average_estimate(
|
||||
db_session: AsyncSession,
|
||||
tenant_id: str,
|
||||
lookback_days: int = 30,
|
||||
limit: int = 10
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
Get historical average training time per product for a tenant.
|
||||
|
||||
This function queries the TrainingPerformanceMetrics table to get
|
||||
recent historical data and calculate an average.
|
||||
|
||||
Args:
|
||||
db_session: Async database session
|
||||
tenant_id: Tenant UUID
|
||||
lookback_days: How many days back to look
|
||||
limit: Maximum number of historical records to consider
|
||||
|
||||
Returns:
|
||||
Average time per product in seconds, or None if no historical data
|
||||
"""
|
||||
try:
|
||||
from app.models.training import TrainingPerformanceMetrics
|
||||
from datetime import timedelta
|
||||
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(days=lookback_days)
|
||||
|
||||
# Query recent training performance metrics using SQLAlchemy 2.0 async pattern
|
||||
query = (
|
||||
select(TrainingPerformanceMetrics)
|
||||
.where(
|
||||
TrainingPerformanceMetrics.tenant_id == tenant_id,
|
||||
TrainingPerformanceMetrics.completed_at >= cutoff
|
||||
)
|
||||
.order_by(TrainingPerformanceMetrics.completed_at.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
result = await db_session.execute(query)
|
||||
metrics = result.scalars().all()
|
||||
|
||||
if not metrics:
|
||||
logger.info(
|
||||
"No historical training data found",
|
||||
tenant_id=tenant_id,
|
||||
lookback_days=lookback_days
|
||||
)
|
||||
return None
|
||||
|
||||
# Calculate weighted average (more recent = higher weight)
|
||||
total_weight = 0
|
||||
weighted_sum = 0
|
||||
|
||||
for i, metric in enumerate(metrics):
|
||||
# Weight: newer records get higher weight
|
||||
weight = limit - i
|
||||
weighted_sum += metric.avg_time_per_product * weight
|
||||
total_weight += weight
|
||||
|
||||
if total_weight == 0:
|
||||
return None
|
||||
|
||||
weighted_avg = weighted_sum / total_weight
|
||||
|
||||
logger.info(
|
||||
"Calculated historical average",
|
||||
tenant_id=tenant_id,
|
||||
records_used=len(metrics),
|
||||
weighted_avg=weighted_avg
|
||||
)
|
||||
|
||||
return weighted_avg
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error getting historical average",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
return None
|
||||
11
services/training/app/websocket/__init__.py
Normal file
11
services/training/app/websocket/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""WebSocket support for training service"""
|
||||
|
||||
from app.websocket.manager import websocket_manager, WebSocketConnectionManager
|
||||
from app.websocket.events import setup_websocket_event_consumer, cleanup_websocket_consumers
|
||||
|
||||
__all__ = [
|
||||
'websocket_manager',
|
||||
'WebSocketConnectionManager',
|
||||
'setup_websocket_event_consumer',
|
||||
'cleanup_websocket_consumers'
|
||||
]
|
||||
148
services/training/app/websocket/events.py
Normal file
148
services/training/app/websocket/events.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""
|
||||
RabbitMQ Event Consumer for WebSocket Broadcasting
|
||||
Listens to training events from RabbitMQ and broadcasts them to WebSocket clients
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Dict, Set
|
||||
import structlog
|
||||
|
||||
from app.websocket.manager import websocket_manager
|
||||
from app.services.training_events import training_publisher
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Track active consumers
|
||||
_active_consumers: Set[asyncio.Task] = set()
|
||||
|
||||
|
||||
async def handle_training_event(message) -> None:
|
||||
"""
|
||||
Handle incoming RabbitMQ training events and broadcast to WebSocket clients.
|
||||
This is the bridge between RabbitMQ and WebSocket.
|
||||
"""
|
||||
try:
|
||||
# Parse message
|
||||
body = message.body.decode()
|
||||
data = json.loads(body)
|
||||
|
||||
event_type = data.get('event_type', 'unknown')
|
||||
event_data = data.get('data', {})
|
||||
job_id = event_data.get('job_id')
|
||||
|
||||
if not job_id:
|
||||
logger.warning("Received event without job_id, skipping", event_type=event_type)
|
||||
await message.ack()
|
||||
return
|
||||
|
||||
logger.info("Received training event from RabbitMQ",
|
||||
job_id=job_id,
|
||||
event_type=event_type,
|
||||
progress=event_data.get('progress'))
|
||||
|
||||
# Map RabbitMQ event types to WebSocket message types
|
||||
ws_message_type = _map_event_type(event_type)
|
||||
|
||||
# Create WebSocket message
|
||||
ws_message = {
|
||||
"type": ws_message_type,
|
||||
"job_id": job_id,
|
||||
"timestamp": data.get('timestamp'),
|
||||
"data": event_data
|
||||
}
|
||||
|
||||
# Broadcast to all WebSocket clients for this job
|
||||
sent_count = await websocket_manager.broadcast(job_id, ws_message)
|
||||
|
||||
logger.info("Broadcasted event to WebSocket clients",
|
||||
job_id=job_id,
|
||||
event_type=event_type,
|
||||
ws_message_type=ws_message_type,
|
||||
clients_notified=sent_count)
|
||||
|
||||
# Always acknowledge the message to avoid infinite redelivery loops
|
||||
# Progress events (started, progress, product_completed) are ephemeral and don't need redelivery
|
||||
# Final events (completed, failed) should always be acknowledged
|
||||
await message.ack()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error handling training event",
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
# Always acknowledge even on error to avoid infinite redelivery loops
|
||||
# The event is logged so we can debug issues
|
||||
try:
|
||||
await message.ack()
|
||||
except:
|
||||
pass # Message already gone or connection closed
|
||||
|
||||
|
||||
def _map_event_type(rabbitmq_event_type: str) -> str:
|
||||
"""Map RabbitMQ event types to WebSocket message types"""
|
||||
mapping = {
|
||||
"training.started": "started",
|
||||
"training.progress": "progress",
|
||||
"training.step.completed": "step_completed",
|
||||
"training.product.completed": "product_completed",
|
||||
"training.completed": "completed",
|
||||
"training.failed": "failed",
|
||||
}
|
||||
return mapping.get(rabbitmq_event_type, "unknown")
|
||||
|
||||
|
||||
async def setup_websocket_event_consumer() -> bool:
|
||||
"""
|
||||
Set up a global RabbitMQ consumer that listens to all training events
|
||||
and broadcasts them to connected WebSocket clients.
|
||||
"""
|
||||
try:
|
||||
# Ensure publisher is connected
|
||||
if not training_publisher.connected:
|
||||
logger.info("Connecting training publisher for WebSocket event consumer")
|
||||
success = await training_publisher.connect()
|
||||
if not success:
|
||||
logger.error("Failed to connect training publisher")
|
||||
return False
|
||||
|
||||
# Create a unique queue for WebSocket broadcasting
|
||||
queue_name = "training_websocket_broadcast"
|
||||
|
||||
logger.info("Setting up WebSocket event consumer", queue_name=queue_name)
|
||||
|
||||
# Subscribe to all training events (routing key: training.#)
|
||||
success = await training_publisher.consume_events(
|
||||
exchange_name="training.events",
|
||||
queue_name=queue_name,
|
||||
routing_key="training.#", # Listen to all training events (multi-level)
|
||||
callback=handle_training_event
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("WebSocket event consumer set up successfully")
|
||||
return True
|
||||
else:
|
||||
logger.error("Failed to set up WebSocket event consumer")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error setting up WebSocket event consumer",
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
async def cleanup_websocket_consumers() -> None:
|
||||
"""Clean up WebSocket event consumers"""
|
||||
logger.info("Cleaning up WebSocket event consumers")
|
||||
|
||||
for task in _active_consumers:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
_active_consumers.clear()
|
||||
logger.info("WebSocket event consumers cleaned up")
|
||||
300
services/training/app/websocket/manager.py
Normal file
300
services/training/app/websocket/manager.py
Normal file
@@ -0,0 +1,300 @@
|
||||
"""
|
||||
WebSocket Connection Manager for Training Service
|
||||
Manages WebSocket connections and broadcasts RabbitMQ events to connected clients
|
||||
|
||||
HORIZONTAL SCALING:
|
||||
- Uses Redis pub/sub for cross-pod WebSocket broadcasting
|
||||
- Each pod subscribes to a Redis channel and broadcasts to its local connections
|
||||
- Events published to Redis are received by all pods, ensuring clients on any
|
||||
pod receive events from training jobs running on any other pod
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, Optional
|
||||
from fastapi import WebSocket
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Redis pub/sub channel for WebSocket events
|
||||
REDIS_WEBSOCKET_CHANNEL = "training:websocket:events"
|
||||
|
||||
|
||||
class WebSocketConnectionManager:
|
||||
"""
|
||||
WebSocket connection manager with Redis pub/sub for horizontal scaling.
|
||||
|
||||
In a multi-pod deployment:
|
||||
1. Events are published to Redis pub/sub (not just local broadcast)
|
||||
2. Each pod subscribes to Redis and broadcasts to its local WebSocket connections
|
||||
3. This ensures clients connected to any pod receive events from any pod
|
||||
|
||||
Flow:
|
||||
- RabbitMQ event → Pod A receives → Pod A publishes to Redis
|
||||
- Redis pub/sub → All pods receive → Each pod broadcasts to local WebSockets
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Structure: {job_id: {websocket_id: WebSocket}}
|
||||
self._connections: Dict[str, Dict[int, WebSocket]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
# Store latest event for each job to provide initial state
|
||||
self._latest_events: Dict[str, dict] = {}
|
||||
# Redis client for pub/sub
|
||||
self._redis: Optional[object] = None
|
||||
self._pubsub: Optional[object] = None
|
||||
self._subscriber_task: Optional[asyncio.Task] = None
|
||||
self._running = False
|
||||
self._instance_id = f"{os.environ.get('HOSTNAME', 'unknown')}:{os.getpid()}"
|
||||
|
||||
async def initialize_redis(self, redis_url: str) -> bool:
|
||||
"""
|
||||
Initialize Redis connection for cross-pod pub/sub.
|
||||
|
||||
Args:
|
||||
redis_url: Redis connection URL
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
import redis.asyncio as redis_async
|
||||
|
||||
self._redis = redis_async.from_url(redis_url, decode_responses=True)
|
||||
await self._redis.ping()
|
||||
|
||||
# Create pub/sub subscriber
|
||||
self._pubsub = self._redis.pubsub()
|
||||
await self._pubsub.subscribe(REDIS_WEBSOCKET_CHANNEL)
|
||||
|
||||
# Start subscriber task
|
||||
self._running = True
|
||||
self._subscriber_task = asyncio.create_task(self._redis_subscriber_loop())
|
||||
|
||||
logger.info("Redis pub/sub initialized for WebSocket broadcasting",
|
||||
instance_id=self._instance_id,
|
||||
channel=REDIS_WEBSOCKET_CHANNEL)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize Redis pub/sub",
|
||||
error=str(e),
|
||||
instance_id=self._instance_id)
|
||||
return False
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown Redis pub/sub connection"""
|
||||
self._running = False
|
||||
|
||||
if self._subscriber_task:
|
||||
self._subscriber_task.cancel()
|
||||
try:
|
||||
await self._subscriber_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self._pubsub:
|
||||
await self._pubsub.unsubscribe(REDIS_WEBSOCKET_CHANNEL)
|
||||
await self._pubsub.close()
|
||||
|
||||
if self._redis:
|
||||
await self._redis.close()
|
||||
|
||||
logger.info("Redis pub/sub shutdown complete",
|
||||
instance_id=self._instance_id)
|
||||
|
||||
async def _redis_subscriber_loop(self):
|
||||
"""Background task to receive Redis pub/sub messages and broadcast locally"""
|
||||
try:
|
||||
while self._running:
|
||||
try:
|
||||
message = await self._pubsub.get_message(
|
||||
ignore_subscribe_messages=True,
|
||||
timeout=1.0
|
||||
)
|
||||
|
||||
if message and message['type'] == 'message':
|
||||
await self._handle_redis_message(message['data'])
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error("Error in Redis subscriber loop",
|
||||
error=str(e),
|
||||
instance_id=self._instance_id)
|
||||
await asyncio.sleep(1) # Backoff on error
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info("Redis subscriber loop stopped",
|
||||
instance_id=self._instance_id)
|
||||
|
||||
async def _handle_redis_message(self, data: str):
|
||||
"""Handle a message received from Redis pub/sub"""
|
||||
try:
|
||||
payload = json.loads(data)
|
||||
job_id = payload.get('job_id')
|
||||
message = payload.get('message')
|
||||
source_instance = payload.get('source_instance')
|
||||
|
||||
if not job_id or not message:
|
||||
return
|
||||
|
||||
# Log cross-pod message
|
||||
if source_instance != self._instance_id:
|
||||
logger.debug("Received cross-pod WebSocket event",
|
||||
job_id=job_id,
|
||||
source_instance=source_instance,
|
||||
local_instance=self._instance_id)
|
||||
|
||||
# Broadcast to local WebSocket connections
|
||||
await self._broadcast_local(job_id, message)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning("Invalid JSON in Redis message", error=str(e))
|
||||
except Exception as e:
|
||||
logger.error("Error handling Redis message", error=str(e))
|
||||
|
||||
async def connect(self, job_id: str, websocket: WebSocket) -> None:
|
||||
"""Register a new WebSocket connection for a job"""
|
||||
await websocket.accept()
|
||||
|
||||
async with self._lock:
|
||||
if job_id not in self._connections:
|
||||
self._connections[job_id] = {}
|
||||
|
||||
ws_id = id(websocket)
|
||||
self._connections[job_id][ws_id] = websocket
|
||||
|
||||
# Send initial state if available
|
||||
if job_id in self._latest_events:
|
||||
try:
|
||||
await websocket.send_json({
|
||||
"type": "initial_state",
|
||||
"job_id": job_id,
|
||||
"data": self._latest_events[job_id]
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning("Failed to send initial state to new connection", error=str(e))
|
||||
|
||||
logger.info("WebSocket connected",
|
||||
job_id=job_id,
|
||||
websocket_id=ws_id,
|
||||
total_connections=len(self._connections[job_id]),
|
||||
instance_id=self._instance_id)
|
||||
|
||||
async def disconnect(self, job_id: str, websocket: WebSocket) -> None:
|
||||
"""Remove a WebSocket connection"""
|
||||
async with self._lock:
|
||||
if job_id in self._connections:
|
||||
ws_id = id(websocket)
|
||||
self._connections[job_id].pop(ws_id, None)
|
||||
|
||||
# Clean up empty job connections
|
||||
if not self._connections[job_id]:
|
||||
del self._connections[job_id]
|
||||
|
||||
logger.info("WebSocket disconnected",
|
||||
job_id=job_id,
|
||||
websocket_id=ws_id,
|
||||
remaining_connections=len(self._connections.get(job_id, {})),
|
||||
instance_id=self._instance_id)
|
||||
|
||||
async def broadcast(self, job_id: str, message: dict) -> int:
|
||||
"""
|
||||
Broadcast a message to all connections for a specific job across ALL pods.
|
||||
|
||||
If Redis is configured, publishes to Redis pub/sub which then broadcasts
|
||||
to all pods. Otherwise, falls back to local-only broadcast.
|
||||
|
||||
Returns the number of successful local broadcasts.
|
||||
"""
|
||||
# Store the latest event for this job to provide initial state to new connections
|
||||
if message.get('type') != 'initial_state':
|
||||
self._latest_events[job_id] = message
|
||||
|
||||
# If Redis is available, publish to Redis for cross-pod broadcast
|
||||
if self._redis:
|
||||
try:
|
||||
payload = json.dumps({
|
||||
'job_id': job_id,
|
||||
'message': message,
|
||||
'source_instance': self._instance_id
|
||||
})
|
||||
await self._redis.publish(REDIS_WEBSOCKET_CHANNEL, payload)
|
||||
logger.debug("Published WebSocket event to Redis",
|
||||
job_id=job_id,
|
||||
message_type=message.get('type'),
|
||||
instance_id=self._instance_id)
|
||||
# Return 0 here because the actual broadcast happens via subscriber
|
||||
# The count will be from _broadcast_local when the message is received
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish to Redis, falling back to local broadcast",
|
||||
error=str(e),
|
||||
job_id=job_id)
|
||||
# Fall through to local broadcast
|
||||
|
||||
# Local-only broadcast (when Redis is not available)
|
||||
return await self._broadcast_local(job_id, message)
|
||||
|
||||
async def _broadcast_local(self, job_id: str, message: dict) -> int:
|
||||
"""
|
||||
Broadcast a message to local WebSocket connections only.
|
||||
This is called either directly (no Redis) or from Redis subscriber.
|
||||
"""
|
||||
if job_id not in self._connections:
|
||||
logger.debug("No active local connections for job",
|
||||
job_id=job_id,
|
||||
instance_id=self._instance_id)
|
||||
return 0
|
||||
|
||||
connections = list(self._connections[job_id].values())
|
||||
successful_sends = 0
|
||||
failed_websockets = []
|
||||
|
||||
for websocket in connections:
|
||||
try:
|
||||
await websocket.send_json(message)
|
||||
successful_sends += 1
|
||||
except Exception as e:
|
||||
logger.warning("Failed to send message to WebSocket",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
failed_websockets.append(websocket)
|
||||
|
||||
# Clean up failed connections
|
||||
if failed_websockets:
|
||||
async with self._lock:
|
||||
for ws in failed_websockets:
|
||||
ws_id = id(ws)
|
||||
self._connections[job_id].pop(ws_id, None)
|
||||
|
||||
if successful_sends > 0:
|
||||
logger.info("Broadcasted message to local WebSocket clients",
|
||||
job_id=job_id,
|
||||
message_type=message.get('type'),
|
||||
successful_sends=successful_sends,
|
||||
failed_sends=len(failed_websockets),
|
||||
instance_id=self._instance_id)
|
||||
|
||||
return successful_sends
|
||||
|
||||
def get_connection_count(self, job_id: str) -> int:
|
||||
"""Get the number of active local connections for a job"""
|
||||
return len(self._connections.get(job_id, {}))
|
||||
|
||||
def get_total_connection_count(self) -> int:
|
||||
"""Get total number of active connections across all jobs"""
|
||||
return sum(len(conns) for conns in self._connections.values())
|
||||
|
||||
def is_redis_enabled(self) -> bool:
|
||||
"""Check if Redis pub/sub is enabled"""
|
||||
return self._redis is not None and self._running
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
websocket_manager = WebSocketConnectionManager()
|
||||
141
services/training/migrations/env.py
Normal file
141
services/training/migrations/env.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""Alembic environment configuration for training service"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from logging.config import fileConfig
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||
from alembic import context
|
||||
|
||||
# Add the service directory to the Python path
|
||||
service_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
if service_path not in sys.path:
|
||||
sys.path.insert(0, service_path)
|
||||
|
||||
# Add shared modules to path
|
||||
shared_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "shared"))
|
||||
if shared_path not in sys.path:
|
||||
sys.path.insert(0, shared_path)
|
||||
|
||||
try:
|
||||
from app.core.config import settings
|
||||
from shared.database.base import Base
|
||||
|
||||
# Import all models to ensure they are registered with Base.metadata
|
||||
from app.models import * # noqa: F401, F403
|
||||
|
||||
except ImportError as e:
|
||||
print(f"Import error in migrations env.py: {e}")
|
||||
print(f"Current Python path: {sys.path}")
|
||||
raise
|
||||
|
||||
# this is the Alembic Config object
|
||||
config = context.config
|
||||
|
||||
# Determine service name from file path
|
||||
service_name = os.path.basename(os.path.dirname(os.path.dirname(__file__)))
|
||||
service_name_upper = service_name.upper().replace('-', '_')
|
||||
|
||||
# Set database URL from environment variables with multiple fallback strategies
|
||||
database_url = (
|
||||
os.getenv(f'{service_name_upper}_DATABASE_URL') or # Service-specific
|
||||
os.getenv('DATABASE_URL') # Generic fallback
|
||||
)
|
||||
|
||||
# If DATABASE_URL is not set, construct from individual components
|
||||
if not database_url:
|
||||
# Try generic PostgreSQL environment variables first
|
||||
postgres_host = os.getenv('POSTGRES_HOST')
|
||||
postgres_port = os.getenv('POSTGRES_PORT', '5432')
|
||||
postgres_db = os.getenv('POSTGRES_DB')
|
||||
postgres_user = os.getenv('POSTGRES_USER')
|
||||
postgres_password = os.getenv('POSTGRES_PASSWORD')
|
||||
|
||||
if all([postgres_host, postgres_db, postgres_user, postgres_password]):
|
||||
database_url = f"postgresql+asyncpg://{postgres_user}:{postgres_password}@{postgres_host}:{postgres_port}/{postgres_db}"
|
||||
else:
|
||||
# Try service-specific environment variables
|
||||
db_host = os.getenv(f'{service_name_upper}_DB_HOST', f'{service_name}-db-service')
|
||||
db_port = os.getenv(f'{service_name_upper}_DB_PORT', '5432')
|
||||
db_name = os.getenv(f'{service_name_upper}_DB_NAME', f'{service_name.replace("-", "_")}_db')
|
||||
db_user = os.getenv(f'{service_name_upper}_DB_USER', f'{service_name.replace("-", "_")}_user')
|
||||
db_password = os.getenv(f'{service_name_upper}_DB_PASSWORD')
|
||||
|
||||
if db_password:
|
||||
database_url = f"postgresql+asyncpg://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}"
|
||||
else:
|
||||
# Final fallback: try to get from settings object
|
||||
try:
|
||||
database_url = getattr(settings, 'DATABASE_URL', None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not database_url:
|
||||
error_msg = f"ERROR: No database URL configured for {service_name} service"
|
||||
print(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
config.set_main_option("sqlalchemy.url", database_url)
|
||||
|
||||
# Interpret the config file for Python logging
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# Set target metadata
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode."""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
compare_type=True,
|
||||
compare_server_default=True,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
"""Execute migrations with the given connection."""
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
compare_type=True,
|
||||
compare_server_default=True,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
"""Run migrations in 'online' mode with async support."""
|
||||
connectable = async_engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
|
||||
await connectable.dispose()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode."""
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
26
services/training/migrations/script.py.mako
Normal file
26
services/training/migrations/script.py.mako
Normal file
@@ -0,0 +1,26 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
@@ -0,0 +1,250 @@
|
||||
"""Initial schema with all training tables and columns
|
||||
|
||||
Revision ID: 26a665cd5348
|
||||
Revises:
|
||||
Create Date: 2025-10-15 12:29:01.717552+02:00
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '26a665cd5348'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create audit_logs table
|
||||
op.create_table('audit_logs',
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('tenant_id', sa.UUID(), nullable=False),
|
||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
||||
sa.Column('action', sa.String(length=100), nullable=False),
|
||||
sa.Column('resource_type', sa.String(length=100), nullable=False),
|
||||
sa.Column('resource_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('severity', sa.String(length=20), nullable=False),
|
||||
sa.Column('service_name', sa.String(length=100), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('changes', postgresql.JSON(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('audit_metadata', postgresql.JSON(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('ip_address', sa.String(length=45), nullable=True),
|
||||
sa.Column('user_agent', sa.Text(), nullable=True),
|
||||
sa.Column('endpoint', sa.String(length=255), nullable=True),
|
||||
sa.Column('method', sa.String(length=10), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index('idx_audit_resource_type_action', 'audit_logs', ['resource_type', 'action'], unique=False)
|
||||
op.create_index('idx_audit_service_created', 'audit_logs', ['service_name', 'created_at'], unique=False)
|
||||
op.create_index('idx_audit_severity_created', 'audit_logs', ['severity', 'created_at'], unique=False)
|
||||
op.create_index('idx_audit_tenant_created', 'audit_logs', ['tenant_id', 'created_at'], unique=False)
|
||||
op.create_index('idx_audit_user_created', 'audit_logs', ['user_id', 'created_at'], unique=False)
|
||||
op.create_index(op.f('ix_audit_logs_action'), 'audit_logs', ['action'], unique=False)
|
||||
op.create_index(op.f('ix_audit_logs_created_at'), 'audit_logs', ['created_at'], unique=False)
|
||||
op.create_index(op.f('ix_audit_logs_resource_id'), 'audit_logs', ['resource_id'], unique=False)
|
||||
op.create_index(op.f('ix_audit_logs_resource_type'), 'audit_logs', ['resource_type'], unique=False)
|
||||
op.create_index(op.f('ix_audit_logs_service_name'), 'audit_logs', ['service_name'], unique=False)
|
||||
op.create_index(op.f('ix_audit_logs_severity'), 'audit_logs', ['severity'], unique=False)
|
||||
op.create_index(op.f('ix_audit_logs_tenant_id'), 'audit_logs', ['tenant_id'], unique=False)
|
||||
op.create_index(op.f('ix_audit_logs_user_id'), 'audit_logs', ['user_id'], unique=False)
|
||||
|
||||
# Create trained_models table
|
||||
op.create_table('trained_models',
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('tenant_id', sa.UUID(), nullable=False),
|
||||
sa.Column('inventory_product_id', sa.UUID(), nullable=False),
|
||||
sa.Column('model_type', sa.String(), nullable=True),
|
||||
sa.Column('model_version', sa.String(), nullable=True),
|
||||
sa.Column('job_id', sa.String(), nullable=False),
|
||||
sa.Column('model_path', sa.String(), nullable=False),
|
||||
sa.Column('metadata_path', sa.String(), nullable=True),
|
||||
sa.Column('mape', sa.Float(), nullable=True),
|
||||
sa.Column('mae', sa.Float(), nullable=True),
|
||||
sa.Column('rmse', sa.Float(), nullable=True),
|
||||
sa.Column('r2_score', sa.Float(), nullable=True),
|
||||
sa.Column('training_samples', sa.Integer(), nullable=True),
|
||||
sa.Column('hyperparameters', sa.JSON(), nullable=True),
|
||||
sa.Column('features_used', sa.JSON(), nullable=True),
|
||||
sa.Column('normalization_params', sa.JSON(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=True),
|
||||
sa.Column('is_production', sa.Boolean(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('last_used_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('training_start_date', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('training_end_date', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('data_quality_score', sa.Float(), nullable=True),
|
||||
sa.Column('notes', sa.Text(), nullable=True),
|
||||
sa.Column('created_by', sa.String(), nullable=True),
|
||||
sa.Column('product_category', sa.String(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_trained_models_inventory_product_id'), 'trained_models', ['inventory_product_id'], unique=False)
|
||||
op.create_index(op.f('ix_trained_models_tenant_id'), 'trained_models', ['tenant_id'], unique=False)
|
||||
|
||||
# Create model_training_logs table
|
||||
op.create_table('model_training_logs',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('job_id', sa.String(length=255), nullable=False),
|
||||
sa.Column('tenant_id', sa.UUID(), nullable=False),
|
||||
sa.Column('status', sa.String(length=50), nullable=False),
|
||||
sa.Column('progress', sa.Integer(), nullable=True),
|
||||
sa.Column('current_step', sa.String(length=500), nullable=True),
|
||||
sa.Column('start_time', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('end_time', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('config', sa.JSON(), nullable=True),
|
||||
sa.Column('results', sa.JSON(), nullable=True),
|
||||
sa.Column('error_message', sa.Text(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_model_training_logs_id'), 'model_training_logs', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_model_training_logs_job_id'), 'model_training_logs', ['job_id'], unique=True)
|
||||
op.create_index(op.f('ix_model_training_logs_tenant_id'), 'model_training_logs', ['tenant_id'], unique=False)
|
||||
|
||||
# Create model_performance_metrics table
|
||||
op.create_table('model_performance_metrics',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('model_id', sa.String(length=255), nullable=False),
|
||||
sa.Column('tenant_id', sa.UUID(), nullable=False),
|
||||
sa.Column('inventory_product_id', sa.UUID(), nullable=False),
|
||||
sa.Column('mae', sa.Float(), nullable=True),
|
||||
sa.Column('mse', sa.Float(), nullable=True),
|
||||
sa.Column('rmse', sa.Float(), nullable=True),
|
||||
sa.Column('mape', sa.Float(), nullable=True),
|
||||
sa.Column('r2_score', sa.Float(), nullable=True),
|
||||
sa.Column('accuracy_percentage', sa.Float(), nullable=True),
|
||||
sa.Column('prediction_confidence', sa.Float(), nullable=True),
|
||||
sa.Column('evaluation_period_start', sa.DateTime(), nullable=True),
|
||||
sa.Column('evaluation_period_end', sa.DateTime(), nullable=True),
|
||||
sa.Column('evaluation_samples', sa.Integer(), nullable=True),
|
||||
sa.Column('measured_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_model_performance_metrics_id'), 'model_performance_metrics', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_model_performance_metrics_inventory_product_id'), 'model_performance_metrics', ['inventory_product_id'], unique=False)
|
||||
op.create_index(op.f('ix_model_performance_metrics_model_id'), 'model_performance_metrics', ['model_id'], unique=False)
|
||||
op.create_index(op.f('ix_model_performance_metrics_tenant_id'), 'model_performance_metrics', ['tenant_id'], unique=False)
|
||||
|
||||
# Create training_job_queue table
|
||||
op.create_table('training_job_queue',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('job_id', sa.String(length=255), nullable=False),
|
||||
sa.Column('tenant_id', sa.UUID(), nullable=False),
|
||||
sa.Column('job_type', sa.String(length=50), nullable=False),
|
||||
sa.Column('priority', sa.Integer(), nullable=True),
|
||||
sa.Column('config', sa.JSON(), nullable=True),
|
||||
sa.Column('scheduled_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('started_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('estimated_duration_minutes', sa.Integer(), nullable=True),
|
||||
sa.Column('status', sa.String(length=50), nullable=False),
|
||||
sa.Column('retry_count', sa.Integer(), nullable=True),
|
||||
sa.Column('max_retries', sa.Integer(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('cancelled_by', sa.String(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_training_job_queue_id'), 'training_job_queue', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_training_job_queue_job_id'), 'training_job_queue', ['job_id'], unique=True)
|
||||
op.create_index(op.f('ix_training_job_queue_tenant_id'), 'training_job_queue', ['tenant_id'], unique=False)
|
||||
|
||||
# Create model_artifacts table
|
||||
op.create_table('model_artifacts',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('model_id', sa.String(length=255), nullable=False),
|
||||
sa.Column('tenant_id', sa.UUID(), nullable=False),
|
||||
sa.Column('artifact_type', sa.String(length=50), nullable=False),
|
||||
sa.Column('file_path', sa.String(length=1000), nullable=False),
|
||||
sa.Column('file_size_bytes', sa.Integer(), nullable=True),
|
||||
sa.Column('checksum', sa.String(length=255), nullable=True),
|
||||
sa.Column('storage_location', sa.String(length=100), nullable=False),
|
||||
sa.Column('compression', sa.String(length=50), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_model_artifacts_id'), 'model_artifacts', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_model_artifacts_model_id'), 'model_artifacts', ['model_id'], unique=False)
|
||||
op.create_index(op.f('ix_model_artifacts_tenant_id'), 'model_artifacts', ['tenant_id'], unique=False)
|
||||
|
||||
# Create training_performance_metrics table
|
||||
op.create_table('training_performance_metrics',
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('tenant_id', sa.UUID(), nullable=False),
|
||||
sa.Column('job_id', sa.String(length=255), nullable=False),
|
||||
sa.Column('total_products', sa.Integer(), nullable=False),
|
||||
sa.Column('successful_products', sa.Integer(), nullable=False),
|
||||
sa.Column('failed_products', sa.Integer(), nullable=False),
|
||||
sa.Column('total_duration_seconds', sa.Float(), nullable=False),
|
||||
sa.Column('avg_time_per_product', sa.Float(), nullable=False),
|
||||
sa.Column('data_analysis_time_seconds', sa.Float(), nullable=True),
|
||||
sa.Column('training_time_seconds', sa.Float(), nullable=True),
|
||||
sa.Column('finalization_time_seconds', sa.Float(), nullable=True),
|
||||
sa.Column('completed_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_training_performance_metrics_job_id'), 'training_performance_metrics', ['job_id'], unique=False)
|
||||
op.create_index(op.f('ix_training_performance_metrics_tenant_id'), 'training_performance_metrics', ['tenant_id'], unique=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop training_performance_metrics table
|
||||
op.drop_index(op.f('ix_training_performance_metrics_tenant_id'), table_name='training_performance_metrics')
|
||||
op.drop_index(op.f('ix_training_performance_metrics_job_id'), table_name='training_performance_metrics')
|
||||
op.drop_table('training_performance_metrics')
|
||||
|
||||
# Drop model_artifacts table
|
||||
op.drop_index(op.f('ix_model_artifacts_tenant_id'), table_name='model_artifacts')
|
||||
op.drop_index(op.f('ix_model_artifacts_model_id'), table_name='model_artifacts')
|
||||
op.drop_index(op.f('ix_model_artifacts_id'), table_name='model_artifacts')
|
||||
op.drop_table('model_artifacts')
|
||||
|
||||
# Drop training_job_queue table
|
||||
op.drop_index(op.f('ix_training_job_queue_tenant_id'), table_name='training_job_queue')
|
||||
op.drop_index(op.f('ix_training_job_queue_job_id'), table_name='training_job_queue')
|
||||
op.drop_index(op.f('ix_training_job_queue_id'), table_name='training_job_queue')
|
||||
op.drop_table('training_job_queue')
|
||||
|
||||
# Drop model_performance_metrics table
|
||||
op.drop_index(op.f('ix_model_performance_metrics_tenant_id'), table_name='model_performance_metrics')
|
||||
op.drop_index(op.f('ix_model_performance_metrics_model_id'), table_name='model_performance_metrics')
|
||||
op.drop_index(op.f('ix_model_performance_metrics_inventory_product_id'), table_name='model_performance_metrics')
|
||||
op.drop_index(op.f('ix_model_performance_metrics_id'), table_name='model_performance_metrics')
|
||||
op.drop_table('model_performance_metrics')
|
||||
|
||||
# Drop model_training_logs table
|
||||
op.drop_index(op.f('ix_model_training_logs_tenant_id'), table_name='model_training_logs')
|
||||
op.drop_index(op.f('ix_model_training_logs_job_id'), table_name='model_training_logs')
|
||||
op.drop_index(op.f('ix_model_training_logs_id'), table_name='model_training_logs')
|
||||
op.drop_table('model_training_logs')
|
||||
|
||||
# Drop trained_models table (with the product_category column)
|
||||
op.drop_index(op.f('ix_trained_models_tenant_id'), table_name='trained_models')
|
||||
op.drop_index(op.f('ix_trained_models_inventory_product_id'), table_name='trained_models')
|
||||
op.drop_table('trained_models')
|
||||
|
||||
# Drop audit_logs table
|
||||
op.drop_index(op.f('ix_audit_logs_user_id'), table_name='audit_logs')
|
||||
op.drop_index(op.f('ix_audit_logs_tenant_id'), table_name='audit_logs')
|
||||
op.drop_index(op.f('ix_audit_logs_severity'), table_name='audit_logs')
|
||||
op.drop_index(op.f('ix_audit_logs_service_name'), table_name='audit_logs')
|
||||
op.drop_index(op.f('ix_audit_logs_resource_type'), table_name='audit_logs')
|
||||
op.drop_index(op.f('ix_audit_logs_resource_id'), table_name='audit_logs')
|
||||
op.drop_index(op.f('ix_audit_logs_created_at'), table_name='audit_logs')
|
||||
op.drop_index(op.f('ix_audit_logs_action'), table_name='audit_logs')
|
||||
op.drop_index('idx_audit_user_created', table_name='audit_logs')
|
||||
op.drop_index('idx_audit_tenant_created', table_name='audit_logs')
|
||||
op.drop_index('idx_audit_severity_created', table_name='audit_logs')
|
||||
op.drop_index('idx_audit_service_created', table_name='audit_logs')
|
||||
op.drop_index('idx_audit_resource_type_action', table_name='audit_logs')
|
||||
op.drop_table('audit_logs')
|
||||
@@ -0,0 +1,60 @@
|
||||
"""Add horizontal scaling constraints for multi-pod deployment
|
||||
|
||||
Revision ID: add_horizontal_scaling
|
||||
Revises: 26a665cd5348
|
||||
Create Date: 2025-01-18
|
||||
|
||||
This migration adds database-level constraints to prevent race conditions
|
||||
when running multiple training service pods:
|
||||
|
||||
1. Partial unique index on model_training_logs to prevent duplicate active jobs per tenant
|
||||
2. Index to speed up active job lookups
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'add_horizontal_scaling'
|
||||
down_revision: Union[str, None] = '26a665cd5348'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add partial unique index to prevent duplicate active training jobs per tenant
|
||||
# This ensures only ONE job can be in 'pending' or 'running' status per tenant at a time
|
||||
# The constraint is enforced at the database level, preventing race conditions
|
||||
# between multiple pods checking and creating jobs simultaneously
|
||||
op.execute("""
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_unique_active_training_per_tenant
|
||||
ON model_training_logs (tenant_id)
|
||||
WHERE status IN ('pending', 'running')
|
||||
""")
|
||||
|
||||
# Add index to speed up active job lookups (used by deduplication check)
|
||||
op.create_index(
|
||||
'idx_training_logs_tenant_status',
|
||||
'model_training_logs',
|
||||
['tenant_id', 'status'],
|
||||
unique=False,
|
||||
if_not_exists=True
|
||||
)
|
||||
|
||||
# Add index for job recovery queries (find stale running jobs)
|
||||
op.create_index(
|
||||
'idx_training_logs_status_updated',
|
||||
'model_training_logs',
|
||||
['status', 'updated_at'],
|
||||
unique=False,
|
||||
if_not_exists=True
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the indexes in reverse order
|
||||
op.execute("DROP INDEX IF EXISTS idx_training_logs_status_updated")
|
||||
op.execute("DROP INDEX IF EXISTS idx_training_logs_tenant_status")
|
||||
op.execute("DROP INDEX IF EXISTS idx_unique_active_training_per_tenant")
|
||||
64
services/training/requirements.txt
Normal file
64
services/training/requirements.txt
Normal file
@@ -0,0 +1,64 @@
|
||||
# services/training/requirements.txt
|
||||
# FastAPI and server
|
||||
fastapi==0.119.0
|
||||
uvicorn[standard]==0.32.1
|
||||
python-multipart==0.0.6
|
||||
|
||||
# Database
|
||||
sqlalchemy==2.0.44
|
||||
asyncpg==0.30.0
|
||||
alembic==1.17.0
|
||||
psycopg2-binary==2.9.10
|
||||
|
||||
# ML libraries
|
||||
prophet==1.2.1
|
||||
cmdstanpy==1.2.4
|
||||
scikit-learn==1.6.1
|
||||
pandas==2.2.3
|
||||
numpy==2.2.2
|
||||
joblib==1.4.2
|
||||
minio==7.2.2
|
||||
xgboost==2.1.3
|
||||
|
||||
# HTTP client
|
||||
httpx==0.28.1
|
||||
|
||||
# Validation
|
||||
pydantic==2.12.3
|
||||
pydantic-settings==2.7.1
|
||||
email-validator==2.2.0
|
||||
|
||||
# Authentication
|
||||
python-jose[cryptography]==3.3.0
|
||||
passlib[bcrypt]==1.7.4
|
||||
cryptography==44.0.0
|
||||
|
||||
# Messaging
|
||||
aio-pika==9.4.3
|
||||
|
||||
# Monitoring and logging
|
||||
structlog==25.4.0
|
||||
opentelemetry-api==1.39.1
|
||||
opentelemetry-sdk==1.39.1
|
||||
opentelemetry-instrumentation-fastapi==0.60b1
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.39.1
|
||||
opentelemetry-instrumentation-httpx==0.60b1
|
||||
opentelemetry-instrumentation-redis==0.60b1
|
||||
opentelemetry-instrumentation-sqlalchemy==0.60b1
|
||||
|
||||
# Development and testing
|
||||
pytest==8.3.4
|
||||
pytest-asyncio==0.25.2
|
||||
pytest-mock==3.14.0
|
||||
pytest-cov==6.0.0
|
||||
coverage==7.6.9
|
||||
psutil==6.1.1
|
||||
|
||||
# Utilities
|
||||
python-dateutil==2.9.0.post0
|
||||
pytz==2024.2
|
||||
holidays==0.63
|
||||
|
||||
# Hyperparameter optimization
|
||||
optuna==4.2.0
|
||||
redis==6.4.0
|
||||
Reference in New Issue
Block a user