Initial commit - production deployment
This commit is contained in:
265
services/training/app/main.py
Normal file
265
services/training/app/main.py
Normal file
@@ -0,0 +1,265 @@
|
||||
# ================================================================
|
||||
# services/training/app/main.py
|
||||
# ================================================================
|
||||
"""
|
||||
Training Service Main Application
|
||||
ML training service for bakery demand forecasting
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from fastapi import FastAPI, Request
|
||||
from sqlalchemy import text
|
||||
from app.core.config import settings
|
||||
from app.core.database import initialize_training_database, cleanup_training_database, database_manager
|
||||
from app.api import training_jobs, training_operations, models, health, monitoring, websocket_operations, audit
|
||||
from app.services.training_events import setup_messaging, cleanup_messaging
|
||||
from app.websocket.events import setup_websocket_event_consumer, cleanup_websocket_consumers
|
||||
from shared.service_base import StandardFastAPIService
|
||||
from shared.monitoring.system_metrics import SystemMetricsCollector
|
||||
|
||||
|
||||
class TrainingService(StandardFastAPIService):
|
||||
"""Training Service with standardized setup"""
|
||||
|
||||
def __init__(self):
|
||||
# Define expected database tables for health checks
|
||||
training_expected_tables = [
|
||||
'model_training_logs', 'trained_models', 'model_performance_metrics',
|
||||
'training_job_queue', 'model_artifacts'
|
||||
]
|
||||
|
||||
super().__init__(
|
||||
service_name="training-service",
|
||||
app_name="Bakery Training Service",
|
||||
description="ML training service for bakery demand forecasting",
|
||||
version="1.0.0",
|
||||
log_level=settings.LOG_LEVEL,
|
||||
cors_origins=settings.CORS_ORIGINS_LIST,
|
||||
api_prefix="",
|
||||
database_manager=database_manager,
|
||||
expected_tables=training_expected_tables,
|
||||
enable_messaging=True
|
||||
)
|
||||
|
||||
async def _setup_messaging(self):
|
||||
"""Setup messaging for training service"""
|
||||
await setup_messaging()
|
||||
self.logger.info("Messaging setup completed")
|
||||
|
||||
# Initialize Redis pub/sub for cross-pod WebSocket broadcasting
|
||||
await self._setup_websocket_redis()
|
||||
|
||||
# Set up WebSocket event consumer (listens to RabbitMQ and broadcasts to WebSockets)
|
||||
success = await setup_websocket_event_consumer()
|
||||
if success:
|
||||
self.logger.info("WebSocket event consumer setup completed")
|
||||
else:
|
||||
self.logger.warning("WebSocket event consumer setup failed")
|
||||
|
||||
async def _setup_websocket_redis(self):
|
||||
"""
|
||||
Initialize Redis pub/sub for WebSocket cross-pod broadcasting.
|
||||
|
||||
CRITICAL FOR HORIZONTAL SCALING:
|
||||
Without this, WebSocket clients on Pod A won't receive events
|
||||
from training jobs running on Pod B.
|
||||
"""
|
||||
try:
|
||||
from app.websocket.manager import websocket_manager
|
||||
from app.core.config import settings
|
||||
|
||||
redis_url = settings.REDIS_URL
|
||||
success = await websocket_manager.initialize_redis(redis_url)
|
||||
|
||||
if success:
|
||||
self.logger.info("WebSocket Redis pub/sub initialized for horizontal scaling")
|
||||
else:
|
||||
self.logger.warning(
|
||||
"WebSocket Redis pub/sub failed to initialize. "
|
||||
"WebSocket events will only be delivered to local connections."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Failed to setup WebSocket Redis pub/sub",
|
||||
error=str(e))
|
||||
# Don't fail startup - WebSockets will work locally without Redis
|
||||
|
||||
async def _cleanup_messaging(self):
|
||||
"""Cleanup messaging for training service"""
|
||||
# Shutdown WebSocket Redis pub/sub
|
||||
try:
|
||||
from app.websocket.manager import websocket_manager
|
||||
await websocket_manager.shutdown()
|
||||
self.logger.info("WebSocket Redis pub/sub shutdown completed")
|
||||
except Exception as e:
|
||||
self.logger.warning("Error shutting down WebSocket Redis", error=str(e))
|
||||
|
||||
await cleanup_websocket_consumers()
|
||||
await cleanup_messaging()
|
||||
|
||||
async def verify_migrations(self):
|
||||
"""Verify database schema matches the latest migrations dynamically."""
|
||||
try:
|
||||
async with self.database_manager.get_session() as session:
|
||||
result = await session.execute(text("SELECT version_num FROM alembic_version"))
|
||||
version = result.scalar()
|
||||
|
||||
if not version:
|
||||
self.logger.error("No migration version found in database")
|
||||
raise RuntimeError("Database not initialized - no alembic version found")
|
||||
|
||||
self.logger.info(f"Migration verification successful: {version}")
|
||||
return version
|
||||
except Exception as e:
|
||||
self.logger.error(f"Migration verification failed: {e}")
|
||||
raise
|
||||
|
||||
async def on_startup(self, app: FastAPI):
|
||||
"""Custom startup logic including migration verification"""
|
||||
await self.verify_migrations()
|
||||
|
||||
# Initialize system metrics collection
|
||||
system_metrics = SystemMetricsCollector("training")
|
||||
self.logger.info("System metrics collection started")
|
||||
|
||||
# Recover stale jobs from previous pod crashes
|
||||
# This is important for horizontal scaling - jobs may be left in 'running'
|
||||
# state if a pod crashes. We mark them as failed so they can be retried.
|
||||
await self._recover_stale_jobs()
|
||||
|
||||
self.logger.info("Training service startup completed")
|
||||
|
||||
async def _recover_stale_jobs(self):
|
||||
"""
|
||||
Recover stale training jobs on startup.
|
||||
|
||||
When a pod crashes mid-training, jobs are left in 'running' or 'pending' state.
|
||||
This method finds jobs that haven't been updated in a while and marks them
|
||||
as failed so users can retry them.
|
||||
"""
|
||||
try:
|
||||
from app.repositories.training_log_repository import TrainingLogRepository
|
||||
|
||||
async with self.database_manager.get_session() as session:
|
||||
log_repo = TrainingLogRepository(session)
|
||||
|
||||
# Recover jobs that haven't been updated in 60 minutes
|
||||
# This is conservative - most training jobs complete within 30 minutes
|
||||
recovered = await log_repo.recover_stale_jobs(stale_threshold_minutes=60)
|
||||
|
||||
if recovered:
|
||||
self.logger.warning(
|
||||
"Recovered stale training jobs on startup",
|
||||
recovered_count=len(recovered),
|
||||
job_ids=[j.job_id for j in recovered]
|
||||
)
|
||||
else:
|
||||
self.logger.info("No stale training jobs to recover")
|
||||
|
||||
except Exception as e:
|
||||
# Don't fail startup if recovery fails - just log the error
|
||||
self.logger.error("Failed to recover stale jobs on startup", error=str(e))
|
||||
|
||||
async def on_shutdown(self, app: FastAPI):
|
||||
"""Custom shutdown logic for training service"""
|
||||
await cleanup_training_database()
|
||||
self.logger.info("Training database cleanup completed")
|
||||
|
||||
def get_service_features(self):
|
||||
"""Return training-specific features"""
|
||||
return [
|
||||
"ml_model_training",
|
||||
"demand_forecasting",
|
||||
"model_performance_tracking",
|
||||
"training_job_queue",
|
||||
"model_artifacts_management",
|
||||
"websocket_support",
|
||||
"messaging_integration"
|
||||
]
|
||||
|
||||
def setup_custom_middleware(self):
|
||||
"""Setup custom middleware for training service"""
|
||||
# Request middleware for logging and metrics
|
||||
@self.app.middleware("http")
|
||||
async def process_request(request: Request, call_next):
|
||||
"""Process requests with logging and metrics"""
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
try:
|
||||
response = await call_next(request)
|
||||
duration = asyncio.get_event_loop().time() - start_time
|
||||
|
||||
self.logger.info(
|
||||
"Request completed",
|
||||
method=request.method,
|
||||
path=request.url.path,
|
||||
status_code=response.status_code,
|
||||
duration_ms=round(duration * 1000, 2)
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
duration = asyncio.get_event_loop().time() - start_time
|
||||
|
||||
self.logger.error(
|
||||
"Request failed",
|
||||
method=request.method,
|
||||
path=request.url.path,
|
||||
error=str(e),
|
||||
duration_ms=round(duration * 1000, 2)
|
||||
)
|
||||
raise
|
||||
|
||||
def setup_custom_endpoints(self):
|
||||
"""Setup custom endpoints for training service"""
|
||||
# Note: Metrics are exported via OpenTelemetry OTLP to SigNoz
|
||||
# The /metrics endpoint is not needed as metrics are pushed automatically
|
||||
# @self.app.get("/metrics")
|
||||
# async def get_metrics():
|
||||
# """Prometheus metrics endpoint"""
|
||||
# if self.metrics_collector:
|
||||
# return self.metrics_collector.get_metrics()
|
||||
# return {"status": "metrics not available"}
|
||||
|
||||
@self.app.get("/")
|
||||
async def root():
|
||||
return {"service": "training-service", "version": "1.0.0"}
|
||||
|
||||
|
||||
# Create service instance
|
||||
service = TrainingService()
|
||||
|
||||
# Create FastAPI app with standardized setup
|
||||
app = service.create_app(
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc"
|
||||
)
|
||||
|
||||
# Setup standard endpoints
|
||||
service.setup_standard_endpoints()
|
||||
|
||||
# Setup custom middleware
|
||||
service.setup_custom_middleware()
|
||||
|
||||
# Setup custom endpoints
|
||||
service.setup_custom_endpoints()
|
||||
|
||||
# Include API routers
|
||||
# IMPORTANT: Register audit router FIRST to avoid route matching conflicts
|
||||
service.add_router(audit.router)
|
||||
service.add_router(training_jobs.router, tags=["training-jobs"])
|
||||
service.add_router(training_operations.router, tags=["training-operations"])
|
||||
service.add_router(models.router, tags=["models"])
|
||||
service.add_router(health.router, tags=["health"])
|
||||
service.add_router(monitoring.router, tags=["monitoring"])
|
||||
service.add_router(websocket_operations.router, tags=["websocket"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host="0.0.0.0",
|
||||
port=settings.PORT,
|
||||
reload=settings.DEBUG,
|
||||
log_level=settings.LOG_LEVEL.lower()
|
||||
)
|
||||
Reference in New Issue
Block a user