Refactor all main.py
This commit is contained in:
@@ -1,280 +1,147 @@
|
||||
# ================================================================
|
||||
# services/training/app/main.py - FIXED VERSION
|
||||
# services/training/app/main.py
|
||||
# ================================================================
|
||||
"""
|
||||
Training Service Main Application
|
||||
Enhanced with proper error handling, monitoring, and lifecycle management
|
||||
ML training service for bakery demand forecasting
|
||||
"""
|
||||
|
||||
import structlog
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
import uvicorn
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import initialize_training_database, cleanup_training_database, get_db_health, get_comprehensive_db_health
|
||||
from app.core.database import initialize_training_database, cleanup_training_database, database_manager
|
||||
from app.api import training, models
|
||||
|
||||
from app.api.websocket import websocket_router
|
||||
from app.services.messaging import setup_messaging, cleanup_messaging
|
||||
from shared.monitoring.logging import setup_logging
|
||||
from shared.monitoring.metrics import MetricsCollector
|
||||
# REMOVED: from shared.auth.decorators import require_auth
|
||||
from shared.service_base import StandardFastAPIService
|
||||
|
||||
# Setup structured logging
|
||||
setup_logging("training-service", settings.LOG_LEVEL)
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Initialize metrics collector
|
||||
metrics_collector = MetricsCollector("training-service")
|
||||
class TrainingService(StandardFastAPIService):
|
||||
"""Training Service with standardized setup"""
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""
|
||||
Application lifespan manager for startup and shutdown events
|
||||
"""
|
||||
# Startup
|
||||
logger.info("Starting Training Service", version="1.0.0")
|
||||
|
||||
try:
|
||||
# Initialize database
|
||||
logger.info("Initializing database connection")
|
||||
await initialize_training_database()
|
||||
logger.info("Database initialized successfully")
|
||||
|
||||
# Initialize messaging
|
||||
logger.info("Setting up messaging")
|
||||
def __init__(self):
|
||||
# Define expected database tables for health checks
|
||||
training_expected_tables = [
|
||||
'model_training_logs', 'trained_models', 'model_performance_metrics',
|
||||
'training_job_queue', 'model_artifacts'
|
||||
]
|
||||
|
||||
super().__init__(
|
||||
service_name="training-service",
|
||||
app_name="Bakery Training Service",
|
||||
description="ML training service for bakery demand forecasting",
|
||||
version="1.0.0",
|
||||
log_level=settings.LOG_LEVEL,
|
||||
cors_origins=settings.CORS_ORIGINS_LIST,
|
||||
api_prefix="/api/v1",
|
||||
database_manager=database_manager,
|
||||
expected_tables=training_expected_tables,
|
||||
enable_messaging=True
|
||||
)
|
||||
|
||||
async def _setup_messaging(self):
|
||||
"""Setup messaging for training service"""
|
||||
await setup_messaging()
|
||||
logger.info("Messaging setup completed")
|
||||
|
||||
# Start metrics server
|
||||
logger.info("Starting metrics server")
|
||||
metrics_collector.start_metrics_server(8080)
|
||||
logger.info("Metrics server started on port 8080")
|
||||
|
||||
# Store metrics collector in app state
|
||||
app.state.metrics_collector = metrics_collector
|
||||
|
||||
# Mark service as ready
|
||||
app.state.ready = True
|
||||
logger.info("Training Service startup completed successfully")
|
||||
|
||||
yield
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to start Training Service", error=str(e))
|
||||
app.state.ready = False
|
||||
raise
|
||||
|
||||
# Shutdown
|
||||
logger.info("Shutting down Training Service")
|
||||
|
||||
try:
|
||||
# Stop metrics server
|
||||
if hasattr(app.state, 'metrics_collector'):
|
||||
await app.state.metrics_collector.shutdown()
|
||||
|
||||
# Cleanup messaging
|
||||
self.logger.info("Messaging setup completed")
|
||||
|
||||
async def _cleanup_messaging(self):
|
||||
"""Cleanup messaging for training service"""
|
||||
await cleanup_messaging()
|
||||
logger.info("Messaging cleanup completed")
|
||||
|
||||
# Close database connections
|
||||
|
||||
async def on_startup(self, app: FastAPI):
|
||||
"""Custom startup logic for training service"""
|
||||
pass
|
||||
|
||||
async def on_shutdown(self, app: FastAPI):
|
||||
"""Custom shutdown logic for training service"""
|
||||
# Note: Database cleanup is handled by the base class
|
||||
# but training service has custom cleanup function
|
||||
await cleanup_training_database()
|
||||
logger.info("Database connections closed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error during shutdown", error=str(e))
|
||||
|
||||
logger.info("Training Service shutdown completed")
|
||||
self.logger.info("Training database cleanup completed")
|
||||
|
||||
# Create FastAPI application with lifespan
|
||||
app = FastAPI(
|
||||
title="Bakery Training Service",
|
||||
description="ML training service for bakery demand forecasting",
|
||||
version="1.0.0",
|
||||
def get_service_features(self):
|
||||
"""Return training-specific features"""
|
||||
return [
|
||||
"ml_model_training",
|
||||
"demand_forecasting",
|
||||
"model_performance_tracking",
|
||||
"training_job_queue",
|
||||
"model_artifacts_management",
|
||||
"websocket_support",
|
||||
"messaging_integration"
|
||||
]
|
||||
|
||||
def setup_custom_middleware(self):
|
||||
"""Setup custom middleware for training service"""
|
||||
# Request middleware for logging and metrics
|
||||
@self.app.middleware("http")
|
||||
async def process_request(request: Request, call_next):
|
||||
"""Process requests with logging and metrics"""
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
try:
|
||||
response = await call_next(request)
|
||||
duration = asyncio.get_event_loop().time() - start_time
|
||||
|
||||
self.logger.info(
|
||||
"Request completed",
|
||||
method=request.method,
|
||||
path=request.url.path,
|
||||
status_code=response.status_code,
|
||||
duration_ms=round(duration * 1000, 2)
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
duration = asyncio.get_event_loop().time() - start_time
|
||||
|
||||
self.logger.error(
|
||||
"Request failed",
|
||||
method=request.method,
|
||||
path=request.url.path,
|
||||
error=str(e),
|
||||
duration_ms=round(duration * 1000, 2)
|
||||
)
|
||||
raise
|
||||
|
||||
def setup_custom_endpoints(self):
|
||||
"""Setup custom endpoints for training service"""
|
||||
@self.app.get("/metrics")
|
||||
async def get_metrics():
|
||||
"""Prometheus metrics endpoint"""
|
||||
if self.metrics_collector:
|
||||
return self.metrics_collector.get_metrics()
|
||||
return {"status": "metrics not available"}
|
||||
|
||||
@self.app.get("/")
|
||||
async def root():
|
||||
return {"service": "training-service", "version": "1.0.0"}
|
||||
|
||||
|
||||
# Create service instance
|
||||
service = TrainingService()
|
||||
|
||||
# Create FastAPI app with standardized setup
|
||||
app = service.create_app(
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
lifespan=lifespan
|
||||
redoc_url="/redoc"
|
||||
)
|
||||
|
||||
# Add middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.CORS_ORIGINS_LIST,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
# Setup standard endpoints
|
||||
service.setup_standard_endpoints()
|
||||
|
||||
# Setup custom middleware
|
||||
service.setup_custom_middleware()
|
||||
|
||||
# Request middleware for logging and metrics
|
||||
@app.middleware("http")
|
||||
async def process_request(request: Request, call_next):
|
||||
"""Process requests with logging and metrics"""
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
try:
|
||||
response = await call_next(request)
|
||||
duration = asyncio.get_event_loop().time() - start_time
|
||||
|
||||
logger.info(
|
||||
"Request completed",
|
||||
method=request.method,
|
||||
path=request.url.path,
|
||||
status_code=response.status_code,
|
||||
duration_ms=round(duration * 1000, 2)
|
||||
)
|
||||
|
||||
# Update metrics
|
||||
metrics_collector.record_request(
|
||||
method=request.method,
|
||||
endpoint=request.url.path,
|
||||
status_code=response.status_code,
|
||||
duration=duration
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
duration = asyncio.get_event_loop().time() - start_time
|
||||
|
||||
logger.error(
|
||||
"Request failed",
|
||||
method=request.method,
|
||||
path=request.url.path,
|
||||
error=str(e),
|
||||
duration_ms=round(duration * 1000, 2)
|
||||
)
|
||||
|
||||
metrics_collector.increment_counter("http_requests_failed_total")
|
||||
raise
|
||||
|
||||
# Exception handlers
|
||||
@app.exception_handler(Exception)
|
||||
async def global_exception_handler(request: Request, exc: Exception):
|
||||
"""Global exception handler for unhandled errors"""
|
||||
logger.error(
|
||||
"Unhandled exception",
|
||||
path=request.url.path,
|
||||
method=request.method,
|
||||
error=str(exc),
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
metrics_collector.increment_counter("unhandled_exceptions_total")
|
||||
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"detail": "Internal server error",
|
||||
"error_id": structlog.get_logger().new().info("Error logged", error=str(exc))
|
||||
}
|
||||
)
|
||||
# Setup custom endpoints
|
||||
service.setup_custom_endpoints()
|
||||
|
||||
# Include API routers
|
||||
app.include_router(training.router, prefix="/api/v1", tags=["training"])
|
||||
|
||||
app.include_router(models.router, prefix="/api/v1", tags=["models"])
|
||||
service.add_router(training.router, tags=["training"])
|
||||
service.add_router(models.router, tags=["models"])
|
||||
app.include_router(websocket_router, prefix="/api/v1/ws", tags=["websocket"])
|
||||
|
||||
|
||||
# Health check endpoints
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Basic health check endpoint"""
|
||||
return {
|
||||
"status": "healthy" if app.state.ready else "starting",
|
||||
"service": "training-service",
|
||||
"version": "1.0.0",
|
||||
"timestamp": structlog.get_logger().new().info("Health check")
|
||||
}
|
||||
|
||||
@app.get("/health/ready")
|
||||
async def readiness_check():
|
||||
"""Kubernetes readiness probe endpoint with comprehensive database checks"""
|
||||
try:
|
||||
# Get comprehensive database health including table verification
|
||||
db_health = await get_comprehensive_db_health()
|
||||
|
||||
checks = {
|
||||
"database_connectivity": db_health["connectivity"],
|
||||
"database_tables": db_health["tables_exist"],
|
||||
"application": getattr(app.state, 'ready', False)
|
||||
}
|
||||
|
||||
# Include detailed database info for debugging
|
||||
database_details = {
|
||||
"status": db_health["status"],
|
||||
"tables_verified": db_health["tables_verified"],
|
||||
"missing_tables": db_health["missing_tables"],
|
||||
"errors": db_health["errors"]
|
||||
}
|
||||
|
||||
# Service is ready only if all checks pass
|
||||
all_ready = all(checks.values()) and db_health["status"] == "healthy"
|
||||
|
||||
if all_ready:
|
||||
return {
|
||||
"status": "ready",
|
||||
"checks": checks,
|
||||
"database": database_details
|
||||
}
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={
|
||||
"status": "not ready",
|
||||
"checks": checks,
|
||||
"database": database_details
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Readiness check failed", error=str(e))
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={
|
||||
"status": "not ready",
|
||||
"error": f"Health check failed: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
@app.get("/health/database")
|
||||
async def database_health_check():
|
||||
"""Detailed database health endpoint for debugging"""
|
||||
try:
|
||||
db_health = await get_comprehensive_db_health()
|
||||
status_code = 200 if db_health["status"] == "healthy" else 503
|
||||
return JSONResponse(status_code=status_code, content=db_health)
|
||||
except Exception as e:
|
||||
logger.error("Database health check failed", error=str(e))
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={
|
||||
"status": "unhealthy",
|
||||
"error": f"Health check failed: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
@app.get("/metrics")
|
||||
async def get_metrics():
|
||||
"""Prometheus metrics endpoint"""
|
||||
if hasattr(app.state, 'metrics_collector'):
|
||||
return app.state.metrics_collector.get_metrics()
|
||||
return {"status": "metrics not available"}
|
||||
|
||||
@app.get("/health/live")
|
||||
async def liveness_check():
|
||||
return {"status": "alive"}
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"service": "training-service", "version": "1.0.0"}
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
|
||||
Reference in New Issue
Block a user