# services/training/app/schemas/training.py """ Complete schema definitions for training service Includes all request/response schemas used by the API endpoints """ from pydantic import BaseModel, Field, validator from typing import List, Optional, Dict, Any, Union from datetime import datetime from enum import Enum class TrainingStatus(str, Enum): """Training job status enumeration""" PENDING = "pending" RUNNING = "running" COMPLETED = "completed" FAILED = "failed" CANCELLED = "cancelled" class TrainingJobRequest(BaseModel): """Request schema for starting a training job""" products: Optional[List[str]] = Field(None, description="Specific products to train (if None, trains all)") include_weather: bool = Field(True, description="Include weather data in training") include_traffic: bool = Field(True, description="Include traffic data in training") start_date: Optional[datetime] = Field(None, description="Start date for training data") end_date: Optional[datetime] = Field(None, description="End date for training data") # Prophet-specific parameters seasonality_mode: str = Field("additive", description="Prophet seasonality mode", pattern="^(additive|multiplicative)$") daily_seasonality: bool = Field(True, description="Enable daily seasonality") weekly_seasonality: bool = Field(True, description="Enable weekly seasonality") yearly_seasonality: bool = Field(True, description="Enable yearly seasonality") # Advanced configuration force_retrain: bool = Field(False, description="Force retraining even if recent model exists") parallel_training: bool = Field(True, description="Train products in parallel") max_workers: int = Field(4, description="Maximum parallel workers", ge=1, le=10) @validator('seasonality_mode') def validate_seasonality_mode(cls, v): if v not in ['additive', 'multiplicative']: raise ValueError('seasonality_mode must be either "additive" or "multiplicative"') return v class SingleProductTrainingRequest(BaseModel): """Request schema for training a single product""" include_weather: bool = Field(True, description="Include weather data in training") include_traffic: bool = Field(True, description="Include traffic data in training") start_date: Optional[datetime] = Field(None, description="Start date for training data") end_date: Optional[datetime] = Field(None, description="End date for training data") # Prophet-specific parameters seasonality_mode: str = Field("additive", description="Prophet seasonality mode") daily_seasonality: bool = Field(True, description="Enable daily seasonality") weekly_seasonality: bool = Field(True, description="Enable weekly seasonality") yearly_seasonality: bool = Field(True, description="Enable yearly seasonality") class TrainingJobResponse(BaseModel): """Response schema for training job creation""" job_id: str = Field(..., description="Unique training job identifier") status: TrainingStatus = Field(..., description="Current job status") message: str = Field(..., description="Status message") tenant_id: str = Field(..., description="Tenant identifier") created_at: datetime = Field(..., description="Job creation timestamp") estimated_duration_minutes: int = Field(..., description="Estimated completion time in minutes") class TrainingJobStatus(BaseModel): """Response schema for training job status checks""" job_id: str = Field(..., description="Training job identifier") status: TrainingStatus = Field(..., description="Current job status") progress: int = Field(0, description="Progress percentage (0-100)") current_step: str = Field("", description="Current processing step") started_at: datetime = Field(..., description="Job start timestamp") completed_at: Optional[datetime] = Field(None, description="Job completion timestamp") products_total: int = Field(0, description="Total number of products to train") products_completed: int = Field(0, description="Number of products completed") products_failed: int = Field(0, description="Number of products that failed") error_message: Optional[str] = Field(None, description="Error message if failed") class TrainingJobProgress(BaseModel): """Schema for real-time training job progress updates""" job_id: str = Field(..., description="Training job identifier") status: TrainingStatus = Field(..., description="Current job status") progress: int = Field(0, description="Progress percentage (0-100)", ge=0, le=100) current_step: str = Field(..., description="Current processing step") current_product: Optional[str] = Field(None, description="Currently training product") products_completed: int = Field(0, description="Number of products completed") products_total: int = Field(0, description="Total number of products") estimated_time_remaining_minutes: Optional[int] = Field(None, description="Estimated time remaining") timestamp: datetime = Field(default_factory=datetime.now, description="Progress update timestamp") class DataValidationRequest(BaseModel): """Request schema for validating training data""" products: Optional[List[str]] = Field(None, description="Specific products to validate (if None, validates all)") min_data_points: int = Field(30, description="Minimum required data points per product", ge=10, le=1000) start_date: Optional[datetime] = Field(None, description="Start date for data validation") end_date: Optional[datetime] = Field(None, description="End date for data validation") @validator('min_data_points') def validate_min_data_points(cls, v): if v < 10: raise ValueError('min_data_points must be at least 10') return v class DataValidationResponse(BaseModel): """Response schema for data validation results""" is_valid: bool = Field(..., description="Whether the data is valid for training") issues: List[str] = Field(default_factory=list, description="List of data quality issues") recommendations: List[str] = Field(default_factory=list, description="Recommendations for improvement") estimated_time_minutes: int = Field(..., description="Estimated training time in minutes") products_analyzed: int = Field(..., description="Number of products analyzed") total_data_points: int = Field(..., description="Total data points available") products_with_insufficient_data: List[str] = Field(default_factory=list, description="Products with insufficient data") data_quality_score: float = Field(0.0, description="Overall data quality score (0-1)", ge=0.0, le=1.0) class ModelInfo(BaseModel): """Schema for trained model information""" model_id: str = Field(..., description="Unique model identifier") model_path: str = Field(..., description="Path to stored model") model_type: str = Field("prophet", description="Type of ML model") training_samples: int = Field(..., description="Number of training samples") features: List[str] = Field(..., description="List of features used") hyperparameters: Dict[str, Any] = Field(..., description="Model hyperparameters") training_metrics: Dict[str, float] = Field(..., description="Training performance metrics") trained_at: datetime = Field(..., description="Training completion timestamp") data_period: Dict[str, str] = Field(..., description="Training data period") class ProductTrainingResult(BaseModel): """Schema for individual product training result""" product_name: str = Field(..., description="Product name") status: str = Field(..., description="Training status for this product") model_info: Optional[ModelInfo] = Field(None, description="Model information if successful") data_points: int = Field(..., description="Number of data points used") error_message: Optional[str] = Field(None, description="Error message if failed") trained_at: datetime = Field(..., description="Training completion timestamp") training_duration_seconds: Optional[float] = Field(None, description="Training duration in seconds") class TrainingResultsResponse(BaseModel): """Response schema for complete training results""" job_id: str = Field(..., description="Training job identifier") tenant_id: str = Field(..., description="Tenant identifier") status: TrainingStatus = Field(..., description="Overall job status") products_trained: int = Field(..., description="Number of products successfully trained") products_failed: int = Field(..., description="Number of products that failed training") total_products: int = Field(..., description="Total number of products processed") training_results: Dict[str, ProductTrainingResult] = Field(..., description="Per-product results") summary: Dict[str, Any] = Field(..., description="Training summary statistics") completed_at: datetime = Field(..., description="Job completion timestamp") class TrainingValidationResult(BaseModel): """Schema for training data validation results""" is_valid: bool = Field(..., description="Whether the data is valid for training") issues: List[str] = Field(default_factory=list, description="List of data quality issues") recommendations: List[str] = Field(default_factory=list, description="Recommendations for improvement") estimated_time_minutes: int = Field(..., description="Estimated training time in minutes") products_analyzed: int = Field(..., description="Number of products analyzed") total_data_points: int = Field(..., description="Total data points available") class TrainingMetrics(BaseModel): """Schema for training performance metrics""" mae: float = Field(..., description="Mean Absolute Error") mse: float = Field(..., description="Mean Squared Error") rmse: float = Field(..., description="Root Mean Squared Error") mape: float = Field(..., description="Mean Absolute Percentage Error") r2_score: float = Field(..., description="R-squared score") mean_actual: float = Field(..., description="Mean of actual values") mean_predicted: float = Field(..., description="Mean of predicted values") class ExternalDataConfig(BaseModel): """Configuration for external data sources""" weather_enabled: bool = Field(True, description="Enable weather data") traffic_enabled: bool = Field(True, description="Enable traffic data") weather_features: List[str] = Field( default_factory=lambda: ["temperature", "precipitation", "humidity"], description="Weather features to include" ) traffic_features: List[str] = Field( default_factory=lambda: ["traffic_volume"], description="Traffic features to include" ) class TrainingJobConfig(BaseModel): """Complete training job configuration""" external_data: ExternalDataConfig = Field(default_factory=ExternalDataConfig) prophet_params: Dict[str, Any] = Field( default_factory=lambda: { "seasonality_mode": "additive", "daily_seasonality": True, "weekly_seasonality": True, "yearly_seasonality": True }, description="Prophet model parameters" ) data_filters: Dict[str, Any] = Field( default_factory=dict, description="Data filtering parameters" ) validation_params: Dict[str, Any] = Field( default_factory=lambda: {"min_data_points": 30}, description="Data validation parameters" ) class TrainedModelResponse(BaseModel): """Response schema for trained model information""" model_id: str = Field(..., description="Unique model identifier") tenant_id: str = Field(..., description="Tenant identifier") product_name: str = Field(..., description="Product name") model_type: str = Field(..., description="Type of ML model") model_path: str = Field(..., description="Path to stored model") version: int = Field(..., description="Model version") training_samples: int = Field(..., description="Number of training samples") features: List[str] = Field(..., description="List of features used") hyperparameters: Dict[str, Any] = Field(..., description="Model hyperparameters") training_metrics: Dict[str, float] = Field(..., description="Training performance metrics") is_active: bool = Field(..., description="Whether model is active") created_at: datetime = Field(..., description="Model creation timestamp") data_period_start: Optional[datetime] = Field(None, description="Training data start date") data_period_end: Optional[datetime] = Field(None, description="Training data end date") class ModelTrainingStats(BaseModel): """Schema for model training statistics""" total_models: int = Field(..., description="Total number of trained models") active_models: int = Field(..., description="Number of active models") last_training_date: Optional[datetime] = Field(None, description="Last training date") avg_training_time_minutes: float = Field(..., description="Average training time in minutes") success_rate: float = Field(..., description="Training success rate (0-1)") class BulkTrainingRequest(BaseModel): """Request schema for bulk training operations""" tenant_ids: List[str] = Field(..., description="List of tenant IDs to train") config: TrainingJobConfig = Field(default_factory=TrainingJobConfig, description="Training configuration") priority: int = Field(1, description="Training priority (1-10)", ge=1, le=10) schedule_time: Optional[datetime] = Field(None, description="Schedule training for specific time") class TrainingScheduleResponse(BaseModel): """Response schema for scheduled training jobs""" schedule_id: str = Field(..., description="Unique schedule identifier") tenant_ids: List[str] = Field(..., description="Scheduled tenant IDs") scheduled_time: datetime = Field(..., description="Scheduled execution time") status: str = Field(..., description="Schedule status") created_at: datetime = Field(..., description="Schedule creation timestamp") # WebSocket response schemas for real-time updates class TrainingProgressUpdate(BaseModel): """WebSocket message for training progress updates""" type: str = Field("training_progress", description="Message type") job_id: str = Field(..., description="Training job identifier") progress: TrainingJobProgress = Field(..., description="Progress information") class TrainingCompletedUpdate(BaseModel): """WebSocket message for training completion""" type: str = Field("training_completed", description="Message type") job_id: str = Field(..., description="Training job identifier") results: TrainingResultsResponse = Field(..., description="Training results") class TrainingErrorUpdate(BaseModel): """WebSocket message for training errors""" type: str = Field("training_error", description="Message type") job_id: str = Field(..., description="Training job identifier") error: str = Field(..., description="Error message") timestamp: datetime = Field(default_factory=datetime.now, description="Error timestamp") # Union type for all WebSocket messages TrainingWebSocketMessage = Union[ TrainingProgressUpdate, TrainingCompletedUpdate, TrainingErrorUpdate ]