Refactor all main.py

This commit is contained in:
Urtzi Alfaro
2025-09-29 13:13:12 +02:00
parent 4777e59e7a
commit befcc126b0
35 changed files with 2537 additions and 1993 deletions

View File

@@ -1,280 +1,147 @@
# ================================================================
# services/training/app/main.py - FIXED VERSION
# services/training/app/main.py
# ================================================================
"""
Training Service Main Application
Enhanced with proper error handling, monitoring, and lifecycle management
ML training service for bakery demand forecasting
"""
import structlog
import asyncio
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from fastapi.responses import JSONResponse
import uvicorn
from app.core.config import settings
from app.core.database import initialize_training_database, cleanup_training_database, get_db_health, get_comprehensive_db_health
from app.core.database import initialize_training_database, cleanup_training_database, database_manager
from app.api import training, models
from app.api.websocket import websocket_router
from app.services.messaging import setup_messaging, cleanup_messaging
from shared.monitoring.logging import setup_logging
from shared.monitoring.metrics import MetricsCollector
# REMOVED: from shared.auth.decorators import require_auth
from shared.service_base import StandardFastAPIService
# Setup structured logging
setup_logging("training-service", settings.LOG_LEVEL)
logger = structlog.get_logger()
# Initialize metrics collector
metrics_collector = MetricsCollector("training-service")
class TrainingService(StandardFastAPIService):
"""Training Service with standardized setup"""
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Application lifespan manager for startup and shutdown events
"""
# Startup
logger.info("Starting Training Service", version="1.0.0")
try:
# Initialize database
logger.info("Initializing database connection")
await initialize_training_database()
logger.info("Database initialized successfully")
# Initialize messaging
logger.info("Setting up messaging")
def __init__(self):
# Define expected database tables for health checks
training_expected_tables = [
'model_training_logs', 'trained_models', 'model_performance_metrics',
'training_job_queue', 'model_artifacts'
]
super().__init__(
service_name="training-service",
app_name="Bakery Training Service",
description="ML training service for bakery demand forecasting",
version="1.0.0",
log_level=settings.LOG_LEVEL,
cors_origins=settings.CORS_ORIGINS_LIST,
api_prefix="/api/v1",
database_manager=database_manager,
expected_tables=training_expected_tables,
enable_messaging=True
)
async def _setup_messaging(self):
"""Setup messaging for training service"""
await setup_messaging()
logger.info("Messaging setup completed")
# Start metrics server
logger.info("Starting metrics server")
metrics_collector.start_metrics_server(8080)
logger.info("Metrics server started on port 8080")
# Store metrics collector in app state
app.state.metrics_collector = metrics_collector
# Mark service as ready
app.state.ready = True
logger.info("Training Service startup completed successfully")
yield
except Exception as e:
logger.error("Failed to start Training Service", error=str(e))
app.state.ready = False
raise
# Shutdown
logger.info("Shutting down Training Service")
try:
# Stop metrics server
if hasattr(app.state, 'metrics_collector'):
await app.state.metrics_collector.shutdown()
# Cleanup messaging
self.logger.info("Messaging setup completed")
async def _cleanup_messaging(self):
"""Cleanup messaging for training service"""
await cleanup_messaging()
logger.info("Messaging cleanup completed")
# Close database connections
async def on_startup(self, app: FastAPI):
"""Custom startup logic for training service"""
pass
async def on_shutdown(self, app: FastAPI):
"""Custom shutdown logic for training service"""
# Note: Database cleanup is handled by the base class
# but training service has custom cleanup function
await cleanup_training_database()
logger.info("Database connections closed")
except Exception as e:
logger.error("Error during shutdown", error=str(e))
logger.info("Training Service shutdown completed")
self.logger.info("Training database cleanup completed")
# Create FastAPI application with lifespan
app = FastAPI(
title="Bakery Training Service",
description="ML training service for bakery demand forecasting",
version="1.0.0",
def get_service_features(self):
"""Return training-specific features"""
return [
"ml_model_training",
"demand_forecasting",
"model_performance_tracking",
"training_job_queue",
"model_artifacts_management",
"websocket_support",
"messaging_integration"
]
def setup_custom_middleware(self):
"""Setup custom middleware for training service"""
# Request middleware for logging and metrics
@self.app.middleware("http")
async def process_request(request: Request, call_next):
"""Process requests with logging and metrics"""
start_time = asyncio.get_event_loop().time()
try:
response = await call_next(request)
duration = asyncio.get_event_loop().time() - start_time
self.logger.info(
"Request completed",
method=request.method,
path=request.url.path,
status_code=response.status_code,
duration_ms=round(duration * 1000, 2)
)
return response
except Exception as e:
duration = asyncio.get_event_loop().time() - start_time
self.logger.error(
"Request failed",
method=request.method,
path=request.url.path,
error=str(e),
duration_ms=round(duration * 1000, 2)
)
raise
def setup_custom_endpoints(self):
"""Setup custom endpoints for training service"""
@self.app.get("/metrics")
async def get_metrics():
"""Prometheus metrics endpoint"""
if self.metrics_collector:
return self.metrics_collector.get_metrics()
return {"status": "metrics not available"}
@self.app.get("/")
async def root():
return {"service": "training-service", "version": "1.0.0"}
# Create service instance
service = TrainingService()
# Create FastAPI app with standardized setup
app = service.create_app(
docs_url="/docs",
redoc_url="/redoc",
lifespan=lifespan
redoc_url="/redoc"
)
# Add middleware
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS_LIST,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Setup standard endpoints
service.setup_standard_endpoints()
# Setup custom middleware
service.setup_custom_middleware()
# Request middleware for logging and metrics
@app.middleware("http")
async def process_request(request: Request, call_next):
"""Process requests with logging and metrics"""
start_time = asyncio.get_event_loop().time()
try:
response = await call_next(request)
duration = asyncio.get_event_loop().time() - start_time
logger.info(
"Request completed",
method=request.method,
path=request.url.path,
status_code=response.status_code,
duration_ms=round(duration * 1000, 2)
)
# Update metrics
metrics_collector.record_request(
method=request.method,
endpoint=request.url.path,
status_code=response.status_code,
duration=duration
)
return response
except Exception as e:
duration = asyncio.get_event_loop().time() - start_time
logger.error(
"Request failed",
method=request.method,
path=request.url.path,
error=str(e),
duration_ms=round(duration * 1000, 2)
)
metrics_collector.increment_counter("http_requests_failed_total")
raise
# Exception handlers
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
"""Global exception handler for unhandled errors"""
logger.error(
"Unhandled exception",
path=request.url.path,
method=request.method,
error=str(exc),
exc_info=True
)
metrics_collector.increment_counter("unhandled_exceptions_total")
return JSONResponse(
status_code=500,
content={
"detail": "Internal server error",
"error_id": structlog.get_logger().new().info("Error logged", error=str(exc))
}
)
# Setup custom endpoints
service.setup_custom_endpoints()
# Include API routers
app.include_router(training.router, prefix="/api/v1", tags=["training"])
app.include_router(models.router, prefix="/api/v1", tags=["models"])
service.add_router(training.router, tags=["training"])
service.add_router(models.router, tags=["models"])
app.include_router(websocket_router, prefix="/api/v1/ws", tags=["websocket"])
# Health check endpoints
@app.get("/health")
async def health_check():
"""Basic health check endpoint"""
return {
"status": "healthy" if app.state.ready else "starting",
"service": "training-service",
"version": "1.0.0",
"timestamp": structlog.get_logger().new().info("Health check")
}
@app.get("/health/ready")
async def readiness_check():
"""Kubernetes readiness probe endpoint with comprehensive database checks"""
try:
# Get comprehensive database health including table verification
db_health = await get_comprehensive_db_health()
checks = {
"database_connectivity": db_health["connectivity"],
"database_tables": db_health["tables_exist"],
"application": getattr(app.state, 'ready', False)
}
# Include detailed database info for debugging
database_details = {
"status": db_health["status"],
"tables_verified": db_health["tables_verified"],
"missing_tables": db_health["missing_tables"],
"errors": db_health["errors"]
}
# Service is ready only if all checks pass
all_ready = all(checks.values()) and db_health["status"] == "healthy"
if all_ready:
return {
"status": "ready",
"checks": checks,
"database": database_details
}
else:
return JSONResponse(
status_code=503,
content={
"status": "not ready",
"checks": checks,
"database": database_details
}
)
except Exception as e:
logger.error("Readiness check failed", error=str(e))
return JSONResponse(
status_code=503,
content={
"status": "not ready",
"error": f"Health check failed: {str(e)}"
}
)
@app.get("/health/database")
async def database_health_check():
"""Detailed database health endpoint for debugging"""
try:
db_health = await get_comprehensive_db_health()
status_code = 200 if db_health["status"] == "healthy" else 503
return JSONResponse(status_code=status_code, content=db_health)
except Exception as e:
logger.error("Database health check failed", error=str(e))
return JSONResponse(
status_code=503,
content={
"status": "unhealthy",
"error": f"Health check failed: {str(e)}"
}
)
@app.get("/metrics")
async def get_metrics():
"""Prometheus metrics endpoint"""
if hasattr(app.state, 'metrics_collector'):
return app.state.metrics_collector.get_metrics()
return {"status": "metrics not available"}
@app.get("/health/live")
async def liveness_check():
return {"status": "alive"}
@app.get("/")
async def root():
return {"service": "training-service", "version": "1.0.0"}
if __name__ == "__main__":
uvicorn.run(
"app.main:app",