""" Training Event Consumer Processes ML model retraining requests from RabbitMQ Queues training jobs and manages model lifecycle """ import json import structlog from typing import Dict, Any, Optional from datetime import datetime from uuid import UUID from shared.messaging import RabbitMQClient from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select logger = structlog.get_logger() class TrainingEventConsumer: """ Consumes training retraining events and queues ML training jobs Ensures no duplicate training jobs and manages priorities """ def __init__(self, db_session: AsyncSession): self.db_session = db_session async def consume_training_events( self, rabbitmq_client: RabbitMQClient ): """ Start consuming training events from RabbitMQ """ async def process_message(message): """Process a single training event message""" try: async with message.process(): # Parse event data event_data = json.loads(message.body.decode()) logger.info( "Received training event", event_id=event_data.get('event_id'), event_type=event_data.get('event_type'), tenant_id=event_data.get('tenant_id') ) # Process the event await self.process_training_event(event_data) except Exception as e: logger.error( "Error processing training event", error=str(e), exc_info=True ) # Start consuming events await rabbitmq_client.consume_events( exchange_name="training.events", queue_name="training.retraining.queue", routing_key="training.retrain.*", callback=process_message ) logger.info("Started consuming training events") async def process_training_event(self, event_data: Dict[str, Any]) -> bool: """ Process a training event based on type Args: event_data: Full event payload from RabbitMQ Returns: bool: True if processed successfully """ try: event_type = event_data.get('event_type') data = event_data.get('data', {}) tenant_id = event_data.get('tenant_id') if not tenant_id: logger.warning("Training event missing tenant_id", event_data=event_data) return False # Route to appropriate handler if event_type == 'training.retrain.requested': success = await self._handle_retrain_requested(tenant_id, data, event_data) elif event_type == 'training.retrain.scheduled': success = await self._handle_retrain_scheduled(tenant_id, data) else: logger.warning("Unknown training event type", event_type=event_type) success = True # Mark as processed to avoid retry if success: logger.info( "Training event processed successfully", event_type=event_type, tenant_id=tenant_id ) else: logger.error( "Training event processing failed", event_type=event_type, tenant_id=tenant_id ) return success except Exception as e: logger.error( "Error in process_training_event", error=str(e), event_id=event_data.get('event_id'), exc_info=True ) return False async def _handle_retrain_requested( self, tenant_id: str, data: Dict[str, Any], event_data: Dict[str, Any] ) -> bool: """ Handle retraining request event Validates model, checks for existing jobs, queues training job Args: tenant_id: Tenant ID data: Retraining request data event_data: Full event payload Returns: bool: True if handled successfully """ try: model_id = data.get('model_id') product_id = data.get('product_id') trigger_reason = data.get('trigger_reason', 'unknown') priority = data.get('priority', 'normal') event_id = event_data.get('event_id') if not model_id: logger.warning("Retraining request missing model_id", data=data) return False # Validate model exists from app.models import TrainedModel stmt = select(TrainedModel).where( TrainedModel.id == UUID(model_id), TrainedModel.tenant_id == UUID(tenant_id) ) result = await self.db_session.execute(stmt) model = result.scalar_one_or_none() if not model: logger.error( "Model not found for retraining", model_id=model_id, tenant_id=tenant_id ) return False # Check if model is already in training if model.status in ['training', 'retraining_queued']: logger.info( "Model already in training, skipping duplicate request", model_id=model_id, current_status=model.status ) return True # Consider successful (idempotent) # Check for existing job in queue from app.models import TrainingJobQueue existing_job_stmt = select(TrainingJobQueue).where( TrainingJobQueue.model_id == UUID(model_id), TrainingJobQueue.status.in_(['pending', 'running']) ) existing_job_result = await self.db_session.execute(existing_job_stmt) existing_job = existing_job_result.scalar_one_or_none() if existing_job: logger.info( "Training job already queued, skipping duplicate", model_id=model_id, job_id=str(existing_job.id) ) return True # Idempotent # Queue training job job_id = await self._queue_training_job( tenant_id=tenant_id, model_id=model_id, product_id=product_id, trigger_reason=trigger_reason, priority=priority, event_id=event_id, metadata=data ) if not job_id: logger.error("Failed to queue training job", model_id=model_id) return False # Update model status model.status = 'retraining_queued' model.updated_at = datetime.utcnow() await self.db_session.commit() # Publish job queued event await self._publish_job_queued_event( tenant_id=tenant_id, model_id=model_id, job_id=job_id, priority=priority ) logger.info( "Retraining job queued successfully", model_id=model_id, job_id=job_id, trigger_reason=trigger_reason, priority=priority ) return True except Exception as e: await self.db_session.rollback() logger.error( "Error handling retrain requested", error=str(e), model_id=data.get('model_id'), exc_info=True ) return False async def _handle_retrain_scheduled( self, tenant_id: str, data: Dict[str, Any] ) -> bool: """ Handle scheduled retraining event Similar to retrain_requested but for scheduled/batch retraining Args: tenant_id: Tenant ID data: Scheduled retraining data Returns: bool: True if handled successfully """ try: # Similar logic to _handle_retrain_requested # but may have different priority or batching logic logger.info( "Handling scheduled retraining", tenant_id=tenant_id, model_count=len(data.get('models', [])) ) # For now, redirect to retrain_requested handler success_count = 0 for model_data in data.get('models', []): if await self._handle_retrain_requested( tenant_id, model_data, {'event_id': data.get('schedule_id'), 'tenant_id': tenant_id} ): success_count += 1 logger.info( "Scheduled retraining processed", tenant_id=tenant_id, successful=success_count, total=len(data.get('models', [])) ) return success_count > 0 except Exception as e: logger.error( "Error handling retrain scheduled", error=str(e), tenant_id=tenant_id, exc_info=True ) return False async def _queue_training_job( self, tenant_id: str, model_id: str, product_id: str, trigger_reason: str, priority: str, event_id: str, metadata: Dict[str, Any] ) -> Optional[str]: """ Queue a training job in the database Args: tenant_id: Tenant ID model_id: Model ID to retrain product_id: Product ID trigger_reason: Why retraining was triggered priority: Job priority (low, normal, high) event_id: Originating event ID metadata: Additional job metadata Returns: Job ID if successful, None otherwise """ try: from app.models import TrainingJobQueue import uuid # Map priority to numeric value for sorting priority_map = { 'low': 1, 'normal': 2, 'high': 3, 'critical': 4 } job = TrainingJobQueue( id=uuid.uuid4(), tenant_id=UUID(tenant_id), model_id=UUID(model_id), product_id=UUID(product_id) if product_id else None, job_type='retrain', status='pending', priority=priority, priority_score=priority_map.get(priority, 2), trigger_reason=trigger_reason, event_id=event_id, metadata=metadata, created_at=datetime.utcnow(), scheduled_at=datetime.utcnow() ) self.db_session.add(job) await self.db_session.commit() logger.info( "Training job created", job_id=str(job.id), model_id=model_id, priority=priority, trigger_reason=trigger_reason ) return str(job.id) except Exception as e: await self.db_session.rollback() logger.error( "Failed to queue training job", model_id=model_id, error=str(e), exc_info=True ) return None async def _publish_job_queued_event( self, tenant_id: str, model_id: str, job_id: str, priority: str ): """ Publish event that training job was queued Args: tenant_id: Tenant ID model_id: Model ID job_id: Training job ID priority: Job priority """ try: from shared.messaging import get_rabbitmq_client import uuid rabbitmq_client = get_rabbitmq_client() if not rabbitmq_client: logger.warning("RabbitMQ client not available for event publishing") return event_payload = { "event_id": str(uuid.uuid4()), "event_type": "training.retrain.queued", "timestamp": datetime.utcnow().isoformat(), "tenant_id": tenant_id, "data": { "job_id": job_id, "model_id": model_id, "priority": priority, "status": "queued" } } await rabbitmq_client.publish_event( exchange_name="training.events", routing_key="training.retrain.queued", event_data=event_payload ) logger.info( "Published job queued event", job_id=job_id, event_id=event_payload["event_id"] ) except Exception as e: logger.error( "Failed to publish job queued event", job_id=job_id, error=str(e) ) # Don't fail the main operation if event publishing fails # Factory function for creating consumer instance def create_training_event_consumer(db_session: AsyncSession) -> TrainingEventConsumer: """Create training event consumer instance""" return TrainingEventConsumer(db_session)