Files
bakery-ia/services/training/app/schemas/training.py
2025-07-28 19:28:39 +02:00

337 lines
16 KiB
Python

# services/training/app/schemas/training.py
"""
Complete schema definitions for training service
Includes all request/response schemas used by the API endpoints
"""
from pydantic import BaseModel, Field, validator
from typing import List, Optional, Dict, Any, Union
from datetime import datetime
from enum import Enum
from uuid import UUID
class TrainingStatus(str, Enum):
"""Training job status enumeration"""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class TrainingJobRequest(BaseModel):
"""Request schema for starting a training job"""
products: Optional[List[str]] = Field(None, description="Specific products to train (if None, trains all)")
start_date: Optional[datetime] = Field(None, description="Start date for training data")
end_date: Optional[datetime] = Field(None, description="End date for training data")
# Prophet-specific parameters
seasonality_mode: str = Field("additive", description="Prophet seasonality mode", pattern="^(additive|multiplicative)$")
daily_seasonality: bool = Field(True, description="Enable daily seasonality")
weekly_seasonality: bool = Field(True, description="Enable weekly seasonality")
yearly_seasonality: bool = Field(True, description="Enable yearly seasonality")
# Advanced configuration
force_retrain: bool = Field(False, description="Force retraining even if recent model exists")
parallel_training: bool = Field(True, description="Train products in parallel")
max_workers: int = Field(4, description="Maximum parallel workers", ge=1, le=10)
@validator('seasonality_mode')
def validate_seasonality_mode(cls, v):
if v not in ['additive', 'multiplicative']:
raise ValueError('seasonality_mode must be either "additive" or "multiplicative"')
return v
class SingleProductTrainingRequest(BaseModel):
"""Request schema for training a single product"""
start_date: Optional[datetime] = Field(None, description="Start date for training data")
end_date: Optional[datetime] = Field(None, description="End date for training data")
# Prophet-specific parameters
seasonality_mode: str = Field("additive", description="Prophet seasonality mode")
daily_seasonality: bool = Field(True, description="Enable daily seasonality")
weekly_seasonality: bool = Field(True, description="Enable weekly seasonality")
yearly_seasonality: bool = Field(True, description="Enable yearly seasonality")
class TrainingJobResponse(BaseModel):
"""Response schema for training job creation"""
job_id: str = Field(..., description="Unique training job identifier")
status: TrainingStatus = Field(..., description="Current job status")
message: str = Field(..., description="Status message")
tenant_id: str = Field(..., description="Tenant identifier")
created_at: datetime = Field(..., description="Job creation timestamp")
estimated_duration_minutes: int = Field(..., description="Estimated completion time in minutes")
# ✅ FIX: Add custom validator to convert UUID to string
@validator('tenant_id', 'job_id', pre=True)
def convert_uuid_to_string(cls, v):
"""Convert UUID objects to strings for JSON serialization"""
if isinstance(v, UUID):
return str(v)
return v
class Config:
from_attributes = True
class TrainingJobStatus(BaseModel):
"""Response schema for training job status checks"""
job_id: str = Field(..., description="Training job identifier")
status: TrainingStatus = Field(..., description="Current job status")
progress: int = Field(0, description="Progress percentage (0-100)")
current_step: str = Field("", description="Current processing step")
started_at: datetime = Field(..., description="Job start timestamp")
completed_at: Optional[datetime] = Field(None, description="Job completion timestamp")
products_total: int = Field(0, description="Total number of products to train")
products_completed: int = Field(0, description="Number of products completed")
products_failed: int = Field(0, description="Number of products that failed")
error_message: Optional[str] = Field(None, description="Error message if failed")
@validator('job_id', pre=True)
def convert_uuid_to_string(cls, v):
"""Convert UUID objects to strings for JSON serialization"""
if isinstance(v, UUID):
return str(v)
return v
class Config:
from_attributes = True
class TrainingJobProgress(BaseModel):
"""Schema for real-time training job progress updates"""
job_id: str = Field(..., description="Training job identifier")
status: TrainingStatus = Field(..., description="Current job status")
progress: int = Field(0, description="Progress percentage (0-100)", ge=0, le=100)
current_step: str = Field(..., description="Current processing step")
current_product: Optional[str] = Field(None, description="Currently training product")
products_completed: int = Field(0, description="Number of products completed")
products_total: int = Field(0, description="Total number of products")
estimated_time_remaining_minutes: Optional[int] = Field(None, description="Estimated time remaining")
timestamp: datetime = Field(default_factory=datetime.now, description="Progress update timestamp")
@validator('job_id', pre=True)
def convert_uuid_to_string(cls, v):
"""Convert UUID objects to strings for JSON serialization"""
if isinstance(v, UUID):
return str(v)
return v
class Config:
from_attributes = True
class DataValidationRequest(BaseModel):
"""Request schema for validating training data"""
products: Optional[List[str]] = Field(None, description="Specific products to validate (if None, validates all)")
min_data_points: int = Field(30, description="Minimum required data points per product", ge=10, le=1000)
start_date: Optional[datetime] = Field(None, description="Start date for data validation")
end_date: Optional[datetime] = Field(None, description="End date for data validation")
@validator('min_data_points')
def validate_min_data_points(cls, v):
if v < 10:
raise ValueError('min_data_points must be at least 10')
return v
class DataValidationResponse(BaseModel):
"""Response schema for data validation results"""
is_valid: bool = Field(..., description="Whether the data is valid for training")
issues: List[str] = Field(default_factory=list, description="List of data quality issues")
recommendations: List[str] = Field(default_factory=list, description="Recommendations for improvement")
estimated_time_minutes: int = Field(..., description="Estimated training time in minutes")
products_analyzed: int = Field(..., description="Number of products analyzed")
total_data_points: int = Field(..., description="Total data points available")
products_with_insufficient_data: List[str] = Field(default_factory=list, description="Products with insufficient data")
data_quality_score: float = Field(0.0, description="Overall data quality score (0-1)", ge=0.0, le=1.0)
class ModelInfo(BaseModel):
"""Schema for trained model information"""
model_id: str = Field(..., description="Unique model identifier")
model_path: str = Field(..., description="Path to stored model")
model_type: str = Field("prophet", description="Type of ML model")
training_samples: int = Field(..., description="Number of training samples")
features: List[str] = Field(..., description="List of features used")
hyperparameters: Dict[str, Any] = Field(..., description="Model hyperparameters")
training_metrics: Dict[str, float] = Field(..., description="Training performance metrics")
trained_at: datetime = Field(..., description="Training completion timestamp")
data_period: Dict[str, str] = Field(..., description="Training data period")
class ProductTrainingResult(BaseModel):
"""Schema for individual product training result"""
product_name: str = Field(..., description="Product name")
status: str = Field(..., description="Training status for this product")
model_info: Optional[ModelInfo] = Field(None, description="Model information if successful")
data_points: int = Field(..., description="Number of data points used")
error_message: Optional[str] = Field(None, description="Error message if failed")
trained_at: datetime = Field(..., description="Training completion timestamp")
training_duration_seconds: Optional[float] = Field(None, description="Training duration in seconds")
class TrainingResultsResponse(BaseModel):
"""Response schema for complete training results"""
job_id: str = Field(..., description="Training job identifier")
tenant_id: str = Field(..., description="Tenant identifier")
status: TrainingStatus = Field(..., description="Overall job status")
products_trained: int = Field(..., description="Number of products successfully trained")
products_failed: int = Field(..., description="Number of products that failed training")
total_products: int = Field(..., description="Total number of products processed")
training_results: Dict[str, ProductTrainingResult] = Field(..., description="Per-product results")
summary: Dict[str, Any] = Field(..., description="Training summary statistics")
completed_at: datetime = Field(..., description="Job completion timestamp")
@validator('tenant_id', 'job_id', pre=True)
def convert_uuid_to_string(cls, v):
"""Convert UUID objects to strings for JSON serialization"""
if isinstance(v, UUID):
return str(v)
return v
class Config:
from_attributes = True
class TrainingValidationResult(BaseModel):
"""Schema for training data validation results"""
is_valid: bool = Field(..., description="Whether the data is valid for training")
issues: List[str] = Field(default_factory=list, description="List of data quality issues")
recommendations: List[str] = Field(default_factory=list, description="Recommendations for improvement")
estimated_time_minutes: int = Field(..., description="Estimated training time in minutes")
products_analyzed: int = Field(..., description="Number of products analyzed")
total_data_points: int = Field(..., description="Total data points available")
class TrainingMetrics(BaseModel):
"""Schema for training performance metrics"""
mae: float = Field(..., description="Mean Absolute Error")
mse: float = Field(..., description="Mean Squared Error")
rmse: float = Field(..., description="Root Mean Squared Error")
mape: float = Field(..., description="Mean Absolute Percentage Error")
r2_score: float = Field(..., description="R-squared score")
mean_actual: float = Field(..., description="Mean of actual values")
mean_predicted: float = Field(..., description="Mean of predicted values")
class ExternalDataConfig(BaseModel):
"""Configuration for external data sources"""
weather_enabled: bool = Field(True, description="Enable weather data")
traffic_enabled: bool = Field(True, description="Enable traffic data")
weather_features: List[str] = Field(
default_factory=lambda: ["temperature", "precipitation", "humidity"],
description="Weather features to include"
)
traffic_features: List[str] = Field(
default_factory=lambda: ["traffic_volume"],
description="Traffic features to include"
)
class TrainingJobConfig(BaseModel):
"""Complete training job configuration"""
external_data: ExternalDataConfig = Field(default_factory=ExternalDataConfig)
prophet_params: Dict[str, Any] = Field(
default_factory=lambda: {
"seasonality_mode": "additive",
"daily_seasonality": True,
"weekly_seasonality": True,
"yearly_seasonality": True
},
description="Prophet model parameters"
)
data_filters: Dict[str, Any] = Field(
default_factory=dict,
description="Data filtering parameters"
)
validation_params: Dict[str, Any] = Field(
default_factory=lambda: {"min_data_points": 30},
description="Data validation parameters"
)
class TrainedModelResponse(BaseModel):
"""Response schema for trained model information"""
model_id: str = Field(..., description="Unique model identifier")
tenant_id: str = Field(..., description="Tenant identifier")
product_name: str = Field(..., description="Product name")
model_type: str = Field(..., description="Type of ML model")
model_path: str = Field(..., description="Path to stored model")
version: int = Field(..., description="Model version")
training_samples: int = Field(..., description="Number of training samples")
features: List[str] = Field(..., description="List of features used")
hyperparameters: Dict[str, Any] = Field(..., description="Model hyperparameters")
training_metrics: Dict[str, float] = Field(..., description="Training performance metrics")
is_active: bool = Field(..., description="Whether model is active")
created_at: datetime = Field(..., description="Model creation timestamp")
data_period_start: Optional[datetime] = Field(None, description="Training data start date")
data_period_end: Optional[datetime] = Field(None, description="Training data end date")
@validator('tenant_id', 'model_id', pre=True)
def convert_uuid_to_string(cls, v):
"""Convert UUID objects to strings for JSON serialization"""
if isinstance(v, UUID):
return str(v)
return v
class Config:
from_attributes = True
class ModelTrainingStats(BaseModel):
"""Schema for model training statistics"""
total_models: int = Field(..., description="Total number of trained models")
active_models: int = Field(..., description="Number of active models")
last_training_date: Optional[datetime] = Field(None, description="Last training date")
avg_training_time_minutes: float = Field(..., description="Average training time in minutes")
success_rate: float = Field(..., description="Training success rate (0-1)")
class BulkTrainingRequest(BaseModel):
"""Request schema for bulk training operations"""
tenant_ids: List[str] = Field(..., description="List of tenant IDs to train")
config: TrainingJobConfig = Field(default_factory=TrainingJobConfig, description="Training configuration")
priority: int = Field(1, description="Training priority (1-10)", ge=1, le=10)
schedule_time: Optional[datetime] = Field(None, description="Schedule training for specific time")
class TrainingScheduleResponse(BaseModel):
"""Response schema for scheduled training jobs"""
schedule_id: str = Field(..., description="Unique schedule identifier")
tenant_ids: List[str] = Field(..., description="Scheduled tenant IDs")
scheduled_time: datetime = Field(..., description="Scheduled execution time")
status: str = Field(..., description="Schedule status")
created_at: datetime = Field(..., description="Schedule creation timestamp")
# WebSocket response schemas for real-time updates
class TrainingProgressUpdate(BaseModel):
"""WebSocket message for training progress updates"""
type: str = Field("training_progress", description="Message type")
job_id: str = Field(..., description="Training job identifier")
progress: TrainingJobProgress = Field(..., description="Progress information")
class TrainingCompletedUpdate(BaseModel):
"""WebSocket message for training completion"""
type: str = Field("training_completed", description="Message type")
job_id: str = Field(..., description="Training job identifier")
results: TrainingResultsResponse = Field(..., description="Training results")
class TrainingErrorUpdate(BaseModel):
"""WebSocket message for training errors"""
type: str = Field("training_error", description="Message type")
job_id: str = Field(..., description="Training job identifier")
error: str = Field(..., description="Error message")
timestamp: datetime = Field(default_factory=datetime.now, description="Error timestamp")
# Union type for all WebSocket messages
TrainingWebSocketMessage = Union[
TrainingProgressUpdate,
TrainingCompletedUpdate,
TrainingErrorUpdate
]