# services/training/app/schemas/training.py """ Pydantic schemas for training service """ from pydantic import BaseModel, Field, validator from typing import Dict, List, Any, Optional from datetime import datetime from enum import Enum class TrainingStatus(str, Enum): """Training job status enumeration""" PENDING = "pending" RUNNING = "running" COMPLETED = "completed" FAILED = "failed" CANCELLED = "cancelled" class TrainingJobRequest(BaseModel): """Request schema for starting a training job""" products: Optional[List[str]] = Field(None, description="Specific products to train (if None, train all)") include_weather: bool = Field(True, description="Include weather data in training") include_traffic: bool = Field(True, description="Include traffic data in training") start_date: Optional[datetime] = Field(None, description="Start date for training data") end_date: Optional[datetime] = Field(None, description="End date for training data") min_data_points: int = Field(30, description="Minimum data points required per product") estimated_duration: Optional[int] = Field(None, description="Estimated duration in minutes") # Prophet-specific parameters seasonality_mode: str = Field("additive", description="Prophet seasonality mode") daily_seasonality: bool = Field(True, description="Enable daily seasonality") weekly_seasonality: bool = Field(True, description="Enable weekly seasonality") yearly_seasonality: bool = Field(True, description="Enable yearly seasonality") @validator('seasonality_mode') def validate_seasonality_mode(cls, v): if v not in ['additive', 'multiplicative']: raise ValueError('seasonality_mode must be additive or multiplicative') return v @validator('min_data_points') def validate_min_data_points(cls, v): if v < 7: raise ValueError('min_data_points must be at least 7') return v class SingleProductTrainingRequest(BaseModel): """Request schema for training a single product""" include_weather: bool = Field(True, description="Include weather data in training") include_traffic: bool = Field(True, description="Include traffic data in training") start_date: Optional[datetime] = Field(None, description="Start date for training data") end_date: Optional[datetime] = Field(None, description="End date for training data") # Prophet-specific parameters seasonality_mode: str = Field("additive", description="Prophet seasonality mode") daily_seasonality: bool = Field(True, description="Enable daily seasonality") weekly_seasonality: bool = Field(True, description="Enable weekly seasonality") yearly_seasonality: bool = Field(True, description="Enable yearly seasonality") class TrainingJobResponse(BaseModel): """Response schema for training job creation""" job_id: str = Field(..., description="Unique training job identifier") status: TrainingStatus = Field(..., description="Current job status") message: str = Field(..., description="Status message") tenant_id: str = Field(..., description="Tenant identifier") created_at: datetime = Field(..., description="Job creation timestamp") estimated_duration_minutes: int = Field(..., description="Estimated completion time in minutes") class TrainingStatusResponse(BaseModel): """Response schema for training job status""" job_id: str = Field(..., description="Training job identifier") status: TrainingStatus = Field(..., description="Current job status") progress: int = Field(0, description="Progress percentage (0-100)") current_step: str = Field("", description="Current processing step") started_at: datetime = Field(..., description="Job start timestamp") completed_at: Optional[datetime] = Field(None, description="Job completion timestamp") results: Optional[Dict[str, Any]] = Field(None, description="Training results") error_message: Optional[str] = Field(None, description="Error message if failed") class ModelInfo(BaseModel): """Schema for trained model information""" model_id: str = Field(..., description="Unique model identifier") model_path: str = Field(..., description="Path to stored model") model_type: str = Field("prophet", description="Type of ML model") training_samples: int = Field(..., description="Number of training samples") features: List[str] = Field(..., description="List of features used") hyperparameters: Dict[str, Any] = Field(..., description="Model hyperparameters") training_metrics: Dict[str, float] = Field(..., description="Training performance metrics") trained_at: datetime = Field(..., description="Training completion timestamp") data_period: Dict[str, str] = Field(..., description="Training data period") class ProductTrainingResult(BaseModel): """Schema for individual product training result""" product_name: str = Field(..., description="Product name") status: str = Field(..., description="Training status for this product") model_info: Optional[ModelInfo] = Field(None, description="Model information if successful") data_points: int = Field(..., description="Number of data points used") error_message: Optional[str] = Field(None, description="Error message if failed") trained_at: datetime = Field(..., description="Training completion timestamp") class TrainingResultsResponse(BaseModel): """Response schema for complete training results""" job_id: str = Field(..., description="Training job identifier") tenant_id: str = Field(..., description="Tenant identifier") status: TrainingStatus = Field(..., description="Overall job status") products_trained: int = Field(..., description="Number of products successfully trained") products_failed: int = Field(..., description="Number of products that failed training") total_products: int = Field(..., description="Total number of products processed") training_results: Dict[str, ProductTrainingResult] = Field(..., description="Per-product results") summary: Dict[str, Any] = Field(..., description="Training summary statistics") completed_at: datetime = Field(..., description="Job completion timestamp") class TrainingValidationResult(BaseModel): """Schema for training data validation results""" is_valid: bool = Field(..., description="Whether the data is valid for training") issues: List[str] = Field(default_factory=list, description="List of data quality issues") recommendations: List[str] = Field(default_factory=list, description="Recommendations for improvement") estimated_time_minutes: int = Field(..., description="Estimated training time in minutes") products_analyzed: int = Field(..., description="Number of products analyzed") total_data_points: int = Field(..., description="Total data points available") class TrainingMetrics(BaseModel): """Schema for training performance metrics""" mae: float = Field(..., description="Mean Absolute Error") mse: float = Field(..., description="Mean Squared Error") rmse: float = Field(..., description="Root Mean Squared Error") mape: float = Field(..., description="Mean Absolute Percentage Error") r2_score: float = Field(..., description="R-squared score") mean_actual: float = Field(..., description="Mean of actual values") mean_predicted: float = Field(..., description="Mean of predicted values") class ExternalDataConfig(BaseModel): """Configuration for external data sources""" weather_enabled: bool = Field(True, description="Enable weather data") traffic_enabled: bool = Field(True, description="Enable traffic data") weather_features: List[str] = Field( default_factory=lambda: ["temperature", "precipitation", "humidity"], description="Weather features to include" ) traffic_features: List[str] = Field( default_factory=lambda: ["traffic_volume"], description="Traffic features to include" ) class TrainingJobConfig(BaseModel): """Complete training job configuration""" external_data: ExternalDataConfig = Field(default_factory=ExternalDataConfig) prophet_params: Dict[str, Any] = Field( default_factory=lambda: { "seasonality_mode": "additive", "daily_seasonality": True, "weekly_seasonality": True, "yearly_seasonality": True }, description="Prophet model parameters" ) data_filters: Dict[str, Any] = Field( default_factory=dict, description="Data filtering parameters" ) validation_params: Dict[str, Any] = Field( default_factory=lambda: {"min_data_points": 30}, description="Data validation parameters" ) class TrainedModelResponse(BaseModel): """Response schema for trained model information""" model_id: str = Field(..., description="Unique model identifier") tenant_id: str = Field(..., description="Tenant identifier") product_name: str = Field(..., description="Product name") model_type: str = Field(..., description="Type of ML model") model_path: str = Field(..., description="Path to stored model") version: int = Field(..., description="Model version") training_samples: int = Field(..., description="Number of training samples") features: List[str] = Field(..., description="List of features used") hyperparameters: Dict[str, Any] = Field(..., description="Model hyperparameters") training_metrics: Dict[str, float] = Field(..., description="Training performance metrics") is_active: bool = Field(..., description="Whether model is active") created_at: datetime = Field(..., description="Model creation timestamp") data_period_start: Optional[datetime] = Field(None, description="Training data start date") data_period_end: Optional[datetime] = Field(None, description="Training data end date")