Improve training code 3
This commit is contained in:
@@ -55,16 +55,58 @@ class SingleProductTrainingRequest(BaseModel):
|
||||
weekly_seasonality: bool = Field(True, description="Enable weekly seasonality")
|
||||
yearly_seasonality: bool = Field(True, description="Enable yearly seasonality")
|
||||
|
||||
class DateRangeInfo(BaseModel):
|
||||
"""Schema for date range information"""
|
||||
start: str = Field(..., description="Start date in ISO format")
|
||||
end: str = Field(..., description="End date in ISO format")
|
||||
|
||||
class DataSummary(BaseModel):
|
||||
"""Schema for training data summary"""
|
||||
sales_records: int = Field(..., description="Number of sales records used")
|
||||
weather_records: int = Field(..., description="Number of weather records used")
|
||||
traffic_records: int = Field(..., description="Number of traffic records used")
|
||||
date_range: DateRangeInfo = Field(..., description="Date range of training data")
|
||||
data_sources_used: List[str] = Field(..., description="List of data sources used")
|
||||
constraints_applied: Dict[str, str] = Field(default_factory=dict, description="Constraints applied during data collection")
|
||||
|
||||
class ProductTrainingResult(BaseModel):
|
||||
"""Schema for individual product training results"""
|
||||
product_name: str = Field(..., description="Product name")
|
||||
status: str = Field(..., description="Training status for this product")
|
||||
model_id: Optional[str] = Field(None, description="Trained model identifier")
|
||||
data_points: int = Field(..., description="Number of data points used for training")
|
||||
metrics: Optional[Dict[str, float]] = Field(None, description="Training metrics (MAE, MAPE, etc.)")
|
||||
training_time_seconds: Optional[float] = Field(None, description="Time taken to train this model")
|
||||
error_message: Optional[str] = Field(None, description="Error message if training failed")
|
||||
|
||||
class TrainingResults(BaseModel):
|
||||
"""Schema for overall training results"""
|
||||
total_products: int = Field(..., description="Total number of products")
|
||||
successful_trainings: int = Field(..., description="Number of successfully trained models")
|
||||
failed_trainings: int = Field(..., description="Number of failed trainings")
|
||||
products: List[ProductTrainingResult] = Field(..., description="Results for each product")
|
||||
overall_training_time_seconds: float = Field(..., description="Total training time")
|
||||
|
||||
class TrainingJobResponse(BaseModel):
|
||||
"""Response schema for training job creation"""
|
||||
"""Enhanced response schema for training job with detailed results"""
|
||||
job_id: str = Field(..., description="Unique training job identifier")
|
||||
status: TrainingStatus = Field(..., description="Current job status")
|
||||
message: str = Field(..., description="Status message")
|
||||
tenant_id: str = Field(..., description="Tenant identifier")
|
||||
status: TrainingStatus = Field(..., description="Overall job status")
|
||||
|
||||
# Required fields for basic response (backwards compatibility)
|
||||
message: str = Field(..., description="Status message")
|
||||
created_at: datetime = Field(..., description="Job creation timestamp")
|
||||
estimated_duration_minutes: int = Field(..., description="Estimated completion time in minutes")
|
||||
|
||||
# ✅ FIX: Add custom validator to convert UUID to string
|
||||
# New detailed fields (optional for backwards compatibility)
|
||||
training_results: Optional[TrainingResults] = Field(None, description="Detailed training results")
|
||||
data_summary: Optional[DataSummary] = Field(None, description="Summary of training data used")
|
||||
completed_at: Optional[str] = Field(None, description="Job completion timestamp in ISO format")
|
||||
|
||||
# Additional optional fields
|
||||
error_details: Optional[Dict[str, Any]] = Field(None, description="Detailed error information if failed")
|
||||
processing_metadata: Optional[Dict[str, Any]] = Field(None, description="Additional processing metadata")
|
||||
|
||||
@validator('tenant_id', 'job_id', pre=True)
|
||||
def convert_uuid_to_string(cls, v):
|
||||
"""Convert UUID objects to strings for JSON serialization"""
|
||||
|
||||
Reference in New Issue
Block a user