REFACTOR external service and improve websocket training

This commit is contained in:
Urtzi Alfaro
2025-10-09 14:11:02 +02:00
parent 7c72f83c51
commit 3c689b4f98
111 changed files with 13289 additions and 2374 deletions

View File

@@ -11,35 +11,15 @@ from fastapi import FastAPI, Request
from sqlalchemy import text
from app.core.config import settings
from app.core.database import initialize_training_database, cleanup_training_database, database_manager
from app.api import training_jobs, training_operations, models
from app.services.messaging import setup_messaging, cleanup_messaging
from app.api import training_jobs, training_operations, models, health, monitoring, websocket_operations
from app.services.training_events import setup_messaging, cleanup_messaging
from app.websocket.events import setup_websocket_event_consumer, cleanup_websocket_consumers
from shared.service_base import StandardFastAPIService
class TrainingService(StandardFastAPIService):
"""Training Service with standardized setup"""
expected_migration_version = "00001"
async def on_startup(self, app):
"""Custom startup logic including migration verification"""
await self.verify_migrations()
await super().on_startup(app)
async def verify_migrations(self):
"""Verify database schema matches the latest migrations."""
try:
async with self.database_manager.get_session() as session:
result = await session.execute(text("SELECT version_num FROM alembic_version"))
version = result.scalar()
if version != self.expected_migration_version:
self.logger.error(f"Migration version mismatch: expected {self.expected_migration_version}, got {version}")
raise RuntimeError(f"Migration version mismatch: expected {self.expected_migration_version}, got {version}")
self.logger.info(f"Migration verification successful: {version}")
except Exception as e:
self.logger.error(f"Migration verification failed: {e}")
raise
def __init__(self):
# Define expected database tables for health checks
training_expected_tables = [
@@ -54,7 +34,7 @@ class TrainingService(StandardFastAPIService):
version="1.0.0",
log_level=settings.LOG_LEVEL,
cors_origins=settings.CORS_ORIGINS_LIST,
api_prefix="", # Empty because RouteBuilder already includes /api/v1
api_prefix="",
database_manager=database_manager,
expected_tables=training_expected_tables,
enable_messaging=True
@@ -65,18 +45,42 @@ class TrainingService(StandardFastAPIService):
await setup_messaging()
self.logger.info("Messaging setup completed")
# Set up WebSocket event consumer (listens to RabbitMQ and broadcasts to WebSockets)
success = await setup_websocket_event_consumer()
if success:
self.logger.info("WebSocket event consumer setup completed")
else:
self.logger.warning("WebSocket event consumer setup failed")
async def _cleanup_messaging(self):
"""Cleanup messaging for training service"""
await cleanup_websocket_consumers()
await cleanup_messaging()
async def verify_migrations(self):
"""Verify database schema matches the latest migrations dynamically."""
try:
async with self.database_manager.get_session() as session:
result = await session.execute(text("SELECT version_num FROM alembic_version"))
version = result.scalar()
if not version:
self.logger.error("No migration version found in database")
raise RuntimeError("Database not initialized - no alembic version found")
self.logger.info(f"Migration verification successful: {version}")
return version
except Exception as e:
self.logger.error(f"Migration verification failed: {e}")
raise
async def on_startup(self, app: FastAPI):
"""Custom startup logic for training service"""
pass
"""Custom startup logic including migration verification"""
await self.verify_migrations()
self.logger.info("Training service startup completed")
async def on_shutdown(self, app: FastAPI):
"""Custom shutdown logic for training service"""
# Note: Database cleanup is handled by the base class
# but training service has custom cleanup function
await cleanup_training_database()
self.logger.info("Training database cleanup completed")
@@ -162,6 +166,9 @@ service.setup_custom_endpoints()
service.add_router(training_jobs.router, tags=["training-jobs"])
service.add_router(training_operations.router, tags=["training-operations"])
service.add_router(models.router, tags=["models"])
service.add_router(health.router, tags=["health"])
service.add_router(monitoring.router, tags=["monitoring"])
service.add_router(websocket_operations.router, tags=["websocket"])
if __name__ == "__main__":
uvicorn.run(