REFACTOR external service and improve websocket training
This commit is contained in:
410
services/training/app/api/monitoring.py
Normal file
410
services/training/app/api/monitoring.py
Normal file
@@ -0,0 +1,410 @@
|
||||
"""
|
||||
Monitoring and Observability Endpoints
|
||||
Real-time service monitoring and diagnostics
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Query
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from sqlalchemy import text, func
|
||||
import logging
|
||||
|
||||
from app.core.database import database_manager
|
||||
from app.utils.circuit_breaker import circuit_breaker_registry
|
||||
from app.models.training import ModelTrainingLog, TrainingJobQueue, TrainedModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/monitoring/circuit-breakers")
|
||||
async def get_circuit_breaker_status() -> Dict[str, Any]:
|
||||
"""
|
||||
Get status of all circuit breakers.
|
||||
Useful for monitoring external service health.
|
||||
"""
|
||||
breakers = circuit_breaker_registry.get_all_states()
|
||||
|
||||
return {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"circuit_breakers": breakers,
|
||||
"summary": {
|
||||
"total": len(breakers),
|
||||
"open": sum(1 for b in breakers.values() if b["state"] == "open"),
|
||||
"half_open": sum(1 for b in breakers.values() if b["state"] == "half_open"),
|
||||
"closed": sum(1 for b in breakers.values() if b["state"] == "closed")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.post("/monitoring/circuit-breakers/{name}/reset")
|
||||
async def reset_circuit_breaker(name: str) -> Dict[str, str]:
|
||||
"""
|
||||
Manually reset a circuit breaker.
|
||||
Use with caution - only reset if you know the service has recovered.
|
||||
"""
|
||||
circuit_breaker_registry.reset(name)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Circuit breaker '{name}' has been reset",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/monitoring/training-jobs")
|
||||
async def get_training_job_stats(
|
||||
hours: int = Query(default=24, ge=1, le=168, description="Look back period in hours")
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get training job statistics for the specified period.
|
||||
"""
|
||||
try:
|
||||
since = datetime.now(timezone.utc) - timedelta(hours=hours)
|
||||
|
||||
async with database_manager.get_session() as session:
|
||||
# Get job counts by status
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT status, COUNT(*) as count
|
||||
FROM model_training_logs
|
||||
WHERE created_at >= :since
|
||||
GROUP BY status
|
||||
"""),
|
||||
{"since": since}
|
||||
)
|
||||
status_counts = dict(result.fetchall())
|
||||
|
||||
# Get average training time for completed jobs
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT AVG(EXTRACT(EPOCH FROM (end_time - start_time))) as avg_duration
|
||||
FROM model_training_logs
|
||||
WHERE status = 'completed'
|
||||
AND created_at >= :since
|
||||
AND end_time IS NOT NULL
|
||||
"""),
|
||||
{"since": since}
|
||||
)
|
||||
avg_duration = result.scalar()
|
||||
|
||||
# Get failure rate
|
||||
total = sum(status_counts.values())
|
||||
failed = status_counts.get('failed', 0)
|
||||
failure_rate = (failed / total * 100) if total > 0 else 0
|
||||
|
||||
# Get recent jobs
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT job_id, tenant_id, status, progress, start_time, end_time
|
||||
FROM model_training_logs
|
||||
WHERE created_at >= :since
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 10
|
||||
"""),
|
||||
{"since": since}
|
||||
)
|
||||
recent_jobs = [
|
||||
{
|
||||
"job_id": row.job_id,
|
||||
"tenant_id": str(row.tenant_id),
|
||||
"status": row.status,
|
||||
"progress": row.progress,
|
||||
"start_time": row.start_time.isoformat() if row.start_time else None,
|
||||
"end_time": row.end_time.isoformat() if row.end_time else None
|
||||
}
|
||||
for row in result.fetchall()
|
||||
]
|
||||
|
||||
return {
|
||||
"period_hours": hours,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"summary": {
|
||||
"total_jobs": total,
|
||||
"by_status": status_counts,
|
||||
"failure_rate_percent": round(failure_rate, 2),
|
||||
"avg_duration_seconds": round(avg_duration, 2) if avg_duration else None
|
||||
},
|
||||
"recent_jobs": recent_jobs
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get training job stats: {e}")
|
||||
return {
|
||||
"error": str(e),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/monitoring/models")
|
||||
async def get_model_stats() -> Dict[str, Any]:
|
||||
"""
|
||||
Get statistics about trained models.
|
||||
"""
|
||||
try:
|
||||
async with database_manager.get_session() as session:
|
||||
# Total models
|
||||
result = await session.execute(
|
||||
text("SELECT COUNT(*) FROM trained_models")
|
||||
)
|
||||
total_models = result.scalar()
|
||||
|
||||
# Active models
|
||||
result = await session.execute(
|
||||
text("SELECT COUNT(*) FROM trained_models WHERE is_active = true")
|
||||
)
|
||||
active_models = result.scalar()
|
||||
|
||||
# Production models
|
||||
result = await session.execute(
|
||||
text("SELECT COUNT(*) FROM trained_models WHERE is_production = true")
|
||||
)
|
||||
production_models = result.scalar()
|
||||
|
||||
# Models by type
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT model_type, COUNT(*) as count
|
||||
FROM trained_models
|
||||
GROUP BY model_type
|
||||
""")
|
||||
)
|
||||
models_by_type = dict(result.fetchall())
|
||||
|
||||
# Average model performance (MAPE)
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT AVG(mape) as avg_mape
|
||||
FROM trained_models
|
||||
WHERE mape IS NOT NULL
|
||||
AND is_active = true
|
||||
""")
|
||||
)
|
||||
avg_mape = result.scalar()
|
||||
|
||||
# Models created today
|
||||
today = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT COUNT(*) FROM trained_models
|
||||
WHERE created_at >= :today
|
||||
"""),
|
||||
{"today": today}
|
||||
)
|
||||
models_today = result.scalar()
|
||||
|
||||
return {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"summary": {
|
||||
"total_models": total_models,
|
||||
"active_models": active_models,
|
||||
"production_models": production_models,
|
||||
"models_created_today": models_today,
|
||||
"average_mape_percent": round(avg_mape, 2) if avg_mape else None
|
||||
},
|
||||
"by_type": models_by_type
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get model stats: {e}")
|
||||
return {
|
||||
"error": str(e),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/monitoring/queue")
|
||||
async def get_queue_status() -> Dict[str, Any]:
|
||||
"""
|
||||
Get training job queue status.
|
||||
"""
|
||||
try:
|
||||
async with database_manager.get_session() as session:
|
||||
# Queued jobs
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT COUNT(*) FROM training_job_queue
|
||||
WHERE status = 'queued'
|
||||
""")
|
||||
)
|
||||
queued = result.scalar()
|
||||
|
||||
# Running jobs
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT COUNT(*) FROM training_job_queue
|
||||
WHERE status = 'running'
|
||||
""")
|
||||
)
|
||||
running = result.scalar()
|
||||
|
||||
# Get oldest queued job
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT created_at FROM training_job_queue
|
||||
WHERE status = 'queued'
|
||||
ORDER BY created_at ASC
|
||||
LIMIT 1
|
||||
""")
|
||||
)
|
||||
oldest_queued = result.scalar()
|
||||
|
||||
# Calculate wait time
|
||||
if oldest_queued:
|
||||
wait_time_seconds = (datetime.now(timezone.utc) - oldest_queued).total_seconds()
|
||||
else:
|
||||
wait_time_seconds = 0
|
||||
|
||||
return {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"queue": {
|
||||
"queued": queued,
|
||||
"running": running,
|
||||
"oldest_wait_time_seconds": round(wait_time_seconds, 2) if oldest_queued else 0,
|
||||
"oldest_queued_at": oldest_queued.isoformat() if oldest_queued else None
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get queue status: {e}")
|
||||
return {
|
||||
"error": str(e),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/monitoring/performance")
|
||||
async def get_performance_metrics(
|
||||
tenant_id: Optional[str] = Query(None, description="Filter by tenant ID")
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get model performance metrics.
|
||||
"""
|
||||
try:
|
||||
async with database_manager.get_session() as session:
|
||||
query_params = {}
|
||||
where_clause = ""
|
||||
|
||||
if tenant_id:
|
||||
where_clause = "WHERE tenant_id = :tenant_id"
|
||||
query_params["tenant_id"] = tenant_id
|
||||
|
||||
# Get performance distribution
|
||||
result = await session.execute(
|
||||
text(f"""
|
||||
SELECT
|
||||
COUNT(*) as total,
|
||||
AVG(mape) as avg_mape,
|
||||
MIN(mape) as min_mape,
|
||||
MAX(mape) as max_mape,
|
||||
PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY mape) as median_mape,
|
||||
AVG(mae) as avg_mae,
|
||||
AVG(rmse) as avg_rmse
|
||||
FROM model_performance_metrics
|
||||
{where_clause}
|
||||
"""),
|
||||
query_params
|
||||
)
|
||||
stats = result.fetchone()
|
||||
|
||||
# Get accuracy distribution (buckets)
|
||||
result = await session.execute(
|
||||
text(f"""
|
||||
SELECT
|
||||
CASE
|
||||
WHEN mape <= 10 THEN 'excellent'
|
||||
WHEN mape <= 20 THEN 'good'
|
||||
WHEN mape <= 30 THEN 'acceptable'
|
||||
ELSE 'poor'
|
||||
END as accuracy_category,
|
||||
COUNT(*) as count
|
||||
FROM model_performance_metrics
|
||||
{where_clause}
|
||||
GROUP BY accuracy_category
|
||||
"""),
|
||||
query_params
|
||||
)
|
||||
distribution = dict(result.fetchall())
|
||||
|
||||
return {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"tenant_id": tenant_id,
|
||||
"statistics": {
|
||||
"total_metrics": stats.total if stats else 0,
|
||||
"avg_mape_percent": round(stats.avg_mape, 2) if stats and stats.avg_mape else None,
|
||||
"min_mape_percent": round(stats.min_mape, 2) if stats and stats.min_mape else None,
|
||||
"max_mape_percent": round(stats.max_mape, 2) if stats and stats.max_mape else None,
|
||||
"median_mape_percent": round(stats.median_mape, 2) if stats and stats.median_mape else None,
|
||||
"avg_mae": round(stats.avg_mae, 2) if stats and stats.avg_mae else None,
|
||||
"avg_rmse": round(stats.avg_rmse, 2) if stats and stats.avg_rmse else None
|
||||
},
|
||||
"distribution": distribution
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get performance metrics: {e}")
|
||||
return {
|
||||
"error": str(e),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/monitoring/alerts")
|
||||
async def get_alerts() -> Dict[str, Any]:
|
||||
"""
|
||||
Get active alerts and warnings based on system state.
|
||||
"""
|
||||
alerts = []
|
||||
warnings = []
|
||||
|
||||
try:
|
||||
# Check circuit breakers
|
||||
breakers = circuit_breaker_registry.get_all_states()
|
||||
for name, state in breakers.items():
|
||||
if state["state"] == "open":
|
||||
alerts.append({
|
||||
"type": "circuit_breaker_open",
|
||||
"severity": "high",
|
||||
"message": f"Circuit breaker '{name}' is OPEN - service unavailable",
|
||||
"details": state
|
||||
})
|
||||
elif state["state"] == "half_open":
|
||||
warnings.append({
|
||||
"type": "circuit_breaker_recovering",
|
||||
"severity": "medium",
|
||||
"message": f"Circuit breaker '{name}' is recovering",
|
||||
"details": state
|
||||
})
|
||||
|
||||
# Check queue backlog
|
||||
async with database_manager.get_session() as session:
|
||||
result = await session.execute(
|
||||
text("SELECT COUNT(*) FROM training_job_queue WHERE status = 'queued'")
|
||||
)
|
||||
queued = result.scalar()
|
||||
|
||||
if queued > 10:
|
||||
warnings.append({
|
||||
"type": "queue_backlog",
|
||||
"severity": "medium",
|
||||
"message": f"Training queue has {queued} pending jobs",
|
||||
"count": queued
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate alerts: {e}")
|
||||
alerts.append({
|
||||
"type": "monitoring_error",
|
||||
"severity": "high",
|
||||
"message": f"Failed to check system alerts: {str(e)}"
|
||||
})
|
||||
|
||||
return {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"summary": {
|
||||
"total_alerts": len(alerts),
|
||||
"total_warnings": len(warnings)
|
||||
},
|
||||
"alerts": alerts,
|
||||
"warnings": warnings
|
||||
}
|
||||
Reference in New Issue
Block a user