Improve training code 3
This commit is contained in:
@@ -81,27 +81,55 @@ class TrainingService:
|
||||
)
|
||||
|
||||
# Step 3: Compile final results
|
||||
logger.info(f"Training job {job_id} completed successfully")
|
||||
return {
|
||||
final_result = {
|
||||
"job_id": job_id,
|
||||
"status": "completed", # or "running" if async
|
||||
"message": "Training job completed successfully",
|
||||
"tenant_id": tenant_id,
|
||||
"created_at": datetime.now(),
|
||||
"estimated_duration_minutes": 5 # reasonable estimate
|
||||
"status": "completed",
|
||||
"training_results": training_results,
|
||||
"data_summary": {
|
||||
"sales_records": len(training_dataset.sales_data),
|
||||
"weather_records": len(training_dataset.weather_data),
|
||||
"traffic_records": len(training_dataset.traffic_data),
|
||||
"date_range": {
|
||||
"start": training_dataset.date_range.start.isoformat(),
|
||||
"end": training_dataset.date_range.end.isoformat()
|
||||
},
|
||||
"data_sources_used": [source.value for source in training_dataset.date_range.available_sources],
|
||||
"constraints_applied": training_dataset.date_range.constraints
|
||||
},
|
||||
"completed_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
logger.info(f"Training job {job_id} completed successfully")
|
||||
return TrainingService.create_detailed_training_response(final_result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training job {job_id} failed: {str(e)}")
|
||||
# Return error response that still matches schema
|
||||
return {
|
||||
# Return error response in same detailed format
|
||||
final_result = {
|
||||
"job_id": job_id,
|
||||
"status": "failed",
|
||||
"message": f"Training job failed: {str(e)}",
|
||||
"tenant_id": tenant_id,
|
||||
"created_at": datetime.now(),
|
||||
"estimated_duration_minutes": 0
|
||||
"status": "failed",
|
||||
"training_results": {
|
||||
"total_products": 0,
|
||||
"successful_trainings": 0,
|
||||
"failed_trainings": 0,
|
||||
"models_trained": {},
|
||||
"total_training_time": 0
|
||||
},
|
||||
"data_summary": {
|
||||
"sales_records": 0,
|
||||
"weather_records": 0,
|
||||
"traffic_records": 0,
|
||||
"date_range": {"start": "", "end": ""},
|
||||
"data_sources_used": [],
|
||||
"constraints_applied": {}
|
||||
},
|
||||
"completed_at": datetime.now().isoformat(),
|
||||
"error_message": str(e)
|
||||
}
|
||||
|
||||
return TrainingService.create_detailed_training_response(final_result)
|
||||
|
||||
async def start_single_product_training(
|
||||
self,
|
||||
@@ -290,4 +318,46 @@ class TrainingService:
|
||||
"reasons": [f"Error analyzing data: {str(e)}"],
|
||||
"recommended_products": [],
|
||||
"optimal_config": {}
|
||||
}
|
||||
}
|
||||
|
||||
def create_detailed_training_response(final_result: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert your final_result structure to match the TrainingJobResponse schema
|
||||
"""
|
||||
# Extract training results and convert to schema format
|
||||
training_results_data = final_result.get("training_results", {})
|
||||
|
||||
# Convert product results to schema format
|
||||
products = []
|
||||
if "models_trained" in training_results_data:
|
||||
for product_name, result in training_results_data["models_trained"].items():
|
||||
products.append({
|
||||
"product_name": product_name,
|
||||
"status": result.get("status", "completed"),
|
||||
"model_id": result.get("model_id"),
|
||||
"data_points": result.get("data_points", 0),
|
||||
"metrics": result.get("metrics"),
|
||||
"training_time_seconds": result.get("training_time_seconds"),
|
||||
"error_message": result.get("error_message")
|
||||
})
|
||||
|
||||
# Build the response matching your structure
|
||||
response_data = {
|
||||
"job_id": final_result["job_id"],
|
||||
"tenant_id": final_result["tenant_id"],
|
||||
"status": final_result["status"],
|
||||
"message": f"Training {final_result['status']} successfully",
|
||||
"created_at": datetime.now(),
|
||||
"estimated_duration_minutes": 0, # Already completed
|
||||
"training_results": {
|
||||
"total_products": len(products),
|
||||
"successful_trainings": len([p for p in products if p["status"] == "completed"]),
|
||||
"failed_trainings": len([p for p in products if p["status"] == "failed"]),
|
||||
"products": products,
|
||||
"overall_training_time_seconds": training_results_data.get("total_training_time", 0)
|
||||
},
|
||||
"data_summary": final_result.get("data_summary", {}),
|
||||
"completed_at": final_result.get("completed_at")
|
||||
}
|
||||
|
||||
return response_data
|
||||
Reference in New Issue
Block a user