318 lines
9.9 KiB
Python
318 lines
9.9 KiB
Python
"""
|
|
Comprehensive Input Validation Schemas
|
|
Ensures all API inputs are properly validated before processing
|
|
"""
|
|
|
|
from pydantic import BaseModel, Field, validator, root_validator
|
|
from typing import Optional, List, Dict, Any
|
|
from datetime import datetime, timedelta
|
|
from uuid import UUID
|
|
import re
|
|
|
|
|
|
class TrainingJobCreateRequest(BaseModel):
|
|
"""Schema for creating a new training job"""
|
|
|
|
tenant_id: UUID = Field(..., description="Tenant identifier")
|
|
start_date: Optional[str] = Field(
|
|
None,
|
|
description="Training data start date (ISO format: YYYY-MM-DD)",
|
|
example="2024-01-01"
|
|
)
|
|
end_date: Optional[str] = Field(
|
|
None,
|
|
description="Training data end date (ISO format: YYYY-MM-DD)",
|
|
example="2024-12-31"
|
|
)
|
|
product_ids: Optional[List[UUID]] = Field(
|
|
None,
|
|
description="Specific products to train (optional, trains all if not provided)"
|
|
)
|
|
force_retrain: bool = Field(
|
|
default=False,
|
|
description="Force retraining even if recent models exist"
|
|
)
|
|
|
|
@validator('start_date', 'end_date')
|
|
def validate_date_format(cls, v):
|
|
"""Validate date is in ISO format"""
|
|
if v is not None:
|
|
try:
|
|
datetime.fromisoformat(v)
|
|
except ValueError:
|
|
raise ValueError(f"Invalid date format: {v}. Use YYYY-MM-DD format")
|
|
return v
|
|
|
|
@root_validator
|
|
def validate_date_range(cls, values):
|
|
"""Validate date range is logical"""
|
|
start = values.get('start_date')
|
|
end = values.get('end_date')
|
|
|
|
if start and end:
|
|
start_dt = datetime.fromisoformat(start)
|
|
end_dt = datetime.fromisoformat(end)
|
|
|
|
if end_dt <= start_dt:
|
|
raise ValueError("end_date must be after start_date")
|
|
|
|
# Check reasonable range (max 3 years)
|
|
if (end_dt - start_dt).days > 1095:
|
|
raise ValueError("Date range cannot exceed 3 years (1095 days)")
|
|
|
|
# Check not in future
|
|
if end_dt > datetime.now():
|
|
raise ValueError("end_date cannot be in the future")
|
|
|
|
return values
|
|
|
|
class Config:
|
|
schema_extra = {
|
|
"example": {
|
|
"tenant_id": "123e4567-e89b-12d3-a456-426614174000",
|
|
"start_date": "2024-01-01",
|
|
"end_date": "2024-12-31",
|
|
"product_ids": None,
|
|
"force_retrain": False
|
|
}
|
|
}
|
|
|
|
|
|
class ForecastRequest(BaseModel):
|
|
"""Schema for generating forecasts"""
|
|
|
|
tenant_id: UUID = Field(..., description="Tenant identifier")
|
|
product_id: UUID = Field(..., description="Product identifier")
|
|
forecast_days: int = Field(
|
|
default=30,
|
|
ge=1,
|
|
le=365,
|
|
description="Number of days to forecast (1-365)"
|
|
)
|
|
include_regressors: bool = Field(
|
|
default=True,
|
|
description="Include weather and traffic data in forecast"
|
|
)
|
|
confidence_level: float = Field(
|
|
default=0.80,
|
|
ge=0.5,
|
|
le=0.99,
|
|
description="Confidence interval (0.5-0.99)"
|
|
)
|
|
|
|
class Config:
|
|
schema_extra = {
|
|
"example": {
|
|
"tenant_id": "123e4567-e89b-12d3-a456-426614174000",
|
|
"product_id": "223e4567-e89b-12d3-a456-426614174000",
|
|
"forecast_days": 30,
|
|
"include_regressors": True,
|
|
"confidence_level": 0.80
|
|
}
|
|
}
|
|
|
|
|
|
class ModelEvaluationRequest(BaseModel):
|
|
"""Schema for model evaluation"""
|
|
|
|
tenant_id: UUID = Field(..., description="Tenant identifier")
|
|
product_id: Optional[UUID] = Field(None, description="Specific product (optional)")
|
|
evaluation_start_date: str = Field(..., description="Evaluation period start")
|
|
evaluation_end_date: str = Field(..., description="Evaluation period end")
|
|
|
|
@validator('evaluation_start_date', 'evaluation_end_date')
|
|
def validate_date_format(cls, v):
|
|
try:
|
|
datetime.fromisoformat(v)
|
|
except ValueError:
|
|
raise ValueError(f"Invalid date format: {v}")
|
|
return v
|
|
|
|
@root_validator
|
|
def validate_evaluation_period(cls, values):
|
|
start = values.get('evaluation_start_date')
|
|
end = values.get('evaluation_end_date')
|
|
|
|
if start and end:
|
|
start_dt = datetime.fromisoformat(start)
|
|
end_dt = datetime.fromisoformat(end)
|
|
|
|
if end_dt <= start_dt:
|
|
raise ValueError("evaluation_end_date must be after evaluation_start_date")
|
|
|
|
# Minimum 7 days for meaningful evaluation
|
|
if (end_dt - start_dt).days < 7:
|
|
raise ValueError("Evaluation period must be at least 7 days")
|
|
|
|
return values
|
|
|
|
|
|
class BulkTrainingRequest(BaseModel):
|
|
"""Schema for bulk training operations"""
|
|
|
|
tenant_ids: List[UUID] = Field(
|
|
...,
|
|
min_items=1,
|
|
max_items=100,
|
|
description="List of tenant IDs (max 100)"
|
|
)
|
|
start_date: Optional[str] = Field(None, description="Common start date")
|
|
end_date: Optional[str] = Field(None, description="Common end date")
|
|
parallel: bool = Field(
|
|
default=True,
|
|
description="Execute training jobs in parallel"
|
|
)
|
|
|
|
@validator('tenant_ids')
|
|
def validate_unique_tenants(cls, v):
|
|
if len(v) != len(set(v)):
|
|
raise ValueError("Duplicate tenant IDs not allowed")
|
|
return v
|
|
|
|
|
|
class HyperparameterOverride(BaseModel):
|
|
"""Schema for manual hyperparameter override"""
|
|
|
|
changepoint_prior_scale: Optional[float] = Field(
|
|
None, ge=0.001, le=0.5,
|
|
description="Flexibility of trend changes"
|
|
)
|
|
seasonality_prior_scale: Optional[float] = Field(
|
|
None, ge=0.01, le=10.0,
|
|
description="Strength of seasonality"
|
|
)
|
|
holidays_prior_scale: Optional[float] = Field(
|
|
None, ge=0.01, le=10.0,
|
|
description="Strength of holiday effects"
|
|
)
|
|
seasonality_mode: Optional[str] = Field(
|
|
None,
|
|
description="Seasonality mode",
|
|
regex="^(additive|multiplicative)$"
|
|
)
|
|
daily_seasonality: Optional[bool] = None
|
|
weekly_seasonality: Optional[bool] = None
|
|
yearly_seasonality: Optional[bool] = None
|
|
|
|
class Config:
|
|
schema_extra = {
|
|
"example": {
|
|
"changepoint_prior_scale": 0.05,
|
|
"seasonality_prior_scale": 10.0,
|
|
"holidays_prior_scale": 10.0,
|
|
"seasonality_mode": "additive",
|
|
"daily_seasonality": False,
|
|
"weekly_seasonality": True,
|
|
"yearly_seasonality": True
|
|
}
|
|
}
|
|
|
|
|
|
class AdvancedTrainingRequest(TrainingJobCreateRequest):
|
|
"""Extended training request with advanced options"""
|
|
|
|
hyperparameter_override: Optional[HyperparameterOverride] = Field(
|
|
None,
|
|
description="Manual hyperparameter settings (skips optimization)"
|
|
)
|
|
enable_cross_validation: bool = Field(
|
|
default=True,
|
|
description="Enable cross-validation during training"
|
|
)
|
|
cv_folds: int = Field(
|
|
default=3,
|
|
ge=2,
|
|
le=10,
|
|
description="Number of cross-validation folds"
|
|
)
|
|
optimization_trials: Optional[int] = Field(
|
|
None,
|
|
ge=5,
|
|
le=100,
|
|
description="Number of hyperparameter optimization trials (overrides defaults)"
|
|
)
|
|
save_diagnostics: bool = Field(
|
|
default=False,
|
|
description="Save detailed diagnostic plots and metrics"
|
|
)
|
|
|
|
|
|
class DataQualityCheckRequest(BaseModel):
|
|
"""Schema for data quality validation"""
|
|
|
|
tenant_id: UUID = Field(..., description="Tenant identifier")
|
|
start_date: str = Field(..., description="Check period start")
|
|
end_date: str = Field(..., description="Check period end")
|
|
product_ids: Optional[List[UUID]] = Field(
|
|
None,
|
|
description="Specific products to check"
|
|
)
|
|
include_recommendations: bool = Field(
|
|
default=True,
|
|
description="Include improvement recommendations"
|
|
)
|
|
|
|
@validator('start_date', 'end_date')
|
|
def validate_date(cls, v):
|
|
try:
|
|
datetime.fromisoformat(v)
|
|
except ValueError:
|
|
raise ValueError(f"Invalid date format: {v}")
|
|
return v
|
|
|
|
|
|
class ModelQueryParams(BaseModel):
|
|
"""Query parameters for model listing"""
|
|
|
|
tenant_id: Optional[UUID] = None
|
|
product_id: Optional[UUID] = None
|
|
is_active: Optional[bool] = None
|
|
is_production: Optional[bool] = None
|
|
model_type: Optional[str] = Field(None, regex="^(prophet|prophet_optimized|lstm|arima)$")
|
|
min_accuracy: Optional[float] = Field(None, ge=0.0, le=1.0)
|
|
created_after: Optional[datetime] = None
|
|
created_before: Optional[datetime] = None
|
|
limit: int = Field(default=100, ge=1, le=1000)
|
|
offset: int = Field(default=0, ge=0)
|
|
|
|
class Config:
|
|
schema_extra = {
|
|
"example": {
|
|
"tenant_id": "123e4567-e89b-12d3-a456-426614174000",
|
|
"is_active": True,
|
|
"is_production": True,
|
|
"limit": 50,
|
|
"offset": 0
|
|
}
|
|
}
|
|
|
|
|
|
def validate_uuid(value: str) -> UUID:
|
|
"""Validate and convert string to UUID"""
|
|
try:
|
|
return UUID(value)
|
|
except (ValueError, AttributeError):
|
|
raise ValueError(f"Invalid UUID format: {value}")
|
|
|
|
|
|
def validate_date_string(value: str) -> datetime:
|
|
"""Validate and convert date string to datetime"""
|
|
try:
|
|
return datetime.fromisoformat(value)
|
|
except ValueError:
|
|
raise ValueError(f"Invalid date format: {value}. Use ISO format (YYYY-MM-DD)")
|
|
|
|
|
|
def validate_positive_integer(value: int, field_name: str = "value") -> int:
|
|
"""Validate positive integer"""
|
|
if value <= 0:
|
|
raise ValueError(f"{field_name} must be positive, got {value}")
|
|
return value
|
|
|
|
|
|
def validate_probability(value: float, field_name: str = "value") -> float:
|
|
"""Validate probability value (0.0-1.0)"""
|
|
if not 0.0 <= value <= 1.0:
|
|
raise ValueError(f"{field_name} must be between 0.0 and 1.0, got {value}")
|
|
return value
|