Initial commit - production deployment
This commit is contained in:
0
services/training/app/schemas/__init__.py
Normal file
0
services/training/app/schemas/__init__.py
Normal file
384
services/training/app/schemas/training.py
Normal file
384
services/training/app/schemas/training.py
Normal file
@@ -0,0 +1,384 @@
|
||||
# services/training/app/schemas/training.py
|
||||
"""
|
||||
Complete schema definitions for training service
|
||||
Includes all request/response schemas used by the API endpoints
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from typing import List, Optional, Dict, Any, Union, Tuple
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
class TrainingStatus(str, Enum):
|
||||
"""Training job status enumeration"""
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class TrainingJobRequest(BaseModel):
|
||||
"""Request schema for starting a training job"""
|
||||
products: Optional[List[str]] = Field(None, description="Specific products to train (if None, trains all)")
|
||||
start_date: Optional[datetime] = Field(None, description="Start date for training data")
|
||||
end_date: Optional[datetime] = Field(None, description="End date for training data")
|
||||
|
||||
|
||||
class SingleProductTrainingRequest(BaseModel):
|
||||
"""Request schema for training a single product"""
|
||||
start_date: Optional[datetime] = Field(None, description="Start date for training data")
|
||||
end_date: Optional[datetime] = Field(None, description="End date for training data")
|
||||
|
||||
# Prophet-specific parameters
|
||||
seasonality_mode: str = Field("additive", description="Prophet seasonality mode")
|
||||
daily_seasonality: bool = Field(True, description="Enable daily seasonality")
|
||||
weekly_seasonality: bool = Field(True, description="Enable weekly seasonality")
|
||||
yearly_seasonality: bool = Field(True, description="Enable yearly seasonality")
|
||||
|
||||
# Location parameters
|
||||
bakery_location: Optional[Tuple[float, float]] = Field(None, description="Bakery coordinates (latitude, longitude)")
|
||||
|
||||
class DateRangeInfo(BaseModel):
|
||||
"""Schema for date range information"""
|
||||
start: str = Field(..., description="Start date in ISO format")
|
||||
end: str = Field(..., description="End date in ISO format")
|
||||
|
||||
class DataSummary(BaseModel):
|
||||
"""Schema for training data summary"""
|
||||
sales_records: int = Field(..., description="Number of sales records used")
|
||||
weather_records: int = Field(..., description="Number of weather records used")
|
||||
traffic_records: int = Field(..., description="Number of traffic records used")
|
||||
date_range: DateRangeInfo = Field(..., description="Date range of training data")
|
||||
data_sources_used: List[str] = Field(..., description="List of data sources used")
|
||||
constraints_applied: Dict[str, str] = Field(default_factory=dict, description="Constraints applied during data collection")
|
||||
|
||||
class ProductTrainingResult(BaseModel):
|
||||
"""Schema for individual product training results"""
|
||||
inventory_product_id: UUID = Field(..., description="Inventory product UUID")
|
||||
status: str = Field(..., description="Training status for this product")
|
||||
model_id: Optional[str] = Field(None, description="Trained model identifier")
|
||||
data_points: int = Field(..., description="Number of data points used for training")
|
||||
metrics: Optional[Dict[str, float]] = Field(None, description="Training metrics (MAE, MAPE, etc.)")
|
||||
training_time_seconds: Optional[float] = Field(None, description="Time taken to train this model")
|
||||
error_message: Optional[str] = Field(None, description="Error message if training failed")
|
||||
|
||||
class TrainingResults(BaseModel):
|
||||
"""Schema for overall training results"""
|
||||
total_products: int = Field(..., description="Total number of products")
|
||||
successful_trainings: int = Field(..., description="Number of successfully trained models")
|
||||
failed_trainings: int = Field(..., description="Number of failed trainings")
|
||||
products: List[ProductTrainingResult] = Field(..., description="Results for each product")
|
||||
overall_training_time_seconds: float = Field(..., description="Total training time")
|
||||
|
||||
class TrainingJobResponse(BaseModel):
|
||||
"""Enhanced response schema for training job with detailed results"""
|
||||
job_id: str = Field(..., description="Unique training job identifier")
|
||||
tenant_id: str = Field(..., description="Tenant identifier")
|
||||
status: TrainingStatus = Field(..., description="Overall job status")
|
||||
|
||||
# Required fields for basic response (backwards compatibility)
|
||||
message: str = Field(..., description="Status message")
|
||||
created_at: datetime = Field(..., description="Job creation timestamp")
|
||||
estimated_duration_minutes: int = Field(..., description="Estimated completion time in minutes")
|
||||
|
||||
# New detailed fields (optional for backwards compatibility)
|
||||
training_results: Optional[TrainingResults] = Field(None, description="Detailed training results")
|
||||
data_summary: Optional[DataSummary] = Field(None, description="Summary of training data used")
|
||||
completed_at: Optional[str] = Field(None, description="Job completion timestamp in ISO format")
|
||||
|
||||
# Additional optional fields
|
||||
error_details: Optional[Dict[str, Any]] = Field(None, description="Detailed error information if failed")
|
||||
processing_metadata: Optional[Dict[str, Any]] = Field(None, description="Additional processing metadata")
|
||||
|
||||
@validator('tenant_id', 'job_id', pre=True)
|
||||
def convert_uuid_to_string(cls, v):
|
||||
"""Convert UUID objects to strings for JSON serialization"""
|
||||
if isinstance(v, UUID):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class TrainingJobStatus(BaseModel):
|
||||
"""Response schema for training job status checks"""
|
||||
job_id: str = Field(..., description="Training job identifier")
|
||||
status: TrainingStatus = Field(..., description="Current job status")
|
||||
progress: int = Field(0, description="Progress percentage (0-100)")
|
||||
current_step: str = Field("", description="Current processing step")
|
||||
started_at: datetime = Field(..., description="Job start timestamp")
|
||||
completed_at: Optional[datetime] = Field(None, description="Job completion timestamp")
|
||||
products_total: int = Field(0, description="Total number of products to train")
|
||||
products_completed: int = Field(0, description="Number of products completed")
|
||||
products_failed: int = Field(0, description="Number of products that failed")
|
||||
error_message: Optional[str] = Field(None, description="Error message if failed")
|
||||
estimated_time_remaining_seconds: Optional[int] = Field(None, description="Estimated time remaining in seconds")
|
||||
message: Optional[str] = Field(None, description="Optional status message")
|
||||
|
||||
@validator('job_id', pre=True)
|
||||
def convert_uuid_to_string(cls, v):
|
||||
"""Convert UUID objects to strings for JSON serialization"""
|
||||
if isinstance(v, UUID):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class TrainingJobProgress(BaseModel):
|
||||
"""Schema for real-time training job progress updates"""
|
||||
job_id: str = Field(..., description="Training job identifier")
|
||||
status: TrainingStatus = Field(..., description="Current job status")
|
||||
progress: int = Field(0, description="Progress percentage (0-100)", ge=0, le=100)
|
||||
current_step: str = Field(..., description="Current processing step")
|
||||
current_product: Optional[str] = Field(None, description="Currently training product")
|
||||
products_completed: int = Field(0, description="Number of products completed")
|
||||
products_total: int = Field(0, description="Total number of products")
|
||||
estimated_time_remaining_minutes: Optional[int] = Field(None, description="Estimated time remaining")
|
||||
timestamp: datetime = Field(default_factory=datetime.now, description="Progress update timestamp")
|
||||
|
||||
@validator('job_id', pre=True)
|
||||
def convert_uuid_to_string(cls, v):
|
||||
"""Convert UUID objects to strings for JSON serialization"""
|
||||
if isinstance(v, UUID):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class DataValidationRequest(BaseModel):
|
||||
"""Request schema for validating training data"""
|
||||
products: Optional[List[str]] = Field(None, description="Specific products to validate (if None, validates all)")
|
||||
min_data_points: int = Field(30, description="Minimum required data points per product", ge=10, le=1000)
|
||||
start_date: Optional[datetime] = Field(None, description="Start date for data validation")
|
||||
end_date: Optional[datetime] = Field(None, description="End date for data validation")
|
||||
|
||||
@validator('min_data_points')
|
||||
def validate_min_data_points(cls, v):
|
||||
if v < 10:
|
||||
raise ValueError('min_data_points must be at least 10')
|
||||
return v
|
||||
|
||||
|
||||
class DataValidationResponse(BaseModel):
|
||||
"""Response schema for data validation results"""
|
||||
is_valid: bool = Field(..., description="Whether the data is valid for training")
|
||||
issues: List[str] = Field(default_factory=list, description="List of data quality issues")
|
||||
recommendations: List[str] = Field(default_factory=list, description="Recommendations for improvement")
|
||||
estimated_time_minutes: int = Field(..., description="Estimated training time in minutes")
|
||||
products_analyzed: int = Field(..., description="Number of products analyzed")
|
||||
total_data_points: int = Field(..., description="Total data points available")
|
||||
products_with_insufficient_data: List[str] = Field(default_factory=list, description="Products with insufficient data")
|
||||
data_quality_score: float = Field(0.0, description="Overall data quality score (0-1)", ge=0.0, le=1.0)
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
"""Schema for trained model information"""
|
||||
model_id: str = Field(..., description="Unique model identifier")
|
||||
model_path: str = Field(..., description="Path to stored model")
|
||||
model_type: str = Field("prophet", description="Type of ML model")
|
||||
training_samples: int = Field(..., description="Number of training samples")
|
||||
features: List[str] = Field(..., description="List of features used")
|
||||
hyperparameters: Dict[str, Any] = Field(..., description="Model hyperparameters")
|
||||
training_metrics: Dict[str, float] = Field(..., description="Training performance metrics")
|
||||
trained_at: datetime = Field(..., description="Training completion timestamp")
|
||||
data_period: Dict[str, str] = Field(..., description="Training data period")
|
||||
|
||||
|
||||
class ProductTrainingResult(BaseModel):
|
||||
"""Schema for individual product training result"""
|
||||
inventory_product_id: UUID = Field(..., description="Inventory product UUID")
|
||||
status: str = Field(..., description="Training status for this product")
|
||||
model_info: Optional[ModelInfo] = Field(None, description="Model information if successful")
|
||||
data_points: int = Field(..., description="Number of data points used")
|
||||
error_message: Optional[str] = Field(None, description="Error message if failed")
|
||||
trained_at: datetime = Field(..., description="Training completion timestamp")
|
||||
training_duration_seconds: Optional[float] = Field(None, description="Training duration in seconds")
|
||||
|
||||
|
||||
class TrainingResultsResponse(BaseModel):
|
||||
"""Response schema for complete training results"""
|
||||
job_id: str = Field(..., description="Training job identifier")
|
||||
tenant_id: str = Field(..., description="Tenant identifier")
|
||||
status: TrainingStatus = Field(..., description="Overall job status")
|
||||
products_trained: int = Field(..., description="Number of products successfully trained")
|
||||
products_failed: int = Field(..., description="Number of products that failed training")
|
||||
total_products: int = Field(..., description="Total number of products processed")
|
||||
training_results: Dict[str, ProductTrainingResult] = Field(..., description="Per-product results")
|
||||
summary: Dict[str, Any] = Field(..., description="Training summary statistics")
|
||||
completed_at: datetime = Field(..., description="Job completion timestamp")
|
||||
|
||||
@validator('tenant_id', 'job_id', pre=True)
|
||||
def convert_uuid_to_string(cls, v):
|
||||
"""Convert UUID objects to strings for JSON serialization"""
|
||||
if isinstance(v, UUID):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class TrainingValidationResult(BaseModel):
|
||||
"""Schema for training data validation results"""
|
||||
is_valid: bool = Field(..., description="Whether the data is valid for training")
|
||||
issues: List[str] = Field(default_factory=list, description="List of data quality issues")
|
||||
recommendations: List[str] = Field(default_factory=list, description="Recommendations for improvement")
|
||||
estimated_time_minutes: int = Field(..., description="Estimated training time in minutes")
|
||||
products_analyzed: int = Field(..., description="Number of products analyzed")
|
||||
total_data_points: int = Field(..., description="Total data points available")
|
||||
|
||||
|
||||
class TrainingMetrics(BaseModel):
|
||||
"""Schema for training performance metrics"""
|
||||
mae: float = Field(..., description="Mean Absolute Error")
|
||||
mse: float = Field(..., description="Mean Squared Error")
|
||||
rmse: float = Field(..., description="Root Mean Squared Error")
|
||||
mape: float = Field(..., description="Mean Absolute Percentage Error")
|
||||
r2_score: float = Field(..., description="R-squared score")
|
||||
mean_actual: float = Field(..., description="Mean of actual values")
|
||||
mean_predicted: float = Field(..., description="Mean of predicted values")
|
||||
|
||||
|
||||
class ExternalDataConfig(BaseModel):
|
||||
"""Configuration for external data sources"""
|
||||
weather_enabled: bool = Field(True, description="Enable weather data")
|
||||
traffic_enabled: bool = Field(True, description="Enable traffic data")
|
||||
weather_features: List[str] = Field(
|
||||
default_factory=lambda: ["temperature", "precipitation", "humidity"],
|
||||
description="Weather features to include"
|
||||
)
|
||||
traffic_features: List[str] = Field(
|
||||
default_factory=lambda: ["traffic_volume"],
|
||||
description="Traffic features to include"
|
||||
)
|
||||
|
||||
|
||||
class TrainingJobConfig(BaseModel):
|
||||
"""Complete training job configuration"""
|
||||
external_data: ExternalDataConfig = Field(default_factory=ExternalDataConfig)
|
||||
prophet_params: Dict[str, Any] = Field(
|
||||
default_factory=lambda: {
|
||||
"seasonality_mode": "additive",
|
||||
"daily_seasonality": True,
|
||||
"weekly_seasonality": True,
|
||||
"yearly_seasonality": True
|
||||
},
|
||||
description="Prophet model parameters"
|
||||
)
|
||||
data_filters: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Data filtering parameters"
|
||||
)
|
||||
validation_params: Dict[str, Any] = Field(
|
||||
default_factory=lambda: {"min_data_points": 30},
|
||||
description="Data validation parameters"
|
||||
)
|
||||
|
||||
|
||||
class TrainedModelResponse(BaseModel):
|
||||
"""Response schema for trained model information"""
|
||||
model_id: str = Field(..., description="Unique model identifier")
|
||||
tenant_id: str = Field(..., description="Tenant identifier")
|
||||
inventory_product_id: UUID = Field(..., description="Inventory product UUID")
|
||||
model_type: str = Field(..., description="Type of ML model")
|
||||
model_path: str = Field(..., description="Path to stored model")
|
||||
version: int = Field(..., description="Model version")
|
||||
training_samples: int = Field(..., description="Number of training samples")
|
||||
features: List[str] = Field(..., description="List of features used")
|
||||
hyperparameters: Dict[str, Any] = Field(..., description="Model hyperparameters")
|
||||
training_metrics: Dict[str, float] = Field(..., description="Training performance metrics")
|
||||
is_active: bool = Field(..., description="Whether model is active")
|
||||
created_at: datetime = Field(..., description="Model creation timestamp")
|
||||
data_period_start: Optional[datetime] = Field(None, description="Training data start date")
|
||||
data_period_end: Optional[datetime] = Field(None, description="Training data end date")
|
||||
|
||||
@validator('tenant_id', 'model_id', pre=True)
|
||||
def convert_uuid_to_string(cls, v):
|
||||
"""Convert UUID objects to strings for JSON serialization"""
|
||||
if isinstance(v, UUID):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ModelTrainingStats(BaseModel):
|
||||
"""Schema for model training statistics"""
|
||||
total_models: int = Field(..., description="Total number of trained models")
|
||||
active_models: int = Field(..., description="Number of active models")
|
||||
last_training_date: Optional[datetime] = Field(None, description="Last training date")
|
||||
avg_training_time_minutes: float = Field(..., description="Average training time in minutes")
|
||||
success_rate: float = Field(..., description="Training success rate (0-1)")
|
||||
|
||||
|
||||
class BulkTrainingRequest(BaseModel):
|
||||
"""Request schema for bulk training operations"""
|
||||
tenant_ids: List[str] = Field(..., description="List of tenant IDs to train")
|
||||
config: TrainingJobConfig = Field(default_factory=TrainingJobConfig, description="Training configuration")
|
||||
priority: int = Field(1, description="Training priority (1-10)", ge=1, le=10)
|
||||
schedule_time: Optional[datetime] = Field(None, description="Schedule training for specific time")
|
||||
|
||||
|
||||
class TrainingScheduleResponse(BaseModel):
|
||||
"""Response schema for scheduled training jobs"""
|
||||
schedule_id: str = Field(..., description="Unique schedule identifier")
|
||||
tenant_ids: List[str] = Field(..., description="Scheduled tenant IDs")
|
||||
scheduled_time: datetime = Field(..., description="Scheduled execution time")
|
||||
status: str = Field(..., description="Schedule status")
|
||||
created_at: datetime = Field(..., description="Schedule creation timestamp")
|
||||
|
||||
|
||||
# WebSocket response schemas for real-time updates
|
||||
class TrainingProgressUpdate(BaseModel):
|
||||
"""WebSocket message for training progress updates"""
|
||||
type: str = Field("training_progress", description="Message type")
|
||||
job_id: str = Field(..., description="Training job identifier")
|
||||
progress: TrainingJobProgress = Field(..., description="Progress information")
|
||||
|
||||
|
||||
class TrainingCompletedUpdate(BaseModel):
|
||||
"""WebSocket message for training completion"""
|
||||
type: str = Field("training_completed", description="Message type")
|
||||
job_id: str = Field(..., description="Training job identifier")
|
||||
results: TrainingResultsResponse = Field(..., description="Training results")
|
||||
|
||||
|
||||
class TrainingErrorUpdate(BaseModel):
|
||||
"""WebSocket message for training errors"""
|
||||
type: str = Field("training_error", description="Message type")
|
||||
job_id: str = Field(..., description="Training job identifier")
|
||||
error: str = Field(..., description="Error message")
|
||||
timestamp: datetime = Field(default_factory=datetime.now, description="Error timestamp")
|
||||
|
||||
|
||||
class ModelMetricsResponse(BaseModel):
|
||||
"""Response schema for model performance metrics"""
|
||||
model_id: str = Field(..., description="Unique model identifier")
|
||||
accuracy: float = Field(..., description="Model accuracy (R2 score)")
|
||||
mape: float = Field(..., description="Mean Absolute Percentage Error")
|
||||
mae: float = Field(..., description="Mean Absolute Error")
|
||||
rmse: float = Field(..., description="Root Mean Square Error")
|
||||
r2_score: float = Field(..., description="R-squared score")
|
||||
training_samples: int = Field(..., description="Number of training samples used")
|
||||
features_used: List[str] = Field(..., description="List of features used in training")
|
||||
model_type: str = Field(..., description="Type of ML model")
|
||||
created_at: Optional[str] = Field(None, description="Model creation timestamp")
|
||||
last_used_at: Optional[str] = Field(None, description="Last time model was used")
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
# Union type for all WebSocket messages
|
||||
TrainingWebSocketMessage = Union[
|
||||
TrainingProgressUpdate,
|
||||
TrainingCompletedUpdate,
|
||||
TrainingErrorUpdate
|
||||
]
|
||||
317
services/training/app/schemas/validation.py
Normal file
317
services/training/app/schemas/validation.py
Normal file
@@ -0,0 +1,317 @@
|
||||
"""
|
||||
Comprehensive Input Validation Schemas
|
||||
Ensures all API inputs are properly validated before processing
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field, validator, root_validator
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
from uuid import UUID
|
||||
import re
|
||||
|
||||
|
||||
class TrainingJobCreateRequest(BaseModel):
|
||||
"""Schema for creating a new training job"""
|
||||
|
||||
tenant_id: UUID = Field(..., description="Tenant identifier")
|
||||
start_date: Optional[str] = Field(
|
||||
None,
|
||||
description="Training data start date (ISO format: YYYY-MM-DD)",
|
||||
example="2024-01-01"
|
||||
)
|
||||
end_date: Optional[str] = Field(
|
||||
None,
|
||||
description="Training data end date (ISO format: YYYY-MM-DD)",
|
||||
example="2024-12-31"
|
||||
)
|
||||
product_ids: Optional[List[UUID]] = Field(
|
||||
None,
|
||||
description="Specific products to train (optional, trains all if not provided)"
|
||||
)
|
||||
force_retrain: bool = Field(
|
||||
default=False,
|
||||
description="Force retraining even if recent models exist"
|
||||
)
|
||||
|
||||
@validator('start_date', 'end_date')
|
||||
def validate_date_format(cls, v):
|
||||
"""Validate date is in ISO format"""
|
||||
if v is not None:
|
||||
try:
|
||||
datetime.fromisoformat(v)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid date format: {v}. Use YYYY-MM-DD format")
|
||||
return v
|
||||
|
||||
@root_validator
|
||||
def validate_date_range(cls, values):
|
||||
"""Validate date range is logical"""
|
||||
start = values.get('start_date')
|
||||
end = values.get('end_date')
|
||||
|
||||
if start and end:
|
||||
start_dt = datetime.fromisoformat(start)
|
||||
end_dt = datetime.fromisoformat(end)
|
||||
|
||||
if end_dt <= start_dt:
|
||||
raise ValueError("end_date must be after start_date")
|
||||
|
||||
# Check reasonable range (max 3 years)
|
||||
if (end_dt - start_dt).days > 1095:
|
||||
raise ValueError("Date range cannot exceed 3 years (1095 days)")
|
||||
|
||||
# Check not in future
|
||||
if end_dt > datetime.now():
|
||||
raise ValueError("end_date cannot be in the future")
|
||||
|
||||
return values
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"tenant_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"start_date": "2024-01-01",
|
||||
"end_date": "2024-12-31",
|
||||
"product_ids": None,
|
||||
"force_retrain": False
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ForecastRequest(BaseModel):
|
||||
"""Schema for generating forecasts"""
|
||||
|
||||
tenant_id: UUID = Field(..., description="Tenant identifier")
|
||||
product_id: UUID = Field(..., description="Product identifier")
|
||||
forecast_days: int = Field(
|
||||
default=30,
|
||||
ge=1,
|
||||
le=365,
|
||||
description="Number of days to forecast (1-365)"
|
||||
)
|
||||
include_regressors: bool = Field(
|
||||
default=True,
|
||||
description="Include weather and traffic data in forecast"
|
||||
)
|
||||
confidence_level: float = Field(
|
||||
default=0.80,
|
||||
ge=0.5,
|
||||
le=0.99,
|
||||
description="Confidence interval (0.5-0.99)"
|
||||
)
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"tenant_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"product_id": "223e4567-e89b-12d3-a456-426614174000",
|
||||
"forecast_days": 30,
|
||||
"include_regressors": True,
|
||||
"confidence_level": 0.80
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ModelEvaluationRequest(BaseModel):
|
||||
"""Schema for model evaluation"""
|
||||
|
||||
tenant_id: UUID = Field(..., description="Tenant identifier")
|
||||
product_id: Optional[UUID] = Field(None, description="Specific product (optional)")
|
||||
evaluation_start_date: str = Field(..., description="Evaluation period start")
|
||||
evaluation_end_date: str = Field(..., description="Evaluation period end")
|
||||
|
||||
@validator('evaluation_start_date', 'evaluation_end_date')
|
||||
def validate_date_format(cls, v):
|
||||
try:
|
||||
datetime.fromisoformat(v)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid date format: {v}")
|
||||
return v
|
||||
|
||||
@root_validator
|
||||
def validate_evaluation_period(cls, values):
|
||||
start = values.get('evaluation_start_date')
|
||||
end = values.get('evaluation_end_date')
|
||||
|
||||
if start and end:
|
||||
start_dt = datetime.fromisoformat(start)
|
||||
end_dt = datetime.fromisoformat(end)
|
||||
|
||||
if end_dt <= start_dt:
|
||||
raise ValueError("evaluation_end_date must be after evaluation_start_date")
|
||||
|
||||
# Minimum 7 days for meaningful evaluation
|
||||
if (end_dt - start_dt).days < 7:
|
||||
raise ValueError("Evaluation period must be at least 7 days")
|
||||
|
||||
return values
|
||||
|
||||
|
||||
class BulkTrainingRequest(BaseModel):
|
||||
"""Schema for bulk training operations"""
|
||||
|
||||
tenant_ids: List[UUID] = Field(
|
||||
...,
|
||||
min_items=1,
|
||||
max_items=100,
|
||||
description="List of tenant IDs (max 100)"
|
||||
)
|
||||
start_date: Optional[str] = Field(None, description="Common start date")
|
||||
end_date: Optional[str] = Field(None, description="Common end date")
|
||||
parallel: bool = Field(
|
||||
default=True,
|
||||
description="Execute training jobs in parallel"
|
||||
)
|
||||
|
||||
@validator('tenant_ids')
|
||||
def validate_unique_tenants(cls, v):
|
||||
if len(v) != len(set(v)):
|
||||
raise ValueError("Duplicate tenant IDs not allowed")
|
||||
return v
|
||||
|
||||
|
||||
class HyperparameterOverride(BaseModel):
|
||||
"""Schema for manual hyperparameter override"""
|
||||
|
||||
changepoint_prior_scale: Optional[float] = Field(
|
||||
None, ge=0.001, le=0.5,
|
||||
description="Flexibility of trend changes"
|
||||
)
|
||||
seasonality_prior_scale: Optional[float] = Field(
|
||||
None, ge=0.01, le=10.0,
|
||||
description="Strength of seasonality"
|
||||
)
|
||||
holidays_prior_scale: Optional[float] = Field(
|
||||
None, ge=0.01, le=10.0,
|
||||
description="Strength of holiday effects"
|
||||
)
|
||||
seasonality_mode: Optional[str] = Field(
|
||||
None,
|
||||
description="Seasonality mode",
|
||||
regex="^(additive|multiplicative)$"
|
||||
)
|
||||
daily_seasonality: Optional[bool] = None
|
||||
weekly_seasonality: Optional[bool] = None
|
||||
yearly_seasonality: Optional[bool] = None
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"changepoint_prior_scale": 0.05,
|
||||
"seasonality_prior_scale": 10.0,
|
||||
"holidays_prior_scale": 10.0,
|
||||
"seasonality_mode": "additive",
|
||||
"daily_seasonality": False,
|
||||
"weekly_seasonality": True,
|
||||
"yearly_seasonality": True
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class AdvancedTrainingRequest(TrainingJobCreateRequest):
|
||||
"""Extended training request with advanced options"""
|
||||
|
||||
hyperparameter_override: Optional[HyperparameterOverride] = Field(
|
||||
None,
|
||||
description="Manual hyperparameter settings (skips optimization)"
|
||||
)
|
||||
enable_cross_validation: bool = Field(
|
||||
default=True,
|
||||
description="Enable cross-validation during training"
|
||||
)
|
||||
cv_folds: int = Field(
|
||||
default=3,
|
||||
ge=2,
|
||||
le=10,
|
||||
description="Number of cross-validation folds"
|
||||
)
|
||||
optimization_trials: Optional[int] = Field(
|
||||
None,
|
||||
ge=5,
|
||||
le=100,
|
||||
description="Number of hyperparameter optimization trials (overrides defaults)"
|
||||
)
|
||||
save_diagnostics: bool = Field(
|
||||
default=False,
|
||||
description="Save detailed diagnostic plots and metrics"
|
||||
)
|
||||
|
||||
|
||||
class DataQualityCheckRequest(BaseModel):
|
||||
"""Schema for data quality validation"""
|
||||
|
||||
tenant_id: UUID = Field(..., description="Tenant identifier")
|
||||
start_date: str = Field(..., description="Check period start")
|
||||
end_date: str = Field(..., description="Check period end")
|
||||
product_ids: Optional[List[UUID]] = Field(
|
||||
None,
|
||||
description="Specific products to check"
|
||||
)
|
||||
include_recommendations: bool = Field(
|
||||
default=True,
|
||||
description="Include improvement recommendations"
|
||||
)
|
||||
|
||||
@validator('start_date', 'end_date')
|
||||
def validate_date(cls, v):
|
||||
try:
|
||||
datetime.fromisoformat(v)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid date format: {v}")
|
||||
return v
|
||||
|
||||
|
||||
class ModelQueryParams(BaseModel):
|
||||
"""Query parameters for model listing"""
|
||||
|
||||
tenant_id: Optional[UUID] = None
|
||||
product_id: Optional[UUID] = None
|
||||
is_active: Optional[bool] = None
|
||||
is_production: Optional[bool] = None
|
||||
model_type: Optional[str] = Field(None, regex="^(prophet|prophet_optimized|lstm|arima)$")
|
||||
min_accuracy: Optional[float] = Field(None, ge=0.0, le=1.0)
|
||||
created_after: Optional[datetime] = None
|
||||
created_before: Optional[datetime] = None
|
||||
limit: int = Field(default=100, ge=1, le=1000)
|
||||
offset: int = Field(default=0, ge=0)
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"tenant_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"is_active": True,
|
||||
"is_production": True,
|
||||
"limit": 50,
|
||||
"offset": 0
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def validate_uuid(value: str) -> UUID:
|
||||
"""Validate and convert string to UUID"""
|
||||
try:
|
||||
return UUID(value)
|
||||
except (ValueError, AttributeError):
|
||||
raise ValueError(f"Invalid UUID format: {value}")
|
||||
|
||||
|
||||
def validate_date_string(value: str) -> datetime:
|
||||
"""Validate and convert date string to datetime"""
|
||||
try:
|
||||
return datetime.fromisoformat(value)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid date format: {value}. Use ISO format (YYYY-MM-DD)")
|
||||
|
||||
|
||||
def validate_positive_integer(value: int, field_name: str = "value") -> int:
|
||||
"""Validate positive integer"""
|
||||
if value <= 0:
|
||||
raise ValueError(f"{field_name} must be positive, got {value}")
|
||||
return value
|
||||
|
||||
|
||||
def validate_probability(value: float, field_name: str = "value") -> float:
|
||||
"""Validate probability value (0.0-1.0)"""
|
||||
if not 0.0 <= value <= 1.0:
|
||||
raise ValueError(f"{field_name} must be between 0.0 and 1.0, got {value}")
|
||||
return value
|
||||
Reference in New Issue
Block a user