REFACTOR external service and improve websocket training
This commit is contained in:
@@ -1,14 +1,16 @@
|
||||
"""
|
||||
Training API Layer
|
||||
HTTP endpoints for ML training operations
|
||||
HTTP endpoints for ML training operations and WebSocket connections
|
||||
"""
|
||||
|
||||
from .training_jobs import router as training_jobs_router
|
||||
from .training_operations import router as training_operations_router
|
||||
from .models import router as models_router
|
||||
from .websocket_operations import router as websocket_operations_router
|
||||
|
||||
__all__ = [
|
||||
"training_jobs_router",
|
||||
"training_operations_router",
|
||||
"models_router"
|
||||
"models_router",
|
||||
"websocket_operations_router"
|
||||
]
|
||||
261
services/training/app/api/health.py
Normal file
261
services/training/app/api/health.py
Normal file
@@ -0,0 +1,261 @@
|
||||
"""
|
||||
Enhanced Health Check Endpoints
|
||||
Comprehensive service health monitoring
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import text
|
||||
from typing import Dict, Any
|
||||
import psutil
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
import logging
|
||||
|
||||
from app.core.database import database_manager
|
||||
from app.utils.circuit_breaker import circuit_breaker_registry
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def check_database_health() -> Dict[str, Any]:
|
||||
"""Check database connectivity and performance"""
|
||||
try:
|
||||
start_time = datetime.now(timezone.utc)
|
||||
|
||||
async with database_manager.async_engine.begin() as conn:
|
||||
# Simple connectivity check
|
||||
await conn.execute(text("SELECT 1"))
|
||||
|
||||
# Check if we can access training tables
|
||||
result = await conn.execute(
|
||||
text("SELECT COUNT(*) FROM trained_models")
|
||||
)
|
||||
model_count = result.scalar()
|
||||
|
||||
# Check connection pool stats
|
||||
pool = database_manager.async_engine.pool
|
||||
pool_size = pool.size()
|
||||
pool_checked_out = pool.checked_out_connections()
|
||||
|
||||
response_time = (datetime.now(timezone.utc) - start_time).total_seconds()
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"response_time_seconds": round(response_time, 3),
|
||||
"model_count": model_count,
|
||||
"connection_pool": {
|
||||
"size": pool_size,
|
||||
"checked_out": pool_checked_out,
|
||||
"available": pool_size - pool_checked_out
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Database health check failed: {e}")
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
def check_system_resources() -> Dict[str, Any]:
|
||||
"""Check system resource usage"""
|
||||
try:
|
||||
cpu_percent = psutil.cpu_percent(interval=0.1)
|
||||
memory = psutil.virtual_memory()
|
||||
disk = psutil.disk_usage('/')
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"cpu": {
|
||||
"usage_percent": cpu_percent,
|
||||
"count": psutil.cpu_count()
|
||||
},
|
||||
"memory": {
|
||||
"total_mb": round(memory.total / 1024 / 1024, 2),
|
||||
"used_mb": round(memory.used / 1024 / 1024, 2),
|
||||
"available_mb": round(memory.available / 1024 / 1024, 2),
|
||||
"usage_percent": memory.percent
|
||||
},
|
||||
"disk": {
|
||||
"total_gb": round(disk.total / 1024 / 1024 / 1024, 2),
|
||||
"used_gb": round(disk.used / 1024 / 1024 / 1024, 2),
|
||||
"free_gb": round(disk.free / 1024 / 1024 / 1024, 2),
|
||||
"usage_percent": disk.percent
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"System resource check failed: {e}")
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
def check_model_storage() -> Dict[str, Any]:
|
||||
"""Check model storage health"""
|
||||
try:
|
||||
storage_path = settings.MODEL_STORAGE_PATH
|
||||
|
||||
if not os.path.exists(storage_path):
|
||||
return {
|
||||
"status": "warning",
|
||||
"message": f"Model storage path does not exist: {storage_path}"
|
||||
}
|
||||
|
||||
# Check if writable
|
||||
test_file = os.path.join(storage_path, ".health_check")
|
||||
try:
|
||||
with open(test_file, 'w') as f:
|
||||
f.write("test")
|
||||
os.remove(test_file)
|
||||
writable = True
|
||||
except Exception:
|
||||
writable = False
|
||||
|
||||
# Count model files
|
||||
model_files = 0
|
||||
total_size = 0
|
||||
for root, dirs, files in os.walk(storage_path):
|
||||
for file in files:
|
||||
if file.endswith('.pkl'):
|
||||
model_files += 1
|
||||
file_path = os.path.join(root, file)
|
||||
total_size += os.path.getsize(file_path)
|
||||
|
||||
return {
|
||||
"status": "healthy" if writable else "degraded",
|
||||
"path": storage_path,
|
||||
"writable": writable,
|
||||
"model_files": model_files,
|
||||
"total_size_mb": round(total_size / 1024 / 1024, 2)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Model storage check failed: {e}")
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check() -> Dict[str, Any]:
|
||||
"""
|
||||
Basic health check endpoint.
|
||||
Returns 200 if service is running.
|
||||
"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "training-service",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/health/detailed")
|
||||
async def detailed_health_check() -> Dict[str, Any]:
|
||||
"""
|
||||
Detailed health check with component status.
|
||||
Includes database, system resources, and dependencies.
|
||||
"""
|
||||
database_health = await check_database_health()
|
||||
system_health = check_system_resources()
|
||||
storage_health = check_model_storage()
|
||||
circuit_breakers = circuit_breaker_registry.get_all_states()
|
||||
|
||||
# Determine overall status
|
||||
component_statuses = [
|
||||
database_health.get("status"),
|
||||
system_health.get("status"),
|
||||
storage_health.get("status")
|
||||
]
|
||||
|
||||
if "unhealthy" in component_statuses or "error" in component_statuses:
|
||||
overall_status = "unhealthy"
|
||||
elif "degraded" in component_statuses or "warning" in component_statuses:
|
||||
overall_status = "degraded"
|
||||
else:
|
||||
overall_status = "healthy"
|
||||
|
||||
return {
|
||||
"status": overall_status,
|
||||
"service": "training-service",
|
||||
"version": "1.0.0",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"components": {
|
||||
"database": database_health,
|
||||
"system": system_health,
|
||||
"storage": storage_health
|
||||
},
|
||||
"circuit_breakers": circuit_breakers,
|
||||
"configuration": {
|
||||
"max_concurrent_jobs": settings.MAX_CONCURRENT_TRAINING_JOBS,
|
||||
"min_training_days": settings.MIN_TRAINING_DATA_DAYS,
|
||||
"pool_size": settings.DB_POOL_SIZE,
|
||||
"pool_max_overflow": settings.DB_MAX_OVERFLOW
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.get("/health/ready")
|
||||
async def readiness_check() -> Dict[str, Any]:
|
||||
"""
|
||||
Readiness check for Kubernetes.
|
||||
Returns 200 only if service is ready to accept traffic.
|
||||
"""
|
||||
database_health = await check_database_health()
|
||||
|
||||
if database_health.get("status") != "healthy":
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Service not ready: database unavailable"
|
||||
)
|
||||
|
||||
storage_health = check_model_storage()
|
||||
if storage_health.get("status") == "error":
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Service not ready: model storage unavailable"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "ready",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/health/live")
|
||||
async def liveness_check() -> Dict[str, Any]:
|
||||
"""
|
||||
Liveness check for Kubernetes.
|
||||
Returns 200 if service process is alive.
|
||||
"""
|
||||
return {
|
||||
"status": "alive",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"pid": os.getpid()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/metrics/system")
|
||||
async def system_metrics() -> Dict[str, Any]:
|
||||
"""
|
||||
Detailed system metrics for monitoring.
|
||||
"""
|
||||
process = psutil.Process(os.getpid())
|
||||
|
||||
return {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"process": {
|
||||
"pid": os.getpid(),
|
||||
"cpu_percent": process.cpu_percent(interval=0.1),
|
||||
"memory_mb": round(process.memory_info().rss / 1024 / 1024, 2),
|
||||
"threads": process.num_threads(),
|
||||
"open_files": len(process.open_files()),
|
||||
"connections": len(process.connections())
|
||||
},
|
||||
"system": check_system_resources()
|
||||
}
|
||||
@@ -10,14 +10,12 @@ from sqlalchemy import text
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.schemas.training import TrainedModelResponse, ModelMetricsResponse
|
||||
from app.services.training_service import TrainingService
|
||||
from app.services.training_service import EnhancedTrainingService
|
||||
from datetime import datetime
|
||||
from sqlalchemy import select, delete, func
|
||||
import uuid
|
||||
import shutil
|
||||
|
||||
from app.services.messaging import publish_models_deleted_event
|
||||
|
||||
from shared.auth.decorators import (
|
||||
get_current_user_dep,
|
||||
require_admin_role
|
||||
@@ -38,7 +36,7 @@ route_builder = RouteBuilder('training')
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter()
|
||||
|
||||
training_service = TrainingService()
|
||||
training_service = EnhancedTrainingService()
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("models") + "/{inventory_product_id}/active"
|
||||
@@ -472,12 +470,7 @@ async def delete_tenant_models_complete(
|
||||
deletion_stats["errors"].append(error_msg)
|
||||
logger.warning(error_msg)
|
||||
|
||||
# Step 5: Publish deletion event
|
||||
try:
|
||||
await publish_models_deleted_event(tenant_id, deletion_stats)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish models deletion event", error=str(e))
|
||||
|
||||
# Models deleted successfully
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"All training data for tenant {tenant_id} deleted successfully",
|
||||
|
||||
410
services/training/app/api/monitoring.py
Normal file
410
services/training/app/api/monitoring.py
Normal file
@@ -0,0 +1,410 @@
|
||||
"""
|
||||
Monitoring and Observability Endpoints
|
||||
Real-time service monitoring and diagnostics
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Query
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from sqlalchemy import text, func
|
||||
import logging
|
||||
|
||||
from app.core.database import database_manager
|
||||
from app.utils.circuit_breaker import circuit_breaker_registry
|
||||
from app.models.training import ModelTrainingLog, TrainingJobQueue, TrainedModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/monitoring/circuit-breakers")
|
||||
async def get_circuit_breaker_status() -> Dict[str, Any]:
|
||||
"""
|
||||
Get status of all circuit breakers.
|
||||
Useful for monitoring external service health.
|
||||
"""
|
||||
breakers = circuit_breaker_registry.get_all_states()
|
||||
|
||||
return {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"circuit_breakers": breakers,
|
||||
"summary": {
|
||||
"total": len(breakers),
|
||||
"open": sum(1 for b in breakers.values() if b["state"] == "open"),
|
||||
"half_open": sum(1 for b in breakers.values() if b["state"] == "half_open"),
|
||||
"closed": sum(1 for b in breakers.values() if b["state"] == "closed")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.post("/monitoring/circuit-breakers/{name}/reset")
|
||||
async def reset_circuit_breaker(name: str) -> Dict[str, str]:
|
||||
"""
|
||||
Manually reset a circuit breaker.
|
||||
Use with caution - only reset if you know the service has recovered.
|
||||
"""
|
||||
circuit_breaker_registry.reset(name)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Circuit breaker '{name}' has been reset",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/monitoring/training-jobs")
|
||||
async def get_training_job_stats(
|
||||
hours: int = Query(default=24, ge=1, le=168, description="Look back period in hours")
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get training job statistics for the specified period.
|
||||
"""
|
||||
try:
|
||||
since = datetime.now(timezone.utc) - timedelta(hours=hours)
|
||||
|
||||
async with database_manager.get_session() as session:
|
||||
# Get job counts by status
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT status, COUNT(*) as count
|
||||
FROM model_training_logs
|
||||
WHERE created_at >= :since
|
||||
GROUP BY status
|
||||
"""),
|
||||
{"since": since}
|
||||
)
|
||||
status_counts = dict(result.fetchall())
|
||||
|
||||
# Get average training time for completed jobs
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT AVG(EXTRACT(EPOCH FROM (end_time - start_time))) as avg_duration
|
||||
FROM model_training_logs
|
||||
WHERE status = 'completed'
|
||||
AND created_at >= :since
|
||||
AND end_time IS NOT NULL
|
||||
"""),
|
||||
{"since": since}
|
||||
)
|
||||
avg_duration = result.scalar()
|
||||
|
||||
# Get failure rate
|
||||
total = sum(status_counts.values())
|
||||
failed = status_counts.get('failed', 0)
|
||||
failure_rate = (failed / total * 100) if total > 0 else 0
|
||||
|
||||
# Get recent jobs
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT job_id, tenant_id, status, progress, start_time, end_time
|
||||
FROM model_training_logs
|
||||
WHERE created_at >= :since
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 10
|
||||
"""),
|
||||
{"since": since}
|
||||
)
|
||||
recent_jobs = [
|
||||
{
|
||||
"job_id": row.job_id,
|
||||
"tenant_id": str(row.tenant_id),
|
||||
"status": row.status,
|
||||
"progress": row.progress,
|
||||
"start_time": row.start_time.isoformat() if row.start_time else None,
|
||||
"end_time": row.end_time.isoformat() if row.end_time else None
|
||||
}
|
||||
for row in result.fetchall()
|
||||
]
|
||||
|
||||
return {
|
||||
"period_hours": hours,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"summary": {
|
||||
"total_jobs": total,
|
||||
"by_status": status_counts,
|
||||
"failure_rate_percent": round(failure_rate, 2),
|
||||
"avg_duration_seconds": round(avg_duration, 2) if avg_duration else None
|
||||
},
|
||||
"recent_jobs": recent_jobs
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get training job stats: {e}")
|
||||
return {
|
||||
"error": str(e),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/monitoring/models")
|
||||
async def get_model_stats() -> Dict[str, Any]:
|
||||
"""
|
||||
Get statistics about trained models.
|
||||
"""
|
||||
try:
|
||||
async with database_manager.get_session() as session:
|
||||
# Total models
|
||||
result = await session.execute(
|
||||
text("SELECT COUNT(*) FROM trained_models")
|
||||
)
|
||||
total_models = result.scalar()
|
||||
|
||||
# Active models
|
||||
result = await session.execute(
|
||||
text("SELECT COUNT(*) FROM trained_models WHERE is_active = true")
|
||||
)
|
||||
active_models = result.scalar()
|
||||
|
||||
# Production models
|
||||
result = await session.execute(
|
||||
text("SELECT COUNT(*) FROM trained_models WHERE is_production = true")
|
||||
)
|
||||
production_models = result.scalar()
|
||||
|
||||
# Models by type
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT model_type, COUNT(*) as count
|
||||
FROM trained_models
|
||||
GROUP BY model_type
|
||||
""")
|
||||
)
|
||||
models_by_type = dict(result.fetchall())
|
||||
|
||||
# Average model performance (MAPE)
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT AVG(mape) as avg_mape
|
||||
FROM trained_models
|
||||
WHERE mape IS NOT NULL
|
||||
AND is_active = true
|
||||
""")
|
||||
)
|
||||
avg_mape = result.scalar()
|
||||
|
||||
# Models created today
|
||||
today = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT COUNT(*) FROM trained_models
|
||||
WHERE created_at >= :today
|
||||
"""),
|
||||
{"today": today}
|
||||
)
|
||||
models_today = result.scalar()
|
||||
|
||||
return {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"summary": {
|
||||
"total_models": total_models,
|
||||
"active_models": active_models,
|
||||
"production_models": production_models,
|
||||
"models_created_today": models_today,
|
||||
"average_mape_percent": round(avg_mape, 2) if avg_mape else None
|
||||
},
|
||||
"by_type": models_by_type
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get model stats: {e}")
|
||||
return {
|
||||
"error": str(e),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/monitoring/queue")
|
||||
async def get_queue_status() -> Dict[str, Any]:
|
||||
"""
|
||||
Get training job queue status.
|
||||
"""
|
||||
try:
|
||||
async with database_manager.get_session() as session:
|
||||
# Queued jobs
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT COUNT(*) FROM training_job_queue
|
||||
WHERE status = 'queued'
|
||||
""")
|
||||
)
|
||||
queued = result.scalar()
|
||||
|
||||
# Running jobs
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT COUNT(*) FROM training_job_queue
|
||||
WHERE status = 'running'
|
||||
""")
|
||||
)
|
||||
running = result.scalar()
|
||||
|
||||
# Get oldest queued job
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT created_at FROM training_job_queue
|
||||
WHERE status = 'queued'
|
||||
ORDER BY created_at ASC
|
||||
LIMIT 1
|
||||
""")
|
||||
)
|
||||
oldest_queued = result.scalar()
|
||||
|
||||
# Calculate wait time
|
||||
if oldest_queued:
|
||||
wait_time_seconds = (datetime.now(timezone.utc) - oldest_queued).total_seconds()
|
||||
else:
|
||||
wait_time_seconds = 0
|
||||
|
||||
return {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"queue": {
|
||||
"queued": queued,
|
||||
"running": running,
|
||||
"oldest_wait_time_seconds": round(wait_time_seconds, 2) if oldest_queued else 0,
|
||||
"oldest_queued_at": oldest_queued.isoformat() if oldest_queued else None
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get queue status: {e}")
|
||||
return {
|
||||
"error": str(e),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/monitoring/performance")
|
||||
async def get_performance_metrics(
|
||||
tenant_id: Optional[str] = Query(None, description="Filter by tenant ID")
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get model performance metrics.
|
||||
"""
|
||||
try:
|
||||
async with database_manager.get_session() as session:
|
||||
query_params = {}
|
||||
where_clause = ""
|
||||
|
||||
if tenant_id:
|
||||
where_clause = "WHERE tenant_id = :tenant_id"
|
||||
query_params["tenant_id"] = tenant_id
|
||||
|
||||
# Get performance distribution
|
||||
result = await session.execute(
|
||||
text(f"""
|
||||
SELECT
|
||||
COUNT(*) as total,
|
||||
AVG(mape) as avg_mape,
|
||||
MIN(mape) as min_mape,
|
||||
MAX(mape) as max_mape,
|
||||
PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY mape) as median_mape,
|
||||
AVG(mae) as avg_mae,
|
||||
AVG(rmse) as avg_rmse
|
||||
FROM model_performance_metrics
|
||||
{where_clause}
|
||||
"""),
|
||||
query_params
|
||||
)
|
||||
stats = result.fetchone()
|
||||
|
||||
# Get accuracy distribution (buckets)
|
||||
result = await session.execute(
|
||||
text(f"""
|
||||
SELECT
|
||||
CASE
|
||||
WHEN mape <= 10 THEN 'excellent'
|
||||
WHEN mape <= 20 THEN 'good'
|
||||
WHEN mape <= 30 THEN 'acceptable'
|
||||
ELSE 'poor'
|
||||
END as accuracy_category,
|
||||
COUNT(*) as count
|
||||
FROM model_performance_metrics
|
||||
{where_clause}
|
||||
GROUP BY accuracy_category
|
||||
"""),
|
||||
query_params
|
||||
)
|
||||
distribution = dict(result.fetchall())
|
||||
|
||||
return {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"tenant_id": tenant_id,
|
||||
"statistics": {
|
||||
"total_metrics": stats.total if stats else 0,
|
||||
"avg_mape_percent": round(stats.avg_mape, 2) if stats and stats.avg_mape else None,
|
||||
"min_mape_percent": round(stats.min_mape, 2) if stats and stats.min_mape else None,
|
||||
"max_mape_percent": round(stats.max_mape, 2) if stats and stats.max_mape else None,
|
||||
"median_mape_percent": round(stats.median_mape, 2) if stats and stats.median_mape else None,
|
||||
"avg_mae": round(stats.avg_mae, 2) if stats and stats.avg_mae else None,
|
||||
"avg_rmse": round(stats.avg_rmse, 2) if stats and stats.avg_rmse else None
|
||||
},
|
||||
"distribution": distribution
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get performance metrics: {e}")
|
||||
return {
|
||||
"error": str(e),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/monitoring/alerts")
|
||||
async def get_alerts() -> Dict[str, Any]:
|
||||
"""
|
||||
Get active alerts and warnings based on system state.
|
||||
"""
|
||||
alerts = []
|
||||
warnings = []
|
||||
|
||||
try:
|
||||
# Check circuit breakers
|
||||
breakers = circuit_breaker_registry.get_all_states()
|
||||
for name, state in breakers.items():
|
||||
if state["state"] == "open":
|
||||
alerts.append({
|
||||
"type": "circuit_breaker_open",
|
||||
"severity": "high",
|
||||
"message": f"Circuit breaker '{name}' is OPEN - service unavailable",
|
||||
"details": state
|
||||
})
|
||||
elif state["state"] == "half_open":
|
||||
warnings.append({
|
||||
"type": "circuit_breaker_recovering",
|
||||
"severity": "medium",
|
||||
"message": f"Circuit breaker '{name}' is recovering",
|
||||
"details": state
|
||||
})
|
||||
|
||||
# Check queue backlog
|
||||
async with database_manager.get_session() as session:
|
||||
result = await session.execute(
|
||||
text("SELECT COUNT(*) FROM training_job_queue WHERE status = 'queued'")
|
||||
)
|
||||
queued = result.scalar()
|
||||
|
||||
if queued > 10:
|
||||
warnings.append({
|
||||
"type": "queue_backlog",
|
||||
"severity": "medium",
|
||||
"message": f"Training queue has {queued} pending jobs",
|
||||
"count": queued
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate alerts: {e}")
|
||||
alerts.append({
|
||||
"type": "monitoring_error",
|
||||
"severity": "high",
|
||||
"message": f"Failed to check system alerts: {str(e)}"
|
||||
})
|
||||
|
||||
return {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"summary": {
|
||||
"total_alerts": len(alerts),
|
||||
"total_warnings": len(warnings)
|
||||
},
|
||||
"alerts": alerts,
|
||||
"warnings": warnings
|
||||
}
|
||||
@@ -1,21 +1,18 @@
|
||||
"""
|
||||
Training Operations API - BUSINESS logic
|
||||
Handles training job execution, metrics, and WebSocket live feed
|
||||
Handles training job execution and metrics
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks, Request, Path, WebSocket, WebSocketDisconnect
|
||||
from typing import List, Optional, Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks, Request, Path
|
||||
from typing import Optional, Dict, Any
|
||||
import structlog
|
||||
import asyncio
|
||||
import json
|
||||
import datetime
|
||||
from shared.auth.access_control import require_user_role, admin_role_required, analytics_tier_required
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
|
||||
from shared.routing import RouteBuilder
|
||||
from shared.monitoring.decorators import track_execution_time
|
||||
from shared.monitoring.metrics import get_metrics_collector
|
||||
from shared.database.base import create_database_manager
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
|
||||
from app.services.training_service import EnhancedTrainingService
|
||||
from app.schemas.training import (
|
||||
@@ -23,15 +20,10 @@ from app.schemas.training import (
|
||||
SingleProductTrainingRequest,
|
||||
TrainingJobResponse
|
||||
)
|
||||
from app.services.messaging import (
|
||||
publish_job_progress,
|
||||
publish_data_validation_started,
|
||||
publish_data_validation_completed,
|
||||
publish_job_step_completed,
|
||||
publish_job_completed,
|
||||
publish_job_failed,
|
||||
publish_job_started,
|
||||
training_publisher
|
||||
from app.services.training_events import (
|
||||
publish_training_started,
|
||||
publish_training_completed,
|
||||
publish_training_failed
|
||||
)
|
||||
from app.core.config import settings
|
||||
|
||||
@@ -85,6 +77,14 @@ async def start_training_job(
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_training_jobs_created_total")
|
||||
|
||||
# Publish training.started event immediately so WebSocket clients
|
||||
# have initial state when they connect
|
||||
await publish_training_started(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
total_products=0 # Will be updated when actual training starts
|
||||
)
|
||||
|
||||
# Add enhanced background task
|
||||
background_tasks.add_task(
|
||||
execute_training_job_background,
|
||||
@@ -190,12 +190,8 @@ async def execute_training_job_background(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# Publish job started event
|
||||
await publish_job_started(job_id, tenant_id, {
|
||||
"enhanced_features": True,
|
||||
"repository_pattern": True,
|
||||
"job_type": "enhanced_training"
|
||||
})
|
||||
# This will be published by the training service itself
|
||||
# when it starts execution
|
||||
|
||||
training_config = {
|
||||
"job_id": job_id,
|
||||
@@ -241,16 +237,7 @@ async def execute_training_job_background(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# Publish enhanced completion event
|
||||
await publish_job_completed(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
results={
|
||||
**result,
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True
|
||||
}
|
||||
)
|
||||
# Completion event is published by the training service
|
||||
|
||||
logger.info("Enhanced background training job completed successfully",
|
||||
job_id=job_id,
|
||||
@@ -276,17 +263,8 @@ async def execute_training_job_background(
|
||||
job_id=job_id,
|
||||
status_error=str(status_error))
|
||||
|
||||
# Publish enhanced failure event
|
||||
await publish_job_failed(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
error=str(training_error),
|
||||
metadata={
|
||||
"enhanced_features": True,
|
||||
"repository_pattern": True,
|
||||
"error_type": type(training_error).__name__
|
||||
}
|
||||
)
|
||||
# Failure event is published by the training service
|
||||
await publish_training_failed(job_id, tenant_id, str(training_error))
|
||||
|
||||
except Exception as background_error:
|
||||
logger.error("Critical error in enhanced background training job",
|
||||
@@ -370,373 +348,19 @@ async def start_single_product_training(
|
||||
)
|
||||
|
||||
|
||||
# ============================================
|
||||
# WebSocket Live Feed
|
||||
# ============================================
|
||||
|
||||
class ConnectionManager:
|
||||
"""Manage WebSocket connections for training progress"""
|
||||
|
||||
def __init__(self):
|
||||
self.active_connections: Dict[str, Dict[str, WebSocket]] = {}
|
||||
# Structure: {job_id: {connection_id: websocket}}
|
||||
|
||||
async def connect(self, websocket: WebSocket, job_id: str, connection_id: str):
|
||||
"""Accept WebSocket connection and register it"""
|
||||
await websocket.accept()
|
||||
|
||||
if job_id not in self.active_connections:
|
||||
self.active_connections[job_id] = {}
|
||||
|
||||
self.active_connections[job_id][connection_id] = websocket
|
||||
logger.info(f"WebSocket connected for job {job_id}, connection {connection_id}")
|
||||
|
||||
def disconnect(self, job_id: str, connection_id: str):
|
||||
"""Remove WebSocket connection"""
|
||||
if job_id in self.active_connections:
|
||||
self.active_connections[job_id].pop(connection_id, None)
|
||||
if not self.active_connections[job_id]:
|
||||
del self.active_connections[job_id]
|
||||
|
||||
logger.info(f"WebSocket disconnected for job {job_id}, connection {connection_id}")
|
||||
|
||||
async def send_to_job(self, job_id: str, message: dict):
|
||||
"""Send message to all connections for a specific job with better error handling"""
|
||||
if job_id not in self.active_connections:
|
||||
logger.debug(f"No active connections for job {job_id}")
|
||||
return
|
||||
|
||||
# Send to all connections for this job
|
||||
disconnected_connections = []
|
||||
|
||||
for connection_id, websocket in self.active_connections[job_id].items():
|
||||
try:
|
||||
await websocket.send_json(message)
|
||||
logger.debug(f"Sent {message.get('type', 'unknown')} to connection {connection_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send message to connection {connection_id}: {e}")
|
||||
disconnected_connections.append(connection_id)
|
||||
|
||||
# Clean up disconnected connections
|
||||
for connection_id in disconnected_connections:
|
||||
self.disconnect(job_id, connection_id)
|
||||
|
||||
# Log successful sends
|
||||
active_count = len(self.active_connections.get(job_id, {}))
|
||||
if active_count > 0:
|
||||
logger.info(f"Sent {message.get('type', 'unknown')} message to {active_count} connection(s) for job {job_id}")
|
||||
|
||||
|
||||
# Global connection manager
|
||||
connection_manager = ConnectionManager()
|
||||
|
||||
|
||||
@router.websocket(route_builder.build_nested_resource_route('jobs', 'job_id', 'live'))
|
||||
async def training_progress_websocket(
|
||||
websocket: WebSocket,
|
||||
tenant_id: str,
|
||||
job_id: str
|
||||
):
|
||||
"""
|
||||
WebSocket endpoint for real-time training progress updates
|
||||
"""
|
||||
# Validate token from query parameters
|
||||
token = websocket.query_params.get("token")
|
||||
if not token:
|
||||
logger.warning(f"WebSocket connection rejected - missing token for job {job_id}")
|
||||
await websocket.close(code=1008, reason="Authentication token required")
|
||||
return
|
||||
|
||||
# Validate the token
|
||||
from shared.auth.jwt_handler import JWTHandler
|
||||
|
||||
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
|
||||
|
||||
try:
|
||||
payload = jwt_handler.verify_token(token)
|
||||
if not payload:
|
||||
logger.warning(f"WebSocket connection rejected - invalid token for job {job_id}")
|
||||
await websocket.close(code=1008, reason="Invalid authentication token")
|
||||
return
|
||||
|
||||
# Verify user has access to this tenant
|
||||
user_id = payload.get('user_id')
|
||||
if not user_id:
|
||||
logger.warning(f"WebSocket connection rejected - no user_id in token for job {job_id}")
|
||||
await websocket.close(code=1008, reason="Invalid token payload")
|
||||
return
|
||||
|
||||
logger.info(f"WebSocket authenticated for user {payload.get('email', 'unknown')} on job {job_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"WebSocket token validation failed for job {job_id}: {e}")
|
||||
await websocket.close(code=1008, reason="Token validation failed")
|
||||
return
|
||||
|
||||
connection_id = f"{tenant_id}_{user_id}_{id(websocket)}"
|
||||
|
||||
await connection_manager.connect(websocket, job_id, connection_id)
|
||||
logger.info(f"WebSocket connection established for job {job_id}, user {user_id}")
|
||||
|
||||
# Send immediate connection confirmation to prevent gateway timeout
|
||||
try:
|
||||
await websocket.send_json({
|
||||
"type": "connected",
|
||||
"job_id": job_id,
|
||||
"message": "WebSocket connection established",
|
||||
"timestamp": str(datetime.now())
|
||||
})
|
||||
logger.debug(f"Sent connection confirmation for job {job_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send connection confirmation for job {job_id}: {e}")
|
||||
|
||||
consumer_task = None
|
||||
training_completed = False
|
||||
|
||||
try:
|
||||
# Start RabbitMQ consumer
|
||||
consumer_task = asyncio.create_task(
|
||||
setup_rabbitmq_consumer_for_job(job_id, tenant_id)
|
||||
)
|
||||
|
||||
last_activity = asyncio.get_event_loop().time()
|
||||
|
||||
while not training_completed:
|
||||
try:
|
||||
try:
|
||||
data = await asyncio.wait_for(websocket.receive(), timeout=60.0)
|
||||
last_activity = asyncio.get_event_loop().time()
|
||||
|
||||
# Handle different message types
|
||||
if data["type"] == "websocket.receive":
|
||||
if "text" in data:
|
||||
message_text = data["text"]
|
||||
if message_text == "ping":
|
||||
await websocket.send_text("pong")
|
||||
logger.debug(f"Text ping received from job {job_id}")
|
||||
elif message_text == "get_status":
|
||||
current_status = await get_current_job_status(job_id, tenant_id)
|
||||
if current_status:
|
||||
await websocket.send_json({
|
||||
"type": "current_status",
|
||||
"job_id": job_id,
|
||||
"data": current_status
|
||||
})
|
||||
elif message_text == "close":
|
||||
logger.info(f"Client requested connection close for job {job_id}")
|
||||
break
|
||||
|
||||
elif "bytes" in data:
|
||||
await websocket.send_text("pong")
|
||||
logger.debug(f"Binary ping received for job {job_id}, responding with text pong")
|
||||
|
||||
elif data["type"] == "websocket.disconnect":
|
||||
logger.info(f"WebSocket disconnect message received for job {job_id}")
|
||||
break
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
|
||||
if current_time - last_activity > 90:
|
||||
logger.warning(f"No frontend activity for 90s on job {job_id}, sending training service heartbeat")
|
||||
|
||||
try:
|
||||
await websocket.send_json({
|
||||
"type": "heartbeat",
|
||||
"job_id": job_id,
|
||||
"timestamp": str(datetime.now()),
|
||||
"message": "Training service heartbeat - frontend inactive",
|
||||
"inactivity_seconds": int(current_time - last_activity)
|
||||
})
|
||||
last_activity = current_time
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send heartbeat for job {job_id}: {e}")
|
||||
break
|
||||
else:
|
||||
logger.debug(f"Normal 60s timeout for job {job_id}, continuing (last activity: {int(current_time - last_activity)}s ago)")
|
||||
continue
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket client disconnected for job {job_id}")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error for job {job_id}: {e}")
|
||||
if "Cannot call" in str(e) and "disconnect message" in str(e):
|
||||
logger.error(f"FastAPI WebSocket disconnect error - connection already closed")
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
|
||||
logger.info(f"WebSocket loop ended for job {job_id}, training_completed: {training_completed}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Critical WebSocket error for job {job_id}: {e}")
|
||||
|
||||
finally:
|
||||
logger.info(f"Cleaning up WebSocket connection for job {job_id}")
|
||||
connection_manager.disconnect(job_id, connection_id)
|
||||
|
||||
if consumer_task and not consumer_task.done():
|
||||
if training_completed:
|
||||
logger.info(f"Training completed, cancelling consumer for job {job_id}")
|
||||
consumer_task.cancel()
|
||||
else:
|
||||
logger.warning(f"WebSocket disconnected but training not completed for job {job_id}")
|
||||
|
||||
try:
|
||||
await consumer_task
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Consumer task cancelled for job {job_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Consumer task error for job {job_id}: {e}")
|
||||
|
||||
|
||||
async def setup_rabbitmq_consumer_for_job(job_id: str, tenant_id: str):
|
||||
"""Set up RabbitMQ consumer to listen for training events for a specific job"""
|
||||
|
||||
logger.info(f"Setting up RabbitMQ consumer for job {job_id}")
|
||||
|
||||
try:
|
||||
# Create a unique queue for this WebSocket connection
|
||||
queue_name = f"websocket_training_{job_id}_{tenant_id}"
|
||||
|
||||
async def handle_training_message(message):
|
||||
"""Handle incoming RabbitMQ messages and forward to WebSocket"""
|
||||
try:
|
||||
# Parse the message
|
||||
body = message.body.decode()
|
||||
data = json.loads(body)
|
||||
|
||||
logger.debug(f"Received message for job {job_id}: {data.get('event_type', 'unknown')}")
|
||||
|
||||
# Extract event data
|
||||
event_type = data.get("event_type", "unknown")
|
||||
event_data = data.get("data", {})
|
||||
|
||||
# Only process messages for this specific job
|
||||
message_job_id = event_data.get("job_id") if event_data else None
|
||||
if message_job_id != job_id:
|
||||
logger.debug(f"Ignoring message for different job: {message_job_id}")
|
||||
await message.ack()
|
||||
return
|
||||
|
||||
# Transform RabbitMQ message to WebSocket message format
|
||||
websocket_message = {
|
||||
"type": map_event_type_to_websocket_type(event_type),
|
||||
"job_id": job_id,
|
||||
"timestamp": data.get("timestamp"),
|
||||
"data": event_data
|
||||
}
|
||||
|
||||
logger.info(f"Forwarding {event_type} message to WebSocket clients for job {job_id}")
|
||||
|
||||
# Send to all WebSocket connections for this job
|
||||
await connection_manager.send_to_job(job_id, websocket_message)
|
||||
|
||||
# Check if this is a completion message
|
||||
if event_type in ["training.completed", "training.failed"]:
|
||||
logger.info(f"Training completion detected for job {job_id}: {event_type}")
|
||||
|
||||
# Acknowledge the message
|
||||
await message.ack()
|
||||
|
||||
logger.debug(f"Successfully processed {event_type} for job {job_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling training message for job {job_id}: {e}")
|
||||
import traceback
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
await message.nack(requeue=False)
|
||||
|
||||
# Check if training_publisher is connected
|
||||
if not training_publisher.connected:
|
||||
logger.warning(f"Training publisher not connected for job {job_id}, attempting to connect...")
|
||||
success = await training_publisher.connect()
|
||||
if not success:
|
||||
logger.error(f"Failed to connect training_publisher for job {job_id}")
|
||||
return
|
||||
|
||||
# Subscribe to training events
|
||||
logger.info(f"Subscribing to training events for job {job_id}")
|
||||
success = await training_publisher.consume_events(
|
||||
exchange_name="training.events",
|
||||
queue_name=queue_name,
|
||||
routing_key="training.*",
|
||||
callback=handle_training_message
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"Successfully set up RabbitMQ consumer for job {job_id} (queue: {queue_name})")
|
||||
|
||||
# Keep the consumer running indefinitely until cancelled
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(10)
|
||||
logger.debug(f"Consumer heartbeat for job {job_id}")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Consumer cancelled for job {job_id}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Consumer error for job {job_id}: {e}")
|
||||
raise
|
||||
else:
|
||||
logger.error(f"Failed to set up RabbitMQ consumer for job {job_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Exception in setup_rabbitmq_consumer_for_job for job {job_id}: {e}")
|
||||
import traceback
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
|
||||
def map_event_type_to_websocket_type(rabbitmq_event_type: str) -> str:
|
||||
"""Map RabbitMQ event types to WebSocket message types"""
|
||||
mapping = {
|
||||
"training.started": "started",
|
||||
"training.progress": "progress",
|
||||
"training.completed": "completed",
|
||||
"training.failed": "failed",
|
||||
"training.cancelled": "cancelled",
|
||||
"training.step.completed": "step_completed",
|
||||
"training.product.started": "product_started",
|
||||
"training.product.completed": "product_completed",
|
||||
"training.product.failed": "product_failed",
|
||||
"training.model.trained": "model_trained",
|
||||
"training.data.validation.started": "validation_started",
|
||||
"training.data.validation.completed": "validation_completed"
|
||||
}
|
||||
|
||||
return mapping.get(rabbitmq_event_type, "unknown")
|
||||
|
||||
|
||||
async def get_current_job_status(job_id: str, tenant_id: str) -> Dict[str, Any]:
|
||||
"""Get current job status from database"""
|
||||
try:
|
||||
return {
|
||||
"job_id": job_id,
|
||||
"status": "running",
|
||||
"progress": 0,
|
||||
"current_step": "Starting...",
|
||||
"started_at": "2025-07-30T19:00:00Z"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get current job status: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint for the training operations"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "training-operations",
|
||||
"version": "2.0.0",
|
||||
"version": "3.0.0",
|
||||
"features": [
|
||||
"repository-pattern",
|
||||
"dependency-injection",
|
||||
"enhanced-error-handling",
|
||||
"metrics-tracking",
|
||||
"transactional-operations",
|
||||
"websocket-support"
|
||||
"transactional-operations"
|
||||
],
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
109
services/training/app/api/websocket_operations.py
Normal file
109
services/training/app/api/websocket_operations.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
WebSocket Operations for Training Service
|
||||
Simple WebSocket endpoint that connects clients and receives broadcasts from RabbitMQ
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Path, Query
|
||||
import structlog
|
||||
|
||||
from app.websocket.manager import websocket_manager
|
||||
from shared.auth.jwt_handler import JWTHandler
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
router = APIRouter(tags=["websocket"])
|
||||
|
||||
|
||||
@router.websocket("/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live")
|
||||
async def training_progress_websocket(
|
||||
websocket: WebSocket,
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
job_id: str = Path(..., description="Job ID"),
|
||||
token: str = Query(..., description="Authentication token")
|
||||
):
|
||||
"""
|
||||
WebSocket endpoint for real-time training progress updates.
|
||||
|
||||
This endpoint:
|
||||
1. Validates the authentication token
|
||||
2. Accepts the WebSocket connection
|
||||
3. Keeps the connection alive
|
||||
4. Receives broadcasts from RabbitMQ (via WebSocket manager)
|
||||
"""
|
||||
|
||||
# Validate token
|
||||
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
|
||||
|
||||
try:
|
||||
payload = jwt_handler.verify_token(token)
|
||||
if not payload:
|
||||
await websocket.close(code=1008, reason="Invalid token")
|
||||
logger.warning("WebSocket connection rejected - invalid token",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id)
|
||||
return
|
||||
|
||||
user_id = payload.get('user_id')
|
||||
if not user_id:
|
||||
await websocket.close(code=1008, reason="Invalid token payload")
|
||||
logger.warning("WebSocket connection rejected - no user_id in token",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id)
|
||||
return
|
||||
|
||||
logger.info("WebSocket authentication successful",
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
job_id=job_id)
|
||||
|
||||
except Exception as e:
|
||||
await websocket.close(code=1008, reason="Authentication failed")
|
||||
logger.warning("WebSocket authentication failed",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return
|
||||
|
||||
# Connect to WebSocket manager
|
||||
await websocket_manager.connect(job_id, websocket)
|
||||
|
||||
try:
|
||||
# Send connection confirmation
|
||||
await websocket.send_json({
|
||||
"type": "connected",
|
||||
"job_id": job_id,
|
||||
"message": "Connected to training progress stream"
|
||||
})
|
||||
|
||||
# Keep connection alive and handle client messages
|
||||
ping_count = 0
|
||||
while True:
|
||||
try:
|
||||
# Receive messages from client (ping, etc.)
|
||||
data = await websocket.receive_text()
|
||||
|
||||
# Handle ping/pong
|
||||
if data == "ping":
|
||||
await websocket.send_text("pong")
|
||||
ping_count += 1
|
||||
logger.info("WebSocket ping/pong",
|
||||
job_id=job_id,
|
||||
ping_count=ping_count,
|
||||
connection_healthy=True)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info("Client disconnected", job_id=job_id)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error("Error in WebSocket message loop",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
break
|
||||
|
||||
finally:
|
||||
# Disconnect from manager
|
||||
await websocket_manager.disconnect(job_id, websocket)
|
||||
logger.info("WebSocket connection closed",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id)
|
||||
@@ -41,25 +41,16 @@ class TrainingSettings(BaseServiceSettings):
|
||||
REDIS_DB: int = 1
|
||||
|
||||
# ML Model Storage
|
||||
MODEL_STORAGE_PATH: str = os.getenv("MODEL_STORAGE_PATH", "/app/models")
|
||||
MODEL_BACKUP_ENABLED: bool = os.getenv("MODEL_BACKUP_ENABLED", "true").lower() == "true"
|
||||
MODEL_VERSIONING_ENABLED: bool = os.getenv("MODEL_VERSIONING_ENABLED", "true").lower() == "true"
|
||||
|
||||
# Training Configuration
|
||||
MAX_TRAINING_TIME_MINUTES: int = int(os.getenv("MAX_TRAINING_TIME_MINUTES", "30"))
|
||||
MAX_CONCURRENT_TRAINING_JOBS: int = int(os.getenv("MAX_CONCURRENT_TRAINING_JOBS", "3"))
|
||||
MIN_TRAINING_DATA_DAYS: int = int(os.getenv("MIN_TRAINING_DATA_DAYS", "30"))
|
||||
TRAINING_BATCH_SIZE: int = int(os.getenv("TRAINING_BATCH_SIZE", "1000"))
|
||||
|
||||
# Prophet Specific Configuration
|
||||
PROPHET_SEASONALITY_MODE: str = os.getenv("PROPHET_SEASONALITY_MODE", "additive")
|
||||
PROPHET_CHANGEPOINT_PRIOR_SCALE: float = float(os.getenv("PROPHET_CHANGEPOINT_PRIOR_SCALE", "0.05"))
|
||||
PROPHET_SEASONALITY_PRIOR_SCALE: float = float(os.getenv("PROPHET_SEASONALITY_PRIOR_SCALE", "10.0"))
|
||||
PROPHET_HOLIDAYS_PRIOR_SCALE: float = float(os.getenv("PROPHET_HOLIDAYS_PRIOR_SCALE", "10.0"))
|
||||
|
||||
# Spanish Holiday Integration
|
||||
ENABLE_SPANISH_HOLIDAYS: bool = True
|
||||
ENABLE_MADRID_HOLIDAYS: bool = True
|
||||
ENABLE_CUSTOM_HOLIDAYS: bool = os.getenv("ENABLE_CUSTOM_HOLIDAYS", "true").lower() == "true"
|
||||
|
||||
# Data Processing
|
||||
@@ -79,6 +70,8 @@ class TrainingSettings(BaseServiceSettings):
|
||||
PROPHET_DAILY_SEASONALITY: bool = True
|
||||
PROPHET_WEEKLY_SEASONALITY: bool = True
|
||||
PROPHET_YEARLY_SEASONALITY: bool = True
|
||||
PROPHET_SEASONALITY_MODE: str = "additive"
|
||||
|
||||
settings = TrainingSettings()
|
||||
# Throttling settings for parallel training to prevent heartbeat blocking
|
||||
MAX_CONCURRENT_TRAININGS: int = int(os.getenv("MAX_CONCURRENT_TRAININGS", "3"))
|
||||
|
||||
settings = TrainingSettings()
|
||||
|
||||
97
services/training/app/core/constants.py
Normal file
97
services/training/app/core/constants.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Training Service Constants
|
||||
Centralized constants to avoid magic numbers throughout the codebase
|
||||
"""
|
||||
|
||||
# Data Validation Thresholds
|
||||
MIN_DATA_POINTS_REQUIRED = 30
|
||||
RECOMMENDED_DATA_POINTS = 90
|
||||
MAX_ZERO_RATIO_ERROR = 0.9 # 90% zeros = error
|
||||
HIGH_ZERO_RATIO_WARNING = 0.7 # 70% zeros = warning
|
||||
MAX_ZERO_RATIO_INTERMITTENT = 0.8 # Products with >80% zeros are intermittent
|
||||
MODERATE_SPARSITY_THRESHOLD = 0.6 # 60% zeros = moderate sparsity
|
||||
|
||||
# Training Time Periods (in days)
|
||||
MIN_NON_ZERO_DAYS = 30 # Minimum days with non-zero sales
|
||||
DATA_QUALITY_DAY_THRESHOLD_LOW = 90
|
||||
DATA_QUALITY_DAY_THRESHOLD_HIGH = 365
|
||||
MAX_TRAINING_RANGE_DAYS = 730 # 2 years
|
||||
MIN_TRAINING_RANGE_DAYS = 30
|
||||
|
||||
# Product Classification Thresholds
|
||||
HIGH_VOLUME_MEAN_SALES = 10.0
|
||||
HIGH_VOLUME_ZERO_RATIO = 0.3
|
||||
MEDIUM_VOLUME_MEAN_SALES = 5.0
|
||||
MEDIUM_VOLUME_ZERO_RATIO = 0.5
|
||||
LOW_VOLUME_MEAN_SALES = 2.0
|
||||
LOW_VOLUME_ZERO_RATIO = 0.7
|
||||
|
||||
# Hyperparameter Optimization
|
||||
OPTUNA_TRIALS_HIGH_VOLUME = 30
|
||||
OPTUNA_TRIALS_MEDIUM_VOLUME = 25
|
||||
OPTUNA_TRIALS_LOW_VOLUME = 20
|
||||
OPTUNA_TRIALS_INTERMITTENT = 15
|
||||
OPTUNA_TIMEOUT_SECONDS = 600
|
||||
|
||||
# Prophet Uncertainty Sampling
|
||||
UNCERTAINTY_SAMPLES_SPARSE_MIN = 100
|
||||
UNCERTAINTY_SAMPLES_SPARSE_MAX = 200
|
||||
UNCERTAINTY_SAMPLES_LOW_MIN = 150
|
||||
UNCERTAINTY_SAMPLES_LOW_MAX = 300
|
||||
UNCERTAINTY_SAMPLES_MEDIUM_MIN = 200
|
||||
UNCERTAINTY_SAMPLES_MEDIUM_MAX = 500
|
||||
UNCERTAINTY_SAMPLES_HIGH_MIN = 300
|
||||
UNCERTAINTY_SAMPLES_HIGH_MAX = 800
|
||||
|
||||
# MAPE Calculation
|
||||
MAPE_LOW_VOLUME_THRESHOLD = 2.0
|
||||
MAPE_MEDIUM_VOLUME_THRESHOLD = 5.0
|
||||
MAPE_CALCULATION_MIN_THRESHOLD = 0.5
|
||||
MAPE_CALCULATION_MID_THRESHOLD = 1.0
|
||||
MAPE_MAX_CAP = 200.0 # Cap MAPE at 200%
|
||||
MAPE_MEDIUM_CAP = 150.0
|
||||
|
||||
# Baseline MAPE estimates for improvement calculation
|
||||
BASELINE_MAPE_VERY_SPARSE = 80.0
|
||||
BASELINE_MAPE_SPARSE = 60.0
|
||||
BASELINE_MAPE_HIGH_VOLUME = 25.0
|
||||
BASELINE_MAPE_MEDIUM_VOLUME = 35.0
|
||||
BASELINE_MAPE_LOW_VOLUME = 45.0
|
||||
IMPROVEMENT_SIGNIFICANCE_THRESHOLD = 0.8 # Only claim improvement if MAPE < 80% of baseline
|
||||
|
||||
# Cross-validation
|
||||
CV_N_SPLITS = 2
|
||||
CV_MIN_VALIDATION_DAYS = 7
|
||||
|
||||
# Progress tracking
|
||||
PROGRESS_DATA_PREPARATION_START = 0
|
||||
PROGRESS_DATA_PREPARATION_END = 45
|
||||
PROGRESS_MODEL_TRAINING_START = 45
|
||||
PROGRESS_MODEL_TRAINING_END = 85
|
||||
PROGRESS_FINALIZATION_START = 85
|
||||
PROGRESS_FINALIZATION_END = 100
|
||||
|
||||
# HTTP Client Configuration
|
||||
HTTP_TIMEOUT_DEFAULT = 30.0 # seconds
|
||||
HTTP_TIMEOUT_LONG_RUNNING = 60.0 # for training data fetches
|
||||
HTTP_MAX_RETRIES = 3
|
||||
HTTP_RETRY_BACKOFF_FACTOR = 2.0
|
||||
|
||||
# WebSocket Configuration
|
||||
WEBSOCKET_PING_TIMEOUT = 60.0 # seconds
|
||||
WEBSOCKET_ACTIVITY_WARNING_THRESHOLD = 90.0 # seconds
|
||||
WEBSOCKET_CONSUMER_HEARTBEAT_INTERVAL = 10.0 # seconds
|
||||
|
||||
# Synthetic Data Generation
|
||||
SYNTHETIC_TEMP_DEFAULT = 50.0
|
||||
SYNTHETIC_TEMP_VARIATION = 100.0
|
||||
SYNTHETIC_TRAFFIC_DEFAULT = 50.0
|
||||
SYNTHETIC_TRAFFIC_VARIATION = 100.0
|
||||
|
||||
# Model Storage
|
||||
MODEL_FILE_EXTENSION = ".pkl"
|
||||
METADATA_FILE_EXTENSION = ".json"
|
||||
|
||||
# Data Quality Scoring
|
||||
MIN_QUALITY_SCORE = 0.1
|
||||
MAX_QUALITY_SCORE = 1.0
|
||||
@@ -15,8 +15,16 @@ from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Initialize database manager using shared infrastructure
|
||||
database_manager = DatabaseManager(settings.DATABASE_URL)
|
||||
# Initialize database manager with connection pooling configuration
|
||||
database_manager = DatabaseManager(
|
||||
settings.DATABASE_URL,
|
||||
pool_size=settings.DB_POOL_SIZE,
|
||||
max_overflow=settings.DB_MAX_OVERFLOW,
|
||||
pool_timeout=settings.DB_POOL_TIMEOUT,
|
||||
pool_recycle=settings.DB_POOL_RECYCLE,
|
||||
pool_pre_ping=settings.DB_POOL_PRE_PING,
|
||||
echo=settings.DB_ECHO
|
||||
)
|
||||
|
||||
# Alias for convenience - matches the existing interface
|
||||
get_db = database_manager.get_db
|
||||
|
||||
@@ -11,35 +11,15 @@ from fastapi import FastAPI, Request
|
||||
from sqlalchemy import text
|
||||
from app.core.config import settings
|
||||
from app.core.database import initialize_training_database, cleanup_training_database, database_manager
|
||||
from app.api import training_jobs, training_operations, models
|
||||
from app.services.messaging import setup_messaging, cleanup_messaging
|
||||
from app.api import training_jobs, training_operations, models, health, monitoring, websocket_operations
|
||||
from app.services.training_events import setup_messaging, cleanup_messaging
|
||||
from app.websocket.events import setup_websocket_event_consumer, cleanup_websocket_consumers
|
||||
from shared.service_base import StandardFastAPIService
|
||||
|
||||
|
||||
class TrainingService(StandardFastAPIService):
|
||||
"""Training Service with standardized setup"""
|
||||
|
||||
expected_migration_version = "00001"
|
||||
|
||||
async def on_startup(self, app):
|
||||
"""Custom startup logic including migration verification"""
|
||||
await self.verify_migrations()
|
||||
await super().on_startup(app)
|
||||
|
||||
async def verify_migrations(self):
|
||||
"""Verify database schema matches the latest migrations."""
|
||||
try:
|
||||
async with self.database_manager.get_session() as session:
|
||||
result = await session.execute(text("SELECT version_num FROM alembic_version"))
|
||||
version = result.scalar()
|
||||
if version != self.expected_migration_version:
|
||||
self.logger.error(f"Migration version mismatch: expected {self.expected_migration_version}, got {version}")
|
||||
raise RuntimeError(f"Migration version mismatch: expected {self.expected_migration_version}, got {version}")
|
||||
self.logger.info(f"Migration verification successful: {version}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Migration verification failed: {e}")
|
||||
raise
|
||||
|
||||
def __init__(self):
|
||||
# Define expected database tables for health checks
|
||||
training_expected_tables = [
|
||||
@@ -54,7 +34,7 @@ class TrainingService(StandardFastAPIService):
|
||||
version="1.0.0",
|
||||
log_level=settings.LOG_LEVEL,
|
||||
cors_origins=settings.CORS_ORIGINS_LIST,
|
||||
api_prefix="", # Empty because RouteBuilder already includes /api/v1
|
||||
api_prefix="",
|
||||
database_manager=database_manager,
|
||||
expected_tables=training_expected_tables,
|
||||
enable_messaging=True
|
||||
@@ -65,18 +45,42 @@ class TrainingService(StandardFastAPIService):
|
||||
await setup_messaging()
|
||||
self.logger.info("Messaging setup completed")
|
||||
|
||||
# Set up WebSocket event consumer (listens to RabbitMQ and broadcasts to WebSockets)
|
||||
success = await setup_websocket_event_consumer()
|
||||
if success:
|
||||
self.logger.info("WebSocket event consumer setup completed")
|
||||
else:
|
||||
self.logger.warning("WebSocket event consumer setup failed")
|
||||
|
||||
async def _cleanup_messaging(self):
|
||||
"""Cleanup messaging for training service"""
|
||||
await cleanup_websocket_consumers()
|
||||
await cleanup_messaging()
|
||||
|
||||
async def verify_migrations(self):
|
||||
"""Verify database schema matches the latest migrations dynamically."""
|
||||
try:
|
||||
async with self.database_manager.get_session() as session:
|
||||
result = await session.execute(text("SELECT version_num FROM alembic_version"))
|
||||
version = result.scalar()
|
||||
|
||||
if not version:
|
||||
self.logger.error("No migration version found in database")
|
||||
raise RuntimeError("Database not initialized - no alembic version found")
|
||||
|
||||
self.logger.info(f"Migration verification successful: {version}")
|
||||
return version
|
||||
except Exception as e:
|
||||
self.logger.error(f"Migration verification failed: {e}")
|
||||
raise
|
||||
|
||||
async def on_startup(self, app: FastAPI):
|
||||
"""Custom startup logic for training service"""
|
||||
pass
|
||||
"""Custom startup logic including migration verification"""
|
||||
await self.verify_migrations()
|
||||
self.logger.info("Training service startup completed")
|
||||
|
||||
async def on_shutdown(self, app: FastAPI):
|
||||
"""Custom shutdown logic for training service"""
|
||||
# Note: Database cleanup is handled by the base class
|
||||
# but training service has custom cleanup function
|
||||
await cleanup_training_database()
|
||||
self.logger.info("Training database cleanup completed")
|
||||
|
||||
@@ -162,6 +166,9 @@ service.setup_custom_endpoints()
|
||||
service.add_router(training_jobs.router, tags=["training-jobs"])
|
||||
service.add_router(training_operations.router, tags=["training-operations"])
|
||||
service.add_router(models.router, tags=["models"])
|
||||
service.add_router(health.router, tags=["health"])
|
||||
service.add_router(monitoring.router, tags=["monitoring"])
|
||||
service.add_router(websocket_operations.router, tags=["websocket"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
|
||||
@@ -3,16 +3,12 @@ ML Pipeline Components
|
||||
Machine learning training and prediction components
|
||||
"""
|
||||
|
||||
from .trainer import BakeryMLTrainer
|
||||
from .trainer import EnhancedBakeryMLTrainer
|
||||
from .data_processor import BakeryDataProcessor
|
||||
from .data_processor import EnhancedBakeryDataProcessor
|
||||
from .prophet_manager import BakeryProphetManager
|
||||
|
||||
__all__ = [
|
||||
"BakeryMLTrainer",
|
||||
"EnhancedBakeryMLTrainer",
|
||||
"BakeryDataProcessor",
|
||||
"EnhancedBakeryDataProcessor",
|
||||
"BakeryProphetManager"
|
||||
]
|
||||
@@ -865,8 +865,4 @@ class EnhancedBakeryDataProcessor:
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error generating data quality report", error=str(e))
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
# Legacy compatibility alias
|
||||
BakeryDataProcessor = EnhancedBakeryDataProcessor
|
||||
return {"error": str(e)}
|
||||
@@ -32,6 +32,10 @@ import optuna
|
||||
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core import constants as const
|
||||
from app.utils.timezone_utils import prepare_prophet_datetime
|
||||
from app.utils.file_utils import ChecksummedFile, calculate_file_checksum
|
||||
from app.utils.distributed_lock import get_training_lock, LockAcquisitionError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -50,72 +54,79 @@ class BakeryProphetManager:
|
||||
# Ensure model storage directory exists
|
||||
os.makedirs(settings.MODEL_STORAGE_PATH, exist_ok=True)
|
||||
|
||||
async def train_bakery_model(self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
async def train_bakery_model(self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
df: pd.DataFrame,
|
||||
job_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Train a Prophet model with automatic hyperparameter optimization.
|
||||
Same interface as before - optimization happens automatically.
|
||||
Train a Prophet model with automatic hyperparameter optimization and distributed locking.
|
||||
"""
|
||||
# Acquire distributed lock to prevent concurrent training of same product
|
||||
lock = get_training_lock(tenant_id, inventory_product_id, use_advisory=True)
|
||||
|
||||
try:
|
||||
logger.info(f"Training optimized bakery model for {inventory_product_id}")
|
||||
|
||||
# Validate input data
|
||||
await self._validate_training_data(df, inventory_product_id)
|
||||
|
||||
# Prepare data for Prophet
|
||||
prophet_data = await self._prepare_prophet_data(df)
|
||||
|
||||
# Get regressor columns
|
||||
regressor_columns = self._extract_regressor_columns(prophet_data)
|
||||
|
||||
# Automatically optimize hyperparameters (this is the new part)
|
||||
logger.info(f"Optimizing hyperparameters for {inventory_product_id}...")
|
||||
best_params = await self._optimize_hyperparameters(prophet_data, inventory_product_id, regressor_columns)
|
||||
|
||||
# Create optimized Prophet model
|
||||
model = self._create_optimized_prophet_model(best_params, regressor_columns)
|
||||
|
||||
# Add regressors to model
|
||||
for regressor in regressor_columns:
|
||||
if regressor in prophet_data.columns:
|
||||
model.add_regressor(regressor)
|
||||
|
||||
# Fit the model
|
||||
model.fit(prophet_data)
|
||||
|
||||
# Calculate enhanced training metrics first
|
||||
training_metrics = await self._calculate_training_metrics(model, prophet_data, best_params)
|
||||
|
||||
# Store model and metrics - Generate proper UUID for model_id
|
||||
model_id = str(uuid.uuid4())
|
||||
model_path = await self._store_model(
|
||||
tenant_id, inventory_product_id, model, model_id, prophet_data, regressor_columns, best_params, training_metrics
|
||||
)
|
||||
|
||||
# Return same format as before, but with optimization info
|
||||
model_info = {
|
||||
"model_id": model_id,
|
||||
"model_path": model_path,
|
||||
"type": "prophet_optimized", # Changed from "prophet"
|
||||
"training_samples": len(prophet_data),
|
||||
"features": regressor_columns,
|
||||
"hyperparameters": best_params, # Now contains optimized params
|
||||
"training_metrics": training_metrics,
|
||||
"trained_at": datetime.now().isoformat(),
|
||||
"data_period": {
|
||||
"start_date": prophet_data['ds'].min().isoformat(),
|
||||
"end_date": prophet_data['ds'].max().isoformat(),
|
||||
"total_days": len(prophet_data)
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(f"Optimized model trained successfully for {inventory_product_id}. "
|
||||
f"MAPE: {training_metrics.get('optimized_mape', 'N/A')}%")
|
||||
return model_info
|
||||
|
||||
async with self.database_manager.get_session() as session:
|
||||
async with lock.acquire(session):
|
||||
logger.info(f"Training optimized bakery model for {inventory_product_id} (lock acquired)")
|
||||
|
||||
# Validate input data
|
||||
await self._validate_training_data(df, inventory_product_id)
|
||||
|
||||
# Prepare data for Prophet
|
||||
prophet_data = await self._prepare_prophet_data(df)
|
||||
|
||||
# Get regressor columns
|
||||
regressor_columns = self._extract_regressor_columns(prophet_data)
|
||||
|
||||
# Automatically optimize hyperparameters
|
||||
logger.info(f"Optimizing hyperparameters for {inventory_product_id}...")
|
||||
best_params = await self._optimize_hyperparameters(prophet_data, inventory_product_id, regressor_columns)
|
||||
|
||||
# Create optimized Prophet model
|
||||
model = self._create_optimized_prophet_model(best_params, regressor_columns)
|
||||
|
||||
# Add regressors to model
|
||||
for regressor in regressor_columns:
|
||||
if regressor in prophet_data.columns:
|
||||
model.add_regressor(regressor)
|
||||
|
||||
# Fit the model
|
||||
model.fit(prophet_data)
|
||||
|
||||
# Calculate enhanced training metrics first
|
||||
training_metrics = await self._calculate_training_metrics(model, prophet_data, best_params)
|
||||
|
||||
# Store model and metrics - Generate proper UUID for model_id
|
||||
model_id = str(uuid.uuid4())
|
||||
model_path = await self._store_model(
|
||||
tenant_id, inventory_product_id, model, model_id, prophet_data, regressor_columns, best_params, training_metrics
|
||||
)
|
||||
|
||||
# Return same format as before, but with optimization info
|
||||
model_info = {
|
||||
"model_id": model_id,
|
||||
"model_path": model_path,
|
||||
"type": "prophet_optimized",
|
||||
"training_samples": len(prophet_data),
|
||||
"features": regressor_columns,
|
||||
"hyperparameters": best_params,
|
||||
"training_metrics": training_metrics,
|
||||
"trained_at": datetime.now().isoformat(),
|
||||
"data_period": {
|
||||
"start_date": prophet_data['ds'].min().isoformat(),
|
||||
"end_date": prophet_data['ds'].max().isoformat(),
|
||||
"total_days": len(prophet_data)
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(f"Optimized model trained successfully for {inventory_product_id}. "
|
||||
f"MAPE: {training_metrics.get('optimized_mape', 'N/A')}%")
|
||||
return model_info
|
||||
|
||||
except LockAcquisitionError as e:
|
||||
logger.warning(f"Could not acquire lock for {inventory_product_id}: {e}")
|
||||
raise RuntimeError(f"Training already in progress for product {inventory_product_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to train optimized bakery model for {inventory_product_id}: {str(e)}")
|
||||
raise
|
||||
@@ -134,11 +145,11 @@ class BakeryProphetManager:
|
||||
|
||||
# Set optimization parameters based on category
|
||||
n_trials = {
|
||||
'high_volume': 30, # Reduced from 75 for speed
|
||||
'medium_volume': 25, # Reduced from 50
|
||||
'low_volume': 20, # Reduced from 30
|
||||
'intermittent': 15 # Reduced from 25
|
||||
}.get(product_category, 25)
|
||||
'high_volume': const.OPTUNA_TRIALS_HIGH_VOLUME,
|
||||
'medium_volume': const.OPTUNA_TRIALS_MEDIUM_VOLUME,
|
||||
'low_volume': const.OPTUNA_TRIALS_LOW_VOLUME,
|
||||
'intermittent': const.OPTUNA_TRIALS_INTERMITTENT
|
||||
}.get(product_category, const.OPTUNA_TRIALS_MEDIUM_VOLUME)
|
||||
|
||||
logger.info(f"Product {inventory_product_id} classified as {product_category}, using {n_trials} trials")
|
||||
|
||||
@@ -152,7 +163,7 @@ class BakeryProphetManager:
|
||||
f"zero_ratio={zero_ratio:.2f}, mean_sales={mean_sales:.2f}, non_zero_days={non_zero_days}")
|
||||
|
||||
# Adjust strategy based on data characteristics
|
||||
if zero_ratio > 0.8 or non_zero_days < 30:
|
||||
if zero_ratio > const.MAX_ZERO_RATIO_INTERMITTENT or non_zero_days < const.MIN_NON_ZERO_DAYS:
|
||||
logger.warning(f"Very sparse data for {inventory_product_id}, using minimal optimization")
|
||||
return {
|
||||
'changepoint_prior_scale': 0.001,
|
||||
@@ -163,9 +174,9 @@ class BakeryProphetManager:
|
||||
'daily_seasonality': False,
|
||||
'weekly_seasonality': True,
|
||||
'yearly_seasonality': False,
|
||||
'uncertainty_samples': 100 # ✅ FIX: Minimal uncertainty sampling for very sparse data
|
||||
'uncertainty_samples': const.UNCERTAINTY_SAMPLES_SPARSE_MIN
|
||||
}
|
||||
elif zero_ratio > 0.6:
|
||||
elif zero_ratio > const.MODERATE_SPARSITY_THRESHOLD:
|
||||
logger.info(f"Moderate sparsity for {inventory_product_id}, using conservative optimization")
|
||||
return {
|
||||
'changepoint_prior_scale': 0.01,
|
||||
@@ -175,8 +186,8 @@ class BakeryProphetManager:
|
||||
'seasonality_mode': 'additive',
|
||||
'daily_seasonality': False,
|
||||
'weekly_seasonality': True,
|
||||
'yearly_seasonality': len(df) > 365, # Only if we have enough data
|
||||
'uncertainty_samples': 200 # ✅ FIX: Conservative uncertainty sampling for moderately sparse data
|
||||
'yearly_seasonality': len(df) > const.DATA_QUALITY_DAY_THRESHOLD_HIGH,
|
||||
'uncertainty_samples': const.UNCERTAINTY_SAMPLES_SPARSE_MAX
|
||||
}
|
||||
|
||||
# Use unique seed for each product to avoid identical results
|
||||
@@ -198,15 +209,15 @@ class BakeryProphetManager:
|
||||
changepoint_scale_range = (0.001, 0.5)
|
||||
seasonality_scale_range = (0.01, 10.0)
|
||||
|
||||
# ✅ FIX: Determine appropriate uncertainty samples range based on product category
|
||||
# Determine appropriate uncertainty samples range based on product category
|
||||
if product_category == 'high_volume':
|
||||
uncertainty_range = (300, 800) # More samples for stable high-volume products
|
||||
uncertainty_range = (const.UNCERTAINTY_SAMPLES_HIGH_MIN, const.UNCERTAINTY_SAMPLES_HIGH_MAX)
|
||||
elif product_category == 'medium_volume':
|
||||
uncertainty_range = (200, 500) # Moderate samples for medium volume
|
||||
uncertainty_range = (const.UNCERTAINTY_SAMPLES_MEDIUM_MIN, const.UNCERTAINTY_SAMPLES_MEDIUM_MAX)
|
||||
elif product_category == 'low_volume':
|
||||
uncertainty_range = (150, 300) # Fewer samples for low volume
|
||||
uncertainty_range = (const.UNCERTAINTY_SAMPLES_LOW_MIN, const.UNCERTAINTY_SAMPLES_LOW_MAX)
|
||||
else: # intermittent
|
||||
uncertainty_range = (100, 200) # Minimal samples for intermittent demand
|
||||
uncertainty_range = (const.UNCERTAINTY_SAMPLES_SPARSE_MIN, const.UNCERTAINTY_SAMPLES_SPARSE_MAX)
|
||||
|
||||
params = {
|
||||
'changepoint_prior_scale': trial.suggest_float(
|
||||
@@ -295,10 +306,10 @@ class BakeryProphetManager:
|
||||
|
||||
# Run optimization with product-specific seed
|
||||
study = optuna.create_study(
|
||||
direction='minimize',
|
||||
sampler=optuna.samplers.TPESampler(seed=product_seed) # Unique seed per product
|
||||
direction='minimize',
|
||||
sampler=optuna.samplers.TPESampler(seed=product_seed)
|
||||
)
|
||||
study.optimize(objective, n_trials=n_trials, timeout=600, show_progress_bar=False)
|
||||
study.optimize(objective, n_trials=n_trials, timeout=const.OPTUNA_TIMEOUT_SECONDS, show_progress_bar=False)
|
||||
|
||||
# Return best parameters
|
||||
best_params = study.best_params
|
||||
@@ -515,8 +526,12 @@ class BakeryProphetManager:
|
||||
# Store model file
|
||||
model_path = model_dir / f"{model_id}.pkl"
|
||||
joblib.dump(model, model_path)
|
||||
|
||||
# Enhanced metadata
|
||||
|
||||
# Calculate checksum for model file integrity
|
||||
checksummed_file = ChecksummedFile(str(model_path))
|
||||
model_checksum = checksummed_file.calculate_and_save_checksum()
|
||||
|
||||
# Enhanced metadata with checksum
|
||||
metadata = {
|
||||
"model_id": model_id,
|
||||
"tenant_id": tenant_id,
|
||||
@@ -531,9 +546,11 @@ class BakeryProphetManager:
|
||||
"optimized_parameters": optimized_params or {},
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"model_type": "prophet_optimized",
|
||||
"file_path": str(model_path)
|
||||
"file_path": str(model_path),
|
||||
"checksum": model_checksum,
|
||||
"checksum_algorithm": "sha256"
|
||||
}
|
||||
|
||||
|
||||
metadata_path = model_path.with_suffix('.json')
|
||||
with open(metadata_path, 'w') as f:
|
||||
json.dump(metadata, f, indent=2, default=str)
|
||||
@@ -609,23 +626,29 @@ class BakeryProphetManager:
|
||||
logger.error(f"Failed to deactivate previous models: {str(e)}")
|
||||
raise
|
||||
|
||||
# Keep all existing methods unchanged
|
||||
async def generate_forecast(self,
|
||||
async def generate_forecast(self,
|
||||
model_path: str,
|
||||
future_dates: pd.DataFrame,
|
||||
regressor_columns: List[str]) -> pd.DataFrame:
|
||||
"""Generate forecast using stored model (unchanged)"""
|
||||
"""Generate forecast using stored model with checksum verification"""
|
||||
try:
|
||||
# Verify model file integrity before loading
|
||||
checksummed_file = ChecksummedFile(model_path)
|
||||
if not checksummed_file.load_and_verify_checksum():
|
||||
logger.warning(f"Checksum verification failed for model: {model_path}")
|
||||
# Still load the model but log warning
|
||||
# In production, you might want to raise an exception instead
|
||||
|
||||
model = joblib.load(model_path)
|
||||
|
||||
|
||||
for regressor in regressor_columns:
|
||||
if regressor not in future_dates.columns:
|
||||
logger.warning(f"Missing regressor {regressor}, filling with median")
|
||||
future_dates[regressor] = 0
|
||||
|
||||
|
||||
forecast = model.predict(future_dates)
|
||||
return forecast
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate forecast: {str(e)}")
|
||||
raise
|
||||
@@ -655,34 +678,28 @@ class BakeryProphetManager:
|
||||
async def _prepare_prophet_data(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Prepare data for Prophet training with timezone handling"""
|
||||
prophet_data = df.copy()
|
||||
|
||||
|
||||
if 'ds' not in prophet_data.columns:
|
||||
raise ValueError("Missing 'ds' column in training data")
|
||||
if 'y' not in prophet_data.columns:
|
||||
raise ValueError("Missing 'y' column in training data")
|
||||
|
||||
# Convert to datetime and remove timezone information
|
||||
prophet_data['ds'] = pd.to_datetime(prophet_data['ds'])
|
||||
|
||||
# Remove timezone if present (Prophet doesn't support timezones)
|
||||
if prophet_data['ds'].dt.tz is not None:
|
||||
logger.info("Removing timezone information from 'ds' column for Prophet compatibility")
|
||||
prophet_data['ds'] = prophet_data['ds'].dt.tz_localize(None)
|
||||
|
||||
|
||||
# Use timezone utility to prepare Prophet-compatible datetime
|
||||
prophet_data = prepare_prophet_datetime(prophet_data, 'ds')
|
||||
|
||||
# Sort by date and clean data
|
||||
prophet_data = prophet_data.sort_values('ds').reset_index(drop=True)
|
||||
prophet_data['y'] = pd.to_numeric(prophet_data['y'], errors='coerce')
|
||||
prophet_data = prophet_data.dropna(subset=['y'])
|
||||
|
||||
# Additional data cleaning for Prophet
|
||||
|
||||
# Remove any duplicate dates (keep last occurrence)
|
||||
prophet_data = prophet_data.drop_duplicates(subset=['ds'], keep='last')
|
||||
|
||||
# Ensure y values are non-negative (Prophet works better with non-negative values)
|
||||
|
||||
# Ensure y values are non-negative
|
||||
prophet_data['y'] = prophet_data['y'].clip(lower=0)
|
||||
|
||||
|
||||
logger.info(f"Prepared Prophet data: {len(prophet_data)} rows, date range: {prophet_data['ds'].min()} to {prophet_data['ds'].max()}")
|
||||
|
||||
|
||||
return prophet_data
|
||||
|
||||
def _extract_regressor_columns(self, df: pd.DataFrame) -> List[str]:
|
||||
|
||||
@@ -10,6 +10,7 @@ from datetime import datetime
|
||||
import structlog
|
||||
import uuid
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
from app.ml.data_processor import EnhancedBakeryDataProcessor
|
||||
from app.ml.prophet_manager import BakeryProphetManager
|
||||
@@ -28,7 +29,13 @@ from app.repositories import (
|
||||
ArtifactRepository
|
||||
)
|
||||
|
||||
from app.services.messaging import TrainingStatusPublisher
|
||||
from app.services.progress_tracker import ParallelProductProgressTracker
|
||||
from app.services.training_events import (
|
||||
publish_training_started,
|
||||
publish_data_analysis,
|
||||
publish_training_completed,
|
||||
publish_training_failed
|
||||
)
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
@@ -75,8 +82,6 @@ class EnhancedBakeryMLTrainer:
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id)
|
||||
|
||||
self.status_publisher = TrainingStatusPublisher(job_id, tenant_id)
|
||||
|
||||
try:
|
||||
# Get database session and repositories
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
@@ -113,8 +118,10 @@ class EnhancedBakeryMLTrainer:
|
||||
else:
|
||||
logger.info("Multiple products detected for training",
|
||||
products_count=len(products))
|
||||
|
||||
self.status_publisher.products_total = len(products)
|
||||
|
||||
# Event 1: Training Started (0%) - update with actual product count
|
||||
# Note: Initial event was already published by API endpoint, this updates with real count
|
||||
await publish_training_started(job_id, tenant_id, len(products))
|
||||
|
||||
# Create initial training log entry
|
||||
await repos['training_log'].update_log_progress(
|
||||
@@ -126,28 +133,45 @@ class EnhancedBakeryMLTrainer:
|
||||
processed_data = await self._process_all_products_enhanced(
|
||||
sales_df, weather_df, traffic_df, products, tenant_id, job_id
|
||||
)
|
||||
|
||||
await self.status_publisher.progress_update(
|
||||
progress=20,
|
||||
step="feature_engineering",
|
||||
step_details="Enhanced processing with repository tracking"
|
||||
|
||||
# Event 2: Data Analysis (20%)
|
||||
await publish_data_analysis(
|
||||
job_id,
|
||||
tenant_id,
|
||||
f"Data analysis completed for {len(processed_data)} products"
|
||||
)
|
||||
|
||||
# Train models for each processed product
|
||||
logger.info("Training models with repository integration")
|
||||
# Train models for each processed product with progress aggregation
|
||||
logger.info("Training models with repository integration and progress aggregation")
|
||||
|
||||
# Create progress tracker for parallel product training (20-80%)
|
||||
progress_tracker = ParallelProductProgressTracker(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
total_products=len(processed_data)
|
||||
)
|
||||
|
||||
training_results = await self._train_all_models_enhanced(
|
||||
tenant_id, processed_data, job_id, repos
|
||||
tenant_id, processed_data, job_id, repos, progress_tracker
|
||||
)
|
||||
|
||||
# Calculate overall training summary with enhanced metrics
|
||||
summary = await self._calculate_enhanced_training_summary(
|
||||
training_results, repos, tenant_id
|
||||
)
|
||||
|
||||
await self.status_publisher.progress_update(
|
||||
progress=90,
|
||||
step="model_validation",
|
||||
step_details="Enhanced validation with repository tracking"
|
||||
|
||||
# Calculate successful and failed trainings
|
||||
successful_trainings = len([r for r in training_results.values() if r.get('status') == 'success'])
|
||||
failed_trainings = len([r for r in training_results.values() if r.get('status') == 'error'])
|
||||
total_duration = sum([r.get('training_time_seconds', 0) for r in training_results.values()])
|
||||
|
||||
# Event 4: Training Completed (100%)
|
||||
await publish_training_completed(
|
||||
job_id,
|
||||
tenant_id,
|
||||
successful_trainings,
|
||||
failed_trainings,
|
||||
total_duration
|
||||
)
|
||||
|
||||
# Create comprehensive result with repository data
|
||||
@@ -189,6 +213,10 @@ class EnhancedBakeryMLTrainer:
|
||||
logger.error("Enhanced ML training pipeline failed",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
|
||||
# Publish training failed event
|
||||
await publish_training_failed(job_id, tenant_id, str(e))
|
||||
|
||||
raise
|
||||
|
||||
async def _process_all_products_enhanced(self,
|
||||
@@ -237,111 +265,158 @@ class EnhancedBakeryMLTrainer:
|
||||
|
||||
return processed_data
|
||||
|
||||
async def _train_single_product(self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
product_data: pd.DataFrame,
|
||||
job_id: str,
|
||||
repos: Dict,
|
||||
progress_tracker: ParallelProductProgressTracker) -> tuple[str, Dict[str, Any]]:
|
||||
"""Train a single product model - used for parallel execution with progress aggregation"""
|
||||
product_start_time = time.time()
|
||||
|
||||
try:
|
||||
logger.info("Training model", inventory_product_id=inventory_product_id)
|
||||
|
||||
# Check if we have enough data
|
||||
if len(product_data) < settings.MIN_TRAINING_DATA_DAYS:
|
||||
result = {
|
||||
'status': 'skipped',
|
||||
'reason': 'insufficient_data',
|
||||
'data_points': len(product_data),
|
||||
'min_required': settings.MIN_TRAINING_DATA_DAYS,
|
||||
'message': f'Need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(product_data)}'
|
||||
}
|
||||
logger.warning("Skipping product due to insufficient data",
|
||||
inventory_product_id=inventory_product_id,
|
||||
data_points=len(product_data),
|
||||
min_required=settings.MIN_TRAINING_DATA_DAYS)
|
||||
return inventory_product_id, result
|
||||
|
||||
# Train the model using Prophet manager
|
||||
model_info = await self.prophet_manager.train_bakery_model(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
df=product_data,
|
||||
job_id=job_id
|
||||
)
|
||||
|
||||
# Store model record using repository
|
||||
model_record = await self._create_model_record(
|
||||
repos, tenant_id, inventory_product_id, model_info, job_id, product_data
|
||||
)
|
||||
|
||||
# Create performance metrics record
|
||||
if model_info.get('training_metrics'):
|
||||
await self._create_performance_metrics(
|
||||
repos, model_record.id if model_record else None,
|
||||
tenant_id, inventory_product_id, model_info['training_metrics']
|
||||
)
|
||||
|
||||
result = {
|
||||
'status': 'success',
|
||||
'model_info': model_info,
|
||||
'model_record_id': model_record.id if model_record else None,
|
||||
'data_points': len(product_data),
|
||||
'training_time_seconds': time.time() - product_start_time,
|
||||
'trained_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
logger.info("Successfully trained model",
|
||||
inventory_product_id=inventory_product_id,
|
||||
model_record_id=model_record.id if model_record else None)
|
||||
|
||||
# Report completion to progress tracker (emits Event 3: product_completed)
|
||||
await progress_tracker.mark_product_completed(inventory_product_id)
|
||||
|
||||
return inventory_product_id, result
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to train model",
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e))
|
||||
result = {
|
||||
'status': 'error',
|
||||
'error_message': str(e),
|
||||
'data_points': len(product_data) if product_data is not None else 0,
|
||||
'training_time_seconds': time.time() - product_start_time,
|
||||
'failed_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Report failure to progress tracker (still emits Event 3: product_completed)
|
||||
await progress_tracker.mark_product_completed(inventory_product_id)
|
||||
|
||||
return inventory_product_id, result
|
||||
|
||||
async def _train_all_models_enhanced(self,
|
||||
tenant_id: str,
|
||||
processed_data: Dict[str, pd.DataFrame],
|
||||
job_id: str,
|
||||
repos: Dict) -> Dict[str, Any]:
|
||||
"""Train models with enhanced repository integration"""
|
||||
training_results = {}
|
||||
i = 0
|
||||
repos: Dict,
|
||||
progress_tracker: ParallelProductProgressTracker) -> Dict[str, Any]:
|
||||
"""Train models with throttled parallel execution and progress tracking"""
|
||||
total_products = len(processed_data)
|
||||
base_progress = 45
|
||||
max_progress = 85
|
||||
logger.info(f"Starting throttled parallel training for {total_products} products")
|
||||
|
||||
# Create training tasks for all products
|
||||
training_tasks = [
|
||||
self._train_single_product(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
product_data=product_data,
|
||||
job_id=job_id,
|
||||
repos=repos,
|
||||
progress_tracker=progress_tracker
|
||||
)
|
||||
for inventory_product_id, product_data in processed_data.items()
|
||||
]
|
||||
|
||||
# Execute training tasks with throttling to prevent heartbeat blocking
|
||||
# Limit concurrent operations to prevent CPU/memory exhaustion
|
||||
from app.core.config import settings
|
||||
max_concurrent = getattr(settings, 'MAX_CONCURRENT_TRAININGS', 3)
|
||||
|
||||
for inventory_product_id, product_data in processed_data.items():
|
||||
product_start_time = time.time()
|
||||
try:
|
||||
logger.info("Training enhanced model",
|
||||
inventory_product_id=inventory_product_id)
|
||||
|
||||
# Check if we have enough data
|
||||
if len(product_data) < settings.MIN_TRAINING_DATA_DAYS:
|
||||
training_results[inventory_product_id] = {
|
||||
'status': 'skipped',
|
||||
'reason': 'insufficient_data',
|
||||
'data_points': len(product_data),
|
||||
'min_required': settings.MIN_TRAINING_DATA_DAYS,
|
||||
'message': f'Need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(product_data)}'
|
||||
}
|
||||
logger.warning("Skipping product due to insufficient data",
|
||||
inventory_product_id=inventory_product_id,
|
||||
data_points=len(product_data),
|
||||
min_required=settings.MIN_TRAINING_DATA_DAYS)
|
||||
continue
|
||||
|
||||
# Train the model using Prophet manager
|
||||
model_info = await self.prophet_manager.train_bakery_model(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
df=product_data,
|
||||
job_id=job_id
|
||||
)
|
||||
|
||||
# Store model record using repository
|
||||
model_record = await self._create_model_record(
|
||||
repos, tenant_id, inventory_product_id, model_info, job_id, product_data
|
||||
)
|
||||
|
||||
# Create performance metrics record
|
||||
if model_info.get('training_metrics'):
|
||||
await self._create_performance_metrics(
|
||||
repos, model_record.id if model_record else None,
|
||||
tenant_id, inventory_product_id, model_info['training_metrics']
|
||||
)
|
||||
|
||||
training_results[inventory_product_id] = {
|
||||
'status': 'success',
|
||||
'model_info': model_info,
|
||||
'model_record_id': model_record.id if model_record else None,
|
||||
'data_points': len(product_data),
|
||||
'training_time_seconds': time.time() - product_start_time,
|
||||
'trained_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
logger.info("Successfully trained enhanced model",
|
||||
inventory_product_id=inventory_product_id,
|
||||
model_record_id=model_record.id if model_record else None)
|
||||
|
||||
completed_products = i + 1
|
||||
i += 1
|
||||
progress = base_progress + int((completed_products / total_products) * (max_progress - base_progress))
|
||||
|
||||
if self.status_publisher:
|
||||
self.status_publisher.products_completed = completed_products
|
||||
|
||||
await self.status_publisher.progress_update(
|
||||
progress=progress,
|
||||
step="model_training",
|
||||
current_product=inventory_product_id,
|
||||
step_details=f"Enhanced training completed for {inventory_product_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to train enhanced model",
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e))
|
||||
training_results[inventory_product_id] = {
|
||||
'status': 'error',
|
||||
'error_message': str(e),
|
||||
'data_points': len(product_data) if product_data is not None else 0,
|
||||
'training_time_seconds': time.time() - product_start_time,
|
||||
'failed_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
completed_products = i + 1
|
||||
i += 1
|
||||
progress = base_progress + int((completed_products / total_products) * (max_progress - base_progress))
|
||||
|
||||
if self.status_publisher:
|
||||
self.status_publisher.products_completed = completed_products
|
||||
await self.status_publisher.progress_update(
|
||||
progress=progress,
|
||||
step="model_training",
|
||||
current_product=inventory_product_id,
|
||||
step_details=f"Enhanced training failed for {inventory_product_id}: {str(e)}"
|
||||
)
|
||||
|
||||
logger.info(f"Executing training with max {max_concurrent} concurrent operations",
|
||||
total_products=total_products)
|
||||
|
||||
# Process tasks in batches to prevent blocking the event loop
|
||||
results_list = []
|
||||
for i in range(0, len(training_tasks), max_concurrent):
|
||||
batch = training_tasks[i:i + max_concurrent]
|
||||
batch_results = await asyncio.gather(*batch, return_exceptions=True)
|
||||
results_list.extend(batch_results)
|
||||
|
||||
# Yield control to event loop to allow heartbeat processing
|
||||
# Increased from 0.01s to 0.1s (100ms) to ensure WebSocket pings, RabbitMQ heartbeats,
|
||||
# and progress events can be processed during long training operations
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Log progress to verify event loop is responsive
|
||||
logger.debug(
|
||||
"Training batch completed, yielding to event loop",
|
||||
batch_num=(i // max_concurrent) + 1,
|
||||
total_batches=(len(training_tasks) + max_concurrent - 1) // max_concurrent,
|
||||
products_completed=len(results_list),
|
||||
total_products=len(training_tasks)
|
||||
)
|
||||
|
||||
# Log final summary
|
||||
summary = progress_tracker.get_progress()
|
||||
logger.info("Throttled parallel training completed",
|
||||
total=summary['total_products'],
|
||||
completed=summary['products_completed'])
|
||||
|
||||
# Convert results to dictionary
|
||||
training_results = {}
|
||||
for result in results_list:
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Training task failed with exception: {result}")
|
||||
continue
|
||||
|
||||
product_id, product_result = result
|
||||
training_results[product_id] = product_result
|
||||
|
||||
logger.info(f"Throttled parallel training completed: {len(training_results)} products processed")
|
||||
return training_results
|
||||
|
||||
async def _create_model_record(self,
|
||||
@@ -655,7 +730,3 @@ class EnhancedBakeryMLTrainer:
|
||||
except Exception as e:
|
||||
logger.error("Enhanced model evaluation failed", error=str(e))
|
||||
raise
|
||||
|
||||
|
||||
# Legacy compatibility alias
|
||||
BakeryMLTrainer = EnhancedBakeryMLTrainer
|
||||
317
services/training/app/schemas/validation.py
Normal file
317
services/training/app/schemas/validation.py
Normal file
@@ -0,0 +1,317 @@
|
||||
"""
|
||||
Comprehensive Input Validation Schemas
|
||||
Ensures all API inputs are properly validated before processing
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field, validator, root_validator
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
from uuid import UUID
|
||||
import re
|
||||
|
||||
|
||||
class TrainingJobCreateRequest(BaseModel):
|
||||
"""Schema for creating a new training job"""
|
||||
|
||||
tenant_id: UUID = Field(..., description="Tenant identifier")
|
||||
start_date: Optional[str] = Field(
|
||||
None,
|
||||
description="Training data start date (ISO format: YYYY-MM-DD)",
|
||||
example="2024-01-01"
|
||||
)
|
||||
end_date: Optional[str] = Field(
|
||||
None,
|
||||
description="Training data end date (ISO format: YYYY-MM-DD)",
|
||||
example="2024-12-31"
|
||||
)
|
||||
product_ids: Optional[List[UUID]] = Field(
|
||||
None,
|
||||
description="Specific products to train (optional, trains all if not provided)"
|
||||
)
|
||||
force_retrain: bool = Field(
|
||||
default=False,
|
||||
description="Force retraining even if recent models exist"
|
||||
)
|
||||
|
||||
@validator('start_date', 'end_date')
|
||||
def validate_date_format(cls, v):
|
||||
"""Validate date is in ISO format"""
|
||||
if v is not None:
|
||||
try:
|
||||
datetime.fromisoformat(v)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid date format: {v}. Use YYYY-MM-DD format")
|
||||
return v
|
||||
|
||||
@root_validator
|
||||
def validate_date_range(cls, values):
|
||||
"""Validate date range is logical"""
|
||||
start = values.get('start_date')
|
||||
end = values.get('end_date')
|
||||
|
||||
if start and end:
|
||||
start_dt = datetime.fromisoformat(start)
|
||||
end_dt = datetime.fromisoformat(end)
|
||||
|
||||
if end_dt <= start_dt:
|
||||
raise ValueError("end_date must be after start_date")
|
||||
|
||||
# Check reasonable range (max 3 years)
|
||||
if (end_dt - start_dt).days > 1095:
|
||||
raise ValueError("Date range cannot exceed 3 years (1095 days)")
|
||||
|
||||
# Check not in future
|
||||
if end_dt > datetime.now():
|
||||
raise ValueError("end_date cannot be in the future")
|
||||
|
||||
return values
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"tenant_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"start_date": "2024-01-01",
|
||||
"end_date": "2024-12-31",
|
||||
"product_ids": None,
|
||||
"force_retrain": False
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ForecastRequest(BaseModel):
|
||||
"""Schema for generating forecasts"""
|
||||
|
||||
tenant_id: UUID = Field(..., description="Tenant identifier")
|
||||
product_id: UUID = Field(..., description="Product identifier")
|
||||
forecast_days: int = Field(
|
||||
default=30,
|
||||
ge=1,
|
||||
le=365,
|
||||
description="Number of days to forecast (1-365)"
|
||||
)
|
||||
include_regressors: bool = Field(
|
||||
default=True,
|
||||
description="Include weather and traffic data in forecast"
|
||||
)
|
||||
confidence_level: float = Field(
|
||||
default=0.80,
|
||||
ge=0.5,
|
||||
le=0.99,
|
||||
description="Confidence interval (0.5-0.99)"
|
||||
)
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"tenant_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"product_id": "223e4567-e89b-12d3-a456-426614174000",
|
||||
"forecast_days": 30,
|
||||
"include_regressors": True,
|
||||
"confidence_level": 0.80
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ModelEvaluationRequest(BaseModel):
|
||||
"""Schema for model evaluation"""
|
||||
|
||||
tenant_id: UUID = Field(..., description="Tenant identifier")
|
||||
product_id: Optional[UUID] = Field(None, description="Specific product (optional)")
|
||||
evaluation_start_date: str = Field(..., description="Evaluation period start")
|
||||
evaluation_end_date: str = Field(..., description="Evaluation period end")
|
||||
|
||||
@validator('evaluation_start_date', 'evaluation_end_date')
|
||||
def validate_date_format(cls, v):
|
||||
try:
|
||||
datetime.fromisoformat(v)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid date format: {v}")
|
||||
return v
|
||||
|
||||
@root_validator
|
||||
def validate_evaluation_period(cls, values):
|
||||
start = values.get('evaluation_start_date')
|
||||
end = values.get('evaluation_end_date')
|
||||
|
||||
if start and end:
|
||||
start_dt = datetime.fromisoformat(start)
|
||||
end_dt = datetime.fromisoformat(end)
|
||||
|
||||
if end_dt <= start_dt:
|
||||
raise ValueError("evaluation_end_date must be after evaluation_start_date")
|
||||
|
||||
# Minimum 7 days for meaningful evaluation
|
||||
if (end_dt - start_dt).days < 7:
|
||||
raise ValueError("Evaluation period must be at least 7 days")
|
||||
|
||||
return values
|
||||
|
||||
|
||||
class BulkTrainingRequest(BaseModel):
|
||||
"""Schema for bulk training operations"""
|
||||
|
||||
tenant_ids: List[UUID] = Field(
|
||||
...,
|
||||
min_items=1,
|
||||
max_items=100,
|
||||
description="List of tenant IDs (max 100)"
|
||||
)
|
||||
start_date: Optional[str] = Field(None, description="Common start date")
|
||||
end_date: Optional[str] = Field(None, description="Common end date")
|
||||
parallel: bool = Field(
|
||||
default=True,
|
||||
description="Execute training jobs in parallel"
|
||||
)
|
||||
|
||||
@validator('tenant_ids')
|
||||
def validate_unique_tenants(cls, v):
|
||||
if len(v) != len(set(v)):
|
||||
raise ValueError("Duplicate tenant IDs not allowed")
|
||||
return v
|
||||
|
||||
|
||||
class HyperparameterOverride(BaseModel):
|
||||
"""Schema for manual hyperparameter override"""
|
||||
|
||||
changepoint_prior_scale: Optional[float] = Field(
|
||||
None, ge=0.001, le=0.5,
|
||||
description="Flexibility of trend changes"
|
||||
)
|
||||
seasonality_prior_scale: Optional[float] = Field(
|
||||
None, ge=0.01, le=10.0,
|
||||
description="Strength of seasonality"
|
||||
)
|
||||
holidays_prior_scale: Optional[float] = Field(
|
||||
None, ge=0.01, le=10.0,
|
||||
description="Strength of holiday effects"
|
||||
)
|
||||
seasonality_mode: Optional[str] = Field(
|
||||
None,
|
||||
description="Seasonality mode",
|
||||
regex="^(additive|multiplicative)$"
|
||||
)
|
||||
daily_seasonality: Optional[bool] = None
|
||||
weekly_seasonality: Optional[bool] = None
|
||||
yearly_seasonality: Optional[bool] = None
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"changepoint_prior_scale": 0.05,
|
||||
"seasonality_prior_scale": 10.0,
|
||||
"holidays_prior_scale": 10.0,
|
||||
"seasonality_mode": "additive",
|
||||
"daily_seasonality": False,
|
||||
"weekly_seasonality": True,
|
||||
"yearly_seasonality": True
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class AdvancedTrainingRequest(TrainingJobCreateRequest):
|
||||
"""Extended training request with advanced options"""
|
||||
|
||||
hyperparameter_override: Optional[HyperparameterOverride] = Field(
|
||||
None,
|
||||
description="Manual hyperparameter settings (skips optimization)"
|
||||
)
|
||||
enable_cross_validation: bool = Field(
|
||||
default=True,
|
||||
description="Enable cross-validation during training"
|
||||
)
|
||||
cv_folds: int = Field(
|
||||
default=3,
|
||||
ge=2,
|
||||
le=10,
|
||||
description="Number of cross-validation folds"
|
||||
)
|
||||
optimization_trials: Optional[int] = Field(
|
||||
None,
|
||||
ge=5,
|
||||
le=100,
|
||||
description="Number of hyperparameter optimization trials (overrides defaults)"
|
||||
)
|
||||
save_diagnostics: bool = Field(
|
||||
default=False,
|
||||
description="Save detailed diagnostic plots and metrics"
|
||||
)
|
||||
|
||||
|
||||
class DataQualityCheckRequest(BaseModel):
|
||||
"""Schema for data quality validation"""
|
||||
|
||||
tenant_id: UUID = Field(..., description="Tenant identifier")
|
||||
start_date: str = Field(..., description="Check period start")
|
||||
end_date: str = Field(..., description="Check period end")
|
||||
product_ids: Optional[List[UUID]] = Field(
|
||||
None,
|
||||
description="Specific products to check"
|
||||
)
|
||||
include_recommendations: bool = Field(
|
||||
default=True,
|
||||
description="Include improvement recommendations"
|
||||
)
|
||||
|
||||
@validator('start_date', 'end_date')
|
||||
def validate_date(cls, v):
|
||||
try:
|
||||
datetime.fromisoformat(v)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid date format: {v}")
|
||||
return v
|
||||
|
||||
|
||||
class ModelQueryParams(BaseModel):
|
||||
"""Query parameters for model listing"""
|
||||
|
||||
tenant_id: Optional[UUID] = None
|
||||
product_id: Optional[UUID] = None
|
||||
is_active: Optional[bool] = None
|
||||
is_production: Optional[bool] = None
|
||||
model_type: Optional[str] = Field(None, regex="^(prophet|prophet_optimized|lstm|arima)$")
|
||||
min_accuracy: Optional[float] = Field(None, ge=0.0, le=1.0)
|
||||
created_after: Optional[datetime] = None
|
||||
created_before: Optional[datetime] = None
|
||||
limit: int = Field(default=100, ge=1, le=1000)
|
||||
offset: int = Field(default=0, ge=0)
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"tenant_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"is_active": True,
|
||||
"is_production": True,
|
||||
"limit": 50,
|
||||
"offset": 0
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def validate_uuid(value: str) -> UUID:
|
||||
"""Validate and convert string to UUID"""
|
||||
try:
|
||||
return UUID(value)
|
||||
except (ValueError, AttributeError):
|
||||
raise ValueError(f"Invalid UUID format: {value}")
|
||||
|
||||
|
||||
def validate_date_string(value: str) -> datetime:
|
||||
"""Validate and convert date string to datetime"""
|
||||
try:
|
||||
return datetime.fromisoformat(value)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid date format: {value}. Use ISO format (YYYY-MM-DD)")
|
||||
|
||||
|
||||
def validate_positive_integer(value: int, field_name: str = "value") -> int:
|
||||
"""Validate positive integer"""
|
||||
if value <= 0:
|
||||
raise ValueError(f"{field_name} must be positive, got {value}")
|
||||
return value
|
||||
|
||||
|
||||
def validate_probability(value: float, field_name: str = "value") -> float:
|
||||
"""Validate probability value (0.0-1.0)"""
|
||||
if not 0.0 <= value <= 1.0:
|
||||
raise ValueError(f"{field_name} must be between 0.0 and 1.0, got {value}")
|
||||
return value
|
||||
@@ -3,32 +3,14 @@ Training Service Layer
|
||||
Business logic services for ML training and model management
|
||||
"""
|
||||
|
||||
from .training_service import TrainingService
|
||||
from .training_service import EnhancedTrainingService
|
||||
from .training_orchestrator import TrainingDataOrchestrator
|
||||
from .date_alignment_service import DateAlignmentService
|
||||
from .data_client import DataClient
|
||||
from .messaging import (
|
||||
publish_job_progress,
|
||||
publish_data_validation_started,
|
||||
publish_data_validation_completed,
|
||||
publish_job_step_completed,
|
||||
publish_job_completed,
|
||||
publish_job_failed,
|
||||
TrainingStatusPublisher
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"TrainingService",
|
||||
"EnhancedTrainingService",
|
||||
"TrainingDataOrchestrator",
|
||||
"TrainingDataOrchestrator",
|
||||
"DateAlignmentService",
|
||||
"DataClient",
|
||||
"publish_job_progress",
|
||||
"publish_data_validation_started",
|
||||
"publish_data_validation_completed",
|
||||
"publish_job_step_completed",
|
||||
"publish_job_completed",
|
||||
"publish_job_failed",
|
||||
"TrainingStatusPublisher"
|
||||
"DataClient"
|
||||
]
|
||||
@@ -1,16 +1,20 @@
|
||||
# services/training/app/services/data_client.py
|
||||
"""
|
||||
Training Service Data Client
|
||||
Migrated to use shared service clients - much simpler now!
|
||||
Migrated to use shared service clients with timeout configuration
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
import httpx
|
||||
|
||||
# Import the shared clients
|
||||
from shared.clients import get_sales_client, get_external_client, get_service_clients
|
||||
from app.core.config import settings
|
||||
from app.core import constants as const
|
||||
from app.utils.circuit_breaker import circuit_breaker_registry, CircuitBreakerError
|
||||
from app.utils.retry import with_retry, HTTP_RETRY_STRATEGY, EXTERNAL_SERVICE_RETRY_STRATEGY
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
@@ -21,21 +25,103 @@ class DataClient:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Get the new specialized clients
|
||||
# Get the new specialized clients with timeout configuration
|
||||
self.sales_client = get_sales_client(settings, "training")
|
||||
self.external_client = get_external_client(settings, "training")
|
||||
|
||||
|
||||
# Configure timeouts for HTTP clients
|
||||
self._configure_timeouts()
|
||||
|
||||
# Initialize circuit breakers for external services
|
||||
self._init_circuit_breakers()
|
||||
|
||||
# Check if the new method is available for stored traffic data
|
||||
if hasattr(self.external_client, 'get_stored_traffic_data_for_training'):
|
||||
self.supports_stored_traffic_data = True
|
||||
|
||||
def _configure_timeouts(self):
|
||||
"""Configure appropriate timeouts for HTTP clients"""
|
||||
timeout = httpx.Timeout(
|
||||
connect=const.HTTP_TIMEOUT_DEFAULT,
|
||||
read=const.HTTP_TIMEOUT_LONG_RUNNING,
|
||||
write=const.HTTP_TIMEOUT_DEFAULT,
|
||||
pool=const.HTTP_TIMEOUT_DEFAULT
|
||||
)
|
||||
|
||||
# Apply timeout to clients if they have httpx clients
|
||||
if hasattr(self.sales_client, 'client') and isinstance(self.sales_client.client, httpx.AsyncClient):
|
||||
self.sales_client.client.timeout = timeout
|
||||
|
||||
if hasattr(self.external_client, 'client') and isinstance(self.external_client.client, httpx.AsyncClient):
|
||||
self.external_client.client.timeout = timeout
|
||||
else:
|
||||
self.supports_stored_traffic_data = False
|
||||
logger.warning("Stored traffic data method not available in external client")
|
||||
|
||||
# Or alternatively, get all clients at once:
|
||||
# self.clients = get_service_clients(settings, "training")
|
||||
# Then use: self.clients.sales.get_sales_data(...) and self.clients.external.get_weather_forecast(...)
|
||||
|
||||
def _init_circuit_breakers(self):
|
||||
"""Initialize circuit breakers for external service calls"""
|
||||
# Sales service circuit breaker
|
||||
self.sales_cb = circuit_breaker_registry.get_or_create(
|
||||
name="sales_service",
|
||||
failure_threshold=5,
|
||||
recovery_timeout=60.0,
|
||||
expected_exception=Exception
|
||||
)
|
||||
|
||||
# Weather service circuit breaker
|
||||
self.weather_cb = circuit_breaker_registry.get_or_create(
|
||||
name="weather_service",
|
||||
failure_threshold=3, # Weather is optional, fail faster
|
||||
recovery_timeout=30.0,
|
||||
expected_exception=Exception
|
||||
)
|
||||
|
||||
# Traffic service circuit breaker
|
||||
self.traffic_cb = circuit_breaker_registry.get_or_create(
|
||||
name="traffic_service",
|
||||
failure_threshold=3, # Traffic is optional, fail faster
|
||||
recovery_timeout=30.0,
|
||||
expected_exception=Exception
|
||||
)
|
||||
|
||||
@with_retry(max_attempts=3, initial_delay=1.0, max_delay=10.0)
|
||||
async def _fetch_sales_data_internal(
|
||||
self,
|
||||
tenant_id: str,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
product_id: Optional[str] = None,
|
||||
fetch_all: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Internal method to fetch sales data with automatic retry"""
|
||||
if fetch_all:
|
||||
sales_data = await self.sales_client.get_all_sales_data(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
product_id=product_id,
|
||||
aggregation="daily",
|
||||
page_size=1000,
|
||||
max_pages=100
|
||||
)
|
||||
else:
|
||||
sales_data = await self.sales_client.get_sales_data(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
product_id=product_id,
|
||||
aggregation="daily"
|
||||
)
|
||||
sales_data = sales_data or []
|
||||
|
||||
if sales_data:
|
||||
logger.info(f"Fetched {len(sales_data)} sales records",
|
||||
tenant_id=tenant_id, product_id=product_id, fetch_all=fetch_all)
|
||||
return sales_data
|
||||
else:
|
||||
logger.error("No sales data returned", tenant_id=tenant_id)
|
||||
raise ValueError(f"No sales data available for tenant {tenant_id}")
|
||||
|
||||
async def fetch_sales_data(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@@ -45,50 +131,21 @@ class DataClient:
|
||||
fetch_all: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch sales data for training
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
start_date: Start date in ISO format
|
||||
end_date: End date in ISO format
|
||||
product_id: Optional product filter
|
||||
fetch_all: If True, fetches ALL records using pagination (original behavior)
|
||||
If False, fetches limited records (standard API response)
|
||||
Fetch sales data for training with circuit breaker protection
|
||||
"""
|
||||
try:
|
||||
if fetch_all:
|
||||
# Use paginated method to get ALL records (original behavior)
|
||||
sales_data = await self.sales_client.get_all_sales_data(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
product_id=product_id,
|
||||
aggregation="daily",
|
||||
page_size=1000, # Comply with API limit
|
||||
max_pages=100 # Safety limit (500k records max)
|
||||
)
|
||||
else:
|
||||
# Use standard method for limited results
|
||||
sales_data = await self.sales_client.get_sales_data(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
product_id=product_id,
|
||||
aggregation="daily"
|
||||
)
|
||||
sales_data = sales_data or []
|
||||
|
||||
if sales_data:
|
||||
logger.info(f"Fetched {len(sales_data)} sales records",
|
||||
tenant_id=tenant_id, product_id=product_id, fetch_all=fetch_all)
|
||||
return sales_data
|
||||
else:
|
||||
logger.warning("No sales data returned", tenant_id=tenant_id)
|
||||
return []
|
||||
|
||||
return await self.sales_cb.call(
|
||||
self._fetch_sales_data_internal,
|
||||
tenant_id, start_date, end_date, product_id, fetch_all
|
||||
)
|
||||
except CircuitBreakerError as e:
|
||||
logger.error(f"Sales service circuit breaker open: {e}")
|
||||
raise RuntimeError(f"Sales service unavailable: {str(e)}")
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching sales data: {e}", tenant_id=tenant_id)
|
||||
return []
|
||||
raise RuntimeError(f"Failed to fetch sales data: {str(e)}")
|
||||
|
||||
async def fetch_weather_data(
|
||||
self,
|
||||
@@ -112,15 +169,15 @@ class DataClient:
|
||||
)
|
||||
|
||||
if weather_data:
|
||||
logger.info(f"Fetched {len(weather_data)} weather records",
|
||||
logger.info(f"Fetched {len(weather_data)} weather records",
|
||||
tenant_id=tenant_id)
|
||||
return weather_data
|
||||
else:
|
||||
logger.warning("No weather data returned", tenant_id=tenant_id)
|
||||
logger.warning("No weather data returned, will use synthetic data", tenant_id=tenant_id)
|
||||
return []
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching weather data: {e}", tenant_id=tenant_id)
|
||||
logger.warning(f"Error fetching weather data, will use synthetic data: {e}", tenant_id=tenant_id)
|
||||
return []
|
||||
|
||||
async def fetch_traffic_data_unified(
|
||||
@@ -264,34 +321,93 @@ class DataClient:
|
||||
self,
|
||||
tenant_id: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
end_date: str,
|
||||
sales_data: List[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate data quality before training
|
||||
Validate data quality before training with comprehensive checks
|
||||
"""
|
||||
try:
|
||||
# Note: validation_data_quality may need to be implemented in one of the new services
|
||||
# validation_result = await self.sales_client.validate_data_quality(
|
||||
# tenant_id=tenant_id,
|
||||
# start_date=start_date,
|
||||
# end_date=end_date
|
||||
# )
|
||||
|
||||
# Temporary implementation - assume data is valid for now
|
||||
validation_result = {"is_valid": True, "message": "Validation temporarily disabled"}
|
||||
|
||||
if validation_result:
|
||||
logger.info("Data validation completed",
|
||||
tenant_id=tenant_id,
|
||||
is_valid=validation_result.get("is_valid", False))
|
||||
return validation_result
|
||||
errors = []
|
||||
warnings = []
|
||||
|
||||
# If sales data provided, validate it directly
|
||||
if sales_data is not None:
|
||||
if not sales_data or len(sales_data) == 0:
|
||||
errors.append("No sales data available for the specified period")
|
||||
return {"is_valid": False, "errors": errors, "warnings": warnings}
|
||||
|
||||
# Check minimum data points
|
||||
if len(sales_data) < 30:
|
||||
errors.append(f"Insufficient data points: {len(sales_data)} (minimum 30 required)")
|
||||
elif len(sales_data) < 90:
|
||||
warnings.append(f"Limited data points: {len(sales_data)} (recommended 90+)")
|
||||
|
||||
# Check for required fields
|
||||
required_fields = ['date', 'inventory_product_id']
|
||||
for record in sales_data[:5]: # Sample check
|
||||
missing = [f for f in required_fields if f not in record or record[f] is None]
|
||||
if missing:
|
||||
errors.append(f"Missing required fields: {missing}")
|
||||
break
|
||||
|
||||
# Check for data quality issues
|
||||
zero_count = sum(1 for r in sales_data if r.get('quantity', 0) == 0)
|
||||
zero_ratio = zero_count / len(sales_data)
|
||||
if zero_ratio > 0.9:
|
||||
errors.append(f"Too many zero values: {zero_ratio:.1%} of records")
|
||||
elif zero_ratio > 0.7:
|
||||
warnings.append(f"High zero value ratio: {zero_ratio:.1%}")
|
||||
|
||||
# Check product diversity
|
||||
unique_products = set(r.get('inventory_product_id') for r in sales_data if r.get('inventory_product_id'))
|
||||
if len(unique_products) == 0:
|
||||
errors.append("No valid product IDs found in sales data")
|
||||
elif len(unique_products) == 1:
|
||||
warnings.append("Only one product found - consider adding more products")
|
||||
|
||||
else:
|
||||
logger.warning("Data validation failed", tenant_id=tenant_id)
|
||||
return {"is_valid": False, "errors": ["Validation service unavailable"]}
|
||||
|
||||
# Fetch data for validation
|
||||
sales_data = await self.fetch_sales_data(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
fetch_all=False
|
||||
)
|
||||
|
||||
if not sales_data:
|
||||
errors.append("Unable to fetch sales data for validation")
|
||||
return {"is_valid": False, "errors": errors, "warnings": warnings}
|
||||
|
||||
# Recursive call with fetched data
|
||||
return await self.validate_data_quality(
|
||||
tenant_id, start_date, end_date, sales_data
|
||||
)
|
||||
|
||||
is_valid = len(errors) == 0
|
||||
result = {
|
||||
"is_valid": is_valid,
|
||||
"errors": errors,
|
||||
"warnings": warnings,
|
||||
"data_points": len(sales_data) if sales_data else 0,
|
||||
"unique_products": len(unique_products) if sales_data else 0
|
||||
}
|
||||
|
||||
if is_valid:
|
||||
logger.info("Data validation passed",
|
||||
tenant_id=tenant_id,
|
||||
data_points=result["data_points"],
|
||||
warnings_count=len(warnings))
|
||||
else:
|
||||
logger.error("Data validation failed",
|
||||
tenant_id=tenant_id,
|
||||
errors=errors)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating data: {e}", tenant_id=tenant_id)
|
||||
return {"is_valid": False, "errors": [str(e)]}
|
||||
raise ValueError(f"Data validation failed: {str(e)}")
|
||||
|
||||
# Global instance - same as before, but much simpler implementation
|
||||
data_client = DataClient()
|
||||
@@ -1,9 +1,9 @@
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from app.utils.timezone_utils import ensure_timezone_aware
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -84,31 +84,25 @@ class DateAlignmentService:
|
||||
requested_end: Optional[datetime]
|
||||
) -> DateRange:
|
||||
"""Determine the base date range for training."""
|
||||
|
||||
# ✅ FIX: Ensure all datetimes are timezone-aware for comparison
|
||||
def ensure_timezone_aware(dt: datetime) -> datetime:
|
||||
if dt.tzinfo is None:
|
||||
return dt.replace(tzinfo=timezone.utc)
|
||||
return dt
|
||||
|
||||
|
||||
# Use explicit dates if provided
|
||||
if requested_start and requested_end:
|
||||
requested_start = ensure_timezone_aware(requested_start)
|
||||
requested_end = ensure_timezone_aware(requested_end)
|
||||
|
||||
|
||||
if requested_end <= requested_start:
|
||||
raise ValueError("End date must be after start date")
|
||||
return DateRange(requested_start, requested_end, DataSourceType.BAKERY_SALES)
|
||||
|
||||
|
||||
# Otherwise, use the user's sales data range as the foundation
|
||||
start_date = ensure_timezone_aware(requested_start or user_sales_range.start)
|
||||
end_date = ensure_timezone_aware(requested_end or user_sales_range.end)
|
||||
|
||||
|
||||
# Ensure we don't exceed maximum training range
|
||||
if (end_date - start_date).days > self.MAX_TRAINING_RANGE_DAYS:
|
||||
start_date = end_date - timedelta(days=self.MAX_TRAINING_RANGE_DAYS)
|
||||
logger.warning(f"Limiting training range to {self.MAX_TRAINING_RANGE_DAYS} days")
|
||||
|
||||
|
||||
return DateRange(start_date, end_date, DataSourceType.BAKERY_SALES)
|
||||
|
||||
def _apply_data_source_constraints(self, base_range: DateRange) -> AlignedDateRange:
|
||||
|
||||
@@ -1,603 +0,0 @@
|
||||
# services/training/app/services/messaging.py
|
||||
"""
|
||||
Enhanced training service messaging - Complete status publishing implementation
|
||||
Uses shared RabbitMQ infrastructure with comprehensive progress tracking
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime
|
||||
from shared.messaging.rabbitmq import RabbitMQClient
|
||||
from shared.messaging.events import (
|
||||
TrainingStartedEvent,
|
||||
TrainingCompletedEvent,
|
||||
TrainingFailedEvent
|
||||
)
|
||||
from app.core.config import settings
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Single global instance
|
||||
training_publisher = RabbitMQClient(settings.RABBITMQ_URL, "training-service")
|
||||
|
||||
async def setup_messaging():
|
||||
"""Initialize messaging for training service"""
|
||||
success = await training_publisher.connect()
|
||||
if success:
|
||||
logger.info("Training service messaging initialized")
|
||||
else:
|
||||
logger.warning("Training service messaging failed to initialize")
|
||||
|
||||
async def cleanup_messaging():
|
||||
"""Cleanup messaging for training service"""
|
||||
await training_publisher.disconnect()
|
||||
logger.info("Training service messaging cleaned up")
|
||||
|
||||
def serialize_for_json(obj: Any) -> Any:
|
||||
"""
|
||||
Convert numpy types and other non-JSON serializable objects to JSON-compatible types
|
||||
"""
|
||||
if isinstance(obj, np.integer):
|
||||
return int(obj)
|
||||
elif isinstance(obj, np.floating):
|
||||
return float(obj)
|
||||
elif isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
elif isinstance(obj, np.bool_):
|
||||
return bool(obj)
|
||||
elif isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
elif isinstance(obj, dict):
|
||||
return {key: serialize_for_json(value) for key, value in obj.items()}
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
return [serialize_for_json(item) for item in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
def safe_json_serialize(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Recursively clean data dictionary for JSON serialization
|
||||
"""
|
||||
return serialize_for_json(data)
|
||||
|
||||
async def setup_websocket_message_routing():
|
||||
"""Set up message routing for WebSocket connections"""
|
||||
try:
|
||||
# This will be called from the WebSocket endpoint
|
||||
# to set up the consumer for a specific job
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to set up WebSocket message routing: {e}")
|
||||
|
||||
# =========================================
|
||||
# ENHANCED TRAINING JOB STATUS EVENTS
|
||||
# =========================================
|
||||
|
||||
async def publish_job_started(job_id: str, tenant_id: str, config: Dict[str, Any]) -> bool:
|
||||
"""Publish training job started event"""
|
||||
event = TrainingStartedEvent(
|
||||
service_name="training-service",
|
||||
data={
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"config": config,
|
||||
"started_at": datetime.now().isoformat(),
|
||||
"estimated_duration_minutes": config.get("estimated_duration_minutes", 15)
|
||||
}
|
||||
)
|
||||
success = await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.started",
|
||||
event_data=event.to_dict()
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"Published job started event", job_id=job_id, tenant_id=tenant_id)
|
||||
else:
|
||||
logger.error(f"Failed to publish job started event", job_id=job_id)
|
||||
|
||||
return success
|
||||
|
||||
async def publish_job_progress(
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
progress: int,
|
||||
step: str,
|
||||
current_product: Optional[str] = None,
|
||||
products_completed: int = 0,
|
||||
products_total: int = 0,
|
||||
estimated_time_remaining_minutes: Optional[int] = None,
|
||||
step_details: Optional[str] = None
|
||||
) -> bool:
|
||||
"""Publish detailed training job progress event with safe serialization"""
|
||||
event_data = {
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.progress",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"progress": min(max(int(progress), 0), 100), # Ensure int, not numpy.int64
|
||||
"current_step": step,
|
||||
"current_product": current_product,
|
||||
"products_completed": int(products_completed), # Convert numpy types
|
||||
"products_total": int(products_total),
|
||||
"estimated_time_remaining_minutes": int(estimated_time_remaining_minutes) if estimated_time_remaining_minutes else None,
|
||||
"step_details": step_details
|
||||
}
|
||||
}
|
||||
|
||||
# Clean the entire event data
|
||||
clean_event_data = safe_json_serialize(event_data)
|
||||
|
||||
success = await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.progress",
|
||||
event_data=clean_event_data
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"Published progress update",
|
||||
job_id=job_id,
|
||||
progress=progress,
|
||||
step=step,
|
||||
current_product=current_product)
|
||||
else:
|
||||
logger.error(f"Failed to publish progress update", job_id=job_id)
|
||||
|
||||
return success
|
||||
|
||||
async def publish_job_step_completed(
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
step_name: str,
|
||||
step_result: Dict[str, Any],
|
||||
progress: int
|
||||
) -> bool:
|
||||
"""Publish when a major training step is completed"""
|
||||
event_data = {
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.step.completed",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"step_name": step_name,
|
||||
"step_result": step_result,
|
||||
"progress": progress,
|
||||
"completed_at": datetime.now().isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
return await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.step.completed",
|
||||
event_data=event_data
|
||||
)
|
||||
|
||||
async def publish_job_completed(job_id: str, tenant_id: str, results: Dict[str, Any]) -> bool:
|
||||
"""Publish training job completed event with safe JSON serialization"""
|
||||
|
||||
# Clean the results data before creating the event
|
||||
clean_results = safe_json_serialize(results)
|
||||
|
||||
event = TrainingCompletedEvent(
|
||||
service_name="training-service",
|
||||
data={
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"results": clean_results, # Now safe for JSON
|
||||
"models_trained": clean_results.get("successful_trainings", 0),
|
||||
"success_rate": clean_results.get("success_rate", 0),
|
||||
"total_duration_seconds": clean_results.get("overall_training_time_seconds", 0),
|
||||
"completed_at": datetime.now().isoformat()
|
||||
}
|
||||
)
|
||||
|
||||
success = await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.completed",
|
||||
event_data=event.to_dict()
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"Published job completed event",
|
||||
job_id=job_id,
|
||||
models_trained=clean_results.get("successful_trainings", 0))
|
||||
else:
|
||||
logger.error(f"Failed to publish job completed event", job_id=job_id)
|
||||
|
||||
return success
|
||||
|
||||
async def publish_job_failed(job_id: str, tenant_id: str, error: str, error_details: Optional[Dict] = None) -> bool:
|
||||
"""Publish training job failed event"""
|
||||
event = TrainingFailedEvent(
|
||||
service_name="training-service",
|
||||
data={
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"error": error,
|
||||
"error_details": error_details or {},
|
||||
"failed_at": datetime.now().isoformat()
|
||||
}
|
||||
)
|
||||
|
||||
success = await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.failed",
|
||||
event_data=event.to_dict()
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"Published job failed event", job_id=job_id, error=error)
|
||||
else:
|
||||
logger.error(f"Failed to publish job failed event", job_id=job_id)
|
||||
|
||||
return success
|
||||
|
||||
async def publish_job_cancelled(job_id: str, tenant_id: str, reason: str = "User requested") -> bool:
|
||||
"""Publish training job cancelled event"""
|
||||
event_data = {
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.cancelled",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"reason": reason,
|
||||
"cancelled_at": datetime.now().isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
return await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.cancelled",
|
||||
event_data=event_data
|
||||
)
|
||||
|
||||
# =========================================
|
||||
# PRODUCT-LEVEL TRAINING EVENTS
|
||||
# =========================================
|
||||
|
||||
async def publish_product_training_started(job_id: str, tenant_id: str, inventory_product_id: str) -> bool:
|
||||
"""Publish single product training started event"""
|
||||
return await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.product.started",
|
||||
event_data={
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.product.started",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"started_at": datetime.now().isoformat()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
async def publish_product_training_completed(
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
model_id: str,
|
||||
metrics: Optional[Dict[str, float]] = None
|
||||
) -> bool:
|
||||
"""Publish single product training completed event"""
|
||||
return await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.product.completed",
|
||||
event_data={
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.product.completed",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"model_id": model_id,
|
||||
"metrics": metrics or {},
|
||||
"completed_at": datetime.now().isoformat()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
async def publish_product_training_failed(
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
error: str
|
||||
) -> bool:
|
||||
"""Publish single product training failed event"""
|
||||
return await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.product.failed",
|
||||
event_data={
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.product.failed",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"error": error,
|
||||
"failed_at": datetime.now().isoformat()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# =========================================
|
||||
# MODEL LIFECYCLE EVENTS
|
||||
# =========================================
|
||||
|
||||
async def publish_model_trained(model_id: str, tenant_id: str, inventory_product_id: str, metrics: Dict[str, float]) -> bool:
|
||||
"""Publish model trained event with safe metric serialization"""
|
||||
|
||||
# Clean metrics to ensure JSON serialization
|
||||
clean_metrics = safe_json_serialize(metrics) if metrics else {}
|
||||
|
||||
event_data = {
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.model.trained",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"model_id": model_id,
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"training_metrics": clean_metrics, # Now safe for JSON
|
||||
"trained_at": datetime.now().isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
return await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.model.trained",
|
||||
event_data=event_data
|
||||
)
|
||||
|
||||
|
||||
async def publish_model_validated(model_id: str, tenant_id: str, inventory_product_id: str, validation_results: Dict[str, Any]) -> bool:
|
||||
"""Publish model validation event"""
|
||||
return await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.model.validated",
|
||||
event_data={
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.model.validated",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"model_id": model_id,
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"validation_results": validation_results,
|
||||
"validated_at": datetime.now().isoformat()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
async def publish_model_saved(model_id: str, tenant_id: str, inventory_product_id: str, model_path: str) -> bool:
|
||||
"""Publish model saved event"""
|
||||
return await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.model.saved",
|
||||
event_data={
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.model.saved",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"model_id": model_id,
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"model_path": model_path,
|
||||
"saved_at": datetime.now().isoformat()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# =========================================
|
||||
# DATA PROCESSING EVENTS
|
||||
# =========================================
|
||||
|
||||
async def publish_data_validation_started(job_id: str, tenant_id: str, products: List[str]) -> bool:
|
||||
"""Publish data validation started event"""
|
||||
return await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.data.validation.started",
|
||||
event_data={
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.data.validation.started",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"products": products,
|
||||
"started_at": datetime.now().isoformat()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
async def publish_data_validation_completed(
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
validation_results: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""Publish data validation completed event"""
|
||||
return await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.data.validation.completed",
|
||||
event_data={
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.data.validation.completed",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"validation_results": validation_results,
|
||||
"completed_at": datetime.now().isoformat()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def publish_models_deleted_event(tenant_id: str, deletion_stats: Dict[str, Any]):
|
||||
"""Publish models deletion event to message queue"""
|
||||
try:
|
||||
await training_publisher.publish_event(
|
||||
exchange="training_events",
|
||||
routing_key="training.tenant.models.deleted",
|
||||
message={
|
||||
"event_type": "tenant_models_deleted",
|
||||
"tenant_id": tenant_id,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"deletion_stats": deletion_stats
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to publish models deletion event", error=str(e))
|
||||
|
||||
|
||||
# =========================================
|
||||
# UTILITY FUNCTIONS FOR BATCH PUBLISHING
|
||||
# =========================================
|
||||
|
||||
async def publish_batch_status_update(
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
updates: List[Dict[str, Any]]
|
||||
) -> bool:
|
||||
"""Publish multiple status updates as a batch"""
|
||||
batch_event = {
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.batch.update",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"updates": updates,
|
||||
"batch_size": len(updates)
|
||||
}
|
||||
}
|
||||
|
||||
return await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.batch.update",
|
||||
event_data=batch_event
|
||||
)
|
||||
|
||||
# =========================================
|
||||
# HELPER FUNCTIONS FOR TRAINING INTEGRATION
|
||||
# =========================================
|
||||
|
||||
class TrainingStatusPublisher:
|
||||
"""Helper class to manage training status publishing throughout the training process"""
|
||||
|
||||
def __init__(self, job_id: str, tenant_id: str):
|
||||
self.job_id = job_id
|
||||
self.tenant_id = tenant_id
|
||||
self.start_time = datetime.now()
|
||||
self.products_total = 0
|
||||
self.products_completed = 0
|
||||
|
||||
async def job_started(self, config: Dict[str, Any], products_total: int = 0):
|
||||
"""Publish job started with initial configuration"""
|
||||
self.products_total = products_total
|
||||
|
||||
# Clean config data
|
||||
clean_config = safe_json_serialize(config)
|
||||
|
||||
await publish_job_started(self.job_id, self.tenant_id, clean_config)
|
||||
|
||||
async def progress_update(
|
||||
self,
|
||||
progress: int,
|
||||
step: str,
|
||||
current_product: Optional[str] = None,
|
||||
step_details: Optional[str] = None
|
||||
):
|
||||
"""Publish progress update with improved time estimates"""
|
||||
elapsed_minutes = (datetime.now() - self.start_time).total_seconds() / 60
|
||||
|
||||
# Improved estimation based on training phases
|
||||
estimated_remaining = self._calculate_smart_time_remaining(progress, elapsed_minutes, step)
|
||||
|
||||
await publish_job_progress(
|
||||
job_id=self.job_id,
|
||||
tenant_id=self.tenant_id,
|
||||
progress=int(progress),
|
||||
step=step,
|
||||
current_product=current_product,
|
||||
products_completed=int(self.products_completed),
|
||||
products_total=int(self.products_total),
|
||||
estimated_time_remaining_minutes=int(estimated_remaining) if estimated_remaining else None,
|
||||
step_details=step_details
|
||||
)
|
||||
|
||||
def _calculate_smart_time_remaining(self, progress: int, elapsed_minutes: float, step: str) -> Optional[int]:
|
||||
"""Calculate estimated time remaining using phase-based estimation"""
|
||||
|
||||
# Define expected time distribution for each phase
|
||||
phase_durations = {
|
||||
"data_validation": 1.0, # 1 minute
|
||||
"feature_engineering": 2.0, # 2 minutes
|
||||
"model_training": 8.0, # 8 minutes (bulk of time)
|
||||
"model_validation": 1.0 # 1 minute
|
||||
}
|
||||
|
||||
total_expected_minutes = sum(phase_durations.values()) # 12 minutes
|
||||
|
||||
# Calculate progress through phases
|
||||
if progress <= 10: # data_validation phase
|
||||
remaining_in_phase = phase_durations["data_validation"] * (1 - (progress / 10))
|
||||
remaining_after_phase = sum(list(phase_durations.values())[1:])
|
||||
return int(remaining_in_phase + remaining_after_phase)
|
||||
|
||||
elif progress <= 20: # feature_engineering phase
|
||||
remaining_in_phase = phase_durations["feature_engineering"] * (1 - ((progress - 10) / 10))
|
||||
remaining_after_phase = sum(list(phase_durations.values())[2:])
|
||||
return int(remaining_in_phase + remaining_after_phase)
|
||||
|
||||
elif progress <= 90: # model_training phase (biggest chunk)
|
||||
remaining_in_phase = phase_durations["model_training"] * (1 - ((progress - 20) / 70))
|
||||
remaining_after_phase = phase_durations["model_validation"]
|
||||
return int(remaining_in_phase + remaining_after_phase)
|
||||
|
||||
elif progress <= 100: # model_validation phase
|
||||
remaining_in_phase = phase_durations["model_validation"] * (1 - ((progress - 90) / 10))
|
||||
return int(remaining_in_phase)
|
||||
|
||||
return 0
|
||||
|
||||
async def product_completed(self, inventory_product_id: str, model_id: str, metrics: Optional[Dict] = None):
|
||||
"""Mark a product as completed and update progress"""
|
||||
self.products_completed += 1
|
||||
|
||||
# Clean metrics before publishing
|
||||
clean_metrics = safe_json_serialize(metrics) if metrics else None
|
||||
|
||||
await publish_product_training_completed(
|
||||
self.job_id, self.tenant_id, inventory_product_id, model_id, clean_metrics
|
||||
)
|
||||
|
||||
# Update overall progress
|
||||
if self.products_total > 0:
|
||||
progress = int((self.products_completed / self.products_total) * 90) # Save 10% for final steps
|
||||
await self.progress_update(
|
||||
progress=progress,
|
||||
step=f"Completed training for {inventory_product_id}",
|
||||
current_product=None
|
||||
)
|
||||
|
||||
async def job_completed(self, results: Dict[str, Any]):
|
||||
"""Publish job completion with clean data"""
|
||||
clean_results = safe_json_serialize(results)
|
||||
await publish_job_completed(self.job_id, self.tenant_id, clean_results)
|
||||
|
||||
async def job_failed(self, error: str, error_details: Optional[Dict] = None):
|
||||
"""Publish job failure with clean error details"""
|
||||
clean_error_details = safe_json_serialize(error_details) if error_details else None
|
||||
await publish_job_failed(self.job_id, self.tenant_id, error, clean_error_details)
|
||||
|
||||
78
services/training/app/services/progress_tracker.py
Normal file
78
services/training/app/services/progress_tracker.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
Training Progress Tracker
|
||||
Manages progress calculation for parallel product training (20-80% range)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import structlog
|
||||
from typing import Optional
|
||||
|
||||
from app.services.training_events import publish_product_training_completed
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class ParallelProductProgressTracker:
|
||||
"""
|
||||
Tracks parallel product training progress and emits events.
|
||||
|
||||
For N products training in parallel:
|
||||
- Each product completion contributes 60/N% to overall progress
|
||||
- Progress range: 20% (after data analysis) to 80% (before completion)
|
||||
- Thread-safe for concurrent product trainings
|
||||
"""
|
||||
|
||||
def __init__(self, job_id: str, tenant_id: str, total_products: int):
|
||||
self.job_id = job_id
|
||||
self.tenant_id = tenant_id
|
||||
self.total_products = total_products
|
||||
self.products_completed = 0
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Calculate progress increment per product
|
||||
# 60% of total progress (from 20% to 80%) divided by number of products
|
||||
self.progress_per_product = 60 / total_products if total_products > 0 else 0
|
||||
|
||||
logger.info("ParallelProductProgressTracker initialized",
|
||||
job_id=job_id,
|
||||
total_products=total_products,
|
||||
progress_per_product=f"{self.progress_per_product:.2f}%")
|
||||
|
||||
async def mark_product_completed(self, product_name: str) -> int:
|
||||
"""
|
||||
Mark a product as completed and publish event.
|
||||
Returns the current overall progress percentage.
|
||||
"""
|
||||
async with self._lock:
|
||||
self.products_completed += 1
|
||||
current_progress = self.products_completed
|
||||
|
||||
# Publish product completion event
|
||||
await publish_product_training_completed(
|
||||
job_id=self.job_id,
|
||||
tenant_id=self.tenant_id,
|
||||
product_name=product_name,
|
||||
products_completed=current_progress,
|
||||
total_products=self.total_products
|
||||
)
|
||||
|
||||
# Calculate overall progress (20% base + progress from completed products)
|
||||
# This calculation is done on the frontend/consumer side based on the event data
|
||||
overall_progress = 20 + int((current_progress / self.total_products) * 60)
|
||||
|
||||
logger.info("Product training completed",
|
||||
job_id=self.job_id,
|
||||
product_name=product_name,
|
||||
products_completed=current_progress,
|
||||
total_products=self.total_products,
|
||||
overall_progress=overall_progress)
|
||||
|
||||
return overall_progress
|
||||
|
||||
def get_progress(self) -> dict:
|
||||
"""Get current progress summary"""
|
||||
return {
|
||||
"products_completed": self.products_completed,
|
||||
"total_products": self.total_products,
|
||||
"progress_percentage": 20 + int((self.products_completed / self.total_products) * 60)
|
||||
}
|
||||
238
services/training/app/services/training_events.py
Normal file
238
services/training/app/services/training_events.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""
|
||||
Training Progress Events Publisher
|
||||
Simple, clean event publisher for the 4 main training steps
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
from shared.messaging.rabbitmq import RabbitMQClient
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Single global publisher instance
|
||||
training_publisher = RabbitMQClient(settings.RABBITMQ_URL, "training-service")
|
||||
|
||||
|
||||
async def setup_messaging():
|
||||
"""Initialize messaging"""
|
||||
success = await training_publisher.connect()
|
||||
if success:
|
||||
logger.info("Training messaging initialized")
|
||||
else:
|
||||
logger.warning("Training messaging failed to initialize")
|
||||
return success
|
||||
|
||||
|
||||
async def cleanup_messaging():
|
||||
"""Cleanup messaging"""
|
||||
await training_publisher.disconnect()
|
||||
logger.info("Training messaging cleaned up")
|
||||
|
||||
|
||||
# ==========================================
|
||||
# 4 MAIN TRAINING PROGRESS EVENTS
|
||||
# ==========================================
|
||||
|
||||
async def publish_training_started(
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
total_products: int
|
||||
) -> bool:
|
||||
"""
|
||||
Event 1: Training Started (0% progress)
|
||||
"""
|
||||
event_data = {
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.started",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"progress": 0,
|
||||
"current_step": "Training Started",
|
||||
"step_details": f"Starting training for {total_products} products",
|
||||
"total_products": total_products
|
||||
}
|
||||
}
|
||||
|
||||
success = await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.started",
|
||||
event_data=event_data
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("Published training started event",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
total_products=total_products)
|
||||
else:
|
||||
logger.error("Failed to publish training started event", job_id=job_id)
|
||||
|
||||
return success
|
||||
|
||||
|
||||
async def publish_data_analysis(
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
analysis_details: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Event 2: Data Analysis (20% progress)
|
||||
"""
|
||||
event_data = {
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.progress",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"progress": 20,
|
||||
"current_step": "Data Analysis",
|
||||
"step_details": analysis_details or "Analyzing sales, weather, and traffic data"
|
||||
}
|
||||
}
|
||||
|
||||
success = await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.progress",
|
||||
event_data=event_data
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("Published data analysis event",
|
||||
job_id=job_id,
|
||||
progress=20)
|
||||
else:
|
||||
logger.error("Failed to publish data analysis event", job_id=job_id)
|
||||
|
||||
return success
|
||||
|
||||
|
||||
async def publish_product_training_completed(
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
product_name: str,
|
||||
products_completed: int,
|
||||
total_products: int
|
||||
) -> bool:
|
||||
"""
|
||||
Event 3: Product Training Completed (contributes to 20-80% progress)
|
||||
|
||||
This event is published each time a product training completes.
|
||||
The frontend/consumer will calculate the progress as:
|
||||
progress = 20 + (products_completed / total_products) * 60
|
||||
"""
|
||||
event_data = {
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.product.completed",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name,
|
||||
"products_completed": products_completed,
|
||||
"total_products": total_products,
|
||||
"current_step": "Model Training",
|
||||
"step_details": f"Completed training for {product_name} ({products_completed}/{total_products})"
|
||||
}
|
||||
}
|
||||
|
||||
success = await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.product.completed",
|
||||
event_data=event_data
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("Published product training completed event",
|
||||
job_id=job_id,
|
||||
product_name=product_name,
|
||||
products_completed=products_completed,
|
||||
total_products=total_products)
|
||||
else:
|
||||
logger.error("Failed to publish product training completed event",
|
||||
job_id=job_id)
|
||||
|
||||
return success
|
||||
|
||||
|
||||
async def publish_training_completed(
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
successful_trainings: int,
|
||||
failed_trainings: int,
|
||||
total_duration_seconds: float
|
||||
) -> bool:
|
||||
"""
|
||||
Event 4: Training Completed (100% progress)
|
||||
"""
|
||||
event_data = {
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.completed",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"progress": 100,
|
||||
"current_step": "Training Completed",
|
||||
"step_details": f"Training completed: {successful_trainings} successful, {failed_trainings} failed",
|
||||
"successful_trainings": successful_trainings,
|
||||
"failed_trainings": failed_trainings,
|
||||
"total_duration_seconds": total_duration_seconds
|
||||
}
|
||||
}
|
||||
|
||||
success = await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.completed",
|
||||
event_data=event_data
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("Published training completed event",
|
||||
job_id=job_id,
|
||||
successful_trainings=successful_trainings,
|
||||
failed_trainings=failed_trainings)
|
||||
else:
|
||||
logger.error("Failed to publish training completed event", job_id=job_id)
|
||||
|
||||
return success
|
||||
|
||||
|
||||
async def publish_training_failed(
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
error_message: str
|
||||
) -> bool:
|
||||
"""
|
||||
Event: Training Failed
|
||||
"""
|
||||
event_data = {
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.failed",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"current_step": "Training Failed",
|
||||
"error_message": error_message
|
||||
}
|
||||
}
|
||||
|
||||
success = await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.failed",
|
||||
event_data=event_data
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("Published training failed event",
|
||||
job_id=job_id,
|
||||
error=error_message)
|
||||
else:
|
||||
logger.error("Failed to publish training failed event", job_id=job_id)
|
||||
|
||||
return success
|
||||
@@ -16,13 +16,7 @@ import pandas as pd
|
||||
from app.services.data_client import DataClient
|
||||
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType, AlignedDateRange
|
||||
|
||||
from app.services.messaging import (
|
||||
publish_job_progress,
|
||||
publish_data_validation_started,
|
||||
publish_data_validation_completed,
|
||||
publish_job_step_completed,
|
||||
publish_job_failed
|
||||
)
|
||||
from app.services.training_events import publish_training_failed
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
@@ -76,7 +70,6 @@ class TrainingDataOrchestrator:
|
||||
# Step 1: Fetch and validate sales data (unified approach)
|
||||
sales_data = await self.data_client.fetch_sales_data(tenant_id, fetch_all=True)
|
||||
|
||||
# Pre-flight validation moved here to eliminate duplicate fetching
|
||||
if not sales_data or len(sales_data) == 0:
|
||||
error_msg = f"No sales data available for tenant {tenant_id}. Please import sales data before starting training."
|
||||
logger.error("Training aborted - no sales data", tenant_id=tenant_id, job_id=job_id)
|
||||
@@ -172,7 +165,8 @@ class TrainingDataOrchestrator:
|
||||
return training_dataset
|
||||
|
||||
except Exception as e:
|
||||
publish_job_failed(job_id, tenant_id, str(e))
|
||||
if job_id and tenant_id:
|
||||
await publish_training_failed(job_id, tenant_id, str(e))
|
||||
logger.error(f"Training data preparation failed: {str(e)}")
|
||||
raise ValueError(f"Failed to prepare training data: {str(e)}")
|
||||
|
||||
@@ -472,30 +466,18 @@ class TrainingDataOrchestrator:
|
||||
logger.warning(f"Enhanced traffic data collection failed: {e}")
|
||||
return []
|
||||
|
||||
# Keep original method for backwards compatibility
|
||||
async def _collect_traffic_data_with_timeout(
|
||||
self,
|
||||
lat: float,
|
||||
lon: float,
|
||||
aligned_range: AlignedDateRange,
|
||||
tenant_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Legacy traffic data collection method - redirects to enhanced version"""
|
||||
return await self._collect_traffic_data_with_timeout_enhanced(lat, lon, aligned_range, tenant_id)
|
||||
|
||||
def _log_enhanced_traffic_data_storage(self,
|
||||
lat: float,
|
||||
lon: float,
|
||||
aligned_range: AlignedDateRange,
|
||||
def _log_enhanced_traffic_data_storage(self,
|
||||
lat: float,
|
||||
lon: float,
|
||||
aligned_range: AlignedDateRange,
|
||||
record_count: int,
|
||||
traffic_data: List[Dict[str, Any]]):
|
||||
"""Enhanced logging for traffic data storage with detailed metadata"""
|
||||
# Analyze the stored data for additional insights
|
||||
cities_detected = set()
|
||||
has_pedestrian_data = 0
|
||||
data_sources = set()
|
||||
districts_covered = set()
|
||||
|
||||
|
||||
for record in traffic_data:
|
||||
if 'city' in record and record['city']:
|
||||
cities_detected.add(record['city'])
|
||||
@@ -505,7 +487,7 @@ class TrainingDataOrchestrator:
|
||||
data_sources.add(record['source'])
|
||||
if 'district' in record and record['district']:
|
||||
districts_covered.add(record['district'])
|
||||
|
||||
|
||||
logger.info(
|
||||
"Enhanced traffic data stored for re-training",
|
||||
location=f"{lat:.4f},{lon:.4f}",
|
||||
@@ -516,20 +498,9 @@ class TrainingDataOrchestrator:
|
||||
data_sources=list(data_sources),
|
||||
districts_covered=list(districts_covered),
|
||||
storage_timestamp=datetime.now().isoformat(),
|
||||
purpose="enhanced_model_training_and_retraining",
|
||||
architecture_version="2.0_abstracted"
|
||||
purpose="model_training_and_retraining"
|
||||
)
|
||||
|
||||
def _log_traffic_data_storage(self,
|
||||
lat: float,
|
||||
lon: float,
|
||||
aligned_range: AlignedDateRange,
|
||||
record_count: int):
|
||||
"""Legacy logging method - redirects to enhanced version"""
|
||||
# Create minimal traffic data structure for enhanced logging
|
||||
minimal_traffic_data = [{"city": "madrid", "source": "legacy"}] * min(record_count, 1)
|
||||
self._log_enhanced_traffic_data_storage(lat, lon, aligned_range, record_count, minimal_traffic_data)
|
||||
|
||||
def _validate_weather_data(self, weather_data: List[Dict[str, Any]]) -> bool:
|
||||
"""Validate weather data quality"""
|
||||
if not weather_data:
|
||||
|
||||
@@ -13,10 +13,9 @@ import json
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from app.ml.trainer import BakeryMLTrainer
|
||||
from app.ml.trainer import EnhancedBakeryMLTrainer
|
||||
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType
|
||||
from app.services.training_orchestrator import TrainingDataOrchestrator
|
||||
from app.services.messaging import TrainingStatusPublisher
|
||||
|
||||
# Import repositories
|
||||
from app.repositories import (
|
||||
@@ -119,7 +118,7 @@ class EnhancedTrainingService:
|
||||
self.artifact_repo = ArtifactRepository(session)
|
||||
|
||||
# Initialize training components
|
||||
self.trainer = BakeryMLTrainer(database_manager=self.database_manager)
|
||||
self.trainer = EnhancedBakeryMLTrainer(database_manager=self.database_manager)
|
||||
self.date_alignment_service = DateAlignmentService()
|
||||
self.orchestrator = TrainingDataOrchestrator(
|
||||
date_alignment_service=self.date_alignment_service
|
||||
@@ -164,10 +163,8 @@ class EnhancedTrainingService:
|
||||
# Get session and initialize repositories
|
||||
async with self.database_manager.get_session() as session:
|
||||
await self._init_repositories(session)
|
||||
|
||||
|
||||
try:
|
||||
# Pre-flight check moved to orchestrator to eliminate duplicate sales data fetching
|
||||
|
||||
# Check if training log already exists, create if not
|
||||
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
|
||||
|
||||
@@ -187,21 +184,12 @@ class EnhancedTrainingService:
|
||||
}
|
||||
training_log = await self.training_log_repo.create_training_log(log_data)
|
||||
|
||||
# Initialize status publisher
|
||||
status_publisher = TrainingStatusPublisher(job_id, tenant_id)
|
||||
|
||||
await status_publisher.progress_update(
|
||||
progress=10,
|
||||
step="data_validation",
|
||||
step_details="Data"
|
||||
)
|
||||
|
||||
# Step 1: Prepare training dataset (includes sales data validation)
|
||||
logger.info("Step 1: Preparing and aligning training data (with validation)")
|
||||
await self.training_log_repo.update_log_progress(
|
||||
job_id, 10, "data_validation", "running"
|
||||
)
|
||||
|
||||
|
||||
# Orchestrator now handles sales data validation to eliminate duplicate fetching
|
||||
training_dataset = await self.orchestrator.prepare_training_data(
|
||||
tenant_id=tenant_id,
|
||||
@@ -210,11 +198,11 @@ class EnhancedTrainingService:
|
||||
requested_end=requested_end,
|
||||
job_id=job_id
|
||||
)
|
||||
|
||||
|
||||
# Log the results from orchestrator's unified sales data fetch
|
||||
logger.info(f"Sales data validation completed: {len(training_dataset.sales_data)} records",
|
||||
logger.info(f"Sales data validation completed: {len(training_dataset.sales_data)} records",
|
||||
tenant_id=tenant_id, job_id=job_id)
|
||||
|
||||
|
||||
await self.training_log_repo.update_log_progress(
|
||||
job_id, 30, "data_preparation_complete", "running"
|
||||
)
|
||||
@@ -224,15 +212,15 @@ class EnhancedTrainingService:
|
||||
await self.training_log_repo.update_log_progress(
|
||||
job_id, 40, "ml_training", "running"
|
||||
)
|
||||
|
||||
|
||||
training_results = await self.trainer.train_tenant_models(
|
||||
tenant_id=tenant_id,
|
||||
training_dataset=training_dataset,
|
||||
job_id=job_id
|
||||
)
|
||||
|
||||
|
||||
await self.training_log_repo.update_log_progress(
|
||||
job_id, 80, "training_complete", "running"
|
||||
job_id, 85, "training_complete", "running"
|
||||
)
|
||||
|
||||
# Step 3: Store model records using repository
|
||||
@@ -240,19 +228,21 @@ class EnhancedTrainingService:
|
||||
logger.debug("Training results structure",
|
||||
keys=list(training_results.keys()) if isinstance(training_results, dict) else "not_dict",
|
||||
training_results_type=type(training_results).__name__)
|
||||
|
||||
stored_models = await self._store_trained_models(
|
||||
tenant_id, job_id, training_results
|
||||
)
|
||||
|
||||
|
||||
await self.training_log_repo.update_log_progress(
|
||||
job_id, 90, "storing_models", "running"
|
||||
job_id, 92, "storing_models", "running"
|
||||
)
|
||||
|
||||
|
||||
# Step 4: Create performance metrics
|
||||
|
||||
await self._create_performance_metrics(
|
||||
tenant_id, stored_models, training_results
|
||||
)
|
||||
|
||||
|
||||
# Step 5: Complete training log
|
||||
final_result = {
|
||||
"job_id": job_id,
|
||||
@@ -308,11 +298,11 @@ class EnhancedTrainingService:
|
||||
await self.training_log_repo.complete_training_log(
|
||||
job_id, results=json_safe_result
|
||||
)
|
||||
|
||||
|
||||
logger.info("Enhanced training job completed successfully",
|
||||
job_id=job_id,
|
||||
models_created=len(stored_models))
|
||||
|
||||
|
||||
return self._create_detailed_training_response(final_result)
|
||||
|
||||
except Exception as e:
|
||||
@@ -460,7 +450,7 @@ class EnhancedTrainingService:
|
||||
async def get_training_status(self, job_id: str) -> Dict[str, Any]:
|
||||
"""Get training job status using repository"""
|
||||
try:
|
||||
async with self.database_manager.get_session()() as session:
|
||||
async with self.database_manager.get_session() as session:
|
||||
await self._init_repositories(session)
|
||||
|
||||
log = await self.training_log_repo.get_log_by_job_id(job_id)
|
||||
@@ -761,8 +751,4 @@ class EnhancedTrainingService:
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to create detailed response", error=str(e))
|
||||
return final_result
|
||||
|
||||
|
||||
# Legacy compatibility alias
|
||||
TrainingService = EnhancedTrainingService
|
||||
return final_result
|
||||
92
services/training/app/utils/__init__.py
Normal file
92
services/training/app/utils/__init__.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""
|
||||
Training Service Utilities
|
||||
"""
|
||||
|
||||
from .timezone_utils import (
|
||||
ensure_timezone_aware,
|
||||
ensure_timezone_naive,
|
||||
normalize_datetime_to_utc,
|
||||
normalize_dataframe_datetime_column,
|
||||
prepare_prophet_datetime,
|
||||
safe_datetime_comparison,
|
||||
get_current_utc,
|
||||
convert_timestamp_to_datetime
|
||||
)
|
||||
|
||||
from .circuit_breaker import (
|
||||
CircuitBreaker,
|
||||
CircuitBreakerError,
|
||||
CircuitState,
|
||||
circuit_breaker_registry
|
||||
)
|
||||
|
||||
from .file_utils import (
|
||||
calculate_file_checksum,
|
||||
verify_file_checksum,
|
||||
get_file_size,
|
||||
ensure_directory_exists,
|
||||
safe_file_delete,
|
||||
get_file_metadata,
|
||||
ChecksummedFile
|
||||
)
|
||||
|
||||
from .distributed_lock import (
|
||||
DatabaseLock,
|
||||
SimpleDatabaseLock,
|
||||
LockAcquisitionError,
|
||||
get_training_lock
|
||||
)
|
||||
|
||||
from .retry import (
|
||||
RetryStrategy,
|
||||
RetryError,
|
||||
retry_async,
|
||||
with_retry,
|
||||
retry_with_timeout,
|
||||
AdaptiveRetryStrategy,
|
||||
TimeoutRetryStrategy,
|
||||
HTTP_RETRY_STRATEGY,
|
||||
DATABASE_RETRY_STRATEGY,
|
||||
EXTERNAL_SERVICE_RETRY_STRATEGY
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Timezone utilities
|
||||
'ensure_timezone_aware',
|
||||
'ensure_timezone_naive',
|
||||
'normalize_datetime_to_utc',
|
||||
'normalize_dataframe_datetime_column',
|
||||
'prepare_prophet_datetime',
|
||||
'safe_datetime_comparison',
|
||||
'get_current_utc',
|
||||
'convert_timestamp_to_datetime',
|
||||
# Circuit breaker
|
||||
'CircuitBreaker',
|
||||
'CircuitBreakerError',
|
||||
'CircuitState',
|
||||
'circuit_breaker_registry',
|
||||
# File utilities
|
||||
'calculate_file_checksum',
|
||||
'verify_file_checksum',
|
||||
'get_file_size',
|
||||
'ensure_directory_exists',
|
||||
'safe_file_delete',
|
||||
'get_file_metadata',
|
||||
'ChecksummedFile',
|
||||
# Distributed locking
|
||||
'DatabaseLock',
|
||||
'SimpleDatabaseLock',
|
||||
'LockAcquisitionError',
|
||||
'get_training_lock',
|
||||
# Retry mechanisms
|
||||
'RetryStrategy',
|
||||
'RetryError',
|
||||
'retry_async',
|
||||
'with_retry',
|
||||
'retry_with_timeout',
|
||||
'AdaptiveRetryStrategy',
|
||||
'TimeoutRetryStrategy',
|
||||
'HTTP_RETRY_STRATEGY',
|
||||
'DATABASE_RETRY_STRATEGY',
|
||||
'EXTERNAL_SERVICE_RETRY_STRATEGY'
|
||||
]
|
||||
198
services/training/app/utils/circuit_breaker.py
Normal file
198
services/training/app/utils/circuit_breaker.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""
|
||||
Circuit Breaker Pattern Implementation
|
||||
Protects against cascading failures from external service calls
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Callable, Any, Optional
|
||||
import logging
|
||||
from functools import wraps
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CircuitState(Enum):
|
||||
"""Circuit breaker states"""
|
||||
CLOSED = "closed" # Normal operation
|
||||
OPEN = "open" # Circuit is open, rejecting requests
|
||||
HALF_OPEN = "half_open" # Testing if service recovered
|
||||
|
||||
|
||||
class CircuitBreakerError(Exception):
|
||||
"""Raised when circuit breaker is open"""
|
||||
pass
|
||||
|
||||
|
||||
class CircuitBreaker:
|
||||
"""
|
||||
Circuit breaker to prevent cascading failures.
|
||||
|
||||
States:
|
||||
- CLOSED: Normal operation, requests pass through
|
||||
- OPEN: Too many failures, rejecting all requests
|
||||
- HALF_OPEN: Testing if service recovered, allowing limited requests
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
failure_threshold: int = 5,
|
||||
recovery_timeout: float = 60.0,
|
||||
expected_exception: type = Exception,
|
||||
name: str = "circuit_breaker"
|
||||
):
|
||||
"""
|
||||
Initialize circuit breaker.
|
||||
|
||||
Args:
|
||||
failure_threshold: Number of failures before opening circuit
|
||||
recovery_timeout: Seconds to wait before attempting recovery
|
||||
expected_exception: Exception type to catch (others will pass through)
|
||||
name: Name for logging purposes
|
||||
"""
|
||||
self.failure_threshold = failure_threshold
|
||||
self.recovery_timeout = recovery_timeout
|
||||
self.expected_exception = expected_exception
|
||||
self.name = name
|
||||
|
||||
self.failure_count = 0
|
||||
self.last_failure_time: Optional[float] = None
|
||||
self.state = CircuitState.CLOSED
|
||||
|
||||
def _record_success(self):
|
||||
"""Record successful call"""
|
||||
self.failure_count = 0
|
||||
self.last_failure_time = None
|
||||
if self.state == CircuitState.HALF_OPEN:
|
||||
logger.info(f"Circuit breaker '{self.name}' recovered, closing circuit")
|
||||
self.state = CircuitState.CLOSED
|
||||
|
||||
def _record_failure(self):
|
||||
"""Record failed call"""
|
||||
self.failure_count += 1
|
||||
self.last_failure_time = time.time()
|
||||
|
||||
if self.failure_count >= self.failure_threshold:
|
||||
if self.state != CircuitState.OPEN:
|
||||
logger.warning(
|
||||
f"Circuit breaker '{self.name}' opened after {self.failure_count} failures"
|
||||
)
|
||||
self.state = CircuitState.OPEN
|
||||
|
||||
def _should_attempt_reset(self) -> bool:
|
||||
"""Check if we should attempt to reset circuit"""
|
||||
return (
|
||||
self.state == CircuitState.OPEN
|
||||
and self.last_failure_time is not None
|
||||
and time.time() - self.last_failure_time >= self.recovery_timeout
|
||||
)
|
||||
|
||||
async def call(self, func: Callable, *args, **kwargs) -> Any:
|
||||
"""
|
||||
Execute function with circuit breaker protection.
|
||||
|
||||
Args:
|
||||
func: Async function to execute
|
||||
*args: Positional arguments for func
|
||||
**kwargs: Keyword arguments for func
|
||||
|
||||
Returns:
|
||||
Result from func
|
||||
|
||||
Raises:
|
||||
CircuitBreakerError: If circuit is open
|
||||
Exception: Original exception if not expected_exception type
|
||||
"""
|
||||
# Check if circuit is open
|
||||
if self.state == CircuitState.OPEN:
|
||||
if self._should_attempt_reset():
|
||||
logger.info(f"Circuit breaker '{self.name}' attempting recovery (half-open)")
|
||||
self.state = CircuitState.HALF_OPEN
|
||||
else:
|
||||
raise CircuitBreakerError(
|
||||
f"Circuit breaker '{self.name}' is open. "
|
||||
f"Service unavailable for {self.recovery_timeout}s after {self.failure_count} failures."
|
||||
)
|
||||
|
||||
try:
|
||||
# Execute the function
|
||||
result = await func(*args, **kwargs)
|
||||
self._record_success()
|
||||
return result
|
||||
|
||||
except self.expected_exception as e:
|
||||
self._record_failure()
|
||||
logger.error(
|
||||
f"Circuit breaker '{self.name}' caught failure",
|
||||
error=str(e),
|
||||
failure_count=self.failure_count,
|
||||
state=self.state.value
|
||||
)
|
||||
raise
|
||||
|
||||
def __call__(self, func: Callable) -> Callable:
|
||||
"""Decorator interface for circuit breaker"""
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
return await self.call(func, *args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
def get_state(self) -> dict:
|
||||
"""Get current circuit breaker state for monitoring"""
|
||||
return {
|
||||
"name": self.name,
|
||||
"state": self.state.value,
|
||||
"failure_count": self.failure_count,
|
||||
"failure_threshold": self.failure_threshold,
|
||||
"last_failure_time": self.last_failure_time,
|
||||
"recovery_timeout": self.recovery_timeout
|
||||
}
|
||||
|
||||
|
||||
class CircuitBreakerRegistry:
|
||||
"""Registry to manage multiple circuit breakers"""
|
||||
|
||||
def __init__(self):
|
||||
self._breakers: dict[str, CircuitBreaker] = {}
|
||||
|
||||
def get_or_create(
|
||||
self,
|
||||
name: str,
|
||||
failure_threshold: int = 5,
|
||||
recovery_timeout: float = 60.0,
|
||||
expected_exception: type = Exception
|
||||
) -> CircuitBreaker:
|
||||
"""Get existing circuit breaker or create new one"""
|
||||
if name not in self._breakers:
|
||||
self._breakers[name] = CircuitBreaker(
|
||||
failure_threshold=failure_threshold,
|
||||
recovery_timeout=recovery_timeout,
|
||||
expected_exception=expected_exception,
|
||||
name=name
|
||||
)
|
||||
return self._breakers[name]
|
||||
|
||||
def get(self, name: str) -> Optional[CircuitBreaker]:
|
||||
"""Get circuit breaker by name"""
|
||||
return self._breakers.get(name)
|
||||
|
||||
def get_all_states(self) -> dict:
|
||||
"""Get states of all circuit breakers"""
|
||||
return {
|
||||
name: breaker.get_state()
|
||||
for name, breaker in self._breakers.items()
|
||||
}
|
||||
|
||||
def reset(self, name: str):
|
||||
"""Manually reset a circuit breaker"""
|
||||
if name in self._breakers:
|
||||
breaker = self._breakers[name]
|
||||
breaker.failure_count = 0
|
||||
breaker.last_failure_time = None
|
||||
breaker.state = CircuitState.CLOSED
|
||||
logger.info(f"Circuit breaker '{name}' manually reset")
|
||||
|
||||
|
||||
# Global registry instance
|
||||
circuit_breaker_registry = CircuitBreakerRegistry()
|
||||
233
services/training/app/utils/distributed_lock.py
Normal file
233
services/training/app/utils/distributed_lock.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""
|
||||
Distributed Locking Mechanisms
|
||||
Prevents concurrent training jobs for the same product
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Optional
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import text
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LockAcquisitionError(Exception):
|
||||
"""Raised when lock cannot be acquired"""
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseLock:
|
||||
"""
|
||||
Database-based distributed lock using PostgreSQL advisory locks.
|
||||
Works across multiple service instances.
|
||||
"""
|
||||
|
||||
def __init__(self, lock_name: str, timeout: float = 30.0):
|
||||
"""
|
||||
Initialize database lock.
|
||||
|
||||
Args:
|
||||
lock_name: Unique identifier for the lock
|
||||
timeout: Maximum seconds to wait for lock acquisition
|
||||
"""
|
||||
self.lock_name = lock_name
|
||||
self.timeout = timeout
|
||||
self.lock_id = self._hash_lock_name(lock_name)
|
||||
|
||||
def _hash_lock_name(self, name: str) -> int:
|
||||
"""Convert lock name to integer ID for PostgreSQL advisory lock"""
|
||||
# Use hash and modulo to get a positive 32-bit integer
|
||||
return abs(hash(name)) % (2**31)
|
||||
|
||||
@asynccontextmanager
|
||||
async def acquire(self, session: AsyncSession):
|
||||
"""
|
||||
Acquire distributed lock as async context manager.
|
||||
|
||||
Args:
|
||||
session: Database session for lock operations
|
||||
|
||||
Raises:
|
||||
LockAcquisitionError: If lock cannot be acquired within timeout
|
||||
"""
|
||||
acquired = False
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Try to acquire lock with timeout
|
||||
while time.time() - start_time < self.timeout:
|
||||
# Try non-blocking lock acquisition
|
||||
result = await session.execute(
|
||||
text("SELECT pg_try_advisory_lock(:lock_id)"),
|
||||
{"lock_id": self.lock_id}
|
||||
)
|
||||
acquired = result.scalar()
|
||||
|
||||
if acquired:
|
||||
logger.info(f"Acquired lock: {self.lock_name} (id={self.lock_id})")
|
||||
break
|
||||
|
||||
# Wait a bit before retrying
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
if not acquired:
|
||||
raise LockAcquisitionError(
|
||||
f"Could not acquire lock '{self.lock_name}' within {self.timeout}s"
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
if acquired:
|
||||
# Release lock
|
||||
await session.execute(
|
||||
text("SELECT pg_advisory_unlock(:lock_id)"),
|
||||
{"lock_id": self.lock_id}
|
||||
)
|
||||
logger.info(f"Released lock: {self.lock_name} (id={self.lock_id})")
|
||||
|
||||
|
||||
class SimpleDatabaseLock:
|
||||
"""
|
||||
Simple table-based distributed lock.
|
||||
Alternative to advisory locks, uses a dedicated locks table.
|
||||
"""
|
||||
|
||||
def __init__(self, lock_name: str, timeout: float = 30.0, ttl: float = 300.0):
|
||||
"""
|
||||
Initialize simple database lock.
|
||||
|
||||
Args:
|
||||
lock_name: Unique identifier for the lock
|
||||
timeout: Maximum seconds to wait for lock acquisition
|
||||
ttl: Time-to-live for stale lock cleanup (seconds)
|
||||
"""
|
||||
self.lock_name = lock_name
|
||||
self.timeout = timeout
|
||||
self.ttl = ttl
|
||||
|
||||
async def _ensure_lock_table(self, session: AsyncSession):
|
||||
"""Ensure locks table exists"""
|
||||
create_table_sql = """
|
||||
CREATE TABLE IF NOT EXISTS distributed_locks (
|
||||
lock_name VARCHAR(255) PRIMARY KEY,
|
||||
acquired_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
||||
acquired_by VARCHAR(255),
|
||||
expires_at TIMESTAMP WITH TIME ZONE NOT NULL
|
||||
)
|
||||
"""
|
||||
await session.execute(text(create_table_sql))
|
||||
await session.commit()
|
||||
|
||||
async def _cleanup_stale_locks(self, session: AsyncSession):
|
||||
"""Remove expired locks"""
|
||||
cleanup_sql = """
|
||||
DELETE FROM distributed_locks
|
||||
WHERE expires_at < :now
|
||||
"""
|
||||
await session.execute(
|
||||
text(cleanup_sql),
|
||||
{"now": datetime.now(timezone.utc)}
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
@asynccontextmanager
|
||||
async def acquire(self, session: AsyncSession, owner: str = "training-service"):
|
||||
"""
|
||||
Acquire simple database lock.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
owner: Identifier for lock owner
|
||||
|
||||
Raises:
|
||||
LockAcquisitionError: If lock cannot be acquired
|
||||
"""
|
||||
await self._ensure_lock_table(session)
|
||||
await self._cleanup_stale_locks(session)
|
||||
|
||||
acquired = False
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Try to acquire lock
|
||||
while time.time() - start_time < self.timeout:
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now + timedelta(seconds=self.ttl)
|
||||
|
||||
try:
|
||||
# Try to insert lock record
|
||||
insert_sql = """
|
||||
INSERT INTO distributed_locks (lock_name, acquired_at, acquired_by, expires_at)
|
||||
VALUES (:lock_name, :acquired_at, :acquired_by, :expires_at)
|
||||
ON CONFLICT (lock_name) DO NOTHING
|
||||
RETURNING lock_name
|
||||
"""
|
||||
|
||||
result = await session.execute(
|
||||
text(insert_sql),
|
||||
{
|
||||
"lock_name": self.lock_name,
|
||||
"acquired_at": now,
|
||||
"acquired_by": owner,
|
||||
"expires_at": expires_at
|
||||
}
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
if result.rowcount > 0:
|
||||
acquired = True
|
||||
logger.info(f"Acquired simple lock: {self.lock_name}")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Lock acquisition attempt failed: {e}")
|
||||
await session.rollback()
|
||||
|
||||
# Wait before retrying
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
if not acquired:
|
||||
raise LockAcquisitionError(
|
||||
f"Could not acquire lock '{self.lock_name}' within {self.timeout}s"
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
if acquired:
|
||||
# Release lock
|
||||
delete_sql = """
|
||||
DELETE FROM distributed_locks
|
||||
WHERE lock_name = :lock_name
|
||||
"""
|
||||
await session.execute(
|
||||
text(delete_sql),
|
||||
{"lock_name": self.lock_name}
|
||||
)
|
||||
await session.commit()
|
||||
logger.info(f"Released simple lock: {self.lock_name}")
|
||||
|
||||
|
||||
def get_training_lock(tenant_id: str, product_id: str, use_advisory: bool = True) -> DatabaseLock:
|
||||
"""
|
||||
Get distributed lock for training a specific product.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
product_id: Product identifier
|
||||
use_advisory: Use PostgreSQL advisory locks (True) or table-based (False)
|
||||
|
||||
Returns:
|
||||
Lock instance
|
||||
"""
|
||||
lock_name = f"training:{tenant_id}:{product_id}"
|
||||
|
||||
if use_advisory:
|
||||
return DatabaseLock(lock_name, timeout=60.0)
|
||||
else:
|
||||
return SimpleDatabaseLock(lock_name, timeout=60.0, ttl=600.0)
|
||||
216
services/training/app/utils/file_utils.py
Normal file
216
services/training/app/utils/file_utils.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""
|
||||
File Utility Functions
|
||||
Utilities for secure file operations including checksum verification
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def calculate_file_checksum(file_path: str, algorithm: str = "sha256") -> str:
|
||||
"""
|
||||
Calculate checksum of a file.
|
||||
|
||||
Args:
|
||||
file_path: Path to file
|
||||
algorithm: Hash algorithm (sha256, md5, etc.)
|
||||
|
||||
Returns:
|
||||
Hexadecimal checksum string
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
ValueError: If algorithm not supported
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
try:
|
||||
hash_func = hashlib.new(algorithm)
|
||||
except ValueError:
|
||||
raise ValueError(f"Unsupported hash algorithm: {algorithm}")
|
||||
|
||||
# Read file in chunks to handle large files efficiently
|
||||
with open(file_path, 'rb') as f:
|
||||
while chunk := f.read(8192):
|
||||
hash_func.update(chunk)
|
||||
|
||||
return hash_func.hexdigest()
|
||||
|
||||
|
||||
def verify_file_checksum(file_path: str, expected_checksum: str, algorithm: str = "sha256") -> bool:
|
||||
"""
|
||||
Verify file matches expected checksum.
|
||||
|
||||
Args:
|
||||
file_path: Path to file
|
||||
expected_checksum: Expected checksum value
|
||||
algorithm: Hash algorithm used
|
||||
|
||||
Returns:
|
||||
True if checksum matches, False otherwise
|
||||
"""
|
||||
try:
|
||||
actual_checksum = calculate_file_checksum(file_path, algorithm)
|
||||
matches = actual_checksum == expected_checksum
|
||||
|
||||
if matches:
|
||||
logger.debug(f"Checksum verified for {file_path}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Checksum mismatch for {file_path}",
|
||||
expected=expected_checksum,
|
||||
actual=actual_checksum
|
||||
)
|
||||
|
||||
return matches
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error verifying checksum for {file_path}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_file_size(file_path: str) -> int:
|
||||
"""
|
||||
Get file size in bytes.
|
||||
|
||||
Args:
|
||||
file_path: Path to file
|
||||
|
||||
Returns:
|
||||
File size in bytes
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
return os.path.getsize(file_path)
|
||||
|
||||
|
||||
def ensure_directory_exists(directory: str) -> Path:
|
||||
"""
|
||||
Ensure directory exists, create if necessary.
|
||||
|
||||
Args:
|
||||
directory: Directory path
|
||||
|
||||
Returns:
|
||||
Path object for directory
|
||||
"""
|
||||
path = Path(directory)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
def safe_file_delete(file_path: str) -> bool:
|
||||
"""
|
||||
Safely delete a file, logging any errors.
|
||||
|
||||
Args:
|
||||
file_path: Path to file
|
||||
|
||||
Returns:
|
||||
True if deleted successfully, False otherwise
|
||||
"""
|
||||
try:
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
logger.info(f"Deleted file: {file_path}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"File not found for deletion: {file_path}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting file {file_path}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_file_metadata(file_path: str) -> dict:
|
||||
"""
|
||||
Get comprehensive file metadata.
|
||||
|
||||
Args:
|
||||
file_path: Path to file
|
||||
|
||||
Returns:
|
||||
Dictionary with file metadata
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
stat = os.stat(file_path)
|
||||
|
||||
return {
|
||||
"path": file_path,
|
||||
"size_bytes": stat.st_size,
|
||||
"created_at": stat.st_ctime,
|
||||
"modified_at": stat.st_mtime,
|
||||
"accessed_at": stat.st_atime,
|
||||
"is_file": os.path.isfile(file_path),
|
||||
"is_dir": os.path.isdir(file_path),
|
||||
"exists": True
|
||||
}
|
||||
|
||||
|
||||
class ChecksummedFile:
|
||||
"""
|
||||
Context manager for working with checksummed files.
|
||||
Automatically calculates and stores checksum when file is written.
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str, checksum_path: Optional[str] = None, algorithm: str = "sha256"):
|
||||
"""
|
||||
Initialize checksummed file handler.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
checksum_path: Path to store checksum (default: file_path + '.checksum')
|
||||
algorithm: Hash algorithm to use
|
||||
"""
|
||||
self.file_path = file_path
|
||||
self.checksum_path = checksum_path or f"{file_path}.checksum"
|
||||
self.algorithm = algorithm
|
||||
self.checksum: Optional[str] = None
|
||||
|
||||
def calculate_and_save_checksum(self) -> str:
|
||||
"""Calculate checksum and save to file"""
|
||||
self.checksum = calculate_file_checksum(self.file_path, self.algorithm)
|
||||
|
||||
with open(self.checksum_path, 'w') as f:
|
||||
f.write(f"{self.checksum} {os.path.basename(self.file_path)}\n")
|
||||
|
||||
logger.info(f"Saved checksum for {self.file_path}: {self.checksum}")
|
||||
return self.checksum
|
||||
|
||||
def load_and_verify_checksum(self) -> bool:
|
||||
"""Load expected checksum and verify file"""
|
||||
try:
|
||||
with open(self.checksum_path, 'r') as f:
|
||||
expected_checksum = f.read().strip().split()[0]
|
||||
|
||||
return verify_file_checksum(self.file_path, expected_checksum, self.algorithm)
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.warning(f"Checksum file not found: {self.checksum_path}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading checksum: {e}")
|
||||
return False
|
||||
|
||||
def get_stored_checksum(self) -> Optional[str]:
|
||||
"""Get checksum from stored file"""
|
||||
try:
|
||||
with open(self.checksum_path, 'r') as f:
|
||||
return f.read().strip().split()[0]
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
316
services/training/app/utils/retry.py
Normal file
316
services/training/app/utils/retry.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""
|
||||
Retry Mechanism with Exponential Backoff
|
||||
Handles transient failures with intelligent retry strategies
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import random
|
||||
from typing import Callable, Any, Optional, Type, Tuple
|
||||
from functools import wraps
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RetryError(Exception):
|
||||
"""Raised when all retry attempts are exhausted"""
|
||||
def __init__(self, message: str, attempts: int, last_exception: Exception):
|
||||
super().__init__(message)
|
||||
self.attempts = attempts
|
||||
self.last_exception = last_exception
|
||||
|
||||
|
||||
class RetryStrategy:
|
||||
"""Base retry strategy"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_attempts: int = 3,
|
||||
initial_delay: float = 1.0,
|
||||
max_delay: float = 60.0,
|
||||
exponential_base: float = 2.0,
|
||||
jitter: bool = True,
|
||||
retriable_exceptions: Tuple[Type[Exception], ...] = (Exception,)
|
||||
):
|
||||
"""
|
||||
Initialize retry strategy.
|
||||
|
||||
Args:
|
||||
max_attempts: Maximum number of retry attempts
|
||||
initial_delay: Initial delay in seconds
|
||||
max_delay: Maximum delay between retries
|
||||
exponential_base: Base for exponential backoff
|
||||
jitter: Add random jitter to prevent thundering herd
|
||||
retriable_exceptions: Tuple of exception types to retry
|
||||
"""
|
||||
self.max_attempts = max_attempts
|
||||
self.initial_delay = initial_delay
|
||||
self.max_delay = max_delay
|
||||
self.exponential_base = exponential_base
|
||||
self.jitter = jitter
|
||||
self.retriable_exceptions = retriable_exceptions
|
||||
|
||||
def calculate_delay(self, attempt: int) -> float:
|
||||
"""Calculate delay for given attempt using exponential backoff"""
|
||||
delay = min(
|
||||
self.initial_delay * (self.exponential_base ** attempt),
|
||||
self.max_delay
|
||||
)
|
||||
|
||||
if self.jitter:
|
||||
# Add random jitter (0-100% of delay)
|
||||
delay = delay * (0.5 + random.random() * 0.5)
|
||||
|
||||
return delay
|
||||
|
||||
def is_retriable(self, exception: Exception) -> bool:
|
||||
"""Check if exception should trigger retry"""
|
||||
return isinstance(exception, self.retriable_exceptions)
|
||||
|
||||
|
||||
async def retry_async(
|
||||
func: Callable,
|
||||
*args,
|
||||
strategy: Optional[RetryStrategy] = None,
|
||||
**kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
Retry async function with exponential backoff.
|
||||
|
||||
Args:
|
||||
func: Async function to retry
|
||||
*args: Positional arguments for func
|
||||
strategy: Retry strategy (uses default if None)
|
||||
**kwargs: Keyword arguments for func
|
||||
|
||||
Returns:
|
||||
Result from func
|
||||
|
||||
Raises:
|
||||
RetryError: When all attempts exhausted
|
||||
"""
|
||||
if strategy is None:
|
||||
strategy = RetryStrategy()
|
||||
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(strategy.max_attempts):
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
if attempt > 0:
|
||||
logger.info(
|
||||
f"Retry succeeded on attempt {attempt + 1}",
|
||||
function=func.__name__,
|
||||
attempt=attempt + 1
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
|
||||
if not strategy.is_retriable(e):
|
||||
logger.error(
|
||||
f"Non-retriable exception occurred",
|
||||
function=func.__name__,
|
||||
exception=str(e)
|
||||
)
|
||||
raise
|
||||
|
||||
if attempt < strategy.max_attempts - 1:
|
||||
delay = strategy.calculate_delay(attempt)
|
||||
logger.warning(
|
||||
f"Attempt {attempt + 1} failed, retrying in {delay:.2f}s",
|
||||
function=func.__name__,
|
||||
attempt=attempt + 1,
|
||||
max_attempts=strategy.max_attempts,
|
||||
exception=str(e)
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
else:
|
||||
logger.error(
|
||||
f"All {strategy.max_attempts} retry attempts exhausted",
|
||||
function=func.__name__,
|
||||
exception=str(e)
|
||||
)
|
||||
|
||||
raise RetryError(
|
||||
f"Failed after {strategy.max_attempts} attempts: {str(last_exception)}",
|
||||
attempts=strategy.max_attempts,
|
||||
last_exception=last_exception
|
||||
)
|
||||
|
||||
|
||||
def with_retry(
|
||||
max_attempts: int = 3,
|
||||
initial_delay: float = 1.0,
|
||||
max_delay: float = 60.0,
|
||||
exponential_base: float = 2.0,
|
||||
jitter: bool = True,
|
||||
retriable_exceptions: Tuple[Type[Exception], ...] = (Exception,)
|
||||
):
|
||||
"""
|
||||
Decorator to add retry logic to async functions.
|
||||
|
||||
Example:
|
||||
@with_retry(max_attempts=5, initial_delay=2.0)
|
||||
async def fetch_data():
|
||||
# Your code here
|
||||
pass
|
||||
"""
|
||||
strategy = RetryStrategy(
|
||||
max_attempts=max_attempts,
|
||||
initial_delay=initial_delay,
|
||||
max_delay=max_delay,
|
||||
exponential_base=exponential_base,
|
||||
jitter=jitter,
|
||||
retriable_exceptions=retriable_exceptions
|
||||
)
|
||||
|
||||
def decorator(func: Callable):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
return await retry_async(func, *args, strategy=strategy, **kwargs)
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class AdaptiveRetryStrategy(RetryStrategy):
|
||||
"""
|
||||
Adaptive retry strategy that adjusts based on success/failure patterns.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.success_count = 0
|
||||
self.failure_count = 0
|
||||
self.consecutive_failures = 0
|
||||
|
||||
def calculate_delay(self, attempt: int) -> float:
|
||||
"""Calculate delay with adaptation based on recent history"""
|
||||
base_delay = super().calculate_delay(attempt)
|
||||
|
||||
# Increase delay if seeing consecutive failures
|
||||
if self.consecutive_failures > 5:
|
||||
multiplier = min(2.0, 1.0 + (self.consecutive_failures - 5) * 0.2)
|
||||
base_delay *= multiplier
|
||||
|
||||
return min(base_delay, self.max_delay)
|
||||
|
||||
def record_success(self):
|
||||
"""Record successful attempt"""
|
||||
self.success_count += 1
|
||||
self.consecutive_failures = 0
|
||||
|
||||
def record_failure(self):
|
||||
"""Record failed attempt"""
|
||||
self.failure_count += 1
|
||||
self.consecutive_failures += 1
|
||||
|
||||
|
||||
class TimeoutRetryStrategy(RetryStrategy):
|
||||
"""
|
||||
Retry strategy with overall timeout across all attempts.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, timeout: float = 300.0, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
timeout: Total timeout in seconds for all attempts
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.timeout = timeout
|
||||
self.start_time: Optional[float] = None
|
||||
|
||||
def should_retry(self, attempt: int) -> bool:
|
||||
"""Check if should attempt another retry"""
|
||||
if self.start_time is None:
|
||||
self.start_time = time.time()
|
||||
return True
|
||||
|
||||
elapsed = time.time() - self.start_time
|
||||
return elapsed < self.timeout and attempt < self.max_attempts
|
||||
|
||||
|
||||
async def retry_with_timeout(
|
||||
func: Callable,
|
||||
*args,
|
||||
max_attempts: int = 3,
|
||||
timeout: float = 300.0,
|
||||
**kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
Retry with overall timeout.
|
||||
|
||||
Args:
|
||||
func: Function to retry
|
||||
max_attempts: Maximum attempts
|
||||
timeout: Overall timeout in seconds
|
||||
|
||||
Returns:
|
||||
Result from func
|
||||
"""
|
||||
strategy = TimeoutRetryStrategy(
|
||||
max_attempts=max_attempts,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
strategy.start_time = start_time
|
||||
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(strategy.max_attempts):
|
||||
if time.time() - start_time >= timeout:
|
||||
raise RetryError(
|
||||
f"Timeout of {timeout}s exceeded",
|
||||
attempts=attempt + 1,
|
||||
last_exception=last_exception
|
||||
)
|
||||
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
|
||||
if not strategy.is_retriable(e):
|
||||
raise
|
||||
|
||||
if attempt < strategy.max_attempts - 1:
|
||||
delay = strategy.calculate_delay(attempt)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
raise RetryError(
|
||||
f"Failed after {strategy.max_attempts} attempts",
|
||||
attempts=strategy.max_attempts,
|
||||
last_exception=last_exception
|
||||
)
|
||||
|
||||
|
||||
# Pre-configured strategies for common use cases
|
||||
HTTP_RETRY_STRATEGY = RetryStrategy(
|
||||
max_attempts=3,
|
||||
initial_delay=1.0,
|
||||
max_delay=10.0,
|
||||
exponential_base=2.0,
|
||||
jitter=True
|
||||
)
|
||||
|
||||
DATABASE_RETRY_STRATEGY = RetryStrategy(
|
||||
max_attempts=5,
|
||||
initial_delay=0.5,
|
||||
max_delay=5.0,
|
||||
exponential_base=1.5,
|
||||
jitter=True
|
||||
)
|
||||
|
||||
EXTERNAL_SERVICE_RETRY_STRATEGY = RetryStrategy(
|
||||
max_attempts=4,
|
||||
initial_delay=2.0,
|
||||
max_delay=30.0,
|
||||
exponential_base=2.5,
|
||||
jitter=True
|
||||
)
|
||||
184
services/training/app/utils/timezone_utils.py
Normal file
184
services/training/app/utils/timezone_utils.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""
|
||||
Timezone Utility Functions
|
||||
Centralized timezone handling to ensure consistency across the training service
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Union
|
||||
import pandas as pd
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def ensure_timezone_aware(dt: datetime, default_tz=timezone.utc) -> datetime:
|
||||
"""
|
||||
Ensure a datetime is timezone-aware.
|
||||
|
||||
Args:
|
||||
dt: Datetime to check
|
||||
default_tz: Timezone to apply if datetime is naive (default: UTC)
|
||||
|
||||
Returns:
|
||||
Timezone-aware datetime
|
||||
"""
|
||||
if dt is None:
|
||||
return None
|
||||
|
||||
if dt.tzinfo is None:
|
||||
return dt.replace(tzinfo=default_tz)
|
||||
return dt
|
||||
|
||||
|
||||
def ensure_timezone_naive(dt: datetime) -> datetime:
|
||||
"""
|
||||
Remove timezone information from a datetime.
|
||||
|
||||
Args:
|
||||
dt: Datetime to process
|
||||
|
||||
Returns:
|
||||
Timezone-naive datetime
|
||||
"""
|
||||
if dt is None:
|
||||
return None
|
||||
|
||||
if dt.tzinfo is not None:
|
||||
return dt.replace(tzinfo=None)
|
||||
return dt
|
||||
|
||||
|
||||
def normalize_datetime_to_utc(dt: Union[datetime, pd.Timestamp]) -> datetime:
|
||||
"""
|
||||
Normalize any datetime to UTC timezone-aware datetime.
|
||||
|
||||
Args:
|
||||
dt: Datetime or pandas Timestamp to normalize
|
||||
|
||||
Returns:
|
||||
UTC timezone-aware datetime
|
||||
"""
|
||||
if dt is None:
|
||||
return None
|
||||
|
||||
# Handle pandas Timestamp
|
||||
if isinstance(dt, pd.Timestamp):
|
||||
dt = dt.to_pydatetime()
|
||||
|
||||
# If naive, assume UTC
|
||||
if dt.tzinfo is None:
|
||||
return dt.replace(tzinfo=timezone.utc)
|
||||
|
||||
# If aware but not UTC, convert to UTC
|
||||
return dt.astimezone(timezone.utc)
|
||||
|
||||
|
||||
def normalize_dataframe_datetime_column(
|
||||
df: pd.DataFrame,
|
||||
column: str,
|
||||
target_format: str = 'naive'
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Normalize a datetime column in a dataframe to consistent format.
|
||||
|
||||
Args:
|
||||
df: DataFrame to process
|
||||
column: Name of datetime column
|
||||
target_format: 'naive' or 'aware' (UTC)
|
||||
|
||||
Returns:
|
||||
DataFrame with normalized datetime column
|
||||
"""
|
||||
if column not in df.columns:
|
||||
logger.warning(f"Column {column} not found in dataframe")
|
||||
return df
|
||||
|
||||
# Convert to datetime if not already
|
||||
df[column] = pd.to_datetime(df[column])
|
||||
|
||||
if target_format == 'naive':
|
||||
# Remove timezone if present
|
||||
if df[column].dt.tz is not None:
|
||||
df[column] = df[column].dt.tz_localize(None)
|
||||
elif target_format == 'aware':
|
||||
# Add UTC timezone if not present
|
||||
if df[column].dt.tz is None:
|
||||
df[column] = df[column].dt.tz_localize(timezone.utc)
|
||||
else:
|
||||
# Convert to UTC if different timezone
|
||||
df[column] = df[column].dt.tz_convert(timezone.utc)
|
||||
else:
|
||||
raise ValueError(f"Invalid target_format: {target_format}. Must be 'naive' or 'aware'")
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def prepare_prophet_datetime(df: pd.DataFrame, datetime_col: str = 'ds') -> pd.DataFrame:
|
||||
"""
|
||||
Prepare datetime column for Prophet (requires timezone-naive datetimes).
|
||||
|
||||
Args:
|
||||
df: DataFrame with datetime column
|
||||
datetime_col: Name of datetime column (default: 'ds')
|
||||
|
||||
Returns:
|
||||
DataFrame with Prophet-compatible datetime column
|
||||
"""
|
||||
df = df.copy()
|
||||
df = normalize_dataframe_datetime_column(df, datetime_col, target_format='naive')
|
||||
return df
|
||||
|
||||
|
||||
def safe_datetime_comparison(dt1: datetime, dt2: datetime) -> int:
|
||||
"""
|
||||
Safely compare two datetimes, handling timezone mismatches.
|
||||
|
||||
Args:
|
||||
dt1: First datetime
|
||||
dt2: Second datetime
|
||||
|
||||
Returns:
|
||||
-1 if dt1 < dt2, 0 if equal, 1 if dt1 > dt2
|
||||
"""
|
||||
# Normalize both to UTC for comparison
|
||||
dt1_utc = normalize_datetime_to_utc(dt1)
|
||||
dt2_utc = normalize_datetime_to_utc(dt2)
|
||||
|
||||
if dt1_utc < dt2_utc:
|
||||
return -1
|
||||
elif dt1_utc > dt2_utc:
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def get_current_utc() -> datetime:
|
||||
"""
|
||||
Get current datetime in UTC with timezone awareness.
|
||||
|
||||
Returns:
|
||||
Current UTC datetime
|
||||
"""
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def convert_timestamp_to_datetime(timestamp: Union[int, float, str]) -> datetime:
|
||||
"""
|
||||
Convert various timestamp formats to datetime.
|
||||
|
||||
Args:
|
||||
timestamp: Unix timestamp (seconds or milliseconds) or ISO string
|
||||
|
||||
Returns:
|
||||
UTC timezone-aware datetime
|
||||
"""
|
||||
if isinstance(timestamp, str):
|
||||
dt = pd.to_datetime(timestamp)
|
||||
return normalize_datetime_to_utc(dt)
|
||||
|
||||
# Check if milliseconds (typical JavaScript timestamp)
|
||||
if timestamp > 1e10:
|
||||
timestamp = timestamp / 1000
|
||||
|
||||
dt = datetime.fromtimestamp(timestamp, tz=timezone.utc)
|
||||
return dt
|
||||
11
services/training/app/websocket/__init__.py
Normal file
11
services/training/app/websocket/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""WebSocket support for training service"""
|
||||
|
||||
from app.websocket.manager import websocket_manager, WebSocketConnectionManager
|
||||
from app.websocket.events import setup_websocket_event_consumer, cleanup_websocket_consumers
|
||||
|
||||
__all__ = [
|
||||
'websocket_manager',
|
||||
'WebSocketConnectionManager',
|
||||
'setup_websocket_event_consumer',
|
||||
'cleanup_websocket_consumers'
|
||||
]
|
||||
148
services/training/app/websocket/events.py
Normal file
148
services/training/app/websocket/events.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""
|
||||
RabbitMQ Event Consumer for WebSocket Broadcasting
|
||||
Listens to training events from RabbitMQ and broadcasts them to WebSocket clients
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Dict, Set
|
||||
import structlog
|
||||
|
||||
from app.websocket.manager import websocket_manager
|
||||
from app.services.training_events import training_publisher
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Track active consumers
|
||||
_active_consumers: Set[asyncio.Task] = set()
|
||||
|
||||
|
||||
async def handle_training_event(message) -> None:
|
||||
"""
|
||||
Handle incoming RabbitMQ training events and broadcast to WebSocket clients.
|
||||
This is the bridge between RabbitMQ and WebSocket.
|
||||
"""
|
||||
try:
|
||||
# Parse message
|
||||
body = message.body.decode()
|
||||
data = json.loads(body)
|
||||
|
||||
event_type = data.get('event_type', 'unknown')
|
||||
event_data = data.get('data', {})
|
||||
job_id = event_data.get('job_id')
|
||||
|
||||
if not job_id:
|
||||
logger.warning("Received event without job_id, skipping", event_type=event_type)
|
||||
await message.ack()
|
||||
return
|
||||
|
||||
logger.info("Received training event from RabbitMQ",
|
||||
job_id=job_id,
|
||||
event_type=event_type,
|
||||
progress=event_data.get('progress'))
|
||||
|
||||
# Map RabbitMQ event types to WebSocket message types
|
||||
ws_message_type = _map_event_type(event_type)
|
||||
|
||||
# Create WebSocket message
|
||||
ws_message = {
|
||||
"type": ws_message_type,
|
||||
"job_id": job_id,
|
||||
"timestamp": data.get('timestamp'),
|
||||
"data": event_data
|
||||
}
|
||||
|
||||
# Broadcast to all WebSocket clients for this job
|
||||
sent_count = await websocket_manager.broadcast(job_id, ws_message)
|
||||
|
||||
logger.info("Broadcasted event to WebSocket clients",
|
||||
job_id=job_id,
|
||||
event_type=event_type,
|
||||
ws_message_type=ws_message_type,
|
||||
clients_notified=sent_count)
|
||||
|
||||
# Always acknowledge the message to avoid infinite redelivery loops
|
||||
# Progress events (started, progress, product_completed) are ephemeral and don't need redelivery
|
||||
# Final events (completed, failed) should always be acknowledged
|
||||
await message.ack()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error handling training event",
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
# Always acknowledge even on error to avoid infinite redelivery loops
|
||||
# The event is logged so we can debug issues
|
||||
try:
|
||||
await message.ack()
|
||||
except:
|
||||
pass # Message already gone or connection closed
|
||||
|
||||
|
||||
def _map_event_type(rabbitmq_event_type: str) -> str:
|
||||
"""Map RabbitMQ event types to WebSocket message types"""
|
||||
mapping = {
|
||||
"training.started": "started",
|
||||
"training.progress": "progress",
|
||||
"training.step.completed": "step_completed",
|
||||
"training.product.completed": "product_completed",
|
||||
"training.completed": "completed",
|
||||
"training.failed": "failed",
|
||||
}
|
||||
return mapping.get(rabbitmq_event_type, "unknown")
|
||||
|
||||
|
||||
async def setup_websocket_event_consumer() -> bool:
|
||||
"""
|
||||
Set up a global RabbitMQ consumer that listens to all training events
|
||||
and broadcasts them to connected WebSocket clients.
|
||||
"""
|
||||
try:
|
||||
# Ensure publisher is connected
|
||||
if not training_publisher.connected:
|
||||
logger.info("Connecting training publisher for WebSocket event consumer")
|
||||
success = await training_publisher.connect()
|
||||
if not success:
|
||||
logger.error("Failed to connect training publisher")
|
||||
return False
|
||||
|
||||
# Create a unique queue for WebSocket broadcasting
|
||||
queue_name = "training_websocket_broadcast"
|
||||
|
||||
logger.info("Setting up WebSocket event consumer", queue_name=queue_name)
|
||||
|
||||
# Subscribe to all training events (routing key: training.#)
|
||||
success = await training_publisher.consume_events(
|
||||
exchange_name="training.events",
|
||||
queue_name=queue_name,
|
||||
routing_key="training.#", # Listen to all training events (multi-level)
|
||||
callback=handle_training_event
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("WebSocket event consumer set up successfully")
|
||||
return True
|
||||
else:
|
||||
logger.error("Failed to set up WebSocket event consumer")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error setting up WebSocket event consumer",
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
async def cleanup_websocket_consumers() -> None:
|
||||
"""Clean up WebSocket event consumers"""
|
||||
logger.info("Cleaning up WebSocket event consumers")
|
||||
|
||||
for task in _active_consumers:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
_active_consumers.clear()
|
||||
logger.info("WebSocket event consumers cleaned up")
|
||||
120
services/training/app/websocket/manager.py
Normal file
120
services/training/app/websocket/manager.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""
|
||||
WebSocket Connection Manager for Training Service
|
||||
Manages WebSocket connections and broadcasts RabbitMQ events to connected clients
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Dict, Set
|
||||
from fastapi import WebSocket
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class WebSocketConnectionManager:
|
||||
"""
|
||||
Simple WebSocket connection manager.
|
||||
Manages connections per job_id and broadcasts messages to all connected clients.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Structure: {job_id: {websocket_id: WebSocket}}
|
||||
self._connections: Dict[str, Dict[int, WebSocket]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
# Store latest event for each job to provide initial state
|
||||
self._latest_events: Dict[str, dict] = {}
|
||||
|
||||
async def connect(self, job_id: str, websocket: WebSocket) -> None:
|
||||
"""Register a new WebSocket connection for a job"""
|
||||
await websocket.accept()
|
||||
|
||||
async with self._lock:
|
||||
if job_id not in self._connections:
|
||||
self._connections[job_id] = {}
|
||||
|
||||
ws_id = id(websocket)
|
||||
self._connections[job_id][ws_id] = websocket
|
||||
|
||||
# Send initial state if available
|
||||
if job_id in self._latest_events:
|
||||
try:
|
||||
await websocket.send_json({
|
||||
"type": "initial_state",
|
||||
"job_id": job_id,
|
||||
"data": self._latest_events[job_id]
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning("Failed to send initial state to new connection", error=str(e))
|
||||
|
||||
logger.info("WebSocket connected",
|
||||
job_id=job_id,
|
||||
websocket_id=ws_id,
|
||||
total_connections=len(self._connections[job_id]))
|
||||
|
||||
async def disconnect(self, job_id: str, websocket: WebSocket) -> None:
|
||||
"""Remove a WebSocket connection"""
|
||||
async with self._lock:
|
||||
if job_id in self._connections:
|
||||
ws_id = id(websocket)
|
||||
self._connections[job_id].pop(ws_id, None)
|
||||
|
||||
# Clean up empty job connections
|
||||
if not self._connections[job_id]:
|
||||
del self._connections[job_id]
|
||||
|
||||
logger.info("WebSocket disconnected",
|
||||
job_id=job_id,
|
||||
websocket_id=ws_id,
|
||||
remaining_connections=len(self._connections.get(job_id, {})))
|
||||
|
||||
async def broadcast(self, job_id: str, message: dict) -> int:
|
||||
"""
|
||||
Broadcast a message to all connections for a specific job.
|
||||
Returns the number of successful broadcasts.
|
||||
"""
|
||||
# Store the latest event for this job to provide initial state to new connections
|
||||
if message.get('type') != 'initial_state': # Don't store initial_state messages
|
||||
self._latest_events[job_id] = message
|
||||
|
||||
if job_id not in self._connections:
|
||||
logger.debug("No active connections for job", job_id=job_id)
|
||||
return 0
|
||||
|
||||
connections = list(self._connections[job_id].values())
|
||||
successful_sends = 0
|
||||
failed_websockets = []
|
||||
|
||||
for websocket in connections:
|
||||
try:
|
||||
await websocket.send_json(message)
|
||||
successful_sends += 1
|
||||
except Exception as e:
|
||||
logger.warning("Failed to send message to WebSocket",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
failed_websockets.append(websocket)
|
||||
|
||||
# Clean up failed connections
|
||||
if failed_websockets:
|
||||
async with self._lock:
|
||||
for ws in failed_websockets:
|
||||
ws_id = id(ws)
|
||||
self._connections[job_id].pop(ws_id, None)
|
||||
|
||||
if successful_sends > 0:
|
||||
logger.info("Broadcasted message to WebSocket clients",
|
||||
job_id=job_id,
|
||||
message_type=message.get('type'),
|
||||
successful_sends=successful_sends,
|
||||
failed_sends=len(failed_websockets))
|
||||
|
||||
return successful_sends
|
||||
|
||||
def get_connection_count(self, job_id: str) -> int:
|
||||
"""Get the number of active connections for a job"""
|
||||
return len(self._connections.get(job_id, {}))
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
websocket_manager = WebSocketConnectionManager()
|
||||
Reference in New Issue
Block a user