REFACTOR external service and improve websocket training

This commit is contained in:
Urtzi Alfaro
2025-10-09 14:11:02 +02:00
parent 7c72f83c51
commit 3c689b4f98
111 changed files with 13289 additions and 2374 deletions

View File

@@ -1,14 +1,16 @@
"""
Training API Layer
HTTP endpoints for ML training operations
HTTP endpoints for ML training operations and WebSocket connections
"""
from .training_jobs import router as training_jobs_router
from .training_operations import router as training_operations_router
from .models import router as models_router
from .websocket_operations import router as websocket_operations_router
__all__ = [
"training_jobs_router",
"training_operations_router",
"models_router"
"models_router",
"websocket_operations_router"
]

View File

@@ -0,0 +1,261 @@
"""
Enhanced Health Check Endpoints
Comprehensive service health monitoring
"""
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import text
from typing import Dict, Any
import psutil
import os
from datetime import datetime, timezone
import logging
from app.core.database import database_manager
from app.utils.circuit_breaker import circuit_breaker_registry
from app.core.config import settings
logger = logging.getLogger(__name__)
router = APIRouter()
async def check_database_health() -> Dict[str, Any]:
"""Check database connectivity and performance"""
try:
start_time = datetime.now(timezone.utc)
async with database_manager.async_engine.begin() as conn:
# Simple connectivity check
await conn.execute(text("SELECT 1"))
# Check if we can access training tables
result = await conn.execute(
text("SELECT COUNT(*) FROM trained_models")
)
model_count = result.scalar()
# Check connection pool stats
pool = database_manager.async_engine.pool
pool_size = pool.size()
pool_checked_out = pool.checked_out_connections()
response_time = (datetime.now(timezone.utc) - start_time).total_seconds()
return {
"status": "healthy",
"response_time_seconds": round(response_time, 3),
"model_count": model_count,
"connection_pool": {
"size": pool_size,
"checked_out": pool_checked_out,
"available": pool_size - pool_checked_out
}
}
except Exception as e:
logger.error(f"Database health check failed: {e}")
return {
"status": "unhealthy",
"error": str(e)
}
def check_system_resources() -> Dict[str, Any]:
"""Check system resource usage"""
try:
cpu_percent = psutil.cpu_percent(interval=0.1)
memory = psutil.virtual_memory()
disk = psutil.disk_usage('/')
return {
"status": "healthy",
"cpu": {
"usage_percent": cpu_percent,
"count": psutil.cpu_count()
},
"memory": {
"total_mb": round(memory.total / 1024 / 1024, 2),
"used_mb": round(memory.used / 1024 / 1024, 2),
"available_mb": round(memory.available / 1024 / 1024, 2),
"usage_percent": memory.percent
},
"disk": {
"total_gb": round(disk.total / 1024 / 1024 / 1024, 2),
"used_gb": round(disk.used / 1024 / 1024 / 1024, 2),
"free_gb": round(disk.free / 1024 / 1024 / 1024, 2),
"usage_percent": disk.percent
}
}
except Exception as e:
logger.error(f"System resource check failed: {e}")
return {
"status": "error",
"error": str(e)
}
def check_model_storage() -> Dict[str, Any]:
"""Check model storage health"""
try:
storage_path = settings.MODEL_STORAGE_PATH
if not os.path.exists(storage_path):
return {
"status": "warning",
"message": f"Model storage path does not exist: {storage_path}"
}
# Check if writable
test_file = os.path.join(storage_path, ".health_check")
try:
with open(test_file, 'w') as f:
f.write("test")
os.remove(test_file)
writable = True
except Exception:
writable = False
# Count model files
model_files = 0
total_size = 0
for root, dirs, files in os.walk(storage_path):
for file in files:
if file.endswith('.pkl'):
model_files += 1
file_path = os.path.join(root, file)
total_size += os.path.getsize(file_path)
return {
"status": "healthy" if writable else "degraded",
"path": storage_path,
"writable": writable,
"model_files": model_files,
"total_size_mb": round(total_size / 1024 / 1024, 2)
}
except Exception as e:
logger.error(f"Model storage check failed: {e}")
return {
"status": "error",
"error": str(e)
}
@router.get("/health")
async def health_check() -> Dict[str, Any]:
"""
Basic health check endpoint.
Returns 200 if service is running.
"""
return {
"status": "healthy",
"service": "training-service",
"timestamp": datetime.now(timezone.utc).isoformat()
}
@router.get("/health/detailed")
async def detailed_health_check() -> Dict[str, Any]:
"""
Detailed health check with component status.
Includes database, system resources, and dependencies.
"""
database_health = await check_database_health()
system_health = check_system_resources()
storage_health = check_model_storage()
circuit_breakers = circuit_breaker_registry.get_all_states()
# Determine overall status
component_statuses = [
database_health.get("status"),
system_health.get("status"),
storage_health.get("status")
]
if "unhealthy" in component_statuses or "error" in component_statuses:
overall_status = "unhealthy"
elif "degraded" in component_statuses or "warning" in component_statuses:
overall_status = "degraded"
else:
overall_status = "healthy"
return {
"status": overall_status,
"service": "training-service",
"version": "1.0.0",
"timestamp": datetime.now(timezone.utc).isoformat(),
"components": {
"database": database_health,
"system": system_health,
"storage": storage_health
},
"circuit_breakers": circuit_breakers,
"configuration": {
"max_concurrent_jobs": settings.MAX_CONCURRENT_TRAINING_JOBS,
"min_training_days": settings.MIN_TRAINING_DATA_DAYS,
"pool_size": settings.DB_POOL_SIZE,
"pool_max_overflow": settings.DB_MAX_OVERFLOW
}
}
@router.get("/health/ready")
async def readiness_check() -> Dict[str, Any]:
"""
Readiness check for Kubernetes.
Returns 200 only if service is ready to accept traffic.
"""
database_health = await check_database_health()
if database_health.get("status") != "healthy":
raise HTTPException(
status_code=503,
detail="Service not ready: database unavailable"
)
storage_health = check_model_storage()
if storage_health.get("status") == "error":
raise HTTPException(
status_code=503,
detail="Service not ready: model storage unavailable"
)
return {
"status": "ready",
"timestamp": datetime.now(timezone.utc).isoformat()
}
@router.get("/health/live")
async def liveness_check() -> Dict[str, Any]:
"""
Liveness check for Kubernetes.
Returns 200 if service process is alive.
"""
return {
"status": "alive",
"timestamp": datetime.now(timezone.utc).isoformat(),
"pid": os.getpid()
}
@router.get("/metrics/system")
async def system_metrics() -> Dict[str, Any]:
"""
Detailed system metrics for monitoring.
"""
process = psutil.Process(os.getpid())
return {
"timestamp": datetime.now(timezone.utc).isoformat(),
"process": {
"pid": os.getpid(),
"cpu_percent": process.cpu_percent(interval=0.1),
"memory_mb": round(process.memory_info().rss / 1024 / 1024, 2),
"threads": process.num_threads(),
"open_files": len(process.open_files()),
"connections": len(process.connections())
},
"system": check_system_resources()
}

View File

@@ -10,14 +10,12 @@ from sqlalchemy import text
from app.core.database import get_db
from app.schemas.training import TrainedModelResponse, ModelMetricsResponse
from app.services.training_service import TrainingService
from app.services.training_service import EnhancedTrainingService
from datetime import datetime
from sqlalchemy import select, delete, func
import uuid
import shutil
from app.services.messaging import publish_models_deleted_event
from shared.auth.decorators import (
get_current_user_dep,
require_admin_role
@@ -38,7 +36,7 @@ route_builder = RouteBuilder('training')
logger = structlog.get_logger()
router = APIRouter()
training_service = TrainingService()
training_service = EnhancedTrainingService()
@router.get(
route_builder.build_base_route("models") + "/{inventory_product_id}/active"
@@ -472,12 +470,7 @@ async def delete_tenant_models_complete(
deletion_stats["errors"].append(error_msg)
logger.warning(error_msg)
# Step 5: Publish deletion event
try:
await publish_models_deleted_event(tenant_id, deletion_stats)
except Exception as e:
logger.warning("Failed to publish models deletion event", error=str(e))
# Models deleted successfully
return {
"success": True,
"message": f"All training data for tenant {tenant_id} deleted successfully",

View File

@@ -0,0 +1,410 @@
"""
Monitoring and Observability Endpoints
Real-time service monitoring and diagnostics
"""
from fastapi import APIRouter, Query
from typing import Dict, Any, List, Optional
from datetime import datetime, timezone, timedelta
from sqlalchemy import text, func
import logging
from app.core.database import database_manager
from app.utils.circuit_breaker import circuit_breaker_registry
from app.models.training import ModelTrainingLog, TrainingJobQueue, TrainedModel
logger = logging.getLogger(__name__)
router = APIRouter()
@router.get("/monitoring/circuit-breakers")
async def get_circuit_breaker_status() -> Dict[str, Any]:
"""
Get status of all circuit breakers.
Useful for monitoring external service health.
"""
breakers = circuit_breaker_registry.get_all_states()
return {
"timestamp": datetime.now(timezone.utc).isoformat(),
"circuit_breakers": breakers,
"summary": {
"total": len(breakers),
"open": sum(1 for b in breakers.values() if b["state"] == "open"),
"half_open": sum(1 for b in breakers.values() if b["state"] == "half_open"),
"closed": sum(1 for b in breakers.values() if b["state"] == "closed")
}
}
@router.post("/monitoring/circuit-breakers/{name}/reset")
async def reset_circuit_breaker(name: str) -> Dict[str, str]:
"""
Manually reset a circuit breaker.
Use with caution - only reset if you know the service has recovered.
"""
circuit_breaker_registry.reset(name)
return {
"status": "success",
"message": f"Circuit breaker '{name}' has been reset",
"timestamp": datetime.now(timezone.utc).isoformat()
}
@router.get("/monitoring/training-jobs")
async def get_training_job_stats(
hours: int = Query(default=24, ge=1, le=168, description="Look back period in hours")
) -> Dict[str, Any]:
"""
Get training job statistics for the specified period.
"""
try:
since = datetime.now(timezone.utc) - timedelta(hours=hours)
async with database_manager.get_session() as session:
# Get job counts by status
result = await session.execute(
text("""
SELECT status, COUNT(*) as count
FROM model_training_logs
WHERE created_at >= :since
GROUP BY status
"""),
{"since": since}
)
status_counts = dict(result.fetchall())
# Get average training time for completed jobs
result = await session.execute(
text("""
SELECT AVG(EXTRACT(EPOCH FROM (end_time - start_time))) as avg_duration
FROM model_training_logs
WHERE status = 'completed'
AND created_at >= :since
AND end_time IS NOT NULL
"""),
{"since": since}
)
avg_duration = result.scalar()
# Get failure rate
total = sum(status_counts.values())
failed = status_counts.get('failed', 0)
failure_rate = (failed / total * 100) if total > 0 else 0
# Get recent jobs
result = await session.execute(
text("""
SELECT job_id, tenant_id, status, progress, start_time, end_time
FROM model_training_logs
WHERE created_at >= :since
ORDER BY created_at DESC
LIMIT 10
"""),
{"since": since}
)
recent_jobs = [
{
"job_id": row.job_id,
"tenant_id": str(row.tenant_id),
"status": row.status,
"progress": row.progress,
"start_time": row.start_time.isoformat() if row.start_time else None,
"end_time": row.end_time.isoformat() if row.end_time else None
}
for row in result.fetchall()
]
return {
"period_hours": hours,
"timestamp": datetime.now(timezone.utc).isoformat(),
"summary": {
"total_jobs": total,
"by_status": status_counts,
"failure_rate_percent": round(failure_rate, 2),
"avg_duration_seconds": round(avg_duration, 2) if avg_duration else None
},
"recent_jobs": recent_jobs
}
except Exception as e:
logger.error(f"Failed to get training job stats: {e}")
return {
"error": str(e),
"timestamp": datetime.now(timezone.utc).isoformat()
}
@router.get("/monitoring/models")
async def get_model_stats() -> Dict[str, Any]:
"""
Get statistics about trained models.
"""
try:
async with database_manager.get_session() as session:
# Total models
result = await session.execute(
text("SELECT COUNT(*) FROM trained_models")
)
total_models = result.scalar()
# Active models
result = await session.execute(
text("SELECT COUNT(*) FROM trained_models WHERE is_active = true")
)
active_models = result.scalar()
# Production models
result = await session.execute(
text("SELECT COUNT(*) FROM trained_models WHERE is_production = true")
)
production_models = result.scalar()
# Models by type
result = await session.execute(
text("""
SELECT model_type, COUNT(*) as count
FROM trained_models
GROUP BY model_type
""")
)
models_by_type = dict(result.fetchall())
# Average model performance (MAPE)
result = await session.execute(
text("""
SELECT AVG(mape) as avg_mape
FROM trained_models
WHERE mape IS NOT NULL
AND is_active = true
""")
)
avg_mape = result.scalar()
# Models created today
today = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
result = await session.execute(
text("""
SELECT COUNT(*) FROM trained_models
WHERE created_at >= :today
"""),
{"today": today}
)
models_today = result.scalar()
return {
"timestamp": datetime.now(timezone.utc).isoformat(),
"summary": {
"total_models": total_models,
"active_models": active_models,
"production_models": production_models,
"models_created_today": models_today,
"average_mape_percent": round(avg_mape, 2) if avg_mape else None
},
"by_type": models_by_type
}
except Exception as e:
logger.error(f"Failed to get model stats: {e}")
return {
"error": str(e),
"timestamp": datetime.now(timezone.utc).isoformat()
}
@router.get("/monitoring/queue")
async def get_queue_status() -> Dict[str, Any]:
"""
Get training job queue status.
"""
try:
async with database_manager.get_session() as session:
# Queued jobs
result = await session.execute(
text("""
SELECT COUNT(*) FROM training_job_queue
WHERE status = 'queued'
""")
)
queued = result.scalar()
# Running jobs
result = await session.execute(
text("""
SELECT COUNT(*) FROM training_job_queue
WHERE status = 'running'
""")
)
running = result.scalar()
# Get oldest queued job
result = await session.execute(
text("""
SELECT created_at FROM training_job_queue
WHERE status = 'queued'
ORDER BY created_at ASC
LIMIT 1
""")
)
oldest_queued = result.scalar()
# Calculate wait time
if oldest_queued:
wait_time_seconds = (datetime.now(timezone.utc) - oldest_queued).total_seconds()
else:
wait_time_seconds = 0
return {
"timestamp": datetime.now(timezone.utc).isoformat(),
"queue": {
"queued": queued,
"running": running,
"oldest_wait_time_seconds": round(wait_time_seconds, 2) if oldest_queued else 0,
"oldest_queued_at": oldest_queued.isoformat() if oldest_queued else None
}
}
except Exception as e:
logger.error(f"Failed to get queue status: {e}")
return {
"error": str(e),
"timestamp": datetime.now(timezone.utc).isoformat()
}
@router.get("/monitoring/performance")
async def get_performance_metrics(
tenant_id: Optional[str] = Query(None, description="Filter by tenant ID")
) -> Dict[str, Any]:
"""
Get model performance metrics.
"""
try:
async with database_manager.get_session() as session:
query_params = {}
where_clause = ""
if tenant_id:
where_clause = "WHERE tenant_id = :tenant_id"
query_params["tenant_id"] = tenant_id
# Get performance distribution
result = await session.execute(
text(f"""
SELECT
COUNT(*) as total,
AVG(mape) as avg_mape,
MIN(mape) as min_mape,
MAX(mape) as max_mape,
PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY mape) as median_mape,
AVG(mae) as avg_mae,
AVG(rmse) as avg_rmse
FROM model_performance_metrics
{where_clause}
"""),
query_params
)
stats = result.fetchone()
# Get accuracy distribution (buckets)
result = await session.execute(
text(f"""
SELECT
CASE
WHEN mape <= 10 THEN 'excellent'
WHEN mape <= 20 THEN 'good'
WHEN mape <= 30 THEN 'acceptable'
ELSE 'poor'
END as accuracy_category,
COUNT(*) as count
FROM model_performance_metrics
{where_clause}
GROUP BY accuracy_category
"""),
query_params
)
distribution = dict(result.fetchall())
return {
"timestamp": datetime.now(timezone.utc).isoformat(),
"tenant_id": tenant_id,
"statistics": {
"total_metrics": stats.total if stats else 0,
"avg_mape_percent": round(stats.avg_mape, 2) if stats and stats.avg_mape else None,
"min_mape_percent": round(stats.min_mape, 2) if stats and stats.min_mape else None,
"max_mape_percent": round(stats.max_mape, 2) if stats and stats.max_mape else None,
"median_mape_percent": round(stats.median_mape, 2) if stats and stats.median_mape else None,
"avg_mae": round(stats.avg_mae, 2) if stats and stats.avg_mae else None,
"avg_rmse": round(stats.avg_rmse, 2) if stats and stats.avg_rmse else None
},
"distribution": distribution
}
except Exception as e:
logger.error(f"Failed to get performance metrics: {e}")
return {
"error": str(e),
"timestamp": datetime.now(timezone.utc).isoformat()
}
@router.get("/monitoring/alerts")
async def get_alerts() -> Dict[str, Any]:
"""
Get active alerts and warnings based on system state.
"""
alerts = []
warnings = []
try:
# Check circuit breakers
breakers = circuit_breaker_registry.get_all_states()
for name, state in breakers.items():
if state["state"] == "open":
alerts.append({
"type": "circuit_breaker_open",
"severity": "high",
"message": f"Circuit breaker '{name}' is OPEN - service unavailable",
"details": state
})
elif state["state"] == "half_open":
warnings.append({
"type": "circuit_breaker_recovering",
"severity": "medium",
"message": f"Circuit breaker '{name}' is recovering",
"details": state
})
# Check queue backlog
async with database_manager.get_session() as session:
result = await session.execute(
text("SELECT COUNT(*) FROM training_job_queue WHERE status = 'queued'")
)
queued = result.scalar()
if queued > 10:
warnings.append({
"type": "queue_backlog",
"severity": "medium",
"message": f"Training queue has {queued} pending jobs",
"count": queued
})
except Exception as e:
logger.error(f"Failed to generate alerts: {e}")
alerts.append({
"type": "monitoring_error",
"severity": "high",
"message": f"Failed to check system alerts: {str(e)}"
})
return {
"timestamp": datetime.now(timezone.utc).isoformat(),
"summary": {
"total_alerts": len(alerts),
"total_warnings": len(warnings)
},
"alerts": alerts,
"warnings": warnings
}

View File

@@ -1,21 +1,18 @@
"""
Training Operations API - BUSINESS logic
Handles training job execution, metrics, and WebSocket live feed
Handles training job execution and metrics
"""
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks, Request, Path, WebSocket, WebSocketDisconnect
from typing import List, Optional, Dict, Any
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks, Request, Path
from typing import Optional, Dict, Any
import structlog
import asyncio
import json
import datetime
from shared.auth.access_control import require_user_role, admin_role_required, analytics_tier_required
from datetime import datetime, timezone
import uuid
from shared.routing import RouteBuilder
from shared.monitoring.decorators import track_execution_time
from shared.monitoring.metrics import get_metrics_collector
from shared.database.base import create_database_manager
from datetime import datetime, timezone
import uuid
from app.services.training_service import EnhancedTrainingService
from app.schemas.training import (
@@ -23,15 +20,10 @@ from app.schemas.training import (
SingleProductTrainingRequest,
TrainingJobResponse
)
from app.services.messaging import (
publish_job_progress,
publish_data_validation_started,
publish_data_validation_completed,
publish_job_step_completed,
publish_job_completed,
publish_job_failed,
publish_job_started,
training_publisher
from app.services.training_events import (
publish_training_started,
publish_training_completed,
publish_training_failed
)
from app.core.config import settings
@@ -85,6 +77,14 @@ async def start_training_job(
if metrics:
metrics.increment_counter("enhanced_training_jobs_created_total")
# Publish training.started event immediately so WebSocket clients
# have initial state when they connect
await publish_training_started(
job_id=job_id,
tenant_id=tenant_id,
total_products=0 # Will be updated when actual training starts
)
# Add enhanced background task
background_tasks.add_task(
execute_training_job_background,
@@ -190,12 +190,8 @@ async def execute_training_job_background(
tenant_id=tenant_id
)
# Publish job started event
await publish_job_started(job_id, tenant_id, {
"enhanced_features": True,
"repository_pattern": True,
"job_type": "enhanced_training"
})
# This will be published by the training service itself
# when it starts execution
training_config = {
"job_id": job_id,
@@ -241,16 +237,7 @@ async def execute_training_job_background(
tenant_id=tenant_id
)
# Publish enhanced completion event
await publish_job_completed(
job_id=job_id,
tenant_id=tenant_id,
results={
**result,
"enhanced_features": True,
"repository_integration": True
}
)
# Completion event is published by the training service
logger.info("Enhanced background training job completed successfully",
job_id=job_id,
@@ -276,17 +263,8 @@ async def execute_training_job_background(
job_id=job_id,
status_error=str(status_error))
# Publish enhanced failure event
await publish_job_failed(
job_id=job_id,
tenant_id=tenant_id,
error=str(training_error),
metadata={
"enhanced_features": True,
"repository_pattern": True,
"error_type": type(training_error).__name__
}
)
# Failure event is published by the training service
await publish_training_failed(job_id, tenant_id, str(training_error))
except Exception as background_error:
logger.error("Critical error in enhanced background training job",
@@ -370,373 +348,19 @@ async def start_single_product_training(
)
# ============================================
# WebSocket Live Feed
# ============================================
class ConnectionManager:
"""Manage WebSocket connections for training progress"""
def __init__(self):
self.active_connections: Dict[str, Dict[str, WebSocket]] = {}
# Structure: {job_id: {connection_id: websocket}}
async def connect(self, websocket: WebSocket, job_id: str, connection_id: str):
"""Accept WebSocket connection and register it"""
await websocket.accept()
if job_id not in self.active_connections:
self.active_connections[job_id] = {}
self.active_connections[job_id][connection_id] = websocket
logger.info(f"WebSocket connected for job {job_id}, connection {connection_id}")
def disconnect(self, job_id: str, connection_id: str):
"""Remove WebSocket connection"""
if job_id in self.active_connections:
self.active_connections[job_id].pop(connection_id, None)
if not self.active_connections[job_id]:
del self.active_connections[job_id]
logger.info(f"WebSocket disconnected for job {job_id}, connection {connection_id}")
async def send_to_job(self, job_id: str, message: dict):
"""Send message to all connections for a specific job with better error handling"""
if job_id not in self.active_connections:
logger.debug(f"No active connections for job {job_id}")
return
# Send to all connections for this job
disconnected_connections = []
for connection_id, websocket in self.active_connections[job_id].items():
try:
await websocket.send_json(message)
logger.debug(f"Sent {message.get('type', 'unknown')} to connection {connection_id}")
except Exception as e:
logger.warning(f"Failed to send message to connection {connection_id}: {e}")
disconnected_connections.append(connection_id)
# Clean up disconnected connections
for connection_id in disconnected_connections:
self.disconnect(job_id, connection_id)
# Log successful sends
active_count = len(self.active_connections.get(job_id, {}))
if active_count > 0:
logger.info(f"Sent {message.get('type', 'unknown')} message to {active_count} connection(s) for job {job_id}")
# Global connection manager
connection_manager = ConnectionManager()
@router.websocket(route_builder.build_nested_resource_route('jobs', 'job_id', 'live'))
async def training_progress_websocket(
websocket: WebSocket,
tenant_id: str,
job_id: str
):
"""
WebSocket endpoint for real-time training progress updates
"""
# Validate token from query parameters
token = websocket.query_params.get("token")
if not token:
logger.warning(f"WebSocket connection rejected - missing token for job {job_id}")
await websocket.close(code=1008, reason="Authentication token required")
return
# Validate the token
from shared.auth.jwt_handler import JWTHandler
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
try:
payload = jwt_handler.verify_token(token)
if not payload:
logger.warning(f"WebSocket connection rejected - invalid token for job {job_id}")
await websocket.close(code=1008, reason="Invalid authentication token")
return
# Verify user has access to this tenant
user_id = payload.get('user_id')
if not user_id:
logger.warning(f"WebSocket connection rejected - no user_id in token for job {job_id}")
await websocket.close(code=1008, reason="Invalid token payload")
return
logger.info(f"WebSocket authenticated for user {payload.get('email', 'unknown')} on job {job_id}")
except Exception as e:
logger.warning(f"WebSocket token validation failed for job {job_id}: {e}")
await websocket.close(code=1008, reason="Token validation failed")
return
connection_id = f"{tenant_id}_{user_id}_{id(websocket)}"
await connection_manager.connect(websocket, job_id, connection_id)
logger.info(f"WebSocket connection established for job {job_id}, user {user_id}")
# Send immediate connection confirmation to prevent gateway timeout
try:
await websocket.send_json({
"type": "connected",
"job_id": job_id,
"message": "WebSocket connection established",
"timestamp": str(datetime.now())
})
logger.debug(f"Sent connection confirmation for job {job_id}")
except Exception as e:
logger.error(f"Failed to send connection confirmation for job {job_id}: {e}")
consumer_task = None
training_completed = False
try:
# Start RabbitMQ consumer
consumer_task = asyncio.create_task(
setup_rabbitmq_consumer_for_job(job_id, tenant_id)
)
last_activity = asyncio.get_event_loop().time()
while not training_completed:
try:
try:
data = await asyncio.wait_for(websocket.receive(), timeout=60.0)
last_activity = asyncio.get_event_loop().time()
# Handle different message types
if data["type"] == "websocket.receive":
if "text" in data:
message_text = data["text"]
if message_text == "ping":
await websocket.send_text("pong")
logger.debug(f"Text ping received from job {job_id}")
elif message_text == "get_status":
current_status = await get_current_job_status(job_id, tenant_id)
if current_status:
await websocket.send_json({
"type": "current_status",
"job_id": job_id,
"data": current_status
})
elif message_text == "close":
logger.info(f"Client requested connection close for job {job_id}")
break
elif "bytes" in data:
await websocket.send_text("pong")
logger.debug(f"Binary ping received for job {job_id}, responding with text pong")
elif data["type"] == "websocket.disconnect":
logger.info(f"WebSocket disconnect message received for job {job_id}")
break
except asyncio.TimeoutError:
current_time = asyncio.get_event_loop().time()
if current_time - last_activity > 90:
logger.warning(f"No frontend activity for 90s on job {job_id}, sending training service heartbeat")
try:
await websocket.send_json({
"type": "heartbeat",
"job_id": job_id,
"timestamp": str(datetime.now()),
"message": "Training service heartbeat - frontend inactive",
"inactivity_seconds": int(current_time - last_activity)
})
last_activity = current_time
except Exception as e:
logger.error(f"Failed to send heartbeat for job {job_id}: {e}")
break
else:
logger.debug(f"Normal 60s timeout for job {job_id}, continuing (last activity: {int(current_time - last_activity)}s ago)")
continue
except WebSocketDisconnect:
logger.info(f"WebSocket client disconnected for job {job_id}")
break
except Exception as e:
logger.error(f"WebSocket error for job {job_id}: {e}")
if "Cannot call" in str(e) and "disconnect message" in str(e):
logger.error(f"FastAPI WebSocket disconnect error - connection already closed")
break
await asyncio.sleep(1)
logger.info(f"WebSocket loop ended for job {job_id}, training_completed: {training_completed}")
except Exception as e:
logger.error(f"Critical WebSocket error for job {job_id}: {e}")
finally:
logger.info(f"Cleaning up WebSocket connection for job {job_id}")
connection_manager.disconnect(job_id, connection_id)
if consumer_task and not consumer_task.done():
if training_completed:
logger.info(f"Training completed, cancelling consumer for job {job_id}")
consumer_task.cancel()
else:
logger.warning(f"WebSocket disconnected but training not completed for job {job_id}")
try:
await consumer_task
except asyncio.CancelledError:
logger.info(f"Consumer task cancelled for job {job_id}")
except Exception as e:
logger.error(f"Consumer task error for job {job_id}: {e}")
async def setup_rabbitmq_consumer_for_job(job_id: str, tenant_id: str):
"""Set up RabbitMQ consumer to listen for training events for a specific job"""
logger.info(f"Setting up RabbitMQ consumer for job {job_id}")
try:
# Create a unique queue for this WebSocket connection
queue_name = f"websocket_training_{job_id}_{tenant_id}"
async def handle_training_message(message):
"""Handle incoming RabbitMQ messages and forward to WebSocket"""
try:
# Parse the message
body = message.body.decode()
data = json.loads(body)
logger.debug(f"Received message for job {job_id}: {data.get('event_type', 'unknown')}")
# Extract event data
event_type = data.get("event_type", "unknown")
event_data = data.get("data", {})
# Only process messages for this specific job
message_job_id = event_data.get("job_id") if event_data else None
if message_job_id != job_id:
logger.debug(f"Ignoring message for different job: {message_job_id}")
await message.ack()
return
# Transform RabbitMQ message to WebSocket message format
websocket_message = {
"type": map_event_type_to_websocket_type(event_type),
"job_id": job_id,
"timestamp": data.get("timestamp"),
"data": event_data
}
logger.info(f"Forwarding {event_type} message to WebSocket clients for job {job_id}")
# Send to all WebSocket connections for this job
await connection_manager.send_to_job(job_id, websocket_message)
# Check if this is a completion message
if event_type in ["training.completed", "training.failed"]:
logger.info(f"Training completion detected for job {job_id}: {event_type}")
# Acknowledge the message
await message.ack()
logger.debug(f"Successfully processed {event_type} for job {job_id}")
except Exception as e:
logger.error(f"Error handling training message for job {job_id}: {e}")
import traceback
logger.error(f"Traceback: {traceback.format_exc()}")
await message.nack(requeue=False)
# Check if training_publisher is connected
if not training_publisher.connected:
logger.warning(f"Training publisher not connected for job {job_id}, attempting to connect...")
success = await training_publisher.connect()
if not success:
logger.error(f"Failed to connect training_publisher for job {job_id}")
return
# Subscribe to training events
logger.info(f"Subscribing to training events for job {job_id}")
success = await training_publisher.consume_events(
exchange_name="training.events",
queue_name=queue_name,
routing_key="training.*",
callback=handle_training_message
)
if success:
logger.info(f"Successfully set up RabbitMQ consumer for job {job_id} (queue: {queue_name})")
# Keep the consumer running indefinitely until cancelled
try:
while True:
await asyncio.sleep(10)
logger.debug(f"Consumer heartbeat for job {job_id}")
except asyncio.CancelledError:
logger.info(f"Consumer cancelled for job {job_id}")
raise
except Exception as e:
logger.error(f"Consumer error for job {job_id}: {e}")
raise
else:
logger.error(f"Failed to set up RabbitMQ consumer for job {job_id}")
except Exception as e:
logger.error(f"Exception in setup_rabbitmq_consumer_for_job for job {job_id}: {e}")
import traceback
logger.error(f"Traceback: {traceback.format_exc()}")
def map_event_type_to_websocket_type(rabbitmq_event_type: str) -> str:
"""Map RabbitMQ event types to WebSocket message types"""
mapping = {
"training.started": "started",
"training.progress": "progress",
"training.completed": "completed",
"training.failed": "failed",
"training.cancelled": "cancelled",
"training.step.completed": "step_completed",
"training.product.started": "product_started",
"training.product.completed": "product_completed",
"training.product.failed": "product_failed",
"training.model.trained": "model_trained",
"training.data.validation.started": "validation_started",
"training.data.validation.completed": "validation_completed"
}
return mapping.get(rabbitmq_event_type, "unknown")
async def get_current_job_status(job_id: str, tenant_id: str) -> Dict[str, Any]:
"""Get current job status from database"""
try:
return {
"job_id": job_id,
"status": "running",
"progress": 0,
"current_step": "Starting...",
"started_at": "2025-07-30T19:00:00Z"
}
except Exception as e:
logger.error(f"Failed to get current job status: {e}")
return None
@router.get("/health")
async def health_check():
"""Health check endpoint for the training operations"""
return {
"status": "healthy",
"service": "training-operations",
"version": "2.0.0",
"version": "3.0.0",
"features": [
"repository-pattern",
"dependency-injection",
"enhanced-error-handling",
"metrics-tracking",
"transactional-operations",
"websocket-support"
"transactional-operations"
],
"timestamp": datetime.now().isoformat()
}

View File

@@ -0,0 +1,109 @@
"""
WebSocket Operations for Training Service
Simple WebSocket endpoint that connects clients and receives broadcasts from RabbitMQ
"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Path, Query
import structlog
from app.websocket.manager import websocket_manager
from shared.auth.jwt_handler import JWTHandler
from app.core.config import settings
logger = structlog.get_logger()
router = APIRouter(tags=["websocket"])
@router.websocket("/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live")
async def training_progress_websocket(
websocket: WebSocket,
tenant_id: str = Path(..., description="Tenant ID"),
job_id: str = Path(..., description="Job ID"),
token: str = Query(..., description="Authentication token")
):
"""
WebSocket endpoint for real-time training progress updates.
This endpoint:
1. Validates the authentication token
2. Accepts the WebSocket connection
3. Keeps the connection alive
4. Receives broadcasts from RabbitMQ (via WebSocket manager)
"""
# Validate token
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
try:
payload = jwt_handler.verify_token(token)
if not payload:
await websocket.close(code=1008, reason="Invalid token")
logger.warning("WebSocket connection rejected - invalid token",
job_id=job_id,
tenant_id=tenant_id)
return
user_id = payload.get('user_id')
if not user_id:
await websocket.close(code=1008, reason="Invalid token payload")
logger.warning("WebSocket connection rejected - no user_id in token",
job_id=job_id,
tenant_id=tenant_id)
return
logger.info("WebSocket authentication successful",
user_id=user_id,
tenant_id=tenant_id,
job_id=job_id)
except Exception as e:
await websocket.close(code=1008, reason="Authentication failed")
logger.warning("WebSocket authentication failed",
job_id=job_id,
tenant_id=tenant_id,
error=str(e))
return
# Connect to WebSocket manager
await websocket_manager.connect(job_id, websocket)
try:
# Send connection confirmation
await websocket.send_json({
"type": "connected",
"job_id": job_id,
"message": "Connected to training progress stream"
})
# Keep connection alive and handle client messages
ping_count = 0
while True:
try:
# Receive messages from client (ping, etc.)
data = await websocket.receive_text()
# Handle ping/pong
if data == "ping":
await websocket.send_text("pong")
ping_count += 1
logger.info("WebSocket ping/pong",
job_id=job_id,
ping_count=ping_count,
connection_healthy=True)
except WebSocketDisconnect:
logger.info("Client disconnected", job_id=job_id)
break
except Exception as e:
logger.error("Error in WebSocket message loop",
job_id=job_id,
error=str(e))
break
finally:
# Disconnect from manager
await websocket_manager.disconnect(job_id, websocket)
logger.info("WebSocket connection closed",
job_id=job_id,
tenant_id=tenant_id)