Add all the code for training service
This commit is contained in:
@@ -1,91 +1,181 @@
|
||||
# services/training/app/schemas/training.py
|
||||
"""
|
||||
Training schemas
|
||||
Pydantic schemas for training service
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from typing import Optional, Dict, Any, List
|
||||
from typing import Dict, List, Any, Optional
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
class TrainingJobStatus(str, Enum):
|
||||
"""Training job status enum"""
|
||||
QUEUED = "queued"
|
||||
class TrainingStatus(str, Enum):
|
||||
"""Training job status enumeration"""
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
class TrainingRequest(BaseModel):
|
||||
"""Training request schema"""
|
||||
tenant_id: Optional[str] = None # Will be set from auth
|
||||
force_retrain: bool = Field(default=False, description="Force retrain even if recent models exist")
|
||||
products: Optional[List[str]] = Field(default=None, description="Specific products to train, or None for all")
|
||||
training_days: Optional[int] = Field(default=730, ge=30, le=1095, description="Number of days of historical data to use")
|
||||
class TrainingJobRequest(BaseModel):
|
||||
"""Request schema for starting a training job"""
|
||||
products: Optional[List[str]] = Field(None, description="Specific products to train (if None, train all)")
|
||||
include_weather: bool = Field(True, description="Include weather data in training")
|
||||
include_traffic: bool = Field(True, description="Include traffic data in training")
|
||||
start_date: Optional[datetime] = Field(None, description="Start date for training data")
|
||||
end_date: Optional[datetime] = Field(None, description="End date for training data")
|
||||
min_data_points: int = Field(30, description="Minimum data points required per product")
|
||||
estimated_duration: Optional[int] = Field(None, description="Estimated duration in minutes")
|
||||
|
||||
@validator('training_days')
|
||||
def validate_training_days(cls, v):
|
||||
if v < 30:
|
||||
raise ValueError('Minimum training days is 30')
|
||||
if v > 1095:
|
||||
raise ValueError('Maximum training days is 1095 (3 years)')
|
||||
# Prophet-specific parameters
|
||||
seasonality_mode: str = Field("additive", description="Prophet seasonality mode")
|
||||
daily_seasonality: bool = Field(True, description="Enable daily seasonality")
|
||||
weekly_seasonality: bool = Field(True, description="Enable weekly seasonality")
|
||||
yearly_seasonality: bool = Field(True, description="Enable yearly seasonality")
|
||||
|
||||
@validator('seasonality_mode')
|
||||
def validate_seasonality_mode(cls, v):
|
||||
if v not in ['additive', 'multiplicative']:
|
||||
raise ValueError('seasonality_mode must be additive or multiplicative')
|
||||
return v
|
||||
|
||||
@validator('min_data_points')
|
||||
def validate_min_data_points(cls, v):
|
||||
if v < 7:
|
||||
raise ValueError('min_data_points must be at least 7')
|
||||
return v
|
||||
|
||||
class SingleProductTrainingRequest(BaseModel):
|
||||
"""Request schema for training a single product"""
|
||||
include_weather: bool = Field(True, description="Include weather data in training")
|
||||
include_traffic: bool = Field(True, description="Include traffic data in training")
|
||||
start_date: Optional[datetime] = Field(None, description="Start date for training data")
|
||||
end_date: Optional[datetime] = Field(None, description="End date for training data")
|
||||
|
||||
# Prophet-specific parameters
|
||||
seasonality_mode: str = Field("additive", description="Prophet seasonality mode")
|
||||
daily_seasonality: bool = Field(True, description="Enable daily seasonality")
|
||||
weekly_seasonality: bool = Field(True, description="Enable weekly seasonality")
|
||||
yearly_seasonality: bool = Field(True, description="Enable yearly seasonality")
|
||||
|
||||
class TrainingJobResponse(BaseModel):
|
||||
"""Training job response schema"""
|
||||
id: str
|
||||
tenant_id: str
|
||||
status: TrainingJobStatus
|
||||
progress: int
|
||||
current_step: Optional[str]
|
||||
started_at: datetime
|
||||
completed_at: Optional[datetime]
|
||||
duration_seconds: Optional[int]
|
||||
models_trained: Optional[Dict[str, Any]]
|
||||
metrics: Optional[Dict[str, Any]]
|
||||
error_message: Optional[str]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
"""Response schema for training job creation"""
|
||||
job_id: str = Field(..., description="Unique training job identifier")
|
||||
status: TrainingStatus = Field(..., description="Current job status")
|
||||
message: str = Field(..., description="Status message")
|
||||
tenant_id: str = Field(..., description="Tenant identifier")
|
||||
created_at: datetime = Field(..., description="Job creation timestamp")
|
||||
estimated_duration_minutes: int = Field(..., description="Estimated completion time in minutes")
|
||||
|
||||
class TrainedModelResponse(BaseModel):
|
||||
"""Trained model response schema"""
|
||||
id: str
|
||||
product_name: str
|
||||
model_type: str
|
||||
model_version: str
|
||||
mape: Optional[float]
|
||||
rmse: Optional[float]
|
||||
mae: Optional[float]
|
||||
r2_score: Optional[float]
|
||||
training_samples: Optional[int]
|
||||
features_used: Optional[List[str]]
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
last_used_at: Optional[datetime]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
class TrainingStatusResponse(BaseModel):
|
||||
"""Response schema for training job status"""
|
||||
job_id: str = Field(..., description="Training job identifier")
|
||||
status: TrainingStatus = Field(..., description="Current job status")
|
||||
progress: int = Field(0, description="Progress percentage (0-100)")
|
||||
current_step: str = Field("", description="Current processing step")
|
||||
started_at: datetime = Field(..., description="Job start timestamp")
|
||||
completed_at: Optional[datetime] = Field(None, description="Job completion timestamp")
|
||||
results: Optional[Dict[str, Any]] = Field(None, description="Training results")
|
||||
error_message: Optional[str] = Field(None, description="Error message if failed")
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
"""Schema for trained model information"""
|
||||
model_id: str = Field(..., description="Unique model identifier")
|
||||
model_path: str = Field(..., description="Path to stored model")
|
||||
model_type: str = Field("prophet", description="Type of ML model")
|
||||
training_samples: int = Field(..., description="Number of training samples")
|
||||
features: List[str] = Field(..., description="List of features used")
|
||||
hyperparameters: Dict[str, Any] = Field(..., description="Model hyperparameters")
|
||||
training_metrics: Dict[str, float] = Field(..., description="Training performance metrics")
|
||||
trained_at: datetime = Field(..., description="Training completion timestamp")
|
||||
data_period: Dict[str, str] = Field(..., description="Training data period")
|
||||
|
||||
class ProductTrainingResult(BaseModel):
|
||||
"""Schema for individual product training result"""
|
||||
product_name: str = Field(..., description="Product name")
|
||||
status: str = Field(..., description="Training status for this product")
|
||||
model_info: Optional[ModelInfo] = Field(None, description="Model information if successful")
|
||||
data_points: int = Field(..., description="Number of data points used")
|
||||
error_message: Optional[str] = Field(None, description="Error message if failed")
|
||||
trained_at: datetime = Field(..., description="Training completion timestamp")
|
||||
|
||||
class TrainingResultsResponse(BaseModel):
|
||||
"""Response schema for complete training results"""
|
||||
job_id: str = Field(..., description="Training job identifier")
|
||||
tenant_id: str = Field(..., description="Tenant identifier")
|
||||
status: TrainingStatus = Field(..., description="Overall job status")
|
||||
products_trained: int = Field(..., description="Number of products successfully trained")
|
||||
products_failed: int = Field(..., description="Number of products that failed training")
|
||||
total_products: int = Field(..., description="Total number of products processed")
|
||||
training_results: Dict[str, ProductTrainingResult] = Field(..., description="Per-product results")
|
||||
summary: Dict[str, Any] = Field(..., description="Training summary statistics")
|
||||
completed_at: datetime = Field(..., description="Job completion timestamp")
|
||||
|
||||
class TrainingValidationResult(BaseModel):
|
||||
"""Schema for training data validation results"""
|
||||
is_valid: bool = Field(..., description="Whether the data is valid for training")
|
||||
issues: List[str] = Field(default_factory=list, description="List of data quality issues")
|
||||
recommendations: List[str] = Field(default_factory=list, description="Recommendations for improvement")
|
||||
estimated_time_minutes: int = Field(..., description="Estimated training time in minutes")
|
||||
products_analyzed: int = Field(..., description="Number of products analyzed")
|
||||
total_data_points: int = Field(..., description="Total data points available")
|
||||
|
||||
class TrainingProgress(BaseModel):
|
||||
"""Training progress update schema"""
|
||||
job_id: str
|
||||
progress: int
|
||||
current_step: str
|
||||
estimated_completion: Optional[datetime]
|
||||
|
||||
class TrainingMetrics(BaseModel):
|
||||
"""Training metrics schema"""
|
||||
total_jobs: int
|
||||
successful_jobs: int
|
||||
failed_jobs: int
|
||||
average_duration: float
|
||||
models_trained: int
|
||||
active_models: int
|
||||
"""Schema for training performance metrics"""
|
||||
mae: float = Field(..., description="Mean Absolute Error")
|
||||
mse: float = Field(..., description="Mean Squared Error")
|
||||
rmse: float = Field(..., description="Root Mean Squared Error")
|
||||
mape: float = Field(..., description="Mean Absolute Percentage Error")
|
||||
r2_score: float = Field(..., description="R-squared score")
|
||||
mean_actual: float = Field(..., description="Mean of actual values")
|
||||
mean_predicted: float = Field(..., description="Mean of predicted values")
|
||||
|
||||
class ModelValidationResult(BaseModel):
|
||||
"""Model validation result schema"""
|
||||
product_name: str
|
||||
is_valid: bool
|
||||
accuracy_score: float
|
||||
validation_error: Optional[str]
|
||||
recommendations: List[str]
|
||||
class ExternalDataConfig(BaseModel):
|
||||
"""Configuration for external data sources"""
|
||||
weather_enabled: bool = Field(True, description="Enable weather data")
|
||||
traffic_enabled: bool = Field(True, description="Enable traffic data")
|
||||
weather_features: List[str] = Field(
|
||||
default_factory=lambda: ["temperature", "precipitation", "humidity"],
|
||||
description="Weather features to include"
|
||||
)
|
||||
traffic_features: List[str] = Field(
|
||||
default_factory=lambda: ["traffic_volume"],
|
||||
description="Traffic features to include"
|
||||
)
|
||||
|
||||
class TrainingJobConfig(BaseModel):
|
||||
"""Complete training job configuration"""
|
||||
external_data: ExternalDataConfig = Field(default_factory=ExternalDataConfig)
|
||||
prophet_params: Dict[str, Any] = Field(
|
||||
default_factory=lambda: {
|
||||
"seasonality_mode": "additive",
|
||||
"daily_seasonality": True,
|
||||
"weekly_seasonality": True,
|
||||
"yearly_seasonality": True
|
||||
},
|
||||
description="Prophet model parameters"
|
||||
)
|
||||
data_filters: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Data filtering parameters"
|
||||
)
|
||||
validation_params: Dict[str, Any] = Field(
|
||||
default_factory=lambda: {"min_data_points": 30},
|
||||
description="Data validation parameters"
|
||||
)
|
||||
|
||||
class TrainedModelResponse(BaseModel):
|
||||
"""Response schema for trained model information"""
|
||||
model_id: str = Field(..., description="Unique model identifier")
|
||||
tenant_id: str = Field(..., description="Tenant identifier")
|
||||
product_name: str = Field(..., description="Product name")
|
||||
model_type: str = Field(..., description="Type of ML model")
|
||||
model_path: str = Field(..., description="Path to stored model")
|
||||
version: int = Field(..., description="Model version")
|
||||
training_samples: int = Field(..., description="Number of training samples")
|
||||
features: List[str] = Field(..., description="List of features used")
|
||||
hyperparameters: Dict[str, Any] = Field(..., description="Model hyperparameters")
|
||||
training_metrics: Dict[str, float] = Field(..., description="Training performance metrics")
|
||||
is_active: bool = Field(..., description="Whether model is active")
|
||||
created_at: datetime = Field(..., description="Model creation timestamp")
|
||||
data_period_start: Optional[datetime] = Field(None, description="Training data start date")
|
||||
data_period_end: Optional[datetime] = Field(None, description="Training data end date")
|
||||
Reference in New Issue
Block a user