Improve training code 3
This commit is contained in:
@@ -55,16 +55,58 @@ class SingleProductTrainingRequest(BaseModel):
|
|||||||
weekly_seasonality: bool = Field(True, description="Enable weekly seasonality")
|
weekly_seasonality: bool = Field(True, description="Enable weekly seasonality")
|
||||||
yearly_seasonality: bool = Field(True, description="Enable yearly seasonality")
|
yearly_seasonality: bool = Field(True, description="Enable yearly seasonality")
|
||||||
|
|
||||||
|
class DateRangeInfo(BaseModel):
|
||||||
|
"""Schema for date range information"""
|
||||||
|
start: str = Field(..., description="Start date in ISO format")
|
||||||
|
end: str = Field(..., description="End date in ISO format")
|
||||||
|
|
||||||
|
class DataSummary(BaseModel):
|
||||||
|
"""Schema for training data summary"""
|
||||||
|
sales_records: int = Field(..., description="Number of sales records used")
|
||||||
|
weather_records: int = Field(..., description="Number of weather records used")
|
||||||
|
traffic_records: int = Field(..., description="Number of traffic records used")
|
||||||
|
date_range: DateRangeInfo = Field(..., description="Date range of training data")
|
||||||
|
data_sources_used: List[str] = Field(..., description="List of data sources used")
|
||||||
|
constraints_applied: Dict[str, str] = Field(default_factory=dict, description="Constraints applied during data collection")
|
||||||
|
|
||||||
|
class ProductTrainingResult(BaseModel):
|
||||||
|
"""Schema for individual product training results"""
|
||||||
|
product_name: str = Field(..., description="Product name")
|
||||||
|
status: str = Field(..., description="Training status for this product")
|
||||||
|
model_id: Optional[str] = Field(None, description="Trained model identifier")
|
||||||
|
data_points: int = Field(..., description="Number of data points used for training")
|
||||||
|
metrics: Optional[Dict[str, float]] = Field(None, description="Training metrics (MAE, MAPE, etc.)")
|
||||||
|
training_time_seconds: Optional[float] = Field(None, description="Time taken to train this model")
|
||||||
|
error_message: Optional[str] = Field(None, description="Error message if training failed")
|
||||||
|
|
||||||
|
class TrainingResults(BaseModel):
|
||||||
|
"""Schema for overall training results"""
|
||||||
|
total_products: int = Field(..., description="Total number of products")
|
||||||
|
successful_trainings: int = Field(..., description="Number of successfully trained models")
|
||||||
|
failed_trainings: int = Field(..., description="Number of failed trainings")
|
||||||
|
products: List[ProductTrainingResult] = Field(..., description="Results for each product")
|
||||||
|
overall_training_time_seconds: float = Field(..., description="Total training time")
|
||||||
|
|
||||||
class TrainingJobResponse(BaseModel):
|
class TrainingJobResponse(BaseModel):
|
||||||
"""Response schema for training job creation"""
|
"""Enhanced response schema for training job with detailed results"""
|
||||||
job_id: str = Field(..., description="Unique training job identifier")
|
job_id: str = Field(..., description="Unique training job identifier")
|
||||||
status: TrainingStatus = Field(..., description="Current job status")
|
|
||||||
message: str = Field(..., description="Status message")
|
|
||||||
tenant_id: str = Field(..., description="Tenant identifier")
|
tenant_id: str = Field(..., description="Tenant identifier")
|
||||||
|
status: TrainingStatus = Field(..., description="Overall job status")
|
||||||
|
|
||||||
|
# Required fields for basic response (backwards compatibility)
|
||||||
|
message: str = Field(..., description="Status message")
|
||||||
created_at: datetime = Field(..., description="Job creation timestamp")
|
created_at: datetime = Field(..., description="Job creation timestamp")
|
||||||
estimated_duration_minutes: int = Field(..., description="Estimated completion time in minutes")
|
estimated_duration_minutes: int = Field(..., description="Estimated completion time in minutes")
|
||||||
|
|
||||||
# ✅ FIX: Add custom validator to convert UUID to string
|
# New detailed fields (optional for backwards compatibility)
|
||||||
|
training_results: Optional[TrainingResults] = Field(None, description="Detailed training results")
|
||||||
|
data_summary: Optional[DataSummary] = Field(None, description="Summary of training data used")
|
||||||
|
completed_at: Optional[str] = Field(None, description="Job completion timestamp in ISO format")
|
||||||
|
|
||||||
|
# Additional optional fields
|
||||||
|
error_details: Optional[Dict[str, Any]] = Field(None, description="Detailed error information if failed")
|
||||||
|
processing_metadata: Optional[Dict[str, Any]] = Field(None, description="Additional processing metadata")
|
||||||
|
|
||||||
@validator('tenant_id', 'job_id', pre=True)
|
@validator('tenant_id', 'job_id', pre=True)
|
||||||
def convert_uuid_to_string(cls, v):
|
def convert_uuid_to_string(cls, v):
|
||||||
"""Convert UUID objects to strings for JSON serialization"""
|
"""Convert UUID objects to strings for JSON serialization"""
|
||||||
|
|||||||
@@ -81,27 +81,55 @@ class TrainingService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Step 3: Compile final results
|
# Step 3: Compile final results
|
||||||
logger.info(f"Training job {job_id} completed successfully")
|
final_result = {
|
||||||
return {
|
|
||||||
"job_id": job_id,
|
"job_id": job_id,
|
||||||
"status": "completed", # or "running" if async
|
|
||||||
"message": "Training job completed successfully",
|
|
||||||
"tenant_id": tenant_id,
|
"tenant_id": tenant_id,
|
||||||
"created_at": datetime.now(),
|
"status": "completed",
|
||||||
"estimated_duration_minutes": 5 # reasonable estimate
|
"training_results": training_results,
|
||||||
|
"data_summary": {
|
||||||
|
"sales_records": len(training_dataset.sales_data),
|
||||||
|
"weather_records": len(training_dataset.weather_data),
|
||||||
|
"traffic_records": len(training_dataset.traffic_data),
|
||||||
|
"date_range": {
|
||||||
|
"start": training_dataset.date_range.start.isoformat(),
|
||||||
|
"end": training_dataset.date_range.end.isoformat()
|
||||||
|
},
|
||||||
|
"data_sources_used": [source.value for source in training_dataset.date_range.available_sources],
|
||||||
|
"constraints_applied": training_dataset.date_range.constraints
|
||||||
|
},
|
||||||
|
"completed_at": datetime.now().isoformat()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.info(f"Training job {job_id} completed successfully")
|
||||||
|
return TrainingService.create_detailed_training_response(final_result)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Training job {job_id} failed: {str(e)}")
|
logger.error(f"Training job {job_id} failed: {str(e)}")
|
||||||
# Return error response that still matches schema
|
# Return error response in same detailed format
|
||||||
return {
|
final_result = {
|
||||||
"job_id": job_id,
|
"job_id": job_id,
|
||||||
"status": "failed",
|
|
||||||
"message": f"Training job failed: {str(e)}",
|
|
||||||
"tenant_id": tenant_id,
|
"tenant_id": tenant_id,
|
||||||
"created_at": datetime.now(),
|
"status": "failed",
|
||||||
"estimated_duration_minutes": 0
|
"training_results": {
|
||||||
|
"total_products": 0,
|
||||||
|
"successful_trainings": 0,
|
||||||
|
"failed_trainings": 0,
|
||||||
|
"models_trained": {},
|
||||||
|
"total_training_time": 0
|
||||||
|
},
|
||||||
|
"data_summary": {
|
||||||
|
"sales_records": 0,
|
||||||
|
"weather_records": 0,
|
||||||
|
"traffic_records": 0,
|
||||||
|
"date_range": {"start": "", "end": ""},
|
||||||
|
"data_sources_used": [],
|
||||||
|
"constraints_applied": {}
|
||||||
|
},
|
||||||
|
"completed_at": datetime.now().isoformat(),
|
||||||
|
"error_message": str(e)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return TrainingService.create_detailed_training_response(final_result)
|
||||||
|
|
||||||
async def start_single_product_training(
|
async def start_single_product_training(
|
||||||
self,
|
self,
|
||||||
@@ -290,4 +318,46 @@ class TrainingService:
|
|||||||
"reasons": [f"Error analyzing data: {str(e)}"],
|
"reasons": [f"Error analyzing data: {str(e)}"],
|
||||||
"recommended_products": [],
|
"recommended_products": [],
|
||||||
"optimal_config": {}
|
"optimal_config": {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def create_detailed_training_response(final_result: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Convert your final_result structure to match the TrainingJobResponse schema
|
||||||
|
"""
|
||||||
|
# Extract training results and convert to schema format
|
||||||
|
training_results_data = final_result.get("training_results", {})
|
||||||
|
|
||||||
|
# Convert product results to schema format
|
||||||
|
products = []
|
||||||
|
if "models_trained" in training_results_data:
|
||||||
|
for product_name, result in training_results_data["models_trained"].items():
|
||||||
|
products.append({
|
||||||
|
"product_name": product_name,
|
||||||
|
"status": result.get("status", "completed"),
|
||||||
|
"model_id": result.get("model_id"),
|
||||||
|
"data_points": result.get("data_points", 0),
|
||||||
|
"metrics": result.get("metrics"),
|
||||||
|
"training_time_seconds": result.get("training_time_seconds"),
|
||||||
|
"error_message": result.get("error_message")
|
||||||
|
})
|
||||||
|
|
||||||
|
# Build the response matching your structure
|
||||||
|
response_data = {
|
||||||
|
"job_id": final_result["job_id"],
|
||||||
|
"tenant_id": final_result["tenant_id"],
|
||||||
|
"status": final_result["status"],
|
||||||
|
"message": f"Training {final_result['status']} successfully",
|
||||||
|
"created_at": datetime.now(),
|
||||||
|
"estimated_duration_minutes": 0, # Already completed
|
||||||
|
"training_results": {
|
||||||
|
"total_products": len(products),
|
||||||
|
"successful_trainings": len([p for p in products if p["status"] == "completed"]),
|
||||||
|
"failed_trainings": len([p for p in products if p["status"] == "failed"]),
|
||||||
|
"products": products,
|
||||||
|
"overall_training_time_seconds": training_results_data.get("total_training_time", 0)
|
||||||
|
},
|
||||||
|
"data_summary": final_result.get("data_summary", {}),
|
||||||
|
"completed_at": final_result.get("completed_at")
|
||||||
|
}
|
||||||
|
|
||||||
|
return response_data
|
||||||
Reference in New Issue
Block a user