Add ci/cd and fix multiple pods issues
This commit is contained in:
@@ -46,6 +46,9 @@ class TrainingService(StandardFastAPIService):
|
||||
await setup_messaging()
|
||||
self.logger.info("Messaging setup completed")
|
||||
|
||||
# Initialize Redis pub/sub for cross-pod WebSocket broadcasting
|
||||
await self._setup_websocket_redis()
|
||||
|
||||
# Set up WebSocket event consumer (listens to RabbitMQ and broadcasts to WebSockets)
|
||||
success = await setup_websocket_event_consumer()
|
||||
if success:
|
||||
@@ -53,8 +56,44 @@ class TrainingService(StandardFastAPIService):
|
||||
else:
|
||||
self.logger.warning("WebSocket event consumer setup failed")
|
||||
|
||||
async def _setup_websocket_redis(self):
|
||||
"""
|
||||
Initialize Redis pub/sub for WebSocket cross-pod broadcasting.
|
||||
|
||||
CRITICAL FOR HORIZONTAL SCALING:
|
||||
Without this, WebSocket clients on Pod A won't receive events
|
||||
from training jobs running on Pod B.
|
||||
"""
|
||||
try:
|
||||
from app.websocket.manager import websocket_manager
|
||||
from app.core.config import settings
|
||||
|
||||
redis_url = settings.REDIS_URL
|
||||
success = await websocket_manager.initialize_redis(redis_url)
|
||||
|
||||
if success:
|
||||
self.logger.info("WebSocket Redis pub/sub initialized for horizontal scaling")
|
||||
else:
|
||||
self.logger.warning(
|
||||
"WebSocket Redis pub/sub failed to initialize. "
|
||||
"WebSocket events will only be delivered to local connections."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Failed to setup WebSocket Redis pub/sub",
|
||||
error=str(e))
|
||||
# Don't fail startup - WebSockets will work locally without Redis
|
||||
|
||||
async def _cleanup_messaging(self):
|
||||
"""Cleanup messaging for training service"""
|
||||
# Shutdown WebSocket Redis pub/sub
|
||||
try:
|
||||
from app.websocket.manager import websocket_manager
|
||||
await websocket_manager.shutdown()
|
||||
self.logger.info("WebSocket Redis pub/sub shutdown completed")
|
||||
except Exception as e:
|
||||
self.logger.warning("Error shutting down WebSocket Redis", error=str(e))
|
||||
|
||||
await cleanup_websocket_consumers()
|
||||
await cleanup_messaging()
|
||||
|
||||
@@ -78,13 +117,49 @@ class TrainingService(StandardFastAPIService):
|
||||
async def on_startup(self, app: FastAPI):
|
||||
"""Custom startup logic including migration verification"""
|
||||
await self.verify_migrations()
|
||||
|
||||
|
||||
# Initialize system metrics collection
|
||||
system_metrics = SystemMetricsCollector("training")
|
||||
self.logger.info("System metrics collection started")
|
||||
|
||||
|
||||
# Recover stale jobs from previous pod crashes
|
||||
# This is important for horizontal scaling - jobs may be left in 'running'
|
||||
# state if a pod crashes. We mark them as failed so they can be retried.
|
||||
await self._recover_stale_jobs()
|
||||
|
||||
self.logger.info("Training service startup completed")
|
||||
|
||||
async def _recover_stale_jobs(self):
|
||||
"""
|
||||
Recover stale training jobs on startup.
|
||||
|
||||
When a pod crashes mid-training, jobs are left in 'running' or 'pending' state.
|
||||
This method finds jobs that haven't been updated in a while and marks them
|
||||
as failed so users can retry them.
|
||||
"""
|
||||
try:
|
||||
from app.repositories.training_log_repository import TrainingLogRepository
|
||||
|
||||
async with self.database_manager.get_session() as session:
|
||||
log_repo = TrainingLogRepository(session)
|
||||
|
||||
# Recover jobs that haven't been updated in 60 minutes
|
||||
# This is conservative - most training jobs complete within 30 minutes
|
||||
recovered = await log_repo.recover_stale_jobs(stale_threshold_minutes=60)
|
||||
|
||||
if recovered:
|
||||
self.logger.warning(
|
||||
"Recovered stale training jobs on startup",
|
||||
recovered_count=len(recovered),
|
||||
job_ids=[j.job_id for j in recovered]
|
||||
)
|
||||
else:
|
||||
self.logger.info("No stale training jobs to recover")
|
||||
|
||||
except Exception as e:
|
||||
# Don't fail startup if recovery fails - just log the error
|
||||
self.logger.error("Failed to recover stale jobs on startup", error=str(e))
|
||||
|
||||
async def on_shutdown(self, app: FastAPI):
|
||||
"""Custom shutdown logic for training service"""
|
||||
await cleanup_training_database()
|
||||
|
||||
Reference in New Issue
Block a user