REFACTOR external service and improve websocket training
This commit is contained in:
@@ -1,14 +1,16 @@
|
||||
"""
|
||||
Training API Layer
|
||||
HTTP endpoints for ML training operations
|
||||
HTTP endpoints for ML training operations and WebSocket connections
|
||||
"""
|
||||
|
||||
from .training_jobs import router as training_jobs_router
|
||||
from .training_operations import router as training_operations_router
|
||||
from .models import router as models_router
|
||||
from .websocket_operations import router as websocket_operations_router
|
||||
|
||||
__all__ = [
|
||||
"training_jobs_router",
|
||||
"training_operations_router",
|
||||
"models_router"
|
||||
"models_router",
|
||||
"websocket_operations_router"
|
||||
]
|
||||
261
services/training/app/api/health.py
Normal file
261
services/training/app/api/health.py
Normal file
@@ -0,0 +1,261 @@
|
||||
"""
|
||||
Enhanced Health Check Endpoints
|
||||
Comprehensive service health monitoring
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import text
|
||||
from typing import Dict, Any
|
||||
import psutil
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
import logging
|
||||
|
||||
from app.core.database import database_manager
|
||||
from app.utils.circuit_breaker import circuit_breaker_registry
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def check_database_health() -> Dict[str, Any]:
|
||||
"""Check database connectivity and performance"""
|
||||
try:
|
||||
start_time = datetime.now(timezone.utc)
|
||||
|
||||
async with database_manager.async_engine.begin() as conn:
|
||||
# Simple connectivity check
|
||||
await conn.execute(text("SELECT 1"))
|
||||
|
||||
# Check if we can access training tables
|
||||
result = await conn.execute(
|
||||
text("SELECT COUNT(*) FROM trained_models")
|
||||
)
|
||||
model_count = result.scalar()
|
||||
|
||||
# Check connection pool stats
|
||||
pool = database_manager.async_engine.pool
|
||||
pool_size = pool.size()
|
||||
pool_checked_out = pool.checked_out_connections()
|
||||
|
||||
response_time = (datetime.now(timezone.utc) - start_time).total_seconds()
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"response_time_seconds": round(response_time, 3),
|
||||
"model_count": model_count,
|
||||
"connection_pool": {
|
||||
"size": pool_size,
|
||||
"checked_out": pool_checked_out,
|
||||
"available": pool_size - pool_checked_out
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Database health check failed: {e}")
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
def check_system_resources() -> Dict[str, Any]:
|
||||
"""Check system resource usage"""
|
||||
try:
|
||||
cpu_percent = psutil.cpu_percent(interval=0.1)
|
||||
memory = psutil.virtual_memory()
|
||||
disk = psutil.disk_usage('/')
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"cpu": {
|
||||
"usage_percent": cpu_percent,
|
||||
"count": psutil.cpu_count()
|
||||
},
|
||||
"memory": {
|
||||
"total_mb": round(memory.total / 1024 / 1024, 2),
|
||||
"used_mb": round(memory.used / 1024 / 1024, 2),
|
||||
"available_mb": round(memory.available / 1024 / 1024, 2),
|
||||
"usage_percent": memory.percent
|
||||
},
|
||||
"disk": {
|
||||
"total_gb": round(disk.total / 1024 / 1024 / 1024, 2),
|
||||
"used_gb": round(disk.used / 1024 / 1024 / 1024, 2),
|
||||
"free_gb": round(disk.free / 1024 / 1024 / 1024, 2),
|
||||
"usage_percent": disk.percent
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"System resource check failed: {e}")
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
def check_model_storage() -> Dict[str, Any]:
|
||||
"""Check model storage health"""
|
||||
try:
|
||||
storage_path = settings.MODEL_STORAGE_PATH
|
||||
|
||||
if not os.path.exists(storage_path):
|
||||
return {
|
||||
"status": "warning",
|
||||
"message": f"Model storage path does not exist: {storage_path}"
|
||||
}
|
||||
|
||||
# Check if writable
|
||||
test_file = os.path.join(storage_path, ".health_check")
|
||||
try:
|
||||
with open(test_file, 'w') as f:
|
||||
f.write("test")
|
||||
os.remove(test_file)
|
||||
writable = True
|
||||
except Exception:
|
||||
writable = False
|
||||
|
||||
# Count model files
|
||||
model_files = 0
|
||||
total_size = 0
|
||||
for root, dirs, files in os.walk(storage_path):
|
||||
for file in files:
|
||||
if file.endswith('.pkl'):
|
||||
model_files += 1
|
||||
file_path = os.path.join(root, file)
|
||||
total_size += os.path.getsize(file_path)
|
||||
|
||||
return {
|
||||
"status": "healthy" if writable else "degraded",
|
||||
"path": storage_path,
|
||||
"writable": writable,
|
||||
"model_files": model_files,
|
||||
"total_size_mb": round(total_size / 1024 / 1024, 2)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Model storage check failed: {e}")
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check() -> Dict[str, Any]:
|
||||
"""
|
||||
Basic health check endpoint.
|
||||
Returns 200 if service is running.
|
||||
"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "training-service",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/health/detailed")
|
||||
async def detailed_health_check() -> Dict[str, Any]:
|
||||
"""
|
||||
Detailed health check with component status.
|
||||
Includes database, system resources, and dependencies.
|
||||
"""
|
||||
database_health = await check_database_health()
|
||||
system_health = check_system_resources()
|
||||
storage_health = check_model_storage()
|
||||
circuit_breakers = circuit_breaker_registry.get_all_states()
|
||||
|
||||
# Determine overall status
|
||||
component_statuses = [
|
||||
database_health.get("status"),
|
||||
system_health.get("status"),
|
||||
storage_health.get("status")
|
||||
]
|
||||
|
||||
if "unhealthy" in component_statuses or "error" in component_statuses:
|
||||
overall_status = "unhealthy"
|
||||
elif "degraded" in component_statuses or "warning" in component_statuses:
|
||||
overall_status = "degraded"
|
||||
else:
|
||||
overall_status = "healthy"
|
||||
|
||||
return {
|
||||
"status": overall_status,
|
||||
"service": "training-service",
|
||||
"version": "1.0.0",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"components": {
|
||||
"database": database_health,
|
||||
"system": system_health,
|
||||
"storage": storage_health
|
||||
},
|
||||
"circuit_breakers": circuit_breakers,
|
||||
"configuration": {
|
||||
"max_concurrent_jobs": settings.MAX_CONCURRENT_TRAINING_JOBS,
|
||||
"min_training_days": settings.MIN_TRAINING_DATA_DAYS,
|
||||
"pool_size": settings.DB_POOL_SIZE,
|
||||
"pool_max_overflow": settings.DB_MAX_OVERFLOW
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.get("/health/ready")
|
||||
async def readiness_check() -> Dict[str, Any]:
|
||||
"""
|
||||
Readiness check for Kubernetes.
|
||||
Returns 200 only if service is ready to accept traffic.
|
||||
"""
|
||||
database_health = await check_database_health()
|
||||
|
||||
if database_health.get("status") != "healthy":
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Service not ready: database unavailable"
|
||||
)
|
||||
|
||||
storage_health = check_model_storage()
|
||||
if storage_health.get("status") == "error":
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Service not ready: model storage unavailable"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "ready",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/health/live")
|
||||
async def liveness_check() -> Dict[str, Any]:
|
||||
"""
|
||||
Liveness check for Kubernetes.
|
||||
Returns 200 if service process is alive.
|
||||
"""
|
||||
return {
|
||||
"status": "alive",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"pid": os.getpid()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/metrics/system")
|
||||
async def system_metrics() -> Dict[str, Any]:
|
||||
"""
|
||||
Detailed system metrics for monitoring.
|
||||
"""
|
||||
process = psutil.Process(os.getpid())
|
||||
|
||||
return {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"process": {
|
||||
"pid": os.getpid(),
|
||||
"cpu_percent": process.cpu_percent(interval=0.1),
|
||||
"memory_mb": round(process.memory_info().rss / 1024 / 1024, 2),
|
||||
"threads": process.num_threads(),
|
||||
"open_files": len(process.open_files()),
|
||||
"connections": len(process.connections())
|
||||
},
|
||||
"system": check_system_resources()
|
||||
}
|
||||
@@ -10,14 +10,12 @@ from sqlalchemy import text
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.schemas.training import TrainedModelResponse, ModelMetricsResponse
|
||||
from app.services.training_service import TrainingService
|
||||
from app.services.training_service import EnhancedTrainingService
|
||||
from datetime import datetime
|
||||
from sqlalchemy import select, delete, func
|
||||
import uuid
|
||||
import shutil
|
||||
|
||||
from app.services.messaging import publish_models_deleted_event
|
||||
|
||||
from shared.auth.decorators import (
|
||||
get_current_user_dep,
|
||||
require_admin_role
|
||||
@@ -38,7 +36,7 @@ route_builder = RouteBuilder('training')
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter()
|
||||
|
||||
training_service = TrainingService()
|
||||
training_service = EnhancedTrainingService()
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("models") + "/{inventory_product_id}/active"
|
||||
@@ -472,12 +470,7 @@ async def delete_tenant_models_complete(
|
||||
deletion_stats["errors"].append(error_msg)
|
||||
logger.warning(error_msg)
|
||||
|
||||
# Step 5: Publish deletion event
|
||||
try:
|
||||
await publish_models_deleted_event(tenant_id, deletion_stats)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish models deletion event", error=str(e))
|
||||
|
||||
# Models deleted successfully
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"All training data for tenant {tenant_id} deleted successfully",
|
||||
|
||||
410
services/training/app/api/monitoring.py
Normal file
410
services/training/app/api/monitoring.py
Normal file
@@ -0,0 +1,410 @@
|
||||
"""
|
||||
Monitoring and Observability Endpoints
|
||||
Real-time service monitoring and diagnostics
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Query
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from sqlalchemy import text, func
|
||||
import logging
|
||||
|
||||
from app.core.database import database_manager
|
||||
from app.utils.circuit_breaker import circuit_breaker_registry
|
||||
from app.models.training import ModelTrainingLog, TrainingJobQueue, TrainedModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/monitoring/circuit-breakers")
|
||||
async def get_circuit_breaker_status() -> Dict[str, Any]:
|
||||
"""
|
||||
Get status of all circuit breakers.
|
||||
Useful for monitoring external service health.
|
||||
"""
|
||||
breakers = circuit_breaker_registry.get_all_states()
|
||||
|
||||
return {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"circuit_breakers": breakers,
|
||||
"summary": {
|
||||
"total": len(breakers),
|
||||
"open": sum(1 for b in breakers.values() if b["state"] == "open"),
|
||||
"half_open": sum(1 for b in breakers.values() if b["state"] == "half_open"),
|
||||
"closed": sum(1 for b in breakers.values() if b["state"] == "closed")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.post("/monitoring/circuit-breakers/{name}/reset")
|
||||
async def reset_circuit_breaker(name: str) -> Dict[str, str]:
|
||||
"""
|
||||
Manually reset a circuit breaker.
|
||||
Use with caution - only reset if you know the service has recovered.
|
||||
"""
|
||||
circuit_breaker_registry.reset(name)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Circuit breaker '{name}' has been reset",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/monitoring/training-jobs")
|
||||
async def get_training_job_stats(
|
||||
hours: int = Query(default=24, ge=1, le=168, description="Look back period in hours")
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get training job statistics for the specified period.
|
||||
"""
|
||||
try:
|
||||
since = datetime.now(timezone.utc) - timedelta(hours=hours)
|
||||
|
||||
async with database_manager.get_session() as session:
|
||||
# Get job counts by status
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT status, COUNT(*) as count
|
||||
FROM model_training_logs
|
||||
WHERE created_at >= :since
|
||||
GROUP BY status
|
||||
"""),
|
||||
{"since": since}
|
||||
)
|
||||
status_counts = dict(result.fetchall())
|
||||
|
||||
# Get average training time for completed jobs
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT AVG(EXTRACT(EPOCH FROM (end_time - start_time))) as avg_duration
|
||||
FROM model_training_logs
|
||||
WHERE status = 'completed'
|
||||
AND created_at >= :since
|
||||
AND end_time IS NOT NULL
|
||||
"""),
|
||||
{"since": since}
|
||||
)
|
||||
avg_duration = result.scalar()
|
||||
|
||||
# Get failure rate
|
||||
total = sum(status_counts.values())
|
||||
failed = status_counts.get('failed', 0)
|
||||
failure_rate = (failed / total * 100) if total > 0 else 0
|
||||
|
||||
# Get recent jobs
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT job_id, tenant_id, status, progress, start_time, end_time
|
||||
FROM model_training_logs
|
||||
WHERE created_at >= :since
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 10
|
||||
"""),
|
||||
{"since": since}
|
||||
)
|
||||
recent_jobs = [
|
||||
{
|
||||
"job_id": row.job_id,
|
||||
"tenant_id": str(row.tenant_id),
|
||||
"status": row.status,
|
||||
"progress": row.progress,
|
||||
"start_time": row.start_time.isoformat() if row.start_time else None,
|
||||
"end_time": row.end_time.isoformat() if row.end_time else None
|
||||
}
|
||||
for row in result.fetchall()
|
||||
]
|
||||
|
||||
return {
|
||||
"period_hours": hours,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"summary": {
|
||||
"total_jobs": total,
|
||||
"by_status": status_counts,
|
||||
"failure_rate_percent": round(failure_rate, 2),
|
||||
"avg_duration_seconds": round(avg_duration, 2) if avg_duration else None
|
||||
},
|
||||
"recent_jobs": recent_jobs
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get training job stats: {e}")
|
||||
return {
|
||||
"error": str(e),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/monitoring/models")
|
||||
async def get_model_stats() -> Dict[str, Any]:
|
||||
"""
|
||||
Get statistics about trained models.
|
||||
"""
|
||||
try:
|
||||
async with database_manager.get_session() as session:
|
||||
# Total models
|
||||
result = await session.execute(
|
||||
text("SELECT COUNT(*) FROM trained_models")
|
||||
)
|
||||
total_models = result.scalar()
|
||||
|
||||
# Active models
|
||||
result = await session.execute(
|
||||
text("SELECT COUNT(*) FROM trained_models WHERE is_active = true")
|
||||
)
|
||||
active_models = result.scalar()
|
||||
|
||||
# Production models
|
||||
result = await session.execute(
|
||||
text("SELECT COUNT(*) FROM trained_models WHERE is_production = true")
|
||||
)
|
||||
production_models = result.scalar()
|
||||
|
||||
# Models by type
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT model_type, COUNT(*) as count
|
||||
FROM trained_models
|
||||
GROUP BY model_type
|
||||
""")
|
||||
)
|
||||
models_by_type = dict(result.fetchall())
|
||||
|
||||
# Average model performance (MAPE)
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT AVG(mape) as avg_mape
|
||||
FROM trained_models
|
||||
WHERE mape IS NOT NULL
|
||||
AND is_active = true
|
||||
""")
|
||||
)
|
||||
avg_mape = result.scalar()
|
||||
|
||||
# Models created today
|
||||
today = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT COUNT(*) FROM trained_models
|
||||
WHERE created_at >= :today
|
||||
"""),
|
||||
{"today": today}
|
||||
)
|
||||
models_today = result.scalar()
|
||||
|
||||
return {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"summary": {
|
||||
"total_models": total_models,
|
||||
"active_models": active_models,
|
||||
"production_models": production_models,
|
||||
"models_created_today": models_today,
|
||||
"average_mape_percent": round(avg_mape, 2) if avg_mape else None
|
||||
},
|
||||
"by_type": models_by_type
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get model stats: {e}")
|
||||
return {
|
||||
"error": str(e),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/monitoring/queue")
|
||||
async def get_queue_status() -> Dict[str, Any]:
|
||||
"""
|
||||
Get training job queue status.
|
||||
"""
|
||||
try:
|
||||
async with database_manager.get_session() as session:
|
||||
# Queued jobs
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT COUNT(*) FROM training_job_queue
|
||||
WHERE status = 'queued'
|
||||
""")
|
||||
)
|
||||
queued = result.scalar()
|
||||
|
||||
# Running jobs
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT COUNT(*) FROM training_job_queue
|
||||
WHERE status = 'running'
|
||||
""")
|
||||
)
|
||||
running = result.scalar()
|
||||
|
||||
# Get oldest queued job
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT created_at FROM training_job_queue
|
||||
WHERE status = 'queued'
|
||||
ORDER BY created_at ASC
|
||||
LIMIT 1
|
||||
""")
|
||||
)
|
||||
oldest_queued = result.scalar()
|
||||
|
||||
# Calculate wait time
|
||||
if oldest_queued:
|
||||
wait_time_seconds = (datetime.now(timezone.utc) - oldest_queued).total_seconds()
|
||||
else:
|
||||
wait_time_seconds = 0
|
||||
|
||||
return {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"queue": {
|
||||
"queued": queued,
|
||||
"running": running,
|
||||
"oldest_wait_time_seconds": round(wait_time_seconds, 2) if oldest_queued else 0,
|
||||
"oldest_queued_at": oldest_queued.isoformat() if oldest_queued else None
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get queue status: {e}")
|
||||
return {
|
||||
"error": str(e),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/monitoring/performance")
|
||||
async def get_performance_metrics(
|
||||
tenant_id: Optional[str] = Query(None, description="Filter by tenant ID")
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get model performance metrics.
|
||||
"""
|
||||
try:
|
||||
async with database_manager.get_session() as session:
|
||||
query_params = {}
|
||||
where_clause = ""
|
||||
|
||||
if tenant_id:
|
||||
where_clause = "WHERE tenant_id = :tenant_id"
|
||||
query_params["tenant_id"] = tenant_id
|
||||
|
||||
# Get performance distribution
|
||||
result = await session.execute(
|
||||
text(f"""
|
||||
SELECT
|
||||
COUNT(*) as total,
|
||||
AVG(mape) as avg_mape,
|
||||
MIN(mape) as min_mape,
|
||||
MAX(mape) as max_mape,
|
||||
PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY mape) as median_mape,
|
||||
AVG(mae) as avg_mae,
|
||||
AVG(rmse) as avg_rmse
|
||||
FROM model_performance_metrics
|
||||
{where_clause}
|
||||
"""),
|
||||
query_params
|
||||
)
|
||||
stats = result.fetchone()
|
||||
|
||||
# Get accuracy distribution (buckets)
|
||||
result = await session.execute(
|
||||
text(f"""
|
||||
SELECT
|
||||
CASE
|
||||
WHEN mape <= 10 THEN 'excellent'
|
||||
WHEN mape <= 20 THEN 'good'
|
||||
WHEN mape <= 30 THEN 'acceptable'
|
||||
ELSE 'poor'
|
||||
END as accuracy_category,
|
||||
COUNT(*) as count
|
||||
FROM model_performance_metrics
|
||||
{where_clause}
|
||||
GROUP BY accuracy_category
|
||||
"""),
|
||||
query_params
|
||||
)
|
||||
distribution = dict(result.fetchall())
|
||||
|
||||
return {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"tenant_id": tenant_id,
|
||||
"statistics": {
|
||||
"total_metrics": stats.total if stats else 0,
|
||||
"avg_mape_percent": round(stats.avg_mape, 2) if stats and stats.avg_mape else None,
|
||||
"min_mape_percent": round(stats.min_mape, 2) if stats and stats.min_mape else None,
|
||||
"max_mape_percent": round(stats.max_mape, 2) if stats and stats.max_mape else None,
|
||||
"median_mape_percent": round(stats.median_mape, 2) if stats and stats.median_mape else None,
|
||||
"avg_mae": round(stats.avg_mae, 2) if stats and stats.avg_mae else None,
|
||||
"avg_rmse": round(stats.avg_rmse, 2) if stats and stats.avg_rmse else None
|
||||
},
|
||||
"distribution": distribution
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get performance metrics: {e}")
|
||||
return {
|
||||
"error": str(e),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/monitoring/alerts")
|
||||
async def get_alerts() -> Dict[str, Any]:
|
||||
"""
|
||||
Get active alerts and warnings based on system state.
|
||||
"""
|
||||
alerts = []
|
||||
warnings = []
|
||||
|
||||
try:
|
||||
# Check circuit breakers
|
||||
breakers = circuit_breaker_registry.get_all_states()
|
||||
for name, state in breakers.items():
|
||||
if state["state"] == "open":
|
||||
alerts.append({
|
||||
"type": "circuit_breaker_open",
|
||||
"severity": "high",
|
||||
"message": f"Circuit breaker '{name}' is OPEN - service unavailable",
|
||||
"details": state
|
||||
})
|
||||
elif state["state"] == "half_open":
|
||||
warnings.append({
|
||||
"type": "circuit_breaker_recovering",
|
||||
"severity": "medium",
|
||||
"message": f"Circuit breaker '{name}' is recovering",
|
||||
"details": state
|
||||
})
|
||||
|
||||
# Check queue backlog
|
||||
async with database_manager.get_session() as session:
|
||||
result = await session.execute(
|
||||
text("SELECT COUNT(*) FROM training_job_queue WHERE status = 'queued'")
|
||||
)
|
||||
queued = result.scalar()
|
||||
|
||||
if queued > 10:
|
||||
warnings.append({
|
||||
"type": "queue_backlog",
|
||||
"severity": "medium",
|
||||
"message": f"Training queue has {queued} pending jobs",
|
||||
"count": queued
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate alerts: {e}")
|
||||
alerts.append({
|
||||
"type": "monitoring_error",
|
||||
"severity": "high",
|
||||
"message": f"Failed to check system alerts: {str(e)}"
|
||||
})
|
||||
|
||||
return {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"summary": {
|
||||
"total_alerts": len(alerts),
|
||||
"total_warnings": len(warnings)
|
||||
},
|
||||
"alerts": alerts,
|
||||
"warnings": warnings
|
||||
}
|
||||
@@ -1,21 +1,18 @@
|
||||
"""
|
||||
Training Operations API - BUSINESS logic
|
||||
Handles training job execution, metrics, and WebSocket live feed
|
||||
Handles training job execution and metrics
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks, Request, Path, WebSocket, WebSocketDisconnect
|
||||
from typing import List, Optional, Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks, Request, Path
|
||||
from typing import Optional, Dict, Any
|
||||
import structlog
|
||||
import asyncio
|
||||
import json
|
||||
import datetime
|
||||
from shared.auth.access_control import require_user_role, admin_role_required, analytics_tier_required
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
|
||||
from shared.routing import RouteBuilder
|
||||
from shared.monitoring.decorators import track_execution_time
|
||||
from shared.monitoring.metrics import get_metrics_collector
|
||||
from shared.database.base import create_database_manager
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
|
||||
from app.services.training_service import EnhancedTrainingService
|
||||
from app.schemas.training import (
|
||||
@@ -23,15 +20,10 @@ from app.schemas.training import (
|
||||
SingleProductTrainingRequest,
|
||||
TrainingJobResponse
|
||||
)
|
||||
from app.services.messaging import (
|
||||
publish_job_progress,
|
||||
publish_data_validation_started,
|
||||
publish_data_validation_completed,
|
||||
publish_job_step_completed,
|
||||
publish_job_completed,
|
||||
publish_job_failed,
|
||||
publish_job_started,
|
||||
training_publisher
|
||||
from app.services.training_events import (
|
||||
publish_training_started,
|
||||
publish_training_completed,
|
||||
publish_training_failed
|
||||
)
|
||||
from app.core.config import settings
|
||||
|
||||
@@ -85,6 +77,14 @@ async def start_training_job(
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_training_jobs_created_total")
|
||||
|
||||
# Publish training.started event immediately so WebSocket clients
|
||||
# have initial state when they connect
|
||||
await publish_training_started(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
total_products=0 # Will be updated when actual training starts
|
||||
)
|
||||
|
||||
# Add enhanced background task
|
||||
background_tasks.add_task(
|
||||
execute_training_job_background,
|
||||
@@ -190,12 +190,8 @@ async def execute_training_job_background(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# Publish job started event
|
||||
await publish_job_started(job_id, tenant_id, {
|
||||
"enhanced_features": True,
|
||||
"repository_pattern": True,
|
||||
"job_type": "enhanced_training"
|
||||
})
|
||||
# This will be published by the training service itself
|
||||
# when it starts execution
|
||||
|
||||
training_config = {
|
||||
"job_id": job_id,
|
||||
@@ -241,16 +237,7 @@ async def execute_training_job_background(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# Publish enhanced completion event
|
||||
await publish_job_completed(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
results={
|
||||
**result,
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True
|
||||
}
|
||||
)
|
||||
# Completion event is published by the training service
|
||||
|
||||
logger.info("Enhanced background training job completed successfully",
|
||||
job_id=job_id,
|
||||
@@ -276,17 +263,8 @@ async def execute_training_job_background(
|
||||
job_id=job_id,
|
||||
status_error=str(status_error))
|
||||
|
||||
# Publish enhanced failure event
|
||||
await publish_job_failed(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
error=str(training_error),
|
||||
metadata={
|
||||
"enhanced_features": True,
|
||||
"repository_pattern": True,
|
||||
"error_type": type(training_error).__name__
|
||||
}
|
||||
)
|
||||
# Failure event is published by the training service
|
||||
await publish_training_failed(job_id, tenant_id, str(training_error))
|
||||
|
||||
except Exception as background_error:
|
||||
logger.error("Critical error in enhanced background training job",
|
||||
@@ -370,373 +348,19 @@ async def start_single_product_training(
|
||||
)
|
||||
|
||||
|
||||
# ============================================
|
||||
# WebSocket Live Feed
|
||||
# ============================================
|
||||
|
||||
class ConnectionManager:
|
||||
"""Manage WebSocket connections for training progress"""
|
||||
|
||||
def __init__(self):
|
||||
self.active_connections: Dict[str, Dict[str, WebSocket]] = {}
|
||||
# Structure: {job_id: {connection_id: websocket}}
|
||||
|
||||
async def connect(self, websocket: WebSocket, job_id: str, connection_id: str):
|
||||
"""Accept WebSocket connection and register it"""
|
||||
await websocket.accept()
|
||||
|
||||
if job_id not in self.active_connections:
|
||||
self.active_connections[job_id] = {}
|
||||
|
||||
self.active_connections[job_id][connection_id] = websocket
|
||||
logger.info(f"WebSocket connected for job {job_id}, connection {connection_id}")
|
||||
|
||||
def disconnect(self, job_id: str, connection_id: str):
|
||||
"""Remove WebSocket connection"""
|
||||
if job_id in self.active_connections:
|
||||
self.active_connections[job_id].pop(connection_id, None)
|
||||
if not self.active_connections[job_id]:
|
||||
del self.active_connections[job_id]
|
||||
|
||||
logger.info(f"WebSocket disconnected for job {job_id}, connection {connection_id}")
|
||||
|
||||
async def send_to_job(self, job_id: str, message: dict):
|
||||
"""Send message to all connections for a specific job with better error handling"""
|
||||
if job_id not in self.active_connections:
|
||||
logger.debug(f"No active connections for job {job_id}")
|
||||
return
|
||||
|
||||
# Send to all connections for this job
|
||||
disconnected_connections = []
|
||||
|
||||
for connection_id, websocket in self.active_connections[job_id].items():
|
||||
try:
|
||||
await websocket.send_json(message)
|
||||
logger.debug(f"Sent {message.get('type', 'unknown')} to connection {connection_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send message to connection {connection_id}: {e}")
|
||||
disconnected_connections.append(connection_id)
|
||||
|
||||
# Clean up disconnected connections
|
||||
for connection_id in disconnected_connections:
|
||||
self.disconnect(job_id, connection_id)
|
||||
|
||||
# Log successful sends
|
||||
active_count = len(self.active_connections.get(job_id, {}))
|
||||
if active_count > 0:
|
||||
logger.info(f"Sent {message.get('type', 'unknown')} message to {active_count} connection(s) for job {job_id}")
|
||||
|
||||
|
||||
# Global connection manager
|
||||
connection_manager = ConnectionManager()
|
||||
|
||||
|
||||
@router.websocket(route_builder.build_nested_resource_route('jobs', 'job_id', 'live'))
|
||||
async def training_progress_websocket(
|
||||
websocket: WebSocket,
|
||||
tenant_id: str,
|
||||
job_id: str
|
||||
):
|
||||
"""
|
||||
WebSocket endpoint for real-time training progress updates
|
||||
"""
|
||||
# Validate token from query parameters
|
||||
token = websocket.query_params.get("token")
|
||||
if not token:
|
||||
logger.warning(f"WebSocket connection rejected - missing token for job {job_id}")
|
||||
await websocket.close(code=1008, reason="Authentication token required")
|
||||
return
|
||||
|
||||
# Validate the token
|
||||
from shared.auth.jwt_handler import JWTHandler
|
||||
|
||||
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
|
||||
|
||||
try:
|
||||
payload = jwt_handler.verify_token(token)
|
||||
if not payload:
|
||||
logger.warning(f"WebSocket connection rejected - invalid token for job {job_id}")
|
||||
await websocket.close(code=1008, reason="Invalid authentication token")
|
||||
return
|
||||
|
||||
# Verify user has access to this tenant
|
||||
user_id = payload.get('user_id')
|
||||
if not user_id:
|
||||
logger.warning(f"WebSocket connection rejected - no user_id in token for job {job_id}")
|
||||
await websocket.close(code=1008, reason="Invalid token payload")
|
||||
return
|
||||
|
||||
logger.info(f"WebSocket authenticated for user {payload.get('email', 'unknown')} on job {job_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"WebSocket token validation failed for job {job_id}: {e}")
|
||||
await websocket.close(code=1008, reason="Token validation failed")
|
||||
return
|
||||
|
||||
connection_id = f"{tenant_id}_{user_id}_{id(websocket)}"
|
||||
|
||||
await connection_manager.connect(websocket, job_id, connection_id)
|
||||
logger.info(f"WebSocket connection established for job {job_id}, user {user_id}")
|
||||
|
||||
# Send immediate connection confirmation to prevent gateway timeout
|
||||
try:
|
||||
await websocket.send_json({
|
||||
"type": "connected",
|
||||
"job_id": job_id,
|
||||
"message": "WebSocket connection established",
|
||||
"timestamp": str(datetime.now())
|
||||
})
|
||||
logger.debug(f"Sent connection confirmation for job {job_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send connection confirmation for job {job_id}: {e}")
|
||||
|
||||
consumer_task = None
|
||||
training_completed = False
|
||||
|
||||
try:
|
||||
# Start RabbitMQ consumer
|
||||
consumer_task = asyncio.create_task(
|
||||
setup_rabbitmq_consumer_for_job(job_id, tenant_id)
|
||||
)
|
||||
|
||||
last_activity = asyncio.get_event_loop().time()
|
||||
|
||||
while not training_completed:
|
||||
try:
|
||||
try:
|
||||
data = await asyncio.wait_for(websocket.receive(), timeout=60.0)
|
||||
last_activity = asyncio.get_event_loop().time()
|
||||
|
||||
# Handle different message types
|
||||
if data["type"] == "websocket.receive":
|
||||
if "text" in data:
|
||||
message_text = data["text"]
|
||||
if message_text == "ping":
|
||||
await websocket.send_text("pong")
|
||||
logger.debug(f"Text ping received from job {job_id}")
|
||||
elif message_text == "get_status":
|
||||
current_status = await get_current_job_status(job_id, tenant_id)
|
||||
if current_status:
|
||||
await websocket.send_json({
|
||||
"type": "current_status",
|
||||
"job_id": job_id,
|
||||
"data": current_status
|
||||
})
|
||||
elif message_text == "close":
|
||||
logger.info(f"Client requested connection close for job {job_id}")
|
||||
break
|
||||
|
||||
elif "bytes" in data:
|
||||
await websocket.send_text("pong")
|
||||
logger.debug(f"Binary ping received for job {job_id}, responding with text pong")
|
||||
|
||||
elif data["type"] == "websocket.disconnect":
|
||||
logger.info(f"WebSocket disconnect message received for job {job_id}")
|
||||
break
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
|
||||
if current_time - last_activity > 90:
|
||||
logger.warning(f"No frontend activity for 90s on job {job_id}, sending training service heartbeat")
|
||||
|
||||
try:
|
||||
await websocket.send_json({
|
||||
"type": "heartbeat",
|
||||
"job_id": job_id,
|
||||
"timestamp": str(datetime.now()),
|
||||
"message": "Training service heartbeat - frontend inactive",
|
||||
"inactivity_seconds": int(current_time - last_activity)
|
||||
})
|
||||
last_activity = current_time
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send heartbeat for job {job_id}: {e}")
|
||||
break
|
||||
else:
|
||||
logger.debug(f"Normal 60s timeout for job {job_id}, continuing (last activity: {int(current_time - last_activity)}s ago)")
|
||||
continue
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket client disconnected for job {job_id}")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error for job {job_id}: {e}")
|
||||
if "Cannot call" in str(e) and "disconnect message" in str(e):
|
||||
logger.error(f"FastAPI WebSocket disconnect error - connection already closed")
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
|
||||
logger.info(f"WebSocket loop ended for job {job_id}, training_completed: {training_completed}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Critical WebSocket error for job {job_id}: {e}")
|
||||
|
||||
finally:
|
||||
logger.info(f"Cleaning up WebSocket connection for job {job_id}")
|
||||
connection_manager.disconnect(job_id, connection_id)
|
||||
|
||||
if consumer_task and not consumer_task.done():
|
||||
if training_completed:
|
||||
logger.info(f"Training completed, cancelling consumer for job {job_id}")
|
||||
consumer_task.cancel()
|
||||
else:
|
||||
logger.warning(f"WebSocket disconnected but training not completed for job {job_id}")
|
||||
|
||||
try:
|
||||
await consumer_task
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Consumer task cancelled for job {job_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Consumer task error for job {job_id}: {e}")
|
||||
|
||||
|
||||
async def setup_rabbitmq_consumer_for_job(job_id: str, tenant_id: str):
|
||||
"""Set up RabbitMQ consumer to listen for training events for a specific job"""
|
||||
|
||||
logger.info(f"Setting up RabbitMQ consumer for job {job_id}")
|
||||
|
||||
try:
|
||||
# Create a unique queue for this WebSocket connection
|
||||
queue_name = f"websocket_training_{job_id}_{tenant_id}"
|
||||
|
||||
async def handle_training_message(message):
|
||||
"""Handle incoming RabbitMQ messages and forward to WebSocket"""
|
||||
try:
|
||||
# Parse the message
|
||||
body = message.body.decode()
|
||||
data = json.loads(body)
|
||||
|
||||
logger.debug(f"Received message for job {job_id}: {data.get('event_type', 'unknown')}")
|
||||
|
||||
# Extract event data
|
||||
event_type = data.get("event_type", "unknown")
|
||||
event_data = data.get("data", {})
|
||||
|
||||
# Only process messages for this specific job
|
||||
message_job_id = event_data.get("job_id") if event_data else None
|
||||
if message_job_id != job_id:
|
||||
logger.debug(f"Ignoring message for different job: {message_job_id}")
|
||||
await message.ack()
|
||||
return
|
||||
|
||||
# Transform RabbitMQ message to WebSocket message format
|
||||
websocket_message = {
|
||||
"type": map_event_type_to_websocket_type(event_type),
|
||||
"job_id": job_id,
|
||||
"timestamp": data.get("timestamp"),
|
||||
"data": event_data
|
||||
}
|
||||
|
||||
logger.info(f"Forwarding {event_type} message to WebSocket clients for job {job_id}")
|
||||
|
||||
# Send to all WebSocket connections for this job
|
||||
await connection_manager.send_to_job(job_id, websocket_message)
|
||||
|
||||
# Check if this is a completion message
|
||||
if event_type in ["training.completed", "training.failed"]:
|
||||
logger.info(f"Training completion detected for job {job_id}: {event_type}")
|
||||
|
||||
# Acknowledge the message
|
||||
await message.ack()
|
||||
|
||||
logger.debug(f"Successfully processed {event_type} for job {job_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling training message for job {job_id}: {e}")
|
||||
import traceback
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
await message.nack(requeue=False)
|
||||
|
||||
# Check if training_publisher is connected
|
||||
if not training_publisher.connected:
|
||||
logger.warning(f"Training publisher not connected for job {job_id}, attempting to connect...")
|
||||
success = await training_publisher.connect()
|
||||
if not success:
|
||||
logger.error(f"Failed to connect training_publisher for job {job_id}")
|
||||
return
|
||||
|
||||
# Subscribe to training events
|
||||
logger.info(f"Subscribing to training events for job {job_id}")
|
||||
success = await training_publisher.consume_events(
|
||||
exchange_name="training.events",
|
||||
queue_name=queue_name,
|
||||
routing_key="training.*",
|
||||
callback=handle_training_message
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"Successfully set up RabbitMQ consumer for job {job_id} (queue: {queue_name})")
|
||||
|
||||
# Keep the consumer running indefinitely until cancelled
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(10)
|
||||
logger.debug(f"Consumer heartbeat for job {job_id}")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Consumer cancelled for job {job_id}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Consumer error for job {job_id}: {e}")
|
||||
raise
|
||||
else:
|
||||
logger.error(f"Failed to set up RabbitMQ consumer for job {job_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Exception in setup_rabbitmq_consumer_for_job for job {job_id}: {e}")
|
||||
import traceback
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
|
||||
def map_event_type_to_websocket_type(rabbitmq_event_type: str) -> str:
|
||||
"""Map RabbitMQ event types to WebSocket message types"""
|
||||
mapping = {
|
||||
"training.started": "started",
|
||||
"training.progress": "progress",
|
||||
"training.completed": "completed",
|
||||
"training.failed": "failed",
|
||||
"training.cancelled": "cancelled",
|
||||
"training.step.completed": "step_completed",
|
||||
"training.product.started": "product_started",
|
||||
"training.product.completed": "product_completed",
|
||||
"training.product.failed": "product_failed",
|
||||
"training.model.trained": "model_trained",
|
||||
"training.data.validation.started": "validation_started",
|
||||
"training.data.validation.completed": "validation_completed"
|
||||
}
|
||||
|
||||
return mapping.get(rabbitmq_event_type, "unknown")
|
||||
|
||||
|
||||
async def get_current_job_status(job_id: str, tenant_id: str) -> Dict[str, Any]:
|
||||
"""Get current job status from database"""
|
||||
try:
|
||||
return {
|
||||
"job_id": job_id,
|
||||
"status": "running",
|
||||
"progress": 0,
|
||||
"current_step": "Starting...",
|
||||
"started_at": "2025-07-30T19:00:00Z"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get current job status: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint for the training operations"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "training-operations",
|
||||
"version": "2.0.0",
|
||||
"version": "3.0.0",
|
||||
"features": [
|
||||
"repository-pattern",
|
||||
"dependency-injection",
|
||||
"enhanced-error-handling",
|
||||
"metrics-tracking",
|
||||
"transactional-operations",
|
||||
"websocket-support"
|
||||
"transactional-operations"
|
||||
],
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
109
services/training/app/api/websocket_operations.py
Normal file
109
services/training/app/api/websocket_operations.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
WebSocket Operations for Training Service
|
||||
Simple WebSocket endpoint that connects clients and receives broadcasts from RabbitMQ
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Path, Query
|
||||
import structlog
|
||||
|
||||
from app.websocket.manager import websocket_manager
|
||||
from shared.auth.jwt_handler import JWTHandler
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
router = APIRouter(tags=["websocket"])
|
||||
|
||||
|
||||
@router.websocket("/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live")
|
||||
async def training_progress_websocket(
|
||||
websocket: WebSocket,
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
job_id: str = Path(..., description="Job ID"),
|
||||
token: str = Query(..., description="Authentication token")
|
||||
):
|
||||
"""
|
||||
WebSocket endpoint for real-time training progress updates.
|
||||
|
||||
This endpoint:
|
||||
1. Validates the authentication token
|
||||
2. Accepts the WebSocket connection
|
||||
3. Keeps the connection alive
|
||||
4. Receives broadcasts from RabbitMQ (via WebSocket manager)
|
||||
"""
|
||||
|
||||
# Validate token
|
||||
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
|
||||
|
||||
try:
|
||||
payload = jwt_handler.verify_token(token)
|
||||
if not payload:
|
||||
await websocket.close(code=1008, reason="Invalid token")
|
||||
logger.warning("WebSocket connection rejected - invalid token",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id)
|
||||
return
|
||||
|
||||
user_id = payload.get('user_id')
|
||||
if not user_id:
|
||||
await websocket.close(code=1008, reason="Invalid token payload")
|
||||
logger.warning("WebSocket connection rejected - no user_id in token",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id)
|
||||
return
|
||||
|
||||
logger.info("WebSocket authentication successful",
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
job_id=job_id)
|
||||
|
||||
except Exception as e:
|
||||
await websocket.close(code=1008, reason="Authentication failed")
|
||||
logger.warning("WebSocket authentication failed",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return
|
||||
|
||||
# Connect to WebSocket manager
|
||||
await websocket_manager.connect(job_id, websocket)
|
||||
|
||||
try:
|
||||
# Send connection confirmation
|
||||
await websocket.send_json({
|
||||
"type": "connected",
|
||||
"job_id": job_id,
|
||||
"message": "Connected to training progress stream"
|
||||
})
|
||||
|
||||
# Keep connection alive and handle client messages
|
||||
ping_count = 0
|
||||
while True:
|
||||
try:
|
||||
# Receive messages from client (ping, etc.)
|
||||
data = await websocket.receive_text()
|
||||
|
||||
# Handle ping/pong
|
||||
if data == "ping":
|
||||
await websocket.send_text("pong")
|
||||
ping_count += 1
|
||||
logger.info("WebSocket ping/pong",
|
||||
job_id=job_id,
|
||||
ping_count=ping_count,
|
||||
connection_healthy=True)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info("Client disconnected", job_id=job_id)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error("Error in WebSocket message loop",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
break
|
||||
|
||||
finally:
|
||||
# Disconnect from manager
|
||||
await websocket_manager.disconnect(job_id, websocket)
|
||||
logger.info("WebSocket connection closed",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id)
|
||||
Reference in New Issue
Block a user