Files
bakery-ia/services/training/app/schemas/validation.py

318 lines
9.9 KiB
Python

"""
Comprehensive Input Validation Schemas
Ensures all API inputs are properly validated before processing
"""
from pydantic import BaseModel, Field, validator, root_validator
from typing import Optional, List, Dict, Any
from datetime import datetime, timedelta
from uuid import UUID
import re
class TrainingJobCreateRequest(BaseModel):
"""Schema for creating a new training job"""
tenant_id: UUID = Field(..., description="Tenant identifier")
start_date: Optional[str] = Field(
None,
description="Training data start date (ISO format: YYYY-MM-DD)",
example="2024-01-01"
)
end_date: Optional[str] = Field(
None,
description="Training data end date (ISO format: YYYY-MM-DD)",
example="2024-12-31"
)
product_ids: Optional[List[UUID]] = Field(
None,
description="Specific products to train (optional, trains all if not provided)"
)
force_retrain: bool = Field(
default=False,
description="Force retraining even if recent models exist"
)
@validator('start_date', 'end_date')
def validate_date_format(cls, v):
"""Validate date is in ISO format"""
if v is not None:
try:
datetime.fromisoformat(v)
except ValueError:
raise ValueError(f"Invalid date format: {v}. Use YYYY-MM-DD format")
return v
@root_validator
def validate_date_range(cls, values):
"""Validate date range is logical"""
start = values.get('start_date')
end = values.get('end_date')
if start and end:
start_dt = datetime.fromisoformat(start)
end_dt = datetime.fromisoformat(end)
if end_dt <= start_dt:
raise ValueError("end_date must be after start_date")
# Check reasonable range (max 3 years)
if (end_dt - start_dt).days > 1095:
raise ValueError("Date range cannot exceed 3 years (1095 days)")
# Check not in future
if end_dt > datetime.now():
raise ValueError("end_date cannot be in the future")
return values
class Config:
schema_extra = {
"example": {
"tenant_id": "123e4567-e89b-12d3-a456-426614174000",
"start_date": "2024-01-01",
"end_date": "2024-12-31",
"product_ids": None,
"force_retrain": False
}
}
class ForecastRequest(BaseModel):
"""Schema for generating forecasts"""
tenant_id: UUID = Field(..., description="Tenant identifier")
product_id: UUID = Field(..., description="Product identifier")
forecast_days: int = Field(
default=30,
ge=1,
le=365,
description="Number of days to forecast (1-365)"
)
include_regressors: bool = Field(
default=True,
description="Include weather and traffic data in forecast"
)
confidence_level: float = Field(
default=0.80,
ge=0.5,
le=0.99,
description="Confidence interval (0.5-0.99)"
)
class Config:
schema_extra = {
"example": {
"tenant_id": "123e4567-e89b-12d3-a456-426614174000",
"product_id": "223e4567-e89b-12d3-a456-426614174000",
"forecast_days": 30,
"include_regressors": True,
"confidence_level": 0.80
}
}
class ModelEvaluationRequest(BaseModel):
"""Schema for model evaluation"""
tenant_id: UUID = Field(..., description="Tenant identifier")
product_id: Optional[UUID] = Field(None, description="Specific product (optional)")
evaluation_start_date: str = Field(..., description="Evaluation period start")
evaluation_end_date: str = Field(..., description="Evaluation period end")
@validator('evaluation_start_date', 'evaluation_end_date')
def validate_date_format(cls, v):
try:
datetime.fromisoformat(v)
except ValueError:
raise ValueError(f"Invalid date format: {v}")
return v
@root_validator
def validate_evaluation_period(cls, values):
start = values.get('evaluation_start_date')
end = values.get('evaluation_end_date')
if start and end:
start_dt = datetime.fromisoformat(start)
end_dt = datetime.fromisoformat(end)
if end_dt <= start_dt:
raise ValueError("evaluation_end_date must be after evaluation_start_date")
# Minimum 7 days for meaningful evaluation
if (end_dt - start_dt).days < 7:
raise ValueError("Evaluation period must be at least 7 days")
return values
class BulkTrainingRequest(BaseModel):
"""Schema for bulk training operations"""
tenant_ids: List[UUID] = Field(
...,
min_items=1,
max_items=100,
description="List of tenant IDs (max 100)"
)
start_date: Optional[str] = Field(None, description="Common start date")
end_date: Optional[str] = Field(None, description="Common end date")
parallel: bool = Field(
default=True,
description="Execute training jobs in parallel"
)
@validator('tenant_ids')
def validate_unique_tenants(cls, v):
if len(v) != len(set(v)):
raise ValueError("Duplicate tenant IDs not allowed")
return v
class HyperparameterOverride(BaseModel):
"""Schema for manual hyperparameter override"""
changepoint_prior_scale: Optional[float] = Field(
None, ge=0.001, le=0.5,
description="Flexibility of trend changes"
)
seasonality_prior_scale: Optional[float] = Field(
None, ge=0.01, le=10.0,
description="Strength of seasonality"
)
holidays_prior_scale: Optional[float] = Field(
None, ge=0.01, le=10.0,
description="Strength of holiday effects"
)
seasonality_mode: Optional[str] = Field(
None,
description="Seasonality mode",
regex="^(additive|multiplicative)$"
)
daily_seasonality: Optional[bool] = None
weekly_seasonality: Optional[bool] = None
yearly_seasonality: Optional[bool] = None
class Config:
schema_extra = {
"example": {
"changepoint_prior_scale": 0.05,
"seasonality_prior_scale": 10.0,
"holidays_prior_scale": 10.0,
"seasonality_mode": "additive",
"daily_seasonality": False,
"weekly_seasonality": True,
"yearly_seasonality": True
}
}
class AdvancedTrainingRequest(TrainingJobCreateRequest):
"""Extended training request with advanced options"""
hyperparameter_override: Optional[HyperparameterOverride] = Field(
None,
description="Manual hyperparameter settings (skips optimization)"
)
enable_cross_validation: bool = Field(
default=True,
description="Enable cross-validation during training"
)
cv_folds: int = Field(
default=3,
ge=2,
le=10,
description="Number of cross-validation folds"
)
optimization_trials: Optional[int] = Field(
None,
ge=5,
le=100,
description="Number of hyperparameter optimization trials (overrides defaults)"
)
save_diagnostics: bool = Field(
default=False,
description="Save detailed diagnostic plots and metrics"
)
class DataQualityCheckRequest(BaseModel):
"""Schema for data quality validation"""
tenant_id: UUID = Field(..., description="Tenant identifier")
start_date: str = Field(..., description="Check period start")
end_date: str = Field(..., description="Check period end")
product_ids: Optional[List[UUID]] = Field(
None,
description="Specific products to check"
)
include_recommendations: bool = Field(
default=True,
description="Include improvement recommendations"
)
@validator('start_date', 'end_date')
def validate_date(cls, v):
try:
datetime.fromisoformat(v)
except ValueError:
raise ValueError(f"Invalid date format: {v}")
return v
class ModelQueryParams(BaseModel):
"""Query parameters for model listing"""
tenant_id: Optional[UUID] = None
product_id: Optional[UUID] = None
is_active: Optional[bool] = None
is_production: Optional[bool] = None
model_type: Optional[str] = Field(None, regex="^(prophet|prophet_optimized|lstm|arima)$")
min_accuracy: Optional[float] = Field(None, ge=0.0, le=1.0)
created_after: Optional[datetime] = None
created_before: Optional[datetime] = None
limit: int = Field(default=100, ge=1, le=1000)
offset: int = Field(default=0, ge=0)
class Config:
schema_extra = {
"example": {
"tenant_id": "123e4567-e89b-12d3-a456-426614174000",
"is_active": True,
"is_production": True,
"limit": 50,
"offset": 0
}
}
def validate_uuid(value: str) -> UUID:
"""Validate and convert string to UUID"""
try:
return UUID(value)
except (ValueError, AttributeError):
raise ValueError(f"Invalid UUID format: {value}")
def validate_date_string(value: str) -> datetime:
"""Validate and convert date string to datetime"""
try:
return datetime.fromisoformat(value)
except ValueError:
raise ValueError(f"Invalid date format: {value}. Use ISO format (YYYY-MM-DD)")
def validate_positive_integer(value: int, field_name: str = "value") -> int:
"""Validate positive integer"""
if value <= 0:
raise ValueError(f"{field_name} must be positive, got {value}")
return value
def validate_probability(value: float, field_name: str = "value") -> float:
"""Validate probability value (0.0-1.0)"""
if not 0.0 <= value <= 1.0:
raise ValueError(f"{field_name} must be between 0.0 and 1.0, got {value}")
return value