Add ci/cd and fix multiple pods issues
This commit is contained in:
@@ -18,6 +18,28 @@ class OrchestratorService(StandardFastAPIService):
|
||||
|
||||
expected_migration_version = "001_initial_schema"
|
||||
|
||||
def __init__(self):
|
||||
# Define expected database tables for health checks
|
||||
orchestrator_expected_tables = [
|
||||
'orchestration_runs'
|
||||
]
|
||||
|
||||
self.rabbitmq_client = None
|
||||
self.event_publisher = None
|
||||
self.leader_election = None
|
||||
self.scheduler_service = None
|
||||
|
||||
super().__init__(
|
||||
service_name="orchestrator-service",
|
||||
app_name=settings.APP_NAME,
|
||||
description=settings.DESCRIPTION,
|
||||
version=settings.VERSION,
|
||||
api_prefix="", # Empty because RouteBuilder already includes /api/v1
|
||||
database_manager=database_manager,
|
||||
expected_tables=orchestrator_expected_tables,
|
||||
enable_messaging=True # Enable RabbitMQ for event publishing
|
||||
)
|
||||
|
||||
async def verify_migrations(self):
|
||||
"""Verify database schema matches the latest migrations"""
|
||||
try:
|
||||
@@ -32,26 +54,6 @@ class OrchestratorService(StandardFastAPIService):
|
||||
self.logger.error(f"Migration verification failed: {e}")
|
||||
raise
|
||||
|
||||
def __init__(self):
|
||||
# Define expected database tables for health checks
|
||||
orchestrator_expected_tables = [
|
||||
'orchestration_runs'
|
||||
]
|
||||
|
||||
self.rabbitmq_client = None
|
||||
self.event_publisher = None
|
||||
|
||||
super().__init__(
|
||||
service_name="orchestrator-service",
|
||||
app_name=settings.APP_NAME,
|
||||
description=settings.DESCRIPTION,
|
||||
version=settings.VERSION,
|
||||
api_prefix="", # Empty because RouteBuilder already includes /api/v1
|
||||
database_manager=database_manager,
|
||||
expected_tables=orchestrator_expected_tables,
|
||||
enable_messaging=True # Enable RabbitMQ for event publishing
|
||||
)
|
||||
|
||||
async def _setup_messaging(self):
|
||||
"""Setup messaging for orchestrator service"""
|
||||
from shared.messaging import UnifiedEventPublisher, RabbitMQClient
|
||||
@@ -84,22 +86,91 @@ class OrchestratorService(StandardFastAPIService):
|
||||
|
||||
self.logger.info("Orchestrator Service starting up...")
|
||||
|
||||
# Initialize orchestrator scheduler service with EventPublisher
|
||||
from app.services.orchestrator_service import OrchestratorSchedulerService
|
||||
scheduler_service = OrchestratorSchedulerService(self.event_publisher, settings)
|
||||
await scheduler_service.start()
|
||||
app.state.scheduler_service = scheduler_service
|
||||
self.logger.info("Orchestrator scheduler service started")
|
||||
# Initialize leader election for horizontal scaling
|
||||
# Only the leader pod will run the scheduler
|
||||
await self._setup_leader_election(app)
|
||||
|
||||
# REMOVED: Delivery tracking service - moved to procurement service (domain ownership)
|
||||
|
||||
async def _setup_leader_election(self, app: FastAPI):
|
||||
"""
|
||||
Setup leader election for scheduler.
|
||||
|
||||
CRITICAL FOR HORIZONTAL SCALING:
|
||||
Without leader election, each pod would run the same scheduled jobs,
|
||||
causing duplicate forecasts, production schedules, and database contention.
|
||||
"""
|
||||
from shared.leader_election import LeaderElectionService
|
||||
import redis.asyncio as redis
|
||||
|
||||
try:
|
||||
# Create Redis connection for leader election
|
||||
redis_url = f"redis://:{settings.REDIS_PASSWORD}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB}"
|
||||
if settings.REDIS_TLS_ENABLED.lower() == "true":
|
||||
redis_url = redis_url.replace("redis://", "rediss://")
|
||||
|
||||
redis_client = redis.from_url(redis_url, decode_responses=False)
|
||||
await redis_client.ping()
|
||||
|
||||
# Use shared leader election service
|
||||
self.leader_election = LeaderElectionService(
|
||||
redis_client,
|
||||
service_name="orchestrator"
|
||||
)
|
||||
|
||||
# Define callbacks for leader state changes
|
||||
async def on_become_leader():
|
||||
self.logger.info("This pod became the leader - starting scheduler")
|
||||
from app.services.orchestrator_service import OrchestratorSchedulerService
|
||||
self.scheduler_service = OrchestratorSchedulerService(self.event_publisher, settings)
|
||||
await self.scheduler_service.start()
|
||||
app.state.scheduler_service = self.scheduler_service
|
||||
self.logger.info("Orchestrator scheduler service started (leader only)")
|
||||
|
||||
async def on_lose_leader():
|
||||
self.logger.warning("This pod lost leadership - stopping scheduler")
|
||||
if self.scheduler_service:
|
||||
await self.scheduler_service.stop()
|
||||
self.scheduler_service = None
|
||||
if hasattr(app.state, 'scheduler_service'):
|
||||
app.state.scheduler_service = None
|
||||
self.logger.info("Orchestrator scheduler service stopped (no longer leader)")
|
||||
|
||||
# Start leader election
|
||||
await self.leader_election.start(
|
||||
on_become_leader=on_become_leader,
|
||||
on_lose_leader=on_lose_leader
|
||||
)
|
||||
|
||||
# Store leader election in app state for health checks
|
||||
app.state.leader_election = self.leader_election
|
||||
|
||||
self.logger.info("Leader election initialized",
|
||||
is_leader=self.leader_election.is_leader,
|
||||
instance_id=self.leader_election.instance_id)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Failed to setup leader election, falling back to standalone mode",
|
||||
error=str(e))
|
||||
# Fallback: start scheduler anyway (for single-pod deployments)
|
||||
from app.services.orchestrator_service import OrchestratorSchedulerService
|
||||
self.scheduler_service = OrchestratorSchedulerService(self.event_publisher, settings)
|
||||
await self.scheduler_service.start()
|
||||
app.state.scheduler_service = self.scheduler_service
|
||||
self.logger.warning("Scheduler started in standalone mode (no leader election)")
|
||||
|
||||
async def on_shutdown(self, app: FastAPI):
|
||||
"""Custom shutdown logic for orchestrator service"""
|
||||
self.logger.info("Orchestrator Service shutting down...")
|
||||
|
||||
# Stop scheduler service
|
||||
if hasattr(app.state, 'scheduler_service'):
|
||||
await app.state.scheduler_service.stop()
|
||||
# Stop leader election (this will also stop scheduler if we're the leader)
|
||||
if self.leader_election:
|
||||
await self.leader_election.stop()
|
||||
self.logger.info("Leader election stopped")
|
||||
|
||||
# Stop scheduler service if still running
|
||||
if self.scheduler_service:
|
||||
await self.scheduler_service.stop()
|
||||
self.logger.info("Orchestrator scheduler service stopped")
|
||||
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
"""
|
||||
Delivery Tracking Service - Simplified
|
||||
Delivery Tracking Service - With Leader Election
|
||||
|
||||
Tracks purchase order deliveries and generates appropriate alerts using EventPublisher:
|
||||
- DELIVERY_ARRIVING_SOON: 2 hours before delivery window
|
||||
- DELIVERY_OVERDUE: 30 minutes after expected delivery time
|
||||
- STOCK_RECEIPT_INCOMPLETE: If delivery not marked as received
|
||||
|
||||
Runs as internal scheduler with leader election.
|
||||
Runs as internal scheduler with leader election for horizontal scaling.
|
||||
Domain ownership: Procurement service owns all PO and delivery tracking.
|
||||
"""
|
||||
|
||||
@@ -30,7 +30,7 @@ class DeliveryTrackingService:
|
||||
Monitors PO deliveries and generates time-based alerts using EventPublisher.
|
||||
|
||||
Uses APScheduler with leader election to run hourly checks.
|
||||
Only one pod executes checks (others skip if not leader).
|
||||
Only one pod executes checks - leader election ensures no duplicate alerts.
|
||||
"""
|
||||
|
||||
def __init__(self, event_publisher: UnifiedEventPublisher, config, database_manager=None):
|
||||
@@ -38,46 +38,121 @@ class DeliveryTrackingService:
|
||||
self.config = config
|
||||
self.database_manager = database_manager
|
||||
self.scheduler = AsyncIOScheduler()
|
||||
self.is_leader = False
|
||||
self._leader_election = None
|
||||
self._redis_client = None
|
||||
self._scheduler_started = False
|
||||
self.instance_id = str(uuid4())[:8] # Short instance ID for logging
|
||||
|
||||
async def start(self):
|
||||
"""Start the delivery tracking scheduler"""
|
||||
# Initialize and start scheduler if not already running
|
||||
"""Start the delivery tracking scheduler with leader election"""
|
||||
try:
|
||||
# Initialize leader election
|
||||
await self._setup_leader_election()
|
||||
except Exception as e:
|
||||
logger.error("Failed to setup leader election, starting in standalone mode",
|
||||
error=str(e))
|
||||
# Fallback: start scheduler without leader election
|
||||
await self._start_scheduler()
|
||||
|
||||
async def _setup_leader_election(self):
|
||||
"""Setup Redis-based leader election for horizontal scaling"""
|
||||
from shared.leader_election import LeaderElectionService
|
||||
import redis.asyncio as redis
|
||||
|
||||
# Build Redis URL from config
|
||||
redis_url = getattr(self.config, 'REDIS_URL', None)
|
||||
if not redis_url:
|
||||
redis_password = getattr(self.config, 'REDIS_PASSWORD', '')
|
||||
redis_host = getattr(self.config, 'REDIS_HOST', 'localhost')
|
||||
redis_port = getattr(self.config, 'REDIS_PORT', 6379)
|
||||
redis_db = getattr(self.config, 'REDIS_DB', 0)
|
||||
redis_url = f"redis://:{redis_password}@{redis_host}:{redis_port}/{redis_db}"
|
||||
|
||||
self._redis_client = redis.from_url(redis_url, decode_responses=False)
|
||||
await self._redis_client.ping()
|
||||
|
||||
# Create leader election service
|
||||
self._leader_election = LeaderElectionService(
|
||||
self._redis_client,
|
||||
service_name="procurement-delivery-tracking"
|
||||
)
|
||||
|
||||
# Start leader election with callbacks
|
||||
await self._leader_election.start(
|
||||
on_become_leader=self._on_become_leader,
|
||||
on_lose_leader=self._on_lose_leader
|
||||
)
|
||||
|
||||
logger.info("Leader election initialized for delivery tracking",
|
||||
is_leader=self._leader_election.is_leader,
|
||||
instance_id=self.instance_id)
|
||||
|
||||
async def _on_become_leader(self):
|
||||
"""Called when this instance becomes the leader"""
|
||||
logger.info("Became leader for delivery tracking - starting scheduler",
|
||||
instance_id=self.instance_id)
|
||||
await self._start_scheduler()
|
||||
|
||||
async def _on_lose_leader(self):
|
||||
"""Called when this instance loses leadership"""
|
||||
logger.warning("Lost leadership for delivery tracking - stopping scheduler",
|
||||
instance_id=self.instance_id)
|
||||
await self._stop_scheduler()
|
||||
|
||||
async def _start_scheduler(self):
|
||||
"""Start the APScheduler with delivery tracking jobs"""
|
||||
if self._scheduler_started:
|
||||
logger.debug("Scheduler already started", instance_id=self.instance_id)
|
||||
return
|
||||
|
||||
if not self.scheduler.running:
|
||||
# Add hourly job to check deliveries
|
||||
self.scheduler.add_job(
|
||||
self._check_all_tenants,
|
||||
trigger=CronTrigger(minute=30), # Run every hour at :30 (00:30, 01:30, 02:30, etc.)
|
||||
trigger=CronTrigger(minute=30), # Run every hour at :30
|
||||
id='hourly_delivery_check',
|
||||
name='Hourly Delivery Tracking',
|
||||
replace_existing=True,
|
||||
max_instances=1, # Ensure no overlapping runs
|
||||
coalesce=True # Combine missed runs
|
||||
max_instances=1,
|
||||
coalesce=True
|
||||
)
|
||||
|
||||
self.scheduler.start()
|
||||
self._scheduler_started = True
|
||||
|
||||
# Log next run time
|
||||
next_run = self.scheduler.get_job('hourly_delivery_check').next_run_time
|
||||
logger.info(
|
||||
"Delivery tracking scheduler started with hourly checks",
|
||||
instance_id=self.instance_id,
|
||||
next_run=next_run.isoformat() if next_run else None
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Delivery tracking scheduler already running",
|
||||
instance_id=self.instance_id
|
||||
)
|
||||
logger.info("Delivery tracking scheduler started",
|
||||
instance_id=self.instance_id,
|
||||
next_run=next_run.isoformat() if next_run else None)
|
||||
|
||||
async def _stop_scheduler(self):
|
||||
"""Stop the APScheduler"""
|
||||
if not self._scheduler_started:
|
||||
return
|
||||
|
||||
if self.scheduler.running:
|
||||
self.scheduler.shutdown(wait=False)
|
||||
self._scheduler_started = False
|
||||
logger.info("Delivery tracking scheduler stopped", instance_id=self.instance_id)
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the scheduler and release leader lock"""
|
||||
if self.scheduler.running:
|
||||
self.scheduler.shutdown(wait=True) # Graceful shutdown
|
||||
logger.info("Delivery tracking scheduler stopped", instance_id=self.instance_id)
|
||||
else:
|
||||
logger.info("Delivery tracking scheduler already stopped", instance_id=self.instance_id)
|
||||
"""Stop the scheduler and leader election"""
|
||||
# Stop leader election first
|
||||
if self._leader_election:
|
||||
await self._leader_election.stop()
|
||||
logger.info("Leader election stopped", instance_id=self.instance_id)
|
||||
|
||||
# Stop scheduler
|
||||
await self._stop_scheduler()
|
||||
|
||||
# Close Redis
|
||||
if self._redis_client:
|
||||
await self._redis_client.close()
|
||||
|
||||
@property
|
||||
def is_leader(self) -> bool:
|
||||
"""Check if this instance is the leader"""
|
||||
return self._leader_election.is_leader if self._leader_election else True
|
||||
|
||||
async def _check_all_tenants(self):
|
||||
"""
|
||||
|
||||
@@ -46,6 +46,9 @@ class TrainingService(StandardFastAPIService):
|
||||
await setup_messaging()
|
||||
self.logger.info("Messaging setup completed")
|
||||
|
||||
# Initialize Redis pub/sub for cross-pod WebSocket broadcasting
|
||||
await self._setup_websocket_redis()
|
||||
|
||||
# Set up WebSocket event consumer (listens to RabbitMQ and broadcasts to WebSockets)
|
||||
success = await setup_websocket_event_consumer()
|
||||
if success:
|
||||
@@ -53,8 +56,44 @@ class TrainingService(StandardFastAPIService):
|
||||
else:
|
||||
self.logger.warning("WebSocket event consumer setup failed")
|
||||
|
||||
async def _setup_websocket_redis(self):
|
||||
"""
|
||||
Initialize Redis pub/sub for WebSocket cross-pod broadcasting.
|
||||
|
||||
CRITICAL FOR HORIZONTAL SCALING:
|
||||
Without this, WebSocket clients on Pod A won't receive events
|
||||
from training jobs running on Pod B.
|
||||
"""
|
||||
try:
|
||||
from app.websocket.manager import websocket_manager
|
||||
from app.core.config import settings
|
||||
|
||||
redis_url = settings.REDIS_URL
|
||||
success = await websocket_manager.initialize_redis(redis_url)
|
||||
|
||||
if success:
|
||||
self.logger.info("WebSocket Redis pub/sub initialized for horizontal scaling")
|
||||
else:
|
||||
self.logger.warning(
|
||||
"WebSocket Redis pub/sub failed to initialize. "
|
||||
"WebSocket events will only be delivered to local connections."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Failed to setup WebSocket Redis pub/sub",
|
||||
error=str(e))
|
||||
# Don't fail startup - WebSockets will work locally without Redis
|
||||
|
||||
async def _cleanup_messaging(self):
|
||||
"""Cleanup messaging for training service"""
|
||||
# Shutdown WebSocket Redis pub/sub
|
||||
try:
|
||||
from app.websocket.manager import websocket_manager
|
||||
await websocket_manager.shutdown()
|
||||
self.logger.info("WebSocket Redis pub/sub shutdown completed")
|
||||
except Exception as e:
|
||||
self.logger.warning("Error shutting down WebSocket Redis", error=str(e))
|
||||
|
||||
await cleanup_websocket_consumers()
|
||||
await cleanup_messaging()
|
||||
|
||||
@@ -78,13 +117,49 @@ class TrainingService(StandardFastAPIService):
|
||||
async def on_startup(self, app: FastAPI):
|
||||
"""Custom startup logic including migration verification"""
|
||||
await self.verify_migrations()
|
||||
|
||||
|
||||
# Initialize system metrics collection
|
||||
system_metrics = SystemMetricsCollector("training")
|
||||
self.logger.info("System metrics collection started")
|
||||
|
||||
|
||||
# Recover stale jobs from previous pod crashes
|
||||
# This is important for horizontal scaling - jobs may be left in 'running'
|
||||
# state if a pod crashes. We mark them as failed so they can be retried.
|
||||
await self._recover_stale_jobs()
|
||||
|
||||
self.logger.info("Training service startup completed")
|
||||
|
||||
async def _recover_stale_jobs(self):
|
||||
"""
|
||||
Recover stale training jobs on startup.
|
||||
|
||||
When a pod crashes mid-training, jobs are left in 'running' or 'pending' state.
|
||||
This method finds jobs that haven't been updated in a while and marks them
|
||||
as failed so users can retry them.
|
||||
"""
|
||||
try:
|
||||
from app.repositories.training_log_repository import TrainingLogRepository
|
||||
|
||||
async with self.database_manager.get_session() as session:
|
||||
log_repo = TrainingLogRepository(session)
|
||||
|
||||
# Recover jobs that haven't been updated in 60 minutes
|
||||
# This is conservative - most training jobs complete within 30 minutes
|
||||
recovered = await log_repo.recover_stale_jobs(stale_threshold_minutes=60)
|
||||
|
||||
if recovered:
|
||||
self.logger.warning(
|
||||
"Recovered stale training jobs on startup",
|
||||
recovered_count=len(recovered),
|
||||
job_ids=[j.job_id for j in recovered]
|
||||
)
|
||||
else:
|
||||
self.logger.info("No stale training jobs to recover")
|
||||
|
||||
except Exception as e:
|
||||
# Don't fail startup if recovery fails - just log the error
|
||||
self.logger.error("Failed to recover stale jobs on startup", error=str(e))
|
||||
|
||||
async def on_shutdown(self, app: FastAPI):
|
||||
"""Custom shutdown logic for training service"""
|
||||
await cleanup_training_database()
|
||||
|
||||
@@ -342,4 +342,166 @@ class TrainingLogRepository(TrainingBaseRepository):
|
||||
logger.error("Failed to get start time",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
return None
|
||||
return None
|
||||
|
||||
async def create_job_atomic(
|
||||
self,
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
config: Dict[str, Any] = None
|
||||
) -> tuple[Optional[ModelTrainingLog], bool]:
|
||||
"""
|
||||
Atomically create a training job, respecting the unique constraint.
|
||||
|
||||
This method uses INSERT ... ON CONFLICT to handle race conditions
|
||||
when multiple pods try to create a job for the same tenant simultaneously.
|
||||
The database constraint (idx_unique_active_training_per_tenant) ensures
|
||||
only one active job per tenant can exist.
|
||||
|
||||
Args:
|
||||
job_id: Unique job identifier
|
||||
tenant_id: Tenant identifier
|
||||
config: Optional job configuration
|
||||
|
||||
Returns:
|
||||
Tuple of (job, created):
|
||||
- If created: (new_job, True)
|
||||
- If conflict (existing active job): (existing_job, False)
|
||||
- If error: raises DatabaseError
|
||||
"""
|
||||
try:
|
||||
# First, try to find an existing active job
|
||||
existing = await self.get_active_jobs(tenant_id=tenant_id)
|
||||
pending = await self.get_logs_by_tenant(tenant_id=tenant_id, status="pending", limit=1)
|
||||
|
||||
if existing or pending:
|
||||
# Return existing job
|
||||
active_job = existing[0] if existing else pending[0]
|
||||
logger.info("Found existing active job, skipping creation",
|
||||
existing_job_id=active_job.job_id,
|
||||
tenant_id=tenant_id,
|
||||
requested_job_id=job_id)
|
||||
return (active_job, False)
|
||||
|
||||
# Try to create the new job
|
||||
# If another pod created one in the meantime, the unique constraint will prevent this
|
||||
log_data = {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"status": "pending",
|
||||
"progress": 0,
|
||||
"current_step": "initializing",
|
||||
"config": config or {}
|
||||
}
|
||||
|
||||
try:
|
||||
new_job = await self.create_training_log(log_data)
|
||||
await self.session.commit()
|
||||
logger.info("Created new training job atomically",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id)
|
||||
return (new_job, True)
|
||||
except Exception as create_error:
|
||||
error_str = str(create_error).lower()
|
||||
# Check if this is a unique constraint violation
|
||||
if "unique" in error_str or "duplicate" in error_str or "constraint" in error_str:
|
||||
await self.session.rollback()
|
||||
# Another pod created a job, fetch it
|
||||
logger.info("Unique constraint hit, fetching existing job",
|
||||
tenant_id=tenant_id,
|
||||
requested_job_id=job_id)
|
||||
existing = await self.get_active_jobs(tenant_id=tenant_id)
|
||||
pending = await self.get_logs_by_tenant(tenant_id=tenant_id, status="pending", limit=1)
|
||||
if existing or pending:
|
||||
active_job = existing[0] if existing else pending[0]
|
||||
return (active_job, False)
|
||||
# If still no job found, something went wrong
|
||||
raise DatabaseError(f"Constraint violation but no active job found: {create_error}")
|
||||
else:
|
||||
raise
|
||||
|
||||
except DatabaseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to create job atomically",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create training job atomically: {str(e)}")
|
||||
|
||||
async def recover_stale_jobs(self, stale_threshold_minutes: int = 60) -> List[ModelTrainingLog]:
|
||||
"""
|
||||
Find and mark stale running jobs as failed.
|
||||
|
||||
This is used during service startup to clean up jobs that were
|
||||
running when a pod crashed. With multiple replicas, only stale
|
||||
jobs (not updated recently) should be marked as failed.
|
||||
|
||||
Args:
|
||||
stale_threshold_minutes: Jobs not updated for this long are considered stale
|
||||
|
||||
Returns:
|
||||
List of jobs that were marked as failed
|
||||
"""
|
||||
try:
|
||||
stale_cutoff = datetime.now() - timedelta(minutes=stale_threshold_minutes)
|
||||
|
||||
# Find running jobs that haven't been updated recently
|
||||
query = text("""
|
||||
SELECT id, job_id, tenant_id, status, updated_at
|
||||
FROM model_training_logs
|
||||
WHERE status IN ('running', 'pending')
|
||||
AND updated_at < :stale_cutoff
|
||||
""")
|
||||
|
||||
result = await self.session.execute(query, {"stale_cutoff": stale_cutoff})
|
||||
stale_jobs = result.fetchall()
|
||||
|
||||
recovered_jobs = []
|
||||
for row in stale_jobs:
|
||||
try:
|
||||
# Mark as failed
|
||||
update_query = text("""
|
||||
UPDATE model_training_logs
|
||||
SET status = 'failed',
|
||||
error_message = :error_msg,
|
||||
end_time = :end_time,
|
||||
updated_at = :updated_at
|
||||
WHERE id = :id AND status IN ('running', 'pending')
|
||||
""")
|
||||
|
||||
await self.session.execute(update_query, {
|
||||
"id": row.id,
|
||||
"error_msg": f"Job recovered as failed - not updated since {row.updated_at.isoformat()}. Pod may have crashed.",
|
||||
"end_time": datetime.now(),
|
||||
"updated_at": datetime.now()
|
||||
})
|
||||
|
||||
logger.warning("Recovered stale training job",
|
||||
job_id=row.job_id,
|
||||
tenant_id=str(row.tenant_id),
|
||||
last_updated=row.updated_at.isoformat() if row.updated_at else "unknown")
|
||||
|
||||
# Fetch the updated job to return
|
||||
job = await self.get_by_job_id(row.job_id)
|
||||
if job:
|
||||
recovered_jobs.append(job)
|
||||
|
||||
except Exception as job_error:
|
||||
logger.error("Failed to recover individual stale job",
|
||||
job_id=row.job_id,
|
||||
error=str(job_error))
|
||||
|
||||
if recovered_jobs:
|
||||
await self.session.commit()
|
||||
logger.info("Stale job recovery completed",
|
||||
recovered_count=len(recovered_jobs),
|
||||
stale_threshold_minutes=stale_threshold_minutes)
|
||||
|
||||
return recovered_jobs
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to recover stale jobs",
|
||||
error=str(e))
|
||||
await self.session.rollback()
|
||||
return []
|
||||
@@ -1,10 +1,16 @@
|
||||
"""
|
||||
Distributed Locking Mechanisms
|
||||
Prevents concurrent training jobs for the same product
|
||||
|
||||
HORIZONTAL SCALING FIX:
|
||||
- Uses SHA256 for stable hash across all Python processes/pods
|
||||
- Python's built-in hash() varies between processes due to hash randomization (Python 3.3+)
|
||||
- This ensures all pods compute the same lock ID for the same lock name
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import hashlib
|
||||
from typing import Optional
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
@@ -39,9 +45,20 @@ class DatabaseLock:
|
||||
self.lock_id = self._hash_lock_name(lock_name)
|
||||
|
||||
def _hash_lock_name(self, name: str) -> int:
|
||||
"""Convert lock name to integer ID for PostgreSQL advisory lock"""
|
||||
# Use hash and modulo to get a positive 32-bit integer
|
||||
return abs(hash(name)) % (2**31)
|
||||
"""
|
||||
Convert lock name to integer ID for PostgreSQL advisory lock.
|
||||
|
||||
CRITICAL: Uses SHA256 for stable hash across all Python processes/pods.
|
||||
Python's built-in hash() varies between processes due to hash randomization
|
||||
(PYTHONHASHSEED, enabled by default since Python 3.3), which would cause
|
||||
different pods to compute different lock IDs for the same lock name,
|
||||
defeating the purpose of distributed locking.
|
||||
"""
|
||||
# Use SHA256 for stable, cross-process hash
|
||||
hash_bytes = hashlib.sha256(name.encode('utf-8')).digest()
|
||||
# Take first 4 bytes and convert to positive 31-bit integer
|
||||
# (PostgreSQL advisory locks use bigint, but we use 31-bit for safety)
|
||||
return int.from_bytes(hash_bytes[:4], 'big') % (2**31)
|
||||
|
||||
@asynccontextmanager
|
||||
async def acquire(self, session: AsyncSession):
|
||||
|
||||
@@ -1,21 +1,39 @@
|
||||
"""
|
||||
WebSocket Connection Manager for Training Service
|
||||
Manages WebSocket connections and broadcasts RabbitMQ events to connected clients
|
||||
|
||||
HORIZONTAL SCALING:
|
||||
- Uses Redis pub/sub for cross-pod WebSocket broadcasting
|
||||
- Each pod subscribes to a Redis channel and broadcasts to its local connections
|
||||
- Events published to Redis are received by all pods, ensuring clients on any
|
||||
pod receive events from training jobs running on any other pod
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Dict, Set
|
||||
import os
|
||||
from typing import Dict, Optional
|
||||
from fastapi import WebSocket
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Redis pub/sub channel for WebSocket events
|
||||
REDIS_WEBSOCKET_CHANNEL = "training:websocket:events"
|
||||
|
||||
|
||||
class WebSocketConnectionManager:
|
||||
"""
|
||||
Simple WebSocket connection manager.
|
||||
Manages connections per job_id and broadcasts messages to all connected clients.
|
||||
WebSocket connection manager with Redis pub/sub for horizontal scaling.
|
||||
|
||||
In a multi-pod deployment:
|
||||
1. Events are published to Redis pub/sub (not just local broadcast)
|
||||
2. Each pod subscribes to Redis and broadcasts to its local WebSocket connections
|
||||
3. This ensures clients connected to any pod receive events from any pod
|
||||
|
||||
Flow:
|
||||
- RabbitMQ event → Pod A receives → Pod A publishes to Redis
|
||||
- Redis pub/sub → All pods receive → Each pod broadcasts to local WebSockets
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@@ -24,6 +42,121 @@ class WebSocketConnectionManager:
|
||||
self._lock = asyncio.Lock()
|
||||
# Store latest event for each job to provide initial state
|
||||
self._latest_events: Dict[str, dict] = {}
|
||||
# Redis client for pub/sub
|
||||
self._redis: Optional[object] = None
|
||||
self._pubsub: Optional[object] = None
|
||||
self._subscriber_task: Optional[asyncio.Task] = None
|
||||
self._running = False
|
||||
self._instance_id = f"{os.environ.get('HOSTNAME', 'unknown')}:{os.getpid()}"
|
||||
|
||||
async def initialize_redis(self, redis_url: str) -> bool:
|
||||
"""
|
||||
Initialize Redis connection for cross-pod pub/sub.
|
||||
|
||||
Args:
|
||||
redis_url: Redis connection URL
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
import redis.asyncio as redis_async
|
||||
|
||||
self._redis = redis_async.from_url(redis_url, decode_responses=True)
|
||||
await self._redis.ping()
|
||||
|
||||
# Create pub/sub subscriber
|
||||
self._pubsub = self._redis.pubsub()
|
||||
await self._pubsub.subscribe(REDIS_WEBSOCKET_CHANNEL)
|
||||
|
||||
# Start subscriber task
|
||||
self._running = True
|
||||
self._subscriber_task = asyncio.create_task(self._redis_subscriber_loop())
|
||||
|
||||
logger.info("Redis pub/sub initialized for WebSocket broadcasting",
|
||||
instance_id=self._instance_id,
|
||||
channel=REDIS_WEBSOCKET_CHANNEL)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize Redis pub/sub",
|
||||
error=str(e),
|
||||
instance_id=self._instance_id)
|
||||
return False
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown Redis pub/sub connection"""
|
||||
self._running = False
|
||||
|
||||
if self._subscriber_task:
|
||||
self._subscriber_task.cancel()
|
||||
try:
|
||||
await self._subscriber_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self._pubsub:
|
||||
await self._pubsub.unsubscribe(REDIS_WEBSOCKET_CHANNEL)
|
||||
await self._pubsub.close()
|
||||
|
||||
if self._redis:
|
||||
await self._redis.close()
|
||||
|
||||
logger.info("Redis pub/sub shutdown complete",
|
||||
instance_id=self._instance_id)
|
||||
|
||||
async def _redis_subscriber_loop(self):
|
||||
"""Background task to receive Redis pub/sub messages and broadcast locally"""
|
||||
try:
|
||||
while self._running:
|
||||
try:
|
||||
message = await self._pubsub.get_message(
|
||||
ignore_subscribe_messages=True,
|
||||
timeout=1.0
|
||||
)
|
||||
|
||||
if message and message['type'] == 'message':
|
||||
await self._handle_redis_message(message['data'])
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error("Error in Redis subscriber loop",
|
||||
error=str(e),
|
||||
instance_id=self._instance_id)
|
||||
await asyncio.sleep(1) # Backoff on error
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info("Redis subscriber loop stopped",
|
||||
instance_id=self._instance_id)
|
||||
|
||||
async def _handle_redis_message(self, data: str):
|
||||
"""Handle a message received from Redis pub/sub"""
|
||||
try:
|
||||
payload = json.loads(data)
|
||||
job_id = payload.get('job_id')
|
||||
message = payload.get('message')
|
||||
source_instance = payload.get('source_instance')
|
||||
|
||||
if not job_id or not message:
|
||||
return
|
||||
|
||||
# Log cross-pod message
|
||||
if source_instance != self._instance_id:
|
||||
logger.debug("Received cross-pod WebSocket event",
|
||||
job_id=job_id,
|
||||
source_instance=source_instance,
|
||||
local_instance=self._instance_id)
|
||||
|
||||
# Broadcast to local WebSocket connections
|
||||
await self._broadcast_local(job_id, message)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning("Invalid JSON in Redis message", error=str(e))
|
||||
except Exception as e:
|
||||
logger.error("Error handling Redis message", error=str(e))
|
||||
|
||||
async def connect(self, job_id: str, websocket: WebSocket) -> None:
|
||||
"""Register a new WebSocket connection for a job"""
|
||||
@@ -50,7 +183,8 @@ class WebSocketConnectionManager:
|
||||
logger.info("WebSocket connected",
|
||||
job_id=job_id,
|
||||
websocket_id=ws_id,
|
||||
total_connections=len(self._connections[job_id]))
|
||||
total_connections=len(self._connections[job_id]),
|
||||
instance_id=self._instance_id)
|
||||
|
||||
async def disconnect(self, job_id: str, websocket: WebSocket) -> None:
|
||||
"""Remove a WebSocket connection"""
|
||||
@@ -66,19 +200,56 @@ class WebSocketConnectionManager:
|
||||
logger.info("WebSocket disconnected",
|
||||
job_id=job_id,
|
||||
websocket_id=ws_id,
|
||||
remaining_connections=len(self._connections.get(job_id, {})))
|
||||
remaining_connections=len(self._connections.get(job_id, {})),
|
||||
instance_id=self._instance_id)
|
||||
|
||||
async def broadcast(self, job_id: str, message: dict) -> int:
|
||||
"""
|
||||
Broadcast a message to all connections for a specific job.
|
||||
Returns the number of successful broadcasts.
|
||||
Broadcast a message to all connections for a specific job across ALL pods.
|
||||
|
||||
If Redis is configured, publishes to Redis pub/sub which then broadcasts
|
||||
to all pods. Otherwise, falls back to local-only broadcast.
|
||||
|
||||
Returns the number of successful local broadcasts.
|
||||
"""
|
||||
# Store the latest event for this job to provide initial state to new connections
|
||||
if message.get('type') != 'initial_state': # Don't store initial_state messages
|
||||
if message.get('type') != 'initial_state':
|
||||
self._latest_events[job_id] = message
|
||||
|
||||
# If Redis is available, publish to Redis for cross-pod broadcast
|
||||
if self._redis:
|
||||
try:
|
||||
payload = json.dumps({
|
||||
'job_id': job_id,
|
||||
'message': message,
|
||||
'source_instance': self._instance_id
|
||||
})
|
||||
await self._redis.publish(REDIS_WEBSOCKET_CHANNEL, payload)
|
||||
logger.debug("Published WebSocket event to Redis",
|
||||
job_id=job_id,
|
||||
message_type=message.get('type'),
|
||||
instance_id=self._instance_id)
|
||||
# Return 0 here because the actual broadcast happens via subscriber
|
||||
# The count will be from _broadcast_local when the message is received
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish to Redis, falling back to local broadcast",
|
||||
error=str(e),
|
||||
job_id=job_id)
|
||||
# Fall through to local broadcast
|
||||
|
||||
# Local-only broadcast (when Redis is not available)
|
||||
return await self._broadcast_local(job_id, message)
|
||||
|
||||
async def _broadcast_local(self, job_id: str, message: dict) -> int:
|
||||
"""
|
||||
Broadcast a message to local WebSocket connections only.
|
||||
This is called either directly (no Redis) or from Redis subscriber.
|
||||
"""
|
||||
if job_id not in self._connections:
|
||||
logger.debug("No active connections for job", job_id=job_id)
|
||||
logger.debug("No active local connections for job",
|
||||
job_id=job_id,
|
||||
instance_id=self._instance_id)
|
||||
return 0
|
||||
|
||||
connections = list(self._connections[job_id].values())
|
||||
@@ -103,18 +274,27 @@ class WebSocketConnectionManager:
|
||||
self._connections[job_id].pop(ws_id, None)
|
||||
|
||||
if successful_sends > 0:
|
||||
logger.info("Broadcasted message to WebSocket clients",
|
||||
logger.info("Broadcasted message to local WebSocket clients",
|
||||
job_id=job_id,
|
||||
message_type=message.get('type'),
|
||||
successful_sends=successful_sends,
|
||||
failed_sends=len(failed_websockets))
|
||||
failed_sends=len(failed_websockets),
|
||||
instance_id=self._instance_id)
|
||||
|
||||
return successful_sends
|
||||
|
||||
def get_connection_count(self, job_id: str) -> int:
|
||||
"""Get the number of active connections for a job"""
|
||||
"""Get the number of active local connections for a job"""
|
||||
return len(self._connections.get(job_id, {}))
|
||||
|
||||
def get_total_connection_count(self) -> int:
|
||||
"""Get total number of active connections across all jobs"""
|
||||
return sum(len(conns) for conns in self._connections.values())
|
||||
|
||||
def is_redis_enabled(self) -> bool:
|
||||
"""Check if Redis pub/sub is enabled"""
|
||||
return self._redis is not None and self._running
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
websocket_manager = WebSocketConnectionManager()
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
"""Add horizontal scaling constraints for multi-pod deployment
|
||||
|
||||
Revision ID: add_horizontal_scaling
|
||||
Revises: 26a665cd5348
|
||||
Create Date: 2025-01-18
|
||||
|
||||
This migration adds database-level constraints to prevent race conditions
|
||||
when running multiple training service pods:
|
||||
|
||||
1. Partial unique index on model_training_logs to prevent duplicate active jobs per tenant
|
||||
2. Index to speed up active job lookups
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'add_horizontal_scaling'
|
||||
down_revision: Union[str, None] = '26a665cd5348'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add partial unique index to prevent duplicate active training jobs per tenant
|
||||
# This ensures only ONE job can be in 'pending' or 'running' status per tenant at a time
|
||||
# The constraint is enforced at the database level, preventing race conditions
|
||||
# between multiple pods checking and creating jobs simultaneously
|
||||
op.execute("""
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_unique_active_training_per_tenant
|
||||
ON model_training_logs (tenant_id)
|
||||
WHERE status IN ('pending', 'running')
|
||||
""")
|
||||
|
||||
# Add index to speed up active job lookups (used by deduplication check)
|
||||
op.create_index(
|
||||
'idx_training_logs_tenant_status',
|
||||
'model_training_logs',
|
||||
['tenant_id', 'status'],
|
||||
unique=False,
|
||||
if_not_exists=True
|
||||
)
|
||||
|
||||
# Add index for job recovery queries (find stale running jobs)
|
||||
op.create_index(
|
||||
'idx_training_logs_status_updated',
|
||||
'model_training_logs',
|
||||
['status', 'updated_at'],
|
||||
unique=False,
|
||||
if_not_exists=True
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the indexes in reverse order
|
||||
op.execute("DROP INDEX IF EXISTS idx_training_logs_status_updated")
|
||||
op.execute("DROP INDEX IF EXISTS idx_training_logs_tenant_status")
|
||||
op.execute("DROP INDEX IF EXISTS idx_unique_active_training_per_tenant")
|
||||
Reference in New Issue
Block a user