Initial commit - production deployment

This commit is contained in:
2026-01-21 17:17:16 +01:00
commit c23d00dd92
2289 changed files with 638440 additions and 0 deletions

View File

View File

@@ -0,0 +1,16 @@
"""
Training API Layer
HTTP endpoints for ML training operations and WebSocket connections
"""
from .training_jobs import router as training_jobs_router
from .training_operations import router as training_operations_router
from .models import router as models_router
from .websocket_operations import router as websocket_operations_router
__all__ = [
"training_jobs_router",
"training_operations_router",
"models_router",
"websocket_operations_router"
]

View File

@@ -0,0 +1,237 @@
# services/training/app/api/audit.py
"""
Audit Logs API - Retrieve audit trail for training service
"""
from fastapi import APIRouter, Depends, HTTPException, Query, Path, status
from typing import Optional, Dict, Any
from uuid import UUID
from datetime import datetime
import structlog
from sqlalchemy import select, func, and_
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import AuditLog
from shared.auth.decorators import get_current_user_dep
from shared.auth.access_control import require_user_role
from shared.routing import RouteBuilder
from shared.models.audit_log_schemas import (
AuditLogResponse,
AuditLogListResponse,
AuditLogStatsResponse
)
from app.core.database import database_manager
route_builder = RouteBuilder('training')
router = APIRouter(tags=["audit-logs"])
logger = structlog.get_logger()
async def get_db():
"""Database session dependency"""
async with database_manager.get_session() as session:
yield session
@router.get(
route_builder.build_base_route("audit-logs"),
response_model=AuditLogListResponse
)
@require_user_role(['admin', 'owner'])
async def get_audit_logs(
tenant_id: UUID = Path(..., description="Tenant ID"),
start_date: Optional[datetime] = Query(None, description="Filter logs from this date"),
end_date: Optional[datetime] = Query(None, description="Filter logs until this date"),
user_id: Optional[UUID] = Query(None, description="Filter by user ID"),
action: Optional[str] = Query(None, description="Filter by action type"),
resource_type: Optional[str] = Query(None, description="Filter by resource type"),
severity: Optional[str] = Query(None, description="Filter by severity level"),
search: Optional[str] = Query(None, description="Search in description field"),
limit: int = Query(100, ge=1, le=1000, description="Number of records to return"),
offset: int = Query(0, ge=0, description="Number of records to skip"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""
Get audit logs for training service.
Requires admin or owner role.
"""
try:
logger.info(
"Retrieving audit logs",
tenant_id=tenant_id,
user_id=current_user.get("user_id"),
filters={
"start_date": start_date,
"end_date": end_date,
"action": action,
"resource_type": resource_type,
"severity": severity
}
)
# Build query filters
filters = [AuditLog.tenant_id == tenant_id]
if start_date:
filters.append(AuditLog.created_at >= start_date)
if end_date:
filters.append(AuditLog.created_at <= end_date)
if user_id:
filters.append(AuditLog.user_id == user_id)
if action:
filters.append(AuditLog.action == action)
if resource_type:
filters.append(AuditLog.resource_type == resource_type)
if severity:
filters.append(AuditLog.severity == severity)
if search:
filters.append(AuditLog.description.ilike(f"%{search}%"))
# Count total matching records
count_query = select(func.count()).select_from(AuditLog).where(and_(*filters))
total_result = await db.execute(count_query)
total = total_result.scalar() or 0
# Fetch paginated results
query = (
select(AuditLog)
.where(and_(*filters))
.order_by(AuditLog.created_at.desc())
.limit(limit)
.offset(offset)
)
result = await db.execute(query)
audit_logs = result.scalars().all()
# Convert to response models
items = [AuditLogResponse.from_orm(log) for log in audit_logs]
logger.info(
"Successfully retrieved audit logs",
tenant_id=tenant_id,
total=total,
returned=len(items)
)
return AuditLogListResponse(
items=items,
total=total,
limit=limit,
offset=offset,
has_more=(offset + len(items)) < total
)
except Exception as e:
logger.error(
"Failed to retrieve audit logs",
error=str(e),
tenant_id=tenant_id
)
raise HTTPException(
status_code=500,
detail=f"Failed to retrieve audit logs: {str(e)}"
)
@router.get(
route_builder.build_base_route("audit-logs/stats"),
response_model=AuditLogStatsResponse
)
@require_user_role(['admin', 'owner'])
async def get_audit_log_stats(
tenant_id: UUID = Path(..., description="Tenant ID"),
start_date: Optional[datetime] = Query(None, description="Filter logs from this date"),
end_date: Optional[datetime] = Query(None, description="Filter logs until this date"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""
Get audit log statistics for training service.
Requires admin or owner role.
"""
try:
logger.info(
"Retrieving audit log statistics",
tenant_id=tenant_id,
user_id=current_user.get("user_id")
)
# Build base filters
filters = [AuditLog.tenant_id == tenant_id]
if start_date:
filters.append(AuditLog.created_at >= start_date)
if end_date:
filters.append(AuditLog.created_at <= end_date)
# Total events
count_query = select(func.count()).select_from(AuditLog).where(and_(*filters))
total_result = await db.execute(count_query)
total_events = total_result.scalar() or 0
# Events by action
action_query = (
select(AuditLog.action, func.count().label('count'))
.where(and_(*filters))
.group_by(AuditLog.action)
)
action_result = await db.execute(action_query)
events_by_action = {row.action: row.count for row in action_result}
# Events by severity
severity_query = (
select(AuditLog.severity, func.count().label('count'))
.where(and_(*filters))
.group_by(AuditLog.severity)
)
severity_result = await db.execute(severity_query)
events_by_severity = {row.severity: row.count for row in severity_result}
# Events by resource type
resource_query = (
select(AuditLog.resource_type, func.count().label('count'))
.where(and_(*filters))
.group_by(AuditLog.resource_type)
)
resource_result = await db.execute(resource_query)
events_by_resource_type = {row.resource_type: row.count for row in resource_result}
# Date range
date_range_query = (
select(
func.min(AuditLog.created_at).label('min_date'),
func.max(AuditLog.created_at).label('max_date')
)
.where(and_(*filters))
)
date_result = await db.execute(date_range_query)
date_row = date_result.one()
logger.info(
"Successfully retrieved audit log statistics",
tenant_id=tenant_id,
total_events=total_events
)
return AuditLogStatsResponse(
total_events=total_events,
events_by_action=events_by_action,
events_by_severity=events_by_severity,
events_by_resource_type=events_by_resource_type,
date_range={
"min": date_row.min_date,
"max": date_row.max_date
}
)
except Exception as e:
logger.error(
"Failed to retrieve audit log statistics",
error=str(e),
tenant_id=tenant_id
)
raise HTTPException(
status_code=500,
detail=f"Failed to retrieve audit log statistics: {str(e)}"
)

View File

@@ -0,0 +1,261 @@
"""
Enhanced Health Check Endpoints
Comprehensive service health monitoring
"""
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import text
from typing import Dict, Any
import psutil
import os
from datetime import datetime, timezone
import logging
from app.core.database import database_manager
from app.utils.circuit_breaker import circuit_breaker_registry
from app.core.config import settings
logger = logging.getLogger(__name__)
router = APIRouter()
async def check_database_health() -> Dict[str, Any]:
"""Check database connectivity and performance"""
try:
start_time = datetime.now(timezone.utc)
async with database_manager.async_engine.begin() as conn:
# Simple connectivity check
await conn.execute(text("SELECT 1"))
# Check if we can access training tables
result = await conn.execute(
text("SELECT COUNT(*) FROM trained_models")
)
model_count = result.scalar()
# Check connection pool stats
pool = database_manager.async_engine.pool
pool_size = pool.size()
pool_checked_out = pool.checked_out_connections()
response_time = (datetime.now(timezone.utc) - start_time).total_seconds()
return {
"status": "healthy",
"response_time_seconds": round(response_time, 3),
"model_count": model_count,
"connection_pool": {
"size": pool_size,
"checked_out": pool_checked_out,
"available": pool_size - pool_checked_out
}
}
except Exception as e:
logger.error(f"Database health check failed: {e}")
return {
"status": "unhealthy",
"error": str(e)
}
def check_system_resources() -> Dict[str, Any]:
"""Check system resource usage"""
try:
cpu_percent = psutil.cpu_percent(interval=0.1)
memory = psutil.virtual_memory()
disk = psutil.disk_usage('/')
return {
"status": "healthy",
"cpu": {
"usage_percent": cpu_percent,
"count": psutil.cpu_count()
},
"memory": {
"total_mb": round(memory.total / 1024 / 1024, 2),
"used_mb": round(memory.used / 1024 / 1024, 2),
"available_mb": round(memory.available / 1024 / 1024, 2),
"usage_percent": memory.percent
},
"disk": {
"total_gb": round(disk.total / 1024 / 1024 / 1024, 2),
"used_gb": round(disk.used / 1024 / 1024 / 1024, 2),
"free_gb": round(disk.free / 1024 / 1024 / 1024, 2),
"usage_percent": disk.percent
}
}
except Exception as e:
logger.error(f"System resource check failed: {e}")
return {
"status": "error",
"error": str(e)
}
def check_model_storage() -> Dict[str, Any]:
"""Check MinIO model storage health"""
try:
from shared.clients.minio_client import minio_client
# Check MinIO connectivity
if not minio_client.health_check():
return {
"status": "unhealthy",
"message": "MinIO service is not reachable",
"storage_type": "minio"
}
bucket_name = settings.MINIO_MODEL_BUCKET
# Check if bucket exists
bucket_exists = minio_client.bucket_exists(bucket_name)
if not bucket_exists:
return {
"status": "warning",
"message": f"MinIO bucket does not exist: {bucket_name}",
"storage_type": "minio"
}
# Count model files in MinIO
model_objects = minio_client.list_objects(bucket_name, prefix="models/")
model_files = [obj for obj in model_objects if obj.endswith('.pkl')]
return {
"status": "healthy",
"storage_type": "minio",
"endpoint": settings.MINIO_ENDPOINT,
"bucket": bucket_name,
"use_ssl": settings.MINIO_USE_SSL,
"model_files": len(model_files),
"bucket_exists": bucket_exists
}
except Exception as e:
logger.error(f"MinIO storage check failed: {e}")
return {
"status": "error",
"storage_type": "minio",
"error": str(e)
}
@router.get("/health")
async def health_check() -> Dict[str, Any]:
"""
Basic health check endpoint.
Returns 200 if service is running.
"""
return {
"status": "healthy",
"service": "training-service",
"timestamp": datetime.now(timezone.utc).isoformat()
}
@router.get("/health/detailed")
async def detailed_health_check() -> Dict[str, Any]:
"""
Detailed health check with component status.
Includes database, system resources, and dependencies.
"""
database_health = await check_database_health()
system_health = check_system_resources()
storage_health = check_model_storage()
circuit_breakers = circuit_breaker_registry.get_all_states()
# Determine overall status
component_statuses = [
database_health.get("status"),
system_health.get("status"),
storage_health.get("status")
]
if "unhealthy" in component_statuses or "error" in component_statuses:
overall_status = "unhealthy"
elif "degraded" in component_statuses or "warning" in component_statuses:
overall_status = "degraded"
else:
overall_status = "healthy"
return {
"status": overall_status,
"service": "training-service",
"version": "1.0.0",
"timestamp": datetime.now(timezone.utc).isoformat(),
"components": {
"database": database_health,
"system": system_health,
"storage": storage_health
},
"circuit_breakers": circuit_breakers,
"configuration": {
"max_concurrent_jobs": settings.MAX_CONCURRENT_TRAINING_JOBS,
"min_training_days": settings.MIN_TRAINING_DATA_DAYS,
"pool_size": settings.DB_POOL_SIZE,
"pool_max_overflow": settings.DB_MAX_OVERFLOW
}
}
@router.get("/health/ready")
async def readiness_check() -> Dict[str, Any]:
"""
Readiness check for Kubernetes.
Returns 200 only if service is ready to accept traffic.
"""
database_health = await check_database_health()
if database_health.get("status") != "healthy":
raise HTTPException(
status_code=503,
detail="Service not ready: database unavailable"
)
storage_health = check_model_storage()
if storage_health.get("status") == "error":
raise HTTPException(
status_code=503,
detail="Service not ready: model storage unavailable"
)
return {
"status": "ready",
"timestamp": datetime.now(timezone.utc).isoformat()
}
@router.get("/health/live")
async def liveness_check() -> Dict[str, Any]:
"""
Liveness check for Kubernetes.
Returns 200 if service process is alive.
"""
return {
"status": "alive",
"timestamp": datetime.now(timezone.utc).isoformat(),
"pid": os.getpid()
}
@router.get("/metrics/system")
async def system_metrics() -> Dict[str, Any]:
"""
Detailed system metrics for monitoring.
"""
process = psutil.Process(os.getpid())
return {
"timestamp": datetime.now(timezone.utc).isoformat(),
"process": {
"pid": os.getpid(),
"cpu_percent": process.cpu_percent(interval=0.1),
"memory_mb": round(process.memory_info().rss / 1024 / 1024, 2),
"threads": process.num_threads(),
"open_files": len(process.open_files()),
"connections": len(process.connections())
},
"system": check_system_resources()
}

View File

@@ -0,0 +1,464 @@
"""
Models API endpoints
"""
from fastapi import APIRouter, Depends, HTTPException, status, Path, Query
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, Optional
import structlog
from sqlalchemy import text
from app.core.database import get_db
from app.schemas.training import TrainedModelResponse, ModelMetricsResponse
from app.services.training_service import EnhancedTrainingService
from datetime import datetime, timezone
from sqlalchemy import select, delete, func
import uuid
from shared.auth.decorators import (
get_current_user_dep,
require_admin_role
)
from shared.routing import RouteBuilder
from shared.auth.access_control import (
require_user_role,
admin_role_required,
owner_role_required,
require_subscription_tier,
analytics_tier_required,
enterprise_tier_required
)
# Create route builder for consistent URL structure
route_builder = RouteBuilder('training')
logger = structlog.get_logger()
router = APIRouter()
training_service = EnhancedTrainingService()
@router.get(
route_builder.build_base_route("models") + "/{inventory_product_id}/active",
response_model=TrainedModelResponse
)
async def get_active_model(
tenant_id: str = Path(..., description="Tenant ID"),
inventory_product_id: str = Path(..., description="Inventory product UUID"),
db: AsyncSession = Depends(get_db)
):
"""
Get the active model for a product - used by forecasting service
"""
try:
logger.debug("Getting active model", tenant_id=tenant_id, inventory_product_id=inventory_product_id)
# ✅ FIX: Wrap SQL with text() for SQLAlchemy 2.0 and add case-insensitive product name matching
query = text("""
SELECT * FROM trained_models
WHERE tenant_id = :tenant_id
AND inventory_product_id = :inventory_product_id
AND is_active = true
AND is_production = true
ORDER BY created_at DESC
LIMIT 1
""")
result = await db.execute(query, {
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id
})
model_record = result.fetchone()
if not model_record:
logger.info("No active model found", tenant_id=tenant_id, inventory_product_id=inventory_product_id)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No active model found for product {inventory_product_id}"
)
# ✅ FIX: Wrap update query with text() too
update_query = text("""
UPDATE trained_models
SET last_used_at = :now
WHERE id = :model_id
""")
await db.execute(update_query, {
"now": datetime.now(timezone.utc),
"model_id": model_record.id
})
await db.commit()
return {
"model_id": str(model_record.id),
"tenant_id": str(model_record.tenant_id),
"inventory_product_id": str(model_record.inventory_product_id),
"model_type": model_record.model_type,
"model_path": model_record.model_path,
"version": 1, # Default version
"training_samples": model_record.training_samples or 0,
"features": model_record.features_used or [],
"hyperparameters": model_record.hyperparameters or {},
"training_metrics": {
"mape": model_record.mape or 0.0,
"mae": model_record.mae or 0.0,
"rmse": model_record.rmse or 0.0,
"r2_score": model_record.r2_score or 0.0
},
"is_active": model_record.is_active,
"created_at": model_record.created_at,
"data_period_start": model_record.training_start_date,
"data_period_end": model_record.training_end_date
}
except HTTPException:
raise
except Exception as e:
error_msg = str(e) if str(e) else f"{type(e).__name__}: {repr(e)}"
logger.error(f"Failed to get active model: {error_msg}", tenant_id=tenant_id, inventory_product_id=inventory_product_id)
# Handle client disconnection gracefully
if "EndOfStream" in str(type(e)) or "WouldBlock" in str(type(e)):
logger.info("Client disconnected during model retrieval", tenant_id=tenant_id, inventory_product_id=inventory_product_id)
raise HTTPException(
status_code=status.HTTP_408_REQUEST_TIMEOUT,
detail="Request connection closed"
)
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve model"
)
@router.get(
route_builder.build_nested_resource_route("models", "model_id", "metrics"),
response_model=ModelMetricsResponse
)
async def get_model_metrics(
model_id: str = Path(..., description="Model ID"),
db: AsyncSession = Depends(get_db)
):
"""
Get performance metrics for a specific model - used by forecasting service
"""
try:
# Query the model by ID
query = text("""
SELECT * FROM trained_models
WHERE id = :model_id
""")
result = await db.execute(query, {"model_id": model_id})
model_record = result.fetchone()
if not model_record:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Model {model_id} not found"
)
# Return metrics in the format expected by forecasting service
metrics = {
"model_id": str(model_record.id),
"accuracy": model_record.r2_score or 0.0, # Use R2 as accuracy measure
"mape": model_record.mape or 0.0,
"mae": model_record.mae or 0.0,
"rmse": model_record.rmse or 0.0,
"r2_score": model_record.r2_score or 0.0,
"training_samples": model_record.training_samples or 0,
"features_used": model_record.features_used or [],
"model_type": model_record.model_type,
"created_at": model_record.created_at.isoformat() if model_record.created_at else None,
"last_used_at": model_record.last_used_at.isoformat() if model_record.last_used_at else None
}
logger.info(f"Retrieved metrics for model {model_id}",
mape=metrics["mape"],
accuracy=metrics["accuracy"])
return metrics
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to get model metrics: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve model metrics"
)
@router.get(
route_builder.build_base_route("models"),
response_model=List[TrainedModelResponse]
)
async def list_models(
tenant_id: str = Path(..., description="Tenant ID"),
status: Optional[str] = Query(None, description="Filter by status (active/inactive)"),
model_type: Optional[str] = Query(None, description="Filter by model type"),
limit: int = Query(50, ge=1, le=100, description="Maximum number of models to return"),
db: AsyncSession = Depends(get_db)
):
"""
List models for a tenant - used by forecasting service for model discovery
"""
try:
# Build query with filters
query_parts = ["SELECT * FROM trained_models WHERE tenant_id = :tenant_id"]
params = {"tenant_id": tenant_id}
if status == "deployed" or status == "active":
query_parts.append("AND is_active = true AND is_production = true")
elif status == "inactive":
query_parts.append("AND (is_active = false OR is_production = false)")
if model_type:
query_parts.append("AND model_type = :model_type")
params["model_type"] = model_type
query_parts.append("ORDER BY created_at DESC LIMIT :limit")
params["limit"] = limit
query = text(" ".join(query_parts))
result = await db.execute(query, params)
model_records = result.fetchall()
models = []
for record in model_records:
models.append({
"model_id": str(record.id),
"tenant_id": str(record.tenant_id),
"inventory_product_id": str(record.inventory_product_id),
"model_type": record.model_type,
"model_path": record.model_path,
"version": 1, # Default version
"training_samples": record.training_samples or 0,
"features": record.features_used or [],
"hyperparameters": record.hyperparameters or {},
"training_metrics": {
"mape": record.mape or 0.0,
"mae": record.mae or 0.0,
"rmse": record.rmse or 0.0,
"r2_score": record.r2_score or 0.0
},
"is_active": record.is_active,
"created_at": record.created_at,
"data_period_start": record.training_start_date,
"data_period_end": record.training_end_date
})
logger.info(f"Retrieved {len(models)} models for tenant {tenant_id}")
return models
except Exception as e:
logger.error(f"Failed to list models: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve models"
)
@router.delete("/models/tenant/{tenant_id}")
@require_user_role(['admin', 'owner'])
async def delete_tenant_models_complete(
tenant_id: str,
current_user = Depends(get_current_user_dep),
_admin_check = Depends(require_admin_role),
db: AsyncSession = Depends(get_db)
):
"""
Delete all trained models and artifacts for a tenant.
**WARNING: This operation is irreversible!**
This endpoint:
1. Cancels any active training jobs for the tenant
2. Deletes all model artifacts (files) from storage
3. Deletes model records from database
4. Deletes training logs and performance metrics
5. Publishes deletion event
Used by admin user deletion process to clean up all training data.
"""
try:
tenant_uuid = uuid.UUID(tenant_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid tenant ID format"
)
try:
from app.models.training import (
ModelTrainingLog,
TrainedModel,
ModelArtifact,
ModelPerformanceMetric,
TrainingJobQueue
)
from app.core.config import settings
deletion_stats = {
"tenant_id": tenant_id,
"deleted_at": datetime.now(timezone.utc).isoformat(),
"jobs_cancelled": 0,
"models_deleted": 0,
"artifacts_deleted": 0,
"minio_objects_deleted": 0,
"training_logs_deleted": 0,
"performance_metrics_deleted": 0,
"errors": []
}
# Step 1: Cancel active training jobs
try:
active_jobs_query = select(TrainingJobQueue).where(
TrainingJobQueue.tenant_id == tenant_uuid,
TrainingJobQueue.status.in_(["queued", "running", "pending"])
)
active_jobs_result = await db.execute(active_jobs_query)
active_jobs = active_jobs_result.scalars().all()
for job in active_jobs:
job.status = "cancelled"
job.updated_at = datetime.now(timezone.utc)
deletion_stats["jobs_cancelled"] += 1
if active_jobs:
await db.commit()
logger.info("Cancelled active training jobs",
tenant_id=tenant_id,
count=len(active_jobs))
except Exception as e:
error_msg = f"Error cancelling training jobs: {str(e)}"
deletion_stats["errors"].append(error_msg)
logger.error(error_msg)
# Step 2: Delete model artifact files from MinIO storage
try:
from shared.clients.minio_client import minio_client
bucket_name = settings.MINIO_MODEL_BUCKET
prefix = f"models/{tenant_id}/"
# List all objects for this tenant
objects_to_delete = minio_client.list_objects(bucket_name, prefix=prefix)
files_deleted = 0
for obj_name in objects_to_delete:
try:
minio_client.delete_object(bucket_name, obj_name)
files_deleted += 1
logger.debug("Deleted MinIO object", object_name=obj_name)
except Exception as e:
error_msg = f"Error deleting MinIO object {obj_name}: {str(e)}"
deletion_stats["errors"].append(error_msg)
logger.warning(error_msg)
deletion_stats["minio_objects_deleted"] = files_deleted
logger.info("Deleted MinIO objects",
tenant_id=tenant_id,
files_deleted=files_deleted)
except Exception as e:
error_msg = f"Error processing MinIO objects: {str(e)}"
deletion_stats["errors"].append(error_msg)
logger.error(error_msg)
# Step 3: Delete database records
try:
# Delete model performance metrics
metrics_count_query = select(func.count(ModelPerformanceMetric.id)).where(
ModelPerformanceMetric.tenant_id == tenant_uuid
)
metrics_count_result = await db.execute(metrics_count_query)
metrics_count = metrics_count_result.scalar()
metrics_delete_query = delete(ModelPerformanceMetric).where(
ModelPerformanceMetric.tenant_id == tenant_uuid
)
await db.execute(metrics_delete_query)
deletion_stats["performance_metrics_deleted"] = metrics_count
# Delete model artifacts records
artifacts_count_query = select(func.count(ModelArtifact.id)).where(
ModelArtifact.tenant_id == tenant_uuid
)
artifacts_count_result = await db.execute(artifacts_count_query)
artifacts_count = artifacts_count_result.scalar()
artifacts_delete_query = delete(ModelArtifact).where(
ModelArtifact.tenant_id == tenant_uuid
)
await db.execute(artifacts_delete_query)
deletion_stats["artifacts_deleted"] = artifacts_count
# Delete trained models
models_count_query = select(func.count(TrainedModel.id)).where(
TrainedModel.tenant_id == tenant_uuid
)
models_count_result = await db.execute(models_count_query)
models_count = models_count_result.scalar()
models_delete_query = delete(TrainedModel).where(
TrainedModel.tenant_id == tenant_uuid
)
await db.execute(models_delete_query)
deletion_stats["models_deleted"] = models_count
# Delete training logs
logs_count_query = select(func.count(ModelTrainingLog.id)).where(
ModelTrainingLog.tenant_id == tenant_uuid
)
logs_count_result = await db.execute(logs_count_query)
logs_count = logs_count_result.scalar()
logs_delete_query = delete(ModelTrainingLog).where(
ModelTrainingLog.tenant_id == tenant_uuid
)
await db.execute(logs_delete_query)
deletion_stats["training_logs_deleted"] = logs_count
# Delete job queue entries
queue_delete_query = delete(TrainingJobQueue).where(
TrainingJobQueue.tenant_id == tenant_uuid
)
await db.execute(queue_delete_query)
await db.commit()
logger.info("Deleted training database records",
tenant_id=tenant_id,
models=models_count,
artifacts=artifacts_count,
logs=logs_count,
metrics=metrics_count)
except Exception as e:
await db.rollback()
error_msg = f"Error deleting database records: {str(e)}"
deletion_stats["errors"].append(error_msg)
logger.error(error_msg)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=error_msg
)
# Step 4: Models deleted successfully (MinIO cleanup already done in Step 2)
return {
"success": True,
"message": f"All training data for tenant {tenant_id} deleted successfully",
"deletion_details": deletion_stats
}
except HTTPException:
raise
except Exception as e:
logger.error("Unexpected error deleting tenant models",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete tenant models: {str(e)}"
)

View File

@@ -0,0 +1,410 @@
"""
Monitoring and Observability Endpoints
Real-time service monitoring and diagnostics
"""
from fastapi import APIRouter, Query
from typing import Dict, Any, List, Optional
from datetime import datetime, timezone, timedelta
from sqlalchemy import text, func
import logging
from app.core.database import database_manager
from app.utils.circuit_breaker import circuit_breaker_registry
from app.models.training import ModelTrainingLog, TrainingJobQueue, TrainedModel
logger = logging.getLogger(__name__)
router = APIRouter()
@router.get("/monitoring/circuit-breakers")
async def get_circuit_breaker_status() -> Dict[str, Any]:
"""
Get status of all circuit breakers.
Useful for monitoring external service health.
"""
breakers = circuit_breaker_registry.get_all_states()
return {
"timestamp": datetime.now(timezone.utc).isoformat(),
"circuit_breakers": breakers,
"summary": {
"total": len(breakers),
"open": sum(1 for b in breakers.values() if b["state"] == "open"),
"half_open": sum(1 for b in breakers.values() if b["state"] == "half_open"),
"closed": sum(1 for b in breakers.values() if b["state"] == "closed")
}
}
@router.post("/monitoring/circuit-breakers/{name}/reset")
async def reset_circuit_breaker(name: str) -> Dict[str, str]:
"""
Manually reset a circuit breaker.
Use with caution - only reset if you know the service has recovered.
"""
circuit_breaker_registry.reset(name)
return {
"status": "success",
"message": f"Circuit breaker '{name}' has been reset",
"timestamp": datetime.now(timezone.utc).isoformat()
}
@router.get("/monitoring/training-jobs")
async def get_training_job_stats(
hours: int = Query(default=24, ge=1, le=168, description="Look back period in hours")
) -> Dict[str, Any]:
"""
Get training job statistics for the specified period.
"""
try:
since = datetime.now(timezone.utc) - timedelta(hours=hours)
async with database_manager.get_session() as session:
# Get job counts by status
result = await session.execute(
text("""
SELECT status, COUNT(*) as count
FROM model_training_logs
WHERE created_at >= :since
GROUP BY status
"""),
{"since": since}
)
status_counts = dict(result.fetchall())
# Get average training time for completed jobs
result = await session.execute(
text("""
SELECT AVG(EXTRACT(EPOCH FROM (end_time - start_time))) as avg_duration
FROM model_training_logs
WHERE status = 'completed'
AND created_at >= :since
AND end_time IS NOT NULL
"""),
{"since": since}
)
avg_duration = result.scalar()
# Get failure rate
total = sum(status_counts.values())
failed = status_counts.get('failed', 0)
failure_rate = (failed / total * 100) if total > 0 else 0
# Get recent jobs
result = await session.execute(
text("""
SELECT job_id, tenant_id, status, progress, start_time, end_time
FROM model_training_logs
WHERE created_at >= :since
ORDER BY created_at DESC
LIMIT 10
"""),
{"since": since}
)
recent_jobs = [
{
"job_id": row.job_id,
"tenant_id": str(row.tenant_id),
"status": row.status,
"progress": row.progress,
"start_time": row.start_time.isoformat() if row.start_time else None,
"end_time": row.end_time.isoformat() if row.end_time else None
}
for row in result.fetchall()
]
return {
"period_hours": hours,
"timestamp": datetime.now(timezone.utc).isoformat(),
"summary": {
"total_jobs": total,
"by_status": status_counts,
"failure_rate_percent": round(failure_rate, 2),
"avg_duration_seconds": round(avg_duration, 2) if avg_duration else None
},
"recent_jobs": recent_jobs
}
except Exception as e:
logger.error(f"Failed to get training job stats: {e}")
return {
"error": str(e),
"timestamp": datetime.now(timezone.utc).isoformat()
}
@router.get("/monitoring/models")
async def get_model_stats() -> Dict[str, Any]:
"""
Get statistics about trained models.
"""
try:
async with database_manager.get_session() as session:
# Total models
result = await session.execute(
text("SELECT COUNT(*) FROM trained_models")
)
total_models = result.scalar()
# Active models
result = await session.execute(
text("SELECT COUNT(*) FROM trained_models WHERE is_active = true")
)
active_models = result.scalar()
# Production models
result = await session.execute(
text("SELECT COUNT(*) FROM trained_models WHERE is_production = true")
)
production_models = result.scalar()
# Models by type
result = await session.execute(
text("""
SELECT model_type, COUNT(*) as count
FROM trained_models
GROUP BY model_type
""")
)
models_by_type = dict(result.fetchall())
# Average model performance (MAPE)
result = await session.execute(
text("""
SELECT AVG(mape) as avg_mape
FROM trained_models
WHERE mape IS NOT NULL
AND is_active = true
""")
)
avg_mape = result.scalar()
# Models created today
today = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
result = await session.execute(
text("""
SELECT COUNT(*) FROM trained_models
WHERE created_at >= :today
"""),
{"today": today}
)
models_today = result.scalar()
return {
"timestamp": datetime.now(timezone.utc).isoformat(),
"summary": {
"total_models": total_models,
"active_models": active_models,
"production_models": production_models,
"models_created_today": models_today,
"average_mape_percent": round(avg_mape, 2) if avg_mape else None
},
"by_type": models_by_type
}
except Exception as e:
logger.error(f"Failed to get model stats: {e}")
return {
"error": str(e),
"timestamp": datetime.now(timezone.utc).isoformat()
}
@router.get("/monitoring/queue")
async def get_queue_status() -> Dict[str, Any]:
"""
Get training job queue status.
"""
try:
async with database_manager.get_session() as session:
# Queued jobs
result = await session.execute(
text("""
SELECT COUNT(*) FROM training_job_queue
WHERE status = 'queued'
""")
)
queued = result.scalar()
# Running jobs
result = await session.execute(
text("""
SELECT COUNT(*) FROM training_job_queue
WHERE status = 'running'
""")
)
running = result.scalar()
# Get oldest queued job
result = await session.execute(
text("""
SELECT created_at FROM training_job_queue
WHERE status = 'queued'
ORDER BY created_at ASC
LIMIT 1
""")
)
oldest_queued = result.scalar()
# Calculate wait time
if oldest_queued:
wait_time_seconds = (datetime.now(timezone.utc) - oldest_queued).total_seconds()
else:
wait_time_seconds = 0
return {
"timestamp": datetime.now(timezone.utc).isoformat(),
"queue": {
"queued": queued,
"running": running,
"oldest_wait_time_seconds": round(wait_time_seconds, 2) if oldest_queued else 0,
"oldest_queued_at": oldest_queued.isoformat() if oldest_queued else None
}
}
except Exception as e:
logger.error(f"Failed to get queue status: {e}")
return {
"error": str(e),
"timestamp": datetime.now(timezone.utc).isoformat()
}
@router.get("/monitoring/performance")
async def get_performance_metrics(
tenant_id: Optional[str] = Query(None, description="Filter by tenant ID")
) -> Dict[str, Any]:
"""
Get model performance metrics.
"""
try:
async with database_manager.get_session() as session:
query_params = {}
where_clause = ""
if tenant_id:
where_clause = "WHERE tenant_id = :tenant_id"
query_params["tenant_id"] = tenant_id
# Get performance distribution
result = await session.execute(
text(f"""
SELECT
COUNT(*) as total,
AVG(mape) as avg_mape,
MIN(mape) as min_mape,
MAX(mape) as max_mape,
PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY mape) as median_mape,
AVG(mae) as avg_mae,
AVG(rmse) as avg_rmse
FROM model_performance_metrics
{where_clause}
"""),
query_params
)
stats = result.fetchone()
# Get accuracy distribution (buckets)
result = await session.execute(
text(f"""
SELECT
CASE
WHEN mape <= 10 THEN 'excellent'
WHEN mape <= 20 THEN 'good'
WHEN mape <= 30 THEN 'acceptable'
ELSE 'poor'
END as accuracy_category,
COUNT(*) as count
FROM model_performance_metrics
{where_clause}
GROUP BY accuracy_category
"""),
query_params
)
distribution = dict(result.fetchall())
return {
"timestamp": datetime.now(timezone.utc).isoformat(),
"tenant_id": tenant_id,
"statistics": {
"total_metrics": stats.total if stats else 0,
"avg_mape_percent": round(stats.avg_mape, 2) if stats and stats.avg_mape else None,
"min_mape_percent": round(stats.min_mape, 2) if stats and stats.min_mape else None,
"max_mape_percent": round(stats.max_mape, 2) if stats and stats.max_mape else None,
"median_mape_percent": round(stats.median_mape, 2) if stats and stats.median_mape else None,
"avg_mae": round(stats.avg_mae, 2) if stats and stats.avg_mae else None,
"avg_rmse": round(stats.avg_rmse, 2) if stats and stats.avg_rmse else None
},
"distribution": distribution
}
except Exception as e:
logger.error(f"Failed to get performance metrics: {e}")
return {
"error": str(e),
"timestamp": datetime.now(timezone.utc).isoformat()
}
@router.get("/monitoring/alerts")
async def get_alerts() -> Dict[str, Any]:
"""
Get active alerts and warnings based on system state.
"""
alerts = []
warnings = []
try:
# Check circuit breakers
breakers = circuit_breaker_registry.get_all_states()
for name, state in breakers.items():
if state["state"] == "open":
alerts.append({
"type": "circuit_breaker_open",
"severity": "high",
"message": f"Circuit breaker '{name}' is OPEN - service unavailable",
"details": state
})
elif state["state"] == "half_open":
warnings.append({
"type": "circuit_breaker_recovering",
"severity": "medium",
"message": f"Circuit breaker '{name}' is recovering",
"details": state
})
# Check queue backlog
async with database_manager.get_session() as session:
result = await session.execute(
text("SELECT COUNT(*) FROM training_job_queue WHERE status = 'queued'")
)
queued = result.scalar()
if queued > 10:
warnings.append({
"type": "queue_backlog",
"severity": "medium",
"message": f"Training queue has {queued} pending jobs",
"count": queued
})
except Exception as e:
logger.error(f"Failed to generate alerts: {e}")
alerts.append({
"type": "monitoring_error",
"severity": "high",
"message": f"Failed to check system alerts: {str(e)}"
})
return {
"timestamp": datetime.now(timezone.utc).isoformat(),
"summary": {
"total_alerts": len(alerts),
"total_warnings": len(warnings)
},
"alerts": alerts,
"warnings": warnings
}

View File

@@ -0,0 +1,123 @@
"""
Training Jobs API - ATOMIC CRUD operations
Handles basic training job creation and retrieval
"""
from fastapi import APIRouter, Depends, HTTPException, status, Path, Query, Request
from typing import List, Optional
import structlog
from shared.routing import RouteBuilder
from shared.monitoring.decorators import track_execution_time
from shared.monitoring.metrics import get_metrics_collector
from datetime import datetime
import uuid
from app.services.training_service import EnhancedTrainingService
from app.schemas.training import TrainingJobResponse
from shared.database.base import create_database_manager
from app.core.config import settings
logger = structlog.get_logger()
route_builder = RouteBuilder('training')
router = APIRouter(tags=["training-jobs"])
def get_enhanced_training_service():
"""Dependency injection for EnhancedTrainingService"""
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
return EnhancedTrainingService(database_manager)
@router.get(
route_builder.build_nested_resource_route("jobs", "job_id", "status")
)
@track_execution_time("enhanced_job_status_duration_seconds", "training-service")
async def get_training_job_status(
tenant_id: str = Path(..., description="Tenant ID"),
job_id: str = Path(..., description="Job ID"),
request_obj: Request = None,
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
):
"""
Get training job status using repository pattern.
"""
metrics = get_metrics_collector(request_obj)
try:
# Get status using enhanced service
status_info = await enhanced_training_service.get_training_status(job_id)
if not status_info or status_info.get("error"):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Training job not found"
)
if metrics:
metrics.increment_counter("enhanced_status_requests_total")
return {
**status_info,
"enhanced_features": True,
"repository_integration": True
}
except HTTPException:
raise
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_status_errors_total")
logger.error("Failed to get training status",
job_id=job_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get training status"
)
@router.get(
route_builder.build_base_route("statistics")
)
@track_execution_time("enhanced_tenant_statistics_duration_seconds", "training-service")
async def get_tenant_statistics(
tenant_id: str = Path(..., description="Tenant ID"),
request_obj: Request = None,
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
):
"""
Get comprehensive tenant statistics using repository pattern.
"""
metrics = get_metrics_collector(request_obj)
try:
# Get statistics using enhanced service
statistics = await enhanced_training_service.get_tenant_statistics(tenant_id)
if statistics.get("error"):
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=statistics["error"]
)
if metrics:
metrics.increment_counter("enhanced_statistics_requests_total")
return {
**statistics,
"enhanced_features": True,
"repository_integration": True
}
except HTTPException:
raise
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_statistics_errors_total")
logger.error("Failed to get tenant statistics",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get tenant statistics"
)

View File

@@ -0,0 +1,821 @@
"""
Training Operations API - BUSINESS logic
Handles training job execution and metrics
"""
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks, Request, Path
from typing import Optional, Dict, Any
import structlog
from datetime import datetime, timezone
import uuid
import shared.redis_utils
from sqlalchemy.ext.asyncio import AsyncSession
from shared.routing import RouteBuilder
from shared.monitoring.decorators import track_execution_time
from shared.monitoring.metrics import get_metrics_collector
from shared.database.base import create_database_manager
from shared.auth.decorators import get_current_user_dep
from shared.auth.access_control import require_user_role, admin_role_required, service_only_access
from shared.security import create_audit_logger, create_rate_limiter, AuditSeverity, AuditAction
from shared.subscription.plans import (
get_training_job_quota,
get_dataset_size_limit
)
from app.services.training_service import EnhancedTrainingService
from app.schemas.training import (
TrainingJobRequest,
SingleProductTrainingRequest,
TrainingJobResponse
)
from app.utils.time_estimation import (
calculate_initial_estimate,
calculate_estimated_completion_time,
get_historical_average_estimate
)
from app.services.training_events import (
publish_training_started,
publish_training_completed,
publish_training_failed
)
from app.core.config import settings
from app.core.database import get_db
from app.models import AuditLog
logger = structlog.get_logger()
route_builder = RouteBuilder('training')
router = APIRouter(tags=["training-operations"])
# Initialize audit logger
audit_logger = create_audit_logger("training-service", AuditLog)
# Redis client for rate limiting
_redis_client = None
async def get_training_redis_client():
"""Get or create Redis client for rate limiting"""
global _redis_client
if _redis_client is None:
# Initialize Redis if not already done
try:
from app.core.config import settings
_redis_client = await shared.redis_utils.initialize_redis(settings.REDIS_URL)
except:
# Fallback to getting the client directly (if already initialized elsewhere)
_redis_client = await shared.redis_utils.get_redis_client()
return _redis_client
async def get_rate_limiter():
"""Dependency for rate limiter"""
redis_client = await get_training_redis_client()
return create_rate_limiter(redis_client)
def get_enhanced_training_service():
"""Dependency injection for EnhancedTrainingService"""
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
return EnhancedTrainingService(database_manager)
@router.post(
route_builder.build_base_route("jobs"), response_model=TrainingJobResponse)
@require_user_role(['admin', 'owner'])
@track_execution_time("enhanced_training_job_duration_seconds", "training-service")
async def start_training_job(
request: TrainingJobRequest,
tenant_id: str = Path(..., description="Tenant ID"),
background_tasks: BackgroundTasks = BackgroundTasks(),
request_obj: Request = None,
current_user: Dict[str, Any] = Depends(get_current_user_dep),
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service),
rate_limiter = Depends(get_rate_limiter),
db: AsyncSession = Depends(get_db)
):
"""
Start a new training job for all tenant products (Admin+ only, quota enforced).
**RBAC:** Admin or Owner role required
**Quotas:**
- Starter: 1 training job/day, max 1,000 rows
- Professional: 5 training jobs/day, max 10,000 rows
- Enterprise: Unlimited jobs, unlimited rows
Enhanced immediate response pattern:
1. Validate subscription tier and quotas
2. Validate request with enhanced validation
3. Create job record using repository pattern
4. Return 200 with enhanced job details
5. Execute enhanced training in background with repository tracking
Enhanced features:
- Repository pattern for data access
- Quota enforcement by subscription tier
- Audit logging for all operations
- Enhanced error handling and logging
- Metrics tracking and monitoring
- Transactional operations
"""
metrics = get_metrics_collector(request_obj)
# Get subscription tier and enforce quotas
tier = current_user.get('subscription_tier', 'starter')
# Estimate dataset size (this should come from the request or be calculated)
# For now, we'll assume a reasonable estimate
estimated_dataset_size = request.estimated_rows if hasattr(request, 'estimated_rows') else 500
# Initialize variables for later use
quota_result = None
quota_limit = None
try:
# Validate dataset size limits
await rate_limiter.validate_dataset_size(
tenant_id, estimated_dataset_size, tier
)
# Check daily training job quota
quota_limit = get_training_job_quota(tier)
quota_result = await rate_limiter.check_and_increment_quota(
tenant_id,
"training_jobs",
quota_limit,
period=86400 # 24 hours
)
logger.info("Training job quota check passed",
tenant_id=tenant_id,
tier=tier,
current_usage=quota_result.get('current', 0) if quota_result else 0,
limit=quota_limit)
except HTTPException:
# Quota or validation error - re-raise
raise
except Exception as quota_error:
logger.error("Quota validation failed", error=str(quota_error))
# Continue with job creation but log the error
try:
# CRITICAL FIX: Check for existing running jobs before starting new one
# This prevents duplicate tenant-level training jobs
async with enhanced_training_service.database_manager.get_session() as check_session:
from app.repositories.training_log_repository import TrainingLogRepository
log_repo = TrainingLogRepository(check_session)
# Check for active jobs (running or pending)
active_jobs = await log_repo.get_active_jobs(tenant_id=tenant_id)
pending_jobs = await log_repo.get_logs_by_tenant(
tenant_id=tenant_id,
status="pending",
limit=10
)
all_active = active_jobs + pending_jobs
if all_active:
# Training job already in progress, return existing job info
existing_job = all_active[0]
logger.info("Training job already in progress, returning existing job",
existing_job_id=existing_job.job_id,
tenant_id=tenant_id,
status=existing_job.status)
return TrainingJobResponse(
job_id=existing_job.job_id,
tenant_id=tenant_id,
status=existing_job.status,
message=f"Training job already in progress (started {existing_job.created_at.isoformat() if existing_job.created_at else 'recently'})",
created_at=existing_job.created_at or datetime.now(timezone.utc),
estimated_duration_minutes=existing_job.config.get("estimated_duration_minutes", 15) if existing_job.config else 15,
training_results={
"total_products": 0,
"successful_trainings": 0,
"failed_trainings": 0,
"products": [],
"overall_training_time_seconds": 0.0
},
data_summary=None,
completed_at=None,
error_details=None,
processing_metadata={
"background_task": True,
"async_execution": True,
"existing_job": True,
"deduplication": True
}
)
# No existing job, proceed with creating new one
# Generate enhanced job ID
job_id = f"enhanced_training_{tenant_id}_{uuid.uuid4().hex[:8]}"
logger.info("Creating enhanced training job using repository pattern",
job_id=job_id,
tenant_id=tenant_id)
# Record job creation metrics
if metrics:
metrics.increment_counter("enhanced_training_jobs_created_total")
# Calculate intelligent time estimate
# We don't know exact product count yet, so use historical average or estimate
try:
# Try to get historical average for this tenant
historical_avg = await get_historical_average_estimate(db, tenant_id)
# If no historical data, estimate based on typical product count (10-20 products)
estimated_products = 15 # Conservative estimate
estimated_duration_minutes = calculate_initial_estimate(
total_products=estimated_products,
avg_training_time_per_product=historical_avg if historical_avg else 60.0
)
except Exception as est_error:
logger.warning("Could not calculate intelligent estimate, using default",
error=str(est_error))
estimated_duration_minutes = 15 # Default fallback
# Calculate estimated completion time
estimated_completion_time = calculate_estimated_completion_time(estimated_duration_minutes)
# Note: training.started event will be published by the trainer with accurate product count
# We don't publish here to avoid duplicate events
# Add enhanced background task
background_tasks.add_task(
execute_training_job_background,
tenant_id=tenant_id,
job_id=job_id,
bakery_location=(40.4168, -3.7038),
requested_start=request.start_date,
requested_end=request.end_date,
estimated_duration_minutes=estimated_duration_minutes
)
# Return enhanced immediate success response
response_data = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "pending",
"message": "Enhanced training job started successfully using repository pattern",
"created_at": datetime.now(timezone.utc),
"estimated_duration_minutes": estimated_duration_minutes,
"training_results": {
"total_products": 0,
"successful_trainings": 0,
"failed_trainings": 0,
"products": [],
"overall_training_time_seconds": 0.0
},
"data_summary": None,
"completed_at": None,
"error_details": None,
"processing_metadata": {
"background_task": True,
"async_execution": True,
"enhanced_features": True,
"repository_pattern": True,
"dependency_injection": True
}
}
logger.info("Enhanced training job queued successfully",
job_id=job_id,
features=["repository-pattern", "dependency-injection", "enhanced-tracking"])
# Log audit event for training job creation
try:
from app.core.database import database_manager
async with database_manager.get_session() as db:
await audit_logger.log_event(
db_session=db,
tenant_id=tenant_id,
user_id=current_user["user_id"],
action=AuditAction.CREATE.value,
resource_type="training_job",
resource_id=job_id,
severity=AuditSeverity.MEDIUM.value,
description=f"Started training job (tier: {tier})",
audit_metadata={
"job_id": job_id,
"tier": tier,
"estimated_dataset_size": estimated_dataset_size,
"quota_usage": quota_result.get('current', 0) if quota_result else 0,
"quota_limit": quota_limit if quota_limit else "unlimited"
},
endpoint="/jobs",
method="POST"
)
except Exception as audit_error:
logger.warning("Failed to log audit event", error=str(audit_error))
return TrainingJobResponse(**response_data)
except HTTPException:
raise
except ValueError as e:
if metrics:
metrics.increment_counter("enhanced_training_validation_errors_total")
logger.error("Enhanced training job validation error",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_training_job_errors_total")
logger.error("Failed to queue enhanced training job",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to start enhanced training job"
)
async def execute_training_job_background(
tenant_id: str,
job_id: str,
bakery_location: tuple,
requested_start: Optional[datetime] = None,
requested_end: Optional[datetime] = None,
estimated_duration_minutes: int = 15
):
"""
Enhanced background task that executes the training job using repository pattern.
Enhanced features:
- Repository pattern for all data operations
- Enhanced error handling with structured logging
- Transactional operations for data consistency
- Comprehensive metrics tracking
- Database connection pooling
- Enhanced progress reporting
"""
logger.info("Enhanced background training job started",
job_id=job_id,
tenant_id=tenant_id,
features=["repository-pattern", "enhanced-tracking"])
# Get enhanced training service with dependency injection
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
enhanced_training_service = EnhancedTrainingService(database_manager)
try:
# Create initial training log entry first
await enhanced_training_service._update_job_status_repository(
job_id=job_id,
status="pending",
progress=0,
current_step="Starting enhanced training job",
tenant_id=tenant_id
)
# This will be published by the training service itself
# when it starts execution
training_config = {
"job_id": job_id,
"tenant_id": tenant_id,
"bakery_location": {
"latitude": bakery_location[0],
"longitude": bakery_location[1]
},
"requested_start": requested_start.isoformat() if requested_start else None,
"requested_end": requested_end.isoformat() if requested_end else None,
"estimated_duration_minutes": estimated_duration_minutes,
"background_execution": True,
"enhanced_features": True,
"repository_pattern": True,
"api_version": "enhanced_v1"
}
# Update job status using repository pattern
await enhanced_training_service._update_job_status_repository(
job_id=job_id,
status="running",
progress=0,
current_step="Initializing enhanced training pipeline",
tenant_id=tenant_id
)
# Execute the enhanced training pipeline with repository pattern
result = await enhanced_training_service.start_training_job(
tenant_id=tenant_id,
job_id=job_id,
bakery_location=bakery_location,
requested_start=requested_start,
requested_end=requested_end
)
# Note: Final status is already updated by start_training_job() via complete_training_log()
# No need for redundant update here - it was causing duplicate log entries
# Completion event is published by the training service
logger.info("Enhanced background training job completed successfully",
job_id=job_id,
models_created=result.get('products_trained', 0),
features=["repository-pattern", "enhanced-tracking"])
except Exception as training_error:
logger.error("Enhanced training pipeline failed",
job_id=job_id,
error=str(training_error))
try:
await enhanced_training_service._update_job_status_repository(
job_id=job_id,
status="failed",
progress=0,
current_step="Enhanced training failed",
error_message=str(training_error),
tenant_id=tenant_id
)
except Exception as status_error:
logger.error("Failed to update job status after training error",
job_id=job_id,
status_error=str(status_error))
# Failure event is published by the training service
await publish_training_failed(job_id, tenant_id, str(training_error))
finally:
logger.info("Enhanced background training job cleanup completed",
job_id=job_id)
@router.post(
route_builder.build_resource_detail_route("products", "inventory_product_id"), response_model=TrainingJobResponse)
@require_user_role(['admin', 'owner'])
@track_execution_time("enhanced_single_product_training_duration_seconds", "training-service")
async def start_single_product_training(
request: SingleProductTrainingRequest,
tenant_id: str = Path(..., description="Tenant ID"),
inventory_product_id: str = Path(..., description="Inventory product UUID"),
background_tasks: BackgroundTasks = BackgroundTasks(),
request_obj: Request = None,
current_user: Dict[str, Any] = Depends(get_current_user_dep),
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
):
"""
Start enhanced training for a single product (Admin+ only).
**RBAC:** Admin or Owner role required
Enhanced features:
- Repository pattern for data access
- Enhanced error handling and validation
- Metrics tracking
- Transactional operations
- Background execution to prevent blocking
"""
metrics = get_metrics_collector(request_obj)
try:
logger.info("Starting enhanced single product training",
inventory_product_id=inventory_product_id,
tenant_id=tenant_id)
# CRITICAL FIX: Check if this product is currently being trained
# This prevents duplicate training from rapid-click scenarios
async with enhanced_training_service.database_manager.get_session() as check_session:
from app.repositories.training_log_repository import TrainingLogRepository
log_repo = TrainingLogRepository(check_session)
# Check for active jobs for this specific product
active_jobs = await log_repo.get_active_jobs(tenant_id=tenant_id)
pending_jobs = await log_repo.get_logs_by_tenant(
tenant_id=tenant_id,
status="pending",
limit=20
)
all_active = active_jobs + pending_jobs
# Filter for jobs that include this specific product
product_jobs = [
job for job in all_active
if job.config and (
# Single product job for this product
job.config.get("product_id") == inventory_product_id or
# Tenant-wide job that would include this product
job.config.get("job_type") == "tenant_training"
)
]
if product_jobs:
existing_job = product_jobs[0]
logger.warning("Product training already in progress, rejecting duplicate request",
existing_job_id=existing_job.job_id,
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
status=existing_job.status)
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail={
"error": "Product training already in progress",
"message": f"Product {inventory_product_id} is currently being trained in job {existing_job.job_id}",
"existing_job_id": existing_job.job_id,
"status": existing_job.status,
"started_at": existing_job.created_at.isoformat() if existing_job.created_at else None
}
)
# No existing job, proceed with training
# Record metrics
if metrics:
metrics.increment_counter("enhanced_single_product_training_total")
# Generate enhanced job ID
job_id = f"enhanced_single_{tenant_id}_{inventory_product_id}_{uuid.uuid4().hex[:8]}"
# CRITICAL FIX: Add initial training log entry
await enhanced_training_service._update_job_status_repository(
job_id=job_id,
status="pending",
progress=0,
current_step="Initializing single product training",
tenant_id=tenant_id
)
# Add enhanced background task for single product training
background_tasks.add_task(
execute_single_product_training_background,
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
job_id=job_id,
bakery_location=request.bakery_location or (40.4168, -3.7038),
database_manager=enhanced_training_service.database_manager
)
# Return immediate response with job info
response_data = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "pending",
"message": "Enhanced single product training started successfully",
"created_at": datetime.now(timezone.utc),
"estimated_duration_minutes": 15, # Default estimate for single product
"training_results": {
"total_products": 1,
"successful_trainings": 0,
"failed_trainings": 0,
"products": [],
"overall_training_time_seconds": 0.0
},
"data_summary": None,
"completed_at": None,
"error_details": None,
"processing_metadata": {
"background_task": True,
"async_execution": True,
"enhanced_features": True,
"repository_pattern": True,
"dependency_injection": True
}
}
logger.info("Enhanced single product training queued successfully",
inventory_product_id=inventory_product_id,
job_id=job_id)
if metrics:
metrics.increment_counter("enhanced_single_product_training_queued_total")
return TrainingJobResponse(**response_data)
except ValueError as e:
if metrics:
metrics.increment_counter("enhanced_single_product_validation_errors_total")
logger.error("Enhanced single product training validation error",
error=str(e),
inventory_product_id=inventory_product_id)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_single_product_training_errors_total")
logger.error("Enhanced single product training failed",
error=str(e),
inventory_product_id=inventory_product_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Enhanced single product training failed"
)
async def execute_single_product_training_background(
tenant_id: str,
inventory_product_id: str,
job_id: str,
bakery_location: tuple,
database_manager
):
"""
Enhanced background task that executes single product training using repository pattern.
Uses a separate service instance to avoid session conflicts.
"""
logger.info("Enhanced background single product training started",
job_id=job_id,
tenant_id=tenant_id,
inventory_product_id=inventory_product_id)
# Create a new service instance with a fresh database session to avoid conflicts
from app.services.training_service import EnhancedTrainingService
fresh_training_service = EnhancedTrainingService(database_manager)
try:
# Update job status to running
await fresh_training_service._update_job_status_repository(
job_id=job_id,
status="running",
progress=0,
current_step="Starting single product training",
tenant_id=tenant_id
)
# Execute the enhanced single product training with repository pattern
result = await fresh_training_service.start_single_product_training(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
job_id=job_id,
bakery_location=bakery_location
)
logger.info("Enhanced background single product training completed successfully",
job_id=job_id,
inventory_product_id=inventory_product_id)
except Exception as training_error:
logger.error("Enhanced single product training failed",
job_id=job_id,
inventory_product_id=inventory_product_id,
error=str(training_error))
try:
await fresh_training_service._update_job_status_repository(
job_id=job_id,
status="failed",
progress=0,
current_step="Single product training failed",
error_message=str(training_error),
tenant_id=tenant_id
)
except Exception as status_error:
logger.error("Failed to update job status after training error",
job_id=job_id,
status_error=str(status_error))
finally:
logger.info("Enhanced background single product training cleanup completed",
job_id=job_id,
inventory_product_id=inventory_product_id)
@router.get("/health")
async def health_check():
"""Health check endpoint for the training operations"""
return {
"status": "healthy",
"service": "training-operations",
"version": "3.0.0",
"features": [
"repository-pattern",
"dependency-injection",
"enhanced-error-handling",
"metrics-tracking",
"transactional-operations"
],
"timestamp": datetime.now().isoformat()
}
# ============================================================================
# Tenant Data Deletion Operations (Internal Service Only)
# ============================================================================
@router.delete(
route_builder.build_base_route("tenant/{tenant_id}", include_tenant_prefix=False),
response_model=dict
)
@service_only_access
async def delete_tenant_data(
tenant_id: str = Path(..., description="Tenant ID to delete data for"),
current_user: dict = Depends(get_current_user_dep)
):
"""
Delete all training data for a tenant (Internal service only)
This endpoint is called by the orchestrator during tenant deletion.
It permanently deletes all training-related data including:
- Trained models (all versions)
- Model artifacts (files and metadata)
- Training logs and job history
- Model performance metrics
- Training job queue entries
- Audit logs
**WARNING**: This operation is irreversible!
**NOTE**: Physical model files (.pkl) should be cleaned up separately
Returns:
Deletion summary with counts of deleted records
"""
from app.services.tenant_deletion_service import TrainingTenantDeletionService
from app.core.config import settings
try:
logger.info("training.tenant_deletion.api_called", tenant_id=tenant_id)
db_manager = create_database_manager(settings.DATABASE_URL, "training")
async with db_manager.get_session() as session:
deletion_service = TrainingTenantDeletionService(session)
result = await deletion_service.safe_delete_tenant_data(tenant_id)
if not result.success:
raise HTTPException(
status_code=500,
detail=f"Tenant data deletion failed: {', '.join(result.errors)}"
)
return {
"message": "Tenant data deletion completed successfully",
"note": "Physical model files should be cleaned up separately from storage",
"summary": result.to_dict()
}
except HTTPException:
raise
except Exception as e:
logger.error("training.tenant_deletion.api_error",
tenant_id=tenant_id,
error=str(e),
exc_info=True)
raise HTTPException(
status_code=500,
detail=f"Failed to delete tenant data: {str(e)}"
)
@router.get(
route_builder.build_base_route("tenant/{tenant_id}/deletion-preview", include_tenant_prefix=False),
response_model=dict
)
@service_only_access
async def preview_tenant_data_deletion(
tenant_id: str = Path(..., description="Tenant ID to preview deletion for"),
current_user: dict = Depends(get_current_user_dep)
):
"""
Preview what data would be deleted for a tenant (dry-run)
This endpoint shows counts of all data that would be deleted
without actually deleting anything. Useful for:
- Confirming deletion scope before execution
- Auditing and compliance
- Troubleshooting
Returns:
Dictionary with entity names and their counts
"""
from app.services.tenant_deletion_service import TrainingTenantDeletionService
from app.core.config import settings
try:
logger.info("training.tenant_deletion.preview_called", tenant_id=tenant_id)
db_manager = create_database_manager(settings.DATABASE_URL, "training")
async with db_manager.get_session() as session:
deletion_service = TrainingTenantDeletionService(session)
preview = await deletion_service.get_tenant_data_preview(tenant_id)
total_records = sum(preview.values())
return {
"tenant_id": tenant_id,
"service": "training",
"preview": preview,
"total_records": total_records,
"note": "Physical model files (.pkl, metadata) are not counted here",
"warning": "These records will be permanently deleted and cannot be recovered"
}
except Exception as e:
logger.error("training.tenant_deletion.preview_error",
tenant_id=tenant_id,
error=str(e),
exc_info=True)
raise HTTPException(
status_code=500,
detail=f"Failed to preview tenant data deletion: {str(e)}"
)

View File

@@ -0,0 +1,163 @@
"""
WebSocket Operations for Training Service
Simple WebSocket endpoint that connects clients and receives broadcasts from RabbitMQ
"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Path, Query
import structlog
from app.websocket.manager import websocket_manager
from shared.auth.jwt_handler import JWTHandler
from app.core.config import settings
from app.services.training_service import EnhancedTrainingService
from shared.database.base import create_database_manager
logger = structlog.get_logger()
router = APIRouter(tags=["websocket"])
def get_enhanced_training_service():
"""Create EnhancedTrainingService instance"""
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
return EnhancedTrainingService(database_manager)
@router.websocket("/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live")
async def training_progress_websocket(
websocket: WebSocket,
tenant_id: str = Path(..., description="Tenant ID"),
job_id: str = Path(..., description="Job ID"),
token: str = Query(..., description="Authentication token")
):
"""
WebSocket endpoint for real-time training progress updates.
This endpoint:
1. Validates the authentication token
2. Accepts the WebSocket connection
3. Keeps the connection alive
4. Receives broadcasts from RabbitMQ (via WebSocket manager)
"""
# Validate token
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
try:
payload = jwt_handler.verify_token(token)
if not payload:
await websocket.close(code=1008, reason="Invalid token")
logger.warning("WebSocket connection rejected - invalid token",
job_id=job_id,
tenant_id=tenant_id)
return
user_id = payload.get('user_id')
if not user_id:
await websocket.close(code=1008, reason="Invalid token payload")
logger.warning("WebSocket connection rejected - no user_id in token",
job_id=job_id,
tenant_id=tenant_id)
return
logger.info("WebSocket authentication successful",
user_id=user_id,
tenant_id=tenant_id,
job_id=job_id)
except Exception as e:
await websocket.close(code=1008, reason="Authentication failed")
logger.warning("WebSocket authentication failed",
job_id=job_id,
tenant_id=tenant_id,
error=str(e))
return
# Connect to WebSocket manager
await websocket_manager.connect(job_id, websocket)
# Helper function to send current job status
async def send_current_status():
"""Fetch and send the current job status to the client"""
try:
training_service = get_enhanced_training_service()
status_info = await training_service.get_training_status(job_id)
if status_info and not status_info.get("error"):
# Map status to WebSocket message type
ws_type = "progress"
if status_info.get("status") == "completed":
ws_type = "completed"
elif status_info.get("status") == "failed":
ws_type = "failed"
await websocket.send_json({
"type": ws_type,
"job_id": job_id,
"data": {
"progress": status_info.get("progress", 0),
"current_step": status_info.get("current_step"),
"status": status_info.get("status"),
"products_total": status_info.get("products_total", 0),
"products_completed": status_info.get("products_completed", 0),
"products_failed": status_info.get("products_failed", 0),
"estimated_time_remaining_seconds": status_info.get("estimated_time_remaining_seconds"),
"message": status_info.get("message")
}
})
logger.info("Sent current job status to client",
job_id=job_id,
status=status_info.get("status"),
progress=status_info.get("progress"))
except Exception as e:
logger.error("Failed to send current job status",
job_id=job_id,
error=str(e))
try:
# Send connection confirmation
await websocket.send_json({
"type": "connected",
"job_id": job_id,
"message": "Connected to training progress stream"
})
# Immediately send current job status after connection
# This handles the race condition where training completes before WebSocket connects
await send_current_status()
# Keep connection alive and handle client messages
ping_count = 0
while True:
try:
# Receive messages from client (ping, get_status, etc.)
data = await websocket.receive_text()
# Handle ping/pong
if data == "ping":
await websocket.send_text("pong")
ping_count += 1
logger.debug("WebSocket ping/pong",
job_id=job_id,
ping_count=ping_count,
connection_healthy=True)
# Handle get_status request
elif data == "get_status":
await send_current_status()
logger.info("Status requested by client", job_id=job_id)
except WebSocketDisconnect:
logger.info("Client disconnected", job_id=job_id)
break
except Exception as e:
logger.error("Error in WebSocket message loop",
job_id=job_id,
error=str(e))
break
finally:
# Disconnect from manager
await websocket_manager.disconnect(job_id, websocket)
logger.info("WebSocket connection closed",
job_id=job_id,
tenant_id=tenant_id)

View File

@@ -0,0 +1,435 @@
"""
Training Event Consumer
Processes ML model retraining requests from RabbitMQ
Queues training jobs and manages model lifecycle
"""
import json
import structlog
from typing import Dict, Any, Optional
from datetime import datetime
from uuid import UUID
from shared.messaging import RabbitMQClient
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
logger = structlog.get_logger()
class TrainingEventConsumer:
"""
Consumes training retraining events and queues ML training jobs
Ensures no duplicate training jobs and manages priorities
"""
def __init__(self, db_session: AsyncSession):
self.db_session = db_session
async def consume_training_events(
self,
rabbitmq_client: RabbitMQClient
):
"""
Start consuming training events from RabbitMQ
"""
async def process_message(message):
"""Process a single training event message"""
try:
async with message.process():
# Parse event data
event_data = json.loads(message.body.decode())
logger.info(
"Received training event",
event_id=event_data.get('event_id'),
event_type=event_data.get('event_type'),
tenant_id=event_data.get('tenant_id')
)
# Process the event
await self.process_training_event(event_data)
except Exception as e:
logger.error(
"Error processing training event",
error=str(e),
exc_info=True
)
# Start consuming events
await rabbitmq_client.consume_events(
exchange_name="training.events",
queue_name="training.retraining.queue",
routing_key="training.retrain.*",
callback=process_message
)
logger.info("Started consuming training events")
async def process_training_event(self, event_data: Dict[str, Any]) -> bool:
"""
Process a training event based on type
Args:
event_data: Full event payload from RabbitMQ
Returns:
bool: True if processed successfully
"""
try:
event_type = event_data.get('event_type')
data = event_data.get('data', {})
tenant_id = event_data.get('tenant_id')
if not tenant_id:
logger.warning("Training event missing tenant_id", event_data=event_data)
return False
# Route to appropriate handler
if event_type == 'training.retrain.requested':
success = await self._handle_retrain_requested(tenant_id, data, event_data)
elif event_type == 'training.retrain.scheduled':
success = await self._handle_retrain_scheduled(tenant_id, data)
else:
logger.warning("Unknown training event type", event_type=event_type)
success = True # Mark as processed to avoid retry
if success:
logger.info(
"Training event processed successfully",
event_type=event_type,
tenant_id=tenant_id
)
else:
logger.error(
"Training event processing failed",
event_type=event_type,
tenant_id=tenant_id
)
return success
except Exception as e:
logger.error(
"Error in process_training_event",
error=str(e),
event_id=event_data.get('event_id'),
exc_info=True
)
return False
async def _handle_retrain_requested(
self,
tenant_id: str,
data: Dict[str, Any],
event_data: Dict[str, Any]
) -> bool:
"""
Handle retraining request event
Validates model, checks for existing jobs, queues training job
Args:
tenant_id: Tenant ID
data: Retraining request data
event_data: Full event payload
Returns:
bool: True if handled successfully
"""
try:
model_id = data.get('model_id')
product_id = data.get('product_id')
trigger_reason = data.get('trigger_reason', 'unknown')
priority = data.get('priority', 'normal')
event_id = event_data.get('event_id')
if not model_id:
logger.warning("Retraining request missing model_id", data=data)
return False
# Validate model exists
from app.models import TrainedModel
stmt = select(TrainedModel).where(
TrainedModel.id == UUID(model_id),
TrainedModel.tenant_id == UUID(tenant_id)
)
result = await self.db_session.execute(stmt)
model = result.scalar_one_or_none()
if not model:
logger.error(
"Model not found for retraining",
model_id=model_id,
tenant_id=tenant_id
)
return False
# Check if model is already in training
if model.status in ['training', 'retraining_queued']:
logger.info(
"Model already in training, skipping duplicate request",
model_id=model_id,
current_status=model.status
)
return True # Consider successful (idempotent)
# Check for existing job in queue
from app.models import TrainingJobQueue
existing_job_stmt = select(TrainingJobQueue).where(
TrainingJobQueue.model_id == UUID(model_id),
TrainingJobQueue.status.in_(['pending', 'running'])
)
existing_job_result = await self.db_session.execute(existing_job_stmt)
existing_job = existing_job_result.scalar_one_or_none()
if existing_job:
logger.info(
"Training job already queued, skipping duplicate",
model_id=model_id,
job_id=str(existing_job.id)
)
return True # Idempotent
# Queue training job
job_id = await self._queue_training_job(
tenant_id=tenant_id,
model_id=model_id,
product_id=product_id,
trigger_reason=trigger_reason,
priority=priority,
event_id=event_id,
metadata=data
)
if not job_id:
logger.error("Failed to queue training job", model_id=model_id)
return False
# Update model status
model.status = 'retraining_queued'
model.updated_at = datetime.utcnow()
await self.db_session.commit()
# Publish job queued event
await self._publish_job_queued_event(
tenant_id=tenant_id,
model_id=model_id,
job_id=job_id,
priority=priority
)
logger.info(
"Retraining job queued successfully",
model_id=model_id,
job_id=job_id,
trigger_reason=trigger_reason,
priority=priority
)
return True
except Exception as e:
await self.db_session.rollback()
logger.error(
"Error handling retrain requested",
error=str(e),
model_id=data.get('model_id'),
exc_info=True
)
return False
async def _handle_retrain_scheduled(
self,
tenant_id: str,
data: Dict[str, Any]
) -> bool:
"""
Handle scheduled retraining event
Similar to retrain_requested but for scheduled/batch retraining
Args:
tenant_id: Tenant ID
data: Scheduled retraining data
Returns:
bool: True if handled successfully
"""
try:
# Similar logic to _handle_retrain_requested
# but may have different priority or batching logic
logger.info(
"Handling scheduled retraining",
tenant_id=tenant_id,
model_count=len(data.get('models', []))
)
# For now, redirect to retrain_requested handler
success_count = 0
for model_data in data.get('models', []):
if await self._handle_retrain_requested(
tenant_id,
model_data,
{'event_id': data.get('schedule_id'), 'tenant_id': tenant_id}
):
success_count += 1
logger.info(
"Scheduled retraining processed",
tenant_id=tenant_id,
successful=success_count,
total=len(data.get('models', []))
)
return success_count > 0
except Exception as e:
logger.error(
"Error handling retrain scheduled",
error=str(e),
tenant_id=tenant_id,
exc_info=True
)
return False
async def _queue_training_job(
self,
tenant_id: str,
model_id: str,
product_id: str,
trigger_reason: str,
priority: str,
event_id: str,
metadata: Dict[str, Any]
) -> Optional[str]:
"""
Queue a training job in the database
Args:
tenant_id: Tenant ID
model_id: Model ID to retrain
product_id: Product ID
trigger_reason: Why retraining was triggered
priority: Job priority (low, normal, high)
event_id: Originating event ID
metadata: Additional job metadata
Returns:
Job ID if successful, None otherwise
"""
try:
from app.models import TrainingJobQueue
import uuid
# Map priority to numeric value for sorting
priority_map = {
'low': 1,
'normal': 2,
'high': 3,
'critical': 4
}
job = TrainingJobQueue(
id=uuid.uuid4(),
tenant_id=UUID(tenant_id),
model_id=UUID(model_id),
product_id=UUID(product_id) if product_id else None,
job_type='retrain',
status='pending',
priority=priority,
priority_score=priority_map.get(priority, 2),
trigger_reason=trigger_reason,
event_id=event_id,
metadata=metadata,
created_at=datetime.utcnow(),
scheduled_at=datetime.utcnow()
)
self.db_session.add(job)
await self.db_session.commit()
logger.info(
"Training job created",
job_id=str(job.id),
model_id=model_id,
priority=priority,
trigger_reason=trigger_reason
)
return str(job.id)
except Exception as e:
await self.db_session.rollback()
logger.error(
"Failed to queue training job",
model_id=model_id,
error=str(e),
exc_info=True
)
return None
async def _publish_job_queued_event(
self,
tenant_id: str,
model_id: str,
job_id: str,
priority: str
):
"""
Publish event that training job was queued
Args:
tenant_id: Tenant ID
model_id: Model ID
job_id: Training job ID
priority: Job priority
"""
try:
from shared.messaging import get_rabbitmq_client
import uuid
rabbitmq_client = get_rabbitmq_client()
if not rabbitmq_client:
logger.warning("RabbitMQ client not available for event publishing")
return
event_payload = {
"event_id": str(uuid.uuid4()),
"event_type": "training.retrain.queued",
"timestamp": datetime.utcnow().isoformat(),
"tenant_id": tenant_id,
"data": {
"job_id": job_id,
"model_id": model_id,
"priority": priority,
"status": "queued"
}
}
await rabbitmq_client.publish_event(
exchange_name="training.events",
routing_key="training.retrain.queued",
event_data=event_payload
)
logger.info(
"Published job queued event",
job_id=job_id,
event_id=event_payload["event_id"]
)
except Exception as e:
logger.error(
"Failed to publish job queued event",
job_id=job_id,
error=str(e)
)
# Don't fail the main operation if event publishing fails
# Factory function for creating consumer instance
def create_training_event_consumer(db_session: AsyncSession) -> TrainingEventConsumer:
"""Create training event consumer instance"""
return TrainingEventConsumer(db_session)

View File

View File

@@ -0,0 +1,89 @@
# ================================================================
# TRAINING SERVICE CONFIGURATION
# services/training/app/core/config.py
# ================================================================
"""
Training service configuration
ML model training and management
"""
from shared.config.base import BaseServiceSettings
import os
class TrainingSettings(BaseServiceSettings):
"""Training service specific settings"""
# Service Identity
APP_NAME: str = "Training Service"
SERVICE_NAME: str = "training-service"
DESCRIPTION: str = "Machine learning model training service"
# Database configuration (secure approach - build from components)
@property
def DATABASE_URL(self) -> str:
"""Build database URL from secure components"""
# Try complete URL first (for backward compatibility)
complete_url = os.getenv("TRAINING_DATABASE_URL")
if complete_url:
return complete_url
# Build from components (secure approach)
user = os.getenv("TRAINING_DB_USER", "training_user")
password = os.getenv("TRAINING_DB_PASSWORD", "training_pass123")
host = os.getenv("TRAINING_DB_HOST", "localhost")
port = os.getenv("TRAINING_DB_PORT", "5432")
name = os.getenv("TRAINING_DB_NAME", "training_db")
return f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{name}"
# Redis Database (dedicated for training cache)
REDIS_DB: int = 1
# ML Model Storage
MODEL_BACKUP_ENABLED: bool = os.getenv("MODEL_BACKUP_ENABLED", "true").lower() == "true"
MODEL_VERSIONING_ENABLED: bool = os.getenv("MODEL_VERSIONING_ENABLED", "true").lower() == "true"
# MinIO Configuration
MINIO_ENDPOINT: str = os.getenv("MINIO_ENDPOINT", "minio.bakery-ia.svc.cluster.local:9000")
MINIO_ACCESS_KEY: str = os.getenv("MINIO_ACCESS_KEY", "training-service")
MINIO_SECRET_KEY: str = os.getenv("MINIO_SECRET_KEY", "training-secret-key")
MINIO_USE_SSL: bool = os.getenv("MINIO_USE_SSL", "true").lower() == "true"
MINIO_MODEL_BUCKET: str = os.getenv("MINIO_MODEL_BUCKET", "training-models")
MINIO_CONSOLE_PORT: str = os.getenv("MINIO_CONSOLE_PORT", "9001")
MINIO_API_PORT: str = os.getenv("MINIO_API_PORT", "9000")
MINIO_REGION: str = os.getenv("MINIO_REGION", "us-east-1")
MINIO_MODEL_LIFECYCLE_DAYS: int = int(os.getenv("MINIO_MODEL_LIFECYCLE_DAYS", "90"))
MINIO_CACHE_TTL_SECONDS: int = int(os.getenv("MINIO_CACHE_TTL_SECONDS", "3600"))
# Training Configuration
MAX_CONCURRENT_TRAINING_JOBS: int = int(os.getenv("MAX_CONCURRENT_TRAINING_JOBS", "3"))
# Prophet Specific Configuration
PROPHET_HOLIDAYS_PRIOR_SCALE: float = float(os.getenv("PROPHET_HOLIDAYS_PRIOR_SCALE", "10.0"))
# Spanish Holiday Integration
ENABLE_CUSTOM_HOLIDAYS: bool = os.getenv("ENABLE_CUSTOM_HOLIDAYS", "true").lower() == "true"
# Data Processing
DATA_PREPROCESSING_ENABLED: bool = True
OUTLIER_DETECTION_ENABLED: bool = os.getenv("OUTLIER_DETECTION_ENABLED", "true").lower() == "true"
SEASONAL_DECOMPOSITION_ENABLED: bool = os.getenv("SEASONAL_DECOMPOSITION_ENABLED", "true").lower() == "true"
# Model Validation
CROSS_VALIDATION_ENABLED: bool = os.getenv("CROSS_VALIDATION_ENABLED", "true").lower() == "true"
VALIDATION_SPLIT_RATIO: float = float(os.getenv("VALIDATION_SPLIT_RATIO", "0.2"))
MIN_MODEL_ACCURACY: float = float(os.getenv("MIN_MODEL_ACCURACY", "0.7"))
# Distributed Training (for future scaling)
DISTRIBUTED_TRAINING_ENABLED: bool = os.getenv("DISTRIBUTED_TRAINING_ENABLED", "false").lower() == "true"
TRAINING_WORKER_COUNT: int = int(os.getenv("TRAINING_WORKER_COUNT", "1"))
PROPHET_DAILY_SEASONALITY: bool = True
PROPHET_WEEKLY_SEASONALITY: bool = True
PROPHET_YEARLY_SEASONALITY: bool = True
# Throttling settings for parallel training to prevent heartbeat blocking
MAX_CONCURRENT_TRAININGS: int = int(os.getenv("MAX_CONCURRENT_TRAININGS", "3"))
settings = TrainingSettings()

View File

@@ -0,0 +1,97 @@
"""
Training Service Constants
Centralized constants to avoid magic numbers throughout the codebase
"""
# Data Validation Thresholds
MIN_DATA_POINTS_REQUIRED = 30
RECOMMENDED_DATA_POINTS = 90
MAX_ZERO_RATIO_ERROR = 0.9 # 90% zeros = error
HIGH_ZERO_RATIO_WARNING = 0.7 # 70% zeros = warning
MAX_ZERO_RATIO_INTERMITTENT = 0.8 # Products with >80% zeros are intermittent
MODERATE_SPARSITY_THRESHOLD = 0.6 # 60% zeros = moderate sparsity
# Training Time Periods (in days)
MIN_NON_ZERO_DAYS = 30 # Minimum days with non-zero sales
DATA_QUALITY_DAY_THRESHOLD_LOW = 90
DATA_QUALITY_DAY_THRESHOLD_HIGH = 365
MAX_TRAINING_RANGE_DAYS = 730 # 2 years
MIN_TRAINING_RANGE_DAYS = 30
# Product Classification Thresholds
HIGH_VOLUME_MEAN_SALES = 10.0
HIGH_VOLUME_ZERO_RATIO = 0.3
MEDIUM_VOLUME_MEAN_SALES = 5.0
MEDIUM_VOLUME_ZERO_RATIO = 0.5
LOW_VOLUME_MEAN_SALES = 2.0
LOW_VOLUME_ZERO_RATIO = 0.7
# Hyperparameter Optimization
OPTUNA_TRIALS_HIGH_VOLUME = 30
OPTUNA_TRIALS_MEDIUM_VOLUME = 25
OPTUNA_TRIALS_LOW_VOLUME = 20
OPTUNA_TRIALS_INTERMITTENT = 15
OPTUNA_TIMEOUT_SECONDS = 600
# Prophet Uncertainty Sampling
UNCERTAINTY_SAMPLES_SPARSE_MIN = 100
UNCERTAINTY_SAMPLES_SPARSE_MAX = 200
UNCERTAINTY_SAMPLES_LOW_MIN = 150
UNCERTAINTY_SAMPLES_LOW_MAX = 300
UNCERTAINTY_SAMPLES_MEDIUM_MIN = 200
UNCERTAINTY_SAMPLES_MEDIUM_MAX = 500
UNCERTAINTY_SAMPLES_HIGH_MIN = 300
UNCERTAINTY_SAMPLES_HIGH_MAX = 800
# MAPE Calculation
MAPE_LOW_VOLUME_THRESHOLD = 2.0
MAPE_MEDIUM_VOLUME_THRESHOLD = 5.0
MAPE_CALCULATION_MIN_THRESHOLD = 0.5
MAPE_CALCULATION_MID_THRESHOLD = 1.0
MAPE_MAX_CAP = 200.0 # Cap MAPE at 200%
MAPE_MEDIUM_CAP = 150.0
# Baseline MAPE estimates for improvement calculation
BASELINE_MAPE_VERY_SPARSE = 80.0
BASELINE_MAPE_SPARSE = 60.0
BASELINE_MAPE_HIGH_VOLUME = 25.0
BASELINE_MAPE_MEDIUM_VOLUME = 35.0
BASELINE_MAPE_LOW_VOLUME = 45.0
IMPROVEMENT_SIGNIFICANCE_THRESHOLD = 0.8 # Only claim improvement if MAPE < 80% of baseline
# Cross-validation
CV_N_SPLITS = 2
CV_MIN_VALIDATION_DAYS = 7
# Progress tracking
PROGRESS_DATA_PREPARATION_START = 0
PROGRESS_DATA_PREPARATION_END = 45
PROGRESS_MODEL_TRAINING_START = 45
PROGRESS_MODEL_TRAINING_END = 85
PROGRESS_FINALIZATION_START = 85
PROGRESS_FINALIZATION_END = 100
# HTTP Client Configuration
HTTP_TIMEOUT_DEFAULT = 30.0 # seconds
HTTP_TIMEOUT_LONG_RUNNING = 60.0 # for training data fetches
HTTP_MAX_RETRIES = 3
HTTP_RETRY_BACKOFF_FACTOR = 2.0
# WebSocket Configuration
WEBSOCKET_PING_TIMEOUT = 60.0 # seconds
WEBSOCKET_ACTIVITY_WARNING_THRESHOLD = 90.0 # seconds
WEBSOCKET_CONSUMER_HEARTBEAT_INTERVAL = 10.0 # seconds
# Synthetic Data Generation
SYNTHETIC_TEMP_DEFAULT = 50.0
SYNTHETIC_TEMP_VARIATION = 100.0
SYNTHETIC_TRAFFIC_DEFAULT = 50.0
SYNTHETIC_TRAFFIC_VARIATION = 100.0
# Model Storage
MODEL_FILE_EXTENSION = ".pkl"
METADATA_FILE_EXTENSION = ".json"
# Data Quality Scoring
MIN_QUALITY_SCORE = 0.1
MAX_QUALITY_SCORE = 1.0

View File

@@ -0,0 +1,432 @@
# services/training/app/core/database.py
"""
Database configuration for training service
Uses shared database infrastructure
"""
import structlog
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession
from contextlib import asynccontextmanager
from sqlalchemy import text
from shared.database.base import DatabaseManager, Base
from app.core.config import settings
logger = structlog.get_logger()
# Initialize database manager with connection pooling configuration
database_manager = DatabaseManager(
settings.DATABASE_URL,
pool_size=settings.DB_POOL_SIZE,
max_overflow=settings.DB_MAX_OVERFLOW,
pool_timeout=settings.DB_POOL_TIMEOUT,
pool_recycle=settings.DB_POOL_RECYCLE,
pool_pre_ping=settings.DB_POOL_PRE_PING,
echo=settings.DB_ECHO
)
# Alias for convenience - matches the existing interface
get_db = database_manager.get_db
@asynccontextmanager
async def get_background_db_session():
async with database_manager.async_session_local() as session:
try:
yield session
await session.commit()
except Exception as e:
await session.rollback()
raise
finally:
await session.close()
async def get_db_health() -> bool:
"""
Health check function for database connectivity
Enhanced version of the shared functionality
"""
try:
async with database_manager.async_engine.begin() as conn:
await conn.execute(text("SELECT 1"))
logger.debug("Database health check passed")
return True
except Exception as e:
logger.error("Database health check failed", error=str(e))
return False
async def get_comprehensive_db_health() -> dict:
"""
Comprehensive health check that verifies both connectivity and table existence
"""
health_status = {
"status": "healthy",
"connectivity": False,
"tables_exist": False,
"tables_verified": [],
"missing_tables": [],
"errors": []
}
try:
# Test basic connectivity
health_status["connectivity"] = await get_db_health()
if not health_status["connectivity"]:
health_status["status"] = "unhealthy"
health_status["errors"].append("Database connectivity failed")
return health_status
# Test table existence
tables_verified = await _verify_tables_exist()
health_status["tables_exist"] = tables_verified
if tables_verified:
health_status["tables_verified"] = [
'model_training_logs', 'trained_models', 'model_performance_metrics',
'training_job_queue', 'model_artifacts'
]
else:
health_status["status"] = "unhealthy"
health_status["errors"].append("Required tables missing or inaccessible")
# Try to identify which specific tables are missing
try:
async with database_manager.get_session() as session:
for table_name in ['model_training_logs', 'trained_models', 'model_performance_metrics',
'training_job_queue', 'model_artifacts']:
try:
await session.execute(text(f"SELECT 1 FROM {table_name} LIMIT 1"))
health_status["tables_verified"].append(table_name)
except Exception:
health_status["missing_tables"].append(table_name)
except Exception as e:
health_status["errors"].append(f"Error checking individual tables: {str(e)}")
logger.debug("Comprehensive database health check completed",
status=health_status["status"],
connectivity=health_status["connectivity"],
tables_exist=health_status["tables_exist"])
except Exception as e:
health_status["status"] = "unhealthy"
health_status["errors"].append(f"Health check failed: {str(e)}")
logger.error("Comprehensive database health check failed", error=str(e))
return health_status
# Training service specific database utilities
class TrainingDatabaseUtils:
"""Training service specific database utilities"""
@staticmethod
async def cleanup_old_training_logs(days_old: int = 90):
"""Clean up old training logs"""
try:
async with database_manager.async_session_local() as session:
if settings.DATABASE_URL.startswith("sqlite"):
query = text(
"DELETE FROM model_training_logs "
"WHERE start_time < datetime('now', :days_param)"
)
params = {"days_param": f"-{days_old} days"}
else:
query = text(
"DELETE FROM model_training_logs "
"WHERE start_time < NOW() - INTERVAL :days_param"
)
params = {"days_param": f"{days_old} days"}
result = await session.execute(query, params)
await session.commit()
deleted_count = result.rowcount
logger.info("Cleaned up old training logs",
deleted_count=deleted_count,
days_old=days_old)
return deleted_count
except Exception as e:
logger.error("Training logs cleanup failed", error=str(e))
raise
@staticmethod
async def cleanup_old_models(days_old: int = 365):
"""Clean up old inactive models"""
try:
async with database_manager.async_session_local() as session:
if settings.DATABASE_URL.startswith("sqlite"):
query = text(
"DELETE FROM trained_models "
"WHERE is_active = 0 AND created_at < datetime('now', :days_param)"
)
params = {"days_param": f"-{days_old} days"}
else:
query = text(
"DELETE FROM trained_models "
"WHERE is_active = false AND created_at < NOW() - INTERVAL :days_param"
)
params = {"days_param": f"{days_old} days"}
result = await session.execute(query, params)
await session.commit()
deleted_count = result.rowcount
logger.info("Cleaned up old models",
deleted_count=deleted_count,
days_old=days_old)
return deleted_count
except Exception as e:
logger.error("Model cleanup failed", error=str(e))
raise
@staticmethod
async def get_training_statistics(tenant_id: str = None) -> dict:
"""Get training statistics"""
try:
async with database_manager.async_session_local() as session:
# Base query for training logs
if tenant_id:
logs_query = text(
"SELECT status, COUNT(*) as count "
"FROM model_training_logs "
"WHERE tenant_id = :tenant_id "
"GROUP BY status"
)
models_query = text(
"SELECT COUNT(*) as count "
"FROM trained_models "
"WHERE tenant_id = :tenant_id AND is_active = :is_active"
)
params = {"tenant_id": tenant_id}
else:
logs_query = text(
"SELECT status, COUNT(*) as count "
"FROM model_training_logs "
"GROUP BY status"
)
models_query = text(
"SELECT COUNT(*) as count "
"FROM trained_models "
"WHERE is_active = :is_active"
)
params = {}
# Get training job statistics
logs_result = await session.execute(logs_query, params)
job_stats = {row.status: row.count for row in logs_result.fetchall()}
# Get active models count
active_models_result = await session.execute(
models_query,
{**params, "is_active": True}
)
active_models = active_models_result.scalar() or 0
# Get inactive models count
inactive_models_result = await session.execute(
models_query,
{**params, "is_active": False}
)
inactive_models = inactive_models_result.scalar() or 0
return {
"training_jobs": job_stats,
"active_models": active_models,
"inactive_models": inactive_models,
"total_models": active_models + inactive_models
}
except Exception as e:
logger.error("Failed to get training statistics", error=str(e))
return {
"training_jobs": {},
"active_models": 0,
"inactive_models": 0,
"total_models": 0
}
@staticmethod
async def check_tenant_data_exists(tenant_id: str) -> bool:
"""Check if tenant has any training data"""
try:
async with database_manager.async_session_local() as session:
query = text(
"SELECT COUNT(*) as count "
"FROM model_training_logs "
"WHERE tenant_id = :tenant_id "
"LIMIT 1"
)
result = await session.execute(query, {"tenant_id": tenant_id})
count = result.scalar() or 0
return count > 0
except Exception as e:
logger.error("Failed to check tenant data existence",
tenant_id=tenant_id, error=str(e))
return False
# Enhanced database session dependency with better error handling
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
"""
Enhanced database session dependency with better logging and error handling
"""
async with database_manager.async_session_local() as session:
try:
logger.debug("Database session created")
yield session
except Exception as e:
logger.error("Database session error", error=str(e), exc_info=True)
await session.rollback()
raise
finally:
await session.close()
logger.debug("Database session closed")
# Database initialization for training service
async def initialize_training_database():
"""Initialize database tables for training service with retry logic and verification"""
import asyncio
from sqlalchemy import text
max_retries = 5
retry_delay = 2.0
for attempt in range(1, max_retries + 1):
try:
logger.info("Initializing training service database",
attempt=attempt, max_retries=max_retries)
# Step 1: Test database connectivity first
logger.info("Testing database connectivity...")
connection_ok = await database_manager.test_connection()
if not connection_ok:
raise Exception("Database connection test failed")
logger.info("Database connectivity verified")
# Step 2: Import models to ensure they're registered
logger.info("Importing and registering database models...")
from app.models.training import (
ModelTrainingLog,
TrainedModel,
ModelPerformanceMetric,
TrainingJobQueue,
ModelArtifact
)
# Verify models are registered in metadata
expected_tables = {
'model_training_logs', 'trained_models', 'model_performance_metrics',
'training_job_queue', 'model_artifacts'
}
registered_tables = set(Base.metadata.tables.keys())
missing_tables = expected_tables - registered_tables
if missing_tables:
raise Exception(f"Models not properly registered: {missing_tables}")
logger.info("Models registered successfully",
tables=list(registered_tables))
# Step 3: Create tables using shared infrastructure with verification
logger.info("Creating database tables...")
await database_manager.create_tables()
# Step 4: Verify tables were actually created
logger.info("Verifying table creation...")
verification_successful = await _verify_tables_exist()
if not verification_successful:
raise Exception("Table verification failed - tables were not created properly")
logger.info("Training service database initialized and verified successfully",
attempt=attempt)
return
except Exception as e:
logger.error("Database initialization failed",
attempt=attempt,
max_retries=max_retries,
error=str(e))
if attempt == max_retries:
logger.error("All database initialization attempts failed - giving up")
raise Exception(f"Failed to initialize training database after {max_retries} attempts: {str(e)}")
# Wait before retry with exponential backoff
wait_time = retry_delay * (2 ** (attempt - 1))
logger.info("Retrying database initialization",
retry_in_seconds=wait_time,
next_attempt=attempt + 1)
await asyncio.sleep(wait_time)
async def _verify_tables_exist() -> bool:
"""Verify that all required tables exist in the database"""
try:
async with database_manager.get_session() as session:
# Check each required table exists and is accessible
required_tables = [
'model_training_logs',
'trained_models',
'model_performance_metrics',
'training_job_queue',
'model_artifacts'
]
for table_name in required_tables:
try:
# Try to query the table structure
result = await session.execute(
text(f"SELECT 1 FROM {table_name} LIMIT 1")
)
logger.debug(f"Table {table_name} exists and is accessible")
except Exception as table_error:
# If it's a "relation does not exist" error, table creation failed
if "does not exist" in str(table_error).lower():
logger.error(f"Table {table_name} does not exist", error=str(table_error))
return False
# If it's an empty table, that's fine - table exists
elif "no data" in str(table_error).lower():
logger.debug(f"Table {table_name} exists but is empty (normal)")
else:
logger.warning(f"Unexpected error querying {table_name}", error=str(table_error))
logger.info("All required tables verified successfully",
tables=required_tables)
return True
except Exception as e:
logger.error("Table verification failed", error=str(e))
return False
# Database cleanup for training service
async def cleanup_training_database():
"""Cleanup database connections for training service"""
try:
logger.info("Cleaning up training service database connections")
# Close engine connections
if hasattr(database_manager, 'async_engine') and database_manager.async_engine:
await database_manager.async_engine.dispose()
logger.info("Training service database cleanup completed")
except Exception as e:
logger.error("Failed to cleanup training service database", error=str(e))
# Export the commonly used items to maintain compatibility
__all__ = [
'Base',
'database_manager',
'get_db',
'get_db_session',
'get_db_health',
'TrainingDatabaseUtils',
'initialize_training_database',
'cleanup_training_database'
]

View File

@@ -0,0 +1,35 @@
"""
Training Progress Constants
Centralized constants for training progress tracking and timing
"""
# Progress Milestones (percentage)
PROGRESS_STARTED = 0
PROGRESS_DATA_VALIDATION = 10
PROGRESS_DATA_ANALYSIS = 20
PROGRESS_DATA_PREPARATION_COMPLETE = 30
PROGRESS_ML_TRAINING_START = 40
PROGRESS_TRAINING_COMPLETE = 85
PROGRESS_STORING_MODELS = 92
PROGRESS_STORING_METRICS = 94
PROGRESS_COMPLETED = 100
# Progress Ranges
PROGRESS_TRAINING_RANGE_START = 20 # After data analysis
PROGRESS_TRAINING_RANGE_END = 80 # Before finalization
PROGRESS_TRAINING_RANGE_WIDTH = PROGRESS_TRAINING_RANGE_END - PROGRESS_TRAINING_RANGE_START # 60%
# Time Limits and Intervals (seconds)
MAX_ESTIMATED_TIME_REMAINING_SECONDS = 1800 # 30 minutes
WEBSOCKET_HEARTBEAT_INTERVAL_SECONDS = 30
WEBSOCKET_RECONNECT_MAX_ATTEMPTS = 3
WEBSOCKET_RECONNECT_INITIAL_DELAY_SECONDS = 1
WEBSOCKET_RECONNECT_MAX_DELAY_SECONDS = 10
# Training Timeouts (seconds)
TRAINING_SKIP_OPTION_DELAY_SECONDS = 120 # 2 minutes
HTTP_POLLING_INTERVAL_MS = 5000 # 5 seconds
HTTP_POLLING_DEBOUNCE_MS = 5000 # 5 seconds before enabling after WebSocket disconnect
# Frontend Display
TRAINING_COMPLETION_DELAY_MS = 2000 # Delay before navigating after completion

View File

@@ -0,0 +1,265 @@
# ================================================================
# services/training/app/main.py
# ================================================================
"""
Training Service Main Application
ML training service for bakery demand forecasting
"""
import asyncio
from fastapi import FastAPI, Request
from sqlalchemy import text
from app.core.config import settings
from app.core.database import initialize_training_database, cleanup_training_database, database_manager
from app.api import training_jobs, training_operations, models, health, monitoring, websocket_operations, audit
from app.services.training_events import setup_messaging, cleanup_messaging
from app.websocket.events import setup_websocket_event_consumer, cleanup_websocket_consumers
from shared.service_base import StandardFastAPIService
from shared.monitoring.system_metrics import SystemMetricsCollector
class TrainingService(StandardFastAPIService):
"""Training Service with standardized setup"""
def __init__(self):
# Define expected database tables for health checks
training_expected_tables = [
'model_training_logs', 'trained_models', 'model_performance_metrics',
'training_job_queue', 'model_artifacts'
]
super().__init__(
service_name="training-service",
app_name="Bakery Training Service",
description="ML training service for bakery demand forecasting",
version="1.0.0",
log_level=settings.LOG_LEVEL,
cors_origins=settings.CORS_ORIGINS_LIST,
api_prefix="",
database_manager=database_manager,
expected_tables=training_expected_tables,
enable_messaging=True
)
async def _setup_messaging(self):
"""Setup messaging for training service"""
await setup_messaging()
self.logger.info("Messaging setup completed")
# Initialize Redis pub/sub for cross-pod WebSocket broadcasting
await self._setup_websocket_redis()
# Set up WebSocket event consumer (listens to RabbitMQ and broadcasts to WebSockets)
success = await setup_websocket_event_consumer()
if success:
self.logger.info("WebSocket event consumer setup completed")
else:
self.logger.warning("WebSocket event consumer setup failed")
async def _setup_websocket_redis(self):
"""
Initialize Redis pub/sub for WebSocket cross-pod broadcasting.
CRITICAL FOR HORIZONTAL SCALING:
Without this, WebSocket clients on Pod A won't receive events
from training jobs running on Pod B.
"""
try:
from app.websocket.manager import websocket_manager
from app.core.config import settings
redis_url = settings.REDIS_URL
success = await websocket_manager.initialize_redis(redis_url)
if success:
self.logger.info("WebSocket Redis pub/sub initialized for horizontal scaling")
else:
self.logger.warning(
"WebSocket Redis pub/sub failed to initialize. "
"WebSocket events will only be delivered to local connections."
)
except Exception as e:
self.logger.error("Failed to setup WebSocket Redis pub/sub",
error=str(e))
# Don't fail startup - WebSockets will work locally without Redis
async def _cleanup_messaging(self):
"""Cleanup messaging for training service"""
# Shutdown WebSocket Redis pub/sub
try:
from app.websocket.manager import websocket_manager
await websocket_manager.shutdown()
self.logger.info("WebSocket Redis pub/sub shutdown completed")
except Exception as e:
self.logger.warning("Error shutting down WebSocket Redis", error=str(e))
await cleanup_websocket_consumers()
await cleanup_messaging()
async def verify_migrations(self):
"""Verify database schema matches the latest migrations dynamically."""
try:
async with self.database_manager.get_session() as session:
result = await session.execute(text("SELECT version_num FROM alembic_version"))
version = result.scalar()
if not version:
self.logger.error("No migration version found in database")
raise RuntimeError("Database not initialized - no alembic version found")
self.logger.info(f"Migration verification successful: {version}")
return version
except Exception as e:
self.logger.error(f"Migration verification failed: {e}")
raise
async def on_startup(self, app: FastAPI):
"""Custom startup logic including migration verification"""
await self.verify_migrations()
# Initialize system metrics collection
system_metrics = SystemMetricsCollector("training")
self.logger.info("System metrics collection started")
# Recover stale jobs from previous pod crashes
# This is important for horizontal scaling - jobs may be left in 'running'
# state if a pod crashes. We mark them as failed so they can be retried.
await self._recover_stale_jobs()
self.logger.info("Training service startup completed")
async def _recover_stale_jobs(self):
"""
Recover stale training jobs on startup.
When a pod crashes mid-training, jobs are left in 'running' or 'pending' state.
This method finds jobs that haven't been updated in a while and marks them
as failed so users can retry them.
"""
try:
from app.repositories.training_log_repository import TrainingLogRepository
async with self.database_manager.get_session() as session:
log_repo = TrainingLogRepository(session)
# Recover jobs that haven't been updated in 60 minutes
# This is conservative - most training jobs complete within 30 minutes
recovered = await log_repo.recover_stale_jobs(stale_threshold_minutes=60)
if recovered:
self.logger.warning(
"Recovered stale training jobs on startup",
recovered_count=len(recovered),
job_ids=[j.job_id for j in recovered]
)
else:
self.logger.info("No stale training jobs to recover")
except Exception as e:
# Don't fail startup if recovery fails - just log the error
self.logger.error("Failed to recover stale jobs on startup", error=str(e))
async def on_shutdown(self, app: FastAPI):
"""Custom shutdown logic for training service"""
await cleanup_training_database()
self.logger.info("Training database cleanup completed")
def get_service_features(self):
"""Return training-specific features"""
return [
"ml_model_training",
"demand_forecasting",
"model_performance_tracking",
"training_job_queue",
"model_artifacts_management",
"websocket_support",
"messaging_integration"
]
def setup_custom_middleware(self):
"""Setup custom middleware for training service"""
# Request middleware for logging and metrics
@self.app.middleware("http")
async def process_request(request: Request, call_next):
"""Process requests with logging and metrics"""
start_time = asyncio.get_event_loop().time()
try:
response = await call_next(request)
duration = asyncio.get_event_loop().time() - start_time
self.logger.info(
"Request completed",
method=request.method,
path=request.url.path,
status_code=response.status_code,
duration_ms=round(duration * 1000, 2)
)
return response
except Exception as e:
duration = asyncio.get_event_loop().time() - start_time
self.logger.error(
"Request failed",
method=request.method,
path=request.url.path,
error=str(e),
duration_ms=round(duration * 1000, 2)
)
raise
def setup_custom_endpoints(self):
"""Setup custom endpoints for training service"""
# Note: Metrics are exported via OpenTelemetry OTLP to SigNoz
# The /metrics endpoint is not needed as metrics are pushed automatically
# @self.app.get("/metrics")
# async def get_metrics():
# """Prometheus metrics endpoint"""
# if self.metrics_collector:
# return self.metrics_collector.get_metrics()
# return {"status": "metrics not available"}
@self.app.get("/")
async def root():
return {"service": "training-service", "version": "1.0.0"}
# Create service instance
service = TrainingService()
# Create FastAPI app with standardized setup
app = service.create_app(
docs_url="/docs",
redoc_url="/redoc"
)
# Setup standard endpoints
service.setup_standard_endpoints()
# Setup custom middleware
service.setup_custom_middleware()
# Setup custom endpoints
service.setup_custom_endpoints()
# Include API routers
# IMPORTANT: Register audit router FIRST to avoid route matching conflicts
service.add_router(audit.router)
service.add_router(training_jobs.router, tags=["training-jobs"])
service.add_router(training_operations.router, tags=["training-operations"])
service.add_router(models.router, tags=["models"])
service.add_router(health.router, tags=["health"])
service.add_router(monitoring.router, tags=["monitoring"])
service.add_router(websocket_operations.router, tags=["websocket"])
if __name__ == "__main__":
uvicorn.run(
"app.main:app",
host="0.0.0.0",
port=settings.PORT,
reload=settings.DEBUG,
log_level=settings.LOG_LEVEL.lower()
)

View File

@@ -0,0 +1,14 @@
"""
ML Pipeline Components
Machine learning training and prediction components
"""
from .trainer import EnhancedBakeryMLTrainer
from .data_processor import EnhancedBakeryDataProcessor
from .prophet_manager import BakeryProphetManager
__all__ = [
"EnhancedBakeryMLTrainer",
"EnhancedBakeryDataProcessor",
"BakeryProphetManager"
]

View File

@@ -0,0 +1,307 @@
"""
Calendar-based Feature Engineering
Hyperlocal school calendar and event features for demand forecasting
"""
import pandas as pd
import structlog
from typing import Dict, List, Any, Optional
from datetime import datetime, date, time, timedelta
from shared.clients.external_client import ExternalServiceClient
logger = structlog.get_logger()
class CalendarFeatureEngine:
"""
Generates features based on school calendars and local events
for hyperlocal demand forecasting enhancement
"""
def __init__(self, external_client: ExternalServiceClient):
self.external_client = external_client
self.calendar_cache = {} # Cache calendar data to avoid repeated API calls
async def get_calendar_for_tenant(
self,
tenant_id: str,
city_id: Optional[str] = "madrid"
) -> Optional[Dict[str, Any]]:
"""
Get the assigned school calendar for a tenant
If tenant has no assignment, returns None
"""
try:
# Check cache first
cache_key = f"tenant_{tenant_id}_calendar"
if cache_key in self.calendar_cache:
logger.debug("Using cached calendar", tenant_id=tenant_id)
return self.calendar_cache[cache_key]
# Get tenant location context
context = await self.external_client.get_tenant_location_context(tenant_id)
if not context or not context.get("calendar"):
logger.info(
"No calendar assigned to tenant, using default if available",
tenant_id=tenant_id
)
return None
calendar = context["calendar"]
self.calendar_cache[cache_key] = calendar
logger.info(
"Retrieved calendar for tenant",
tenant_id=tenant_id,
calendar_name=calendar.get("calendar_name")
)
return calendar
except Exception as e:
logger.error(
"Error retrieving calendar for tenant",
tenant_id=tenant_id,
error=str(e)
)
return None
def _is_date_in_holiday_period(
self,
check_date: date,
holiday_periods: List[Dict[str, Any]]
) -> tuple[bool, Optional[str]]:
"""
Check if a date falls within any holiday period
Returns:
(is_holiday, holiday_name)
"""
for period in holiday_periods:
start = datetime.strptime(period["start_date"], "%Y-%m-%d").date()
end = datetime.strptime(period["end_date"], "%Y-%m-%d").date()
if start <= check_date <= end:
return True, period["name"]
return False, None
def _is_school_hours_active(
self,
check_datetime: datetime,
school_hours: Dict[str, Any]
) -> bool:
"""
Check if datetime falls during school operating hours
Args:
check_datetime: DateTime to check
school_hours: School hours configuration dict
Returns:
True if during school hours, False otherwise
"""
# Only check weekdays
if check_datetime.weekday() >= 5: # Saturday=5, Sunday=6
return False
check_time = check_datetime.time()
# Morning session
morning_start = datetime.strptime(
school_hours["morning_start"], "%H:%M"
).time()
morning_end = datetime.strptime(
school_hours["morning_end"], "%H:%M"
).time()
if morning_start <= check_time <= morning_end:
return True
# Afternoon session (if applicable)
if school_hours.get("has_afternoon_session", False):
afternoon_start = datetime.strptime(
school_hours["afternoon_start"], "%H:%M"
).time()
afternoon_end = datetime.strptime(
school_hours["afternoon_end"], "%H:%M"
).time()
if afternoon_start <= check_time <= afternoon_end:
return True
return False
def _calculate_school_proximity_intensity(
self,
check_datetime: datetime,
school_hours: Dict[str, Any]
) -> float:
"""
Calculate intensity of school-related foot traffic
Peaks during drop-off and pick-up times
Returns:
Float between 0.0 (no impact) and 1.0 (peak impact)
"""
# Only weekdays
if check_datetime.weekday() >= 5:
return 0.0
check_time = check_datetime.time()
# Define peak windows (30 minutes before and after school start/end)
morning_start = datetime.strptime(
school_hours["morning_start"], "%H:%M"
).time()
morning_end = datetime.strptime(
school_hours["morning_end"], "%H:%M"
).time()
# Morning drop-off peak (30 min before to 15 min after start)
drop_off_start = (
datetime.combine(date.today(), morning_start) - timedelta(minutes=30)
).time()
drop_off_end = (
datetime.combine(date.today(), morning_start) + timedelta(minutes=15)
).time()
if drop_off_start <= check_time <= drop_off_end:
return 1.0 # Peak morning traffic
# Morning pick-up peak (15 min before to 30 min after end)
pickup_start = (
datetime.combine(date.today(), morning_end) - timedelta(minutes=15)
).time()
pickup_end = (
datetime.combine(date.today(), morning_end) + timedelta(minutes=30)
).time()
if pickup_start <= check_time <= pickup_end:
return 1.0 # Peak afternoon traffic
# During school hours (moderate impact)
if morning_start <= check_time <= morning_end:
return 0.3
# Afternoon session if applicable
if school_hours.get("has_afternoon_session", False):
afternoon_start = datetime.strptime(
school_hours["afternoon_start"], "%H:%M"
).time()
afternoon_end = datetime.strptime(
school_hours["afternoon_end"], "%H:%M"
).time()
if afternoon_start <= check_time <= afternoon_end:
return 0.3
return 0.0
async def add_calendar_features(
self,
df: pd.DataFrame,
tenant_id: str,
date_column: str = "date"
) -> pd.DataFrame:
"""
Add calendar-based features to dataframe
Features added:
- is_school_holiday: Binary (1/0)
- school_holiday_name: String (name of holiday or None)
- school_hours_active: Binary (1/0) - if during school operating hours
- school_proximity_intensity: Float (0.0-1.0) - peak during drop-off/pick-up
Args:
df: DataFrame with date/datetime column
tenant_id: Tenant ID to get calendar assignment
date_column: Name of date column
Returns:
DataFrame with added calendar features
"""
try:
logger.info(
"Adding calendar-based features",
tenant_id=tenant_id,
rows=len(df)
)
# Get calendar for tenant
calendar = await self.get_calendar_for_tenant(tenant_id)
if not calendar:
logger.warning(
"No calendar available, using fallback features",
tenant_id=tenant_id
)
# Add default features (all zeros)
df["is_school_holiday"] = 0
df["school_holiday_name"] = None
df["school_hours_active"] = 0
df["school_proximity_intensity"] = 0.0
return df
holiday_periods = calendar.get("holiday_periods", [])
school_hours = calendar.get("school_hours", {})
# Initialize feature columns
school_holidays = []
holiday_names = []
hours_active = []
proximity_intensity = []
# Process each row
for idx, row in df.iterrows():
row_date = pd.to_datetime(row[date_column])
# Check if holiday
is_holiday, holiday_name = self._is_date_in_holiday_period(
row_date.date(),
holiday_periods
)
school_holidays.append(1 if is_holiday else 0)
holiday_names.append(holiday_name)
# Check if during school hours (requires time component)
if hasattr(row_date, 'hour'): # Has time component
hours_active.append(
1 if self._is_school_hours_active(row_date, school_hours) else 0
)
proximity_intensity.append(
self._calculate_school_proximity_intensity(row_date, school_hours)
)
else:
# Date only, no time component
hours_active.append(0)
proximity_intensity.append(0.0)
# Add features to dataframe
df["is_school_holiday"] = school_holidays
df["school_holiday_name"] = holiday_names
df["school_hours_active"] = hours_active
df["school_proximity_intensity"] = proximity_intensity
logger.info(
"Calendar features added successfully",
tenant_id=tenant_id,
holiday_periods_count=len(holiday_periods),
holidays_found=sum(school_holidays)
)
return df
except Exception as e:
logger.error(
"Error adding calendar features",
tenant_id=tenant_id,
error=str(e)
)
# Return df with default features on error
df["is_school_holiday"] = 0
df["school_holiday_name"] = None
df["school_hours_active"] = 0
df["school_proximity_intensity"] = 0.0
return df

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,355 @@
"""
Enhanced Feature Engineering for Hybrid Prophet + XGBoost Models
Adds lagged features, rolling statistics, and advanced interactions
"""
import pandas as pd
import numpy as np
from typing import Dict, List, Optional
import structlog
from shared.ml.feature_calculator import HistoricalFeatureCalculator
logger = structlog.get_logger()
class AdvancedFeatureEngineer:
"""
Advanced feature engineering for hybrid forecasting models.
Adds lagged features, rolling statistics, and complex interactions.
"""
def __init__(self):
self.feature_columns = []
self.feature_calculator = HistoricalFeatureCalculator()
def add_lagged_features(self, df: pd.DataFrame, lag_days: List[int] = None) -> pd.DataFrame:
"""
Add lagged demand features for capturing recent trends.
Uses shared feature calculator for consistency with prediction service.
Args:
df: DataFrame with 'quantity' column
lag_days: List of lag periods (default: [1, 7, 14])
Returns:
DataFrame with added lagged features
"""
if lag_days is None:
lag_days = [1, 7, 14]
# Use shared calculator for consistent lag calculation
df = self.feature_calculator.calculate_lag_features(
df,
lag_days=lag_days,
mode='training'
)
# Update feature columns list
for lag in lag_days:
col_name = f'lag_{lag}_day'
if col_name not in self.feature_columns:
self.feature_columns.append(col_name)
logger.info(f"Added {len(lag_days)} lagged features (using shared calculator)", lags=lag_days)
return df
def add_rolling_features(
self,
df: pd.DataFrame,
windows: List[int] = None,
features: List[str] = None
) -> pd.DataFrame:
"""
Add rolling statistics (mean, std, max, min).
Uses shared feature calculator for consistency with prediction service.
Args:
df: DataFrame with 'quantity' column
windows: List of window sizes (default: [7, 14, 30])
features: List of statistics to calculate (default: ['mean', 'std', 'max', 'min'])
Returns:
DataFrame with rolling features
"""
if windows is None:
windows = [7, 14, 30]
if features is None:
features = ['mean', 'std', 'max', 'min']
# Use shared calculator for consistent rolling calculation
df = self.feature_calculator.calculate_rolling_features(
df,
windows=windows,
statistics=features,
mode='training'
)
# Update feature columns list
for window in windows:
for feature in features:
col_name = f'rolling_{feature}_{window}d'
if col_name not in self.feature_columns:
self.feature_columns.append(col_name)
logger.info(f"Added rolling features (using shared calculator)", windows=windows, features=features)
return df
def add_day_of_week_features(self, df: pd.DataFrame, date_column: str = 'date') -> pd.DataFrame:
"""
Add enhanced day-of-week features.
Args:
df: DataFrame with date column
date_column: Name of date column
Returns:
DataFrame with day-of-week features
"""
df = df.copy()
# Day of week (0=Monday, 6=Sunday)
df['day_of_week'] = df[date_column].dt.dayofweek
# Is weekend
df['is_weekend'] = (df['day_of_week'] >= 5).astype(int)
# Is Friday (often higher demand due to weekend prep)
df['is_friday'] = (df['day_of_week'] == 4).astype(int)
# Is Monday (often lower demand after weekend)
df['is_monday'] = (df['day_of_week'] == 0).astype(int)
# Add to feature list
for col in ['day_of_week', 'is_weekend', 'is_friday', 'is_monday']:
if col not in self.feature_columns:
self.feature_columns.append(col)
return df
def add_calendar_enhanced_features(self, df: pd.DataFrame, date_column: str = 'date') -> pd.DataFrame:
"""
Add enhanced calendar features beyond basic temporal features.
Args:
df: DataFrame with date column
date_column: Name of date column
Returns:
DataFrame with enhanced calendar features
"""
df = df.copy()
# Month and quarter (if not already present)
if 'month' not in df.columns:
df['month'] = df[date_column].dt.month
if 'quarter' not in df.columns:
df['quarter'] = df[date_column].dt.quarter
# Day of month
df['day_of_month'] = df[date_column].dt.day
# Is month start/end
df['is_month_start'] = (df['day_of_month'] <= 3).astype(int)
df['is_month_end'] = (df[date_column].dt.is_month_end).astype(int)
# Week of year
df['week_of_year'] = df[date_column].dt.isocalendar().week
# Payday indicators for Spain (high bakery traffic)
# Spain commonly pays on: 28th, 15th, or last day of month
df['is_payday'] = (
(df['day_of_month'] == 15) | # Mid-month payday
(df['day_of_month'] == 28) | # Common Spanish payday (28th)
df[date_column].dt.is_month_end # End of month
).astype(int)
# Add to feature list
for col in ['month', 'quarter', 'day_of_month', 'is_month_start', 'is_month_end',
'week_of_year', 'is_payday']:
if col not in self.feature_columns:
self.feature_columns.append(col)
return df
def add_interaction_features(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Add interaction features between variables.
Args:
df: DataFrame with base features
Returns:
DataFrame with interaction features
"""
df = df.copy()
# Weekend × Temperature (people buy more cold drinks in hot weekends)
if 'is_weekend' in df.columns and 'temperature' in df.columns:
df['weekend_temp_interaction'] = df['is_weekend'] * df['temperature']
self.feature_columns.append('weekend_temp_interaction')
# Rain × Weekend (bad weather reduces weekend traffic)
if 'is_weekend' in df.columns and 'precipitation' in df.columns:
df['rain_weekend_interaction'] = df['is_weekend'] * (df['precipitation'] > 0).astype(int)
self.feature_columns.append('rain_weekend_interaction')
# Friday × Traffic (high Friday traffic means weekend prep buying)
if 'is_friday' in df.columns and 'traffic_volume' in df.columns:
df['friday_traffic_interaction'] = df['is_friday'] * df['traffic_volume']
self.feature_columns.append('friday_traffic_interaction')
# Month × Temperature (seasonal temperature patterns)
if 'month' in df.columns and 'temperature' in df.columns:
df['month_temp_interaction'] = df['month'] * df['temperature']
self.feature_columns.append('month_temp_interaction')
# Payday × Weekend (big shopping days)
if 'is_payday' in df.columns and 'is_weekend' in df.columns:
df['payday_weekend_interaction'] = df['is_payday'] * df['is_weekend']
self.feature_columns.append('payday_weekend_interaction')
logger.info(f"Added {len([c for c in self.feature_columns if 'interaction' in c])} interaction features")
return df
def add_trend_features(self, df: pd.DataFrame, date_column: str = 'date') -> pd.DataFrame:
"""
Add trend-based features.
Uses shared feature calculator for consistency with prediction service.
Args:
df: DataFrame with date and quantity
date_column: Name of date column
Returns:
DataFrame with trend features
"""
# Use shared calculator for consistent trend calculation
df = self.feature_calculator.calculate_trend_features(
df,
mode='training'
)
# Update feature columns list
for feature_name in ['days_since_start', 'momentum_1_7', 'trend_7_30', 'velocity_week']:
if feature_name in df.columns and feature_name not in self.feature_columns:
self.feature_columns.append(feature_name)
logger.debug("Added trend features (using shared calculator)")
return df
def add_cyclical_encoding(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Add cyclical encoding for periodic features (day_of_week, month).
Helps models understand that Monday follows Sunday, December follows January.
Args:
df: DataFrame with day_of_week and month columns
Returns:
DataFrame with cyclical features
"""
df = df.copy()
# Day of week cyclical encoding
if 'day_of_week' in df.columns:
df['day_of_week_sin'] = np.sin(2 * np.pi * df['day_of_week'] / 7)
df['day_of_week_cos'] = np.cos(2 * np.pi * df['day_of_week'] / 7)
self.feature_columns.extend(['day_of_week_sin', 'day_of_week_cos'])
# Month cyclical encoding
if 'month' in df.columns:
df['month_sin'] = np.sin(2 * np.pi * df['month'] / 12)
df['month_cos'] = np.cos(2 * np.pi * df['month'] / 12)
self.feature_columns.extend(['month_sin', 'month_cos'])
logger.info("Added cyclical encoding for temporal features")
return df
def create_all_features(
self,
df: pd.DataFrame,
date_column: str = 'date',
include_lags: bool = True,
include_rolling: bool = True,
include_interactions: bool = True,
include_cyclical: bool = True
) -> pd.DataFrame:
"""
Create all enhanced features in one go.
Args:
df: DataFrame with base data
date_column: Name of date column
include_lags: Whether to include lagged features
include_rolling: Whether to include rolling statistics
include_interactions: Whether to include interaction features
include_cyclical: Whether to include cyclical encoding
Returns:
DataFrame with all enhanced features
"""
logger.info("Creating comprehensive feature set for hybrid model")
# Reset feature list
self.feature_columns = []
# Day of week and calendar features (always needed)
df = self.add_day_of_week_features(df, date_column)
df = self.add_calendar_enhanced_features(df, date_column)
# Optional features
if include_lags:
df = self.add_lagged_features(df)
if include_rolling:
df = self.add_rolling_features(df)
if include_interactions:
df = self.add_interaction_features(df)
if include_cyclical:
df = self.add_cyclical_encoding(df)
# Trend features (depends on lags and rolling)
if include_lags or include_rolling:
df = self.add_trend_features(df, date_column)
logger.info(f"Created {len(self.feature_columns)} enhanced features for hybrid model")
return df
def get_feature_columns(self) -> List[str]:
"""Get list of all created feature column names."""
return self.feature_columns.copy()
def fill_na_values(self, df: pd.DataFrame, strategy: str = 'forward_mean') -> pd.DataFrame:
"""
Fill NA values in lagged and rolling features.
IMPORTANT: Never uses backward fill to prevent data leakage in time series training.
Args:
df: DataFrame with potential NA values
strategy: 'forward_mean', 'zero', 'mean'
Returns:
DataFrame with filled NA values
"""
df = df.copy()
if strategy == 'forward_mean':
# Forward fill first (use previous values)
df = df.fillna(method='ffill')
# Fill remaining with mean (typically at beginning of series)
# NEVER use bfill as it leaks future information into training data
df = df.fillna(df.mean())
elif strategy == 'zero':
df = df.fillna(0)
elif strategy == 'mean':
df = df.fillna(df.mean())
return df

View File

@@ -0,0 +1,253 @@
"""
Event Feature Generator
Converts calendar events into features for demand forecasting
"""
import pandas as pd
import numpy as np
from typing import List, Dict, Any, Optional
from datetime import date, timedelta
import structlog
logger = structlog.get_logger()
class EventFeatureGenerator:
"""
Generate event-related features for demand forecasting.
Features include:
- Binary flags for event presence
- Event impact multipliers
- Event type indicators
- Days until/since major events
"""
# Event type impact weights (default multipliers)
EVENT_IMPACT_WEIGHTS = {
'promotion': 1.3,
'festival': 1.8,
'holiday': 0.7, # Bakeries often close or have reduced demand
'weather_event': 0.8, # Bad weather reduces foot traffic
'school_break': 1.2,
'sport_event': 1.4,
'market': 1.5,
'concert': 1.3,
'local_event': 1.2
}
def __init__(self):
pass
def generate_event_features(
self,
dates: pd.DatetimeIndex,
events: List[Dict[str, Any]]
) -> pd.DataFrame:
"""
Generate event features for given dates.
Args:
dates: Dates to generate features for
events: List of event dictionaries with keys:
- event_date: date
- event_type: str
- impact_multiplier: float (optional)
- event_name: str
Returns:
DataFrame with event features
"""
df = pd.DataFrame({'date': dates})
# Initialize feature columns
df['has_event'] = 0
df['event_impact'] = 1.0 # Neutral impact
df['is_promotion'] = 0
df['is_festival'] = 0
df['is_local_event'] = 0
df['days_to_next_event'] = 365
df['days_since_last_event'] = 365
if not events:
logger.debug("No events provided, returning default features")
return df
# Convert events to DataFrame for easier processing
events_df = pd.DataFrame(events)
events_df['event_date'] = pd.to_datetime(events_df['event_date'])
for idx, row in df.iterrows():
current_date = pd.to_datetime(row['date'])
# Check if there's an event on this date
day_events = events_df[events_df['event_date'] == current_date]
if not day_events.empty:
df.at[idx, 'has_event'] = 1
# Use custom impact multiplier if provided, else use default
if 'impact_multiplier' in day_events.columns and not day_events['impact_multiplier'].isna().all():
impact = day_events['impact_multiplier'].max()
else:
# Use default impact based on event type
event_types = day_events['event_type'].tolist()
impacts = [self.EVENT_IMPACT_WEIGHTS.get(et, 1.0) for et in event_types]
impact = max(impacts)
df.at[idx, 'event_impact'] = impact
# Set event type flags
event_types = day_events['event_type'].tolist()
if 'promotion' in event_types:
df.at[idx, 'is_promotion'] = 1
if 'festival' in event_types:
df.at[idx, 'is_festival'] = 1
if 'local_event' in event_types or 'market' in event_types:
df.at[idx, 'is_local_event'] = 1
# Calculate days to/from nearest event
future_events = events_df[events_df['event_date'] > current_date]
if not future_events.empty:
next_event_date = future_events['event_date'].min()
df.at[idx, 'days_to_next_event'] = (next_event_date - current_date).days
past_events = events_df[events_df['event_date'] < current_date]
if not past_events.empty:
last_event_date = past_events['event_date'].max()
df.at[idx, 'days_since_last_event'] = (current_date - last_event_date).days
# Cap days values at 365
df['days_to_next_event'] = df['days_to_next_event'].clip(upper=365)
df['days_since_last_event'] = df['days_since_last_event'].clip(upper=365)
logger.debug("Generated event features",
total_days=len(df),
days_with_events=df['has_event'].sum())
return df
def add_event_features_to_forecast_data(
self,
forecast_data: pd.DataFrame,
event_features: pd.DataFrame
) -> pd.DataFrame:
"""
Add event features to forecast input data.
Args:
forecast_data: Existing forecast data with 'date' column
event_features: Event features from generate_event_features()
Returns:
Enhanced forecast data with event features
"""
forecast_data = forecast_data.copy()
forecast_data['date'] = pd.to_datetime(forecast_data['date'])
event_features['date'] = pd.to_datetime(event_features['date'])
# Merge event features
enhanced_data = forecast_data.merge(
event_features[[
'date', 'has_event', 'event_impact', 'is_promotion',
'is_festival', 'is_local_event', 'days_to_next_event',
'days_since_last_event'
]],
on='date',
how='left'
)
# Fill missing with defaults
enhanced_data['has_event'].fillna(0, inplace=True)
enhanced_data['event_impact'].fillna(1.0, inplace=True)
enhanced_data['is_promotion'].fillna(0, inplace=True)
enhanced_data['is_festival'].fillna(0, inplace=True)
enhanced_data['is_local_event'].fillna(0, inplace=True)
enhanced_data['days_to_next_event'].fillna(365, inplace=True)
enhanced_data['days_since_last_event'].fillna(365, inplace=True)
return enhanced_data
def get_event_summary(self, events: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Get summary statistics about events.
Args:
events: List of event dictionaries
Returns:
Summary dict with counts by type, avg impact, etc.
"""
if not events:
return {
'total_events': 0,
'events_by_type': {},
'avg_impact': 1.0
}
events_df = pd.DataFrame(events)
summary = {
'total_events': len(events),
'events_by_type': events_df['event_type'].value_counts().to_dict(),
'date_range': {
'start': events_df['event_date'].min().isoformat() if not events_df.empty else None,
'end': events_df['event_date'].max().isoformat() if not events_df.empty else None
}
}
if 'impact_multiplier' in events_df.columns:
summary['avg_impact'] = float(events_df['impact_multiplier'].mean())
return summary
def create_event_calendar_features(
dates: pd.DatetimeIndex,
tenant_id: str,
event_repository = None
) -> pd.DataFrame:
"""
Convenience function to fetch events from database and generate features.
Args:
dates: Dates to generate features for
tenant_id: Tenant UUID
event_repository: EventRepository instance (optional)
Returns:
DataFrame with event features
"""
if event_repository is None:
logger.warning("No event repository provided, using empty events")
events = []
else:
# Fetch events from database
from datetime import date
start_date = dates.min().date()
end_date = dates.max().date()
try:
import asyncio
from uuid import UUID
loop = asyncio.get_event_loop()
events_objects = loop.run_until_complete(
event_repository.get_events_by_date_range(
tenant_id=UUID(tenant_id),
start_date=start_date,
end_date=end_date,
confirmed_only=False
)
)
# Convert to dict format
events = [event.to_dict() for event in events_objects]
except Exception as e:
logger.error(f"Failed to fetch events from database: {e}")
events = []
# Generate features
generator = EventFeatureGenerator()
return generator.generate_event_features(dates, events)

View File

@@ -0,0 +1,463 @@
"""
Hybrid Prophet + XGBoost Trainer
Combines Prophet's seasonality modeling with XGBoost's pattern learning
"""
import pandas as pd
import numpy as np
import io
from typing import Dict, List, Any, Optional, Tuple
import structlog
from datetime import datetime, timezone
import joblib
from sklearn.metrics import mean_absolute_error, mean_squared_error, mean_absolute_percentage_error
from sklearn.model_selection import TimeSeriesSplit
import warnings
warnings.filterwarnings('ignore')
# Import XGBoost
try:
import xgboost as xgb
except ImportError:
raise ImportError("XGBoost not installed. Run: pip install xgboost")
from app.ml.prophet_manager import BakeryProphetManager
from app.ml.enhanced_features import AdvancedFeatureEngineer
logger = structlog.get_logger()
class HybridProphetXGBoost:
"""
Hybrid forecasting model combining Prophet and XGBoost.
Approach:
1. Train Prophet on historical data (captures trend, seasonality, holidays)
2. Calculate residuals (actual - prophet_prediction)
3. Train XGBoost on residuals using enhanced features
4. Final prediction = prophet_prediction + xgboost_residual_prediction
Benefits:
- Prophet handles seasonality, holidays, trends
- XGBoost captures complex patterns Prophet misses
- Maintains Prophet's interpretability
- Improves accuracy by 10-25% over Prophet alone
"""
def __init__(self, database_manager=None):
self.prophet_manager = BakeryProphetManager(database_manager)
self.feature_engineer = AdvancedFeatureEngineer()
self.xgb_model = None
self.feature_columns = []
self.prophet_model_data = None
async def train_hybrid_model(
self,
tenant_id: str,
inventory_product_id: str,
df: pd.DataFrame,
job_id: str,
validation_split: float = 0.2,
session = None
) -> Dict[str, Any]:
"""
Train hybrid Prophet + XGBoost model.
Args:
tenant_id: Tenant identifier
inventory_product_id: Product identifier
df: Training data (must have 'ds', 'y' and regressor columns)
job_id: Training job identifier
validation_split: Fraction of data for validation
session: Optional database session (uses parent session if provided to avoid nested sessions)
Returns:
Dictionary with model metadata and performance metrics
"""
logger.info(
"Starting hybrid Prophet + XGBoost training",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
data_points=len(df)
)
# Step 1: Train Prophet model (base forecaster)
logger.info("Step 1: Training Prophet base model")
# ✅ FIX: Pass session to prophet_manager to avoid nested session issues
prophet_result = await self.prophet_manager.train_bakery_model(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
df=df.copy(),
job_id=job_id,
session=session
)
self.prophet_model_data = prophet_result
# Step 2: Create enhanced features for XGBoost
logger.info("Step 2: Engineering enhanced features for XGBoost")
df_enhanced = self._prepare_xgboost_features(df)
# Step 3: Split into train/validation
split_idx = int(len(df_enhanced) * (1 - validation_split))
train_df = df_enhanced.iloc[:split_idx].copy()
val_df = df_enhanced.iloc[split_idx:].copy()
logger.info(
"Data split",
train_samples=len(train_df),
val_samples=len(val_df)
)
# Step 4: Get Prophet predictions on training data
logger.info("Step 3: Generating Prophet predictions for residual calculation")
train_prophet_pred = await self._get_prophet_predictions(prophet_result, train_df)
val_prophet_pred = await self._get_prophet_predictions(prophet_result, val_df)
# Step 5: Calculate residuals (actual - prophet_prediction)
train_residuals = train_df['y'].values - train_prophet_pred
val_residuals = val_df['y'].values - val_prophet_pred
logger.info(
"Residuals calculated",
train_residual_mean=float(np.mean(train_residuals)),
train_residual_std=float(np.std(train_residuals))
)
# Step 6: Prepare feature matrix for XGBoost
X_train = train_df[self.feature_columns].values
X_val = val_df[self.feature_columns].values
# Step 7: Train XGBoost on residuals
logger.info("Step 4: Training XGBoost on residuals")
self.xgb_model = await self._train_xgboost(
X_train, train_residuals,
X_val, val_residuals
)
# Step 8: Evaluate hybrid model
logger.info("Step 5: Evaluating hybrid model performance")
metrics = await self._evaluate_hybrid_model(
train_df, val_df,
train_prophet_pred, val_prophet_pred,
prophet_result
)
# Step 9: Save hybrid model
model_data = self._package_hybrid_model(
prophet_result, metrics, tenant_id, inventory_product_id
)
logger.info(
"Hybrid model training complete",
prophet_mape=metrics['prophet_val_mape'],
hybrid_mape=metrics['hybrid_val_mape'],
improvement_pct=metrics['improvement_percentage']
)
return model_data
def _prepare_xgboost_features(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Prepare enhanced features for XGBoost.
Args:
df: Base dataframe with 'ds', 'y' and regressor columns
Returns:
DataFrame with all enhanced features
"""
# Rename 'ds' to 'date' for feature engineering
df_prep = df.copy()
if 'ds' in df_prep.columns:
df_prep['date'] = df_prep['ds']
# Ensure 'quantity' column for feature engineering
if 'y' in df_prep.columns:
df_prep['quantity'] = df_prep['y']
# Create all enhanced features
df_enhanced = self.feature_engineer.create_all_features(
df_prep,
date_column='date',
include_lags=True,
include_rolling=True,
include_interactions=True,
include_cyclical=True
)
# Fill NA values (from lagged features at beginning)
df_enhanced = self.feature_engineer.fill_na_values(df_enhanced)
# Get feature column list (excluding target and date columns)
self.feature_columns = [
col for col in self.feature_engineer.get_feature_columns()
if col in df_enhanced.columns
]
# Also include original regressor columns if present
regressor_cols = [
col for col in df.columns
if col not in ['ds', 'y', 'date', 'quantity'] and col in df_enhanced.columns
]
self.feature_columns.extend(regressor_cols)
self.feature_columns = list(set(self.feature_columns)) # Remove duplicates
logger.info(f"Prepared {len(self.feature_columns)} features for XGBoost")
return df_enhanced
async def _get_prophet_predictions(
self,
prophet_result: Dict[str, Any],
df: pd.DataFrame
) -> np.ndarray:
"""
Get Prophet predictions for given dataframe.
Args:
prophet_result: Prophet model result from training (contains model_path)
df: DataFrame with 'ds' column
Returns:
Array of predictions
"""
# Get the model path from result instead of expecting the model object directly
model_path = prophet_result.get('model_path')
if model_path is None:
raise ValueError("Prophet model path not found in result")
# Load the actual Prophet model from the stored path
try:
if model_path.startswith("minio://"):
# Use prophet_manager to load from MinIO
prophet_model = await self.prophet_manager._load_model_from_minio(model_path)
else:
# Fallback to direct loading for local paths
import joblib
prophet_model = joblib.load(model_path)
except Exception as e:
raise ValueError(f"Failed to load Prophet model from path {model_path}: {str(e)}")
# Prepare dataframe for prediction
pred_df = df[['ds']].copy()
# Add regressors if present
regressor_cols = [col for col in df.columns if col not in ['ds', 'y', 'date', 'quantity']]
for col in regressor_cols:
if col in df.columns:
pred_df[col] = df[col]
# Get predictions
forecast = prophet_model.predict(pred_df)
return forecast['yhat'].values
async def _train_xgboost(
self,
X_train: np.ndarray,
y_train: np.ndarray,
X_val: np.ndarray,
y_val: np.ndarray
) -> xgb.XGBRegressor:
"""
Train XGBoost model on residuals.
Args:
X_train: Training features
y_train: Training residuals
X_val: Validation features
y_val: Validation residuals
Returns:
Trained XGBoost model
"""
# XGBoost parameters optimized for residual learning
params = {
'n_estimators': 100,
'max_depth': 3, # Shallow trees to prevent overfitting
'learning_rate': 0.1,
'subsample': 0.8,
'colsample_bytree': 0.8,
'min_child_weight': 3,
'reg_alpha': 0.1, # L1 regularization
'reg_lambda': 1.0, # L2 regularization
'objective': 'reg:squarederror',
'random_state': 42,
'n_jobs': -1,
'early_stopping_rounds': 10
}
# Initialize model
model = xgb.XGBRegressor(**params)
# ✅ FIX: Run blocking model.fit() in thread pool to avoid blocking event loop
import asyncio
await asyncio.to_thread(
model.fit,
X_train, y_train,
eval_set=[(X_val, y_val)],
verbose=False
)
logger.info(
"XGBoost training complete",
best_iteration=model.best_iteration if hasattr(model, 'best_iteration') else None
)
return model
async def _evaluate_hybrid_model(
self,
train_df: pd.DataFrame,
val_df: pd.DataFrame,
train_prophet_pred: np.ndarray,
val_prophet_pred: np.ndarray,
prophet_result: Dict[str, Any]
) -> Dict[str, Any]:
"""
Evaluate the overall performance of the hybrid model using threading for metrics.
"""
import asyncio
# Get XGBoost predictions on training and validation
X_train = train_df[self.feature_columns].values
X_val = val_df[self.feature_columns].values
train_xgb_pred = await asyncio.to_thread(self.xgb_model.predict, X_train)
val_xgb_pred = await asyncio.to_thread(self.xgb_model.predict, X_val)
# Hybrid prediction = Prophet prediction + XGBoost residual prediction
train_hybrid_pred = train_prophet_pred + train_xgb_pred
val_hybrid_pred = val_prophet_pred + val_xgb_pred
actual_train = train_df['y'].values
actual_val = val_df['y'].values
# Basic RMSE calculation
train_rmse = float(np.sqrt(np.mean((actual_train - train_hybrid_pred)**2)))
val_rmse = float(np.sqrt(np.mean((actual_val - val_hybrid_pred)**2)))
# MAE
train_mae = float(np.mean(np.abs(actual_train - train_hybrid_pred)))
val_mae = float(np.mean(np.abs(actual_val - val_hybrid_pred)))
# MAPE (with safety for zero sales)
train_mape = float(np.mean(np.abs((actual_train - train_hybrid_pred) / np.maximum(actual_train, 1))))
val_mape = float(np.mean(np.abs((actual_val - val_hybrid_pred) / np.maximum(actual_val, 1))))
# Calculate improvement
prophet_metrics = prophet_result.get("metrics", {})
prophet_val_mae = prophet_metrics.get("val_mae", val_mae) # Fallback to hybrid if missing
prophet_val_mape = prophet_metrics.get("val_mape", val_mape)
improvement_pct = 0.0
if prophet_val_mape > 0:
improvement_pct = ((prophet_val_mape - val_mape) / prophet_val_mape) * 100
metrics = {
"train_rmse": train_rmse,
"val_rmse": val_rmse,
"train_mae": train_mae,
"val_mae": val_mae,
"train_mape": train_mape,
"val_mape": val_mape,
"prophet_val_mape": prophet_val_mape,
"hybrid_val_mape": val_mape,
"improvement_percentage": float(improvement_pct),
"prophet_metrics": prophet_metrics
}
logger.info(
"Hybrid model evaluation complete",
val_rmse=val_rmse,
val_mae=val_mae,
val_mape=val_mape,
improvement=improvement_pct
)
return metrics
def _package_hybrid_model(
self,
prophet_result: Dict[str, Any],
metrics: Dict[str, Any],
tenant_id: str,
inventory_product_id: str
) -> Dict[str, Any]:
"""
Package hybrid model for storage.
"""
return {
'model_type': 'hybrid_prophet_xgboost',
'prophet_model_path': prophet_result.get('model_path'),
'xgboost_model': self.xgb_model,
'feature_columns': self.feature_columns,
'metrics': metrics,
'tenant_id': tenant_id,
'inventory_product_id': inventory_product_id,
'trained_at': datetime.now(timezone.utc).isoformat()
}
async def predict(
self,
future_df: pd.DataFrame,
model_data: Dict[str, Any]
) -> pd.DataFrame:
"""
Make predictions using hybrid model.
Args:
future_df: DataFrame with future dates and regressors
model_data: Loaded hybrid model data
Returns:
DataFrame with predictions
"""
# Step 1: Get Prophet model from path and make predictions
prophet_model_path = model_data.get('prophet_model_path')
if prophet_model_path is None:
raise ValueError("Prophet model path not found in model data")
# Load the Prophet model from the stored path
try:
if prophet_model_path.startswith("minio://"):
# Use prophet_manager to load from MinIO
prophet_model = await self.prophet_manager._load_model_from_minio(prophet_model_path)
else:
# Fallback to direct loading for local paths
import joblib
prophet_model = joblib.load(prophet_model_path)
except Exception as e:
raise ValueError(f"Failed to load Prophet model from path {prophet_model_path}: {str(e)}")
# ✅ FIX: Run blocking predict() in thread pool to avoid blocking event loop
import asyncio
prophet_forecast = await asyncio.to_thread(prophet_model.predict, future_df)
# Step 2: Prepare features for XGBoost
future_enhanced = self._prepare_xgboost_features(future_df)
# Step 3: Get XGBoost predictions
xgb_model = model_data['xgboost_model']
feature_columns = model_data['feature_columns']
X_future = future_enhanced[feature_columns].values
# ✅ FIX: Run blocking predict() in thread pool to avoid blocking event loop
xgb_pred = await asyncio.to_thread(xgb_model.predict, X_future)
# Step 4: Combine predictions
hybrid_pred = prophet_forecast['yhat'].values + xgb_pred
# Step 5: Create result dataframe
result = pd.DataFrame({
'ds': future_df['ds'],
'prophet_yhat': prophet_forecast['yhat'],
'xgb_adjustment': xgb_pred,
'yhat': hybrid_pred,
'yhat_lower': prophet_forecast['yhat_lower'] + xgb_pred,
'yhat_upper': prophet_forecast['yhat_upper'] + xgb_pred
})
return result

View File

@@ -0,0 +1,257 @@
"""
Model Selection System
Determines whether to use Prophet-only or Hybrid Prophet+XGBoost models
"""
import pandas as pd
import numpy as np
from typing import Dict, Any, Optional
import structlog
logger = structlog.get_logger()
class ModelSelector:
"""
Intelligent model selection based on data characteristics.
Decision Criteria:
- Data size: Hybrid needs more data (min 90 days)
- Complexity: High variance benefits from XGBoost
- Seasonality strength: Weak seasonality benefits from XGBoost
- Historical performance: Compare models on validation set
"""
# Thresholds for model selection
MIN_DATA_POINTS_HYBRID = 90 # Minimum data points for hybrid
HIGH_VARIANCE_THRESHOLD = 0.5 # CV > 0.5 suggests complex patterns
LOW_SEASONALITY_THRESHOLD = 0.3 # Weak seasonal patterns
HYBRID_IMPROVEMENT_THRESHOLD = 0.05 # 5% MAPE improvement to justify hybrid
def __init__(self):
pass
def select_model_type(
self,
df: pd.DataFrame,
product_category: str = "unknown",
force_prophet: bool = False,
force_hybrid: bool = False
) -> str:
"""
Select best model type based on data characteristics.
Args:
df: Training data with 'y' column
product_category: Product category (bread, pastries, etc.)
force_prophet: Force Prophet-only model
force_hybrid: Force hybrid model
Returns:
"prophet" or "hybrid"
"""
# Honor forced selections
if force_prophet:
logger.info("Prophet-only model forced by configuration")
return "prophet"
if force_hybrid:
logger.info("Hybrid model forced by configuration")
return "hybrid"
# Check minimum data requirements
if len(df) < self.MIN_DATA_POINTS_HYBRID:
logger.info(
"Insufficient data for hybrid model, using Prophet",
data_points=len(df),
min_required=self.MIN_DATA_POINTS_HYBRID
)
return "prophet"
# Calculate data characteristics
characteristics = self._analyze_data_characteristics(df)
# Decision logic
score_hybrid = 0
score_prophet = 0
# Factor 1: Data complexity (variance)
if characteristics['coefficient_of_variation'] > self.HIGH_VARIANCE_THRESHOLD:
score_hybrid += 2
logger.debug("High variance detected, favoring hybrid", cv=characteristics['coefficient_of_variation'])
else:
score_prophet += 1
# Factor 2: Seasonality strength
if characteristics['seasonality_strength'] < self.LOW_SEASONALITY_THRESHOLD:
score_hybrid += 2
logger.debug("Weak seasonality detected, favoring hybrid", strength=characteristics['seasonality_strength'])
else:
score_prophet += 1
# Factor 3: Data size (more data = better for hybrid)
if len(df) > 180:
score_hybrid += 1
elif len(df) < 120:
score_prophet += 1
# Factor 4: Product category considerations
if product_category in ['seasonal', 'cakes']:
# Event-driven products benefit from XGBoost pattern learning
score_hybrid += 1
elif product_category in ['bread', 'savory']:
# Stable products work well with Prophet
score_prophet += 1
# Factor 5: Zero ratio (sparse data)
if characteristics['zero_ratio'] > 0.3:
# High zero ratio suggests difficult forecasting, hybrid might help
score_hybrid += 1
# Make decision
selected_model = "hybrid" if score_hybrid > score_prophet else "prophet"
logger.info(
"Model selection complete",
selected_model=selected_model,
score_hybrid=score_hybrid,
score_prophet=score_prophet,
data_points=len(df),
cv=characteristics['coefficient_of_variation'],
seasonality=characteristics['seasonality_strength'],
category=product_category
)
return selected_model
def _analyze_data_characteristics(self, df: pd.DataFrame) -> Dict[str, float]:
"""
Analyze time series characteristics.
Args:
df: DataFrame with 'y' column (sales data)
Returns:
Dictionary with data characteristics
"""
y = df['y'].values
# Coefficient of variation
cv = np.std(y) / np.mean(y) if np.mean(y) > 0 else 0
# Zero ratio
zero_ratio = (y == 0).sum() / len(y)
# Seasonality strength using autocorrelation at key lags (7 days, 30 days)
# This better captures periodic patterns without using future data
if len(df) >= 14:
# Calculate autocorrelation at weekly lag (7 days)
# Higher autocorrelation indicates stronger weekly patterns
try:
weekly_autocorr = pd.Series(y).autocorr(lag=7) if len(y) > 7 else 0
# Calculate autocorrelation at monthly lag if enough data
monthly_autocorr = pd.Series(y).autocorr(lag=30) if len(y) > 30 else 0
# Combine autocorrelations (weekly weighted more for bakery data)
seasonality_strength = abs(weekly_autocorr) * 0.7 + abs(monthly_autocorr) * 0.3
# Ensure in valid range [0, 1]
seasonality_strength = max(0.0, min(1.0, seasonality_strength))
except Exception:
# Fallback to simpler calculation if autocorrelation fails
seasonality_strength = 0.5
else:
seasonality_strength = 0.5 # Default
# Trend strength
if len(df) >= 30:
from scipy import stats
x = np.arange(len(y))
slope, _, r_value, _, _ = stats.linregress(x, y)
trend_strength = abs(r_value)
else:
trend_strength = 0
return {
'coefficient_of_variation': float(cv),
'zero_ratio': float(zero_ratio),
'seasonality_strength': float(seasonality_strength),
'trend_strength': float(trend_strength),
'mean': float(np.mean(y)),
'std': float(np.std(y))
}
def compare_models(
self,
prophet_metrics: Dict[str, float],
hybrid_metrics: Dict[str, float]
) -> str:
"""
Compare Prophet and Hybrid model performance.
Args:
prophet_metrics: Prophet model metrics (with 'mape' key)
hybrid_metrics: Hybrid model metrics (with 'mape' key)
Returns:
"prophet" or "hybrid" based on better performance
"""
prophet_mape = prophet_metrics.get('mape', float('inf'))
hybrid_mape = hybrid_metrics.get('mape', float('inf'))
# Calculate improvement
if prophet_mape > 0:
improvement = (prophet_mape - hybrid_mape) / prophet_mape
else:
improvement = 0
# Hybrid must improve by at least threshold to justify complexity
if improvement >= self.HYBRID_IMPROVEMENT_THRESHOLD:
logger.info(
"Hybrid model selected based on performance",
prophet_mape=prophet_mape,
hybrid_mape=hybrid_mape,
improvement=f"{improvement*100:.1f}%"
)
return "hybrid"
else:
logger.info(
"Prophet model selected (hybrid improvement insufficient)",
prophet_mape=prophet_mape,
hybrid_mape=hybrid_mape,
improvement=f"{improvement*100:.1f}%"
)
return "prophet"
def should_use_hybrid_model(
df: pd.DataFrame,
product_category: str = "unknown",
tenant_settings: Dict[str, Any] = None
) -> bool:
"""
Convenience function to determine if hybrid model should be used.
Args:
df: Training data
product_category: Product category
tenant_settings: Optional tenant-specific settings
Returns:
True if hybrid model should be used, False otherwise
"""
selector = ModelSelector()
# Check tenant settings
force_prophet = tenant_settings.get('force_prophet_only', False) if tenant_settings else False
force_hybrid = tenant_settings.get('force_hybrid', False) if tenant_settings else False
selected = selector.select_model_type(
df=df,
product_category=product_category,
force_prophet=force_prophet,
force_hybrid=force_hybrid
)
return selected == "hybrid"

View File

@@ -0,0 +1,192 @@
"""
POI Feature Integrator
Integrates POI features into ML training pipeline.
Fetches POI context from External service and merges features into training data.
"""
from typing import Dict, Any, Optional, List
import structlog
import pandas as pd
from shared.clients.external_client import ExternalServiceClient
logger = structlog.get_logger()
class POIFeatureIntegrator:
"""
POI feature integration for ML training.
Fetches POI context from External service and adds features
to training dataframes for location-based demand forecasting.
"""
def __init__(self, external_client: ExternalServiceClient = None):
"""
Initialize POI feature integrator.
Args:
external_client: External service client instance (optional)
"""
if external_client is None:
from app.core.config import settings
self.external_client = ExternalServiceClient(settings, "training-service")
else:
self.external_client = external_client
async def fetch_poi_features(
self,
tenant_id: str,
latitude: float,
longitude: float,
force_refresh: bool = False
) -> Optional[Dict[str, Any]]:
"""
Fetch POI features for tenant location (optimized for training).
First checks if POI context exists. If not, returns None without triggering detection.
POI detection should be triggered during tenant registration, not during training.
Args:
tenant_id: Tenant UUID
latitude: Bakery latitude
longitude: Bakery longitude
force_refresh: Force re-detection (only use if POI context already exists)
Returns:
Dictionary with POI features or None if not available
"""
try:
# Try to get existing POI context first
existing_context = await self.external_client.get_poi_context(tenant_id)
if existing_context:
poi_context = existing_context.get("poi_context", {})
ml_features = poi_context.get("ml_features", {})
# Check if stale and force_refresh is requested
is_stale = existing_context.get("is_stale", False)
if not is_stale or not force_refresh:
logger.info(
"Using existing POI context",
tenant_id=tenant_id,
is_stale=is_stale,
feature_count=len(ml_features)
)
return ml_features
else:
logger.info(
"POI context is stale and force_refresh=True, refreshing",
tenant_id=tenant_id
)
# Only refresh if explicitly requested and context exists
detection_result = await self.external_client.detect_poi_for_tenant(
tenant_id=tenant_id,
latitude=latitude,
longitude=longitude,
force_refresh=True
)
if detection_result:
poi_context = detection_result.get("poi_context", {})
ml_features = poi_context.get("ml_features", {})
logger.info(
"POI refresh completed",
tenant_id=tenant_id,
feature_count=len(ml_features)
)
return ml_features
else:
logger.warning(
"POI refresh failed, returning existing features",
tenant_id=tenant_id
)
return ml_features
else:
logger.info(
"No existing POI context found - POI detection should be triggered during tenant registration",
tenant_id=tenant_id
)
return None
except Exception as e:
logger.warning(
"Error fetching POI features - returning None",
tenant_id=tenant_id,
error=str(e)
)
return None
def add_poi_features_to_dataframe(
self,
df: pd.DataFrame,
poi_features: Dict[str, Any]
) -> pd.DataFrame:
"""
Add POI features to training dataframe.
POI features are static (don't vary by date), so they're
broadcast to all rows in the dataframe.
Args:
df: Training dataframe
poi_features: Dictionary of POI ML features
Returns:
Dataframe with POI features added as columns
"""
if not poi_features:
logger.warning("No POI features to add")
return df
logger.info(
"Adding POI features to dataframe",
feature_count=len(poi_features),
dataframe_rows=len(df)
)
# Add each POI feature as a column with constant value
for feature_name, feature_value in poi_features.items():
df[feature_name] = feature_value
logger.info(
"POI features added successfully",
new_columns=list(poi_features.keys())
)
return df
def get_poi_feature_names(self, poi_features: Dict[str, Any]) -> List[str]:
"""
Get list of POI feature names for model registration.
Args:
poi_features: Dictionary of POI ML features
Returns:
List of feature names
"""
return list(poi_features.keys()) if poi_features else []
async def check_poi_service_health(self) -> bool:
"""
Check if POI service is accessible through the external client.
Returns:
True if service is healthy, False otherwise
"""
try:
# We can test the external service health by attempting to get POI context for a dummy tenant
# This will go through the proper authentication and routing
dummy_context = await self.external_client.get_poi_context("test-tenant")
# If we can successfully make a request (even if it returns None for missing tenant),
# it means the service is accessible
return True
except Exception as e:
logger.error(
"POI service health check failed",
error=str(e)
)
return False

View File

@@ -0,0 +1,361 @@
"""
Product Categorization System
Classifies bakery products into categories for category-specific forecasting
"""
import pandas as pd
import numpy as np
from typing import Dict, List, Optional, Tuple
from enum import Enum
import structlog
logger = structlog.get_logger()
class ProductCategory(str, Enum):
"""Product categories for bakery items"""
BREAD = "bread"
PASTRIES = "pastries"
CAKES = "cakes"
DRINKS = "drinks"
SEASONAL = "seasonal"
SAVORY = "savory"
UNKNOWN = "unknown"
class ProductCategorizer:
"""
Automatic product categorization based on product name and sales patterns.
Categories have different characteristics:
- BREAD: Daily staple, high volume, consistent demand, short shelf life (1 day)
- PASTRIES: Morning peak, weekend boost, medium shelf life (2-3 days)
- CAKES: Event-driven, weekends, advance orders, longer shelf life (3-5 days)
- DRINKS: Weather-dependent, hot/cold seasonal patterns
- SEASONAL: Holiday-specific (roscón, panettone, etc.)
- SAVORY: Lunch peak, weekday focus
"""
def __init__(self):
# Keywords for automatic classification
self.category_keywords = {
ProductCategory.BREAD: [
'pan', 'baguette', 'hogaza', 'chapata', 'integral', 'centeno',
'bread', 'loaf', 'barra', 'molde', 'candeal'
],
ProductCategory.PASTRIES: [
'croissant', 'napolitana', 'palmera', 'ensaimada', 'magdalena',
'bollo', 'brioche', 'suizo', 'caracola', 'donut', 'berlina'
],
ProductCategory.CAKES: [
'tarta', 'pastel', 'bizcocho', 'cake', 'torta', 'milhojas',
'saint honoré', 'selva negra', 'tres leches'
],
ProductCategory.DRINKS: [
'café', 'coffee', '', 'tea', 'zumo', 'juice', 'batido',
'smoothie', 'refresco', 'agua', 'water'
],
ProductCategory.SEASONAL: [
'roscón', 'panettone', 'turrón', 'polvorón', 'mona de pascua',
'huevo de pascua', 'buñuelo', 'torrija'
],
ProductCategory.SAVORY: [
'empanada', 'quiche', 'pizza', 'focaccia', 'salado', 'bocadillo',
'sandwich', 'croqueta', 'hojaldre salado'
]
}
def categorize_product(
self,
product_name: str,
product_id: str = None,
sales_data: pd.DataFrame = None
) -> ProductCategory:
"""
Categorize a product based on name and optional sales patterns.
Args:
product_name: Product name
product_id: Optional product ID
sales_data: Optional historical sales data for pattern analysis
Returns:
ProductCategory enum
"""
# First try keyword matching
category = self._categorize_by_keywords(product_name)
if category != ProductCategory.UNKNOWN:
logger.info(f"Product categorized by keywords",
product=product_name,
category=category.value)
return category
# If no keyword match and we have sales data, analyze patterns
if sales_data is not None and len(sales_data) > 30:
category = self._categorize_by_sales_pattern(product_name, sales_data)
logger.info(f"Product categorized by sales pattern",
product=product_name,
category=category.value)
return category
logger.warning(f"Could not categorize product, using UNKNOWN",
product=product_name)
return ProductCategory.UNKNOWN
def _categorize_by_keywords(self, product_name: str) -> ProductCategory:
"""Categorize by matching keywords in product name"""
product_name_lower = product_name.lower()
# Check each category's keywords
for category, keywords in self.category_keywords.items():
for keyword in keywords:
if keyword in product_name_lower:
return category
return ProductCategory.UNKNOWN
def _categorize_by_sales_pattern(
self,
product_name: str,
sales_data: pd.DataFrame
) -> ProductCategory:
"""
Categorize by analyzing sales patterns.
Patterns:
- BREAD: Consistent daily sales, low variance
- PASTRIES: Weekend boost, morning peak
- CAKES: Weekend spike, event correlation
- DRINKS: Temperature correlation
- SEASONAL: Concentrated in specific months
- SAVORY: Weekday focus, lunch peak
"""
try:
# Ensure we have required columns
if 'date' not in sales_data.columns or 'quantity' not in sales_data.columns:
return ProductCategory.UNKNOWN
sales_data = sales_data.copy()
sales_data['date'] = pd.to_datetime(sales_data['date'])
sales_data['day_of_week'] = sales_data['date'].dt.dayofweek
sales_data['month'] = sales_data['date'].dt.month
sales_data['is_weekend'] = sales_data['day_of_week'].isin([5, 6])
# Calculate pattern metrics
weekend_avg = sales_data[sales_data['is_weekend']]['quantity'].mean()
weekday_avg = sales_data[~sales_data['is_weekend']]['quantity'].mean()
overall_avg = sales_data['quantity'].mean()
cv = sales_data['quantity'].std() / overall_avg if overall_avg > 0 else 0
# Weekend ratio
weekend_ratio = weekend_avg / weekday_avg if weekday_avg > 0 else 1.0
# Seasonal concentration (Gini coefficient for months)
monthly_sales = sales_data.groupby('month')['quantity'].sum()
seasonal_concentration = self._gini_coefficient(monthly_sales.values)
# Decision rules based on patterns
if seasonal_concentration > 0.6:
# High concentration in specific months = seasonal
return ProductCategory.SEASONAL
elif cv < 0.3 and weekend_ratio < 1.2:
# Low variance, consistent daily = bread
return ProductCategory.BREAD
elif weekend_ratio > 1.5:
# Strong weekend boost = cakes
return ProductCategory.CAKES
elif weekend_ratio > 1.2:
# Moderate weekend boost = pastries
return ProductCategory.PASTRIES
elif weekend_ratio < 0.9:
# Weekday focus = savory
return ProductCategory.SAVORY
else:
return ProductCategory.UNKNOWN
except Exception as e:
logger.error(f"Error analyzing sales pattern: {e}")
return ProductCategory.UNKNOWN
def _gini_coefficient(self, values: np.ndarray) -> float:
"""Calculate Gini coefficient for concentration measurement"""
if len(values) == 0:
return 0.0
sorted_values = np.sort(values)
n = len(values)
cumsum = np.cumsum(sorted_values)
# Gini coefficient formula
return (2 * np.sum((np.arange(1, n + 1) * sorted_values))) / (n * cumsum[-1]) - (n + 1) / n
def get_category_characteristics(self, category: ProductCategory) -> Dict[str, any]:
"""
Get forecasting characteristics for a category.
Returns hyperparameters and settings specific to the category.
"""
characteristics = {
ProductCategory.BREAD: {
"shelf_life_days": 1,
"demand_stability": "high",
"seasonality_strength": "low",
"weekend_factor": 0.95, # Slightly lower on weekends
"holiday_factor": 0.7, # Much lower on holidays
"weather_sensitivity": "low",
"prophet_params": {
"seasonality_mode": "additive",
"yearly_seasonality": False,
"weekly_seasonality": True,
"daily_seasonality": False,
"changepoint_prior_scale": 0.01, # Very stable
"seasonality_prior_scale": 5.0
}
},
ProductCategory.PASTRIES: {
"shelf_life_days": 2,
"demand_stability": "medium",
"seasonality_strength": "medium",
"weekend_factor": 1.3, # Boost on weekends
"holiday_factor": 1.1, # Slight boost on holidays
"weather_sensitivity": "medium",
"prophet_params": {
"seasonality_mode": "multiplicative",
"yearly_seasonality": True,
"weekly_seasonality": True,
"daily_seasonality": False,
"changepoint_prior_scale": 0.05,
"seasonality_prior_scale": 10.0
}
},
ProductCategory.CAKES: {
"shelf_life_days": 4,
"demand_stability": "low",
"seasonality_strength": "high",
"weekend_factor": 2.0, # Large weekend boost
"holiday_factor": 1.5, # Holiday boost
"weather_sensitivity": "low",
"prophet_params": {
"seasonality_mode": "multiplicative",
"yearly_seasonality": True,
"weekly_seasonality": True,
"daily_seasonality": False,
"changepoint_prior_scale": 0.1, # More flexible
"seasonality_prior_scale": 15.0
}
},
ProductCategory.DRINKS: {
"shelf_life_days": 1,
"demand_stability": "medium",
"seasonality_strength": "high",
"weekend_factor": 1.1,
"holiday_factor": 1.2,
"weather_sensitivity": "very_high",
"prophet_params": {
"seasonality_mode": "multiplicative",
"yearly_seasonality": True,
"weekly_seasonality": True,
"daily_seasonality": False,
"changepoint_prior_scale": 0.08,
"seasonality_prior_scale": 12.0
}
},
ProductCategory.SEASONAL: {
"shelf_life_days": 7,
"demand_stability": "very_low",
"seasonality_strength": "very_high",
"weekend_factor": 1.2,
"holiday_factor": 3.0, # Massive holiday boost
"weather_sensitivity": "low",
"prophet_params": {
"seasonality_mode": "multiplicative",
"yearly_seasonality": True,
"weekly_seasonality": False,
"daily_seasonality": False,
"changepoint_prior_scale": 0.2, # Very flexible
"seasonality_prior_scale": 20.0
}
},
ProductCategory.SAVORY: {
"shelf_life_days": 1,
"demand_stability": "medium",
"seasonality_strength": "low",
"weekend_factor": 0.8, # Lower on weekends
"holiday_factor": 0.6, # Much lower on holidays
"weather_sensitivity": "medium",
"prophet_params": {
"seasonality_mode": "additive",
"yearly_seasonality": False,
"weekly_seasonality": True,
"daily_seasonality": False,
"changepoint_prior_scale": 0.03,
"seasonality_prior_scale": 7.0
}
},
ProductCategory.UNKNOWN: {
"shelf_life_days": 2,
"demand_stability": "medium",
"seasonality_strength": "medium",
"weekend_factor": 1.0,
"holiday_factor": 1.0,
"weather_sensitivity": "medium",
"prophet_params": {
"seasonality_mode": "multiplicative",
"yearly_seasonality": True,
"weekly_seasonality": True,
"daily_seasonality": False,
"changepoint_prior_scale": 0.05,
"seasonality_prior_scale": 10.0
}
}
}
return characteristics.get(category, characteristics[ProductCategory.UNKNOWN])
def batch_categorize(
self,
products: List[Dict[str, any]],
sales_data: pd.DataFrame = None
) -> Dict[str, ProductCategory]:
"""
Categorize multiple products at once.
Args:
products: List of dicts with 'id' and 'name' keys
sales_data: Optional sales data with 'inventory_product_id' column
Returns:
Dict mapping product_id to category
"""
results = {}
for product in products:
product_id = product.get('id')
product_name = product.get('name', '')
# Filter sales data for this product if available
product_sales = None
if sales_data is not None and 'inventory_product_id' in sales_data.columns:
product_sales = sales_data[
sales_data['inventory_product_id'] == product_id
].copy()
category = self.categorize_product(
product_name=product_name,
product_id=product_id,
sales_data=product_sales
)
results[product_id] = category
logger.info(f"Batch categorization complete",
total_products=len(products),
categories=dict(pd.Series(list(results.values())).value_counts()))
return results

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,284 @@
"""
Traffic Forecasting System
Predicts bakery foot traffic using weather and temporal features
"""
import pandas as pd
import numpy as np
from typing import Dict, List, Any, Optional
from prophet import Prophet
import structlog
from datetime import datetime, timedelta
logger = structlog.get_logger()
class TrafficForecaster:
"""
Forecast bakery foot traffic using Prophet with weather and temporal features.
Traffic patterns are influenced by:
- Weather: Temperature, precipitation, conditions
- Time: Day of week, holidays, season
- Special events: Local events, promotions
"""
def __init__(self):
self.model = None
self.is_trained = False
def train(
self,
historical_traffic: pd.DataFrame,
weather_data: pd.DataFrame = None
) -> Dict[str, Any]:
"""
Train traffic forecasting model.
Args:
historical_traffic: DataFrame with columns ['date', 'traffic_count']
weather_data: Optional weather data with columns ['date', 'temperature', 'precipitation', 'condition']
Returns:
Training metrics
"""
try:
logger.info("Training traffic forecasting model",
data_points=len(historical_traffic))
# Prepare Prophet format
df = historical_traffic.copy()
df = df.rename(columns={'date': 'ds', 'traffic_count': 'y'})
df['ds'] = pd.to_datetime(df['ds'])
df = df.sort_values('ds')
# Merge with weather data if available
if weather_data is not None:
weather_data = weather_data.copy()
weather_data['date'] = pd.to_datetime(weather_data['date'])
df = df.merge(weather_data, left_on='ds', right_on='date', how='left')
# Create Prophet model with custom settings for traffic
self.model = Prophet(
seasonality_mode='multiplicative',
yearly_seasonality=True,
weekly_seasonality=True,
daily_seasonality=False,
changepoint_prior_scale=0.05, # Moderate flexibility
seasonality_prior_scale=10.0,
holidays_prior_scale=10.0
)
# Add weather regressors if available
if 'temperature' in df.columns:
self.model.add_regressor('temperature')
if 'precipitation' in df.columns:
self.model.add_regressor('precipitation')
if 'is_rainy' in df.columns:
self.model.add_regressor('is_rainy')
# Add custom holidays for Spain
from app.ml.prophet_manager import BakeryProphetManager
spanish_holidays = self._get_spanish_holidays(
df['ds'].min().year,
df['ds'].max().year + 1
)
self.model.add_country_holidays(country_name='ES')
# Fit model
self.model.fit(df)
self.is_trained = True
# Calculate training metrics
predictions = self.model.predict(df)
metrics = self._calculate_metrics(df['y'].values, predictions['yhat'].values)
logger.info("Traffic forecasting model trained successfully",
mape=metrics['mape'],
rmse=metrics['rmse'])
return metrics
except Exception as e:
logger.error(f"Failed to train traffic forecasting model: {e}")
raise
def predict(
self,
future_dates: pd.DatetimeIndex,
weather_forecast: pd.DataFrame = None
) -> pd.DataFrame:
"""
Predict traffic for future dates.
Args:
future_dates: Dates to predict traffic for
weather_forecast: Optional weather forecast data
Returns:
DataFrame with columns ['date', 'predicted_traffic', 'yhat_lower', 'yhat_upper']
"""
if not self.is_trained:
raise ValueError("Model not trained. Call train() first.")
try:
# Create future dataframe
future = pd.DataFrame({'ds': future_dates})
# Add weather features if available
if weather_forecast is not None:
weather_forecast = weather_forecast.copy()
weather_forecast['date'] = pd.to_datetime(weather_forecast['date'])
future = future.merge(weather_forecast, left_on='ds', right_on='date', how='left')
# Fill missing weather with defaults
if 'temperature' in future.columns:
future['temperature'].fillna(15.0, inplace=True)
if 'precipitation' in future.columns:
future['precipitation'].fillna(0.0, inplace=True)
if 'is_rainy' in future.columns:
future['is_rainy'].fillna(0, inplace=True)
# Predict
forecast = self.model.predict(future)
# Format results
results = pd.DataFrame({
'date': forecast['ds'],
'predicted_traffic': forecast['yhat'].clip(lower=0), # Traffic can't be negative
'yhat_lower': forecast['yhat_lower'].clip(lower=0),
'yhat_upper': forecast['yhat_upper'].clip(lower=0)
})
logger.info("Traffic predictions generated",
dates=len(results),
avg_traffic=results['predicted_traffic'].mean())
return results
except Exception as e:
logger.error(f"Failed to predict traffic: {e}")
raise
def _calculate_metrics(self, actual: np.ndarray, predicted: np.ndarray) -> Dict[str, float]:
"""Calculate forecast accuracy metrics"""
mae = np.mean(np.abs(actual - predicted))
mse = np.mean((actual - predicted) ** 2)
rmse = np.sqrt(mse)
# MAPE (handle zeros)
mask = actual != 0
mape = np.mean(np.abs((actual[mask] - predicted[mask]) / actual[mask])) * 100 if mask.any() else 0
return {
'mae': float(mae),
'mse': float(mse),
'rmse': float(rmse),
'mape': float(mape)
}
def _get_spanish_holidays(self, start_year: int, end_year: int) -> pd.DataFrame:
"""Get Spanish holidays for the date range"""
try:
import holidays
es_holidays = holidays.Spain(years=range(start_year, end_year + 1))
holiday_dates = []
holiday_names = []
for date, name in es_holidays.items():
holiday_dates.append(date)
holiday_names.append(name)
return pd.DataFrame({
'ds': pd.to_datetime(holiday_dates),
'holiday': holiday_names
})
except Exception as e:
logger.warning(f"Could not load Spanish holidays: {e}")
return pd.DataFrame(columns=['ds', 'holiday'])
class TrafficFeatureGenerator:
"""
Generate traffic-related features for demand forecasting.
Uses predicted traffic as a feature in product demand models.
"""
def __init__(self, traffic_forecaster: TrafficForecaster = None):
self.traffic_forecaster = traffic_forecaster or TrafficForecaster()
def generate_traffic_features(
self,
dates: pd.DatetimeIndex,
weather_forecast: pd.DataFrame = None
) -> pd.DataFrame:
"""
Generate traffic features for given dates.
Args:
dates: Dates to generate features for
weather_forecast: Optional weather forecast
Returns:
DataFrame with traffic features
"""
if not self.traffic_forecaster.is_trained:
logger.warning("Traffic forecaster not trained, using default traffic values")
return pd.DataFrame({
'date': dates,
'predicted_traffic': 100.0, # Default baseline
'traffic_normalized': 1.0
})
# Predict traffic
traffic_predictions = self.traffic_forecaster.predict(dates, weather_forecast)
# Normalize traffic (0-2 range, 1 = average)
mean_traffic = traffic_predictions['predicted_traffic'].mean()
traffic_predictions['traffic_normalized'] = (
traffic_predictions['predicted_traffic'] / mean_traffic
).clip(0, 2)
# Add traffic categories
traffic_predictions['traffic_category'] = pd.cut(
traffic_predictions['predicted_traffic'],
bins=[0, 50, 100, 150, np.inf],
labels=['low', 'medium', 'high', 'very_high']
)
return traffic_predictions
def add_traffic_features_to_forecast_data(
self,
forecast_data: pd.DataFrame,
traffic_predictions: pd.DataFrame
) -> pd.DataFrame:
"""
Add traffic features to forecast input data.
Args:
forecast_data: Existing forecast data with 'date' column
traffic_predictions: Traffic predictions from generate_traffic_features()
Returns:
Enhanced forecast data with traffic features
"""
forecast_data = forecast_data.copy()
forecast_data['date'] = pd.to_datetime(forecast_data['date'])
traffic_predictions['date'] = pd.to_datetime(traffic_predictions['date'])
# Merge traffic features
enhanced_data = forecast_data.merge(
traffic_predictions[['date', 'predicted_traffic', 'traffic_normalized']],
on='date',
how='left'
)
# Fill missing with defaults
enhanced_data['predicted_traffic'].fillna(100.0, inplace=True)
enhanced_data['traffic_normalized'].fillna(1.0, inplace=True)
return enhanced_data

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,33 @@
"""
Training Service Models Package
Import all models to ensure they are registered with SQLAlchemy Base.
"""
# Import AuditLog model for this service
from shared.security import create_audit_log_model
from shared.database.base import Base
# Create audit log model for this service
AuditLog = create_audit_log_model(Base)
# Import all models to register them with the Base metadata
from .training import (
TrainedModel,
ModelTrainingLog,
ModelPerformanceMetric,
TrainingJobQueue,
ModelArtifact,
TrainingPerformanceMetrics,
)
# List all models for easier access
__all__ = [
"TrainedModel",
"ModelTrainingLog",
"ModelPerformanceMetric",
"TrainingJobQueue",
"ModelArtifact",
"TrainingPerformanceMetrics",
"AuditLog",
]

View File

@@ -0,0 +1,254 @@
# services/training/app/models/training.py
"""
Database models for training service
"""
from sqlalchemy import Column, Integer, String, DateTime, Text, Boolean, JSON, Float
from sqlalchemy.dialects.postgresql import UUID, ARRAY
from shared.database.base import Base
from datetime import datetime, timezone
import uuid
class ModelTrainingLog(Base):
"""
Table to track training job execution and status.
Replaces the old Celery task tracking.
"""
__tablename__ = "model_training_logs"
id = Column(Integer, primary_key=True, index=True)
job_id = Column(String(255), unique=True, index=True, nullable=False)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
status = Column(String(50), nullable=False, default="pending") # pending, running, completed, failed, cancelled
progress = Column(Integer, default=0) # 0-100 percentage
current_step = Column(String(500), default="")
# Timestamps
start_time = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
end_time = Column(DateTime(timezone=True), nullable=True)
# Configuration and results
config = Column(JSON, nullable=True) # Training job configuration
results = Column(JSON, nullable=True) # Training results
error_message = Column(Text, nullable=True)
# Metadata
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
class ModelPerformanceMetric(Base):
"""
Table to track model performance over time.
"""
__tablename__ = "model_performance_metrics"
id = Column(Integer, primary_key=True, index=True)
model_id = Column(String(255), index=True, nullable=False)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
inventory_product_id = Column(UUID(as_uuid=True), index=True, nullable=False)
# Performance metrics
mae = Column(Float, nullable=True) # Mean Absolute Error
mse = Column(Float, nullable=True) # Mean Squared Error
rmse = Column(Float, nullable=True) # Root Mean Squared Error
mape = Column(Float, nullable=True) # Mean Absolute Percentage Error
r2_score = Column(Float, nullable=True) # R-squared score
# Additional metrics
accuracy_percentage = Column(Float, nullable=True)
prediction_confidence = Column(Float, nullable=True)
# Evaluation information
evaluation_period_start = Column(DateTime, nullable=True)
evaluation_period_end = Column(DateTime, nullable=True)
evaluation_samples = Column(Integer, nullable=True)
# Metadata
measured_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
class TrainingJobQueue(Base):
"""
Table to manage training job queue and scheduling.
"""
__tablename__ = "training_job_queue"
id = Column(Integer, primary_key=True, index=True)
job_id = Column(String(255), unique=True, index=True, nullable=False)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
# Job configuration
job_type = Column(String(50), nullable=False) # full_training, single_product, evaluation
priority = Column(Integer, default=1) # Higher number = higher priority
config = Column(JSON, nullable=True)
# Scheduling information
scheduled_at = Column(DateTime, nullable=True)
started_at = Column(DateTime, nullable=True)
estimated_duration_minutes = Column(Integer, nullable=True)
# Status
status = Column(String(50), nullable=False, default="queued") # queued, running, completed, failed
retry_count = Column(Integer, default=0)
max_retries = Column(Integer, default=3)
# Metadata
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
cancelled_by = Column(String, nullable=True)
class ModelArtifact(Base):
"""
Table to track model files and artifacts.
"""
__tablename__ = "model_artifacts"
id = Column(Integer, primary_key=True, index=True)
model_id = Column(String(255), index=True, nullable=False)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
# Artifact information
artifact_type = Column(String(50), nullable=False) # model_file, metadata, training_data, etc.
file_path = Column(String(1000), nullable=False)
file_size_bytes = Column(Integer, nullable=True)
checksum = Column(String(255), nullable=True) # For file integrity
# Storage information
storage_location = Column(String(100), nullable=False, default="local") # local, s3, gcs, etc.
compression = Column(String(50), nullable=True) # gzip, lz4, etc.
# Metadata
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
expires_at = Column(DateTime(timezone=True), nullable=True) # For automatic cleanup
class TrainedModel(Base):
__tablename__ = "trained_models"
# Primary identification - Updated to use UUID properly
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
inventory_product_id = Column(UUID(as_uuid=True), nullable=False, index=True)
# Model information
model_type = Column(String, default="prophet_optimized")
model_version = Column(String, default="1.0")
job_id = Column(String, nullable=False)
# File storage
model_path = Column(String, nullable=False) # Path to the .pkl file
metadata_path = Column(String) # Path to metadata JSON
# Training metrics
mape = Column(Float)
mae = Column(Float)
rmse = Column(Float)
r2_score = Column(Float)
training_samples = Column(Integer)
# Hyperparameters and features
hyperparameters = Column(JSON) # Store optimized parameters
features_used = Column(JSON) # List of regressor columns
normalization_params = Column(JSON) # Store feature normalization parameters for consistent predictions
product_category = Column(String, nullable=True) # Product category for category-specific forecasting
# Model status
is_active = Column(Boolean, default=True)
is_production = Column(Boolean, default=False)
# Timestamps - Updated to be timezone-aware with proper defaults
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
last_used_at = Column(DateTime(timezone=True))
# Training data info
training_start_date = Column(DateTime(timezone=True))
training_end_date = Column(DateTime(timezone=True))
data_quality_score = Column(Float)
# Additional metadata
notes = Column(Text)
created_by = Column(String) # User who triggered training
def to_dict(self):
return {
"id": str(self.id),
"model_id": str(self.id),
"tenant_id": str(self.tenant_id),
"inventory_product_id": str(self.inventory_product_id),
"model_type": self.model_type,
"model_version": self.model_version,
"model_path": self.model_path,
"mape": self.mape,
"mae": self.mae,
"rmse": self.rmse,
"r2_score": self.r2_score,
"training_samples": self.training_samples,
"hyperparameters": self.hyperparameters,
"features_used": self.features_used,
"features": self.features_used, # Alias for frontend compatibility (ModelDetailsModal expects 'features')
"product_category": self.product_category,
"is_active": self.is_active,
"is_production": self.is_production,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
"last_used_at": self.last_used_at.isoformat() if self.last_used_at else None,
"training_start_date": self.training_start_date.isoformat() if self.training_start_date else None,
"training_end_date": self.training_end_date.isoformat() if self.training_end_date else None,
"data_quality_score": self.data_quality_score
}
class TrainingPerformanceMetrics(Base):
"""
Table to track historical training performance for time estimation.
Stores aggregated metrics from completed training jobs.
"""
__tablename__ = "training_performance_metrics"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
job_id = Column(String(255), nullable=False, index=True)
# Training job statistics
total_products = Column(Integer, nullable=False)
successful_products = Column(Integer, nullable=False)
failed_products = Column(Integer, nullable=False)
# Time metrics
total_duration_seconds = Column(Float, nullable=False)
avg_time_per_product = Column(Float, nullable=False) # Key metric for estimation
data_analysis_time_seconds = Column(Float, nullable=True)
training_time_seconds = Column(Float, nullable=True)
finalization_time_seconds = Column(Float, nullable=True)
# Job metadata
completed_at = Column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc))
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
def __repr__(self):
return (
f"<TrainingPerformanceMetrics("
f"tenant_id={self.tenant_id}, "
f"job_id={self.job_id}, "
f"total_products={self.total_products}, "
f"avg_time_per_product={self.avg_time_per_product:.2f}s"
f")>"
)
def to_dict(self):
return {
"id": str(self.id),
"tenant_id": str(self.tenant_id),
"job_id": self.job_id,
"total_products": self.total_products,
"successful_products": self.successful_products,
"failed_products": self.failed_products,
"total_duration_seconds": self.total_duration_seconds,
"avg_time_per_product": self.avg_time_per_product,
"data_analysis_time_seconds": self.data_analysis_time_seconds,
"training_time_seconds": self.training_time_seconds,
"finalization_time_seconds": self.finalization_time_seconds,
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
"created_at": self.created_at.isoformat() if self.created_at else None
}

View File

@@ -0,0 +1,11 @@
# services/training/app/models/training_models.py
"""
Legacy file - TrainedModel has been moved to training.py
This file is deprecated and should be removed after migration.
"""
# Import the actual model from the correct location
from .training import TrainedModel
# For backward compatibility, re-export the model
__all__ = ["TrainedModel"]

View File

@@ -0,0 +1,20 @@
"""
Training Service Repositories
Repository implementations for training service
"""
from .base import TrainingBaseRepository
from .model_repository import ModelRepository
from .training_log_repository import TrainingLogRepository
from .performance_repository import PerformanceRepository
from .job_queue_repository import JobQueueRepository
from .artifact_repository import ArtifactRepository
__all__ = [
"TrainingBaseRepository",
"ModelRepository",
"TrainingLogRepository",
"PerformanceRepository",
"JobQueueRepository",
"ArtifactRepository"
]

View File

@@ -0,0 +1,560 @@
"""
Artifact Repository
Repository for model artifact operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, text, desc
from datetime import datetime, timedelta
import structlog
from .base import TrainingBaseRepository
from app.models.training import ModelArtifact
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class ArtifactRepository(TrainingBaseRepository):
"""Repository for model artifact operations"""
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 1800):
# Artifacts are stable, longer cache time (30 minutes)
super().__init__(ModelArtifact, session, cache_ttl)
async def create_artifact(self, artifact_data: Dict[str, Any]) -> ModelArtifact:
"""Create a new model artifact record"""
try:
# Validate artifact data
validation_result = self._validate_training_data(
artifact_data,
["model_id", "tenant_id", "artifact_type", "file_path"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid artifact data: {validation_result['errors']}")
# Set default values
if "storage_location" not in artifact_data:
artifact_data["storage_location"] = "local"
# Create artifact record
artifact = await self.create(artifact_data)
logger.info("Model artifact created",
model_id=artifact.model_id,
tenant_id=artifact.tenant_id,
artifact_type=artifact.artifact_type,
file_path=artifact.file_path)
return artifact
except ValidationError:
raise
except Exception as e:
logger.error("Failed to create model artifact",
model_id=artifact_data.get("model_id"),
error=str(e))
raise DatabaseError(f"Failed to create artifact: {str(e)}")
async def get_artifacts_by_model(
self,
model_id: str,
artifact_type: str = None
) -> List[ModelArtifact]:
"""Get all artifacts for a model"""
try:
filters = {"model_id": model_id}
if artifact_type:
filters["artifact_type"] = artifact_type
return await self.get_multi(
filters=filters,
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get artifacts by model",
model_id=model_id,
artifact_type=artifact_type,
error=str(e))
raise DatabaseError(f"Failed to get artifacts: {str(e)}")
async def get_artifacts_by_tenant(
self,
tenant_id: str,
artifact_type: str = None,
skip: int = 0,
limit: int = 100
) -> List[ModelArtifact]:
"""Get artifacts for a tenant"""
try:
filters = {"tenant_id": tenant_id}
if artifact_type:
filters["artifact_type"] = artifact_type
return await self.get_multi(
filters=filters,
skip=skip,
limit=limit,
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get artifacts by tenant",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get tenant artifacts: {str(e)}")
async def get_artifact_by_path(self, file_path: str) -> Optional[ModelArtifact]:
"""Get artifact by file path"""
try:
return await self.get_by_field("file_path", file_path)
except Exception as e:
logger.error("Failed to get artifact by path",
file_path=file_path,
error=str(e))
raise DatabaseError(f"Failed to get artifact: {str(e)}")
async def update_artifact_size(self, artifact_id: int, file_size_bytes: int) -> Optional[ModelArtifact]:
"""Update artifact file size"""
try:
return await self.update(artifact_id, {"file_size_bytes": file_size_bytes})
except Exception as e:
logger.error("Failed to update artifact size",
artifact_id=artifact_id,
error=str(e))
return None
async def update_artifact_checksum(self, artifact_id: int, checksum: str) -> Optional[ModelArtifact]:
"""Update artifact checksum for integrity verification"""
try:
return await self.update(artifact_id, {"checksum": checksum})
except Exception as e:
logger.error("Failed to update artifact checksum",
artifact_id=artifact_id,
error=str(e))
return None
async def mark_artifact_expired(self, artifact_id: int, expires_at: datetime = None) -> Optional[ModelArtifact]:
"""Mark artifact for expiration/cleanup"""
try:
if not expires_at:
expires_at = datetime.now()
return await self.update(artifact_id, {"expires_at": expires_at})
except Exception as e:
logger.error("Failed to mark artifact as expired",
artifact_id=artifact_id,
error=str(e))
return None
async def get_expired_artifacts(self, days_expired: int = 0) -> List[ModelArtifact]:
"""Get artifacts that have expired"""
try:
cutoff_date = datetime.now() - timedelta(days=days_expired)
query_text = """
SELECT * FROM model_artifacts
WHERE expires_at IS NOT NULL
AND expires_at <= :cutoff_date
ORDER BY expires_at ASC
"""
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
expired_artifacts = []
for row in result.fetchall():
record_dict = dict(row._mapping)
artifact = self.model(**record_dict)
expired_artifacts.append(artifact)
return expired_artifacts
except Exception as e:
logger.error("Failed to get expired artifacts",
days_expired=days_expired,
error=str(e))
return []
async def cleanup_expired_artifacts(self, days_expired: int = 0) -> int:
"""Clean up expired artifacts"""
try:
cutoff_date = datetime.now() - timedelta(days=days_expired)
query_text = """
DELETE FROM model_artifacts
WHERE expires_at IS NOT NULL
AND expires_at <= :cutoff_date
"""
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
deleted_count = result.rowcount
logger.info("Cleaned up expired artifacts",
deleted_count=deleted_count,
days_expired=days_expired)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup expired artifacts",
days_expired=days_expired,
error=str(e))
raise DatabaseError(f"Artifact cleanup failed: {str(e)}")
async def get_large_artifacts(self, min_size_mb: int = 100) -> List[ModelArtifact]:
"""Get artifacts larger than specified size"""
try:
min_size_bytes = min_size_mb * 1024 * 1024 # Convert MB to bytes
query_text = """
SELECT * FROM model_artifacts
WHERE file_size_bytes >= :min_size_bytes
ORDER BY file_size_bytes DESC
"""
result = await self.session.execute(text(query_text), {"min_size_bytes": min_size_bytes})
large_artifacts = []
for row in result.fetchall():
record_dict = dict(row._mapping)
artifact = self.model(**record_dict)
large_artifacts.append(artifact)
return large_artifacts
except Exception as e:
logger.error("Failed to get large artifacts",
min_size_mb=min_size_mb,
error=str(e))
return []
async def get_artifacts_by_storage_location(
self,
storage_location: str,
tenant_id: str = None
) -> List[ModelArtifact]:
"""Get artifacts by storage location"""
try:
filters = {"storage_location": storage_location}
if tenant_id:
filters["tenant_id"] = tenant_id
return await self.get_multi(
filters=filters,
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get artifacts by storage location",
storage_location=storage_location,
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get artifacts: {str(e)}")
async def get_artifact_statistics(self, tenant_id: str = None) -> Dict[str, Any]:
"""Get artifact statistics"""
try:
base_filters = {}
if tenant_id:
base_filters["tenant_id"] = tenant_id
# Get basic counts
total_artifacts = await self.count(filters=base_filters)
# Get artifacts by type
type_query_params = {}
type_query_filter = ""
if tenant_id:
type_query_filter = "WHERE tenant_id = :tenant_id"
type_query_params["tenant_id"] = tenant_id
type_query = text(f"""
SELECT artifact_type, COUNT(*) as count
FROM model_artifacts
{type_query_filter}
GROUP BY artifact_type
ORDER BY count DESC
""")
result = await self.session.execute(type_query, type_query_params)
artifacts_by_type = {row.artifact_type: row.count for row in result.fetchall()}
# Get storage location stats
location_query = text(f"""
SELECT
storage_location,
COUNT(*) as count,
SUM(COALESCE(file_size_bytes, 0)) as total_size_bytes
FROM model_artifacts
{type_query_filter}
GROUP BY storage_location
ORDER BY count DESC
""")
location_result = await self.session.execute(location_query, type_query_params)
storage_stats = {}
total_size_bytes = 0
for row in location_result.fetchall():
storage_stats[row.storage_location] = {
"artifact_count": row.count,
"total_size_bytes": int(row.total_size_bytes or 0),
"total_size_mb": round((row.total_size_bytes or 0) / (1024 * 1024), 2)
}
total_size_bytes += row.total_size_bytes or 0
# Get expired artifacts count
expired_artifacts = len(await self.get_expired_artifacts())
return {
"total_artifacts": total_artifacts,
"expired_artifacts": expired_artifacts,
"active_artifacts": total_artifacts - expired_artifacts,
"artifacts_by_type": artifacts_by_type,
"storage_statistics": storage_stats,
"total_storage": {
"total_size_bytes": total_size_bytes,
"total_size_mb": round(total_size_bytes / (1024 * 1024), 2),
"total_size_gb": round(total_size_bytes / (1024 * 1024 * 1024), 2)
}
}
except Exception as e:
logger.error("Failed to get artifact statistics",
tenant_id=tenant_id,
error=str(e))
return {
"total_artifacts": 0,
"expired_artifacts": 0,
"active_artifacts": 0,
"artifacts_by_type": {},
"storage_statistics": {},
"total_storage": {
"total_size_bytes": 0,
"total_size_mb": 0.0,
"total_size_gb": 0.0
}
}
async def verify_artifact_integrity(self, artifact_id: int) -> Dict[str, Any]:
"""Verify artifact file integrity with actual file system checks"""
try:
import os
import hashlib
artifact = await self.get_by_id(artifact_id)
if not artifact:
return {"exists": False, "error": "Artifact not found"}
# Check if file exists
file_exists = os.path.exists(artifact.file_path)
if not file_exists:
return {
"artifact_id": artifact_id,
"file_path": artifact.file_path,
"exists": False,
"checksum_valid": False,
"size_valid": False,
"storage_location": artifact.storage_location,
"last_verified": datetime.now().isoformat(),
"error": "File does not exist on disk"
}
# Verify file size
actual_size = os.path.getsize(artifact.file_path)
size_valid = True
if artifact.file_size_bytes:
size_valid = (actual_size == artifact.file_size_bytes)
# Verify checksum if stored
checksum_valid = True
actual_checksum = None
if artifact.checksum:
# Calculate checksum of actual file
sha256_hash = hashlib.sha256()
try:
with open(artifact.file_path, "rb") as f:
# Read file in chunks to handle large files
for byte_block in iter(lambda: f.read(4096), b""):
sha256_hash.update(byte_block)
actual_checksum = sha256_hash.hexdigest()
checksum_valid = (actual_checksum == artifact.checksum)
except Exception as checksum_error:
logger.error(f"Failed to calculate checksum: {checksum_error}")
checksum_valid = False
actual_checksum = None
# Overall integrity status
integrity_valid = file_exists and size_valid and checksum_valid
result = {
"artifact_id": artifact_id,
"file_path": artifact.file_path,
"exists": file_exists,
"checksum_valid": checksum_valid,
"size_valid": size_valid,
"integrity_valid": integrity_valid,
"storage_location": artifact.storage_location,
"last_verified": datetime.now().isoformat(),
"details": {
"stored_size_bytes": artifact.file_size_bytes,
"actual_size_bytes": actual_size if file_exists else None,
"stored_checksum": artifact.checksum,
"actual_checksum": actual_checksum
}
}
if not integrity_valid:
issues = []
if not file_exists:
issues.append("file_missing")
if not size_valid:
issues.append("size_mismatch")
if not checksum_valid:
issues.append("checksum_mismatch")
result["issues"] = issues
return result
except Exception as e:
logger.error("Failed to verify artifact integrity",
artifact_id=artifact_id,
error=str(e))
return {
"exists": False,
"error": f"Verification failed: {str(e)}"
}
async def migrate_artifacts_to_storage(
self,
from_location: str,
to_location: str,
tenant_id: str = None,
copy_only: bool = False,
verify: bool = True
) -> Dict[str, Any]:
"""Migrate artifacts from one storage location to another with actual file operations"""
try:
import os
import shutil
import hashlib
# Get artifacts to migrate
artifacts = await self.get_artifacts_by_storage_location(from_location, tenant_id)
migrated_count = 0
failed_count = 0
failed_artifacts = []
verified_count = 0
for artifact in artifacts:
try:
# Determine new file path
new_file_path = artifact.file_path.replace(from_location, to_location, 1)
# Create destination directory if it doesn't exist
dest_dir = os.path.dirname(new_file_path)
os.makedirs(dest_dir, exist_ok=True)
# Check if source file exists
if not os.path.exists(artifact.file_path):
logger.warning(f"Source file not found: {artifact.file_path}")
failed_count += 1
failed_artifacts.append({
"artifact_id": artifact.id,
"file_path": artifact.file_path,
"reason": "source_file_not_found"
})
continue
# Copy or move file
if copy_only:
shutil.copy2(artifact.file_path, new_file_path)
logger.debug(f"Copied file from {artifact.file_path} to {new_file_path}")
else:
shutil.move(artifact.file_path, new_file_path)
logger.debug(f"Moved file from {artifact.file_path} to {new_file_path}")
# Verify file was copied/moved successfully
if verify and os.path.exists(new_file_path):
# Verify file size
new_size = os.path.getsize(new_file_path)
if artifact.file_size_bytes and new_size != artifact.file_size_bytes:
logger.warning(f"File size mismatch after migration: {new_file_path}")
failed_count += 1
failed_artifacts.append({
"artifact_id": artifact.id,
"file_path": new_file_path,
"reason": "size_mismatch_after_migration"
})
continue
# Verify checksum if available
if artifact.checksum:
sha256_hash = hashlib.sha256()
with open(new_file_path, "rb") as f:
for byte_block in iter(lambda: f.read(4096), b""):
sha256_hash.update(byte_block)
new_checksum = sha256_hash.hexdigest()
if new_checksum != artifact.checksum:
logger.warning(f"Checksum mismatch after migration: {new_file_path}")
failed_count += 1
failed_artifacts.append({
"artifact_id": artifact.id,
"file_path": new_file_path,
"reason": "checksum_mismatch_after_migration"
})
continue
verified_count += 1
# Update database with new location
await self.update(artifact.id, {
"storage_location": to_location,
"file_path": new_file_path
})
migrated_count += 1
except Exception as migration_error:
logger.error("Failed to migrate artifact",
artifact_id=artifact.id,
error=str(migration_error))
failed_count += 1
failed_artifacts.append({
"artifact_id": artifact.id,
"file_path": artifact.file_path,
"reason": str(migration_error)
})
logger.info("Artifact migration completed",
from_location=from_location,
to_location=to_location,
migrated_count=migrated_count,
failed_count=failed_count,
verified_count=verified_count)
return {
"from_location": from_location,
"to_location": to_location,
"total_artifacts": len(artifacts),
"migrated_count": migrated_count,
"failed_count": failed_count,
"verified_count": verified_count if verify else None,
"success_rate": round((migrated_count / len(artifacts)) * 100, 2) if artifacts else 100,
"copy_only": copy_only,
"failed_artifacts": failed_artifacts if failed_artifacts else None
}
except Exception as e:
logger.error("Failed to migrate artifacts",
from_location=from_location,
to_location=to_location,
error=str(e))
return {
"error": f"Migration failed: {str(e)}"
}

View File

@@ -0,0 +1,179 @@
"""
Base Repository for Training Service
Service-specific repository base class with training service utilities
"""
from typing import Optional, List, Dict, Any, Type
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
from datetime import datetime, timezone, timedelta
import structlog
from shared.database.repository import BaseRepository
from shared.database.exceptions import DatabaseError
logger = structlog.get_logger()
class TrainingBaseRepository(BaseRepository):
"""Base repository for training service with common training operations"""
def __init__(self, model: Type, session: AsyncSession, cache_ttl: Optional[int] = 300):
# Training data changes frequently, shorter cache time (5 minutes)
super().__init__(model, session, cache_ttl)
async def get_by_tenant_id(self, tenant_id: str, skip: int = 0, limit: int = 100) -> List:
"""Get records by tenant ID"""
if hasattr(self.model, 'tenant_id'):
return await self.get_multi(
skip=skip,
limit=limit,
filters={"tenant_id": tenant_id},
order_by="created_at",
order_desc=True
)
return await self.get_multi(skip=skip, limit=limit)
async def get_active_records(self, skip: int = 0, limit: int = 100) -> List:
"""Get active records (if model has is_active field)"""
if hasattr(self.model, 'is_active'):
return await self.get_multi(
skip=skip,
limit=limit,
filters={"is_active": True},
order_by="created_at",
order_desc=True
)
return await self.get_multi(skip=skip, limit=limit)
async def get_by_job_id(self, job_id: str) -> Optional:
"""Get record by job ID (if model has job_id field)"""
if hasattr(self.model, 'job_id'):
return await self.get_by_field("job_id", job_id)
return None
async def get_by_model_id(self, model_id: str) -> Optional:
"""Get record by model ID (if model has model_id field)"""
if hasattr(self.model, 'model_id'):
return await self.get_by_field("model_id", model_id)
return None
async def deactivate_record(self, record_id: Any) -> Optional:
"""Deactivate a record instead of deleting it"""
if hasattr(self.model, 'is_active'):
return await self.update(record_id, {"is_active": False})
return await self.delete(record_id)
async def activate_record(self, record_id: Any) -> Optional:
"""Activate a record"""
if hasattr(self.model, 'is_active'):
return await self.update(record_id, {"is_active": True})
return await self.get_by_id(record_id)
async def cleanup_old_records(self, days_old: int = 90, status_filter: str = None) -> int:
"""Clean up old training records"""
try:
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_old)
table_name = self.model.__tablename__
# Build query based on available fields
conditions = [f"created_at < :cutoff_date"]
params = {"cutoff_date": cutoff_date}
if status_filter and hasattr(self.model, 'status'):
conditions.append(f"status = :status")
params["status"] = status_filter
query_text = f"""
DELETE FROM {table_name}
WHERE {' AND '.join(conditions)}
"""
result = await self.session.execute(text(query_text), params)
deleted_count = result.rowcount
logger.info(f"Cleaned up old {self.model.__name__} records",
deleted_count=deleted_count,
days_old=days_old,
status_filter=status_filter)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup old records",
model=self.model.__name__,
error=str(e))
raise DatabaseError(f"Cleanup failed: {str(e)}")
async def get_records_by_date_range(
self,
start_date: datetime,
end_date: datetime,
skip: int = 0,
limit: int = 100
) -> List:
"""Get records within date range"""
if not hasattr(self.model, 'created_at'):
logger.warning(f"Model {self.model.__name__} has no created_at field")
return []
try:
table_name = self.model.__tablename__
query_text = f"""
SELECT * FROM {table_name}
WHERE created_at >= :start_date
AND created_at <= :end_date
ORDER BY created_at DESC
LIMIT :limit OFFSET :skip
"""
result = await self.session.execute(text(query_text), {
"start_date": start_date,
"end_date": end_date,
"limit": limit,
"skip": skip
})
# Convert rows to model objects
records = []
for row in result.fetchall():
# Create model instance from row data
record_dict = dict(row._mapping)
record = self.model(**record_dict)
records.append(record)
return records
except Exception as e:
logger.error("Failed to get records by date range",
model=self.model.__name__,
start_date=start_date,
end_date=end_date,
error=str(e))
raise DatabaseError(f"Date range query failed: {str(e)}")
def _validate_training_data(self, data: Dict[str, Any], required_fields: List[str]) -> Dict[str, Any]:
"""Validate training-related data"""
errors = []
for field in required_fields:
if field not in data or not data[field]:
errors.append(f"Missing required field: {field}")
# Validate tenant_id format if present
if "tenant_id" in data and data["tenant_id"]:
tenant_id = data["tenant_id"]
if not isinstance(tenant_id, str) or len(tenant_id) < 1:
errors.append("Invalid tenant_id format")
# Validate job_id format if present
if "job_id" in data and data["job_id"]:
job_id = data["job_id"]
if not isinstance(job_id, str) or len(job_id) < 1:
errors.append("Invalid job_id format")
return {
"is_valid": len(errors) == 0,
"errors": errors
}

View File

@@ -0,0 +1,445 @@
"""
Job Queue Repository
Repository for training job queue operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, text, desc
from datetime import datetime, timedelta
import structlog
from .base import TrainingBaseRepository
from app.models.training import TrainingJobQueue
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class JobQueueRepository(TrainingBaseRepository):
"""Repository for training job queue operations"""
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 60):
# Job queue changes frequently, very short cache time (1 minute)
super().__init__(TrainingJobQueue, session, cache_ttl)
async def enqueue_job(self, job_data: Dict[str, Any]) -> TrainingJobQueue:
"""Add a job to the training queue"""
try:
# Validate job data
validation_result = self._validate_training_data(
job_data,
["job_id", "tenant_id", "job_type"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid job data: {validation_result['errors']}")
# Set default values
if "priority" not in job_data:
job_data["priority"] = 1
if "status" not in job_data:
job_data["status"] = "queued"
if "max_retries" not in job_data:
job_data["max_retries"] = 3
# Create queue entry
queued_job = await self.create(job_data)
logger.info("Job enqueued",
job_id=queued_job.job_id,
tenant_id=queued_job.tenant_id,
job_type=queued_job.job_type,
priority=queued_job.priority)
return queued_job
except ValidationError:
raise
except Exception as e:
logger.error("Failed to enqueue job",
job_id=job_data.get("job_id"),
error=str(e))
raise DatabaseError(f"Failed to enqueue job: {str(e)}")
async def get_next_job(self, job_types: List[str] = None) -> Optional[TrainingJobQueue]:
"""Get the next job to process from the queue"""
try:
# Build filters for job types if specified
filters = {"status": "queued"}
if job_types:
# For multiple job types, we need to use raw SQL
job_types_str = "', '".join(job_types)
query_text = f"""
SELECT * FROM training_job_queue
WHERE status = 'queued'
AND job_type IN ('{job_types_str}')
AND (scheduled_at IS NULL OR scheduled_at <= :now)
ORDER BY priority DESC, created_at ASC
LIMIT 1
"""
result = await self.session.execute(text(query_text), {"now": datetime.now()})
row = result.fetchone()
if row:
record_dict = dict(row._mapping)
return self.model(**record_dict)
return None
else:
# Simple case - get any queued job
jobs = await self.get_multi(
filters=filters,
limit=1,
order_by="priority",
order_desc=True
)
return jobs[0] if jobs else None
except Exception as e:
logger.error("Failed to get next job from queue",
job_types=job_types,
error=str(e))
raise DatabaseError(f"Failed to get next job: {str(e)}")
async def start_job(self, job_id: str) -> Optional[TrainingJobQueue]:
"""Mark a job as started"""
try:
job = await self.get_by_job_id(job_id)
if not job:
logger.error(f"Job not found in queue: {job_id}")
return None
if job.status != "queued":
logger.warning(f"Job {job_id} is not queued (status: {job.status})")
return job
updated_job = await self.update(job.id, {
"status": "running",
"started_at": datetime.now(),
"updated_at": datetime.now()
})
logger.info("Job started",
job_id=job_id,
job_type=job.job_type)
return updated_job
except Exception as e:
logger.error("Failed to start job",
job_id=job_id,
error=str(e))
raise DatabaseError(f"Failed to start job: {str(e)}")
async def complete_job(self, job_id: str) -> Optional[TrainingJobQueue]:
"""Mark a job as completed"""
try:
job = await self.get_by_job_id(job_id)
if not job:
logger.error(f"Job not found in queue: {job_id}")
return None
updated_job = await self.update(job.id, {
"status": "completed",
"updated_at": datetime.now()
})
logger.info("Job completed",
job_id=job_id,
job_type=job.job_type if job else "unknown")
return updated_job
except Exception as e:
logger.error("Failed to complete job",
job_id=job_id,
error=str(e))
raise DatabaseError(f"Failed to complete job: {str(e)}")
async def fail_job(self, job_id: str, error_message: str = None) -> Optional[TrainingJobQueue]:
"""Mark a job as failed and handle retries"""
try:
job = await self.get_by_job_id(job_id)
if not job:
logger.error(f"Job not found in queue: {job_id}")
return None
# Increment retry count
new_retry_count = job.retry_count + 1
# Check if we should retry
if new_retry_count < job.max_retries:
# Reset to queued for retry
updated_job = await self.update(job.id, {
"status": "queued",
"retry_count": new_retry_count,
"updated_at": datetime.now(),
"started_at": None # Reset started_at for retry
})
logger.info("Job failed, queued for retry",
job_id=job_id,
retry_count=new_retry_count,
max_retries=job.max_retries)
else:
# Mark as permanently failed
updated_job = await self.update(job.id, {
"status": "failed",
"retry_count": new_retry_count,
"updated_at": datetime.now()
})
logger.error("Job permanently failed",
job_id=job_id,
retry_count=new_retry_count,
error_message=error_message)
return updated_job
except Exception as e:
logger.error("Failed to handle job failure",
job_id=job_id,
error=str(e))
raise DatabaseError(f"Failed to handle job failure: {str(e)}")
async def cancel_job(self, job_id: str, cancelled_by: str = None) -> Optional[TrainingJobQueue]:
"""Cancel a job"""
try:
job = await self.get_by_job_id(job_id)
if not job:
logger.error(f"Job not found in queue: {job_id}")
return None
if job.status in ["completed", "failed"]:
logger.warning(f"Cannot cancel job {job_id} with status {job.status}")
return job
updated_job = await self.update(job.id, {
"status": "cancelled",
"cancelled_by": cancelled_by,
"updated_at": datetime.now()
})
logger.info("Job cancelled",
job_id=job_id,
cancelled_by=cancelled_by)
return updated_job
except Exception as e:
logger.error("Failed to cancel job",
job_id=job_id,
error=str(e))
raise DatabaseError(f"Failed to cancel job: {str(e)}")
async def get_queue_status(self, tenant_id: str = None) -> Dict[str, Any]:
"""Get queue status and statistics"""
try:
base_filters = {}
if tenant_id:
base_filters["tenant_id"] = tenant_id
# Get counts by status
queued_jobs = await self.count(filters={**base_filters, "status": "queued"})
running_jobs = await self.count(filters={**base_filters, "status": "running"})
completed_jobs = await self.count(filters={**base_filters, "status": "completed"})
failed_jobs = await self.count(filters={**base_filters, "status": "failed"})
cancelled_jobs = await self.count(filters={**base_filters, "status": "cancelled"})
# Get jobs by type
type_query = text(f"""
SELECT job_type, COUNT(*) as count
FROM training_job_queue
WHERE 1=1
{' AND tenant_id = :tenant_id' if tenant_id else ''}
GROUP BY job_type
ORDER BY count DESC
""")
params = {"tenant_id": tenant_id} if tenant_id else {}
result = await self.session.execute(type_query, params)
jobs_by_type = {row.job_type: row.count for row in result.fetchall()}
# Get average wait time for completed jobs
wait_time_query = text(f"""
SELECT
AVG(EXTRACT(EPOCH FROM (started_at - created_at))/60) as avg_wait_minutes
FROM training_job_queue
WHERE status = 'completed'
AND started_at IS NOT NULL
AND created_at IS NOT NULL
{' AND tenant_id = :tenant_id' if tenant_id else ''}
""")
wait_result = await self.session.execute(wait_time_query, params)
wait_row = wait_result.fetchone()
avg_wait_time = float(wait_row.avg_wait_minutes) if wait_row and wait_row.avg_wait_minutes else 0.0
return {
"tenant_id": tenant_id,
"queue_counts": {
"queued": queued_jobs,
"running": running_jobs,
"completed": completed_jobs,
"failed": failed_jobs,
"cancelled": cancelled_jobs,
"total": queued_jobs + running_jobs + completed_jobs + failed_jobs + cancelled_jobs
},
"jobs_by_type": jobs_by_type,
"avg_wait_time_minutes": round(avg_wait_time, 2),
"queue_health": {
"has_queued_jobs": queued_jobs > 0,
"has_running_jobs": running_jobs > 0,
"failure_rate": round((failed_jobs / max(completed_jobs + failed_jobs, 1)) * 100, 2)
}
}
except Exception as e:
logger.error("Failed to get queue status",
tenant_id=tenant_id,
error=str(e))
return {
"tenant_id": tenant_id,
"queue_counts": {
"queued": 0, "running": 0, "completed": 0,
"failed": 0, "cancelled": 0, "total": 0
},
"jobs_by_type": {},
"avg_wait_time_minutes": 0.0,
"queue_health": {
"has_queued_jobs": False,
"has_running_jobs": False,
"failure_rate": 0.0
}
}
async def get_jobs_by_tenant(
self,
tenant_id: str,
status: str = None,
job_type: str = None,
skip: int = 0,
limit: int = 100
) -> List[TrainingJobQueue]:
"""Get jobs for a tenant with optional filtering"""
try:
filters = {"tenant_id": tenant_id}
if status:
filters["status"] = status
if job_type:
filters["job_type"] = job_type
return await self.get_multi(
filters=filters,
skip=skip,
limit=limit,
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get jobs by tenant",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get tenant jobs: {str(e)}")
async def cleanup_old_jobs(self, days_old: int = 30, status_filter: str = None) -> int:
"""Clean up old completed/failed/cancelled jobs"""
try:
cutoff_date = datetime.now() - timedelta(days=days_old)
# Only clean up finished jobs by default
default_statuses = ["completed", "failed", "cancelled"]
if status_filter:
status_condition = "status = :status"
params = {"cutoff_date": cutoff_date, "status": status_filter}
else:
status_list = "', '".join(default_statuses)
status_condition = f"status IN ('{status_list}')"
params = {"cutoff_date": cutoff_date}
query_text = f"""
DELETE FROM training_job_queue
WHERE created_at < :cutoff_date
AND {status_condition}
"""
result = await self.session.execute(text(query_text), params)
deleted_count = result.rowcount
logger.info("Cleaned up old queue jobs",
deleted_count=deleted_count,
days_old=days_old,
status_filter=status_filter)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup old queue jobs",
error=str(e))
raise DatabaseError(f"Queue cleanup failed: {str(e)}")
async def get_stuck_jobs(self, hours_stuck: int = 2) -> List[TrainingJobQueue]:
"""Get jobs that have been running for too long"""
try:
cutoff_time = datetime.now() - timedelta(hours=hours_stuck)
query_text = """
SELECT * FROM training_job_queue
WHERE status = 'running'
AND started_at IS NOT NULL
AND started_at < :cutoff_time
ORDER BY started_at ASC
"""
result = await self.session.execute(text(query_text), {"cutoff_time": cutoff_time})
stuck_jobs = []
for row in result.fetchall():
record_dict = dict(row._mapping)
job = self.model(**record_dict)
stuck_jobs.append(job)
if stuck_jobs:
logger.warning("Found stuck jobs",
count=len(stuck_jobs),
hours_stuck=hours_stuck)
return stuck_jobs
except Exception as e:
logger.error("Failed to get stuck jobs",
hours_stuck=hours_stuck,
error=str(e))
return []
async def reset_stuck_jobs(self, hours_stuck: int = 2) -> int:
"""Reset stuck jobs back to queued status"""
try:
stuck_jobs = await self.get_stuck_jobs(hours_stuck)
reset_count = 0
for job in stuck_jobs:
# Reset job to queued status
await self.update(job.id, {
"status": "queued",
"started_at": None,
"updated_at": datetime.now()
})
reset_count += 1
if reset_count > 0:
logger.info("Reset stuck jobs",
reset_count=reset_count,
hours_stuck=hours_stuck)
return reset_count
except Exception as e:
logger.error("Failed to reset stuck jobs",
hours_stuck=hours_stuck,
error=str(e))
raise DatabaseError(f"Failed to reset stuck jobs: {str(e)}")

View File

@@ -0,0 +1,375 @@
"""
Model Repository
Repository for trained model operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, text, desc
from datetime import datetime, timezone, timedelta
import structlog
from .base import TrainingBaseRepository
from app.models.training import TrainedModel
from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError
logger = structlog.get_logger()
class ModelRepository(TrainingBaseRepository):
"""Repository for trained model operations"""
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 600):
# Models are relatively stable, longer cache time (10 minutes)
super().__init__(TrainedModel, session, cache_ttl)
async def create_model(self, model_data: Dict[str, Any]) -> TrainedModel:
"""Create a new trained model with validation"""
try:
# Validate model data
validation_result = self._validate_training_data(
model_data,
["tenant_id", "inventory_product_id", "model_path", "job_id"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid model data: {validation_result['errors']}")
# Check for duplicate active models for same tenant+product
existing_model = await self.get_active_model_for_product(
model_data["tenant_id"],
model_data["inventory_product_id"]
)
# If there's an existing active model, we may want to deactivate it
if existing_model and model_data.get("is_production", False):
logger.info("Deactivating previous production model",
previous_model_id=existing_model.id,
tenant_id=model_data["tenant_id"],
inventory_product_id=model_data["inventory_product_id"])
await self.update(existing_model.id, {"is_production": False})
# Create new model
model = await self.create(model_data)
logger.info("Trained model created successfully",
model_id=model.id,
tenant_id=model.tenant_id,
inventory_product_id=str(model.inventory_product_id),
model_type=model.model_type)
return model
except (ValidationError, DuplicateRecordError):
raise
except Exception as e:
logger.error("Failed to create trained model",
tenant_id=model_data.get("tenant_id"),
inventory_product_id=model_data.get("inventory_product_id"),
error=str(e))
raise DatabaseError(f"Failed to create model: {str(e)}")
async def get_model_by_tenant_and_product(
self,
tenant_id: str,
inventory_product_id: str
) -> List[TrainedModel]:
"""Get all models for a tenant and product"""
try:
return await self.get_multi(
filters={
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id
},
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get models by tenant and product",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
error=str(e))
raise DatabaseError(f"Failed to get models: {str(e)}")
async def get_active_model_for_product(
self,
tenant_id: str,
inventory_product_id: str
) -> Optional[TrainedModel]:
"""Get the active production model for a product"""
try:
models = await self.get_multi(
filters={
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id,
"is_active": True,
"is_production": True
},
order_by="created_at",
order_desc=True,
limit=1
)
return models[0] if models else None
except Exception as e:
logger.error("Failed to get active model for product",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
error=str(e))
raise DatabaseError(f"Failed to get active model: {str(e)}")
async def get_models_by_tenant(
self,
tenant_id: str,
skip: int = 0,
limit: int = 100
) -> List[TrainedModel]:
"""Get all models for a tenant"""
return await self.get_by_tenant_id(tenant_id, skip=skip, limit=limit)
async def promote_to_production(self, model_id: str) -> Optional[TrainedModel]:
"""Promote a model to production"""
try:
# Get the model first
model = await self.get_by_id(model_id)
if not model:
raise ValueError(f"Model {model_id} not found")
# Deactivate other production models for the same tenant+product
await self._deactivate_other_production_models(
model.tenant_id,
str(model.inventory_product_id),
model_id
)
# Promote this model
updated_model = await self.update(model_id, {
"is_production": True,
"last_used_at": datetime.now(timezone.utc)
})
logger.info("Model promoted to production",
model_id=model_id,
tenant_id=model.tenant_id,
inventory_product_id=str(model.inventory_product_id))
return updated_model
except Exception as e:
logger.error("Failed to promote model to production",
model_id=model_id,
error=str(e))
raise DatabaseError(f"Failed to promote model: {str(e)}")
async def update_model_usage(self, model_id: str) -> Optional[TrainedModel]:
"""Update model last used timestamp"""
try:
return await self.update(model_id, {
"last_used_at": datetime.now(timezone.utc)
})
except Exception as e:
logger.error("Failed to update model usage",
model_id=model_id,
error=str(e))
# Don't raise here - usage update is not critical
return None
async def archive_old_models(self, tenant_id: str, days_old: int = 90) -> int:
"""Archive old non-production models"""
try:
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_old)
query = text("""
UPDATE trained_models
SET is_active = false
WHERE tenant_id = :tenant_id
AND is_production = false
AND created_at < :cutoff_date
AND is_active = true
""")
result = await self.session.execute(query, {
"tenant_id": tenant_id,
"cutoff_date": cutoff_date
})
archived_count = result.rowcount
logger.info("Archived old models",
tenant_id=tenant_id,
archived_count=archived_count,
days_old=days_old)
return archived_count
except Exception as e:
logger.error("Failed to archive old models",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Model archival failed: {str(e)}")
async def get_model_statistics(self, tenant_id: str) -> Dict[str, Any]:
"""Get model statistics for a tenant"""
try:
# Get basic counts
total_models = await self.count(filters={"tenant_id": tenant_id})
active_models = await self.count(filters={
"tenant_id": tenant_id,
"is_active": True
})
production_models = await self.count(filters={
"tenant_id": tenant_id,
"is_production": True
})
# Get models by product using raw query
product_query = text("""
SELECT inventory_product_id, COUNT(*) as count
FROM trained_models
WHERE tenant_id = :tenant_id
AND is_active = true
GROUP BY inventory_product_id
ORDER BY count DESC
""")
result = await self.session.execute(product_query, {"tenant_id": tenant_id})
product_stats = {row.inventory_product_id: row.count for row in result.fetchall()}
# Recent activity (models created in last 30 days)
thirty_days_ago = datetime.now(timezone.utc) - timedelta(days=30)
recent_models_query = text("""
SELECT COUNT(*) as count
FROM trained_models
WHERE tenant_id = :tenant_id
AND created_at >= :thirty_days_ago
""")
recent_result = await self.session.execute(
recent_models_query,
{"tenant_id": tenant_id, "thirty_days_ago": thirty_days_ago}
)
recent_models = recent_result.scalar() or 0
# Calculate average accuracy from model metrics
accuracy_query = text("""
SELECT AVG(mape) as average_mape, COUNT(*) as total_models_with_metrics
FROM trained_models
WHERE tenant_id = :tenant_id
AND mape IS NOT NULL
AND is_active = true
""")
accuracy_result = await self.session.execute(accuracy_query, {"tenant_id": tenant_id})
accuracy_row = accuracy_result.fetchone()
average_mape = accuracy_row.average_mape if accuracy_row and accuracy_row.average_mape else 0
total_models_with_metrics = accuracy_row.total_models_with_metrics if accuracy_row else 0
# Convert MAPE to accuracy percentage (lower MAPE = higher accuracy)
# Use 100 - MAPE as a simple conversion, but cap it at reasonable bounds
# Return None if no models have metrics (no data), rather than 0
if total_models_with_metrics == 0:
average_accuracy = None
else:
average_accuracy = max(0, min(100, 100 - float(average_mape))) if average_mape > 0 else 0
return {
"total_models": total_models,
"active_models": active_models,
"inactive_models": total_models - active_models,
"production_models": production_models,
"models_by_product": product_stats,
"recent_models_30d": recent_models,
"average_accuracy": average_accuracy,
"total_models_with_metrics": total_models_with_metrics,
"average_mape": float(average_mape) if average_mape > 0 else 0
}
except Exception as e:
logger.error("Failed to get model statistics",
tenant_id=tenant_id,
error=str(e))
return {
"total_models": 0,
"active_models": 0,
"inactive_models": 0,
"production_models": 0,
"models_by_product": {},
"recent_models_30d": 0,
"average_accuracy": 0,
"total_models_with_metrics": 0,
"average_mape": 0
}
async def _deactivate_other_production_models(
self,
tenant_id: str,
inventory_product_id: str,
exclude_model_id: str
) -> int:
"""Deactivate other production models for the same tenant+product"""
try:
query = text("""
UPDATE trained_models
SET is_production = false
WHERE tenant_id = :tenant_id
AND inventory_product_id = :inventory_product_id
AND id != :exclude_model_id
AND is_production = true
""")
result = await self.session.execute(query, {
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id,
"exclude_model_id": exclude_model_id
})
return result.rowcount
except Exception as e:
logger.error("Failed to deactivate other production models",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
error=str(e))
raise DatabaseError(f"Failed to deactivate models: {str(e)}")
async def get_model_performance_summary(self, model_id: str) -> Dict[str, Any]:
"""Get performance summary for a model"""
try:
model = await self.get_by_id(model_id)
if not model:
return {}
return {
"model_id": model.id,
"tenant_id": model.tenant_id,
"inventory_product_id": str(model.inventory_product_id),
"model_type": model.model_type,
"metrics": {
"mape": model.mape,
"mae": model.mae,
"rmse": model.rmse,
"r2_score": model.r2_score
},
"training_info": {
"training_samples": model.training_samples,
"training_start_date": model.training_start_date.isoformat() if model.training_start_date else None,
"training_end_date": model.training_end_date.isoformat() if model.training_end_date else None,
"data_quality_score": model.data_quality_score
},
"status": {
"is_active": model.is_active,
"is_production": model.is_production,
"created_at": model.created_at.isoformat() if model.created_at else None,
"last_used_at": model.last_used_at.isoformat() if model.last_used_at else None
},
"features": {
"hyperparameters": model.hyperparameters,
"features_used": model.features_used
}
}
except Exception as e:
logger.error("Failed to get model performance summary",
model_id=model_id,
error=str(e))
return {}

View File

@@ -0,0 +1,433 @@
"""
Performance Repository
Repository for model performance metrics operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, text, desc
from datetime import datetime, timedelta
import structlog
from .base import TrainingBaseRepository
from app.models.training import ModelPerformanceMetric
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class PerformanceRepository(TrainingBaseRepository):
"""Repository for model performance metrics operations"""
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 900):
# Performance metrics are relatively stable, longer cache time (15 minutes)
super().__init__(ModelPerformanceMetric, session, cache_ttl)
async def create_performance_metric(self, metric_data: Dict[str, Any]) -> ModelPerformanceMetric:
"""Create a new performance metric record"""
try:
# Validate metric data
validation_result = self._validate_training_data(
metric_data,
["model_id", "tenant_id", "inventory_product_id"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid metric data: {validation_result['errors']}")
# Set measurement timestamp if not provided
if "measured_at" not in metric_data:
metric_data["measured_at"] = datetime.now()
# Create metric record
metric = await self.create(metric_data)
logger.info("Performance metric created",
model_id=metric.model_id,
tenant_id=metric.tenant_id,
inventory_product_id=str(metric.inventory_product_id))
return metric
except ValidationError:
raise
except Exception as e:
logger.error("Failed to create performance metric",
model_id=metric_data.get("model_id"),
error=str(e))
raise DatabaseError(f"Failed to create metric: {str(e)}")
async def get_metrics_by_model(
self,
model_id: str,
skip: int = 0,
limit: int = 100
) -> List[ModelPerformanceMetric]:
"""Get all performance metrics for a model"""
try:
return await self.get_multi(
filters={"model_id": model_id},
skip=skip,
limit=limit,
order_by="measured_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get metrics by model",
model_id=model_id,
error=str(e))
raise DatabaseError(f"Failed to get metrics: {str(e)}")
async def get_latest_metric_for_model(self, model_id: str) -> Optional[ModelPerformanceMetric]:
"""Get the latest performance metric for a model"""
try:
metrics = await self.get_multi(
filters={"model_id": model_id},
limit=1,
order_by="measured_at",
order_desc=True
)
return metrics[0] if metrics else None
except Exception as e:
logger.error("Failed to get latest metric for model",
model_id=model_id,
error=str(e))
raise DatabaseError(f"Failed to get latest metric: {str(e)}")
async def get_metrics_by_tenant_and_product(
self,
tenant_id: str,
inventory_product_id: str,
skip: int = 0,
limit: int = 100
) -> List[ModelPerformanceMetric]:
"""Get performance metrics for a tenant's product"""
try:
return await self.get_multi(
filters={
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id
},
skip=skip,
limit=limit,
order_by="measured_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get metrics by tenant and product",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
error=str(e))
raise DatabaseError(f"Failed to get metrics: {str(e)}")
async def get_metrics_in_date_range(
self,
start_date: datetime,
end_date: datetime,
tenant_id: str = None,
model_id: str = None,
skip: int = 0,
limit: int = 100
) -> List[ModelPerformanceMetric]:
"""Get performance metrics within a date range"""
try:
# Build filters
table_name = self.model.__tablename__
conditions = ["measured_at >= :start_date", "measured_at <= :end_date"]
params = {"start_date": start_date, "end_date": end_date, "limit": limit, "skip": skip}
if tenant_id:
conditions.append("tenant_id = :tenant_id")
params["tenant_id"] = tenant_id
if model_id:
conditions.append("model_id = :model_id")
params["model_id"] = model_id
query_text = f"""
SELECT * FROM {table_name}
WHERE {' AND '.join(conditions)}
ORDER BY measured_at DESC
LIMIT :limit OFFSET :skip
"""
result = await self.session.execute(text(query_text), params)
# Convert rows to model objects
metrics = []
for row in result.fetchall():
record_dict = dict(row._mapping)
metric = self.model(**record_dict)
metrics.append(metric)
return metrics
except Exception as e:
logger.error("Failed to get metrics in date range",
start_date=start_date,
end_date=end_date,
error=str(e))
raise DatabaseError(f"Date range query failed: {str(e)}")
async def get_performance_trends(
self,
tenant_id: str,
inventory_product_id: str = None,
days: int = 30
) -> Dict[str, Any]:
"""Get performance trends for analysis"""
try:
start_date = datetime.now() - timedelta(days=days)
end_date = datetime.now()
# Build query for performance trends
conditions = ["tenant_id = :tenant_id", "measured_at >= :start_date"]
params = {"tenant_id": tenant_id, "start_date": start_date}
if inventory_product_id:
conditions.append("inventory_product_id = :inventory_product_id")
params["inventory_product_id"] = inventory_product_id
query_text = f"""
SELECT
inventory_product_id,
AVG(mae) as avg_mae,
AVG(mse) as avg_mse,
AVG(rmse) as avg_rmse,
AVG(mape) as avg_mape,
AVG(r2_score) as avg_r2_score,
AVG(accuracy_percentage) as avg_accuracy,
COUNT(*) as measurement_count,
MIN(measured_at) as first_measurement,
MAX(measured_at) as last_measurement
FROM model_performance_metrics
WHERE {' AND '.join(conditions)}
GROUP BY inventory_product_id
ORDER BY avg_accuracy DESC
"""
result = await self.session.execute(text(query_text), params)
trends = []
for row in result.fetchall():
trends.append({
"inventory_product_id": row.inventory_product_id,
"metrics": {
"avg_mae": float(row.avg_mae) if row.avg_mae else None,
"avg_mse": float(row.avg_mse) if row.avg_mse else None,
"avg_rmse": float(row.avg_rmse) if row.avg_rmse else None,
"avg_mape": float(row.avg_mape) if row.avg_mape else None,
"avg_r2_score": float(row.avg_r2_score) if row.avg_r2_score else None,
"avg_accuracy": float(row.avg_accuracy) if row.avg_accuracy else None
},
"measurement_count": int(row.measurement_count),
"period": {
"start": row.first_measurement.isoformat() if row.first_measurement else None,
"end": row.last_measurement.isoformat() if row.last_measurement else None,
"days": days
}
})
return {
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id,
"trends": trends,
"period_days": days,
"total_products": len(trends)
}
except Exception as e:
logger.error("Failed to get performance trends",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
error=str(e))
return {
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id,
"trends": [],
"period_days": days,
"total_products": 0
}
async def get_best_performing_models(
self,
tenant_id: str,
metric_type: str = "accuracy_percentage",
limit: int = 10
) -> List[Dict[str, Any]]:
"""Get best performing models based on a specific metric"""
try:
# Validate metric type
valid_metrics = ["mae", "mse", "rmse", "mape", "r2_score", "accuracy_percentage"]
if metric_type not in valid_metrics:
metric_type = "accuracy_percentage"
# For error metrics (mae, mse, rmse, mape), lower is better
# For performance metrics (r2_score, accuracy_percentage), higher is better
order_desc = metric_type in ["r2_score", "accuracy_percentage"]
order_direction = "DESC" if order_desc else "ASC"
query_text = f"""
SELECT DISTINCT ON (inventory_product_id, model_id)
model_id,
inventory_product_id,
{metric_type},
measured_at,
evaluation_samples
FROM model_performance_metrics
WHERE tenant_id = :tenant_id
AND {metric_type} IS NOT NULL
ORDER BY inventory_product_id, model_id, measured_at DESC, {metric_type} {order_direction}
LIMIT :limit
"""
result = await self.session.execute(text(query_text), {
"tenant_id": tenant_id,
"limit": limit
})
best_models = []
for row in result.fetchall():
best_models.append({
"model_id": row.model_id,
"inventory_product_id": row.inventory_product_id,
"metric_value": float(getattr(row, metric_type)),
"metric_type": metric_type,
"measured_at": row.measured_at.isoformat() if row.measured_at else None,
"evaluation_samples": int(row.evaluation_samples) if row.evaluation_samples else None
})
return best_models
except Exception as e:
logger.error("Failed to get best performing models",
tenant_id=tenant_id,
metric_type=metric_type,
error=str(e))
return []
async def cleanup_old_metrics(self, days_old: int = 180) -> int:
"""Clean up old performance metrics"""
return await self.cleanup_old_records(days_old=days_old)
async def get_metric_statistics(self, tenant_id: str) -> Dict[str, Any]:
"""Get performance metric statistics for a tenant"""
try:
# Get basic counts
total_metrics = await self.count(filters={"tenant_id": tenant_id})
# Get metrics by product using raw query
product_query = text("""
SELECT
inventory_product_id,
COUNT(*) as metric_count,
AVG(accuracy_percentage) as avg_accuracy
FROM model_performance_metrics
WHERE tenant_id = :tenant_id
GROUP BY inventory_product_id
ORDER BY avg_accuracy DESC
""")
result = await self.session.execute(product_query, {"tenant_id": tenant_id})
product_stats = {}
for row in result.fetchall():
product_stats[row.inventory_product_id] = {
"metric_count": row.metric_count,
"avg_accuracy": float(row.avg_accuracy) if row.avg_accuracy else None
}
# Recent activity (metrics in last 7 days)
seven_days_ago = datetime.now() - timedelta(days=7)
recent_metrics = len(await self.get_records_by_date_range(
seven_days_ago,
datetime.now(),
limit=1000 # High limit to get accurate count
))
return {
"total_metrics": total_metrics,
"products_tracked": len(product_stats),
"metrics_by_product": product_stats,
"recent_metrics_7d": recent_metrics
}
except Exception as e:
logger.error("Failed to get metric statistics",
tenant_id=tenant_id,
error=str(e))
return {
"total_metrics": 0,
"products_tracked": 0,
"metrics_by_product": {},
"recent_metrics_7d": 0
}
async def compare_model_performance(
self,
model_ids: List[str],
metric_type: str = "accuracy_percentage"
) -> Dict[str, Any]:
"""Compare performance between multiple models"""
try:
if not model_ids or len(model_ids) < 2:
return {"error": "At least 2 model IDs required for comparison"}
# Validate metric type
valid_metrics = ["mae", "mse", "rmse", "mape", "r2_score", "accuracy_percentage"]
if metric_type not in valid_metrics:
metric_type = "accuracy_percentage"
model_ids_str = "', '".join(model_ids)
query_text = f"""
SELECT
model_id,
inventory_product_id,
AVG({metric_type}) as avg_metric,
MIN({metric_type}) as min_metric,
MAX({metric_type}) as max_metric,
COUNT(*) as measurement_count,
MAX(measured_at) as latest_measurement
FROM model_performance_metrics
WHERE model_id IN ('{model_ids_str}')
AND {metric_type} IS NOT NULL
GROUP BY model_id, inventory_product_id
ORDER BY avg_metric DESC
"""
result = await self.session.execute(text(query_text))
comparisons = []
for row in result.fetchall():
comparisons.append({
"model_id": row.model_id,
"inventory_product_id": row.inventory_product_id,
"avg_metric": float(row.avg_metric),
"min_metric": float(row.min_metric),
"max_metric": float(row.max_metric),
"measurement_count": int(row.measurement_count),
"latest_measurement": row.latest_measurement.isoformat() if row.latest_measurement else None
})
# Find best and worst performing models
if comparisons:
best_model = max(comparisons, key=lambda x: x["avg_metric"])
worst_model = min(comparisons, key=lambda x: x["avg_metric"])
else:
best_model = worst_model = None
return {
"metric_type": metric_type,
"models_compared": len(set(comp["model_id"] for comp in comparisons)),
"comparisons": comparisons,
"best_performing": best_model,
"worst_performing": worst_model
}
except Exception as e:
logger.error("Failed to compare model performance",
model_ids=model_ids,
metric_type=metric_type,
error=str(e))
return {"error": f"Comparison failed: {str(e)}"}

View File

@@ -0,0 +1,507 @@
"""
Training Log Repository
Repository for model training log operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, text, desc
from datetime import datetime, timedelta
import structlog
from .base import TrainingBaseRepository
from app.models.training import ModelTrainingLog
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class TrainingLogRepository(TrainingBaseRepository):
"""Repository for training log operations"""
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 300):
# Training logs change frequently, shorter cache time (5 minutes)
super().__init__(ModelTrainingLog, session, cache_ttl)
async def create_training_log(self, log_data: Dict[str, Any]) -> ModelTrainingLog:
"""Create a new training log entry"""
try:
# Validate log data
validation_result = self._validate_training_data(
log_data,
["job_id", "tenant_id", "status"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid training log data: {validation_result['errors']}")
# Set default values
if "progress" not in log_data:
log_data["progress"] = 0
if "current_step" not in log_data:
log_data["current_step"] = "initializing"
# Create log entry
log_entry = await self.create(log_data)
logger.info("Training log created",
job_id=log_entry.job_id,
tenant_id=log_entry.tenant_id,
status=log_entry.status)
return log_entry
except ValidationError:
raise
except Exception as e:
logger.error("Failed to create training log",
job_id=log_data.get("job_id"),
error=str(e))
raise DatabaseError(f"Failed to create training log: {str(e)}")
async def get_log_by_job_id(self, job_id: str) -> Optional[ModelTrainingLog]:
"""Get training log by job ID"""
return await self.get_by_job_id(job_id)
async def update_log_progress(
self,
job_id: str,
progress: int,
current_step: str = None,
status: str = None
) -> Optional[ModelTrainingLog]:
"""Update training log progress"""
try:
update_data = {"progress": progress, "updated_at": datetime.now()}
if current_step:
update_data["current_step"] = current_step
if status:
update_data["status"] = status
log_entry = await self.get_by_job_id(job_id)
if not log_entry:
logger.error(f"Training log not found for job {job_id}")
return None
updated_log = await self.update(log_entry.id, update_data)
logger.debug("Training log progress updated",
job_id=job_id,
progress=progress,
step=current_step)
return updated_log
except Exception as e:
logger.error("Failed to update training log progress",
job_id=job_id,
error=str(e))
raise DatabaseError(f"Failed to update progress: {str(e)}")
async def complete_training_log(
self,
job_id: str,
results: Dict[str, Any] = None,
error_message: str = None
) -> Optional[ModelTrainingLog]:
"""Mark training log as completed or failed"""
try:
status = "failed" if error_message else "completed"
update_data = {
"status": status,
"progress": 100 if status == "completed" else None,
"end_time": datetime.now(),
"updated_at": datetime.now()
}
if results:
update_data["results"] = results
if error_message:
update_data["error_message"] = error_message
log_entry = await self.get_by_job_id(job_id)
if not log_entry:
logger.error(f"Training log not found for job {job_id}")
return None
updated_log = await self.update(log_entry.id, update_data)
logger.info("Training log completed",
job_id=job_id,
status=status,
has_results=bool(results))
return updated_log
except Exception as e:
logger.error("Failed to complete training log",
job_id=job_id,
error=str(e))
raise DatabaseError(f"Failed to complete training log: {str(e)}")
async def get_logs_by_tenant(
self,
tenant_id: str,
status: str = None,
skip: int = 0,
limit: int = 100
) -> List[ModelTrainingLog]:
"""Get training logs for a tenant"""
try:
filters = {"tenant_id": tenant_id}
if status:
filters["status"] = status
return await self.get_multi(
filters=filters,
skip=skip,
limit=limit,
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get logs by tenant",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get training logs: {str(e)}")
async def get_active_jobs(self, tenant_id: str = None) -> List[ModelTrainingLog]:
"""Get currently running training jobs"""
try:
filters = {"status": "running"}
if tenant_id:
filters["tenant_id"] = tenant_id
return await self.get_multi(
filters=filters,
order_by="start_time",
order_desc=True
)
except Exception as e:
logger.error("Failed to get active jobs",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get active jobs: {str(e)}")
async def cancel_job(self, job_id: str, cancelled_by: str = None) -> Optional[ModelTrainingLog]:
"""Cancel a training job"""
try:
update_data = {
"status": "cancelled",
"end_time": datetime.now(),
"updated_at": datetime.now()
}
if cancelled_by:
update_data["error_message"] = f"Cancelled by {cancelled_by}"
log_entry = await self.get_by_job_id(job_id)
if not log_entry:
logger.error(f"Training log not found for job {job_id}")
return None
# Only cancel if job is still running
if log_entry.status not in ["pending", "running"]:
logger.warning(f"Cannot cancel job {job_id} with status {log_entry.status}")
return log_entry
updated_log = await self.update(log_entry.id, update_data)
logger.info("Training job cancelled",
job_id=job_id,
cancelled_by=cancelled_by)
return updated_log
except Exception as e:
logger.error("Failed to cancel training job",
job_id=job_id,
error=str(e))
raise DatabaseError(f"Failed to cancel job: {str(e)}")
async def get_job_statistics(self, tenant_id: str = None) -> Dict[str, Any]:
"""Get training job statistics"""
try:
base_filters = {}
if tenant_id:
base_filters["tenant_id"] = tenant_id
# Get counts by status
total_jobs = await self.count(filters=base_filters)
completed_jobs = await self.count(filters={**base_filters, "status": "completed"})
failed_jobs = await self.count(filters={**base_filters, "status": "failed"})
running_jobs = await self.count(filters={**base_filters, "status": "running"})
pending_jobs = await self.count(filters={**base_filters, "status": "pending"})
# Get recent activity (jobs in last 7 days)
seven_days_ago = datetime.now() - timedelta(days=7)
recent_jobs = len(await self.get_records_by_date_range(
seven_days_ago,
datetime.now(),
limit=1000 # High limit to get accurate count
))
# Calculate success rate
finished_jobs = completed_jobs + failed_jobs
success_rate = (completed_jobs / finished_jobs * 100) if finished_jobs > 0 else 0
return {
"total_jobs": total_jobs,
"completed_jobs": completed_jobs,
"failed_jobs": failed_jobs,
"running_jobs": running_jobs,
"pending_jobs": pending_jobs,
"cancelled_jobs": total_jobs - completed_jobs - failed_jobs - running_jobs - pending_jobs,
"success_rate": round(success_rate, 2),
"recent_jobs_7d": recent_jobs
}
except Exception as e:
logger.error("Failed to get job statistics",
tenant_id=tenant_id,
error=str(e))
return {
"total_jobs": 0,
"completed_jobs": 0,
"failed_jobs": 0,
"running_jobs": 0,
"pending_jobs": 0,
"cancelled_jobs": 0,
"success_rate": 0.0,
"recent_jobs_7d": 0
}
async def cleanup_old_logs(self, days_old: int = 90) -> int:
"""Clean up old completed/failed training logs"""
return await self.cleanup_old_records(
days_old=days_old,
status_filter=None # Clean up all old records regardless of status
)
async def get_job_duration_stats(self, tenant_id: str = None) -> Dict[str, Any]:
"""Get job duration statistics"""
try:
# Use raw SQL for complex duration calculations
tenant_filter = "AND tenant_id = :tenant_id" if tenant_id else ""
params = {"tenant_id": tenant_id} if tenant_id else {}
query = text(f"""
SELECT
AVG(EXTRACT(EPOCH FROM (end_time - start_time))/60) as avg_duration_minutes,
MIN(EXTRACT(EPOCH FROM (end_time - start_time))/60) as min_duration_minutes,
MAX(EXTRACT(EPOCH FROM (end_time - start_time))/60) as max_duration_minutes,
COUNT(*) as completed_jobs_with_duration
FROM model_training_logs
WHERE status = 'completed'
AND start_time IS NOT NULL
AND end_time IS NOT NULL
{tenant_filter}
""")
result = await self.session.execute(query, params)
row = result.fetchone()
if row and row.completed_jobs_with_duration > 0:
return {
"avg_duration_minutes": round(float(row.avg_duration_minutes or 0), 2),
"min_duration_minutes": round(float(row.min_duration_minutes or 0), 2),
"max_duration_minutes": round(float(row.max_duration_minutes or 0), 2),
"completed_jobs_with_duration": int(row.completed_jobs_with_duration)
}
return {
"avg_duration_minutes": 0.0,
"min_duration_minutes": 0.0,
"max_duration_minutes": 0.0,
"completed_jobs_with_duration": 0
}
except Exception as e:
logger.error("Failed to get job duration statistics",
tenant_id=tenant_id,
error=str(e))
return {
"avg_duration_minutes": 0.0,
"min_duration_minutes": 0.0,
"max_duration_minutes": 0.0,
"completed_jobs_with_duration": 0
}
async def get_start_time(self, job_id: str) -> Optional[datetime]:
"""Get the start time for a training job"""
try:
log_entry = await self.get_by_job_id(job_id)
if log_entry and log_entry.start_time:
return log_entry.start_time
return None
except Exception as e:
logger.error("Failed to get start time",
job_id=job_id,
error=str(e))
return None
async def create_job_atomic(
self,
job_id: str,
tenant_id: str,
config: Dict[str, Any] = None
) -> tuple[Optional[ModelTrainingLog], bool]:
"""
Atomically create a training job, respecting the unique constraint.
This method uses INSERT ... ON CONFLICT to handle race conditions
when multiple pods try to create a job for the same tenant simultaneously.
The database constraint (idx_unique_active_training_per_tenant) ensures
only one active job per tenant can exist.
Args:
job_id: Unique job identifier
tenant_id: Tenant identifier
config: Optional job configuration
Returns:
Tuple of (job, created):
- If created: (new_job, True)
- If conflict (existing active job): (existing_job, False)
- If error: raises DatabaseError
"""
try:
# First, try to find an existing active job
existing = await self.get_active_jobs(tenant_id=tenant_id)
pending = await self.get_logs_by_tenant(tenant_id=tenant_id, status="pending", limit=1)
if existing or pending:
# Return existing job
active_job = existing[0] if existing else pending[0]
logger.info("Found existing active job, skipping creation",
existing_job_id=active_job.job_id,
tenant_id=tenant_id,
requested_job_id=job_id)
return (active_job, False)
# Try to create the new job
# If another pod created one in the meantime, the unique constraint will prevent this
log_data = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "pending",
"progress": 0,
"current_step": "initializing",
"config": config or {}
}
try:
new_job = await self.create_training_log(log_data)
await self.session.commit()
logger.info("Created new training job atomically",
job_id=job_id,
tenant_id=tenant_id)
return (new_job, True)
except Exception as create_error:
error_str = str(create_error).lower()
# Check if this is a unique constraint violation
if "unique" in error_str or "duplicate" in error_str or "constraint" in error_str:
await self.session.rollback()
# Another pod created a job, fetch it
logger.info("Unique constraint hit, fetching existing job",
tenant_id=tenant_id,
requested_job_id=job_id)
existing = await self.get_active_jobs(tenant_id=tenant_id)
pending = await self.get_logs_by_tenant(tenant_id=tenant_id, status="pending", limit=1)
if existing or pending:
active_job = existing[0] if existing else pending[0]
return (active_job, False)
# If still no job found, something went wrong
raise DatabaseError(f"Constraint violation but no active job found: {create_error}")
else:
raise
except DatabaseError:
raise
except Exception as e:
logger.error("Failed to create job atomically",
job_id=job_id,
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to create training job atomically: {str(e)}")
async def recover_stale_jobs(self, stale_threshold_minutes: int = 60) -> List[ModelTrainingLog]:
"""
Find and mark stale running jobs as failed.
This is used during service startup to clean up jobs that were
running when a pod crashed. With multiple replicas, only stale
jobs (not updated recently) should be marked as failed.
Args:
stale_threshold_minutes: Jobs not updated for this long are considered stale
Returns:
List of jobs that were marked as failed
"""
try:
stale_cutoff = datetime.now() - timedelta(minutes=stale_threshold_minutes)
# Find running jobs that haven't been updated recently
query = text("""
SELECT id, job_id, tenant_id, status, updated_at
FROM model_training_logs
WHERE status IN ('running', 'pending')
AND updated_at < :stale_cutoff
""")
result = await self.session.execute(query, {"stale_cutoff": stale_cutoff})
stale_jobs = result.fetchall()
recovered_jobs = []
for row in stale_jobs:
try:
# Mark as failed
update_query = text("""
UPDATE model_training_logs
SET status = 'failed',
error_message = :error_msg,
end_time = :end_time,
updated_at = :updated_at
WHERE id = :id AND status IN ('running', 'pending')
""")
await self.session.execute(update_query, {
"id": row.id,
"error_msg": f"Job recovered as failed - not updated since {row.updated_at.isoformat()}. Pod may have crashed.",
"end_time": datetime.now(),
"updated_at": datetime.now()
})
logger.warning("Recovered stale training job",
job_id=row.job_id,
tenant_id=str(row.tenant_id),
last_updated=row.updated_at.isoformat() if row.updated_at else "unknown")
# Fetch the updated job to return
job = await self.get_by_job_id(row.job_id)
if job:
recovered_jobs.append(job)
except Exception as job_error:
logger.error("Failed to recover individual stale job",
job_id=row.job_id,
error=str(job_error))
if recovered_jobs:
await self.session.commit()
logger.info("Stale job recovery completed",
recovered_count=len(recovered_jobs),
stale_threshold_minutes=stale_threshold_minutes)
return recovered_jobs
except Exception as e:
logger.error("Failed to recover stale jobs",
error=str(e))
await self.session.rollback()
return []

View File

@@ -0,0 +1,384 @@
# services/training/app/schemas/training.py
"""
Complete schema definitions for training service
Includes all request/response schemas used by the API endpoints
"""
from pydantic import BaseModel, Field, validator
from typing import List, Optional, Dict, Any, Union, Tuple
from datetime import datetime
from enum import Enum
from uuid import UUID
class TrainingStatus(str, Enum):
"""Training job status enumeration"""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class TrainingJobRequest(BaseModel):
"""Request schema for starting a training job"""
products: Optional[List[str]] = Field(None, description="Specific products to train (if None, trains all)")
start_date: Optional[datetime] = Field(None, description="Start date for training data")
end_date: Optional[datetime] = Field(None, description="End date for training data")
class SingleProductTrainingRequest(BaseModel):
"""Request schema for training a single product"""
start_date: Optional[datetime] = Field(None, description="Start date for training data")
end_date: Optional[datetime] = Field(None, description="End date for training data")
# Prophet-specific parameters
seasonality_mode: str = Field("additive", description="Prophet seasonality mode")
daily_seasonality: bool = Field(True, description="Enable daily seasonality")
weekly_seasonality: bool = Field(True, description="Enable weekly seasonality")
yearly_seasonality: bool = Field(True, description="Enable yearly seasonality")
# Location parameters
bakery_location: Optional[Tuple[float, float]] = Field(None, description="Bakery coordinates (latitude, longitude)")
class DateRangeInfo(BaseModel):
"""Schema for date range information"""
start: str = Field(..., description="Start date in ISO format")
end: str = Field(..., description="End date in ISO format")
class DataSummary(BaseModel):
"""Schema for training data summary"""
sales_records: int = Field(..., description="Number of sales records used")
weather_records: int = Field(..., description="Number of weather records used")
traffic_records: int = Field(..., description="Number of traffic records used")
date_range: DateRangeInfo = Field(..., description="Date range of training data")
data_sources_used: List[str] = Field(..., description="List of data sources used")
constraints_applied: Dict[str, str] = Field(default_factory=dict, description="Constraints applied during data collection")
class ProductTrainingResult(BaseModel):
"""Schema for individual product training results"""
inventory_product_id: UUID = Field(..., description="Inventory product UUID")
status: str = Field(..., description="Training status for this product")
model_id: Optional[str] = Field(None, description="Trained model identifier")
data_points: int = Field(..., description="Number of data points used for training")
metrics: Optional[Dict[str, float]] = Field(None, description="Training metrics (MAE, MAPE, etc.)")
training_time_seconds: Optional[float] = Field(None, description="Time taken to train this model")
error_message: Optional[str] = Field(None, description="Error message if training failed")
class TrainingResults(BaseModel):
"""Schema for overall training results"""
total_products: int = Field(..., description="Total number of products")
successful_trainings: int = Field(..., description="Number of successfully trained models")
failed_trainings: int = Field(..., description="Number of failed trainings")
products: List[ProductTrainingResult] = Field(..., description="Results for each product")
overall_training_time_seconds: float = Field(..., description="Total training time")
class TrainingJobResponse(BaseModel):
"""Enhanced response schema for training job with detailed results"""
job_id: str = Field(..., description="Unique training job identifier")
tenant_id: str = Field(..., description="Tenant identifier")
status: TrainingStatus = Field(..., description="Overall job status")
# Required fields for basic response (backwards compatibility)
message: str = Field(..., description="Status message")
created_at: datetime = Field(..., description="Job creation timestamp")
estimated_duration_minutes: int = Field(..., description="Estimated completion time in minutes")
# New detailed fields (optional for backwards compatibility)
training_results: Optional[TrainingResults] = Field(None, description="Detailed training results")
data_summary: Optional[DataSummary] = Field(None, description="Summary of training data used")
completed_at: Optional[str] = Field(None, description="Job completion timestamp in ISO format")
# Additional optional fields
error_details: Optional[Dict[str, Any]] = Field(None, description="Detailed error information if failed")
processing_metadata: Optional[Dict[str, Any]] = Field(None, description="Additional processing metadata")
@validator('tenant_id', 'job_id', pre=True)
def convert_uuid_to_string(cls, v):
"""Convert UUID objects to strings for JSON serialization"""
if isinstance(v, UUID):
return str(v)
return v
class Config:
from_attributes = True
class TrainingJobStatus(BaseModel):
"""Response schema for training job status checks"""
job_id: str = Field(..., description="Training job identifier")
status: TrainingStatus = Field(..., description="Current job status")
progress: int = Field(0, description="Progress percentage (0-100)")
current_step: str = Field("", description="Current processing step")
started_at: datetime = Field(..., description="Job start timestamp")
completed_at: Optional[datetime] = Field(None, description="Job completion timestamp")
products_total: int = Field(0, description="Total number of products to train")
products_completed: int = Field(0, description="Number of products completed")
products_failed: int = Field(0, description="Number of products that failed")
error_message: Optional[str] = Field(None, description="Error message if failed")
estimated_time_remaining_seconds: Optional[int] = Field(None, description="Estimated time remaining in seconds")
message: Optional[str] = Field(None, description="Optional status message")
@validator('job_id', pre=True)
def convert_uuid_to_string(cls, v):
"""Convert UUID objects to strings for JSON serialization"""
if isinstance(v, UUID):
return str(v)
return v
class Config:
from_attributes = True
class TrainingJobProgress(BaseModel):
"""Schema for real-time training job progress updates"""
job_id: str = Field(..., description="Training job identifier")
status: TrainingStatus = Field(..., description="Current job status")
progress: int = Field(0, description="Progress percentage (0-100)", ge=0, le=100)
current_step: str = Field(..., description="Current processing step")
current_product: Optional[str] = Field(None, description="Currently training product")
products_completed: int = Field(0, description="Number of products completed")
products_total: int = Field(0, description="Total number of products")
estimated_time_remaining_minutes: Optional[int] = Field(None, description="Estimated time remaining")
timestamp: datetime = Field(default_factory=datetime.now, description="Progress update timestamp")
@validator('job_id', pre=True)
def convert_uuid_to_string(cls, v):
"""Convert UUID objects to strings for JSON serialization"""
if isinstance(v, UUID):
return str(v)
return v
class Config:
from_attributes = True
class DataValidationRequest(BaseModel):
"""Request schema for validating training data"""
products: Optional[List[str]] = Field(None, description="Specific products to validate (if None, validates all)")
min_data_points: int = Field(30, description="Minimum required data points per product", ge=10, le=1000)
start_date: Optional[datetime] = Field(None, description="Start date for data validation")
end_date: Optional[datetime] = Field(None, description="End date for data validation")
@validator('min_data_points')
def validate_min_data_points(cls, v):
if v < 10:
raise ValueError('min_data_points must be at least 10')
return v
class DataValidationResponse(BaseModel):
"""Response schema for data validation results"""
is_valid: bool = Field(..., description="Whether the data is valid for training")
issues: List[str] = Field(default_factory=list, description="List of data quality issues")
recommendations: List[str] = Field(default_factory=list, description="Recommendations for improvement")
estimated_time_minutes: int = Field(..., description="Estimated training time in minutes")
products_analyzed: int = Field(..., description="Number of products analyzed")
total_data_points: int = Field(..., description="Total data points available")
products_with_insufficient_data: List[str] = Field(default_factory=list, description="Products with insufficient data")
data_quality_score: float = Field(0.0, description="Overall data quality score (0-1)", ge=0.0, le=1.0)
class ModelInfo(BaseModel):
"""Schema for trained model information"""
model_id: str = Field(..., description="Unique model identifier")
model_path: str = Field(..., description="Path to stored model")
model_type: str = Field("prophet", description="Type of ML model")
training_samples: int = Field(..., description="Number of training samples")
features: List[str] = Field(..., description="List of features used")
hyperparameters: Dict[str, Any] = Field(..., description="Model hyperparameters")
training_metrics: Dict[str, float] = Field(..., description="Training performance metrics")
trained_at: datetime = Field(..., description="Training completion timestamp")
data_period: Dict[str, str] = Field(..., description="Training data period")
class ProductTrainingResult(BaseModel):
"""Schema for individual product training result"""
inventory_product_id: UUID = Field(..., description="Inventory product UUID")
status: str = Field(..., description="Training status for this product")
model_info: Optional[ModelInfo] = Field(None, description="Model information if successful")
data_points: int = Field(..., description="Number of data points used")
error_message: Optional[str] = Field(None, description="Error message if failed")
trained_at: datetime = Field(..., description="Training completion timestamp")
training_duration_seconds: Optional[float] = Field(None, description="Training duration in seconds")
class TrainingResultsResponse(BaseModel):
"""Response schema for complete training results"""
job_id: str = Field(..., description="Training job identifier")
tenant_id: str = Field(..., description="Tenant identifier")
status: TrainingStatus = Field(..., description="Overall job status")
products_trained: int = Field(..., description="Number of products successfully trained")
products_failed: int = Field(..., description="Number of products that failed training")
total_products: int = Field(..., description="Total number of products processed")
training_results: Dict[str, ProductTrainingResult] = Field(..., description="Per-product results")
summary: Dict[str, Any] = Field(..., description="Training summary statistics")
completed_at: datetime = Field(..., description="Job completion timestamp")
@validator('tenant_id', 'job_id', pre=True)
def convert_uuid_to_string(cls, v):
"""Convert UUID objects to strings for JSON serialization"""
if isinstance(v, UUID):
return str(v)
return v
class Config:
from_attributes = True
class TrainingValidationResult(BaseModel):
"""Schema for training data validation results"""
is_valid: bool = Field(..., description="Whether the data is valid for training")
issues: List[str] = Field(default_factory=list, description="List of data quality issues")
recommendations: List[str] = Field(default_factory=list, description="Recommendations for improvement")
estimated_time_minutes: int = Field(..., description="Estimated training time in minutes")
products_analyzed: int = Field(..., description="Number of products analyzed")
total_data_points: int = Field(..., description="Total data points available")
class TrainingMetrics(BaseModel):
"""Schema for training performance metrics"""
mae: float = Field(..., description="Mean Absolute Error")
mse: float = Field(..., description="Mean Squared Error")
rmse: float = Field(..., description="Root Mean Squared Error")
mape: float = Field(..., description="Mean Absolute Percentage Error")
r2_score: float = Field(..., description="R-squared score")
mean_actual: float = Field(..., description="Mean of actual values")
mean_predicted: float = Field(..., description="Mean of predicted values")
class ExternalDataConfig(BaseModel):
"""Configuration for external data sources"""
weather_enabled: bool = Field(True, description="Enable weather data")
traffic_enabled: bool = Field(True, description="Enable traffic data")
weather_features: List[str] = Field(
default_factory=lambda: ["temperature", "precipitation", "humidity"],
description="Weather features to include"
)
traffic_features: List[str] = Field(
default_factory=lambda: ["traffic_volume"],
description="Traffic features to include"
)
class TrainingJobConfig(BaseModel):
"""Complete training job configuration"""
external_data: ExternalDataConfig = Field(default_factory=ExternalDataConfig)
prophet_params: Dict[str, Any] = Field(
default_factory=lambda: {
"seasonality_mode": "additive",
"daily_seasonality": True,
"weekly_seasonality": True,
"yearly_seasonality": True
},
description="Prophet model parameters"
)
data_filters: Dict[str, Any] = Field(
default_factory=dict,
description="Data filtering parameters"
)
validation_params: Dict[str, Any] = Field(
default_factory=lambda: {"min_data_points": 30},
description="Data validation parameters"
)
class TrainedModelResponse(BaseModel):
"""Response schema for trained model information"""
model_id: str = Field(..., description="Unique model identifier")
tenant_id: str = Field(..., description="Tenant identifier")
inventory_product_id: UUID = Field(..., description="Inventory product UUID")
model_type: str = Field(..., description="Type of ML model")
model_path: str = Field(..., description="Path to stored model")
version: int = Field(..., description="Model version")
training_samples: int = Field(..., description="Number of training samples")
features: List[str] = Field(..., description="List of features used")
hyperparameters: Dict[str, Any] = Field(..., description="Model hyperparameters")
training_metrics: Dict[str, float] = Field(..., description="Training performance metrics")
is_active: bool = Field(..., description="Whether model is active")
created_at: datetime = Field(..., description="Model creation timestamp")
data_period_start: Optional[datetime] = Field(None, description="Training data start date")
data_period_end: Optional[datetime] = Field(None, description="Training data end date")
@validator('tenant_id', 'model_id', pre=True)
def convert_uuid_to_string(cls, v):
"""Convert UUID objects to strings for JSON serialization"""
if isinstance(v, UUID):
return str(v)
return v
class Config:
from_attributes = True
class ModelTrainingStats(BaseModel):
"""Schema for model training statistics"""
total_models: int = Field(..., description="Total number of trained models")
active_models: int = Field(..., description="Number of active models")
last_training_date: Optional[datetime] = Field(None, description="Last training date")
avg_training_time_minutes: float = Field(..., description="Average training time in minutes")
success_rate: float = Field(..., description="Training success rate (0-1)")
class BulkTrainingRequest(BaseModel):
"""Request schema for bulk training operations"""
tenant_ids: List[str] = Field(..., description="List of tenant IDs to train")
config: TrainingJobConfig = Field(default_factory=TrainingJobConfig, description="Training configuration")
priority: int = Field(1, description="Training priority (1-10)", ge=1, le=10)
schedule_time: Optional[datetime] = Field(None, description="Schedule training for specific time")
class TrainingScheduleResponse(BaseModel):
"""Response schema for scheduled training jobs"""
schedule_id: str = Field(..., description="Unique schedule identifier")
tenant_ids: List[str] = Field(..., description="Scheduled tenant IDs")
scheduled_time: datetime = Field(..., description="Scheduled execution time")
status: str = Field(..., description="Schedule status")
created_at: datetime = Field(..., description="Schedule creation timestamp")
# WebSocket response schemas for real-time updates
class TrainingProgressUpdate(BaseModel):
"""WebSocket message for training progress updates"""
type: str = Field("training_progress", description="Message type")
job_id: str = Field(..., description="Training job identifier")
progress: TrainingJobProgress = Field(..., description="Progress information")
class TrainingCompletedUpdate(BaseModel):
"""WebSocket message for training completion"""
type: str = Field("training_completed", description="Message type")
job_id: str = Field(..., description="Training job identifier")
results: TrainingResultsResponse = Field(..., description="Training results")
class TrainingErrorUpdate(BaseModel):
"""WebSocket message for training errors"""
type: str = Field("training_error", description="Message type")
job_id: str = Field(..., description="Training job identifier")
error: str = Field(..., description="Error message")
timestamp: datetime = Field(default_factory=datetime.now, description="Error timestamp")
class ModelMetricsResponse(BaseModel):
"""Response schema for model performance metrics"""
model_id: str = Field(..., description="Unique model identifier")
accuracy: float = Field(..., description="Model accuracy (R2 score)")
mape: float = Field(..., description="Mean Absolute Percentage Error")
mae: float = Field(..., description="Mean Absolute Error")
rmse: float = Field(..., description="Root Mean Square Error")
r2_score: float = Field(..., description="R-squared score")
training_samples: int = Field(..., description="Number of training samples used")
features_used: List[str] = Field(..., description="List of features used in training")
model_type: str = Field(..., description="Type of ML model")
created_at: Optional[str] = Field(None, description="Model creation timestamp")
last_used_at: Optional[str] = Field(None, description="Last time model was used")
class Config:
from_attributes = True
# Union type for all WebSocket messages
TrainingWebSocketMessage = Union[
TrainingProgressUpdate,
TrainingCompletedUpdate,
TrainingErrorUpdate
]

View File

@@ -0,0 +1,317 @@
"""
Comprehensive Input Validation Schemas
Ensures all API inputs are properly validated before processing
"""
from pydantic import BaseModel, Field, validator, root_validator
from typing import Optional, List, Dict, Any
from datetime import datetime, timedelta
from uuid import UUID
import re
class TrainingJobCreateRequest(BaseModel):
"""Schema for creating a new training job"""
tenant_id: UUID = Field(..., description="Tenant identifier")
start_date: Optional[str] = Field(
None,
description="Training data start date (ISO format: YYYY-MM-DD)",
example="2024-01-01"
)
end_date: Optional[str] = Field(
None,
description="Training data end date (ISO format: YYYY-MM-DD)",
example="2024-12-31"
)
product_ids: Optional[List[UUID]] = Field(
None,
description="Specific products to train (optional, trains all if not provided)"
)
force_retrain: bool = Field(
default=False,
description="Force retraining even if recent models exist"
)
@validator('start_date', 'end_date')
def validate_date_format(cls, v):
"""Validate date is in ISO format"""
if v is not None:
try:
datetime.fromisoformat(v)
except ValueError:
raise ValueError(f"Invalid date format: {v}. Use YYYY-MM-DD format")
return v
@root_validator
def validate_date_range(cls, values):
"""Validate date range is logical"""
start = values.get('start_date')
end = values.get('end_date')
if start and end:
start_dt = datetime.fromisoformat(start)
end_dt = datetime.fromisoformat(end)
if end_dt <= start_dt:
raise ValueError("end_date must be after start_date")
# Check reasonable range (max 3 years)
if (end_dt - start_dt).days > 1095:
raise ValueError("Date range cannot exceed 3 years (1095 days)")
# Check not in future
if end_dt > datetime.now():
raise ValueError("end_date cannot be in the future")
return values
class Config:
schema_extra = {
"example": {
"tenant_id": "123e4567-e89b-12d3-a456-426614174000",
"start_date": "2024-01-01",
"end_date": "2024-12-31",
"product_ids": None,
"force_retrain": False
}
}
class ForecastRequest(BaseModel):
"""Schema for generating forecasts"""
tenant_id: UUID = Field(..., description="Tenant identifier")
product_id: UUID = Field(..., description="Product identifier")
forecast_days: int = Field(
default=30,
ge=1,
le=365,
description="Number of days to forecast (1-365)"
)
include_regressors: bool = Field(
default=True,
description="Include weather and traffic data in forecast"
)
confidence_level: float = Field(
default=0.80,
ge=0.5,
le=0.99,
description="Confidence interval (0.5-0.99)"
)
class Config:
schema_extra = {
"example": {
"tenant_id": "123e4567-e89b-12d3-a456-426614174000",
"product_id": "223e4567-e89b-12d3-a456-426614174000",
"forecast_days": 30,
"include_regressors": True,
"confidence_level": 0.80
}
}
class ModelEvaluationRequest(BaseModel):
"""Schema for model evaluation"""
tenant_id: UUID = Field(..., description="Tenant identifier")
product_id: Optional[UUID] = Field(None, description="Specific product (optional)")
evaluation_start_date: str = Field(..., description="Evaluation period start")
evaluation_end_date: str = Field(..., description="Evaluation period end")
@validator('evaluation_start_date', 'evaluation_end_date')
def validate_date_format(cls, v):
try:
datetime.fromisoformat(v)
except ValueError:
raise ValueError(f"Invalid date format: {v}")
return v
@root_validator
def validate_evaluation_period(cls, values):
start = values.get('evaluation_start_date')
end = values.get('evaluation_end_date')
if start and end:
start_dt = datetime.fromisoformat(start)
end_dt = datetime.fromisoformat(end)
if end_dt <= start_dt:
raise ValueError("evaluation_end_date must be after evaluation_start_date")
# Minimum 7 days for meaningful evaluation
if (end_dt - start_dt).days < 7:
raise ValueError("Evaluation period must be at least 7 days")
return values
class BulkTrainingRequest(BaseModel):
"""Schema for bulk training operations"""
tenant_ids: List[UUID] = Field(
...,
min_items=1,
max_items=100,
description="List of tenant IDs (max 100)"
)
start_date: Optional[str] = Field(None, description="Common start date")
end_date: Optional[str] = Field(None, description="Common end date")
parallel: bool = Field(
default=True,
description="Execute training jobs in parallel"
)
@validator('tenant_ids')
def validate_unique_tenants(cls, v):
if len(v) != len(set(v)):
raise ValueError("Duplicate tenant IDs not allowed")
return v
class HyperparameterOverride(BaseModel):
"""Schema for manual hyperparameter override"""
changepoint_prior_scale: Optional[float] = Field(
None, ge=0.001, le=0.5,
description="Flexibility of trend changes"
)
seasonality_prior_scale: Optional[float] = Field(
None, ge=0.01, le=10.0,
description="Strength of seasonality"
)
holidays_prior_scale: Optional[float] = Field(
None, ge=0.01, le=10.0,
description="Strength of holiday effects"
)
seasonality_mode: Optional[str] = Field(
None,
description="Seasonality mode",
regex="^(additive|multiplicative)$"
)
daily_seasonality: Optional[bool] = None
weekly_seasonality: Optional[bool] = None
yearly_seasonality: Optional[bool] = None
class Config:
schema_extra = {
"example": {
"changepoint_prior_scale": 0.05,
"seasonality_prior_scale": 10.0,
"holidays_prior_scale": 10.0,
"seasonality_mode": "additive",
"daily_seasonality": False,
"weekly_seasonality": True,
"yearly_seasonality": True
}
}
class AdvancedTrainingRequest(TrainingJobCreateRequest):
"""Extended training request with advanced options"""
hyperparameter_override: Optional[HyperparameterOverride] = Field(
None,
description="Manual hyperparameter settings (skips optimization)"
)
enable_cross_validation: bool = Field(
default=True,
description="Enable cross-validation during training"
)
cv_folds: int = Field(
default=3,
ge=2,
le=10,
description="Number of cross-validation folds"
)
optimization_trials: Optional[int] = Field(
None,
ge=5,
le=100,
description="Number of hyperparameter optimization trials (overrides defaults)"
)
save_diagnostics: bool = Field(
default=False,
description="Save detailed diagnostic plots and metrics"
)
class DataQualityCheckRequest(BaseModel):
"""Schema for data quality validation"""
tenant_id: UUID = Field(..., description="Tenant identifier")
start_date: str = Field(..., description="Check period start")
end_date: str = Field(..., description="Check period end")
product_ids: Optional[List[UUID]] = Field(
None,
description="Specific products to check"
)
include_recommendations: bool = Field(
default=True,
description="Include improvement recommendations"
)
@validator('start_date', 'end_date')
def validate_date(cls, v):
try:
datetime.fromisoformat(v)
except ValueError:
raise ValueError(f"Invalid date format: {v}")
return v
class ModelQueryParams(BaseModel):
"""Query parameters for model listing"""
tenant_id: Optional[UUID] = None
product_id: Optional[UUID] = None
is_active: Optional[bool] = None
is_production: Optional[bool] = None
model_type: Optional[str] = Field(None, regex="^(prophet|prophet_optimized|lstm|arima)$")
min_accuracy: Optional[float] = Field(None, ge=0.0, le=1.0)
created_after: Optional[datetime] = None
created_before: Optional[datetime] = None
limit: int = Field(default=100, ge=1, le=1000)
offset: int = Field(default=0, ge=0)
class Config:
schema_extra = {
"example": {
"tenant_id": "123e4567-e89b-12d3-a456-426614174000",
"is_active": True,
"is_production": True,
"limit": 50,
"offset": 0
}
}
def validate_uuid(value: str) -> UUID:
"""Validate and convert string to UUID"""
try:
return UUID(value)
except (ValueError, AttributeError):
raise ValueError(f"Invalid UUID format: {value}")
def validate_date_string(value: str) -> datetime:
"""Validate and convert date string to datetime"""
try:
return datetime.fromisoformat(value)
except ValueError:
raise ValueError(f"Invalid date format: {value}. Use ISO format (YYYY-MM-DD)")
def validate_positive_integer(value: int, field_name: str = "value") -> int:
"""Validate positive integer"""
if value <= 0:
raise ValueError(f"{field_name} must be positive, got {value}")
return value
def validate_probability(value: float, field_name: str = "value") -> float:
"""Validate probability value (0.0-1.0)"""
if not 0.0 <= value <= 1.0:
raise ValueError(f"{field_name} must be between 0.0 and 1.0, got {value}")
return value

View File

@@ -0,0 +1,16 @@
"""
Training Service Layer
Business logic services for ML training and model management
"""
from .training_service import EnhancedTrainingService
from .training_orchestrator import TrainingDataOrchestrator
from .date_alignment_service import DateAlignmentService
from .data_client import DataClient
__all__ = [
"EnhancedTrainingService",
"TrainingDataOrchestrator",
"DateAlignmentService",
"DataClient"
]

View File

@@ -0,0 +1,410 @@
# services/training/app/services/data_client.py
"""
Training Service Data Client
Migrated to use shared service clients with timeout configuration
"""
import structlog
from typing import Dict, Any, List, Optional
from datetime import datetime
import httpx
# Import the shared clients
from shared.clients import get_sales_client, get_external_client, get_service_clients
from app.core.config import settings
from app.core import constants as const
from app.utils.circuit_breaker import circuit_breaker_registry, CircuitBreakerError
from app.utils.retry import with_retry, HTTP_RETRY_STRATEGY, EXTERNAL_SERVICE_RETRY_STRATEGY
logger = structlog.get_logger()
class DataClient:
"""
Data client for training service
Now uses the shared data service client under the hood
"""
def __init__(self):
# Get the new specialized clients with timeout configuration
self.sales_client = get_sales_client(settings, "training")
self.external_client = get_external_client(settings, "training")
# ExternalServiceClient always has get_stored_traffic_data_for_training method
self.supports_stored_traffic_data = True
# Configure timeouts for HTTP clients
self._configure_timeouts()
# Initialize circuit breakers for external services
self._init_circuit_breakers()
def _configure_timeouts(self):
"""Configure appropriate timeouts for HTTP clients"""
timeout = httpx.Timeout(
connect=const.HTTP_TIMEOUT_DEFAULT,
read=const.HTTP_TIMEOUT_LONG_RUNNING,
write=const.HTTP_TIMEOUT_DEFAULT,
pool=const.HTTP_TIMEOUT_DEFAULT
)
# Apply timeout to clients if they have httpx clients
# Note: BaseServiceClient manages its own HTTP client internally
if hasattr(self.sales_client, 'client') and isinstance(self.sales_client.client, httpx.AsyncClient):
self.sales_client.client.timeout = timeout
if hasattr(self.external_client, 'client') and isinstance(self.external_client.client, httpx.AsyncClient):
self.external_client.client.timeout = timeout
def _init_circuit_breakers(self):
"""Initialize circuit breakers for external service calls"""
# Sales service circuit breaker
self.sales_cb = circuit_breaker_registry.get_or_create(
name="sales_service",
failure_threshold=5,
recovery_timeout=60.0,
expected_exception=Exception
)
# Weather service circuit breaker
self.weather_cb = circuit_breaker_registry.get_or_create(
name="weather_service",
failure_threshold=3, # Weather is optional, fail faster
recovery_timeout=30.0,
expected_exception=Exception
)
# Traffic service circuit breaker
self.traffic_cb = circuit_breaker_registry.get_or_create(
name="traffic_service",
failure_threshold=3, # Traffic is optional, fail faster
recovery_timeout=30.0,
expected_exception=Exception
)
@with_retry(max_attempts=3, initial_delay=1.0, max_delay=10.0)
async def _fetch_sales_data_internal(
self,
tenant_id: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
product_id: Optional[str] = None,
fetch_all: bool = True
) -> List[Dict[str, Any]]:
"""Internal method to fetch sales data with automatic retry"""
if fetch_all:
sales_data = await self.sales_client.get_all_sales_data(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
product_id=product_id,
aggregation="daily",
page_size=1000,
max_pages=100
)
else:
sales_data = await self.sales_client.get_sales_data(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
product_id=product_id,
aggregation="daily"
)
sales_data = sales_data or []
if sales_data:
logger.info(f"Fetched {len(sales_data)} sales records",
tenant_id=tenant_id, product_id=product_id, fetch_all=fetch_all)
return sales_data
else:
logger.error("No sales data returned", tenant_id=tenant_id)
raise ValueError(f"No sales data available for tenant {tenant_id}")
async def fetch_sales_data(
self,
tenant_id: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
product_id: Optional[str] = None,
fetch_all: bool = True
) -> List[Dict[str, Any]]:
"""
Fetch sales data for training with circuit breaker protection
"""
try:
return await self.sales_cb.call(
self._fetch_sales_data_internal,
tenant_id, start_date, end_date, product_id, fetch_all
)
except CircuitBreakerError as exc:
logger.error("Sales service circuit breaker open", error_message=str(exc))
raise RuntimeError(f"Sales service unavailable: {str(exc)}")
except ValueError:
raise
except Exception as exc:
logger.error("Error fetching sales data", tenant_id=tenant_id, error_message=str(exc))
raise RuntimeError(f"Failed to fetch sales data: {str(exc)}")
async def fetch_weather_data(
self,
tenant_id: str,
start_date: str,
end_date: str,
latitude: Optional[float] = None,
longitude: Optional[float] = None
) -> List[Dict[str, Any]]:
"""
Fetch weather data for training
All the error handling and retry logic is now in the base client!
"""
try:
weather_data = await self.external_client.get_weather_historical(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
latitude=latitude,
longitude=longitude
)
if weather_data:
logger.info(f"Fetched {len(weather_data)} weather records",
tenant_id=tenant_id)
return weather_data
else:
logger.warning("No weather data returned, will use synthetic data", tenant_id=tenant_id)
return []
except Exception as exc:
logger.warning("Error fetching weather data, will use synthetic data", tenant_id=tenant_id, error_message=str(exc))
return []
async def fetch_traffic_data_unified(
self,
tenant_id: str,
start_date: str,
end_date: str,
latitude: Optional[float] = None,
longitude: Optional[float] = None,
force_refresh: bool = False
) -> List[Dict[str, Any]]:
"""
Unified traffic data fetching with intelligent cache-first strategy
Strategy:
1. Check if stored/cached traffic data exists for the date range
2. If exists and not force_refresh, return cached data
3. If not exists or force_refresh, fetch fresh data
4. Always return data without duplicate fetching
Args:
tenant_id: Tenant identifier
start_date: Start date string (ISO format)
end_date: End date string (ISO format)
latitude: Optional latitude for location-based data
longitude: Optional longitude for location-based data
force_refresh: If True, bypass cache and fetch fresh data
"""
cache_key = f"{tenant_id}_{start_date}_{end_date}_{latitude}_{longitude}"
try:
# Step 1: Try to get stored/cached data first (unless force_refresh)
if not force_refresh and self.supports_stored_traffic_data:
logger.info("Attempting to fetch cached traffic data",
tenant_id=tenant_id, cache_key=cache_key)
try:
cached_data = await self.external_client.get_stored_traffic_data_for_training(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
latitude=latitude,
longitude=longitude
)
if cached_data and len(cached_data) > 0:
logger.info(f"✅ Using cached traffic data: {len(cached_data)} records",
tenant_id=tenant_id)
return cached_data
else:
logger.info("No cached traffic data found, fetching fresh data",
tenant_id=tenant_id)
except Exception as cache_error:
logger.warning(f"Cache fetch failed, falling back to fresh data: {cache_error}",
tenant_id=tenant_id)
# Step 2: Fetch fresh data if no cache or force_refresh
logger.info("Fetching fresh traffic data" + (" (force refresh)" if force_refresh else ""),
tenant_id=tenant_id)
fresh_data = await self.external_client.get_traffic_data(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
latitude=latitude,
longitude=longitude
)
if fresh_data and len(fresh_data) > 0:
logger.info(f"✅ Fetched fresh traffic data: {len(fresh_data)} records",
tenant_id=tenant_id)
return fresh_data
else:
logger.warning("No fresh traffic data available", tenant_id=tenant_id)
return []
except Exception as exc:
logger.error("Error in unified traffic data fetch",
tenant_id=tenant_id, cache_key=cache_key, error_message=str(exc))
return []
# Legacy methods for backward compatibility - now delegate to unified method
async def fetch_traffic_data(
self,
tenant_id: str,
start_date: str,
end_date: str,
latitude: Optional[float] = None,
longitude: Optional[float] = None
) -> List[Dict[str, Any]]:
"""Legacy method - delegates to unified fetcher with cache-first strategy"""
logger.info("Legacy fetch_traffic_data called - delegating to unified method", tenant_id=tenant_id)
return await self.fetch_traffic_data_unified(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
latitude=latitude,
longitude=longitude,
force_refresh=False # Use cache-first for legacy calls
)
async def fetch_stored_traffic_data_for_training(
self,
tenant_id: str,
start_date: str,
end_date: str,
latitude: Optional[float] = None,
longitude: Optional[float] = None
) -> List[Dict[str, Any]]:
"""Legacy method - delegates to unified fetcher with cache-first strategy"""
logger.info("Legacy fetch_stored_traffic_data_for_training called - delegating to unified method", tenant_id=tenant_id)
return await self.fetch_traffic_data_unified(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
latitude=latitude,
longitude=longitude,
force_refresh=False # Use cache-first for training calls
)
async def refresh_traffic_data(
self,
tenant_id: str,
start_date: str,
end_date: str,
latitude: Optional[float] = None,
longitude: Optional[float] = None
) -> List[Dict[str, Any]]:
"""Convenience method to force refresh traffic data"""
logger.info("Force refreshing traffic data (bypassing cache)", tenant_id=tenant_id)
return await self.fetch_traffic_data_unified(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
latitude=latitude,
longitude=longitude,
force_refresh=True # Force fresh data
)
async def validate_data_quality(
self,
tenant_id: str,
start_date: str,
end_date: str,
sales_data: List[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Validate data quality before training with comprehensive checks
"""
try:
errors = []
warnings = []
# If sales data provided, validate it directly
if sales_data is not None:
if not sales_data or len(sales_data) == 0:
errors.append("No sales data available for the specified period")
return {"is_valid": False, "errors": errors, "warnings": warnings}
# Check minimum data points
if len(sales_data) < 30:
errors.append(f"Insufficient data points: {len(sales_data)} (minimum 30 required)")
elif len(sales_data) < 90:
warnings.append(f"Limited data points: {len(sales_data)} (recommended 90+)")
# Check for required fields
required_fields = ['date', 'inventory_product_id']
for record in sales_data[:5]: # Sample check
missing = [f for f in required_fields if f not in record or record[f] is None]
if missing:
errors.append(f"Missing required fields: {missing}")
break
# Check for data quality issues
zero_count = sum(1 for r in sales_data if r.get('quantity', 0) == 0)
zero_ratio = zero_count / len(sales_data)
if zero_ratio > 0.9:
errors.append(f"Too many zero values: {zero_ratio:.1%} of records")
elif zero_ratio > 0.7:
warnings.append(f"High zero value ratio: {zero_ratio:.1%}")
# Check product diversity
unique_products = set(r.get('inventory_product_id') for r in sales_data if r.get('inventory_product_id'))
if len(unique_products) == 0:
errors.append("No valid product IDs found in sales data")
elif len(unique_products) == 1:
warnings.append("Only one product found - consider adding more products")
else:
# Fetch data for validation
sales_data = await self.fetch_sales_data(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
fetch_all=False
)
if not sales_data:
errors.append("Unable to fetch sales data for validation")
return {"is_valid": False, "errors": errors, "warnings": warnings}
# Recursive call with fetched data
return await self.validate_data_quality(
tenant_id, start_date, end_date, sales_data
)
is_valid = len(errors) == 0
result = {
"is_valid": is_valid,
"errors": errors,
"warnings": warnings,
"data_points": len(sales_data) if sales_data else 0,
"unique_products": len(unique_products) if sales_data else 0
}
if is_valid:
logger.info("Data validation passed",
tenant_id=tenant_id,
data_points=result["data_points"],
warnings_count=len(warnings))
else:
logger.error("Data validation failed",
tenant_id=tenant_id,
errors=errors)
return result
except Exception as exc:
logger.error("Error validating data", tenant_id=tenant_id, error_message=str(exc))
raise ValueError(f"Data validation failed: {str(exc)}")
# Global instance - same as before, but much simpler implementation
data_client = DataClient()

View File

@@ -0,0 +1,239 @@
from datetime import datetime, timedelta, timezone
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from enum import Enum
import logging
from app.utils.ml_datetime import ensure_timezone_aware
logger = logging.getLogger(__name__)
class DataSourceType(Enum):
BAKERY_SALES = "bakery_sales"
MADRID_TRAFFIC = "madrid_traffic"
WEATHER_FORECAST = "weather_forecast"
@dataclass
class DateRange:
start: datetime
end: datetime
source: DataSourceType
def duration_days(self) -> int:
return (self.end - self.start).days
def overlaps_with(self, other: 'DateRange') -> bool:
return self.start <= other.end and other.start <= self.end
@dataclass
class AlignedDateRange:
start: datetime
end: datetime
available_sources: List[DataSourceType]
constraints: Dict[str, str]
class DateAlignmentService:
"""
Central service for managing and aligning dates across multiple data sources
for the bakery sales prediction model.
"""
def __init__(self):
self.MAX_TRAINING_RANGE_DAYS = 730 # Maximum training data range
self.MIN_TRAINING_RANGE_DAYS = 30 # Minimum viable training data
def validate_and_align_dates(
self,
user_sales_range: DateRange,
requested_start: Optional[datetime] = None,
requested_end: Optional[datetime] = None
) -> AlignedDateRange:
"""
Main method to validate and align dates across all data sources.
Args:
user_sales_range: Date range of user-provided sales data
requested_start: Optional explicit start date for training
requested_end: Optional explicit end date for training
Returns:
AlignedDateRange with validated start/end dates and available sources
"""
try:
# Step 1: Determine the base date range
base_range = self._determine_base_range(
user_sales_range, requested_start, requested_end
)
# Step 2: Apply data source constraints
aligned_range = self._apply_data_source_constraints(base_range)
# Step 3: Validate final range
self._validate_final_range(aligned_range)
logger.info(f"Date alignment completed: {aligned_range.start} to {aligned_range.end}")
return aligned_range
except Exception as e:
logger.error(f"Date alignment failed: {str(e)}")
raise ValueError(f"Unable to align dates: {str(e)}")
def _determine_base_range(
self,
user_sales_range: DateRange,
requested_start: Optional[datetime],
requested_end: Optional[datetime]
) -> DateRange:
"""Determine the base date range for training."""
# Use explicit dates if provided
if requested_start and requested_end:
requested_start = ensure_timezone_aware(requested_start)
requested_end = ensure_timezone_aware(requested_end)
if requested_end <= requested_start:
raise ValueError("End date must be after start date")
return DateRange(requested_start, requested_end, DataSourceType.BAKERY_SALES)
# Otherwise, use the user's sales data range as the foundation
start_date = ensure_timezone_aware(requested_start or user_sales_range.start)
end_date = ensure_timezone_aware(requested_end or user_sales_range.end)
# Ensure we don't exceed maximum training range
if (end_date - start_date).days > self.MAX_TRAINING_RANGE_DAYS:
start_date = end_date - timedelta(days=self.MAX_TRAINING_RANGE_DAYS)
logger.warning(f"Limiting training range to {self.MAX_TRAINING_RANGE_DAYS} days")
return DateRange(start_date, end_date, DataSourceType.BAKERY_SALES)
def _apply_data_source_constraints(self, base_range: DateRange) -> AlignedDateRange:
"""Apply constraints from each data source and determine final aligned range."""
current_month = datetime.now(timezone.utc).replace(day=1, hour=0, minute=0, second=0, microsecond=0)
available_sources = [DataSourceType.BAKERY_SALES] # Always have sales data
constraints = {}
# Madrid Traffic Data Constraint
madrid_end_date = self._get_madrid_traffic_end_date()
if base_range.end > madrid_end_date:
# If requested end date is in current month, adjust it
new_end = madrid_end_date
constraints["madrid_traffic"] = f"Adjusted end date to {new_end.date()} (latest available traffic data)"
logger.info(f"Madrid traffic constraint: end date adjusted to {new_end.date()}")
else:
new_end = base_range.end
available_sources.append(DataSourceType.MADRID_TRAFFIC)
# Weather Forecast Constraint
# Weather data available from yesterday backward
weather_end_date = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0) - timedelta(days=1)
if base_range.end > weather_end_date:
if new_end > weather_end_date:
new_end = weather_end_date
constraints["weather"] = f"Adjusted end date to {new_end.date()} (latest available weather data)"
logger.info(f"Weather constraint: end date adjusted to {new_end.date()}")
if new_end >= base_range.start:
available_sources.append(DataSourceType.WEATHER_FORECAST)
# Ensure minimum training period
final_start = base_range.start
if (new_end - final_start).days < self.MIN_TRAINING_RANGE_DAYS:
final_start = new_end - timedelta(days=self.MIN_TRAINING_RANGE_DAYS)
constraints["minimum_period"] = f"Adjusted start date to ensure {self.MIN_TRAINING_RANGE_DAYS} day minimum training period"
logger.info(f"Minimum period constraint: start date adjusted to {final_start.date()}")
return AlignedDateRange(
start=final_start,
end=new_end,
available_sources=available_sources,
constraints=constraints
)
def _get_madrid_traffic_end_date(self) -> datetime:
"""
Get the latest available date for Madrid traffic data.
Data for current month is not available until the following month.
"""
now = datetime.now(timezone.utc)
# Data up to the previous month is available
# Go to first day of current month, then subtract 1 day to get last day of previous month
last_day_of_previous_month = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0) - timedelta(days=1)
return last_day_of_previous_month
def _validate_final_range(self, aligned_range: AlignedDateRange) -> None:
"""Validate the final aligned date range."""
if aligned_range.start >= aligned_range.end:
raise ValueError("Invalid date range: start date must be before end date")
duration = (aligned_range.end - aligned_range.start).days
if duration < self.MIN_TRAINING_RANGE_DAYS:
raise ValueError(f"Insufficient training data: {duration} days (minimum: {self.MIN_TRAINING_RANGE_DAYS})")
if duration > self.MAX_TRAINING_RANGE_DAYS:
raise ValueError(f"Training period too long: {duration} days (maximum: {self.MAX_TRAINING_RANGE_DAYS})")
# Ensure we have at least sales data
if DataSourceType.BAKERY_SALES not in aligned_range.available_sources:
raise ValueError("No sales data available for the aligned date range")
def get_data_collection_plan(self, aligned_range: AlignedDateRange) -> Dict[str, Dict]:
"""
Generate a data collection plan based on the aligned date range.
Returns:
Dictionary with collection plans for each data source
"""
plan = {}
# Bakery Sales Data
if DataSourceType.BAKERY_SALES in aligned_range.available_sources:
plan["sales_data"] = {
"start_date": aligned_range.start,
"end_date": aligned_range.end,
"source": "user_upload",
"required": True
}
# Madrid Traffic Data
if DataSourceType.MADRID_TRAFFIC in aligned_range.available_sources:
plan["traffic_data"] = {
"start_date": aligned_range.start,
"end_date": aligned_range.end,
"source": "madrid_opendata",
"required": False,
"constraint": "Cannot request current month data"
}
# Weather Data
if DataSourceType.WEATHER_FORECAST in aligned_range.available_sources:
plan["weather_data"] = {
"start_date": aligned_range.start,
"end_date": aligned_range.end,
"source": "aemet_api",
"required": False,
"constraint": "Available from yesterday backward"
}
return plan
def check_madrid_current_month_constraint(self, end_date: datetime) -> bool:
"""
Check if the end date violates the Madrid Open Data current month constraint.
Args:
end_date: The requested end date
Returns:
True if the constraint is violated (end date is in current month)
"""
now = datetime.now(timezone.utc)
current_month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
# Debug logging
logger.info(f"🔍 Madrid constraint check: end_date={end_date}, current_month_start={current_month_start}, violation={end_date >= current_month_start}")
return end_date >= current_month_start

View File

@@ -0,0 +1,120 @@
"""
Training Progress Tracker
Manages progress calculation for parallel product training (20-80% range)
"""
import asyncio
import structlog
from typing import Optional
from datetime import datetime, timezone
from app.services.training_events import publish_product_training_completed
from app.utils.time_estimation import calculate_estimated_completion_time
from app.core.training_constants import (
PROGRESS_TRAINING_RANGE_START,
PROGRESS_TRAINING_RANGE_END,
PROGRESS_TRAINING_RANGE_WIDTH
)
logger = structlog.get_logger()
class ParallelProductProgressTracker:
"""
Tracks parallel product training progress and emits events.
For N products training in parallel:
- Each product completion contributes 60/N% to overall progress
- Progress range: 20% (after data analysis) to 80% (before completion)
- Thread-safe for concurrent product trainings
- Calculates time estimates based on elapsed time and progress
"""
def __init__(self, job_id: str, tenant_id: str, total_products: int):
self.job_id = job_id
self.tenant_id = tenant_id
self.total_products = max(total_products, 1) # Ensure at least 1 to avoid division by zero
self.products_completed = 0
self._lock = asyncio.Lock()
self.start_time = datetime.now(timezone.utc)
# Calculate progress increment per product
# Training range (from PROGRESS_TRAINING_RANGE_START to PROGRESS_TRAINING_RANGE_END) divided by number of products
self.progress_per_product = PROGRESS_TRAINING_RANGE_WIDTH / self.total_products if self.total_products > 0 else 0
if total_products == 0:
logger.warning("ParallelProductProgressTracker initialized with zero products",
job_id=job_id)
logger.info("ParallelProductProgressTracker initialized",
job_id=job_id,
total_products=self.total_products,
progress_per_product=f"{self.progress_per_product:.2f}%")
async def mark_product_completed(self, product_name: str) -> int:
"""
Mark a product as completed and publish event with time estimates.
Returns the current overall progress percentage.
"""
async with self._lock:
self.products_completed += 1
current_progress = self.products_completed
# Calculate time estimates based on elapsed time and progress
elapsed_seconds = (datetime.now(timezone.utc) - self.start_time).total_seconds()
products_remaining = self.total_products - current_progress
# Calculate estimated time remaining
# Avg time per product * remaining products
estimated_time_remaining_seconds = None
estimated_completion_time = None
if current_progress > 0 and products_remaining > 0:
avg_time_per_product = elapsed_seconds / current_progress
estimated_time_remaining_seconds = int(avg_time_per_product * products_remaining)
# Calculate estimated completion time
estimated_duration_minutes = estimated_time_remaining_seconds / 60
completion_datetime = calculate_estimated_completion_time(estimated_duration_minutes)
estimated_completion_time = completion_datetime.isoformat()
# Publish product completion event with time estimates
await publish_product_training_completed(
job_id=self.job_id,
tenant_id=self.tenant_id,
product_name=product_name,
products_completed=current_progress,
total_products=self.total_products,
estimated_time_remaining_seconds=estimated_time_remaining_seconds,
estimated_completion_time=estimated_completion_time
)
# Calculate overall progress (PROGRESS_TRAINING_RANGE_START% base + progress from completed products)
# This calculation is done on the frontend/consumer side based on the event data
if self.total_products > 0:
overall_progress = PROGRESS_TRAINING_RANGE_START + int((current_progress / self.total_products) * PROGRESS_TRAINING_RANGE_WIDTH)
else:
overall_progress = PROGRESS_TRAINING_RANGE_START
logger.info("Product training completed",
job_id=self.job_id,
product_name=product_name,
products_completed=current_progress,
total_products=self.total_products,
overall_progress=overall_progress,
estimated_time_remaining_seconds=estimated_time_remaining_seconds)
return overall_progress
def get_progress(self) -> dict:
"""Get current progress summary"""
if self.total_products > 0:
progress_percentage = PROGRESS_TRAINING_RANGE_START + int((self.products_completed / self.total_products) * PROGRESS_TRAINING_RANGE_WIDTH)
else:
progress_percentage = PROGRESS_TRAINING_RANGE_START
return {
"products_completed": self.products_completed,
"total_products": self.total_products,
"progress_percentage": progress_percentage
}

View File

@@ -0,0 +1,339 @@
# services/training/app/services/tenant_deletion_service.py
"""
Tenant Data Deletion Service for Training Service
Handles deletion of all training-related data for a tenant
"""
from typing import Dict
from sqlalchemy import select, func, delete
from sqlalchemy.ext.asyncio import AsyncSession
import structlog
from shared.services.tenant_deletion import (
BaseTenantDataDeletionService,
TenantDataDeletionResult
)
from app.models import (
TrainedModel,
ModelTrainingLog,
ModelPerformanceMetric,
TrainingJobQueue,
ModelArtifact,
AuditLog
)
logger = structlog.get_logger(__name__)
class TrainingTenantDeletionService(BaseTenantDataDeletionService):
"""Service for deleting all training-related data for a tenant"""
def __init__(self, db: AsyncSession):
self.db = db
self.service_name = "training"
async def get_tenant_data_preview(self, tenant_id: str) -> Dict[str, int]:
"""
Get counts of what would be deleted for a tenant (dry-run)
Args:
tenant_id: The tenant ID to preview deletion for
Returns:
Dictionary with entity names and their counts
"""
logger.info("training.tenant_deletion.preview", tenant_id=tenant_id)
preview = {}
try:
# Count trained models
model_count = await self.db.scalar(
select(func.count(TrainedModel.id)).where(
TrainedModel.tenant_id == tenant_id
)
)
preview["trained_models"] = model_count or 0
# Count model artifacts
artifact_count = await self.db.scalar(
select(func.count(ModelArtifact.id)).where(
ModelArtifact.tenant_id == tenant_id
)
)
preview["model_artifacts"] = artifact_count or 0
# Count training logs
log_count = await self.db.scalar(
select(func.count(ModelTrainingLog.id)).where(
ModelTrainingLog.tenant_id == tenant_id
)
)
preview["model_training_logs"] = log_count or 0
# Count performance metrics
metric_count = await self.db.scalar(
select(func.count(ModelPerformanceMetric.id)).where(
ModelPerformanceMetric.tenant_id == tenant_id
)
)
preview["model_performance_metrics"] = metric_count or 0
# Count training job queue entries
queue_count = await self.db.scalar(
select(func.count(TrainingJobQueue.id)).where(
TrainingJobQueue.tenant_id == tenant_id
)
)
preview["training_job_queue"] = queue_count or 0
# Count audit logs
audit_count = await self.db.scalar(
select(func.count(AuditLog.id)).where(
AuditLog.tenant_id == tenant_id
)
)
preview["audit_logs"] = audit_count or 0
logger.info(
"training.tenant_deletion.preview_complete",
tenant_id=tenant_id,
preview=preview
)
except Exception as e:
logger.error(
"training.tenant_deletion.preview_error",
tenant_id=tenant_id,
error=str(e),
exc_info=True
)
raise
return preview
async def delete_tenant_data(self, tenant_id: str) -> TenantDataDeletionResult:
"""
Permanently delete all training data for a tenant
Deletion order:
1. ModelArtifact (references models)
2. ModelPerformanceMetric (references models)
3. ModelTrainingLog (independent job logs)
4. TrainingJobQueue (independent queue entries)
5. TrainedModel (parent model records)
6. AuditLog (independent)
Note: This also deletes physical model files from disk/storage
Args:
tenant_id: The tenant ID to delete data for
Returns:
TenantDataDeletionResult with deletion counts and any errors
"""
logger.info("training.tenant_deletion.started", tenant_id=tenant_id)
result = TenantDataDeletionResult(tenant_id=tenant_id, service_name=self.service_name)
try:
import os
# Step 1: Delete model artifacts (references models)
logger.info("training.tenant_deletion.deleting_artifacts", tenant_id=tenant_id)
# Delete physical files from storage before deleting DB records
artifacts = await self.db.execute(
select(ModelArtifact).where(ModelArtifact.tenant_id == tenant_id)
)
deleted_files = 0
failed_files = 0
for artifact in artifacts.scalars():
try:
if artifact.file_path and os.path.exists(artifact.file_path):
os.remove(artifact.file_path)
deleted_files += 1
logger.info("Deleted artifact file",
path=artifact.file_path,
artifact_id=artifact.id)
except Exception as e:
failed_files += 1
logger.warning("Failed to delete artifact file",
path=artifact.file_path,
artifact_id=artifact.id if hasattr(artifact, 'id') else 'unknown',
error=str(e))
logger.info("Artifact files deletion complete",
deleted_files=deleted_files,
failed_files=failed_files)
# Now delete DB records
artifacts_result = await self.db.execute(
delete(ModelArtifact).where(
ModelArtifact.tenant_id == tenant_id
)
)
result.deleted_counts["model_artifacts"] = artifacts_result.rowcount
result.deleted_counts["artifact_files_deleted"] = deleted_files
result.deleted_counts["artifact_files_failed"] = failed_files
logger.info(
"training.tenant_deletion.artifacts_deleted",
tenant_id=tenant_id,
count=artifacts_result.rowcount
)
# Step 2: Delete model performance metrics
logger.info("training.tenant_deletion.deleting_metrics", tenant_id=tenant_id)
metrics_result = await self.db.execute(
delete(ModelPerformanceMetric).where(
ModelPerformanceMetric.tenant_id == tenant_id
)
)
result.deleted_counts["model_performance_metrics"] = metrics_result.rowcount
logger.info(
"training.tenant_deletion.metrics_deleted",
tenant_id=tenant_id,
count=metrics_result.rowcount
)
# Step 3: Delete training logs
logger.info("training.tenant_deletion.deleting_logs", tenant_id=tenant_id)
logs_result = await self.db.execute(
delete(ModelTrainingLog).where(
ModelTrainingLog.tenant_id == tenant_id
)
)
result.deleted_counts["model_training_logs"] = logs_result.rowcount
logger.info(
"training.tenant_deletion.logs_deleted",
tenant_id=tenant_id,
count=logs_result.rowcount
)
# Step 4: Delete training job queue entries
logger.info("training.tenant_deletion.deleting_queue", tenant_id=tenant_id)
queue_result = await self.db.execute(
delete(TrainingJobQueue).where(
TrainingJobQueue.tenant_id == tenant_id
)
)
result.deleted_counts["training_job_queue"] = queue_result.rowcount
logger.info(
"training.tenant_deletion.queue_deleted",
tenant_id=tenant_id,
count=queue_result.rowcount
)
# Step 5: Delete trained models (parent records)
logger.info("training.tenant_deletion.deleting_models", tenant_id=tenant_id)
# Delete physical model files (.pkl) before deleting DB records
models = await self.db.execute(
select(TrainedModel).where(TrainedModel.tenant_id == tenant_id)
)
deleted_model_files = 0
failed_model_files = 0
for model in models.scalars():
try:
# Delete .pkl file
if hasattr(model, 'model_path') and model.model_path and os.path.exists(model.model_path):
os.remove(model.model_path)
deleted_model_files += 1
logger.info("Deleted model file",
path=model.model_path,
model_id=model.id)
# Delete model_file_path if it exists
if hasattr(model, 'model_file_path') and model.model_file_path and os.path.exists(model.model_file_path):
os.remove(model.model_file_path)
deleted_model_files += 1
logger.info("Deleted model file",
path=model.model_file_path,
model_id=model.id)
# Delete metadata file if exists
if hasattr(model, 'metadata_path') and model.metadata_path and os.path.exists(model.metadata_path):
os.remove(model.metadata_path)
logger.info("Deleted metadata file",
path=model.metadata_path,
model_id=model.id)
except Exception as e:
failed_model_files += 1
logger.warning("Failed to delete model file",
path=getattr(model, 'model_path', getattr(model, 'model_file_path', 'unknown')),
model_id=model.id if hasattr(model, 'id') else 'unknown',
error=str(e))
logger.info("Model files deletion complete",
deleted_files=deleted_model_files,
failed_files=failed_model_files)
# Now delete DB records
models_result = await self.db.execute(
delete(TrainedModel).where(
TrainedModel.tenant_id == tenant_id
)
)
result.deleted_counts["trained_models"] = models_result.rowcount
result.deleted_counts["model_files_deleted"] = deleted_model_files
result.deleted_counts["model_files_failed"] = failed_model_files
logger.info(
"training.tenant_deletion.models_deleted",
tenant_id=tenant_id,
count=models_result.rowcount
)
# Step 6: Delete audit logs
logger.info("training.tenant_deletion.deleting_audit_logs", tenant_id=tenant_id)
audit_result = await self.db.execute(
delete(AuditLog).where(
AuditLog.tenant_id == tenant_id
)
)
result.deleted_counts["audit_logs"] = audit_result.rowcount
logger.info(
"training.tenant_deletion.audit_logs_deleted",
tenant_id=tenant_id,
count=audit_result.rowcount
)
# Commit the transaction
await self.db.commit()
# Calculate total deleted
total_deleted = sum(result.deleted_counts.values())
logger.info(
"training.tenant_deletion.completed",
tenant_id=tenant_id,
total_deleted=total_deleted,
breakdown=result.deleted_counts,
note="Physical model files should be cleaned up separately"
)
result.success = True
except Exception as e:
await self.db.rollback()
error_msg = f"Failed to delete training data for tenant {tenant_id}: {str(e)}"
logger.error(
"training.tenant_deletion.failed",
tenant_id=tenant_id,
error=str(e),
exc_info=True
)
result.errors.append(error_msg)
result.success = False
return result
def get_training_tenant_deletion_service(
db: AsyncSession
) -> TrainingTenantDeletionService:
"""
Factory function to create TrainingTenantDeletionService instance
Args:
db: AsyncSession database session
Returns:
TrainingTenantDeletionService instance
"""
return TrainingTenantDeletionService(db)

View File

@@ -0,0 +1,330 @@
"""
Training Progress Events Publisher
Simple, clean event publisher for the 4 main training steps
"""
import structlog
from datetime import datetime
from typing import Dict, Any, Optional
from shared.messaging import RabbitMQClient
from app.core.config import settings
logger = structlog.get_logger()
# Single global publisher instance
training_publisher = RabbitMQClient(settings.RABBITMQ_URL, "training-service")
async def setup_messaging():
"""Initialize messaging"""
success = await training_publisher.connect()
if success:
logger.info("Training messaging initialized")
else:
logger.warning("Training messaging failed to initialize")
return success
async def cleanup_messaging():
"""Cleanup messaging"""
await training_publisher.disconnect()
logger.info("Training messaging cleaned up")
# ==========================================
# 4 MAIN TRAINING PROGRESS EVENTS
# ==========================================
async def publish_training_started(
job_id: str,
tenant_id: str,
total_products: int,
estimated_duration_minutes: Optional[int] = None,
estimated_completion_time: Optional[str] = None
) -> bool:
"""
Event 1: Training Started (0% progress)
Args:
job_id: Training job identifier
tenant_id: Tenant identifier
total_products: Number of products to train
estimated_duration_minutes: Estimated time to completion in minutes
estimated_completion_time: ISO timestamp of estimated completion
"""
event_data = {
"service_name": "training-service",
"event_type": "training.started",
"timestamp": datetime.now().isoformat(),
"data": {
"job_id": job_id,
"tenant_id": tenant_id,
"progress": 0,
"current_step": "Training Started",
"step_details": f"Starting training for {total_products} products",
"total_products": total_products,
"estimated_duration_minutes": estimated_duration_minutes,
"estimated_completion_time": estimated_completion_time,
"estimated_time_remaining_seconds": estimated_duration_minutes * 60 if estimated_duration_minutes else None
}
}
success = await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.started",
event_data=event_data
)
if success:
logger.info("Published training started event",
job_id=job_id,
tenant_id=tenant_id,
total_products=total_products,
estimated_duration_minutes=estimated_duration_minutes)
else:
logger.error("Failed to publish training started event", job_id=job_id)
return success
async def publish_data_analysis(
job_id: str,
tenant_id: str,
analysis_details: Optional[str] = None,
estimated_time_remaining_seconds: Optional[int] = None,
estimated_completion_time: Optional[str] = None
) -> bool:
"""
Event 2: Data Analysis (20% progress)
Args:
job_id: Training job identifier
tenant_id: Tenant identifier
analysis_details: Details about the analysis
estimated_time_remaining_seconds: Estimated time remaining in seconds
estimated_completion_time: ISO timestamp of estimated completion
"""
event_data = {
"service_name": "training-service",
"event_type": "training.progress",
"timestamp": datetime.now().isoformat(),
"data": {
"job_id": job_id,
"tenant_id": tenant_id,
"progress": 20,
"current_step": "Data Analysis",
"step_details": analysis_details or "Analyzing sales, weather, and traffic data",
"estimated_time_remaining_seconds": estimated_time_remaining_seconds,
"estimated_completion_time": estimated_completion_time
}
}
success = await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.progress",
event_data=event_data
)
if success:
logger.info("Published data analysis event",
job_id=job_id,
progress=20)
else:
logger.error("Failed to publish data analysis event", job_id=job_id)
return success
async def publish_training_progress(
job_id: str,
tenant_id: str,
progress: int,
current_step: str,
step_details: Optional[str] = None,
estimated_time_remaining_seconds: Optional[int] = None,
estimated_completion_time: Optional[str] = None
) -> bool:
"""
Generic Training Progress Event (for any progress percentage)
Args:
job_id: Training job identifier
tenant_id: Tenant identifier
progress: Progress percentage (0-100)
current_step: Current step name
step_details: Details about the current step
estimated_time_remaining_seconds: Estimated time remaining in seconds
estimated_completion_time: ISO timestamp of estimated completion
"""
event_data = {
"service_name": "training-service",
"event_type": "training.progress",
"timestamp": datetime.now().isoformat(),
"data": {
"job_id": job_id,
"tenant_id": tenant_id,
"progress": progress,
"current_step": current_step,
"step_details": step_details or current_step,
"estimated_time_remaining_seconds": estimated_time_remaining_seconds,
"estimated_completion_time": estimated_completion_time
}
}
success = await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.progress",
event_data=event_data
)
if success:
logger.info("Published training progress event",
job_id=job_id,
progress=progress,
current_step=current_step)
else:
logger.error("Failed to publish training progress event",
job_id=job_id,
progress=progress)
return success
async def publish_product_training_completed(
job_id: str,
tenant_id: str,
product_name: str,
products_completed: int,
total_products: int,
estimated_time_remaining_seconds: Optional[int] = None,
estimated_completion_time: Optional[str] = None
) -> bool:
"""
Event 3: Product Training Completed (contributes to 20-80% progress)
This event is published each time a product training completes.
The frontend/consumer will calculate the progress as:
progress = 20 + (products_completed / total_products) * 60
Args:
job_id: Training job identifier
tenant_id: Tenant identifier
product_name: Name of the product that was trained
products_completed: Number of products completed so far
total_products: Total number of products
estimated_time_remaining_seconds: Estimated time remaining in seconds
estimated_completion_time: ISO timestamp of estimated completion
"""
event_data = {
"service_name": "training-service",
"event_type": "training.product.completed",
"timestamp": datetime.now().isoformat(),
"data": {
"job_id": job_id,
"tenant_id": tenant_id,
"product_name": product_name,
"products_completed": products_completed,
"total_products": total_products,
"current_step": "Model Training",
"step_details": f"Completed training for {product_name} ({products_completed}/{total_products})",
"estimated_time_remaining_seconds": estimated_time_remaining_seconds,
"estimated_completion_time": estimated_completion_time
}
}
success = await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.product.completed",
event_data=event_data
)
if success:
logger.info("Published product training completed event",
job_id=job_id,
product_name=product_name,
products_completed=products_completed,
total_products=total_products)
else:
logger.error("Failed to publish product training completed event",
job_id=job_id)
return success
async def publish_training_completed(
job_id: str,
tenant_id: str,
successful_trainings: int,
failed_trainings: int,
total_duration_seconds: float
) -> bool:
"""
Event 4: Training Completed (100% progress)
"""
event_data = {
"service_name": "training-service",
"event_type": "training.completed",
"timestamp": datetime.now().isoformat(),
"data": {
"job_id": job_id,
"tenant_id": tenant_id,
"progress": 100,
"current_step": "Training Completed",
"step_details": f"Training completed: {successful_trainings} successful, {failed_trainings} failed",
"successful_trainings": successful_trainings,
"failed_trainings": failed_trainings,
"total_duration_seconds": total_duration_seconds
}
}
success = await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.completed",
event_data=event_data
)
if success:
logger.info("Published training completed event",
job_id=job_id,
successful_trainings=successful_trainings,
failed_trainings=failed_trainings)
else:
logger.error("Failed to publish training completed event", job_id=job_id)
return success
async def publish_training_failed(
job_id: str,
tenant_id: str,
error_message: str
) -> bool:
"""
Event: Training Failed
"""
event_data = {
"service_name": "training-service",
"event_type": "training.failed",
"timestamp": datetime.now().isoformat(),
"data": {
"job_id": job_id,
"tenant_id": tenant_id,
"current_step": "Training Failed",
"error_message": error_message
}
}
success = await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.failed",
event_data=event_data
)
if success:
logger.info("Published training failed event",
job_id=job_id,
error=error_message)
else:
logger.error("Failed to publish training failed event", job_id=job_id)
return success

View File

@@ -0,0 +1,971 @@
# services/training/app/services/training_orchestrator.py
"""
Training Data Orchestrator - Enhanced Integration Layer
Orchestrates data collection, date alignment, and preparation for ML training
"""
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass
import asyncio
import structlog
from concurrent.futures import ThreadPoolExecutor
from datetime import timezone
import pandas as pd
from app.services.data_client import DataClient
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType, AlignedDateRange
from app.ml.poi_feature_integrator import POIFeatureIntegrator
from app.services.training_events import publish_training_failed
logger = structlog.get_logger()
@dataclass
class TrainingDataSet:
"""Container for all training data with metadata"""
sales_data: List[Dict[str, Any]]
weather_data: List[Dict[str, Any]]
traffic_data: List[Dict[str, Any]]
poi_features: Dict[str, Any] # POI features for location-based forecasting
date_range: AlignedDateRange
metadata: Dict[str, Any]
class TrainingDataOrchestrator:
"""
Enhanced orchestrator for data collection from multiple sources.
Ensures date alignment, handles data source constraints, and prepares data for ML training.
Uses the new abstracted traffic service layer for multi-city support.
"""
def __init__(self,
date_alignment_service: DateAlignmentService = None,
poi_feature_integrator: POIFeatureIntegrator = None):
self.data_client = DataClient()
self.date_alignment_service = date_alignment_service or DateAlignmentService()
self.poi_feature_integrator = poi_feature_integrator or POIFeatureIntegrator()
self.max_concurrent_requests = 5 # Increased for better performance
async def prepare_training_data(
self,
tenant_id: str,
bakery_location: Tuple[float, float], # (lat, lon)
requested_start: Optional[datetime] = None,
requested_end: Optional[datetime] = None,
job_id: Optional[str] = None
) -> TrainingDataSet:
"""
Main method to prepare all training data with comprehensive date alignment.
Args:
tenant_id: Tenant identifier
sales_data: User-provided sales data
bakery_location: Bakery coordinates (lat, lon)
requested_start: Optional explicit start date
requested_end: Optional explicit end date
job_id: Training job identifier for logging
Returns:
TrainingDataSet with all aligned and validated data
"""
logger.info(f"Starting comprehensive training data preparation for tenant {tenant_id}, job {job_id}")
try:
# Step 1: Fetch and validate sales data (unified approach)
sales_data = await self.data_client.fetch_sales_data(tenant_id, fetch_all=True)
if not sales_data or len(sales_data) == 0:
error_msg = f"No sales data available for tenant {tenant_id}. Please import sales data before starting training."
logger.error("Training aborted - no sales data", tenant_id=tenant_id, job_id=job_id)
raise ValueError(error_msg)
# Debug: Analyze the sales data structure to understand product distribution
sales_df_debug = pd.DataFrame(sales_data)
if 'inventory_product_id' in sales_df_debug.columns:
unique_products_found = sales_df_debug['inventory_product_id'].unique()
product_counts = sales_df_debug['inventory_product_id'].value_counts().to_dict()
logger.info("Sales data analysis (moved from pre-flight)",
tenant_id=tenant_id,
job_id=job_id,
total_sales_records=len(sales_data),
unique_products_count=len(unique_products_found),
unique_products=unique_products_found.tolist(),
records_per_product=product_counts)
if len(unique_products_found) == 1:
logger.warning("POTENTIAL ISSUE: Only ONE unique product found in all sales data",
tenant_id=tenant_id,
single_product=unique_products_found[0],
record_count=len(sales_data))
else:
logger.warning("No 'inventory_product_id' column found in sales data",
tenant_id=tenant_id,
columns=list(sales_df_debug.columns))
logger.info(f"Sales data validation passed: {len(sales_data)} sales records found",
tenant_id=tenant_id, job_id=job_id)
# Step 2: Extract and validate sales data date range
sales_date_range = self._extract_sales_date_range(sales_data)
logger.info(f"Sales data range detected: {sales_date_range.start} to {sales_date_range.end}")
# Step 3: Apply date alignment across all data sources
aligned_range = self.date_alignment_service.validate_and_align_dates(
user_sales_range=sales_date_range,
requested_start=requested_start,
requested_end=requested_end
)
logger.info(f"Date alignment completed: {aligned_range.start} to {aligned_range.end}")
if aligned_range.constraints:
logger.info(f"Applied constraints: {aligned_range.constraints}")
# Step 4: Filter sales data to aligned date range
filtered_sales = self._filter_sales_data(sales_data, aligned_range)
# Step 5: Collect external data sources concurrently
logger.info("Collecting external data sources...")
weather_data, traffic_data, poi_features = await self._collect_external_data(
aligned_range, bakery_location, tenant_id
)
# Step 6: Validate data quality
data_quality_results = self._validate_data_sources(
filtered_sales, weather_data, traffic_data, aligned_range
)
# Step 7: Create comprehensive training dataset
training_dataset = TrainingDataSet(
sales_data=filtered_sales,
weather_data=weather_data,
traffic_data=traffic_data,
poi_features=poi_features or {}, # POI features (static, location-based)
date_range=aligned_range,
metadata={
"tenant_id": tenant_id,
"job_id": job_id,
"bakery_location": bakery_location,
"data_sources_used": aligned_range.available_sources,
"constraints_applied": aligned_range.constraints,
"data_quality": data_quality_results,
"preparation_timestamp": datetime.now().isoformat(),
"original_sales_range": {
"start": sales_date_range.start.isoformat(),
"end": sales_date_range.end.isoformat()
},
"poi_features_count": len(poi_features) if poi_features else 0
}
)
# Step 8: Final validation
final_validation = self.validate_training_data_quality(training_dataset)
training_dataset.metadata["final_validation"] = final_validation
logger.info(f"Training data preparation completed successfully:")
logger.info(f" - Sales records: {len(filtered_sales)}")
logger.info(f" - Weather records: {len(weather_data)}")
logger.info(f" - Traffic records: {len(traffic_data)}")
logger.info(f" - POI features: {len(poi_features) if poi_features else 0}")
logger.info(f" - Data quality score: {final_validation.get('data_quality_score', 'N/A')}")
return training_dataset
except Exception as e:
if job_id and tenant_id:
await publish_training_failed(job_id, tenant_id, str(e))
logger.error(f"Training data preparation failed: {str(e)}")
raise ValueError(f"Failed to prepare training data: {str(e)}")
@staticmethod
def extract_sales_date_range_utc_localize(sales_data_df: pd.DataFrame):
"""
Extracts the UTC-aware date range from a sales DataFrame using tz_localize.
Args:
sales_data_df: A pandas DataFrame containing a 'date' column.
Returns:
A tuple of timezone-aware start and end dates in UTC.
"""
if 'date' not in sales_data_df.columns:
raise ValueError("DataFrame does not contain a 'date' column.")
# Convert the 'date' column to datetime objects
sales_data_df['date'] = pd.to_datetime(sales_data_df['date'])
# Localize the naive datetime objects to UTC
sales_data_df['date'] = sales_data_df['date'].tz_localize('UTC')
# Find the minimum and maximum dates
start_date = sales_data_df['date'].min()
end_date = sales_data_df['date'].max()
return DateRange(start_date, end_date, DataSourceType.BAKERY_SALES)
def _extract_sales_date_range(self, sales_data: List[Dict[str, Any]]) -> 'DateRange':
"""
Extract date range from sales data with proper date parsing
Args:
sales_data: List of sales records
Returns:
DateRange object with timezone-aware start and end dates
"""
if not sales_data:
raise ValueError("No sales data provided for date range extraction")
# Convert to DataFrame for easier processing
sales_df = pd.DataFrame(sales_data)
if 'date' not in sales_df.columns:
raise ValueError("Sales data does not contain a 'date' column")
# Convert dates to datetime with proper parsing
# This will use the improved date parsing from the data import service
sales_df['date'] = pd.to_datetime(sales_df['date'], utc=True, errors='coerce')
# Remove any rows with invalid dates
sales_df = sales_df.dropna(subset=['date'])
if len(sales_df) == 0:
raise ValueError("No valid dates found in sales data")
# Find the minimum and maximum dates
start_date = sales_df['date'].min()
end_date = sales_df['date'].max()
logger.info(f"Extracted sales date range: {start_date} to {end_date}")
return DateRange(start_date, end_date, DataSourceType.BAKERY_SALES)
def _filter_sales_data(
self,
sales_data: List[Dict[str, Any]],
aligned_range: AlignedDateRange
) -> List[Dict[str, Any]]:
"""Filter sales data to the aligned date range with enhanced validation"""
filtered_data = []
filtered_count = 0
for record in sales_data:
try:
if 'date' in record:
record_date = record['date']
# ✅ FIX: Proper timezone handling for date parsing - FIXED THE TRUNCATION ISSUE
if isinstance(record_date, str):
# Parse complete ISO datetime string with timezone info intact
# DO NOT truncate to date part only - this was causing the filtering issue
if 'T' in record_date:
record_date = record_date.replace('Z', '+00:00')
# Parse with FULL datetime info, not just date part
parsed_date = datetime.fromisoformat(record_date)
# Ensure timezone-aware
if parsed_date.tzinfo is None:
parsed_date = parsed_date.replace(tzinfo=timezone.utc)
record_date = parsed_date
elif isinstance(record_date, datetime):
# Ensure timezone-aware
if record_date.tzinfo is None:
record_date = record_date.replace(tzinfo=timezone.utc)
# DO NOT normalize to start of day - keep actual datetime for proper filtering
# Only normalize if needed for daily aggregation, but preserve original for filtering
# ✅ FIX: Ensure aligned_range dates are also timezone-aware for comparison
aligned_start = aligned_range.start
aligned_end = aligned_range.end
if aligned_start.tzinfo is None:
aligned_start = aligned_start.replace(tzinfo=timezone.utc)
if aligned_end.tzinfo is None:
aligned_end = aligned_end.replace(tzinfo=timezone.utc)
# Check if date falls within aligned range (now both are timezone-aware)
if aligned_start <= record_date <= aligned_end:
# Validate that record has required fields
if self._validate_sales_record(record):
filtered_data.append(record)
else:
filtered_count += 1
else:
# Record outside date range
filtered_count += 1
except Exception as e:
logger.warning(f"Error processing sales record: {str(e)}")
filtered_count += 1
continue
logger.info(f"Filtered sales data: {len(filtered_data)} records in aligned range")
if filtered_count > 0:
logger.warning(f"Filtered out {filtered_count} invalid records")
return filtered_data
def _validate_sales_record(self, record: Dict[str, Any]) -> bool:
"""Validate individual sales record"""
required_fields = ['date', 'inventory_product_id']
quantity_fields = ['quantity', 'quantity_sold', 'sales', 'units_sold']
# Check required fields
for field in required_fields:
if field not in record or record[field] is None:
return False
# Check at least one quantity field exists
has_quantity = any(field in record and record[field] is not None for field in quantity_fields)
if not has_quantity:
return False
# Validate quantity is numeric and non-negative
for field in quantity_fields:
if field in record and record[field] is not None:
try:
quantity = float(record[field])
if quantity < 0:
return False
except (ValueError, TypeError):
return False
break
return True
async def _collect_external_data(
self,
aligned_range: AlignedDateRange,
bakery_location: Tuple[float, float],
tenant_id: str
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], Dict[str, Any]]:
"""Collect weather, traffic, and POI data concurrently with enhanced error handling"""
lat, lon = bakery_location
# Create collection tasks with timeout
tasks = []
# Weather data collection
if DataSourceType.WEATHER_FORECAST in aligned_range.available_sources:
weather_task = asyncio.create_task(
self._collect_weather_data_with_timeout(lat, lon, aligned_range, tenant_id)
)
tasks.append(("weather", weather_task))
# Enhanced Traffic data collection (supports multiple cities)
if DataSourceType.MADRID_TRAFFIC in aligned_range.available_sources:
logger.info(f"🚛 Traffic data source available for multiple cities, creating collection task for date range: {aligned_range.start} to {aligned_range.end}")
traffic_task = asyncio.create_task(
self._collect_traffic_data_with_timeout_enhanced(lat, lon, aligned_range, tenant_id)
)
tasks.append(("traffic", traffic_task))
else:
logger.warning(f"🚫 Traffic data source NOT available in sources: {[s.value for s in aligned_range.available_sources]}")
# POI features collection (static, location-based)
poi_task = asyncio.create_task(
self._collect_poi_features(lat, lon, tenant_id)
)
tasks.append(("poi", poi_task))
# Execute tasks concurrently with proper error handling
results = {}
if tasks:
try:
completed_tasks = await asyncio.gather(
*[task for _, task in tasks],
return_exceptions=True
)
for i, (task_name, _) in enumerate(tasks):
result = completed_tasks[i]
if isinstance(result, Exception):
logger.warning(f"{task_name} data collection failed: {result}")
results[task_name] = [] if task_name != "poi" else {}
else:
results[task_name] = result
if task_name == "poi":
logger.info(f"{task_name} features collected: {len(result) if result else 0} features")
else:
logger.info(f"{task_name} data collection completed: {len(result)} records")
except Exception as e:
logger.error(f"Error in concurrent data collection: {str(e)}")
results = {"weather": [], "traffic": [], "poi": {}}
weather_data = results.get("weather", [])
traffic_data = results.get("traffic", [])
poi_features = results.get("poi", {})
return weather_data, traffic_data, poi_features
async def _collect_poi_features(
self,
lat: float,
lon: float,
tenant_id: str
) -> Dict[str, Any]:
"""
Collect POI features for bakery location (non-blocking).
POI features are static (location-based, not time-varying).
This method is non-blocking with a short timeout to prevent training delays.
If POI detection hasn't been run yet, training continues without POI features.
Returns:
Dictionary with POI features or empty dict if unavailable
"""
try:
logger.info(
"Collecting POI features (non-blocking)",
tenant_id=tenant_id,
location=(lat, lon)
)
# Set a short timeout to prevent blocking training
# POI detection should have been triggered during tenant registration
poi_features = await asyncio.wait_for(
self.poi_feature_integrator.fetch_poi_features(
tenant_id=tenant_id,
latitude=lat,
longitude=lon,
force_refresh=False
),
timeout=15.0 # 15 second timeout - POI should be cached from registration
)
if poi_features:
logger.info(
"POI features collected successfully",
tenant_id=tenant_id,
feature_count=len(poi_features)
)
else:
logger.warning(
"No POI features collected (service may be unavailable or not yet detected)",
tenant_id=tenant_id
)
return poi_features or {}
except asyncio.TimeoutError:
logger.warning(
"POI collection timeout (15s) - continuing training without POI features. "
"POI detection should be triggered during tenant registration for best results.",
tenant_id=tenant_id
)
return {}
except Exception as e:
logger.warning(
"Failed to collect POI features (non-blocking) - continuing training without them",
tenant_id=tenant_id,
error=str(e)
)
return {}
async def _collect_weather_data_with_timeout(
self,
lat: float,
lon: float,
aligned_range: AlignedDateRange,
tenant_id: str
) -> List[Dict[str, Any]]:
"""Collect weather data with timeout and fallback"""
try:
start_date_str = aligned_range.start.isoformat()
end_date_str = aligned_range.end.isoformat()
weather_data = await self.data_client.fetch_weather_data(
tenant_id=tenant_id,
start_date=start_date_str,
end_date=end_date_str,
latitude=lat,
longitude=lon)
# Validate weather data
if self._validate_weather_data(weather_data):
logger.info(f"Collected {len(weather_data)} valid weather records")
return weather_data
else:
logger.warning("Invalid weather data received, using synthetic data")
return self._generate_synthetic_weather_data(aligned_range)
except asyncio.TimeoutError:
logger.warning(f"Weather data collection timed out, using synthetic data")
return self._generate_synthetic_weather_data(aligned_range)
except Exception as e:
logger.warning(f"Weather data collection failed: {e}, using synthetic data")
return self._generate_synthetic_weather_data(aligned_range)
async def _collect_traffic_data_with_timeout_enhanced(
self,
lat: float,
lon: float,
aligned_range: AlignedDateRange,
tenant_id: str
) -> List[Dict[str, Any]]:
"""
Enhanced traffic data collection with multi-city support and improved storage
Uses the new abstracted traffic service layer
"""
try:
# Double-check constraints before making request
constraint_violated = self.date_alignment_service.check_madrid_current_month_constraint(aligned_range.end)
if constraint_violated:
logger.warning(f"🚫 Current month constraint violation: end_date={aligned_range.end}, no traffic data available")
return []
else:
logger.info(f"✅ Date constraints passed: end_date={aligned_range.end}, proceeding with traffic data request")
start_date_str = aligned_range.start.isoformat()
end_date_str = aligned_range.end.isoformat()
# Enhanced: Fetch traffic data using unified cache-first method
# This automatically detects the appropriate city and uses the right client
traffic_data = await self.data_client.fetch_traffic_data_unified(
tenant_id=tenant_id,
start_date=start_date_str,
end_date=end_date_str,
latitude=lat,
longitude=lon,
force_refresh=False # Use cache-first strategy
)
# Enhanced validation including pedestrian inference data
if self._validate_traffic_data_enhanced(traffic_data):
logger.info(f"Collected and stored {len(traffic_data)} valid enhanced traffic records for re-training")
# Log storage success with enhanced metadata
self._log_enhanced_traffic_data_storage(lat, lon, aligned_range, len(traffic_data), traffic_data)
return traffic_data
else:
logger.warning("Invalid enhanced traffic data received")
return []
except asyncio.TimeoutError:
logger.warning(f"Enhanced traffic data collection timed out")
return []
except Exception as e:
logger.warning(f"Enhanced traffic data collection failed: {e}")
return []
def _log_enhanced_traffic_data_storage(self,
lat: float,
lon: float,
aligned_range: AlignedDateRange,
record_count: int,
traffic_data: List[Dict[str, Any]]):
"""Enhanced logging for traffic data storage with detailed metadata"""
cities_detected = set()
has_pedestrian_data = 0
data_sources = set()
districts_covered = set()
for record in traffic_data:
if 'city' in record and record['city']:
cities_detected.add(record['city'])
if 'pedestrian_count' in record and record['pedestrian_count'] is not None:
has_pedestrian_data += 1
if 'source' in record and record['source']:
data_sources.add(record['source'])
if 'district' in record and record['district']:
districts_covered.add(record['district'])
logger.info(
"Enhanced traffic data stored for re-training",
location=f"{lat:.4f},{lon:.4f}",
date_range=f"{aligned_range.start.isoformat()} to {aligned_range.end.isoformat()}",
records_stored=record_count,
cities_detected=list(cities_detected),
pedestrian_inference_coverage=f"{has_pedestrian_data}/{record_count}",
data_sources=list(data_sources),
districts_covered=list(districts_covered),
storage_timestamp=datetime.now().isoformat(),
purpose="model_training_and_retraining"
)
def _validate_weather_data(self, weather_data: List[Dict[str, Any]]) -> bool:
"""Validate weather data quality"""
if not weather_data:
return False
required_fields = ['date']
weather_fields = ['temperature', 'temp', 'temperatura', 'precipitation', 'rain', 'lluvia']
valid_records = 0
for record in weather_data:
# Check required fields
if not all(field in record for field in required_fields):
continue
# Check at least one weather field exists
if any(field in record and record[field] is not None for field in weather_fields):
valid_records += 1
# Consider valid if at least 50% of records are valid
validity_threshold = 0.5
is_valid = (valid_records / len(weather_data)) >= validity_threshold
if not is_valid:
logger.warning(f"Weather data validation failed: {valid_records}/{len(weather_data)} valid records")
return is_valid
def _validate_traffic_data_enhanced(self, traffic_data: List[Dict[str, Any]]) -> bool:
"""Enhanced validation for traffic data including pedestrian inference and city-specific fields"""
if not traffic_data:
return False
required_fields = ['date']
traffic_fields = ['traffic_volume', 'traffic_intensity', 'intensidad', 'trafico']
enhanced_fields = ['pedestrian_count', 'congestion_level', 'source']
city_specific_fields = ['city', 'measurement_point_id', 'district']
valid_records = 0
enhanced_records = 0
city_aware_records = 0
for record in traffic_data:
record_score = 0
# Check required fields
if all(field in record and record[field] is not None for field in required_fields):
record_score += 1
# Check traffic data fields
if any(field in record and record[field] is not None for field in traffic_fields):
record_score += 1
# Check enhanced fields (pedestrian inference, etc.)
enhanced_count = sum(1 for field in enhanced_fields
if field in record and record[field] is not None)
if enhanced_count >= 2: # At least 2 enhanced fields
enhanced_records += 1
record_score += 1
# Check city-specific awareness
city_count = sum(1 for field in city_specific_fields
if field in record and record[field] is not None)
if city_count >= 1: # At least some city awareness
city_aware_records += 1
# Record is valid if it has basic requirements (date + any traffic field)
# Lowered requirement from >= 2 to >= 1 to accept records with just date or traffic data
if record_score >= 1:
valid_records += 1
total_records = len(traffic_data)
validity_threshold = 0.1 # Reduced from 0.3 to 0.1 - accept if 10% of records are valid
enhancement_threshold = 0.1 # Reduced threshold for enhanced features
basic_validity = (valid_records / total_records) >= validity_threshold
has_enhancements = (enhanced_records / total_records) >= enhancement_threshold
has_city_awareness = (city_aware_records / total_records) >= enhancement_threshold
logger.info("Enhanced traffic data validation results",
total_records=total_records,
valid_records=valid_records,
enhanced_records=enhanced_records,
city_aware_records=city_aware_records,
basic_validity=basic_validity,
has_enhancements=has_enhancements,
has_city_awareness=has_city_awareness)
if not basic_validity:
logger.warning(f"Traffic data basic validation failed: {valid_records}/{total_records} valid records")
return basic_validity
def _validate_traffic_data(self, traffic_data: List[Dict[str, Any]]) -> bool:
"""Legacy validation method - redirects to enhanced version"""
return self._validate_traffic_data_enhanced(traffic_data)
def _validate_data_sources(
self,
sales_data: List[Dict[str, Any]],
weather_data: List[Dict[str, Any]],
traffic_data: List[Dict[str, Any]],
aligned_range: AlignedDateRange
) -> Dict[str, Any]:
"""Validate all data sources and provide quality metrics"""
validation_results = {
"sales_data": {
"record_count": len(sales_data),
"is_valid": len(sales_data) > 0,
"coverage_days": (aligned_range.end - aligned_range.start).days,
"quality_score": 0.0
},
"weather_data": {
"record_count": len(weather_data),
"is_valid": self._validate_weather_data(weather_data) if weather_data else False,
"quality_score": 0.0
},
"traffic_data": {
"record_count": len(traffic_data),
"is_valid": self._validate_traffic_data(traffic_data) if traffic_data else False,
"quality_score": 0.0
},
"overall_quality_score": 0.0
}
# Calculate quality scores
# Sales data quality (most important)
if validation_results["sales_data"]["record_count"] > 0:
coverage_ratio = min(1.0, validation_results["sales_data"]["record_count"] / validation_results["sales_data"]["coverage_days"])
validation_results["sales_data"]["quality_score"] = coverage_ratio * 100
# Weather data quality
if validation_results["weather_data"]["record_count"] > 0:
expected_weather_records = (aligned_range.end - aligned_range.start).days
coverage_ratio = min(1.0, validation_results["weather_data"]["record_count"] / expected_weather_records)
validation_results["weather_data"]["quality_score"] = coverage_ratio * 100
# Traffic data quality
if validation_results["traffic_data"]["record_count"] > 0:
expected_traffic_records = (aligned_range.end - aligned_range.start).days
coverage_ratio = min(1.0, validation_results["traffic_data"]["record_count"] / expected_traffic_records)
validation_results["traffic_data"]["quality_score"] = coverage_ratio * 100
# Overall quality score (weighted by importance)
weights = {"sales_data": 0.7, "weather_data": 0.2, "traffic_data": 0.1}
overall_score = sum(
validation_results[source]["quality_score"] * weight
for source, weight in weights.items()
)
validation_results["overall_quality_score"] = round(overall_score, 2)
return validation_results
def _generate_synthetic_weather_data(
self,
aligned_range: AlignedDateRange
) -> List[Dict[str, Any]]:
"""Generate realistic synthetic weather data for Madrid"""
synthetic_data = []
current_date = aligned_range.start
# Madrid seasonal temperature patterns
seasonal_temps = {
1: 9, 2: 11, 3: 15, 4: 17, 5: 21, 6: 26,
7: 29, 8: 28, 9: 24, 10: 18, 11: 12, 12: 9
}
while current_date <= aligned_range.end:
month = current_date.month
base_temp = seasonal_temps.get(month, 15)
# Add some realistic variation
import random
temp_variation = random.gauss(0, 3) # ±3°C variation
temperature = max(0, base_temp + temp_variation)
# Precipitation patterns (Madrid is relatively dry)
precipitation = 0.0
if random.random() < 0.15: # 15% chance of rain
precipitation = random.uniform(0.1, 15.0)
synthetic_data.append({
"date": current_date,
"temperature": round(temperature, 1),
"precipitation": round(precipitation, 1),
"humidity": round(random.uniform(40, 80), 1),
"wind_speed": round(random.uniform(2, 15), 1),
"pressure": round(random.uniform(1005, 1025), 1),
"source": "synthetic_madrid_pattern"
})
current_date = current_date + timedelta(days=1)
logger.info(f"Generated {len(synthetic_data)} synthetic weather records with Madrid patterns")
return synthetic_data
def validate_training_data_quality(self, dataset: TrainingDataSet) -> Dict[str, Any]:
"""Enhanced validation of training data quality"""
validation_results = {
"is_valid": True,
"warnings": [],
"errors": [],
"data_quality_score": 100.0,
"recommendations": []
}
# Check sales data completeness
sales_count = len(dataset.sales_data)
if sales_count < 30:
validation_results["warnings"].append(
f"Limited sales data: {sales_count} records (recommended: 30+)"
)
validation_results["data_quality_score"] -= 20
validation_results["recommendations"].append("Consider collecting more historical sales data")
elif sales_count < 90:
validation_results["warnings"].append(
f"Moderate sales data: {sales_count} records (optimal: 90+)"
)
validation_results["data_quality_score"] -= 10
# Check date coverage
date_coverage = (dataset.date_range.end - dataset.date_range.start).days
if date_coverage < 90:
validation_results["warnings"].append(
f"Limited date coverage: {date_coverage} days (recommended: 90+)"
)
validation_results["data_quality_score"] -= 15
validation_results["recommendations"].append("Extend date range for better seasonality detection")
# Check external data availability
if not dataset.weather_data:
validation_results["warnings"].append("No weather data available")
validation_results["data_quality_score"] -= 10
validation_results["recommendations"].append("Weather data improves forecast accuracy")
elif len(dataset.weather_data) < date_coverage * 0.5:
validation_results["warnings"].append("Sparse weather data coverage")
validation_results["data_quality_score"] -= 5
if not dataset.traffic_data:
validation_results["warnings"].append("No traffic data available")
validation_results["data_quality_score"] -= 5
validation_results["recommendations"].append("Traffic data can help with location-based patterns")
# Check data consistency
unique_products = set()
for record in dataset.sales_data:
if 'inventory_product_id' in record:
unique_products.add(record['inventory_product_id'])
if len(unique_products) == 0:
validation_results["errors"].append("No product names found in sales data")
validation_results["is_valid"] = False
elif len(unique_products) > 50:
validation_results["warnings"].append(
f"Many products detected ({len(unique_products)}). Consider training models in batches."
)
validation_results["recommendations"].append("Group similar products for better training efficiency")
# Check for data source constraints
if dataset.date_range.constraints:
constraint_info = []
for constraint_type, message in dataset.date_range.constraints.items():
constraint_info.append(f"{constraint_type}: {message}")
validation_results["warnings"].append(
f"Data source constraints applied: {'; '.join(constraint_info)}"
)
# Final validation
if validation_results["errors"]:
validation_results["is_valid"] = False
validation_results["data_quality_score"] = 0.0
# Ensure score doesn't go below 0
validation_results["data_quality_score"] = max(0.0, validation_results["data_quality_score"])
# Add quality assessment
score = validation_results["data_quality_score"]
if score >= 80:
validation_results["quality_assessment"] = "Excellent"
elif score >= 60:
validation_results["quality_assessment"] = "Good"
elif score >= 40:
validation_results["quality_assessment"] = "Fair"
else:
validation_results["quality_assessment"] = "Poor"
return validation_results
def get_data_collection_plan(self, aligned_range: AlignedDateRange) -> Dict[str, Dict]:
"""
Generate an enhanced data collection plan based on the aligned date range.
"""
plan = {
"collection_summary": {
"start_date": aligned_range.start.isoformat(),
"end_date": aligned_range.end.isoformat(),
"duration_days": (aligned_range.end - aligned_range.start).days,
"available_sources": [source.value for source in aligned_range.available_sources],
"constraints": aligned_range.constraints
},
"data_sources": {}
}
# Bakery Sales Data
if DataSourceType.BAKERY_SALES in aligned_range.available_sources:
plan["data_sources"]["sales_data"] = {
"start_date": aligned_range.start.isoformat(),
"end_date": aligned_range.end.isoformat(),
"source": "user_upload",
"required": True,
"priority": "high",
"expected_records": "variable",
"data_points": ["date", "inventory_product_id", "quantity"],
"validation": "required_fields_check"
}
# Madrid Traffic Data
if DataSourceType.MADRID_TRAFFIC in aligned_range.available_sources:
plan["data_sources"]["traffic_data"] = {
"start_date": aligned_range.start.isoformat(),
"end_date": aligned_range.end.isoformat(),
"source": "madrid_opendata",
"required": False,
"priority": "medium",
"expected_records": (aligned_range.end - aligned_range.start).days,
"constraint": "Cannot request current month data",
"data_points": ["date", "traffic_volume", "congestion_level"],
"validation": "date_constraint_check"
}
# Weather Data
if DataSourceType.WEATHER_FORECAST in aligned_range.available_sources:
plan["data_sources"]["weather_data"] = {
"start_date": aligned_range.start.isoformat(),
"end_date": aligned_range.end.isoformat(),
"source": "aemet_api",
"required": False,
"priority": "high",
"expected_records": (aligned_range.end - aligned_range.start).days,
"constraint": "Available from yesterday backward",
"data_points": ["date", "temperature", "precipitation", "humidity"],
"validation": "temporal_constraint_check",
"fallback": "synthetic_madrid_weather"
}
return plan
def get_orchestration_summary(self, dataset: TrainingDataSet) -> Dict[str, Any]:
"""
Generate a comprehensive summary of the orchestration process.
"""
return {
"tenant_id": dataset.metadata.get("tenant_id"),
"job_id": dataset.metadata.get("job_id"),
"orchestration_completed_at": dataset.metadata.get("preparation_timestamp"),
"data_alignment": {
"original_range": dataset.metadata.get("original_sales_range"),
"aligned_range": {
"start": dataset.date_range.start.isoformat(),
"end": dataset.date_range.end.isoformat(),
"duration_days": (dataset.date_range.end - dataset.date_range.start).days
},
"constraints_applied": dataset.date_range.constraints,
"available_sources": [source.value for source in dataset.date_range.available_sources]
},
"data_collection_results": {
"sales_records": len(dataset.sales_data),
"weather_records": len(dataset.weather_data),
"traffic_records": len(dataset.traffic_data),
"total_records": len(dataset.sales_data) + len(dataset.weather_data) + len(dataset.traffic_data)
},
"data_quality": dataset.metadata.get("data_quality", {}),
"validation_results": dataset.metadata.get("final_validation", {}),
"processing_metadata": {
"bakery_location": dataset.metadata.get("bakery_location"),
"data_sources_requested": len(dataset.date_range.available_sources),
"data_sources_successful": sum([
1 if len(dataset.sales_data) > 0 else 0,
1 if len(dataset.weather_data) > 0 else 0,
1 if len(dataset.traffic_data) > 0 else 0
])
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,92 @@
"""
Training Service Utilities
"""
from .ml_datetime import (
ensure_timezone_aware,
ensure_timezone_naive,
normalize_datetime_to_utc,
normalize_dataframe_datetime_column,
prepare_prophet_datetime,
safe_datetime_comparison,
get_current_utc,
convert_timestamp_to_datetime
)
from .circuit_breaker import (
CircuitBreaker,
CircuitBreakerError,
CircuitState,
circuit_breaker_registry
)
from .file_utils import (
calculate_file_checksum,
verify_file_checksum,
get_file_size,
ensure_directory_exists,
safe_file_delete,
get_file_metadata,
ChecksummedFile
)
from .distributed_lock import (
DatabaseLock,
SimpleDatabaseLock,
LockAcquisitionError,
get_training_lock
)
from .retry import (
RetryStrategy,
RetryError,
retry_async,
with_retry,
retry_with_timeout,
AdaptiveRetryStrategy,
TimeoutRetryStrategy,
HTTP_RETRY_STRATEGY,
DATABASE_RETRY_STRATEGY,
EXTERNAL_SERVICE_RETRY_STRATEGY
)
__all__ = [
# Timezone utilities
'ensure_timezone_aware',
'ensure_timezone_naive',
'normalize_datetime_to_utc',
'normalize_dataframe_datetime_column',
'prepare_prophet_datetime',
'safe_datetime_comparison',
'get_current_utc',
'convert_timestamp_to_datetime',
# Circuit breaker
'CircuitBreaker',
'CircuitBreakerError',
'CircuitState',
'circuit_breaker_registry',
# File utilities
'calculate_file_checksum',
'verify_file_checksum',
'get_file_size',
'ensure_directory_exists',
'safe_file_delete',
'get_file_metadata',
'ChecksummedFile',
# Distributed locking
'DatabaseLock',
'SimpleDatabaseLock',
'LockAcquisitionError',
'get_training_lock',
# Retry mechanisms
'RetryStrategy',
'RetryError',
'retry_async',
'with_retry',
'retry_with_timeout',
'AdaptiveRetryStrategy',
'TimeoutRetryStrategy',
'HTTP_RETRY_STRATEGY',
'DATABASE_RETRY_STRATEGY',
'EXTERNAL_SERVICE_RETRY_STRATEGY'
]

View File

@@ -0,0 +1,198 @@
"""
Circuit Breaker Pattern Implementation
Protects against cascading failures from external service calls
"""
import asyncio
import time
from enum import Enum
from typing import Callable, Any, Optional
import logging
from functools import wraps
logger = logging.getLogger(__name__)
class CircuitState(Enum):
"""Circuit breaker states"""
CLOSED = "closed" # Normal operation
OPEN = "open" # Circuit is open, rejecting requests
HALF_OPEN = "half_open" # Testing if service recovered
class CircuitBreakerError(Exception):
"""Raised when circuit breaker is open"""
pass
class CircuitBreaker:
"""
Circuit breaker to prevent cascading failures.
States:
- CLOSED: Normal operation, requests pass through
- OPEN: Too many failures, rejecting all requests
- HALF_OPEN: Testing if service recovered, allowing limited requests
"""
def __init__(
self,
failure_threshold: int = 5,
recovery_timeout: float = 60.0,
expected_exception: type = Exception,
name: str = "circuit_breaker"
):
"""
Initialize circuit breaker.
Args:
failure_threshold: Number of failures before opening circuit
recovery_timeout: Seconds to wait before attempting recovery
expected_exception: Exception type to catch (others will pass through)
name: Name for logging purposes
"""
self.failure_threshold = failure_threshold
self.recovery_timeout = recovery_timeout
self.expected_exception = expected_exception
self.name = name
self.failure_count = 0
self.last_failure_time: Optional[float] = None
self.state = CircuitState.CLOSED
def _record_success(self):
"""Record successful call"""
self.failure_count = 0
self.last_failure_time = None
if self.state == CircuitState.HALF_OPEN:
logger.info(f"Circuit breaker '{self.name}' recovered, closing circuit")
self.state = CircuitState.CLOSED
def _record_failure(self):
"""Record failed call"""
self.failure_count += 1
self.last_failure_time = time.time()
if self.failure_count >= self.failure_threshold:
if self.state != CircuitState.OPEN:
logger.warning(
f"Circuit breaker '{self.name}' opened after {self.failure_count} failures"
)
self.state = CircuitState.OPEN
def _should_attempt_reset(self) -> bool:
"""Check if we should attempt to reset circuit"""
return (
self.state == CircuitState.OPEN
and self.last_failure_time is not None
and time.time() - self.last_failure_time >= self.recovery_timeout
)
async def call(self, func: Callable, *args, **kwargs) -> Any:
"""
Execute function with circuit breaker protection.
Args:
func: Async function to execute
*args: Positional arguments for func
**kwargs: Keyword arguments for func
Returns:
Result from func
Raises:
CircuitBreakerError: If circuit is open
Exception: Original exception if not expected_exception type
"""
# Check if circuit is open
if self.state == CircuitState.OPEN:
if self._should_attempt_reset():
logger.info(f"Circuit breaker '{self.name}' attempting recovery (half-open)")
self.state = CircuitState.HALF_OPEN
else:
raise CircuitBreakerError(
f"Circuit breaker '{self.name}' is open. "
f"Service unavailable for {self.recovery_timeout}s after {self.failure_count} failures."
)
try:
# Execute the function
result = await func(*args, **kwargs)
self._record_success()
return result
except self.expected_exception as e:
self._record_failure()
logger.error(
f"Circuit breaker '{self.name}' caught failure",
error=str(e),
failure_count=self.failure_count,
state=self.state.value
)
raise
def __call__(self, func: Callable) -> Callable:
"""Decorator interface for circuit breaker"""
@wraps(func)
async def wrapper(*args, **kwargs):
return await self.call(func, *args, **kwargs)
return wrapper
def get_state(self) -> dict:
"""Get current circuit breaker state for monitoring"""
return {
"name": self.name,
"state": self.state.value,
"failure_count": self.failure_count,
"failure_threshold": self.failure_threshold,
"last_failure_time": self.last_failure_time,
"recovery_timeout": self.recovery_timeout
}
class CircuitBreakerRegistry:
"""Registry to manage multiple circuit breakers"""
def __init__(self):
self._breakers: dict[str, CircuitBreaker] = {}
def get_or_create(
self,
name: str,
failure_threshold: int = 5,
recovery_timeout: float = 60.0,
expected_exception: type = Exception
) -> CircuitBreaker:
"""Get existing circuit breaker or create new one"""
if name not in self._breakers:
self._breakers[name] = CircuitBreaker(
failure_threshold=failure_threshold,
recovery_timeout=recovery_timeout,
expected_exception=expected_exception,
name=name
)
return self._breakers[name]
def get(self, name: str) -> Optional[CircuitBreaker]:
"""Get circuit breaker by name"""
return self._breakers.get(name)
def get_all_states(self) -> dict:
"""Get states of all circuit breakers"""
return {
name: breaker.get_state()
for name, breaker in self._breakers.items()
}
def reset(self, name: str):
"""Manually reset a circuit breaker"""
if name in self._breakers:
breaker = self._breakers[name]
breaker.failure_count = 0
breaker.last_failure_time = None
breaker.state = CircuitState.CLOSED
logger.info(f"Circuit breaker '{name}' manually reset")
# Global registry instance
circuit_breaker_registry = CircuitBreakerRegistry()

View File

@@ -0,0 +1,250 @@
"""
Distributed Locking Mechanisms
Prevents concurrent training jobs for the same product
HORIZONTAL SCALING FIX:
- Uses SHA256 for stable hash across all Python processes/pods
- Python's built-in hash() varies between processes due to hash randomization (Python 3.3+)
- This ensures all pods compute the same lock ID for the same lock name
"""
import asyncio
import time
import hashlib
from typing import Optional
import logging
from contextlib import asynccontextmanager
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
from datetime import datetime, timezone, timedelta
logger = logging.getLogger(__name__)
class LockAcquisitionError(Exception):
"""Raised when lock cannot be acquired"""
pass
class DatabaseLock:
"""
Database-based distributed lock using PostgreSQL advisory locks.
Works across multiple service instances.
"""
def __init__(self, lock_name: str, timeout: float = 30.0):
"""
Initialize database lock.
Args:
lock_name: Unique identifier for the lock
timeout: Maximum seconds to wait for lock acquisition
"""
self.lock_name = lock_name
self.timeout = timeout
self.lock_id = self._hash_lock_name(lock_name)
def _hash_lock_name(self, name: str) -> int:
"""
Convert lock name to integer ID for PostgreSQL advisory lock.
CRITICAL: Uses SHA256 for stable hash across all Python processes/pods.
Python's built-in hash() varies between processes due to hash randomization
(PYTHONHASHSEED, enabled by default since Python 3.3), which would cause
different pods to compute different lock IDs for the same lock name,
defeating the purpose of distributed locking.
"""
# Use SHA256 for stable, cross-process hash
hash_bytes = hashlib.sha256(name.encode('utf-8')).digest()
# Take first 4 bytes and convert to positive 31-bit integer
# (PostgreSQL advisory locks use bigint, but we use 31-bit for safety)
return int.from_bytes(hash_bytes[:4], 'big') % (2**31)
@asynccontextmanager
async def acquire(self, session: AsyncSession):
"""
Acquire distributed lock as async context manager.
Args:
session: Database session for lock operations
Raises:
LockAcquisitionError: If lock cannot be acquired within timeout
"""
acquired = False
start_time = time.time()
try:
# Try to acquire lock with timeout
while time.time() - start_time < self.timeout:
# Try non-blocking lock acquisition
result = await session.execute(
text("SELECT pg_try_advisory_lock(:lock_id)"),
{"lock_id": self.lock_id}
)
acquired = result.scalar()
if acquired:
logger.info(f"Acquired lock: {self.lock_name} (id={self.lock_id})")
break
# Wait a bit before retrying
await asyncio.sleep(0.1)
if not acquired:
raise LockAcquisitionError(
f"Could not acquire lock '{self.lock_name}' within {self.timeout}s"
)
yield
finally:
if acquired:
# Release lock
await session.execute(
text("SELECT pg_advisory_unlock(:lock_id)"),
{"lock_id": self.lock_id}
)
logger.info(f"Released lock: {self.lock_name} (id={self.lock_id})")
class SimpleDatabaseLock:
"""
Simple table-based distributed lock.
Alternative to advisory locks, uses a dedicated locks table.
"""
def __init__(self, lock_name: str, timeout: float = 30.0, ttl: float = 300.0):
"""
Initialize simple database lock.
Args:
lock_name: Unique identifier for the lock
timeout: Maximum seconds to wait for lock acquisition
ttl: Time-to-live for stale lock cleanup (seconds)
"""
self.lock_name = lock_name
self.timeout = timeout
self.ttl = ttl
async def _ensure_lock_table(self, session: AsyncSession):
"""Ensure locks table exists"""
create_table_sql = """
CREATE TABLE IF NOT EXISTS distributed_locks (
lock_name VARCHAR(255) PRIMARY KEY,
acquired_at TIMESTAMP WITH TIME ZONE NOT NULL,
acquired_by VARCHAR(255),
expires_at TIMESTAMP WITH TIME ZONE NOT NULL
)
"""
await session.execute(text(create_table_sql))
await session.commit()
async def _cleanup_stale_locks(self, session: AsyncSession):
"""Remove expired locks"""
cleanup_sql = """
DELETE FROM distributed_locks
WHERE expires_at < :now
"""
await session.execute(
text(cleanup_sql),
{"now": datetime.now(timezone.utc)}
)
await session.commit()
@asynccontextmanager
async def acquire(self, session: AsyncSession, owner: str = "training-service"):
"""
Acquire simple database lock.
Args:
session: Database session
owner: Identifier for lock owner
Raises:
LockAcquisitionError: If lock cannot be acquired
"""
await self._ensure_lock_table(session)
await self._cleanup_stale_locks(session)
acquired = False
start_time = time.time()
try:
# Try to acquire lock
while time.time() - start_time < self.timeout:
now = datetime.now(timezone.utc)
expires_at = now + timedelta(seconds=self.ttl)
try:
# Try to insert lock record
insert_sql = """
INSERT INTO distributed_locks (lock_name, acquired_at, acquired_by, expires_at)
VALUES (:lock_name, :acquired_at, :acquired_by, :expires_at)
ON CONFLICT (lock_name) DO NOTHING
RETURNING lock_name
"""
result = await session.execute(
text(insert_sql),
{
"lock_name": self.lock_name,
"acquired_at": now,
"acquired_by": owner,
"expires_at": expires_at
}
)
await session.commit()
if result.rowcount > 0:
acquired = True
logger.info(f"Acquired simple lock: {self.lock_name}")
break
except Exception as e:
logger.debug(f"Lock acquisition attempt failed: {e}")
await session.rollback()
# Wait before retrying
await asyncio.sleep(0.5)
if not acquired:
raise LockAcquisitionError(
f"Could not acquire lock '{self.lock_name}' within {self.timeout}s"
)
yield
finally:
if acquired:
# Release lock
delete_sql = """
DELETE FROM distributed_locks
WHERE lock_name = :lock_name
"""
await session.execute(
text(delete_sql),
{"lock_name": self.lock_name}
)
await session.commit()
logger.info(f"Released simple lock: {self.lock_name}")
def get_training_lock(tenant_id: str, product_id: str, use_advisory: bool = True) -> DatabaseLock:
"""
Get distributed lock for training a specific product.
Args:
tenant_id: Tenant identifier
product_id: Product identifier
use_advisory: Use PostgreSQL advisory locks (True) or table-based (False)
Returns:
Lock instance
"""
lock_name = f"training:{tenant_id}:{product_id}"
if use_advisory:
return DatabaseLock(lock_name, timeout=60.0)
else:
return SimpleDatabaseLock(lock_name, timeout=60.0, ttl=600.0)

View File

@@ -0,0 +1,216 @@
"""
File Utility Functions
Utilities for secure file operations including checksum verification
"""
import hashlib
import os
from pathlib import Path
from typing import Optional
import logging
logger = logging.getLogger(__name__)
def calculate_file_checksum(file_path: str, algorithm: str = "sha256") -> str:
"""
Calculate checksum of a file.
Args:
file_path: Path to file
algorithm: Hash algorithm (sha256, md5, etc.)
Returns:
Hexadecimal checksum string
Raises:
FileNotFoundError: If file doesn't exist
ValueError: If algorithm not supported
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
try:
hash_func = hashlib.new(algorithm)
except ValueError:
raise ValueError(f"Unsupported hash algorithm: {algorithm}")
# Read file in chunks to handle large files efficiently
with open(file_path, 'rb') as f:
while chunk := f.read(8192):
hash_func.update(chunk)
return hash_func.hexdigest()
def verify_file_checksum(file_path: str, expected_checksum: str, algorithm: str = "sha256") -> bool:
"""
Verify file matches expected checksum.
Args:
file_path: Path to file
expected_checksum: Expected checksum value
algorithm: Hash algorithm used
Returns:
True if checksum matches, False otherwise
"""
try:
actual_checksum = calculate_file_checksum(file_path, algorithm)
matches = actual_checksum == expected_checksum
if matches:
logger.debug(f"Checksum verified for {file_path}")
else:
logger.warning(
f"Checksum mismatch for {file_path}",
expected=expected_checksum,
actual=actual_checksum
)
return matches
except Exception as e:
logger.error(f"Error verifying checksum for {file_path}: {e}")
return False
def get_file_size(file_path: str) -> int:
"""
Get file size in bytes.
Args:
file_path: Path to file
Returns:
File size in bytes
Raises:
FileNotFoundError: If file doesn't exist
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
return os.path.getsize(file_path)
def ensure_directory_exists(directory: str) -> Path:
"""
Ensure directory exists, create if necessary.
Args:
directory: Directory path
Returns:
Path object for directory
"""
path = Path(directory)
path.mkdir(parents=True, exist_ok=True)
return path
def safe_file_delete(file_path: str) -> bool:
"""
Safely delete a file, logging any errors.
Args:
file_path: Path to file
Returns:
True if deleted successfully, False otherwise
"""
try:
if os.path.exists(file_path):
os.remove(file_path)
logger.info(f"Deleted file: {file_path}")
return True
else:
logger.warning(f"File not found for deletion: {file_path}")
return False
except Exception as e:
logger.error(f"Error deleting file {file_path}: {e}")
return False
def get_file_metadata(file_path: str) -> dict:
"""
Get comprehensive file metadata.
Args:
file_path: Path to file
Returns:
Dictionary with file metadata
Raises:
FileNotFoundError: If file doesn't exist
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
stat = os.stat(file_path)
return {
"path": file_path,
"size_bytes": stat.st_size,
"created_at": stat.st_ctime,
"modified_at": stat.st_mtime,
"accessed_at": stat.st_atime,
"is_file": os.path.isfile(file_path),
"is_dir": os.path.isdir(file_path),
"exists": True
}
class ChecksummedFile:
"""
Context manager for working with checksummed files.
Automatically calculates and stores checksum when file is written.
"""
def __init__(self, file_path: str, checksum_path: Optional[str] = None, algorithm: str = "sha256"):
"""
Initialize checksummed file handler.
Args:
file_path: Path to the file
checksum_path: Path to store checksum (default: file_path + '.checksum')
algorithm: Hash algorithm to use
"""
self.file_path = file_path
self.checksum_path = checksum_path or f"{file_path}.checksum"
self.algorithm = algorithm
self.checksum: Optional[str] = None
def calculate_and_save_checksum(self) -> str:
"""Calculate checksum and save to file"""
self.checksum = calculate_file_checksum(self.file_path, self.algorithm)
with open(self.checksum_path, 'w') as f:
f.write(f"{self.checksum} {os.path.basename(self.file_path)}\n")
logger.info(f"Saved checksum for {self.file_path}: {self.checksum}")
return self.checksum
def load_and_verify_checksum(self) -> bool:
"""Load expected checksum and verify file"""
try:
with open(self.checksum_path, 'r') as f:
expected_checksum = f.read().strip().split()[0]
return verify_file_checksum(self.file_path, expected_checksum, self.algorithm)
except FileNotFoundError:
logger.warning(f"Checksum file not found: {self.checksum_path}")
return False
except Exception as e:
logger.error(f"Error loading checksum: {e}")
return False
def get_stored_checksum(self) -> Optional[str]:
"""Get checksum from stored file"""
try:
with open(self.checksum_path, 'r') as f:
return f.read().strip().split()[0]
except FileNotFoundError:
return None

View File

@@ -0,0 +1,270 @@
"""
ML-Specific DateTime Utilities
DateTime utilities for machine learning operations, specifically for:
- Prophet forecasting model (requires timezone-naive datetimes)
- Pandas DataFrame datetime operations
- Time series data processing
"""
from datetime import datetime, timezone
from typing import Union
import pandas as pd
import logging
logger = logging.getLogger(__name__)
def ensure_timezone_aware(dt: datetime, default_tz=timezone.utc) -> datetime:
"""
Ensure a datetime is timezone-aware.
Args:
dt: Datetime to check
default_tz: Timezone to apply if datetime is naive (default: UTC)
Returns:
Timezone-aware datetime
"""
if dt is None:
return None
if dt.tzinfo is None:
return dt.replace(tzinfo=default_tz)
return dt
def ensure_timezone_naive(dt: datetime) -> datetime:
"""
Remove timezone information from a datetime.
Args:
dt: Datetime to process
Returns:
Timezone-naive datetime
"""
if dt is None:
return None
if dt.tzinfo is not None:
return dt.replace(tzinfo=None)
return dt
def normalize_datetime_to_utc(dt: Union[datetime, pd.Timestamp]) -> datetime:
"""
Normalize any datetime to UTC timezone-aware datetime.
Args:
dt: Datetime or pandas Timestamp to normalize
Returns:
UTC timezone-aware datetime
"""
if dt is None:
return None
if isinstance(dt, pd.Timestamp):
dt = dt.to_pydatetime()
if dt.tzinfo is None:
return dt.replace(tzinfo=timezone.utc)
return dt.astimezone(timezone.utc)
def normalize_dataframe_datetime_column(
df: pd.DataFrame,
column: str,
target_format: str = 'naive'
) -> pd.DataFrame:
"""
Normalize a datetime column in a dataframe to consistent format.
Args:
df: DataFrame to process
column: Name of datetime column
target_format: 'naive' or 'aware' (UTC)
Returns:
DataFrame with normalized datetime column
"""
if column not in df.columns:
logger.warning(f"Column {column} not found in dataframe")
return df
df[column] = pd.to_datetime(df[column])
if target_format == 'naive':
if df[column].dt.tz is not None:
df[column] = df[column].dt.tz_localize(None)
elif target_format == 'aware':
if df[column].dt.tz is None:
df[column] = df[column].dt.tz_localize(timezone.utc)
else:
df[column] = df[column].dt.tz_convert(timezone.utc)
else:
raise ValueError(f"Invalid target_format: {target_format}. Must be 'naive' or 'aware'")
return df
def prepare_prophet_datetime(df: pd.DataFrame, datetime_col: str = 'ds') -> pd.DataFrame:
"""
Prepare datetime column for Prophet (requires timezone-naive datetimes).
Args:
df: DataFrame with datetime column
datetime_col: Name of datetime column (default: 'ds')
Returns:
DataFrame with Prophet-compatible datetime column
"""
df = df.copy()
df = normalize_dataframe_datetime_column(df, datetime_col, target_format='naive')
return df
def safe_datetime_comparison(dt1: datetime, dt2: datetime) -> int:
"""
Safely compare two datetimes, handling timezone mismatches.
Args:
dt1: First datetime
dt2: Second datetime
Returns:
-1 if dt1 < dt2, 0 if equal, 1 if dt1 > dt2
"""
dt1_utc = normalize_datetime_to_utc(dt1)
dt2_utc = normalize_datetime_to_utc(dt2)
if dt1_utc < dt2_utc:
return -1
elif dt1_utc > dt2_utc:
return 1
else:
return 0
def get_current_utc() -> datetime:
"""
Get current datetime in UTC with timezone awareness.
Returns:
Current UTC datetime
"""
return datetime.now(timezone.utc)
def convert_timestamp_to_datetime(timestamp: Union[int, float, str]) -> datetime:
"""
Convert various timestamp formats to datetime.
Args:
timestamp: Unix timestamp (seconds or milliseconds) or ISO string
Returns:
UTC timezone-aware datetime
"""
if isinstance(timestamp, str):
dt = pd.to_datetime(timestamp)
return normalize_datetime_to_utc(dt)
if timestamp > 1e10:
timestamp = timestamp / 1000
dt = datetime.fromtimestamp(timestamp, tz=timezone.utc)
return dt
def align_dataframe_dates(
dfs: list[pd.DataFrame],
date_column: str = 'ds',
method: str = 'inner'
) -> list[pd.DataFrame]:
"""
Align multiple dataframes to have the same date range.
Args:
dfs: List of DataFrames to align
date_column: Name of the date column
method: 'inner' (intersection) or 'outer' (union)
Returns:
List of aligned DataFrames
"""
if not dfs:
return []
if len(dfs) == 1:
return dfs
all_dates = None
for df in dfs:
if date_column not in df.columns:
continue
dates = set(pd.to_datetime(df[date_column]).dt.date)
if all_dates is None:
all_dates = dates
else:
if method == 'inner':
all_dates = all_dates.intersection(dates)
elif method == 'outer':
all_dates = all_dates.union(dates)
aligned_dfs = []
for df in dfs:
if date_column not in df.columns:
aligned_dfs.append(df)
continue
df = df.copy()
df[date_column] = pd.to_datetime(df[date_column])
df['_date_only'] = df[date_column].dt.date
df = df[df['_date_only'].isin(all_dates)]
df = df.drop('_date_only', axis=1)
aligned_dfs.append(df)
return aligned_dfs
def fill_missing_dates(
df: pd.DataFrame,
date_column: str = 'ds',
freq: str = 'D',
fill_value: float = 0.0
) -> pd.DataFrame:
"""
Fill missing dates in a DataFrame with a specified frequency.
Args:
df: DataFrame with date column
date_column: Name of the date column
freq: Pandas frequency string ('D' for daily, 'H' for hourly, etc.)
fill_value: Value to fill for missing dates
Returns:
DataFrame with filled dates
"""
df = df.copy()
df[date_column] = pd.to_datetime(df[date_column])
df = df.set_index(date_column)
full_range = pd.date_range(
start=df.index.min(),
end=df.index.max(),
freq=freq
)
df = df.reindex(full_range, fill_value=fill_value)
df = df.reset_index()
df = df.rename(columns={'index': date_column})
return df

View File

@@ -0,0 +1,316 @@
"""
Retry Mechanism with Exponential Backoff
Handles transient failures with intelligent retry strategies
"""
import asyncio
import time
import random
from typing import Callable, Any, Optional, Type, Tuple
from functools import wraps
import logging
logger = logging.getLogger(__name__)
class RetryError(Exception):
"""Raised when all retry attempts are exhausted"""
def __init__(self, message: str, attempts: int, last_exception: Exception):
super().__init__(message)
self.attempts = attempts
self.last_exception = last_exception
class RetryStrategy:
"""Base retry strategy"""
def __init__(
self,
max_attempts: int = 3,
initial_delay: float = 1.0,
max_delay: float = 60.0,
exponential_base: float = 2.0,
jitter: bool = True,
retriable_exceptions: Tuple[Type[Exception], ...] = (Exception,)
):
"""
Initialize retry strategy.
Args:
max_attempts: Maximum number of retry attempts
initial_delay: Initial delay in seconds
max_delay: Maximum delay between retries
exponential_base: Base for exponential backoff
jitter: Add random jitter to prevent thundering herd
retriable_exceptions: Tuple of exception types to retry
"""
self.max_attempts = max_attempts
self.initial_delay = initial_delay
self.max_delay = max_delay
self.exponential_base = exponential_base
self.jitter = jitter
self.retriable_exceptions = retriable_exceptions
def calculate_delay(self, attempt: int) -> float:
"""Calculate delay for given attempt using exponential backoff"""
delay = min(
self.initial_delay * (self.exponential_base ** attempt),
self.max_delay
)
if self.jitter:
# Add random jitter (0-100% of delay)
delay = delay * (0.5 + random.random() * 0.5)
return delay
def is_retriable(self, exception: Exception) -> bool:
"""Check if exception should trigger retry"""
return isinstance(exception, self.retriable_exceptions)
async def retry_async(
func: Callable,
*args,
strategy: Optional[RetryStrategy] = None,
**kwargs
) -> Any:
"""
Retry async function with exponential backoff.
Args:
func: Async function to retry
*args: Positional arguments for func
strategy: Retry strategy (uses default if None)
**kwargs: Keyword arguments for func
Returns:
Result from func
Raises:
RetryError: When all attempts exhausted
"""
if strategy is None:
strategy = RetryStrategy()
last_exception = None
for attempt in range(strategy.max_attempts):
try:
result = await func(*args, **kwargs)
if attempt > 0:
logger.info(
f"Retry succeeded on attempt {attempt + 1}",
function=func.__name__,
attempt=attempt + 1
)
return result
except Exception as e:
last_exception = e
if not strategy.is_retriable(e):
logger.error(
f"Non-retriable exception occurred",
function=func.__name__,
exception=str(e)
)
raise
if attempt < strategy.max_attempts - 1:
delay = strategy.calculate_delay(attempt)
logger.warning(
f"Attempt {attempt + 1} failed, retrying in {delay:.2f}s",
function=func.__name__,
attempt=attempt + 1,
max_attempts=strategy.max_attempts,
exception=str(e)
)
await asyncio.sleep(delay)
else:
logger.error(
f"All {strategy.max_attempts} retry attempts exhausted",
function=func.__name__,
exception=str(e)
)
raise RetryError(
f"Failed after {strategy.max_attempts} attempts: {str(last_exception)}",
attempts=strategy.max_attempts,
last_exception=last_exception
)
def with_retry(
max_attempts: int = 3,
initial_delay: float = 1.0,
max_delay: float = 60.0,
exponential_base: float = 2.0,
jitter: bool = True,
retriable_exceptions: Tuple[Type[Exception], ...] = (Exception,)
):
"""
Decorator to add retry logic to async functions.
Example:
@with_retry(max_attempts=5, initial_delay=2.0)
async def fetch_data():
# Your code here
pass
"""
strategy = RetryStrategy(
max_attempts=max_attempts,
initial_delay=initial_delay,
max_delay=max_delay,
exponential_base=exponential_base,
jitter=jitter,
retriable_exceptions=retriable_exceptions
)
def decorator(func: Callable):
@wraps(func)
async def wrapper(*args, **kwargs):
return await retry_async(func, *args, strategy=strategy, **kwargs)
return wrapper
return decorator
class AdaptiveRetryStrategy(RetryStrategy):
"""
Adaptive retry strategy that adjusts based on success/failure patterns.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.success_count = 0
self.failure_count = 0
self.consecutive_failures = 0
def calculate_delay(self, attempt: int) -> float:
"""Calculate delay with adaptation based on recent history"""
base_delay = super().calculate_delay(attempt)
# Increase delay if seeing consecutive failures
if self.consecutive_failures > 5:
multiplier = min(2.0, 1.0 + (self.consecutive_failures - 5) * 0.2)
base_delay *= multiplier
return min(base_delay, self.max_delay)
def record_success(self):
"""Record successful attempt"""
self.success_count += 1
self.consecutive_failures = 0
def record_failure(self):
"""Record failed attempt"""
self.failure_count += 1
self.consecutive_failures += 1
class TimeoutRetryStrategy(RetryStrategy):
"""
Retry strategy with overall timeout across all attempts.
"""
def __init__(self, *args, timeout: float = 300.0, **kwargs):
"""
Args:
timeout: Total timeout in seconds for all attempts
"""
super().__init__(*args, **kwargs)
self.timeout = timeout
self.start_time: Optional[float] = None
def should_retry(self, attempt: int) -> bool:
"""Check if should attempt another retry"""
if self.start_time is None:
self.start_time = time.time()
return True
elapsed = time.time() - self.start_time
return elapsed < self.timeout and attempt < self.max_attempts
async def retry_with_timeout(
func: Callable,
*args,
max_attempts: int = 3,
timeout: float = 300.0,
**kwargs
) -> Any:
"""
Retry with overall timeout.
Args:
func: Function to retry
max_attempts: Maximum attempts
timeout: Overall timeout in seconds
Returns:
Result from func
"""
strategy = TimeoutRetryStrategy(
max_attempts=max_attempts,
timeout=timeout
)
start_time = time.time()
strategy.start_time = start_time
last_exception = None
for attempt in range(strategy.max_attempts):
if time.time() - start_time >= timeout:
raise RetryError(
f"Timeout of {timeout}s exceeded",
attempts=attempt + 1,
last_exception=last_exception
)
try:
return await func(*args, **kwargs)
except Exception as e:
last_exception = e
if not strategy.is_retriable(e):
raise
if attempt < strategy.max_attempts - 1:
delay = strategy.calculate_delay(attempt)
await asyncio.sleep(delay)
raise RetryError(
f"Failed after {strategy.max_attempts} attempts",
attempts=strategy.max_attempts,
last_exception=last_exception
)
# Pre-configured strategies for common use cases
HTTP_RETRY_STRATEGY = RetryStrategy(
max_attempts=3,
initial_delay=1.0,
max_delay=10.0,
exponential_base=2.0,
jitter=True
)
DATABASE_RETRY_STRATEGY = RetryStrategy(
max_attempts=5,
initial_delay=0.5,
max_delay=5.0,
exponential_base=1.5,
jitter=True
)
EXTERNAL_SERVICE_RETRY_STRATEGY = RetryStrategy(
max_attempts=4,
initial_delay=2.0,
max_delay=30.0,
exponential_base=2.5,
jitter=True
)

View File

@@ -0,0 +1,340 @@
"""
Training Time Estimation Utilities
Provides intelligent time estimation for training jobs based on:
- Product count
- Historical performance data
- Current progress and throughput
"""
from typing import List, Optional
from datetime import datetime, timedelta, timezone
import structlog
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
logger = structlog.get_logger()
def calculate_initial_estimate(
total_products: int,
avg_training_time_per_product: float = 60.0, # seconds, default 1 min/product
data_analysis_overhead: float = 120.0, # seconds, data loading & analysis
finalization_overhead: float = 60.0, # seconds, saving models & cleanup
min_estimate_minutes: int = 5,
max_estimate_minutes: int = 60
) -> int:
"""
Calculate realistic initial time estimate for training job.
Formula:
total_time = data_analysis + (products * avg_time_per_product) + finalization
Args:
total_products: Number of products to train
avg_training_time_per_product: Average time per product in seconds
data_analysis_overhead: Time for data loading and analysis in seconds
finalization_overhead: Time for saving models and cleanup in seconds
min_estimate_minutes: Minimum estimate (prevents unrealistic low values)
max_estimate_minutes: Maximum estimate (prevents unrealistic high values)
Returns:
Estimated duration in minutes
Examples:
>>> calculate_initial_estimate(1)
4 # 120 + 60 + 60 = 240s = 4min
>>> calculate_initial_estimate(5)
8 # 120 + 300 + 60 = 480s = 8min
>>> calculate_initial_estimate(10)
13 # 120 + 600 + 60 = 780s = 13min
>>> calculate_initial_estimate(20)
23 # 120 + 1200 + 60 = 1380s = 23min
>>> calculate_initial_estimate(100)
60 # Capped at max (would be 103 min)
"""
# Calculate total estimated time in seconds
estimated_seconds = (
data_analysis_overhead +
(total_products * avg_training_time_per_product) +
finalization_overhead
)
# Convert to minutes, round up
estimated_minutes = int((estimated_seconds / 60) + 0.5)
# Apply min/max bounds
estimated_minutes = max(min_estimate_minutes, min(max_estimate_minutes, estimated_minutes))
logger.info(
"Calculated initial time estimate",
total_products=total_products,
estimated_seconds=estimated_seconds,
estimated_minutes=estimated_minutes,
avg_time_per_product=avg_training_time_per_product
)
return estimated_minutes
def calculate_estimated_completion_time(
estimated_duration_minutes: int,
start_time: Optional[datetime] = None
) -> datetime:
"""
Calculate estimated completion timestamp.
Args:
estimated_duration_minutes: Estimated duration in minutes
start_time: Job start time (defaults to now)
Returns:
Estimated completion datetime (timezone-aware UTC)
"""
if start_time is None:
start_time = datetime.now(timezone.utc)
completion_time = start_time + timedelta(minutes=estimated_duration_minutes)
return completion_time
def calculate_remaining_time_smart(
progress: int,
elapsed_time: float,
products_completed: int,
total_products: int,
recent_product_times: Optional[List[float]] = None,
max_remaining_seconds: int = 1800 # 30 minutes
) -> Optional[int]:
"""
Calculate remaining time using smart algorithm that considers:
- Current progress percentage
- Actual throughput (products completed / elapsed time)
- Recent performance (weighted moving average)
Args:
progress: Current progress percentage (0-100)
elapsed_time: Time elapsed since job start (seconds)
products_completed: Number of products completed
total_products: Total number of products
recent_product_times: List of recent product training times (seconds)
max_remaining_seconds: Maximum remaining time (safety cap)
Returns:
Estimated remaining time in seconds, or None if can't calculate
"""
# Job completed or not started
if progress >= 100 or progress <= 0:
return None
# Early stage (0-20%): Use weighted estimate
if progress <= 20:
# In data analysis phase - estimate based on remaining products
remaining_products = total_products - products_completed
if recent_product_times and len(recent_product_times) > 0:
# Use recent performance if available
avg_time_per_product = sum(recent_product_times) / len(recent_product_times)
else:
# Fallback to default
avg_time_per_product = 60.0 # 1 minute per product
# Estimate: remaining products * avg time + overhead
estimated_remaining = (remaining_products * avg_time_per_product) + 60.0 # +1 min overhead
logger.debug(
"Early stage estimation",
progress=progress,
remaining_products=remaining_products,
avg_time_per_product=avg_time_per_product,
estimated_remaining=estimated_remaining
)
# Mid/late stage (21-99%): Use actual throughput
else:
if products_completed > 0:
# Calculate actual time per product from current run
actual_time_per_product = elapsed_time / products_completed
remaining_products = total_products - products_completed
estimated_remaining = remaining_products * actual_time_per_product
logger.debug(
"Mid/late stage estimation",
progress=progress,
products_completed=products_completed,
total_products=total_products,
actual_time_per_product=actual_time_per_product,
estimated_remaining=estimated_remaining
)
else:
# Fallback to linear extrapolation
estimated_total = (elapsed_time / progress) * 100
estimated_remaining = estimated_total - elapsed_time
logger.debug(
"Fallback linear estimation",
progress=progress,
elapsed_time=elapsed_time,
estimated_remaining=estimated_remaining
)
# Apply safety cap
estimated_remaining = min(estimated_remaining, max_remaining_seconds)
return int(estimated_remaining)
def calculate_average_product_time(
products_completed: int,
elapsed_time: float,
min_products_threshold: int = 3
) -> Optional[float]:
"""
Calculate average time per product from current job progress.
Args:
products_completed: Number of products completed
elapsed_time: Time elapsed since job start (seconds)
min_products_threshold: Minimum products needed for reliable calculation
Returns:
Average time per product in seconds, or None if insufficient data
"""
if products_completed < min_products_threshold:
return None
avg_time = elapsed_time / products_completed
logger.debug(
"Calculated average product time",
products_completed=products_completed,
elapsed_time=elapsed_time,
avg_time=avg_time
)
return avg_time
def format_time_remaining(seconds: int) -> str:
"""
Format remaining time in human-readable format.
Args:
seconds: Time in seconds
Returns:
Formatted string (e.g., "5 minutes", "1 hour 23 minutes")
Examples:
>>> format_time_remaining(45)
"45 seconds"
>>> format_time_remaining(180)
"3 minutes"
>>> format_time_remaining(5400)
"1 hour 30 minutes"
"""
if seconds < 60:
return f"{seconds} seconds"
minutes = seconds // 60
remaining_seconds = seconds % 60
if minutes < 60:
if remaining_seconds > 0:
return f"{minutes} minutes {remaining_seconds} seconds"
return f"{minutes} minutes"
hours = minutes // 60
remaining_minutes = minutes % 60
if remaining_minutes > 0:
return f"{hours} hour{'s' if hours > 1 else ''} {remaining_minutes} minutes"
return f"{hours} hour{'s' if hours > 1 else ''}"
async def get_historical_average_estimate(
db_session: AsyncSession,
tenant_id: str,
lookback_days: int = 30,
limit: int = 10
) -> Optional[float]:
"""
Get historical average training time per product for a tenant.
This function queries the TrainingPerformanceMetrics table to get
recent historical data and calculate an average.
Args:
db_session: Async database session
tenant_id: Tenant UUID
lookback_days: How many days back to look
limit: Maximum number of historical records to consider
Returns:
Average time per product in seconds, or None if no historical data
"""
try:
from app.models.training import TrainingPerformanceMetrics
from datetime import timedelta
cutoff = datetime.now(timezone.utc) - timedelta(days=lookback_days)
# Query recent training performance metrics using SQLAlchemy 2.0 async pattern
query = (
select(TrainingPerformanceMetrics)
.where(
TrainingPerformanceMetrics.tenant_id == tenant_id,
TrainingPerformanceMetrics.completed_at >= cutoff
)
.order_by(TrainingPerformanceMetrics.completed_at.desc())
.limit(limit)
)
result = await db_session.execute(query)
metrics = result.scalars().all()
if not metrics:
logger.info(
"No historical training data found",
tenant_id=tenant_id,
lookback_days=lookback_days
)
return None
# Calculate weighted average (more recent = higher weight)
total_weight = 0
weighted_sum = 0
for i, metric in enumerate(metrics):
# Weight: newer records get higher weight
weight = limit - i
weighted_sum += metric.avg_time_per_product * weight
total_weight += weight
if total_weight == 0:
return None
weighted_avg = weighted_sum / total_weight
logger.info(
"Calculated historical average",
tenant_id=tenant_id,
records_used=len(metrics),
weighted_avg=weighted_avg
)
return weighted_avg
except Exception as e:
logger.error(
"Error getting historical average",
tenant_id=tenant_id,
error=str(e)
)
return None

View File

@@ -0,0 +1,11 @@
"""WebSocket support for training service"""
from app.websocket.manager import websocket_manager, WebSocketConnectionManager
from app.websocket.events import setup_websocket_event_consumer, cleanup_websocket_consumers
__all__ = [
'websocket_manager',
'WebSocketConnectionManager',
'setup_websocket_event_consumer',
'cleanup_websocket_consumers'
]

View File

@@ -0,0 +1,148 @@
"""
RabbitMQ Event Consumer for WebSocket Broadcasting
Listens to training events from RabbitMQ and broadcasts them to WebSocket clients
"""
import asyncio
import json
from typing import Dict, Set
import structlog
from app.websocket.manager import websocket_manager
from app.services.training_events import training_publisher
logger = structlog.get_logger()
# Track active consumers
_active_consumers: Set[asyncio.Task] = set()
async def handle_training_event(message) -> None:
"""
Handle incoming RabbitMQ training events and broadcast to WebSocket clients.
This is the bridge between RabbitMQ and WebSocket.
"""
try:
# Parse message
body = message.body.decode()
data = json.loads(body)
event_type = data.get('event_type', 'unknown')
event_data = data.get('data', {})
job_id = event_data.get('job_id')
if not job_id:
logger.warning("Received event without job_id, skipping", event_type=event_type)
await message.ack()
return
logger.info("Received training event from RabbitMQ",
job_id=job_id,
event_type=event_type,
progress=event_data.get('progress'))
# Map RabbitMQ event types to WebSocket message types
ws_message_type = _map_event_type(event_type)
# Create WebSocket message
ws_message = {
"type": ws_message_type,
"job_id": job_id,
"timestamp": data.get('timestamp'),
"data": event_data
}
# Broadcast to all WebSocket clients for this job
sent_count = await websocket_manager.broadcast(job_id, ws_message)
logger.info("Broadcasted event to WebSocket clients",
job_id=job_id,
event_type=event_type,
ws_message_type=ws_message_type,
clients_notified=sent_count)
# Always acknowledge the message to avoid infinite redelivery loops
# Progress events (started, progress, product_completed) are ephemeral and don't need redelivery
# Final events (completed, failed) should always be acknowledged
await message.ack()
except Exception as e:
logger.error("Error handling training event",
error=str(e),
exc_info=True)
# Always acknowledge even on error to avoid infinite redelivery loops
# The event is logged so we can debug issues
try:
await message.ack()
except:
pass # Message already gone or connection closed
def _map_event_type(rabbitmq_event_type: str) -> str:
"""Map RabbitMQ event types to WebSocket message types"""
mapping = {
"training.started": "started",
"training.progress": "progress",
"training.step.completed": "step_completed",
"training.product.completed": "product_completed",
"training.completed": "completed",
"training.failed": "failed",
}
return mapping.get(rabbitmq_event_type, "unknown")
async def setup_websocket_event_consumer() -> bool:
"""
Set up a global RabbitMQ consumer that listens to all training events
and broadcasts them to connected WebSocket clients.
"""
try:
# Ensure publisher is connected
if not training_publisher.connected:
logger.info("Connecting training publisher for WebSocket event consumer")
success = await training_publisher.connect()
if not success:
logger.error("Failed to connect training publisher")
return False
# Create a unique queue for WebSocket broadcasting
queue_name = "training_websocket_broadcast"
logger.info("Setting up WebSocket event consumer", queue_name=queue_name)
# Subscribe to all training events (routing key: training.#)
success = await training_publisher.consume_events(
exchange_name="training.events",
queue_name=queue_name,
routing_key="training.#", # Listen to all training events (multi-level)
callback=handle_training_event
)
if success:
logger.info("WebSocket event consumer set up successfully")
return True
else:
logger.error("Failed to set up WebSocket event consumer")
return False
except Exception as e:
logger.error("Error setting up WebSocket event consumer",
error=str(e),
exc_info=True)
return False
async def cleanup_websocket_consumers() -> None:
"""Clean up WebSocket event consumers"""
logger.info("Cleaning up WebSocket event consumers")
for task in _active_consumers:
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
_active_consumers.clear()
logger.info("WebSocket event consumers cleaned up")

View File

@@ -0,0 +1,300 @@
"""
WebSocket Connection Manager for Training Service
Manages WebSocket connections and broadcasts RabbitMQ events to connected clients
HORIZONTAL SCALING:
- Uses Redis pub/sub for cross-pod WebSocket broadcasting
- Each pod subscribes to a Redis channel and broadcasts to its local connections
- Events published to Redis are received by all pods, ensuring clients on any
pod receive events from training jobs running on any other pod
"""
import asyncio
import json
import os
from typing import Dict, Optional
from fastapi import WebSocket
import structlog
logger = structlog.get_logger()
# Redis pub/sub channel for WebSocket events
REDIS_WEBSOCKET_CHANNEL = "training:websocket:events"
class WebSocketConnectionManager:
"""
WebSocket connection manager with Redis pub/sub for horizontal scaling.
In a multi-pod deployment:
1. Events are published to Redis pub/sub (not just local broadcast)
2. Each pod subscribes to Redis and broadcasts to its local WebSocket connections
3. This ensures clients connected to any pod receive events from any pod
Flow:
- RabbitMQ event → Pod A receives → Pod A publishes to Redis
- Redis pub/sub → All pods receive → Each pod broadcasts to local WebSockets
"""
def __init__(self):
# Structure: {job_id: {websocket_id: WebSocket}}
self._connections: Dict[str, Dict[int, WebSocket]] = {}
self._lock = asyncio.Lock()
# Store latest event for each job to provide initial state
self._latest_events: Dict[str, dict] = {}
# Redis client for pub/sub
self._redis: Optional[object] = None
self._pubsub: Optional[object] = None
self._subscriber_task: Optional[asyncio.Task] = None
self._running = False
self._instance_id = f"{os.environ.get('HOSTNAME', 'unknown')}:{os.getpid()}"
async def initialize_redis(self, redis_url: str) -> bool:
"""
Initialize Redis connection for cross-pod pub/sub.
Args:
redis_url: Redis connection URL
Returns:
True if successful, False otherwise
"""
try:
import redis.asyncio as redis_async
self._redis = redis_async.from_url(redis_url, decode_responses=True)
await self._redis.ping()
# Create pub/sub subscriber
self._pubsub = self._redis.pubsub()
await self._pubsub.subscribe(REDIS_WEBSOCKET_CHANNEL)
# Start subscriber task
self._running = True
self._subscriber_task = asyncio.create_task(self._redis_subscriber_loop())
logger.info("Redis pub/sub initialized for WebSocket broadcasting",
instance_id=self._instance_id,
channel=REDIS_WEBSOCKET_CHANNEL)
return True
except Exception as e:
logger.error("Failed to initialize Redis pub/sub",
error=str(e),
instance_id=self._instance_id)
return False
async def shutdown(self):
"""Shutdown Redis pub/sub connection"""
self._running = False
if self._subscriber_task:
self._subscriber_task.cancel()
try:
await self._subscriber_task
except asyncio.CancelledError:
pass
if self._pubsub:
await self._pubsub.unsubscribe(REDIS_WEBSOCKET_CHANNEL)
await self._pubsub.close()
if self._redis:
await self._redis.close()
logger.info("Redis pub/sub shutdown complete",
instance_id=self._instance_id)
async def _redis_subscriber_loop(self):
"""Background task to receive Redis pub/sub messages and broadcast locally"""
try:
while self._running:
try:
message = await self._pubsub.get_message(
ignore_subscribe_messages=True,
timeout=1.0
)
if message and message['type'] == 'message':
await self._handle_redis_message(message['data'])
except asyncio.CancelledError:
break
except Exception as e:
logger.error("Error in Redis subscriber loop",
error=str(e),
instance_id=self._instance_id)
await asyncio.sleep(1) # Backoff on error
except asyncio.CancelledError:
pass
logger.info("Redis subscriber loop stopped",
instance_id=self._instance_id)
async def _handle_redis_message(self, data: str):
"""Handle a message received from Redis pub/sub"""
try:
payload = json.loads(data)
job_id = payload.get('job_id')
message = payload.get('message')
source_instance = payload.get('source_instance')
if not job_id or not message:
return
# Log cross-pod message
if source_instance != self._instance_id:
logger.debug("Received cross-pod WebSocket event",
job_id=job_id,
source_instance=source_instance,
local_instance=self._instance_id)
# Broadcast to local WebSocket connections
await self._broadcast_local(job_id, message)
except json.JSONDecodeError as e:
logger.warning("Invalid JSON in Redis message", error=str(e))
except Exception as e:
logger.error("Error handling Redis message", error=str(e))
async def connect(self, job_id: str, websocket: WebSocket) -> None:
"""Register a new WebSocket connection for a job"""
await websocket.accept()
async with self._lock:
if job_id not in self._connections:
self._connections[job_id] = {}
ws_id = id(websocket)
self._connections[job_id][ws_id] = websocket
# Send initial state if available
if job_id in self._latest_events:
try:
await websocket.send_json({
"type": "initial_state",
"job_id": job_id,
"data": self._latest_events[job_id]
})
except Exception as e:
logger.warning("Failed to send initial state to new connection", error=str(e))
logger.info("WebSocket connected",
job_id=job_id,
websocket_id=ws_id,
total_connections=len(self._connections[job_id]),
instance_id=self._instance_id)
async def disconnect(self, job_id: str, websocket: WebSocket) -> None:
"""Remove a WebSocket connection"""
async with self._lock:
if job_id in self._connections:
ws_id = id(websocket)
self._connections[job_id].pop(ws_id, None)
# Clean up empty job connections
if not self._connections[job_id]:
del self._connections[job_id]
logger.info("WebSocket disconnected",
job_id=job_id,
websocket_id=ws_id,
remaining_connections=len(self._connections.get(job_id, {})),
instance_id=self._instance_id)
async def broadcast(self, job_id: str, message: dict) -> int:
"""
Broadcast a message to all connections for a specific job across ALL pods.
If Redis is configured, publishes to Redis pub/sub which then broadcasts
to all pods. Otherwise, falls back to local-only broadcast.
Returns the number of successful local broadcasts.
"""
# Store the latest event for this job to provide initial state to new connections
if message.get('type') != 'initial_state':
self._latest_events[job_id] = message
# If Redis is available, publish to Redis for cross-pod broadcast
if self._redis:
try:
payload = json.dumps({
'job_id': job_id,
'message': message,
'source_instance': self._instance_id
})
await self._redis.publish(REDIS_WEBSOCKET_CHANNEL, payload)
logger.debug("Published WebSocket event to Redis",
job_id=job_id,
message_type=message.get('type'),
instance_id=self._instance_id)
# Return 0 here because the actual broadcast happens via subscriber
# The count will be from _broadcast_local when the message is received
return 0
except Exception as e:
logger.warning("Failed to publish to Redis, falling back to local broadcast",
error=str(e),
job_id=job_id)
# Fall through to local broadcast
# Local-only broadcast (when Redis is not available)
return await self._broadcast_local(job_id, message)
async def _broadcast_local(self, job_id: str, message: dict) -> int:
"""
Broadcast a message to local WebSocket connections only.
This is called either directly (no Redis) or from Redis subscriber.
"""
if job_id not in self._connections:
logger.debug("No active local connections for job",
job_id=job_id,
instance_id=self._instance_id)
return 0
connections = list(self._connections[job_id].values())
successful_sends = 0
failed_websockets = []
for websocket in connections:
try:
await websocket.send_json(message)
successful_sends += 1
except Exception as e:
logger.warning("Failed to send message to WebSocket",
job_id=job_id,
error=str(e))
failed_websockets.append(websocket)
# Clean up failed connections
if failed_websockets:
async with self._lock:
for ws in failed_websockets:
ws_id = id(ws)
self._connections[job_id].pop(ws_id, None)
if successful_sends > 0:
logger.info("Broadcasted message to local WebSocket clients",
job_id=job_id,
message_type=message.get('type'),
successful_sends=successful_sends,
failed_sends=len(failed_websockets),
instance_id=self._instance_id)
return successful_sends
def get_connection_count(self, job_id: str) -> int:
"""Get the number of active local connections for a job"""
return len(self._connections.get(job_id, {}))
def get_total_connection_count(self) -> int:
"""Get total number of active connections across all jobs"""
return sum(len(conns) for conns in self._connections.values())
def is_redis_enabled(self) -> bool:
"""Check if Redis pub/sub is enabled"""
return self._redis is not None and self._running
# Global singleton instance
websocket_manager = WebSocketConnectionManager()