From c788c7e4068d2a53eefc407eddfd276ebfa74396 Mon Sep 17 00:00:00 2001 From: Urtzi Alfaro Date: Mon, 28 Jul 2025 21:30:49 +0200 Subject: [PATCH] Improve training code 3 --- services/training/app/schemas/training.py | 50 +++++++++- .../training/app/services/training_service.py | 96 ++++++++++++++++--- 2 files changed, 129 insertions(+), 17 deletions(-) diff --git a/services/training/app/schemas/training.py b/services/training/app/schemas/training.py index 94ff5478..153e492d 100644 --- a/services/training/app/schemas/training.py +++ b/services/training/app/schemas/training.py @@ -55,16 +55,58 @@ class SingleProductTrainingRequest(BaseModel): weekly_seasonality: bool = Field(True, description="Enable weekly seasonality") yearly_seasonality: bool = Field(True, description="Enable yearly seasonality") +class DateRangeInfo(BaseModel): + """Schema for date range information""" + start: str = Field(..., description="Start date in ISO format") + end: str = Field(..., description="End date in ISO format") + +class DataSummary(BaseModel): + """Schema for training data summary""" + sales_records: int = Field(..., description="Number of sales records used") + weather_records: int = Field(..., description="Number of weather records used") + traffic_records: int = Field(..., description="Number of traffic records used") + date_range: DateRangeInfo = Field(..., description="Date range of training data") + data_sources_used: List[str] = Field(..., description="List of data sources used") + constraints_applied: Dict[str, str] = Field(default_factory=dict, description="Constraints applied during data collection") + +class ProductTrainingResult(BaseModel): + """Schema for individual product training results""" + product_name: str = Field(..., description="Product name") + status: str = Field(..., description="Training status for this product") + model_id: Optional[str] = Field(None, description="Trained model identifier") + data_points: int = Field(..., description="Number of data points used for training") + metrics: Optional[Dict[str, float]] = Field(None, description="Training metrics (MAE, MAPE, etc.)") + training_time_seconds: Optional[float] = Field(None, description="Time taken to train this model") + error_message: Optional[str] = Field(None, description="Error message if training failed") + +class TrainingResults(BaseModel): + """Schema for overall training results""" + total_products: int = Field(..., description="Total number of products") + successful_trainings: int = Field(..., description="Number of successfully trained models") + failed_trainings: int = Field(..., description="Number of failed trainings") + products: List[ProductTrainingResult] = Field(..., description="Results for each product") + overall_training_time_seconds: float = Field(..., description="Total training time") + class TrainingJobResponse(BaseModel): - """Response schema for training job creation""" + """Enhanced response schema for training job with detailed results""" job_id: str = Field(..., description="Unique training job identifier") - status: TrainingStatus = Field(..., description="Current job status") - message: str = Field(..., description="Status message") tenant_id: str = Field(..., description="Tenant identifier") + status: TrainingStatus = Field(..., description="Overall job status") + + # Required fields for basic response (backwards compatibility) + message: str = Field(..., description="Status message") created_at: datetime = Field(..., description="Job creation timestamp") estimated_duration_minutes: int = Field(..., description="Estimated completion time in minutes") - # ✅ FIX: Add custom validator to convert UUID to string + # New detailed fields (optional for backwards compatibility) + training_results: Optional[TrainingResults] = Field(None, description="Detailed training results") + data_summary: Optional[DataSummary] = Field(None, description="Summary of training data used") + completed_at: Optional[str] = Field(None, description="Job completion timestamp in ISO format") + + # Additional optional fields + error_details: Optional[Dict[str, Any]] = Field(None, description="Detailed error information if failed") + processing_metadata: Optional[Dict[str, Any]] = Field(None, description="Additional processing metadata") + @validator('tenant_id', 'job_id', pre=True) def convert_uuid_to_string(cls, v): """Convert UUID objects to strings for JSON serialization""" diff --git a/services/training/app/services/training_service.py b/services/training/app/services/training_service.py index 83ddd8c0..58d1c586 100644 --- a/services/training/app/services/training_service.py +++ b/services/training/app/services/training_service.py @@ -81,27 +81,55 @@ class TrainingService: ) # Step 3: Compile final results - logger.info(f"Training job {job_id} completed successfully") - return { + final_result = { "job_id": job_id, - "status": "completed", # or "running" if async - "message": "Training job completed successfully", "tenant_id": tenant_id, - "created_at": datetime.now(), - "estimated_duration_minutes": 5 # reasonable estimate + "status": "completed", + "training_results": training_results, + "data_summary": { + "sales_records": len(training_dataset.sales_data), + "weather_records": len(training_dataset.weather_data), + "traffic_records": len(training_dataset.traffic_data), + "date_range": { + "start": training_dataset.date_range.start.isoformat(), + "end": training_dataset.date_range.end.isoformat() + }, + "data_sources_used": [source.value for source in training_dataset.date_range.available_sources], + "constraints_applied": training_dataset.date_range.constraints + }, + "completed_at": datetime.now().isoformat() } + logger.info(f"Training job {job_id} completed successfully") + return TrainingService.create_detailed_training_response(final_result) + except Exception as e: logger.error(f"Training job {job_id} failed: {str(e)}") - # Return error response that still matches schema - return { + # Return error response in same detailed format + final_result = { "job_id": job_id, - "status": "failed", - "message": f"Training job failed: {str(e)}", "tenant_id": tenant_id, - "created_at": datetime.now(), - "estimated_duration_minutes": 0 + "status": "failed", + "training_results": { + "total_products": 0, + "successful_trainings": 0, + "failed_trainings": 0, + "models_trained": {}, + "total_training_time": 0 + }, + "data_summary": { + "sales_records": 0, + "weather_records": 0, + "traffic_records": 0, + "date_range": {"start": "", "end": ""}, + "data_sources_used": [], + "constraints_applied": {} + }, + "completed_at": datetime.now().isoformat(), + "error_message": str(e) } + + return TrainingService.create_detailed_training_response(final_result) async def start_single_product_training( self, @@ -290,4 +318,46 @@ class TrainingService: "reasons": [f"Error analyzing data: {str(e)}"], "recommended_products": [], "optimal_config": {} - } \ No newline at end of file + } + + def create_detailed_training_response(final_result: Dict[str, Any]) -> Dict[str, Any]: + """ + Convert your final_result structure to match the TrainingJobResponse schema + """ + # Extract training results and convert to schema format + training_results_data = final_result.get("training_results", {}) + + # Convert product results to schema format + products = [] + if "models_trained" in training_results_data: + for product_name, result in training_results_data["models_trained"].items(): + products.append({ + "product_name": product_name, + "status": result.get("status", "completed"), + "model_id": result.get("model_id"), + "data_points": result.get("data_points", 0), + "metrics": result.get("metrics"), + "training_time_seconds": result.get("training_time_seconds"), + "error_message": result.get("error_message") + }) + + # Build the response matching your structure + response_data = { + "job_id": final_result["job_id"], + "tenant_id": final_result["tenant_id"], + "status": final_result["status"], + "message": f"Training {final_result['status']} successfully", + "created_at": datetime.now(), + "estimated_duration_minutes": 0, # Already completed + "training_results": { + "total_products": len(products), + "successful_trainings": len([p for p in products if p["status"] == "completed"]), + "failed_trainings": len([p for p in products if p["status"] == "failed"]), + "products": products, + "overall_training_time_seconds": training_results_data.get("total_training_time", 0) + }, + "data_summary": final_result.get("data_summary", {}), + "completed_at": final_result.get("completed_at") + } + + return response_data \ No newline at end of file