Files
bakery-ia/services/training/app/main.py

174 lines
6.1 KiB
Python
Raw Normal View History

2025-07-19 21:16:25 +02:00
# ================================================================
2025-09-29 13:13:12 +02:00
# services/training/app/main.py
2025-07-19 21:16:25 +02:00
# ================================================================
"""
2025-07-19 16:59:37 +02:00
Training Service Main Application
2025-09-29 13:13:12 +02:00
ML training service for bakery demand forecasting
"""
2025-07-19 16:59:37 +02:00
import asyncio
from fastapi import FastAPI, Request
2025-09-30 08:12:45 +02:00
from sqlalchemy import text
from app.core.config import settings
2025-09-29 13:13:12 +02:00
from app.core.database import initialize_training_database, cleanup_training_database, database_manager
from app.api import training, models
from app.api.websocket import websocket_router
2025-07-18 14:41:39 +02:00
from app.services.messaging import setup_messaging, cleanup_messaging
2025-09-29 13:13:12 +02:00
from shared.service_base import StandardFastAPIService
class TrainingService(StandardFastAPIService):
"""Training Service with standardized setup"""
2025-09-30 08:12:45 +02:00
expected_migration_version = "001_initial_training"
async def on_startup(self, app):
"""Custom startup logic including migration verification"""
await self.verify_migrations()
await super().on_startup(app)
async def verify_migrations(self):
"""Verify database schema matches the latest migrations."""
try:
async with self.database_manager.get_session() as session:
result = await session.execute(text("SELECT version_num FROM alembic_version"))
version = result.scalar()
if version != self.expected_migration_version:
self.logger.error(f"Migration version mismatch: expected {self.expected_migration_version}, got {version}")
raise RuntimeError(f"Migration version mismatch: expected {self.expected_migration_version}, got {version}")
self.logger.info(f"Migration verification successful: {version}")
except Exception as e:
self.logger.error(f"Migration verification failed: {e}")
raise
2025-09-29 13:13:12 +02:00
def __init__(self):
# Define expected database tables for health checks
training_expected_tables = [
'model_training_logs', 'trained_models', 'model_performance_metrics',
'training_job_queue', 'model_artifacts'
]
super().__init__(
service_name="training-service",
app_name="Bakery Training Service",
description="ML training service for bakery demand forecasting",
version="1.0.0",
log_level=settings.LOG_LEVEL,
cors_origins=settings.CORS_ORIGINS_LIST,
api_prefix="/api/v1",
database_manager=database_manager,
expected_tables=training_expected_tables,
enable_messaging=True
)
2025-07-19 16:59:37 +02:00
2025-09-29 13:13:12 +02:00
async def _setup_messaging(self):
"""Setup messaging for training service"""
2025-07-19 16:59:37 +02:00
await setup_messaging()
2025-09-29 13:13:12 +02:00
self.logger.info("Messaging setup completed")
async def _cleanup_messaging(self):
"""Cleanup messaging for training service"""
2025-07-19 16:59:37 +02:00
await cleanup_messaging()
2025-09-29 13:13:12 +02:00
async def on_startup(self, app: FastAPI):
"""Custom startup logic for training service"""
pass
2025-09-29 13:13:12 +02:00
async def on_shutdown(self, app: FastAPI):
"""Custom shutdown logic for training service"""
# Note: Database cleanup is handled by the base class
# but training service has custom cleanup function
await cleanup_training_database()
self.logger.info("Training database cleanup completed")
def get_service_features(self):
"""Return training-specific features"""
return [
"ml_model_training",
"demand_forecasting",
"model_performance_tracking",
"training_job_queue",
"model_artifacts_management",
"websocket_support",
"messaging_integration"
]
def setup_custom_middleware(self):
"""Setup custom middleware for training service"""
# Request middleware for logging and metrics
@self.app.middleware("http")
async def process_request(request: Request, call_next):
"""Process requests with logging and metrics"""
start_time = asyncio.get_event_loop().time()
try:
response = await call_next(request)
duration = asyncio.get_event_loop().time() - start_time
self.logger.info(
"Request completed",
method=request.method,
path=request.url.path,
status_code=response.status_code,
duration_ms=round(duration * 1000, 2)
)
return response
except Exception as e:
duration = asyncio.get_event_loop().time() - start_time
self.logger.error(
"Request failed",
method=request.method,
path=request.url.path,
error=str(e),
duration_ms=round(duration * 1000, 2)
)
raise
def setup_custom_endpoints(self):
"""Setup custom endpoints for training service"""
@self.app.get("/metrics")
async def get_metrics():
"""Prometheus metrics endpoint"""
if self.metrics_collector:
return self.metrics_collector.get_metrics()
return {"status": "metrics not available"}
@self.app.get("/")
async def root():
return {"service": "training-service", "version": "1.0.0"}
# Create service instance
service = TrainingService()
# Create FastAPI app with standardized setup
app = service.create_app(
docs_url="/docs",
redoc_url="/redoc"
)
2025-09-29 13:13:12 +02:00
# Setup standard endpoints
service.setup_standard_endpoints()
2025-07-19 21:16:25 +02:00
2025-09-29 13:13:12 +02:00
# Setup custom middleware
service.setup_custom_middleware()
2025-09-29 13:13:12 +02:00
# Setup custom endpoints
service.setup_custom_endpoints()
2025-07-26 18:46:52 +02:00
# Include API routers
2025-09-29 13:13:12 +02:00
service.add_router(training.router, tags=["training"])
service.add_router(models.router, tags=["models"])
app.include_router(websocket_router, prefix="/api/v1/ws", tags=["websocket"])
2025-07-19 16:59:37 +02:00
if __name__ == "__main__":
2025-07-19 16:59:37 +02:00
uvicorn.run(
"app.main:app",
host="0.0.0.0",
2025-07-19 21:16:25 +02:00
port=settings.PORT,
2025-07-19 16:59:37 +02:00
reload=settings.DEBUG,
2025-07-19 21:16:25 +02:00
log_level=settings.LOG_LEVEL.lower()
2025-07-19 16:59:37 +02:00
)