Files
bakery-ia/services/training/app/schemas/validation.py

318 lines
9.9 KiB
Python
Raw Normal View History

"""
Comprehensive Input Validation Schemas
Ensures all API inputs are properly validated before processing
"""
from pydantic import BaseModel, Field, validator, root_validator
from typing import Optional, List, Dict, Any
from datetime import datetime, timedelta
from uuid import UUID
import re
class TrainingJobCreateRequest(BaseModel):
"""Schema for creating a new training job"""
tenant_id: UUID = Field(..., description="Tenant identifier")
start_date: Optional[str] = Field(
None,
description="Training data start date (ISO format: YYYY-MM-DD)",
example="2024-01-01"
)
end_date: Optional[str] = Field(
None,
description="Training data end date (ISO format: YYYY-MM-DD)",
example="2024-12-31"
)
product_ids: Optional[List[UUID]] = Field(
None,
description="Specific products to train (optional, trains all if not provided)"
)
force_retrain: bool = Field(
default=False,
description="Force retraining even if recent models exist"
)
@validator('start_date', 'end_date')
def validate_date_format(cls, v):
"""Validate date is in ISO format"""
if v is not None:
try:
datetime.fromisoformat(v)
except ValueError:
raise ValueError(f"Invalid date format: {v}. Use YYYY-MM-DD format")
return v
@root_validator
def validate_date_range(cls, values):
"""Validate date range is logical"""
start = values.get('start_date')
end = values.get('end_date')
if start and end:
start_dt = datetime.fromisoformat(start)
end_dt = datetime.fromisoformat(end)
if end_dt <= start_dt:
raise ValueError("end_date must be after start_date")
# Check reasonable range (max 3 years)
if (end_dt - start_dt).days > 1095:
raise ValueError("Date range cannot exceed 3 years (1095 days)")
# Check not in future
if end_dt > datetime.now():
raise ValueError("end_date cannot be in the future")
return values
class Config:
schema_extra = {
"example": {
"tenant_id": "123e4567-e89b-12d3-a456-426614174000",
"start_date": "2024-01-01",
"end_date": "2024-12-31",
"product_ids": None,
"force_retrain": False
}
}
class ForecastRequest(BaseModel):
"""Schema for generating forecasts"""
tenant_id: UUID = Field(..., description="Tenant identifier")
product_id: UUID = Field(..., description="Product identifier")
forecast_days: int = Field(
default=30,
ge=1,
le=365,
description="Number of days to forecast (1-365)"
)
include_regressors: bool = Field(
default=True,
description="Include weather and traffic data in forecast"
)
confidence_level: float = Field(
default=0.80,
ge=0.5,
le=0.99,
description="Confidence interval (0.5-0.99)"
)
class Config:
schema_extra = {
"example": {
"tenant_id": "123e4567-e89b-12d3-a456-426614174000",
"product_id": "223e4567-e89b-12d3-a456-426614174000",
"forecast_days": 30,
"include_regressors": True,
"confidence_level": 0.80
}
}
class ModelEvaluationRequest(BaseModel):
"""Schema for model evaluation"""
tenant_id: UUID = Field(..., description="Tenant identifier")
product_id: Optional[UUID] = Field(None, description="Specific product (optional)")
evaluation_start_date: str = Field(..., description="Evaluation period start")
evaluation_end_date: str = Field(..., description="Evaluation period end")
@validator('evaluation_start_date', 'evaluation_end_date')
def validate_date_format(cls, v):
try:
datetime.fromisoformat(v)
except ValueError:
raise ValueError(f"Invalid date format: {v}")
return v
@root_validator
def validate_evaluation_period(cls, values):
start = values.get('evaluation_start_date')
end = values.get('evaluation_end_date')
if start and end:
start_dt = datetime.fromisoformat(start)
end_dt = datetime.fromisoformat(end)
if end_dt <= start_dt:
raise ValueError("evaluation_end_date must be after evaluation_start_date")
# Minimum 7 days for meaningful evaluation
if (end_dt - start_dt).days < 7:
raise ValueError("Evaluation period must be at least 7 days")
return values
class BulkTrainingRequest(BaseModel):
"""Schema for bulk training operations"""
tenant_ids: List[UUID] = Field(
...,
min_items=1,
max_items=100,
description="List of tenant IDs (max 100)"
)
start_date: Optional[str] = Field(None, description="Common start date")
end_date: Optional[str] = Field(None, description="Common end date")
parallel: bool = Field(
default=True,
description="Execute training jobs in parallel"
)
@validator('tenant_ids')
def validate_unique_tenants(cls, v):
if len(v) != len(set(v)):
raise ValueError("Duplicate tenant IDs not allowed")
return v
class HyperparameterOverride(BaseModel):
"""Schema for manual hyperparameter override"""
changepoint_prior_scale: Optional[float] = Field(
None, ge=0.001, le=0.5,
description="Flexibility of trend changes"
)
seasonality_prior_scale: Optional[float] = Field(
None, ge=0.01, le=10.0,
description="Strength of seasonality"
)
holidays_prior_scale: Optional[float] = Field(
None, ge=0.01, le=10.0,
description="Strength of holiday effects"
)
seasonality_mode: Optional[str] = Field(
None,
description="Seasonality mode",
regex="^(additive|multiplicative)$"
)
daily_seasonality: Optional[bool] = None
weekly_seasonality: Optional[bool] = None
yearly_seasonality: Optional[bool] = None
class Config:
schema_extra = {
"example": {
"changepoint_prior_scale": 0.05,
"seasonality_prior_scale": 10.0,
"holidays_prior_scale": 10.0,
"seasonality_mode": "additive",
"daily_seasonality": False,
"weekly_seasonality": True,
"yearly_seasonality": True
}
}
class AdvancedTrainingRequest(TrainingJobCreateRequest):
"""Extended training request with advanced options"""
hyperparameter_override: Optional[HyperparameterOverride] = Field(
None,
description="Manual hyperparameter settings (skips optimization)"
)
enable_cross_validation: bool = Field(
default=True,
description="Enable cross-validation during training"
)
cv_folds: int = Field(
default=3,
ge=2,
le=10,
description="Number of cross-validation folds"
)
optimization_trials: Optional[int] = Field(
None,
ge=5,
le=100,
description="Number of hyperparameter optimization trials (overrides defaults)"
)
save_diagnostics: bool = Field(
default=False,
description="Save detailed diagnostic plots and metrics"
)
class DataQualityCheckRequest(BaseModel):
"""Schema for data quality validation"""
tenant_id: UUID = Field(..., description="Tenant identifier")
start_date: str = Field(..., description="Check period start")
end_date: str = Field(..., description="Check period end")
product_ids: Optional[List[UUID]] = Field(
None,
description="Specific products to check"
)
include_recommendations: bool = Field(
default=True,
description="Include improvement recommendations"
)
@validator('start_date', 'end_date')
def validate_date(cls, v):
try:
datetime.fromisoformat(v)
except ValueError:
raise ValueError(f"Invalid date format: {v}")
return v
class ModelQueryParams(BaseModel):
"""Query parameters for model listing"""
tenant_id: Optional[UUID] = None
product_id: Optional[UUID] = None
is_active: Optional[bool] = None
is_production: Optional[bool] = None
model_type: Optional[str] = Field(None, regex="^(prophet|prophet_optimized|lstm|arima)$")
min_accuracy: Optional[float] = Field(None, ge=0.0, le=1.0)
created_after: Optional[datetime] = None
created_before: Optional[datetime] = None
limit: int = Field(default=100, ge=1, le=1000)
offset: int = Field(default=0, ge=0)
class Config:
schema_extra = {
"example": {
"tenant_id": "123e4567-e89b-12d3-a456-426614174000",
"is_active": True,
"is_production": True,
"limit": 50,
"offset": 0
}
}
def validate_uuid(value: str) -> UUID:
"""Validate and convert string to UUID"""
try:
return UUID(value)
except (ValueError, AttributeError):
raise ValueError(f"Invalid UUID format: {value}")
def validate_date_string(value: str) -> datetime:
"""Validate and convert date string to datetime"""
try:
return datetime.fromisoformat(value)
except ValueError:
raise ValueError(f"Invalid date format: {value}. Use ISO format (YYYY-MM-DD)")
def validate_positive_integer(value: int, field_name: str = "value") -> int:
"""Validate positive integer"""
if value <= 0:
raise ValueError(f"{field_name} must be positive, got {value}")
return value
def validate_probability(value: float, field_name: str = "value") -> float:
"""Validate probability value (0.0-1.0)"""
if not 0.0 <= value <= 1.0:
raise ValueError(f"{field_name} must be between 0.0 and 1.0, got {value}")
return value