Files
bakery-ia/services/training/app/main.py
2026-01-18 09:02:27 +01:00

265 lines
10 KiB
Python

# ================================================================
# services/training/app/main.py
# ================================================================
"""
Training Service Main Application
ML training service for bakery demand forecasting
"""
import asyncio
from fastapi import FastAPI, Request
from sqlalchemy import text
from app.core.config import settings
from app.core.database import initialize_training_database, cleanup_training_database, database_manager
from app.api import training_jobs, training_operations, models, health, monitoring, websocket_operations, audit
from app.services.training_events import setup_messaging, cleanup_messaging
from app.websocket.events import setup_websocket_event_consumer, cleanup_websocket_consumers
from shared.service_base import StandardFastAPIService
from shared.monitoring.system_metrics import SystemMetricsCollector
class TrainingService(StandardFastAPIService):
"""Training Service with standardized setup"""
def __init__(self):
# Define expected database tables for health checks
training_expected_tables = [
'model_training_logs', 'trained_models', 'model_performance_metrics',
'training_job_queue', 'model_artifacts'
]
super().__init__(
service_name="training-service",
app_name="Bakery Training Service",
description="ML training service for bakery demand forecasting",
version="1.0.0",
log_level=settings.LOG_LEVEL,
cors_origins=settings.CORS_ORIGINS_LIST,
api_prefix="",
database_manager=database_manager,
expected_tables=training_expected_tables,
enable_messaging=True
)
async def _setup_messaging(self):
"""Setup messaging for training service"""
await setup_messaging()
self.logger.info("Messaging setup completed")
# Initialize Redis pub/sub for cross-pod WebSocket broadcasting
await self._setup_websocket_redis()
# Set up WebSocket event consumer (listens to RabbitMQ and broadcasts to WebSockets)
success = await setup_websocket_event_consumer()
if success:
self.logger.info("WebSocket event consumer setup completed")
else:
self.logger.warning("WebSocket event consumer setup failed")
async def _setup_websocket_redis(self):
"""
Initialize Redis pub/sub for WebSocket cross-pod broadcasting.
CRITICAL FOR HORIZONTAL SCALING:
Without this, WebSocket clients on Pod A won't receive events
from training jobs running on Pod B.
"""
try:
from app.websocket.manager import websocket_manager
from app.core.config import settings
redis_url = settings.REDIS_URL
success = await websocket_manager.initialize_redis(redis_url)
if success:
self.logger.info("WebSocket Redis pub/sub initialized for horizontal scaling")
else:
self.logger.warning(
"WebSocket Redis pub/sub failed to initialize. "
"WebSocket events will only be delivered to local connections."
)
except Exception as e:
self.logger.error("Failed to setup WebSocket Redis pub/sub",
error=str(e))
# Don't fail startup - WebSockets will work locally without Redis
async def _cleanup_messaging(self):
"""Cleanup messaging for training service"""
# Shutdown WebSocket Redis pub/sub
try:
from app.websocket.manager import websocket_manager
await websocket_manager.shutdown()
self.logger.info("WebSocket Redis pub/sub shutdown completed")
except Exception as e:
self.logger.warning("Error shutting down WebSocket Redis", error=str(e))
await cleanup_websocket_consumers()
await cleanup_messaging()
async def verify_migrations(self):
"""Verify database schema matches the latest migrations dynamically."""
try:
async with self.database_manager.get_session() as session:
result = await session.execute(text("SELECT version_num FROM alembic_version"))
version = result.scalar()
if not version:
self.logger.error("No migration version found in database")
raise RuntimeError("Database not initialized - no alembic version found")
self.logger.info(f"Migration verification successful: {version}")
return version
except Exception as e:
self.logger.error(f"Migration verification failed: {e}")
raise
async def on_startup(self, app: FastAPI):
"""Custom startup logic including migration verification"""
await self.verify_migrations()
# Initialize system metrics collection
system_metrics = SystemMetricsCollector("training")
self.logger.info("System metrics collection started")
# Recover stale jobs from previous pod crashes
# This is important for horizontal scaling - jobs may be left in 'running'
# state if a pod crashes. We mark them as failed so they can be retried.
await self._recover_stale_jobs()
self.logger.info("Training service startup completed")
async def _recover_stale_jobs(self):
"""
Recover stale training jobs on startup.
When a pod crashes mid-training, jobs are left in 'running' or 'pending' state.
This method finds jobs that haven't been updated in a while and marks them
as failed so users can retry them.
"""
try:
from app.repositories.training_log_repository import TrainingLogRepository
async with self.database_manager.get_session() as session:
log_repo = TrainingLogRepository(session)
# Recover jobs that haven't been updated in 60 minutes
# This is conservative - most training jobs complete within 30 minutes
recovered = await log_repo.recover_stale_jobs(stale_threshold_minutes=60)
if recovered:
self.logger.warning(
"Recovered stale training jobs on startup",
recovered_count=len(recovered),
job_ids=[j.job_id for j in recovered]
)
else:
self.logger.info("No stale training jobs to recover")
except Exception as e:
# Don't fail startup if recovery fails - just log the error
self.logger.error("Failed to recover stale jobs on startup", error=str(e))
async def on_shutdown(self, app: FastAPI):
"""Custom shutdown logic for training service"""
await cleanup_training_database()
self.logger.info("Training database cleanup completed")
def get_service_features(self):
"""Return training-specific features"""
return [
"ml_model_training",
"demand_forecasting",
"model_performance_tracking",
"training_job_queue",
"model_artifacts_management",
"websocket_support",
"messaging_integration"
]
def setup_custom_middleware(self):
"""Setup custom middleware for training service"""
# Request middleware for logging and metrics
@self.app.middleware("http")
async def process_request(request: Request, call_next):
"""Process requests with logging and metrics"""
start_time = asyncio.get_event_loop().time()
try:
response = await call_next(request)
duration = asyncio.get_event_loop().time() - start_time
self.logger.info(
"Request completed",
method=request.method,
path=request.url.path,
status_code=response.status_code,
duration_ms=round(duration * 1000, 2)
)
return response
except Exception as e:
duration = asyncio.get_event_loop().time() - start_time
self.logger.error(
"Request failed",
method=request.method,
path=request.url.path,
error=str(e),
duration_ms=round(duration * 1000, 2)
)
raise
def setup_custom_endpoints(self):
"""Setup custom endpoints for training service"""
# Note: Metrics are exported via OpenTelemetry OTLP to SigNoz
# The /metrics endpoint is not needed as metrics are pushed automatically
# @self.app.get("/metrics")
# async def get_metrics():
# """Prometheus metrics endpoint"""
# if self.metrics_collector:
# return self.metrics_collector.get_metrics()
# return {"status": "metrics not available"}
@self.app.get("/")
async def root():
return {"service": "training-service", "version": "1.0.0"}
# Create service instance
service = TrainingService()
# Create FastAPI app with standardized setup
app = service.create_app(
docs_url="/docs",
redoc_url="/redoc"
)
# Setup standard endpoints
service.setup_standard_endpoints()
# Setup custom middleware
service.setup_custom_middleware()
# Setup custom endpoints
service.setup_custom_endpoints()
# Include API routers
# IMPORTANT: Register audit router FIRST to avoid route matching conflicts
service.add_router(audit.router)
service.add_router(training_jobs.router, tags=["training-jobs"])
service.add_router(training_operations.router, tags=["training-operations"])
service.add_router(models.router, tags=["models"])
service.add_router(health.router, tags=["health"])
service.add_router(monitoring.router, tags=["monitoring"])
service.add_router(websocket_operations.router, tags=["websocket"])
if __name__ == "__main__":
uvicorn.run(
"app.main:app",
host="0.0.0.0",
port=settings.PORT,
reload=settings.DEBUG,
log_level=settings.LOG_LEVEL.lower()
)