Files
bakery-ia/services/training/app/schemas/training.py
2025-07-29 07:53:30 +02:00

362 lines
18 KiB
Python

# services/training/app/schemas/training.py
"""
Complete schema definitions for training service
Includes all request/response schemas used by the API endpoints
"""
from pydantic import BaseModel, Field, validator
from typing import List, Optional, Dict, Any, Union
from datetime import datetime
from enum import Enum
from uuid import UUID
class TrainingStatus(str, Enum):
"""Training job status enumeration"""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class TrainingJobRequest(BaseModel):
"""Request schema for starting a training job"""
products: Optional[List[str]] = Field(None, description="Specific products to train (if None, trains all)")
start_date: Optional[datetime] = Field(None, description="Start date for training data")
end_date: Optional[datetime] = Field(None, description="End date for training data")
class SingleProductTrainingRequest(BaseModel):
"""Request schema for training a single product"""
start_date: Optional[datetime] = Field(None, description="Start date for training data")
end_date: Optional[datetime] = Field(None, description="End date for training data")
# Prophet-specific parameters
seasonality_mode: str = Field("additive", description="Prophet seasonality mode")
daily_seasonality: bool = Field(True, description="Enable daily seasonality")
weekly_seasonality: bool = Field(True, description="Enable weekly seasonality")
yearly_seasonality: bool = Field(True, description="Enable yearly seasonality")
class DateRangeInfo(BaseModel):
"""Schema for date range information"""
start: str = Field(..., description="Start date in ISO format")
end: str = Field(..., description="End date in ISO format")
class DataSummary(BaseModel):
"""Schema for training data summary"""
sales_records: int = Field(..., description="Number of sales records used")
weather_records: int = Field(..., description="Number of weather records used")
traffic_records: int = Field(..., description="Number of traffic records used")
date_range: DateRangeInfo = Field(..., description="Date range of training data")
data_sources_used: List[str] = Field(..., description="List of data sources used")
constraints_applied: Dict[str, str] = Field(default_factory=dict, description="Constraints applied during data collection")
class ProductTrainingResult(BaseModel):
"""Schema for individual product training results"""
product_name: str = Field(..., description="Product name")
status: str = Field(..., description="Training status for this product")
model_id: Optional[str] = Field(None, description="Trained model identifier")
data_points: int = Field(..., description="Number of data points used for training")
metrics: Optional[Dict[str, float]] = Field(None, description="Training metrics (MAE, MAPE, etc.)")
training_time_seconds: Optional[float] = Field(None, description="Time taken to train this model")
error_message: Optional[str] = Field(None, description="Error message if training failed")
class TrainingResults(BaseModel):
"""Schema for overall training results"""
total_products: int = Field(..., description="Total number of products")
successful_trainings: int = Field(..., description="Number of successfully trained models")
failed_trainings: int = Field(..., description="Number of failed trainings")
products: List[ProductTrainingResult] = Field(..., description="Results for each product")
overall_training_time_seconds: float = Field(..., description="Total training time")
class TrainingJobResponse(BaseModel):
"""Enhanced response schema for training job with detailed results"""
job_id: str = Field(..., description="Unique training job identifier")
tenant_id: str = Field(..., description="Tenant identifier")
status: TrainingStatus = Field(..., description="Overall job status")
# Required fields for basic response (backwards compatibility)
message: str = Field(..., description="Status message")
created_at: datetime = Field(..., description="Job creation timestamp")
estimated_duration_minutes: int = Field(..., description="Estimated completion time in minutes")
# New detailed fields (optional for backwards compatibility)
training_results: Optional[TrainingResults] = Field(None, description="Detailed training results")
data_summary: Optional[DataSummary] = Field(None, description="Summary of training data used")
completed_at: Optional[str] = Field(None, description="Job completion timestamp in ISO format")
# Additional optional fields
error_details: Optional[Dict[str, Any]] = Field(None, description="Detailed error information if failed")
processing_metadata: Optional[Dict[str, Any]] = Field(None, description="Additional processing metadata")
@validator('tenant_id', 'job_id', pre=True)
def convert_uuid_to_string(cls, v):
"""Convert UUID objects to strings for JSON serialization"""
if isinstance(v, UUID):
return str(v)
return v
class Config:
from_attributes = True
class TrainingJobStatus(BaseModel):
"""Response schema for training job status checks"""
job_id: str = Field(..., description="Training job identifier")
status: TrainingStatus = Field(..., description="Current job status")
progress: int = Field(0, description="Progress percentage (0-100)")
current_step: str = Field("", description="Current processing step")
started_at: datetime = Field(..., description="Job start timestamp")
completed_at: Optional[datetime] = Field(None, description="Job completion timestamp")
products_total: int = Field(0, description="Total number of products to train")
products_completed: int = Field(0, description="Number of products completed")
products_failed: int = Field(0, description="Number of products that failed")
error_message: Optional[str] = Field(None, description="Error message if failed")
@validator('job_id', pre=True)
def convert_uuid_to_string(cls, v):
"""Convert UUID objects to strings for JSON serialization"""
if isinstance(v, UUID):
return str(v)
return v
class Config:
from_attributes = True
class TrainingJobProgress(BaseModel):
"""Schema for real-time training job progress updates"""
job_id: str = Field(..., description="Training job identifier")
status: TrainingStatus = Field(..., description="Current job status")
progress: int = Field(0, description="Progress percentage (0-100)", ge=0, le=100)
current_step: str = Field(..., description="Current processing step")
current_product: Optional[str] = Field(None, description="Currently training product")
products_completed: int = Field(0, description="Number of products completed")
products_total: int = Field(0, description="Total number of products")
estimated_time_remaining_minutes: Optional[int] = Field(None, description="Estimated time remaining")
timestamp: datetime = Field(default_factory=datetime.now, description="Progress update timestamp")
@validator('job_id', pre=True)
def convert_uuid_to_string(cls, v):
"""Convert UUID objects to strings for JSON serialization"""
if isinstance(v, UUID):
return str(v)
return v
class Config:
from_attributes = True
class DataValidationRequest(BaseModel):
"""Request schema for validating training data"""
products: Optional[List[str]] = Field(None, description="Specific products to validate (if None, validates all)")
min_data_points: int = Field(30, description="Minimum required data points per product", ge=10, le=1000)
start_date: Optional[datetime] = Field(None, description="Start date for data validation")
end_date: Optional[datetime] = Field(None, description="End date for data validation")
@validator('min_data_points')
def validate_min_data_points(cls, v):
if v < 10:
raise ValueError('min_data_points must be at least 10')
return v
class DataValidationResponse(BaseModel):
"""Response schema for data validation results"""
is_valid: bool = Field(..., description="Whether the data is valid for training")
issues: List[str] = Field(default_factory=list, description="List of data quality issues")
recommendations: List[str] = Field(default_factory=list, description="Recommendations for improvement")
estimated_time_minutes: int = Field(..., description="Estimated training time in minutes")
products_analyzed: int = Field(..., description="Number of products analyzed")
total_data_points: int = Field(..., description="Total data points available")
products_with_insufficient_data: List[str] = Field(default_factory=list, description="Products with insufficient data")
data_quality_score: float = Field(0.0, description="Overall data quality score (0-1)", ge=0.0, le=1.0)
class ModelInfo(BaseModel):
"""Schema for trained model information"""
model_id: str = Field(..., description="Unique model identifier")
model_path: str = Field(..., description="Path to stored model")
model_type: str = Field("prophet", description="Type of ML model")
training_samples: int = Field(..., description="Number of training samples")
features: List[str] = Field(..., description="List of features used")
hyperparameters: Dict[str, Any] = Field(..., description="Model hyperparameters")
training_metrics: Dict[str, float] = Field(..., description="Training performance metrics")
trained_at: datetime = Field(..., description="Training completion timestamp")
data_period: Dict[str, str] = Field(..., description="Training data period")
class ProductTrainingResult(BaseModel):
"""Schema for individual product training result"""
product_name: str = Field(..., description="Product name")
status: str = Field(..., description="Training status for this product")
model_info: Optional[ModelInfo] = Field(None, description="Model information if successful")
data_points: int = Field(..., description="Number of data points used")
error_message: Optional[str] = Field(None, description="Error message if failed")
trained_at: datetime = Field(..., description="Training completion timestamp")
training_duration_seconds: Optional[float] = Field(None, description="Training duration in seconds")
class TrainingResultsResponse(BaseModel):
"""Response schema for complete training results"""
job_id: str = Field(..., description="Training job identifier")
tenant_id: str = Field(..., description="Tenant identifier")
status: TrainingStatus = Field(..., description="Overall job status")
products_trained: int = Field(..., description="Number of products successfully trained")
products_failed: int = Field(..., description="Number of products that failed training")
total_products: int = Field(..., description="Total number of products processed")
training_results: Dict[str, ProductTrainingResult] = Field(..., description="Per-product results")
summary: Dict[str, Any] = Field(..., description="Training summary statistics")
completed_at: datetime = Field(..., description="Job completion timestamp")
@validator('tenant_id', 'job_id', pre=True)
def convert_uuid_to_string(cls, v):
"""Convert UUID objects to strings for JSON serialization"""
if isinstance(v, UUID):
return str(v)
return v
class Config:
from_attributes = True
class TrainingValidationResult(BaseModel):
"""Schema for training data validation results"""
is_valid: bool = Field(..., description="Whether the data is valid for training")
issues: List[str] = Field(default_factory=list, description="List of data quality issues")
recommendations: List[str] = Field(default_factory=list, description="Recommendations for improvement")
estimated_time_minutes: int = Field(..., description="Estimated training time in minutes")
products_analyzed: int = Field(..., description="Number of products analyzed")
total_data_points: int = Field(..., description="Total data points available")
class TrainingMetrics(BaseModel):
"""Schema for training performance metrics"""
mae: float = Field(..., description="Mean Absolute Error")
mse: float = Field(..., description="Mean Squared Error")
rmse: float = Field(..., description="Root Mean Squared Error")
mape: float = Field(..., description="Mean Absolute Percentage Error")
r2_score: float = Field(..., description="R-squared score")
mean_actual: float = Field(..., description="Mean of actual values")
mean_predicted: float = Field(..., description="Mean of predicted values")
class ExternalDataConfig(BaseModel):
"""Configuration for external data sources"""
weather_enabled: bool = Field(True, description="Enable weather data")
traffic_enabled: bool = Field(True, description="Enable traffic data")
weather_features: List[str] = Field(
default_factory=lambda: ["temperature", "precipitation", "humidity"],
description="Weather features to include"
)
traffic_features: List[str] = Field(
default_factory=lambda: ["traffic_volume"],
description="Traffic features to include"
)
class TrainingJobConfig(BaseModel):
"""Complete training job configuration"""
external_data: ExternalDataConfig = Field(default_factory=ExternalDataConfig)
prophet_params: Dict[str, Any] = Field(
default_factory=lambda: {
"seasonality_mode": "additive",
"daily_seasonality": True,
"weekly_seasonality": True,
"yearly_seasonality": True
},
description="Prophet model parameters"
)
data_filters: Dict[str, Any] = Field(
default_factory=dict,
description="Data filtering parameters"
)
validation_params: Dict[str, Any] = Field(
default_factory=lambda: {"min_data_points": 30},
description="Data validation parameters"
)
class TrainedModelResponse(BaseModel):
"""Response schema for trained model information"""
model_id: str = Field(..., description="Unique model identifier")
tenant_id: str = Field(..., description="Tenant identifier")
product_name: str = Field(..., description="Product name")
model_type: str = Field(..., description="Type of ML model")
model_path: str = Field(..., description="Path to stored model")
version: int = Field(..., description="Model version")
training_samples: int = Field(..., description="Number of training samples")
features: List[str] = Field(..., description="List of features used")
hyperparameters: Dict[str, Any] = Field(..., description="Model hyperparameters")
training_metrics: Dict[str, float] = Field(..., description="Training performance metrics")
is_active: bool = Field(..., description="Whether model is active")
created_at: datetime = Field(..., description="Model creation timestamp")
data_period_start: Optional[datetime] = Field(None, description="Training data start date")
data_period_end: Optional[datetime] = Field(None, description="Training data end date")
@validator('tenant_id', 'model_id', pre=True)
def convert_uuid_to_string(cls, v):
"""Convert UUID objects to strings for JSON serialization"""
if isinstance(v, UUID):
return str(v)
return v
class Config:
from_attributes = True
class ModelTrainingStats(BaseModel):
"""Schema for model training statistics"""
total_models: int = Field(..., description="Total number of trained models")
active_models: int = Field(..., description="Number of active models")
last_training_date: Optional[datetime] = Field(None, description="Last training date")
avg_training_time_minutes: float = Field(..., description="Average training time in minutes")
success_rate: float = Field(..., description="Training success rate (0-1)")
class BulkTrainingRequest(BaseModel):
"""Request schema for bulk training operations"""
tenant_ids: List[str] = Field(..., description="List of tenant IDs to train")
config: TrainingJobConfig = Field(default_factory=TrainingJobConfig, description="Training configuration")
priority: int = Field(1, description="Training priority (1-10)", ge=1, le=10)
schedule_time: Optional[datetime] = Field(None, description="Schedule training for specific time")
class TrainingScheduleResponse(BaseModel):
"""Response schema for scheduled training jobs"""
schedule_id: str = Field(..., description="Unique schedule identifier")
tenant_ids: List[str] = Field(..., description="Scheduled tenant IDs")
scheduled_time: datetime = Field(..., description="Scheduled execution time")
status: str = Field(..., description="Schedule status")
created_at: datetime = Field(..., description="Schedule creation timestamp")
# WebSocket response schemas for real-time updates
class TrainingProgressUpdate(BaseModel):
"""WebSocket message for training progress updates"""
type: str = Field("training_progress", description="Message type")
job_id: str = Field(..., description="Training job identifier")
progress: TrainingJobProgress = Field(..., description="Progress information")
class TrainingCompletedUpdate(BaseModel):
"""WebSocket message for training completion"""
type: str = Field("training_completed", description="Message type")
job_id: str = Field(..., description="Training job identifier")
results: TrainingResultsResponse = Field(..., description="Training results")
class TrainingErrorUpdate(BaseModel):
"""WebSocket message for training errors"""
type: str = Field("training_error", description="Message type")
job_id: str = Field(..., description="Training job identifier")
error: str = Field(..., description="Error message")
timestamp: datetime = Field(default_factory=datetime.now, description="Error timestamp")
# Union type for all WebSocket messages
TrainingWebSocketMessage = Union[
TrainingProgressUpdate,
TrainingCompletedUpdate,
TrainingErrorUpdate
]