""" Comprehensive Input Validation Schemas Ensures all API inputs are properly validated before processing """ from pydantic import BaseModel, Field, validator, root_validator from typing import Optional, List, Dict, Any from datetime import datetime, timedelta from uuid import UUID import re class TrainingJobCreateRequest(BaseModel): """Schema for creating a new training job""" tenant_id: UUID = Field(..., description="Tenant identifier") start_date: Optional[str] = Field( None, description="Training data start date (ISO format: YYYY-MM-DD)", example="2024-01-01" ) end_date: Optional[str] = Field( None, description="Training data end date (ISO format: YYYY-MM-DD)", example="2024-12-31" ) product_ids: Optional[List[UUID]] = Field( None, description="Specific products to train (optional, trains all if not provided)" ) force_retrain: bool = Field( default=False, description="Force retraining even if recent models exist" ) @validator('start_date', 'end_date') def validate_date_format(cls, v): """Validate date is in ISO format""" if v is not None: try: datetime.fromisoformat(v) except ValueError: raise ValueError(f"Invalid date format: {v}. Use YYYY-MM-DD format") return v @root_validator def validate_date_range(cls, values): """Validate date range is logical""" start = values.get('start_date') end = values.get('end_date') if start and end: start_dt = datetime.fromisoformat(start) end_dt = datetime.fromisoformat(end) if end_dt <= start_dt: raise ValueError("end_date must be after start_date") # Check reasonable range (max 3 years) if (end_dt - start_dt).days > 1095: raise ValueError("Date range cannot exceed 3 years (1095 days)") # Check not in future if end_dt > datetime.now(): raise ValueError("end_date cannot be in the future") return values class Config: schema_extra = { "example": { "tenant_id": "123e4567-e89b-12d3-a456-426614174000", "start_date": "2024-01-01", "end_date": "2024-12-31", "product_ids": None, "force_retrain": False } } class ForecastRequest(BaseModel): """Schema for generating forecasts""" tenant_id: UUID = Field(..., description="Tenant identifier") product_id: UUID = Field(..., description="Product identifier") forecast_days: int = Field( default=30, ge=1, le=365, description="Number of days to forecast (1-365)" ) include_regressors: bool = Field( default=True, description="Include weather and traffic data in forecast" ) confidence_level: float = Field( default=0.80, ge=0.5, le=0.99, description="Confidence interval (0.5-0.99)" ) class Config: schema_extra = { "example": { "tenant_id": "123e4567-e89b-12d3-a456-426614174000", "product_id": "223e4567-e89b-12d3-a456-426614174000", "forecast_days": 30, "include_regressors": True, "confidence_level": 0.80 } } class ModelEvaluationRequest(BaseModel): """Schema for model evaluation""" tenant_id: UUID = Field(..., description="Tenant identifier") product_id: Optional[UUID] = Field(None, description="Specific product (optional)") evaluation_start_date: str = Field(..., description="Evaluation period start") evaluation_end_date: str = Field(..., description="Evaluation period end") @validator('evaluation_start_date', 'evaluation_end_date') def validate_date_format(cls, v): try: datetime.fromisoformat(v) except ValueError: raise ValueError(f"Invalid date format: {v}") return v @root_validator def validate_evaluation_period(cls, values): start = values.get('evaluation_start_date') end = values.get('evaluation_end_date') if start and end: start_dt = datetime.fromisoformat(start) end_dt = datetime.fromisoformat(end) if end_dt <= start_dt: raise ValueError("evaluation_end_date must be after evaluation_start_date") # Minimum 7 days for meaningful evaluation if (end_dt - start_dt).days < 7: raise ValueError("Evaluation period must be at least 7 days") return values class BulkTrainingRequest(BaseModel): """Schema for bulk training operations""" tenant_ids: List[UUID] = Field( ..., min_items=1, max_items=100, description="List of tenant IDs (max 100)" ) start_date: Optional[str] = Field(None, description="Common start date") end_date: Optional[str] = Field(None, description="Common end date") parallel: bool = Field( default=True, description="Execute training jobs in parallel" ) @validator('tenant_ids') def validate_unique_tenants(cls, v): if len(v) != len(set(v)): raise ValueError("Duplicate tenant IDs not allowed") return v class HyperparameterOverride(BaseModel): """Schema for manual hyperparameter override""" changepoint_prior_scale: Optional[float] = Field( None, ge=0.001, le=0.5, description="Flexibility of trend changes" ) seasonality_prior_scale: Optional[float] = Field( None, ge=0.01, le=10.0, description="Strength of seasonality" ) holidays_prior_scale: Optional[float] = Field( None, ge=0.01, le=10.0, description="Strength of holiday effects" ) seasonality_mode: Optional[str] = Field( None, description="Seasonality mode", regex="^(additive|multiplicative)$" ) daily_seasonality: Optional[bool] = None weekly_seasonality: Optional[bool] = None yearly_seasonality: Optional[bool] = None class Config: schema_extra = { "example": { "changepoint_prior_scale": 0.05, "seasonality_prior_scale": 10.0, "holidays_prior_scale": 10.0, "seasonality_mode": "additive", "daily_seasonality": False, "weekly_seasonality": True, "yearly_seasonality": True } } class AdvancedTrainingRequest(TrainingJobCreateRequest): """Extended training request with advanced options""" hyperparameter_override: Optional[HyperparameterOverride] = Field( None, description="Manual hyperparameter settings (skips optimization)" ) enable_cross_validation: bool = Field( default=True, description="Enable cross-validation during training" ) cv_folds: int = Field( default=3, ge=2, le=10, description="Number of cross-validation folds" ) optimization_trials: Optional[int] = Field( None, ge=5, le=100, description="Number of hyperparameter optimization trials (overrides defaults)" ) save_diagnostics: bool = Field( default=False, description="Save detailed diagnostic plots and metrics" ) class DataQualityCheckRequest(BaseModel): """Schema for data quality validation""" tenant_id: UUID = Field(..., description="Tenant identifier") start_date: str = Field(..., description="Check period start") end_date: str = Field(..., description="Check period end") product_ids: Optional[List[UUID]] = Field( None, description="Specific products to check" ) include_recommendations: bool = Field( default=True, description="Include improvement recommendations" ) @validator('start_date', 'end_date') def validate_date(cls, v): try: datetime.fromisoformat(v) except ValueError: raise ValueError(f"Invalid date format: {v}") return v class ModelQueryParams(BaseModel): """Query parameters for model listing""" tenant_id: Optional[UUID] = None product_id: Optional[UUID] = None is_active: Optional[bool] = None is_production: Optional[bool] = None model_type: Optional[str] = Field(None, regex="^(prophet|prophet_optimized|lstm|arima)$") min_accuracy: Optional[float] = Field(None, ge=0.0, le=1.0) created_after: Optional[datetime] = None created_before: Optional[datetime] = None limit: int = Field(default=100, ge=1, le=1000) offset: int = Field(default=0, ge=0) class Config: schema_extra = { "example": { "tenant_id": "123e4567-e89b-12d3-a456-426614174000", "is_active": True, "is_production": True, "limit": 50, "offset": 0 } } def validate_uuid(value: str) -> UUID: """Validate and convert string to UUID""" try: return UUID(value) except (ValueError, AttributeError): raise ValueError(f"Invalid UUID format: {value}") def validate_date_string(value: str) -> datetime: """Validate and convert date string to datetime""" try: return datetime.fromisoformat(value) except ValueError: raise ValueError(f"Invalid date format: {value}. Use ISO format (YYYY-MM-DD)") def validate_positive_integer(value: int, field_name: str = "value") -> int: """Validate positive integer""" if value <= 0: raise ValueError(f"{field_name} must be positive, got {value}") return value def validate_probability(value: float, field_name: str = "value") -> float: """Validate probability value (0.0-1.0)""" if not 0.0 <= value <= 1.0: raise ValueError(f"{field_name} must be between 0.0 and 1.0, got {value}") return value