Files
bakery-ia/services/training/app/main.py
2025-09-30 08:12:45 +02:00

174 lines
6.1 KiB
Python

# ================================================================
# services/training/app/main.py
# ================================================================
"""
Training Service Main Application
ML training service for bakery demand forecasting
"""
import asyncio
from fastapi import FastAPI, Request
from sqlalchemy import text
from app.core.config import settings
from app.core.database import initialize_training_database, cleanup_training_database, database_manager
from app.api import training, models
from app.api.websocket import websocket_router
from app.services.messaging import setup_messaging, cleanup_messaging
from shared.service_base import StandardFastAPIService
class TrainingService(StandardFastAPIService):
"""Training Service with standardized setup"""
expected_migration_version = "001_initial_training"
async def on_startup(self, app):
"""Custom startup logic including migration verification"""
await self.verify_migrations()
await super().on_startup(app)
async def verify_migrations(self):
"""Verify database schema matches the latest migrations."""
try:
async with self.database_manager.get_session() as session:
result = await session.execute(text("SELECT version_num FROM alembic_version"))
version = result.scalar()
if version != self.expected_migration_version:
self.logger.error(f"Migration version mismatch: expected {self.expected_migration_version}, got {version}")
raise RuntimeError(f"Migration version mismatch: expected {self.expected_migration_version}, got {version}")
self.logger.info(f"Migration verification successful: {version}")
except Exception as e:
self.logger.error(f"Migration verification failed: {e}")
raise
def __init__(self):
# Define expected database tables for health checks
training_expected_tables = [
'model_training_logs', 'trained_models', 'model_performance_metrics',
'training_job_queue', 'model_artifacts'
]
super().__init__(
service_name="training-service",
app_name="Bakery Training Service",
description="ML training service for bakery demand forecasting",
version="1.0.0",
log_level=settings.LOG_LEVEL,
cors_origins=settings.CORS_ORIGINS_LIST,
api_prefix="/api/v1",
database_manager=database_manager,
expected_tables=training_expected_tables,
enable_messaging=True
)
async def _setup_messaging(self):
"""Setup messaging for training service"""
await setup_messaging()
self.logger.info("Messaging setup completed")
async def _cleanup_messaging(self):
"""Cleanup messaging for training service"""
await cleanup_messaging()
async def on_startup(self, app: FastAPI):
"""Custom startup logic for training service"""
pass
async def on_shutdown(self, app: FastAPI):
"""Custom shutdown logic for training service"""
# Note: Database cleanup is handled by the base class
# but training service has custom cleanup function
await cleanup_training_database()
self.logger.info("Training database cleanup completed")
def get_service_features(self):
"""Return training-specific features"""
return [
"ml_model_training",
"demand_forecasting",
"model_performance_tracking",
"training_job_queue",
"model_artifacts_management",
"websocket_support",
"messaging_integration"
]
def setup_custom_middleware(self):
"""Setup custom middleware for training service"""
# Request middleware for logging and metrics
@self.app.middleware("http")
async def process_request(request: Request, call_next):
"""Process requests with logging and metrics"""
start_time = asyncio.get_event_loop().time()
try:
response = await call_next(request)
duration = asyncio.get_event_loop().time() - start_time
self.logger.info(
"Request completed",
method=request.method,
path=request.url.path,
status_code=response.status_code,
duration_ms=round(duration * 1000, 2)
)
return response
except Exception as e:
duration = asyncio.get_event_loop().time() - start_time
self.logger.error(
"Request failed",
method=request.method,
path=request.url.path,
error=str(e),
duration_ms=round(duration * 1000, 2)
)
raise
def setup_custom_endpoints(self):
"""Setup custom endpoints for training service"""
@self.app.get("/metrics")
async def get_metrics():
"""Prometheus metrics endpoint"""
if self.metrics_collector:
return self.metrics_collector.get_metrics()
return {"status": "metrics not available"}
@self.app.get("/")
async def root():
return {"service": "training-service", "version": "1.0.0"}
# Create service instance
service = TrainingService()
# Create FastAPI app with standardized setup
app = service.create_app(
docs_url="/docs",
redoc_url="/redoc"
)
# Setup standard endpoints
service.setup_standard_endpoints()
# Setup custom middleware
service.setup_custom_middleware()
# Setup custom endpoints
service.setup_custom_endpoints()
# Include API routers
service.add_router(training.router, tags=["training"])
service.add_router(models.router, tags=["models"])
app.include_router(websocket_router, prefix="/api/v1/ws", tags=["websocket"])
if __name__ == "__main__":
uvicorn.run(
"app.main:app",
host="0.0.0.0",
port=settings.PORT,
reload=settings.DEBUG,
log_level=settings.LOG_LEVEL.lower()
)