Files
bakery-ia/services/training/app/schemas/training.py
2025-07-17 13:09:24 +02:00

91 lines
2.6 KiB
Python

"""
Training schemas
"""
from pydantic import BaseModel, Field, validator
from typing import Optional, Dict, Any, List
from datetime import datetime
from enum import Enum
class TrainingJobStatus(str, Enum):
"""Training job status enum"""
QUEUED = "queued"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class TrainingRequest(BaseModel):
"""Training request schema"""
tenant_id: Optional[str] = None # Will be set from auth
force_retrain: bool = Field(default=False, description="Force retrain even if recent models exist")
products: Optional[List[str]] = Field(default=None, description="Specific products to train, or None for all")
training_days: Optional[int] = Field(default=730, ge=30, le=1095, description="Number of days of historical data to use")
@validator('training_days')
def validate_training_days(cls, v):
if v < 30:
raise ValueError('Minimum training days is 30')
if v > 1095:
raise ValueError('Maximum training days is 1095 (3 years)')
return v
class TrainingJobResponse(BaseModel):
"""Training job response schema"""
id: str
tenant_id: str
status: TrainingJobStatus
progress: int
current_step: Optional[str]
started_at: datetime
completed_at: Optional[datetime]
duration_seconds: Optional[int]
models_trained: Optional[Dict[str, Any]]
metrics: Optional[Dict[str, Any]]
error_message: Optional[str]
class Config:
from_attributes = True
class TrainedModelResponse(BaseModel):
"""Trained model response schema"""
id: str
product_name: str
model_type: str
model_version: str
mape: Optional[float]
rmse: Optional[float]
mae: Optional[float]
r2_score: Optional[float]
training_samples: Optional[int]
features_used: Optional[List[str]]
is_active: bool
created_at: datetime
last_used_at: Optional[datetime]
class Config:
from_attributes = True
class TrainingProgress(BaseModel):
"""Training progress update schema"""
job_id: str
progress: int
current_step: str
estimated_completion: Optional[datetime]
class TrainingMetrics(BaseModel):
"""Training metrics schema"""
total_jobs: int
successful_jobs: int
failed_jobs: int
average_duration: float
models_trained: int
active_models: int
class ModelValidationResult(BaseModel):
"""Model validation result schema"""
product_name: str
is_valid: bool
accuracy_score: float
validation_error: Optional[str]
recommendations: List[str]