# ================================================================ # services/training/app/main.py # ================================================================ """ Training Service Main Application ML training service for bakery demand forecasting """ import asyncio from fastapi import FastAPI, Request from sqlalchemy import text from app.core.config import settings from app.core.database import initialize_training_database, cleanup_training_database, database_manager from app.api import training_jobs, training_operations, models, health, monitoring, websocket_operations from app.services.training_events import setup_messaging, cleanup_messaging from app.websocket.events import setup_websocket_event_consumer, cleanup_websocket_consumers from shared.service_base import StandardFastAPIService class TrainingService(StandardFastAPIService): """Training Service with standardized setup""" def __init__(self): # Define expected database tables for health checks training_expected_tables = [ 'model_training_logs', 'trained_models', 'model_performance_metrics', 'training_job_queue', 'model_artifacts' ] super().__init__( service_name="training-service", app_name="Bakery Training Service", description="ML training service for bakery demand forecasting", version="1.0.0", log_level=settings.LOG_LEVEL, cors_origins=settings.CORS_ORIGINS_LIST, api_prefix="", database_manager=database_manager, expected_tables=training_expected_tables, enable_messaging=True ) async def _setup_messaging(self): """Setup messaging for training service""" await setup_messaging() self.logger.info("Messaging setup completed") # Set up WebSocket event consumer (listens to RabbitMQ and broadcasts to WebSockets) success = await setup_websocket_event_consumer() if success: self.logger.info("WebSocket event consumer setup completed") else: self.logger.warning("WebSocket event consumer setup failed") async def _cleanup_messaging(self): """Cleanup messaging for training service""" await cleanup_websocket_consumers() await cleanup_messaging() async def verify_migrations(self): """Verify database schema matches the latest migrations dynamically.""" try: async with self.database_manager.get_session() as session: result = await session.execute(text("SELECT version_num FROM alembic_version")) version = result.scalar() if not version: self.logger.error("No migration version found in database") raise RuntimeError("Database not initialized - no alembic version found") self.logger.info(f"Migration verification successful: {version}") return version except Exception as e: self.logger.error(f"Migration verification failed: {e}") raise async def on_startup(self, app: FastAPI): """Custom startup logic including migration verification""" await self.verify_migrations() self.logger.info("Training service startup completed") async def on_shutdown(self, app: FastAPI): """Custom shutdown logic for training service""" await cleanup_training_database() self.logger.info("Training database cleanup completed") def get_service_features(self): """Return training-specific features""" return [ "ml_model_training", "demand_forecasting", "model_performance_tracking", "training_job_queue", "model_artifacts_management", "websocket_support", "messaging_integration" ] def setup_custom_middleware(self): """Setup custom middleware for training service""" # Request middleware for logging and metrics @self.app.middleware("http") async def process_request(request: Request, call_next): """Process requests with logging and metrics""" start_time = asyncio.get_event_loop().time() try: response = await call_next(request) duration = asyncio.get_event_loop().time() - start_time self.logger.info( "Request completed", method=request.method, path=request.url.path, status_code=response.status_code, duration_ms=round(duration * 1000, 2) ) return response except Exception as e: duration = asyncio.get_event_loop().time() - start_time self.logger.error( "Request failed", method=request.method, path=request.url.path, error=str(e), duration_ms=round(duration * 1000, 2) ) raise def setup_custom_endpoints(self): """Setup custom endpoints for training service""" @self.app.get("/metrics") async def get_metrics(): """Prometheus metrics endpoint""" if self.metrics_collector: return self.metrics_collector.get_metrics() return {"status": "metrics not available"} @self.app.get("/") async def root(): return {"service": "training-service", "version": "1.0.0"} # Create service instance service = TrainingService() # Create FastAPI app with standardized setup app = service.create_app( docs_url="/docs", redoc_url="/redoc" ) # Setup standard endpoints service.setup_standard_endpoints() # Setup custom middleware service.setup_custom_middleware() # Setup custom endpoints service.setup_custom_endpoints() # Include API routers service.add_router(training_jobs.router, tags=["training-jobs"]) service.add_router(training_operations.router, tags=["training-operations"]) service.add_router(models.router, tags=["models"]) service.add_router(health.router, tags=["health"]) service.add_router(monitoring.router, tags=["monitoring"]) service.add_router(websocket_operations.router, tags=["websocket"]) if __name__ == "__main__": uvicorn.run( "app.main:app", host="0.0.0.0", port=settings.PORT, reload=settings.DEBUG, log_level=settings.LOG_LEVEL.lower() )