Add ci/cd and fix multiple pods issues

This commit is contained in:
Urtzi Alfaro
2026-01-18 09:02:27 +01:00
parent 3c4b5c2a06
commit 21d35ea92b
27 changed files with 3779 additions and 73 deletions

View File

@@ -18,6 +18,28 @@ class OrchestratorService(StandardFastAPIService):
expected_migration_version = "001_initial_schema"
def __init__(self):
# Define expected database tables for health checks
orchestrator_expected_tables = [
'orchestration_runs'
]
self.rabbitmq_client = None
self.event_publisher = None
self.leader_election = None
self.scheduler_service = None
super().__init__(
service_name="orchestrator-service",
app_name=settings.APP_NAME,
description=settings.DESCRIPTION,
version=settings.VERSION,
api_prefix="", # Empty because RouteBuilder already includes /api/v1
database_manager=database_manager,
expected_tables=orchestrator_expected_tables,
enable_messaging=True # Enable RabbitMQ for event publishing
)
async def verify_migrations(self):
"""Verify database schema matches the latest migrations"""
try:
@@ -32,26 +54,6 @@ class OrchestratorService(StandardFastAPIService):
self.logger.error(f"Migration verification failed: {e}")
raise
def __init__(self):
# Define expected database tables for health checks
orchestrator_expected_tables = [
'orchestration_runs'
]
self.rabbitmq_client = None
self.event_publisher = None
super().__init__(
service_name="orchestrator-service",
app_name=settings.APP_NAME,
description=settings.DESCRIPTION,
version=settings.VERSION,
api_prefix="", # Empty because RouteBuilder already includes /api/v1
database_manager=database_manager,
expected_tables=orchestrator_expected_tables,
enable_messaging=True # Enable RabbitMQ for event publishing
)
async def _setup_messaging(self):
"""Setup messaging for orchestrator service"""
from shared.messaging import UnifiedEventPublisher, RabbitMQClient
@@ -84,22 +86,91 @@ class OrchestratorService(StandardFastAPIService):
self.logger.info("Orchestrator Service starting up...")
# Initialize orchestrator scheduler service with EventPublisher
from app.services.orchestrator_service import OrchestratorSchedulerService
scheduler_service = OrchestratorSchedulerService(self.event_publisher, settings)
await scheduler_service.start()
app.state.scheduler_service = scheduler_service
self.logger.info("Orchestrator scheduler service started")
# Initialize leader election for horizontal scaling
# Only the leader pod will run the scheduler
await self._setup_leader_election(app)
# REMOVED: Delivery tracking service - moved to procurement service (domain ownership)
async def _setup_leader_election(self, app: FastAPI):
"""
Setup leader election for scheduler.
CRITICAL FOR HORIZONTAL SCALING:
Without leader election, each pod would run the same scheduled jobs,
causing duplicate forecasts, production schedules, and database contention.
"""
from shared.leader_election import LeaderElectionService
import redis.asyncio as redis
try:
# Create Redis connection for leader election
redis_url = f"redis://:{settings.REDIS_PASSWORD}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB}"
if settings.REDIS_TLS_ENABLED.lower() == "true":
redis_url = redis_url.replace("redis://", "rediss://")
redis_client = redis.from_url(redis_url, decode_responses=False)
await redis_client.ping()
# Use shared leader election service
self.leader_election = LeaderElectionService(
redis_client,
service_name="orchestrator"
)
# Define callbacks for leader state changes
async def on_become_leader():
self.logger.info("This pod became the leader - starting scheduler")
from app.services.orchestrator_service import OrchestratorSchedulerService
self.scheduler_service = OrchestratorSchedulerService(self.event_publisher, settings)
await self.scheduler_service.start()
app.state.scheduler_service = self.scheduler_service
self.logger.info("Orchestrator scheduler service started (leader only)")
async def on_lose_leader():
self.logger.warning("This pod lost leadership - stopping scheduler")
if self.scheduler_service:
await self.scheduler_service.stop()
self.scheduler_service = None
if hasattr(app.state, 'scheduler_service'):
app.state.scheduler_service = None
self.logger.info("Orchestrator scheduler service stopped (no longer leader)")
# Start leader election
await self.leader_election.start(
on_become_leader=on_become_leader,
on_lose_leader=on_lose_leader
)
# Store leader election in app state for health checks
app.state.leader_election = self.leader_election
self.logger.info("Leader election initialized",
is_leader=self.leader_election.is_leader,
instance_id=self.leader_election.instance_id)
except Exception as e:
self.logger.error("Failed to setup leader election, falling back to standalone mode",
error=str(e))
# Fallback: start scheduler anyway (for single-pod deployments)
from app.services.orchestrator_service import OrchestratorSchedulerService
self.scheduler_service = OrchestratorSchedulerService(self.event_publisher, settings)
await self.scheduler_service.start()
app.state.scheduler_service = self.scheduler_service
self.logger.warning("Scheduler started in standalone mode (no leader election)")
async def on_shutdown(self, app: FastAPI):
"""Custom shutdown logic for orchestrator service"""
self.logger.info("Orchestrator Service shutting down...")
# Stop scheduler service
if hasattr(app.state, 'scheduler_service'):
await app.state.scheduler_service.stop()
# Stop leader election (this will also stop scheduler if we're the leader)
if self.leader_election:
await self.leader_election.stop()
self.logger.info("Leader election stopped")
# Stop scheduler service if still running
if self.scheduler_service:
await self.scheduler_service.stop()
self.logger.info("Orchestrator scheduler service stopped")

View File

@@ -1,12 +1,12 @@
"""
Delivery Tracking Service - Simplified
Delivery Tracking Service - With Leader Election
Tracks purchase order deliveries and generates appropriate alerts using EventPublisher:
- DELIVERY_ARRIVING_SOON: 2 hours before delivery window
- DELIVERY_OVERDUE: 30 minutes after expected delivery time
- STOCK_RECEIPT_INCOMPLETE: If delivery not marked as received
Runs as internal scheduler with leader election.
Runs as internal scheduler with leader election for horizontal scaling.
Domain ownership: Procurement service owns all PO and delivery tracking.
"""
@@ -30,7 +30,7 @@ class DeliveryTrackingService:
Monitors PO deliveries and generates time-based alerts using EventPublisher.
Uses APScheduler with leader election to run hourly checks.
Only one pod executes checks (others skip if not leader).
Only one pod executes checks - leader election ensures no duplicate alerts.
"""
def __init__(self, event_publisher: UnifiedEventPublisher, config, database_manager=None):
@@ -38,46 +38,121 @@ class DeliveryTrackingService:
self.config = config
self.database_manager = database_manager
self.scheduler = AsyncIOScheduler()
self.is_leader = False
self._leader_election = None
self._redis_client = None
self._scheduler_started = False
self.instance_id = str(uuid4())[:8] # Short instance ID for logging
async def start(self):
"""Start the delivery tracking scheduler"""
# Initialize and start scheduler if not already running
"""Start the delivery tracking scheduler with leader election"""
try:
# Initialize leader election
await self._setup_leader_election()
except Exception as e:
logger.error("Failed to setup leader election, starting in standalone mode",
error=str(e))
# Fallback: start scheduler without leader election
await self._start_scheduler()
async def _setup_leader_election(self):
"""Setup Redis-based leader election for horizontal scaling"""
from shared.leader_election import LeaderElectionService
import redis.asyncio as redis
# Build Redis URL from config
redis_url = getattr(self.config, 'REDIS_URL', None)
if not redis_url:
redis_password = getattr(self.config, 'REDIS_PASSWORD', '')
redis_host = getattr(self.config, 'REDIS_HOST', 'localhost')
redis_port = getattr(self.config, 'REDIS_PORT', 6379)
redis_db = getattr(self.config, 'REDIS_DB', 0)
redis_url = f"redis://:{redis_password}@{redis_host}:{redis_port}/{redis_db}"
self._redis_client = redis.from_url(redis_url, decode_responses=False)
await self._redis_client.ping()
# Create leader election service
self._leader_election = LeaderElectionService(
self._redis_client,
service_name="procurement-delivery-tracking"
)
# Start leader election with callbacks
await self._leader_election.start(
on_become_leader=self._on_become_leader,
on_lose_leader=self._on_lose_leader
)
logger.info("Leader election initialized for delivery tracking",
is_leader=self._leader_election.is_leader,
instance_id=self.instance_id)
async def _on_become_leader(self):
"""Called when this instance becomes the leader"""
logger.info("Became leader for delivery tracking - starting scheduler",
instance_id=self.instance_id)
await self._start_scheduler()
async def _on_lose_leader(self):
"""Called when this instance loses leadership"""
logger.warning("Lost leadership for delivery tracking - stopping scheduler",
instance_id=self.instance_id)
await self._stop_scheduler()
async def _start_scheduler(self):
"""Start the APScheduler with delivery tracking jobs"""
if self._scheduler_started:
logger.debug("Scheduler already started", instance_id=self.instance_id)
return
if not self.scheduler.running:
# Add hourly job to check deliveries
self.scheduler.add_job(
self._check_all_tenants,
trigger=CronTrigger(minute=30), # Run every hour at :30 (00:30, 01:30, 02:30, etc.)
trigger=CronTrigger(minute=30), # Run every hour at :30
id='hourly_delivery_check',
name='Hourly Delivery Tracking',
replace_existing=True,
max_instances=1, # Ensure no overlapping runs
coalesce=True # Combine missed runs
max_instances=1,
coalesce=True
)
self.scheduler.start()
self._scheduler_started = True
# Log next run time
next_run = self.scheduler.get_job('hourly_delivery_check').next_run_time
logger.info(
"Delivery tracking scheduler started with hourly checks",
instance_id=self.instance_id,
next_run=next_run.isoformat() if next_run else None
)
else:
logger.info(
"Delivery tracking scheduler already running",
instance_id=self.instance_id
)
logger.info("Delivery tracking scheduler started",
instance_id=self.instance_id,
next_run=next_run.isoformat() if next_run else None)
async def _stop_scheduler(self):
"""Stop the APScheduler"""
if not self._scheduler_started:
return
if self.scheduler.running:
self.scheduler.shutdown(wait=False)
self._scheduler_started = False
logger.info("Delivery tracking scheduler stopped", instance_id=self.instance_id)
async def stop(self):
"""Stop the scheduler and release leader lock"""
if self.scheduler.running:
self.scheduler.shutdown(wait=True) # Graceful shutdown
logger.info("Delivery tracking scheduler stopped", instance_id=self.instance_id)
else:
logger.info("Delivery tracking scheduler already stopped", instance_id=self.instance_id)
"""Stop the scheduler and leader election"""
# Stop leader election first
if self._leader_election:
await self._leader_election.stop()
logger.info("Leader election stopped", instance_id=self.instance_id)
# Stop scheduler
await self._stop_scheduler()
# Close Redis
if self._redis_client:
await self._redis_client.close()
@property
def is_leader(self) -> bool:
"""Check if this instance is the leader"""
return self._leader_election.is_leader if self._leader_election else True
async def _check_all_tenants(self):
"""

View File

@@ -46,6 +46,9 @@ class TrainingService(StandardFastAPIService):
await setup_messaging()
self.logger.info("Messaging setup completed")
# Initialize Redis pub/sub for cross-pod WebSocket broadcasting
await self._setup_websocket_redis()
# Set up WebSocket event consumer (listens to RabbitMQ and broadcasts to WebSockets)
success = await setup_websocket_event_consumer()
if success:
@@ -53,8 +56,44 @@ class TrainingService(StandardFastAPIService):
else:
self.logger.warning("WebSocket event consumer setup failed")
async def _setup_websocket_redis(self):
"""
Initialize Redis pub/sub for WebSocket cross-pod broadcasting.
CRITICAL FOR HORIZONTAL SCALING:
Without this, WebSocket clients on Pod A won't receive events
from training jobs running on Pod B.
"""
try:
from app.websocket.manager import websocket_manager
from app.core.config import settings
redis_url = settings.REDIS_URL
success = await websocket_manager.initialize_redis(redis_url)
if success:
self.logger.info("WebSocket Redis pub/sub initialized for horizontal scaling")
else:
self.logger.warning(
"WebSocket Redis pub/sub failed to initialize. "
"WebSocket events will only be delivered to local connections."
)
except Exception as e:
self.logger.error("Failed to setup WebSocket Redis pub/sub",
error=str(e))
# Don't fail startup - WebSockets will work locally without Redis
async def _cleanup_messaging(self):
"""Cleanup messaging for training service"""
# Shutdown WebSocket Redis pub/sub
try:
from app.websocket.manager import websocket_manager
await websocket_manager.shutdown()
self.logger.info("WebSocket Redis pub/sub shutdown completed")
except Exception as e:
self.logger.warning("Error shutting down WebSocket Redis", error=str(e))
await cleanup_websocket_consumers()
await cleanup_messaging()
@@ -78,13 +117,49 @@ class TrainingService(StandardFastAPIService):
async def on_startup(self, app: FastAPI):
"""Custom startup logic including migration verification"""
await self.verify_migrations()
# Initialize system metrics collection
system_metrics = SystemMetricsCollector("training")
self.logger.info("System metrics collection started")
# Recover stale jobs from previous pod crashes
# This is important for horizontal scaling - jobs may be left in 'running'
# state if a pod crashes. We mark them as failed so they can be retried.
await self._recover_stale_jobs()
self.logger.info("Training service startup completed")
async def _recover_stale_jobs(self):
"""
Recover stale training jobs on startup.
When a pod crashes mid-training, jobs are left in 'running' or 'pending' state.
This method finds jobs that haven't been updated in a while and marks them
as failed so users can retry them.
"""
try:
from app.repositories.training_log_repository import TrainingLogRepository
async with self.database_manager.get_session() as session:
log_repo = TrainingLogRepository(session)
# Recover jobs that haven't been updated in 60 minutes
# This is conservative - most training jobs complete within 30 minutes
recovered = await log_repo.recover_stale_jobs(stale_threshold_minutes=60)
if recovered:
self.logger.warning(
"Recovered stale training jobs on startup",
recovered_count=len(recovered),
job_ids=[j.job_id for j in recovered]
)
else:
self.logger.info("No stale training jobs to recover")
except Exception as e:
# Don't fail startup if recovery fails - just log the error
self.logger.error("Failed to recover stale jobs on startup", error=str(e))
async def on_shutdown(self, app: FastAPI):
"""Custom shutdown logic for training service"""
await cleanup_training_database()

View File

@@ -342,4 +342,166 @@ class TrainingLogRepository(TrainingBaseRepository):
logger.error("Failed to get start time",
job_id=job_id,
error=str(e))
return None
return None
async def create_job_atomic(
self,
job_id: str,
tenant_id: str,
config: Dict[str, Any] = None
) -> tuple[Optional[ModelTrainingLog], bool]:
"""
Atomically create a training job, respecting the unique constraint.
This method uses INSERT ... ON CONFLICT to handle race conditions
when multiple pods try to create a job for the same tenant simultaneously.
The database constraint (idx_unique_active_training_per_tenant) ensures
only one active job per tenant can exist.
Args:
job_id: Unique job identifier
tenant_id: Tenant identifier
config: Optional job configuration
Returns:
Tuple of (job, created):
- If created: (new_job, True)
- If conflict (existing active job): (existing_job, False)
- If error: raises DatabaseError
"""
try:
# First, try to find an existing active job
existing = await self.get_active_jobs(tenant_id=tenant_id)
pending = await self.get_logs_by_tenant(tenant_id=tenant_id, status="pending", limit=1)
if existing or pending:
# Return existing job
active_job = existing[0] if existing else pending[0]
logger.info("Found existing active job, skipping creation",
existing_job_id=active_job.job_id,
tenant_id=tenant_id,
requested_job_id=job_id)
return (active_job, False)
# Try to create the new job
# If another pod created one in the meantime, the unique constraint will prevent this
log_data = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "pending",
"progress": 0,
"current_step": "initializing",
"config": config or {}
}
try:
new_job = await self.create_training_log(log_data)
await self.session.commit()
logger.info("Created new training job atomically",
job_id=job_id,
tenant_id=tenant_id)
return (new_job, True)
except Exception as create_error:
error_str = str(create_error).lower()
# Check if this is a unique constraint violation
if "unique" in error_str or "duplicate" in error_str or "constraint" in error_str:
await self.session.rollback()
# Another pod created a job, fetch it
logger.info("Unique constraint hit, fetching existing job",
tenant_id=tenant_id,
requested_job_id=job_id)
existing = await self.get_active_jobs(tenant_id=tenant_id)
pending = await self.get_logs_by_tenant(tenant_id=tenant_id, status="pending", limit=1)
if existing or pending:
active_job = existing[0] if existing else pending[0]
return (active_job, False)
# If still no job found, something went wrong
raise DatabaseError(f"Constraint violation but no active job found: {create_error}")
else:
raise
except DatabaseError:
raise
except Exception as e:
logger.error("Failed to create job atomically",
job_id=job_id,
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to create training job atomically: {str(e)}")
async def recover_stale_jobs(self, stale_threshold_minutes: int = 60) -> List[ModelTrainingLog]:
"""
Find and mark stale running jobs as failed.
This is used during service startup to clean up jobs that were
running when a pod crashed. With multiple replicas, only stale
jobs (not updated recently) should be marked as failed.
Args:
stale_threshold_minutes: Jobs not updated for this long are considered stale
Returns:
List of jobs that were marked as failed
"""
try:
stale_cutoff = datetime.now() - timedelta(minutes=stale_threshold_minutes)
# Find running jobs that haven't been updated recently
query = text("""
SELECT id, job_id, tenant_id, status, updated_at
FROM model_training_logs
WHERE status IN ('running', 'pending')
AND updated_at < :stale_cutoff
""")
result = await self.session.execute(query, {"stale_cutoff": stale_cutoff})
stale_jobs = result.fetchall()
recovered_jobs = []
for row in stale_jobs:
try:
# Mark as failed
update_query = text("""
UPDATE model_training_logs
SET status = 'failed',
error_message = :error_msg,
end_time = :end_time,
updated_at = :updated_at
WHERE id = :id AND status IN ('running', 'pending')
""")
await self.session.execute(update_query, {
"id": row.id,
"error_msg": f"Job recovered as failed - not updated since {row.updated_at.isoformat()}. Pod may have crashed.",
"end_time": datetime.now(),
"updated_at": datetime.now()
})
logger.warning("Recovered stale training job",
job_id=row.job_id,
tenant_id=str(row.tenant_id),
last_updated=row.updated_at.isoformat() if row.updated_at else "unknown")
# Fetch the updated job to return
job = await self.get_by_job_id(row.job_id)
if job:
recovered_jobs.append(job)
except Exception as job_error:
logger.error("Failed to recover individual stale job",
job_id=row.job_id,
error=str(job_error))
if recovered_jobs:
await self.session.commit()
logger.info("Stale job recovery completed",
recovered_count=len(recovered_jobs),
stale_threshold_minutes=stale_threshold_minutes)
return recovered_jobs
except Exception as e:
logger.error("Failed to recover stale jobs",
error=str(e))
await self.session.rollback()
return []

View File

@@ -1,10 +1,16 @@
"""
Distributed Locking Mechanisms
Prevents concurrent training jobs for the same product
HORIZONTAL SCALING FIX:
- Uses SHA256 for stable hash across all Python processes/pods
- Python's built-in hash() varies between processes due to hash randomization (Python 3.3+)
- This ensures all pods compute the same lock ID for the same lock name
"""
import asyncio
import time
import hashlib
from typing import Optional
import logging
from contextlib import asynccontextmanager
@@ -39,9 +45,20 @@ class DatabaseLock:
self.lock_id = self._hash_lock_name(lock_name)
def _hash_lock_name(self, name: str) -> int:
"""Convert lock name to integer ID for PostgreSQL advisory lock"""
# Use hash and modulo to get a positive 32-bit integer
return abs(hash(name)) % (2**31)
"""
Convert lock name to integer ID for PostgreSQL advisory lock.
CRITICAL: Uses SHA256 for stable hash across all Python processes/pods.
Python's built-in hash() varies between processes due to hash randomization
(PYTHONHASHSEED, enabled by default since Python 3.3), which would cause
different pods to compute different lock IDs for the same lock name,
defeating the purpose of distributed locking.
"""
# Use SHA256 for stable, cross-process hash
hash_bytes = hashlib.sha256(name.encode('utf-8')).digest()
# Take first 4 bytes and convert to positive 31-bit integer
# (PostgreSQL advisory locks use bigint, but we use 31-bit for safety)
return int.from_bytes(hash_bytes[:4], 'big') % (2**31)
@asynccontextmanager
async def acquire(self, session: AsyncSession):

View File

@@ -1,21 +1,39 @@
"""
WebSocket Connection Manager for Training Service
Manages WebSocket connections and broadcasts RabbitMQ events to connected clients
HORIZONTAL SCALING:
- Uses Redis pub/sub for cross-pod WebSocket broadcasting
- Each pod subscribes to a Redis channel and broadcasts to its local connections
- Events published to Redis are received by all pods, ensuring clients on any
pod receive events from training jobs running on any other pod
"""
import asyncio
import json
from typing import Dict, Set
import os
from typing import Dict, Optional
from fastapi import WebSocket
import structlog
logger = structlog.get_logger()
# Redis pub/sub channel for WebSocket events
REDIS_WEBSOCKET_CHANNEL = "training:websocket:events"
class WebSocketConnectionManager:
"""
Simple WebSocket connection manager.
Manages connections per job_id and broadcasts messages to all connected clients.
WebSocket connection manager with Redis pub/sub for horizontal scaling.
In a multi-pod deployment:
1. Events are published to Redis pub/sub (not just local broadcast)
2. Each pod subscribes to Redis and broadcasts to its local WebSocket connections
3. This ensures clients connected to any pod receive events from any pod
Flow:
- RabbitMQ event → Pod A receives → Pod A publishes to Redis
- Redis pub/sub → All pods receive → Each pod broadcasts to local WebSockets
"""
def __init__(self):
@@ -24,6 +42,121 @@ class WebSocketConnectionManager:
self._lock = asyncio.Lock()
# Store latest event for each job to provide initial state
self._latest_events: Dict[str, dict] = {}
# Redis client for pub/sub
self._redis: Optional[object] = None
self._pubsub: Optional[object] = None
self._subscriber_task: Optional[asyncio.Task] = None
self._running = False
self._instance_id = f"{os.environ.get('HOSTNAME', 'unknown')}:{os.getpid()}"
async def initialize_redis(self, redis_url: str) -> bool:
"""
Initialize Redis connection for cross-pod pub/sub.
Args:
redis_url: Redis connection URL
Returns:
True if successful, False otherwise
"""
try:
import redis.asyncio as redis_async
self._redis = redis_async.from_url(redis_url, decode_responses=True)
await self._redis.ping()
# Create pub/sub subscriber
self._pubsub = self._redis.pubsub()
await self._pubsub.subscribe(REDIS_WEBSOCKET_CHANNEL)
# Start subscriber task
self._running = True
self._subscriber_task = asyncio.create_task(self._redis_subscriber_loop())
logger.info("Redis pub/sub initialized for WebSocket broadcasting",
instance_id=self._instance_id,
channel=REDIS_WEBSOCKET_CHANNEL)
return True
except Exception as e:
logger.error("Failed to initialize Redis pub/sub",
error=str(e),
instance_id=self._instance_id)
return False
async def shutdown(self):
"""Shutdown Redis pub/sub connection"""
self._running = False
if self._subscriber_task:
self._subscriber_task.cancel()
try:
await self._subscriber_task
except asyncio.CancelledError:
pass
if self._pubsub:
await self._pubsub.unsubscribe(REDIS_WEBSOCKET_CHANNEL)
await self._pubsub.close()
if self._redis:
await self._redis.close()
logger.info("Redis pub/sub shutdown complete",
instance_id=self._instance_id)
async def _redis_subscriber_loop(self):
"""Background task to receive Redis pub/sub messages and broadcast locally"""
try:
while self._running:
try:
message = await self._pubsub.get_message(
ignore_subscribe_messages=True,
timeout=1.0
)
if message and message['type'] == 'message':
await self._handle_redis_message(message['data'])
except asyncio.CancelledError:
break
except Exception as e:
logger.error("Error in Redis subscriber loop",
error=str(e),
instance_id=self._instance_id)
await asyncio.sleep(1) # Backoff on error
except asyncio.CancelledError:
pass
logger.info("Redis subscriber loop stopped",
instance_id=self._instance_id)
async def _handle_redis_message(self, data: str):
"""Handle a message received from Redis pub/sub"""
try:
payload = json.loads(data)
job_id = payload.get('job_id')
message = payload.get('message')
source_instance = payload.get('source_instance')
if not job_id or not message:
return
# Log cross-pod message
if source_instance != self._instance_id:
logger.debug("Received cross-pod WebSocket event",
job_id=job_id,
source_instance=source_instance,
local_instance=self._instance_id)
# Broadcast to local WebSocket connections
await self._broadcast_local(job_id, message)
except json.JSONDecodeError as e:
logger.warning("Invalid JSON in Redis message", error=str(e))
except Exception as e:
logger.error("Error handling Redis message", error=str(e))
async def connect(self, job_id: str, websocket: WebSocket) -> None:
"""Register a new WebSocket connection for a job"""
@@ -50,7 +183,8 @@ class WebSocketConnectionManager:
logger.info("WebSocket connected",
job_id=job_id,
websocket_id=ws_id,
total_connections=len(self._connections[job_id]))
total_connections=len(self._connections[job_id]),
instance_id=self._instance_id)
async def disconnect(self, job_id: str, websocket: WebSocket) -> None:
"""Remove a WebSocket connection"""
@@ -66,19 +200,56 @@ class WebSocketConnectionManager:
logger.info("WebSocket disconnected",
job_id=job_id,
websocket_id=ws_id,
remaining_connections=len(self._connections.get(job_id, {})))
remaining_connections=len(self._connections.get(job_id, {})),
instance_id=self._instance_id)
async def broadcast(self, job_id: str, message: dict) -> int:
"""
Broadcast a message to all connections for a specific job.
Returns the number of successful broadcasts.
Broadcast a message to all connections for a specific job across ALL pods.
If Redis is configured, publishes to Redis pub/sub which then broadcasts
to all pods. Otherwise, falls back to local-only broadcast.
Returns the number of successful local broadcasts.
"""
# Store the latest event for this job to provide initial state to new connections
if message.get('type') != 'initial_state': # Don't store initial_state messages
if message.get('type') != 'initial_state':
self._latest_events[job_id] = message
# If Redis is available, publish to Redis for cross-pod broadcast
if self._redis:
try:
payload = json.dumps({
'job_id': job_id,
'message': message,
'source_instance': self._instance_id
})
await self._redis.publish(REDIS_WEBSOCKET_CHANNEL, payload)
logger.debug("Published WebSocket event to Redis",
job_id=job_id,
message_type=message.get('type'),
instance_id=self._instance_id)
# Return 0 here because the actual broadcast happens via subscriber
# The count will be from _broadcast_local when the message is received
return 0
except Exception as e:
logger.warning("Failed to publish to Redis, falling back to local broadcast",
error=str(e),
job_id=job_id)
# Fall through to local broadcast
# Local-only broadcast (when Redis is not available)
return await self._broadcast_local(job_id, message)
async def _broadcast_local(self, job_id: str, message: dict) -> int:
"""
Broadcast a message to local WebSocket connections only.
This is called either directly (no Redis) or from Redis subscriber.
"""
if job_id not in self._connections:
logger.debug("No active connections for job", job_id=job_id)
logger.debug("No active local connections for job",
job_id=job_id,
instance_id=self._instance_id)
return 0
connections = list(self._connections[job_id].values())
@@ -103,18 +274,27 @@ class WebSocketConnectionManager:
self._connections[job_id].pop(ws_id, None)
if successful_sends > 0:
logger.info("Broadcasted message to WebSocket clients",
logger.info("Broadcasted message to local WebSocket clients",
job_id=job_id,
message_type=message.get('type'),
successful_sends=successful_sends,
failed_sends=len(failed_websockets))
failed_sends=len(failed_websockets),
instance_id=self._instance_id)
return successful_sends
def get_connection_count(self, job_id: str) -> int:
"""Get the number of active connections for a job"""
"""Get the number of active local connections for a job"""
return len(self._connections.get(job_id, {}))
def get_total_connection_count(self) -> int:
"""Get total number of active connections across all jobs"""
return sum(len(conns) for conns in self._connections.values())
def is_redis_enabled(self) -> bool:
"""Check if Redis pub/sub is enabled"""
return self._redis is not None and self._running
# Global singleton instance
websocket_manager = WebSocketConnectionManager()

View File

@@ -0,0 +1,60 @@
"""Add horizontal scaling constraints for multi-pod deployment
Revision ID: add_horizontal_scaling
Revises: 26a665cd5348
Create Date: 2025-01-18
This migration adds database-level constraints to prevent race conditions
when running multiple training service pods:
1. Partial unique index on model_training_logs to prevent duplicate active jobs per tenant
2. Index to speed up active job lookups
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'add_horizontal_scaling'
down_revision: Union[str, None] = '26a665cd5348'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add partial unique index to prevent duplicate active training jobs per tenant
# This ensures only ONE job can be in 'pending' or 'running' status per tenant at a time
# The constraint is enforced at the database level, preventing race conditions
# between multiple pods checking and creating jobs simultaneously
op.execute("""
CREATE UNIQUE INDEX IF NOT EXISTS idx_unique_active_training_per_tenant
ON model_training_logs (tenant_id)
WHERE status IN ('pending', 'running')
""")
# Add index to speed up active job lookups (used by deduplication check)
op.create_index(
'idx_training_logs_tenant_status',
'model_training_logs',
['tenant_id', 'status'],
unique=False,
if_not_exists=True
)
# Add index for job recovery queries (find stale running jobs)
op.create_index(
'idx_training_logs_status_updated',
'model_training_logs',
['status', 'updated_at'],
unique=False,
if_not_exists=True
)
def downgrade() -> None:
# Remove the indexes in reverse order
op.execute("DROP INDEX IF EXISTS idx_training_logs_status_updated")
op.execute("DROP INDEX IF EXISTS idx_training_logs_tenant_status")
op.execute("DROP INDEX IF EXISTS idx_unique_active_training_per_tenant")