Improve the event messaging for training service 2
This commit is contained in:
@@ -63,6 +63,15 @@ def safe_json_serialize(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
return serialize_for_json(data)
|
||||
|
||||
async def setup_websocket_message_routing():
|
||||
"""Set up message routing for WebSocket connections"""
|
||||
try:
|
||||
# This will be called from the WebSocket endpoint
|
||||
# to set up the consumer for a specific job
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to set up WebSocket message routing: {e}")
|
||||
|
||||
# =========================================
|
||||
# ENHANCED TRAINING JOB STATUS EVENTS
|
||||
# =========================================
|
||||
|
||||
@@ -16,6 +16,15 @@ from app.services.training_orchestrator import TrainingDataOrchestrator
|
||||
|
||||
from app.core.database import get_db_session
|
||||
|
||||
from app.services.messaging import (
|
||||
publish_job_progress,
|
||||
publish_data_validation_started,
|
||||
publish_data_validation_completed,
|
||||
publish_job_step_completed,
|
||||
publish_job_completed,
|
||||
publish_job_failed
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TrainingService:
|
||||
@@ -61,18 +70,12 @@ class TrainingService:
|
||||
|
||||
logger.info(f"Starting training job {job_id} for tenant {tenant_id}")
|
||||
|
||||
from app.services.messaging import TrainingStatusPublisher
|
||||
status_publisher = TrainingStatusPublisher(job_id, tenant_id)
|
||||
|
||||
try:
|
||||
|
||||
await status_publisher.job_started({
|
||||
"bakery_location": bakery_location,
|
||||
"has_custom_dates": bool(requested_start or requested_end)
|
||||
}, 0) # Will be updated when we know product count
|
||||
|
||||
# Step 1: Prepare training dataset with date alignment and orchestration
|
||||
logger.info("Step 1: Preparing and aligning training data")
|
||||
await publish_job_progress(job_id, tenant_id, 0, "Extrayendo datos de ventas")
|
||||
training_dataset = await self.orchestrator.prepare_training_data(
|
||||
tenant_id=tenant_id,
|
||||
bakery_location=bakery_location,
|
||||
@@ -83,6 +86,7 @@ class TrainingService:
|
||||
|
||||
# Step 2: Execute ML training pipeline
|
||||
logger.info("Step 2: Starting ML training pipeline")
|
||||
await publish_job_progress(job_id, tenant_id, 35, "Starting ML training pipeline")
|
||||
training_results = await self.trainer.train_tenant_models(
|
||||
tenant_id=tenant_id,
|
||||
training_dataset=training_dataset,
|
||||
@@ -110,12 +114,11 @@ class TrainingService:
|
||||
}
|
||||
|
||||
logger.info(f"Training job {job_id} completed successfully")
|
||||
await status_publisher.job_completed(final_result)
|
||||
await publish_job_completed(job_id, tenant_id, final_result);
|
||||
return TrainingService.create_detailed_training_response(final_result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training job {job_id} failed: {str(e)}")
|
||||
await status_publisher.job_failed(str(e))
|
||||
# Return error response in same detailed format
|
||||
final_result = {
|
||||
"job_id": job_id,
|
||||
@@ -139,7 +142,7 @@ class TrainingService:
|
||||
"completed_at": datetime.now().isoformat(),
|
||||
"error_message": str(e)
|
||||
}
|
||||
|
||||
await publish_job_failed(job_id, tenant_id, str(e), final_result)
|
||||
return TrainingService.create_detailed_training_response(final_result)
|
||||
|
||||
async def start_single_product_training(
|
||||
|
||||
Reference in New Issue
Block a user