Files
bakery-ia/services/training/app/schemas/training.py
2025-07-19 16:59:37 +02:00

181 lines
9.8 KiB
Python

# services/training/app/schemas/training.py
"""
Pydantic schemas for training service
"""
from pydantic import BaseModel, Field, validator
from typing import Dict, List, Any, Optional
from datetime import datetime
from enum import Enum
class TrainingStatus(str, Enum):
"""Training job status enumeration"""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class TrainingJobRequest(BaseModel):
"""Request schema for starting a training job"""
products: Optional[List[str]] = Field(None, description="Specific products to train (if None, train all)")
include_weather: bool = Field(True, description="Include weather data in training")
include_traffic: bool = Field(True, description="Include traffic data in training")
start_date: Optional[datetime] = Field(None, description="Start date for training data")
end_date: Optional[datetime] = Field(None, description="End date for training data")
min_data_points: int = Field(30, description="Minimum data points required per product")
estimated_duration: Optional[int] = Field(None, description="Estimated duration in minutes")
# Prophet-specific parameters
seasonality_mode: str = Field("additive", description="Prophet seasonality mode")
daily_seasonality: bool = Field(True, description="Enable daily seasonality")
weekly_seasonality: bool = Field(True, description="Enable weekly seasonality")
yearly_seasonality: bool = Field(True, description="Enable yearly seasonality")
@validator('seasonality_mode')
def validate_seasonality_mode(cls, v):
if v not in ['additive', 'multiplicative']:
raise ValueError('seasonality_mode must be additive or multiplicative')
return v
@validator('min_data_points')
def validate_min_data_points(cls, v):
if v < 7:
raise ValueError('min_data_points must be at least 7')
return v
class SingleProductTrainingRequest(BaseModel):
"""Request schema for training a single product"""
include_weather: bool = Field(True, description="Include weather data in training")
include_traffic: bool = Field(True, description="Include traffic data in training")
start_date: Optional[datetime] = Field(None, description="Start date for training data")
end_date: Optional[datetime] = Field(None, description="End date for training data")
# Prophet-specific parameters
seasonality_mode: str = Field("additive", description="Prophet seasonality mode")
daily_seasonality: bool = Field(True, description="Enable daily seasonality")
weekly_seasonality: bool = Field(True, description="Enable weekly seasonality")
yearly_seasonality: bool = Field(True, description="Enable yearly seasonality")
class TrainingJobResponse(BaseModel):
"""Response schema for training job creation"""
job_id: str = Field(..., description="Unique training job identifier")
status: TrainingStatus = Field(..., description="Current job status")
message: str = Field(..., description="Status message")
tenant_id: str = Field(..., description="Tenant identifier")
created_at: datetime = Field(..., description="Job creation timestamp")
estimated_duration_minutes: int = Field(..., description="Estimated completion time in minutes")
class TrainingStatusResponse(BaseModel):
"""Response schema for training job status"""
job_id: str = Field(..., description="Training job identifier")
status: TrainingStatus = Field(..., description="Current job status")
progress: int = Field(0, description="Progress percentage (0-100)")
current_step: str = Field("", description="Current processing step")
started_at: datetime = Field(..., description="Job start timestamp")
completed_at: Optional[datetime] = Field(None, description="Job completion timestamp")
results: Optional[Dict[str, Any]] = Field(None, description="Training results")
error_message: Optional[str] = Field(None, description="Error message if failed")
class ModelInfo(BaseModel):
"""Schema for trained model information"""
model_id: str = Field(..., description="Unique model identifier")
model_path: str = Field(..., description="Path to stored model")
model_type: str = Field("prophet", description="Type of ML model")
training_samples: int = Field(..., description="Number of training samples")
features: List[str] = Field(..., description="List of features used")
hyperparameters: Dict[str, Any] = Field(..., description="Model hyperparameters")
training_metrics: Dict[str, float] = Field(..., description="Training performance metrics")
trained_at: datetime = Field(..., description="Training completion timestamp")
data_period: Dict[str, str] = Field(..., description="Training data period")
class ProductTrainingResult(BaseModel):
"""Schema for individual product training result"""
product_name: str = Field(..., description="Product name")
status: str = Field(..., description="Training status for this product")
model_info: Optional[ModelInfo] = Field(None, description="Model information if successful")
data_points: int = Field(..., description="Number of data points used")
error_message: Optional[str] = Field(None, description="Error message if failed")
trained_at: datetime = Field(..., description="Training completion timestamp")
class TrainingResultsResponse(BaseModel):
"""Response schema for complete training results"""
job_id: str = Field(..., description="Training job identifier")
tenant_id: str = Field(..., description="Tenant identifier")
status: TrainingStatus = Field(..., description="Overall job status")
products_trained: int = Field(..., description="Number of products successfully trained")
products_failed: int = Field(..., description="Number of products that failed training")
total_products: int = Field(..., description="Total number of products processed")
training_results: Dict[str, ProductTrainingResult] = Field(..., description="Per-product results")
summary: Dict[str, Any] = Field(..., description="Training summary statistics")
completed_at: datetime = Field(..., description="Job completion timestamp")
class TrainingValidationResult(BaseModel):
"""Schema for training data validation results"""
is_valid: bool = Field(..., description="Whether the data is valid for training")
issues: List[str] = Field(default_factory=list, description="List of data quality issues")
recommendations: List[str] = Field(default_factory=list, description="Recommendations for improvement")
estimated_time_minutes: int = Field(..., description="Estimated training time in minutes")
products_analyzed: int = Field(..., description="Number of products analyzed")
total_data_points: int = Field(..., description="Total data points available")
class TrainingMetrics(BaseModel):
"""Schema for training performance metrics"""
mae: float = Field(..., description="Mean Absolute Error")
mse: float = Field(..., description="Mean Squared Error")
rmse: float = Field(..., description="Root Mean Squared Error")
mape: float = Field(..., description="Mean Absolute Percentage Error")
r2_score: float = Field(..., description="R-squared score")
mean_actual: float = Field(..., description="Mean of actual values")
mean_predicted: float = Field(..., description="Mean of predicted values")
class ExternalDataConfig(BaseModel):
"""Configuration for external data sources"""
weather_enabled: bool = Field(True, description="Enable weather data")
traffic_enabled: bool = Field(True, description="Enable traffic data")
weather_features: List[str] = Field(
default_factory=lambda: ["temperature", "precipitation", "humidity"],
description="Weather features to include"
)
traffic_features: List[str] = Field(
default_factory=lambda: ["traffic_volume"],
description="Traffic features to include"
)
class TrainingJobConfig(BaseModel):
"""Complete training job configuration"""
external_data: ExternalDataConfig = Field(default_factory=ExternalDataConfig)
prophet_params: Dict[str, Any] = Field(
default_factory=lambda: {
"seasonality_mode": "additive",
"daily_seasonality": True,
"weekly_seasonality": True,
"yearly_seasonality": True
},
description="Prophet model parameters"
)
data_filters: Dict[str, Any] = Field(
default_factory=dict,
description="Data filtering parameters"
)
validation_params: Dict[str, Any] = Field(
default_factory=lambda: {"min_data_points": 30},
description="Data validation parameters"
)
class TrainedModelResponse(BaseModel):
"""Response schema for trained model information"""
model_id: str = Field(..., description="Unique model identifier")
tenant_id: str = Field(..., description="Tenant identifier")
product_name: str = Field(..., description="Product name")
model_type: str = Field(..., description="Type of ML model")
model_path: str = Field(..., description="Path to stored model")
version: int = Field(..., description="Model version")
training_samples: int = Field(..., description="Number of training samples")
features: List[str] = Field(..., description="List of features used")
hyperparameters: Dict[str, Any] = Field(..., description="Model hyperparameters")
training_metrics: Dict[str, float] = Field(..., description="Training performance metrics")
is_active: bool = Field(..., description="Whether model is active")
created_at: datetime = Field(..., description="Model creation timestamp")
data_period_start: Optional[datetime] = Field(None, description="Training data start date")
data_period_end: Optional[datetime] = Field(None, description="Training data end date")