91 lines
2.6 KiB
Python
91 lines
2.6 KiB
Python
"""
|
|
Training schemas
|
|
"""
|
|
|
|
from pydantic import BaseModel, Field, validator
|
|
from typing import Optional, Dict, Any, List
|
|
from datetime import datetime
|
|
from enum import Enum
|
|
|
|
class TrainingJobStatus(str, Enum):
|
|
"""Training job status enum"""
|
|
QUEUED = "queued"
|
|
RUNNING = "running"
|
|
COMPLETED = "completed"
|
|
FAILED = "failed"
|
|
CANCELLED = "cancelled"
|
|
|
|
class TrainingRequest(BaseModel):
|
|
"""Training request schema"""
|
|
tenant_id: Optional[str] = None # Will be set from auth
|
|
force_retrain: bool = Field(default=False, description="Force retrain even if recent models exist")
|
|
products: Optional[List[str]] = Field(default=None, description="Specific products to train, or None for all")
|
|
training_days: Optional[int] = Field(default=730, ge=30, le=1095, description="Number of days of historical data to use")
|
|
|
|
@validator('training_days')
|
|
def validate_training_days(cls, v):
|
|
if v < 30:
|
|
raise ValueError('Minimum training days is 30')
|
|
if v > 1095:
|
|
raise ValueError('Maximum training days is 1095 (3 years)')
|
|
return v
|
|
|
|
class TrainingJobResponse(BaseModel):
|
|
"""Training job response schema"""
|
|
id: str
|
|
tenant_id: str
|
|
status: TrainingJobStatus
|
|
progress: int
|
|
current_step: Optional[str]
|
|
started_at: datetime
|
|
completed_at: Optional[datetime]
|
|
duration_seconds: Optional[int]
|
|
models_trained: Optional[Dict[str, Any]]
|
|
metrics: Optional[Dict[str, Any]]
|
|
error_message: Optional[str]
|
|
|
|
class Config:
|
|
from_attributes = True
|
|
|
|
class TrainedModelResponse(BaseModel):
|
|
"""Trained model response schema"""
|
|
id: str
|
|
product_name: str
|
|
model_type: str
|
|
model_version: str
|
|
mape: Optional[float]
|
|
rmse: Optional[float]
|
|
mae: Optional[float]
|
|
r2_score: Optional[float]
|
|
training_samples: Optional[int]
|
|
features_used: Optional[List[str]]
|
|
is_active: bool
|
|
created_at: datetime
|
|
last_used_at: Optional[datetime]
|
|
|
|
class Config:
|
|
from_attributes = True
|
|
|
|
class TrainingProgress(BaseModel):
|
|
"""Training progress update schema"""
|
|
job_id: str
|
|
progress: int
|
|
current_step: str
|
|
estimated_completion: Optional[datetime]
|
|
|
|
class TrainingMetrics(BaseModel):
|
|
"""Training metrics schema"""
|
|
total_jobs: int
|
|
successful_jobs: int
|
|
failed_jobs: int
|
|
average_duration: float
|
|
models_trained: int
|
|
active_models: int
|
|
|
|
class ModelValidationResult(BaseModel):
|
|
"""Model validation result schema"""
|
|
product_name: str
|
|
is_valid: bool
|
|
accuracy_score: float
|
|
validation_error: Optional[str]
|
|
recommendations: List[str] |