Add all the code for training service

This commit is contained in:
Urtzi Alfaro
2025-07-19 16:59:37 +02:00
parent 42097202d2
commit f3071c00bd
21 changed files with 7504 additions and 764 deletions

View File

@@ -8,10 +8,11 @@ from typing import List
import structlog
from app.core.database import get_db
from app.core.auth import verify_token
from app.core.auth import get_current_tenant_id
from app.schemas.training import TrainedModelResponse
from app.services.training_service import TrainingService
logger = structlog.get_logger()
router = APIRouter()
@@ -19,12 +20,12 @@ training_service = TrainingService()
@router.get("/", response_model=List[TrainedModelResponse])
async def get_trained_models(
user_data: dict = Depends(verify_token),
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
"""Get trained models"""
try:
return await training_service.get_trained_models(user_data, db)
return await training_service.get_trained_models(tenant_id, db)
except Exception as e:
logger.error(f"Get trained models error: {e}")
raise HTTPException(

View File

@@ -1,77 +1,299 @@
# services/training/app/api/training.py
"""
Training API endpoints
Training API endpoints for the training service
"""
from fastapi import APIRouter, Depends, HTTPException, status, Query
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, Optional
import structlog
from typing import Dict, List, Any, Optional
import logging
from datetime import datetime
import uuid
from app.core.database import get_db
from app.core.auth import verify_token
from app.schemas.training import TrainingRequest, TrainingJobResponse, TrainedModelResponse
from app.core.auth import get_current_tenant_id
from app.schemas.training import (
TrainingJobRequest,
TrainingJobResponse,
TrainingStatusResponse,
SingleProductTrainingRequest
)
from app.services.training_service import TrainingService
from app.services.messaging import publish_job_started, publish_job_cancelled, publish_product_training_started
from shared.monitoring.metrics import MetricsCollector
logger = structlog.get_logger()
logger = logging.getLogger(__name__)
router = APIRouter()
metrics = MetricsCollector("training-service")
# Initialize training service
training_service = TrainingService()
@router.post("/train", response_model=TrainingJobResponse)
async def start_training(
request: TrainingRequest,
user_data: dict = Depends(verify_token),
@router.post("/jobs", response_model=TrainingJobResponse)
async def start_training_job(
request: TrainingJobRequest,
background_tasks: BackgroundTasks,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
"""Start training job"""
"""
Start a new training job for all products of a tenant.
Replaces the old Celery-based training system.
"""
try:
return await training_service.start_training(request, user_data, db)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
logger.info(f"Starting training job for tenant {tenant_id}")
metrics.increment_counter("training_jobs_started")
# Generate job ID
job_id = f"training_{tenant_id}_{uuid.uuid4().hex[:8]}"
# Create training job record
training_job = await training_service.create_training_job(
db=db,
tenant_id=tenant_id,
job_id=job_id,
config=request.dict()
)
# Start training in background
background_tasks.add_task(
training_service.execute_training_job,
db,
job_id,
tenant_id,
request
)
# Publish training started event
await publish_job_started(job_id, tenant_id, request.dict())
return TrainingJobResponse(
job_id=job_id,
status="started",
message="Training job started successfully",
tenant_id=tenant_id,
created_at=training_job.start_time,
estimated_duration_minutes=request.estimated_duration or 15
)
except Exception as e:
logger.error(f"Training start error: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to start training"
)
logger.error(f"Failed to start training job: {str(e)}")
metrics.increment_counter("training_jobs_failed")
raise HTTPException(status_code=500, detail=f"Failed to start training job: {str(e)}")
@router.get("/status/{job_id}", response_model=TrainingJobResponse)
@router.get("/jobs/{job_id}/status", response_model=TrainingStatusResponse)
async def get_training_status(
job_id: str,
user_data: dict = Depends(verify_token),
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
"""Get training job status"""
"""
Get the status of a training job.
Provides real-time progress updates.
"""
try:
return await training_service.get_training_status(job_id, user_data, db)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e)
# Get job status from database
job_status = await training_service.get_job_status(db, job_id, tenant_id)
if not job_status:
raise HTTPException(status_code=404, detail="Training job not found")
return TrainingStatusResponse(
job_id=job_id,
status=job_status.status,
progress=job_status.progress,
current_step=job_status.current_step,
started_at=job_status.start_time,
completed_at=job_status.end_time,
results=job_status.results,
error_message=job_status.error_message
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Get training status error: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get training status"
)
logger.error(f"Failed to get training status: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to get training status: {str(e)}")
@router.get("/jobs", response_model=List[TrainingJobResponse])
async def get_training_jobs(
limit: int = Query(10, ge=1, le=100),
offset: int = Query(0, ge=0),
user_data: dict = Depends(verify_token),
@router.post("/products/{product_name}", response_model=TrainingJobResponse)
async def train_single_product(
product_name: str,
request: SingleProductTrainingRequest,
background_tasks: BackgroundTasks,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
"""Get training jobs"""
"""
Train a model for a single product.
Useful for quick model updates or new products.
"""
try:
return await training_service.get_training_jobs(user_data, limit, offset, db)
logger.info(f"Starting single product training for {product_name}, tenant {tenant_id}")
metrics.increment_counter("single_product_training_started")
# Generate job ID
job_id = f"single_{tenant_id}_{product_name}_{uuid.uuid4().hex[:8]}"
# Create training job record
training_job = await training_service.create_single_product_job(
db=db,
tenant_id=tenant_id,
product_name=product_name,
job_id=job_id,
config=request.dict()
)
# Start training in background
background_tasks.add_task(
training_service.execute_single_product_training,
db,
job_id,
tenant_id,
product_name,
request
)
# Publish event
await publish_product_training_started(job_id, tenant_id, product_name)
return TrainingJobResponse(
job_id=job_id,
status="started",
message=f"Single product training started for {product_name}",
tenant_id=tenant_id,
created_at=training_job.start_time,
estimated_duration_minutes=5
)
except Exception as e:
logger.error(f"Get training jobs error: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get training jobs"
)
logger.error(f"Failed to start single product training: {str(e)}")
metrics.increment_counter("single_product_training_failed")
raise HTTPException(status_code=500, detail=f"Failed to start training: {str(e)}")
@router.get("/jobs", response_model=List[TrainingStatusResponse])
async def list_training_jobs(
limit: int = 10,
status: Optional[str] = None,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
"""
List training jobs for a tenant.
"""
try:
jobs = await training_service.list_training_jobs(
db=db,
tenant_id=tenant_id,
limit=limit,
status_filter=status
)
return [
TrainingStatusResponse(
job_id=job.job_id,
status=job.status,
progress=job.progress,
current_step=job.current_step,
started_at=job.start_time,
completed_at=job.end_time,
results=job.results,
error_message=job.error_message
)
for job in jobs
]
except Exception as e:
logger.error(f"Failed to list training jobs: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to list training jobs: {str(e)}")
@router.post("/jobs/{job_id}/cancel")
async def cancel_training_job(
job_id: str,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
"""
Cancel a running training job.
"""
try:
logger.info(f"Cancelling training job {job_id} for tenant {tenant_id}")
# Update job status to cancelled
success = await training_service.cancel_training_job(db, job_id, tenant_id)
if not success:
raise HTTPException(status_code=404, detail="Training job not found or cannot be cancelled")
# Publish cancellation event
await publish_job_cancelled(job_id, tenant_id)
return {"message": "Training job cancelled successfully"}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to cancel training job: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to cancel training job: {str(e)}")
@router.get("/jobs/{job_id}/logs")
async def get_training_logs(
job_id: str,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
"""
Get detailed logs for a training job.
"""
try:
logs = await training_service.get_training_logs(db, job_id, tenant_id)
if not logs:
raise HTTPException(status_code=404, detail="Training job not found")
return {"job_id": job_id, "logs": logs}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to get training logs: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to get training logs: {str(e)}")
@router.post("/validate")
async def validate_training_data(
request: TrainingJobRequest,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
"""
Validate training data before starting a job.
Provides early feedback on data quality issues.
"""
try:
logger.info(f"Validating training data for tenant {tenant_id}")
# Perform data validation
validation_result = await training_service.validate_training_data(
db=db,
tenant_id=tenant_id,
config=request.dict()
)
return {
"is_valid": validation_result["is_valid"],
"issues": validation_result.get("issues", []),
"recommendations": validation_result.get("recommendations", []),
"estimated_training_time": validation_result.get("estimated_time_minutes", 15)
}
except Exception as e:
logger.error(f"Failed to validate training data: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to validate training data: {str(e)}")
@router.get("/health")
async def health_check():
"""Health check for the training service"""
return {
"status": "healthy",
"service": "training-service",
"timestamp": datetime.now().isoformat()
}

View File

@@ -1,38 +1,303 @@
# services/training/app/core/auth.py
"""
Authentication utilities for training service
Authentication and authorization for training service
"""
import httpx
from fastapi import HTTPException, status, Depends
from fastapi.security import HTTPBearer
import structlog
from typing import Optional
from fastapi import HTTPException, Depends, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
import httpx
from app.core.config import settings
logger = structlog.get_logger()
security = HTTPBearer()
# HTTP Bearer token scheme
security = HTTPBearer(auto_error=False)
async def verify_token(token: str = Depends(security)):
"""Verify token with auth service"""
class AuthenticationError(Exception):
"""Custom exception for authentication errors"""
pass
class AuthorizationError(Exception):
"""Custom exception for authorization errors"""
pass
async def verify_token(token: str) -> dict:
"""
Verify JWT token with auth service
Args:
token: JWT token to verify
Returns:
dict: Token payload with user and tenant information
Raises:
AuthenticationError: If token is invalid
"""
try:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{settings.AUTH_SERVICE_URL}/auth/verify",
headers={"Authorization": f"Bearer {token.credentials}"}
headers={"Authorization": f"Bearer {token}"},
timeout=10.0
)
if response.status_code == 200:
return response.json()
token_data = response.json()
logger.debug("Token verified successfully", user_id=token_data.get("user_id"))
return token_data
elif response.status_code == 401:
logger.warning("Invalid token provided")
raise AuthenticationError("Invalid or expired token")
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials"
)
logger.error("Auth service error", status_code=response.status_code)
raise AuthenticationError("Authentication service unavailable")
except httpx.TimeoutException:
logger.error("Auth service timeout")
raise AuthenticationError("Authentication service timeout")
except httpx.RequestError as e:
logger.error(f"Auth service unavailable: {e}")
logger.error("Auth service request error", error=str(e))
raise AuthenticationError("Authentication service unavailable")
except AuthenticationError:
raise
except Exception as e:
logger.error("Unexpected auth error", error=str(e))
raise AuthenticationError("Authentication failed")
async def get_current_user(
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
) -> dict:
"""
Get current authenticated user
Args:
credentials: HTTP Bearer credentials
Returns:
dict: User information
Raises:
HTTPException: If authentication fails
"""
if not credentials:
logger.warning("No credentials provided")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Authentication service unavailable"
)
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication credentials required",
headers={"WWW-Authenticate": "Bearer"},
)
try:
token_data = await verify_token(credentials.credentials)
return token_data
except AuthenticationError as e:
logger.warning("Authentication failed", error=str(e))
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=str(e),
headers={"WWW-Authenticate": "Bearer"},
)
async def get_current_tenant_id(
current_user: dict = Depends(get_current_user)
) -> str:
"""
Get current tenant ID from authenticated user
Args:
current_user: Current authenticated user data
Returns:
str: Tenant ID
Raises:
HTTPException: If tenant ID is missing
"""
tenant_id = current_user.get("tenant_id")
if not tenant_id:
logger.error("Missing tenant_id in token", user_data=current_user)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid token: missing tenant information"
)
return tenant_id
async def require_admin_role(
current_user: dict = Depends(get_current_user)
) -> dict:
"""
Require admin role for endpoint access
Args:
current_user: Current authenticated user data
Returns:
dict: User information
Raises:
HTTPException: If user is not admin
"""
user_role = current_user.get("role", "").lower()
if user_role != "admin":
logger.warning("Access denied - admin role required",
user_id=current_user.get("user_id"),
role=user_role)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin role required"
)
return current_user
async def require_training_permission(
current_user: dict = Depends(get_current_user)
) -> dict:
"""
Require training permission for endpoint access
Args:
current_user: Current authenticated user data
Returns:
dict: User information
Raises:
HTTPException: If user doesn't have training permission
"""
permissions = current_user.get("permissions", [])
if "training" not in permissions and current_user.get("role", "").lower() != "admin":
logger.warning("Access denied - training permission required",
user_id=current_user.get("user_id"),
permissions=permissions)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Training permission required"
)
return current_user
# Optional authentication for development/testing
async def get_current_user_optional(
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
) -> Optional[dict]:
"""
Get current user but don't require authentication (for development)
Args:
credentials: HTTP Bearer credentials
Returns:
dict or None: User information if authenticated, None otherwise
"""
if not credentials:
return None
try:
token_data = await verify_token(credentials.credentials)
return token_data
except AuthenticationError:
return None
async def get_tenant_id_optional(
current_user: Optional[dict] = Depends(get_current_user_optional)
) -> Optional[str]:
"""
Get tenant ID but don't require authentication (for development)
Args:
current_user: Current user data (optional)
Returns:
str or None: Tenant ID if available, None otherwise
"""
if not current_user:
return None
return current_user.get("tenant_id")
# Development/testing auth bypass
async def get_test_tenant_id() -> str:
"""
Get test tenant ID for development/testing
Only works when DEBUG is enabled
Returns:
str: Test tenant ID
"""
if settings.DEBUG:
return "test-tenant-development"
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Test authentication only available in debug mode"
)
# Token validation utility
def validate_token_structure(token_data: dict) -> bool:
"""
Validate that token data has required structure
Args:
token_data: Token payload data
Returns:
bool: True if valid structure, False otherwise
"""
required_fields = ["user_id", "tenant_id"]
for field in required_fields:
if field not in token_data:
logger.warning("Invalid token structure - missing field", field=field)
return False
return True
# Role checking utilities
def has_role(user_data: dict, required_role: str) -> bool:
"""
Check if user has required role
Args:
user_data: User data from token
required_role: Required role name
Returns:
bool: True if user has role, False otherwise
"""
user_role = user_data.get("role", "").lower()
return user_role == required_role.lower()
def has_permission(user_data: dict, required_permission: str) -> bool:
"""
Check if user has required permission
Args:
user_data: User data from token
required_permission: Required permission name
Returns:
bool: True if user has permission, False otherwise
"""
permissions = user_data.get("permissions", [])
return required_permission in permissions or has_role(user_data, "admin")
# Export commonly used items
__all__ = [
'get_current_user',
'get_current_tenant_id',
'require_admin_role',
'require_training_permission',
'get_current_user_optional',
'get_tenant_id_optional',
'get_test_tenant_id',
'has_role',
'has_permission',
'AuthenticationError',
'AuthorizationError'
]

View File

@@ -1,12 +1,260 @@
# services/training/app/core/database.py
"""
Database configuration for training service
Uses shared database infrastructure
"""
from shared.database.base import DatabaseManager
import structlog
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
from shared.database.base import DatabaseManager, Base
from app.core.config import settings
# Initialize database manager
logger = structlog.get_logger()
# Initialize database manager using shared infrastructure
database_manager = DatabaseManager(settings.DATABASE_URL)
# Alias for convenience
get_db = database_manager.get_db
# Alias for convenience - matches the existing interface
get_db = database_manager.get_db
async def get_db_health() -> bool:
"""
Health check function for database connectivity
Enhanced version of the shared functionality
"""
try:
async with database_manager.async_engine.begin() as conn:
await conn.execute(text("SELECT 1"))
logger.debug("Database health check passed")
return True
except Exception as e:
logger.error("Database health check failed", error=str(e))
return False
# Training service specific database utilities
class TrainingDatabaseUtils:
"""Training service specific database utilities"""
@staticmethod
async def cleanup_old_training_logs(days_old: int = 90):
"""Clean up old training logs"""
try:
async with database_manager.async_session_local() as session:
if settings.DATABASE_URL.startswith("sqlite"):
query = text(
"DELETE FROM model_training_logs "
"WHERE start_time < datetime('now', :days_param)"
)
params = {"days_param": f"-{days_old} days"}
else:
query = text(
"DELETE FROM model_training_logs "
"WHERE start_time < NOW() - INTERVAL :days_param"
)
params = {"days_param": f"{days_old} days"}
result = await session.execute(query, params)
await session.commit()
deleted_count = result.rowcount
logger.info("Cleaned up old training logs",
deleted_count=deleted_count,
days_old=days_old)
return deleted_count
except Exception as e:
logger.error("Training logs cleanup failed", error=str(e))
raise
@staticmethod
async def cleanup_old_models(days_old: int = 365):
"""Clean up old inactive models"""
try:
async with database_manager.async_session_local() as session:
if settings.DATABASE_URL.startswith("sqlite"):
query = text(
"DELETE FROM trained_models "
"WHERE is_active = 0 AND created_at < datetime('now', :days_param)"
)
params = {"days_param": f"-{days_old} days"}
else:
query = text(
"DELETE FROM trained_models "
"WHERE is_active = false AND created_at < NOW() - INTERVAL :days_param"
)
params = {"days_param": f"{days_old} days"}
result = await session.execute(query, params)
await session.commit()
deleted_count = result.rowcount
logger.info("Cleaned up old models",
deleted_count=deleted_count,
days_old=days_old)
return deleted_count
except Exception as e:
logger.error("Model cleanup failed", error=str(e))
raise
@staticmethod
async def get_training_statistics(tenant_id: str = None) -> dict:
"""Get training statistics"""
try:
async with database_manager.async_session_local() as session:
# Base query for training logs
if tenant_id:
logs_query = text(
"SELECT status, COUNT(*) as count "
"FROM model_training_logs "
"WHERE tenant_id = :tenant_id "
"GROUP BY status"
)
models_query = text(
"SELECT COUNT(*) as count "
"FROM trained_models "
"WHERE tenant_id = :tenant_id AND is_active = :is_active"
)
params = {"tenant_id": tenant_id}
else:
logs_query = text(
"SELECT status, COUNT(*) as count "
"FROM model_training_logs "
"GROUP BY status"
)
models_query = text(
"SELECT COUNT(*) as count "
"FROM trained_models "
"WHERE is_active = :is_active"
)
params = {}
# Get training job statistics
logs_result = await session.execute(logs_query, params)
job_stats = {row.status: row.count for row in logs_result.fetchall()}
# Get active models count
active_models_result = await session.execute(
models_query,
{**params, "is_active": True}
)
active_models = active_models_result.scalar() or 0
# Get inactive models count
inactive_models_result = await session.execute(
models_query,
{**params, "is_active": False}
)
inactive_models = inactive_models_result.scalar() or 0
return {
"training_jobs": job_stats,
"active_models": active_models,
"inactive_models": inactive_models,
"total_models": active_models + inactive_models
}
except Exception as e:
logger.error("Failed to get training statistics", error=str(e))
return {
"training_jobs": {},
"active_models": 0,
"inactive_models": 0,
"total_models": 0
}
@staticmethod
async def check_tenant_data_exists(tenant_id: str) -> bool:
"""Check if tenant has any training data"""
try:
async with database_manager.async_session_local() as session:
query = text(
"SELECT COUNT(*) as count "
"FROM model_training_logs "
"WHERE tenant_id = :tenant_id "
"LIMIT 1"
)
result = await session.execute(query, {"tenant_id": tenant_id})
count = result.scalar() or 0
return count > 0
except Exception as e:
logger.error("Failed to check tenant data existence",
tenant_id=tenant_id, error=str(e))
return False
# Enhanced database session dependency with better error handling
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
"""
Enhanced database session dependency with better logging and error handling
"""
async with database_manager.async_session_local() as session:
try:
logger.debug("Database session created")
yield session
except Exception as e:
logger.error("Database session error", error=str(e), exc_info=True)
await session.rollback()
raise
finally:
await session.close()
logger.debug("Database session closed")
# Database initialization for training service
async def initialize_training_database():
"""Initialize database tables for training service"""
try:
logger.info("Initializing training service database")
# Import models to ensure they're registered
from app.models.training import (
ModelTrainingLog,
TrainedModel,
ModelPerformanceMetric,
TrainingJobQueue,
ModelArtifact
)
# Create tables using shared infrastructure
await database_manager.create_tables()
logger.info("Training service database initialized successfully")
except Exception as e:
logger.error("Failed to initialize training service database", error=str(e))
raise
# Database cleanup for training service
async def cleanup_training_database():
"""Cleanup database connections for training service"""
try:
logger.info("Cleaning up training service database connections")
# Close engine connections
if hasattr(database_manager, 'async_engine') and database_manager.async_engine:
await database_manager.async_engine.dispose()
logger.info("Training service database cleanup completed")
except Exception as e:
logger.error("Failed to cleanup training service database", error=str(e))
# Export the commonly used items to maintain compatibility
__all__ = [
'Base',
'database_manager',
'get_db',
'get_db_session',
'get_db_health',
'TrainingDatabaseUtils',
'initialize_training_database',
'cleanup_training_database'
]

View File

@@ -1,81 +1,282 @@
# services/training/app/main.py
"""
Training Service
Handles ML model training for bakery demand forecasting
Training Service Main Application
Enhanced with proper error handling, monitoring, and lifecycle management
"""
import structlog
from fastapi import FastAPI, BackgroundTasks
import asyncio
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from fastapi.responses import JSONResponse
import uvicorn
from app.core.config import settings
from app.core.database import database_manager
from app.core.database import database_manager, get_db_health
from app.api import training, models
from app.services.messaging import setup_messaging, cleanup_messaging
from shared.monitoring.logging import setup_logging
from shared.monitoring.metrics import MetricsCollector
from shared.auth.decorators import require_auth
# Setup logging
# Setup structured logging
setup_logging("training-service", settings.LOG_LEVEL)
logger = structlog.get_logger()
# Create FastAPI app
app = FastAPI(
title="Training Service",
description="ML model training service for bakery demand forecasting",
version="1.0.0"
)
# Initialize metrics collector
metrics_collector = MetricsCollector("training-service")
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Application lifespan manager for startup and shutdown events
"""
# Startup
logger.info("Starting Training Service", version="1.0.0")
try:
# Initialize database
logger.info("Initializing database connection")
await database_manager.create_tables()
logger.info("Database initialized successfully")
# Initialize messaging
logger.info("Setting up messaging")
await setup_messaging()
logger.info("Messaging setup completed")
# Start metrics server
logger.info("Starting metrics server")
metrics_collector.start_metrics_server(8080)
logger.info("Metrics server started on port 8080")
# Mark service as ready
app.state.ready = True
logger.info("Training Service startup completed successfully")
yield
except Exception as e:
logger.error("Failed to start Training Service", error=str(e))
app.state.ready = False
raise
# Shutdown
logger.info("Shutting down Training Service")
try:
# Cleanup messaging
logger.info("Cleaning up messaging")
await cleanup_messaging()
# Close database connections
logger.info("Closing database connections")
await database_manager.close_connections()
logger.info("Training Service shutdown completed")
except Exception as e:
logger.error("Error during shutdown", error=str(e))
# Create FastAPI app with lifespan
app = FastAPI(
title="Training Service",
description="ML model training service for bakery demand forecasting",
version="1.0.0",
docs_url="/docs" if settings.DEBUG else None,
redoc_url="/redoc" if settings.DEBUG else None,
lifespan=lifespan
)
# Initialize app state
app.state.ready = False
# Security middleware
if not settings.DEBUG:
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=["localhost", "127.0.0.1", "training-service", "*.bakery-forecast.local"]
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_origins=["*"] if settings.DEBUG else [
"http://localhost:3000",
"http://localhost:8000",
"https://dashboard.bakery-forecast.es"
],
allow_credentials=True,
allow_methods=["*"],
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_headers=["*"],
)
# Include routers
app.include_router(training.router, prefix="/training", tags=["training"])
app.include_router(models.router, prefix="/models", tags=["models"])
# Request logging middleware
@app.middleware("http")
async def log_requests(request: Request, call_next):
"""Log all incoming requests with timing"""
start_time = asyncio.get_event_loop().time()
# Log request
logger.info(
"Request started",
method=request.method,
path=request.url.path,
client_ip=request.client.host if request.client else "unknown"
)
# Process request
try:
response = await call_next(request)
# Calculate duration
duration = asyncio.get_event_loop().time() - start_time
# Log response
logger.info(
"Request completed",
method=request.method,
path=request.url.path,
status_code=response.status_code,
duration_ms=round(duration * 1000, 2)
)
# Update metrics
metrics_collector.record_request(
method=request.method,
endpoint=request.url.path,
status_code=response.status_code,
duration=duration
)
return response
except Exception as e:
duration = asyncio.get_event_loop().time() - start_time
logger.error(
"Request failed",
method=request.method,
path=request.url.path,
error=str(e),
duration_ms=round(duration * 1000, 2)
)
metrics_collector.increment_counter("http_requests_failed_total")
raise
@app.on_event("startup")
async def startup_event():
"""Application startup"""
logger.info("Starting Training Service")
# Exception handlers
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
"""Global exception handler for unhandled errors"""
logger.error(
"Unhandled exception",
path=request.url.path,
method=request.method,
error=str(exc),
exc_info=True
)
# Create database tables
await database_manager.create_tables()
metrics_collector.increment_counter("unhandled_exceptions_total")
# Initialize message publisher
await setup_messaging()
# Start metrics server
metrics_collector.start_metrics_server(8080)
logger.info("Training Service started successfully")
return JSONResponse(
status_code=500,
content={
"detail": "Internal server error",
"error_id": structlog.get_logger().new().info("Error logged", error=str(exc))
}
)
@app.on_event("shutdown")
async def shutdown_event():
"""Application shutdown"""
logger.info("Shutting down Training Service")
# Cleanup message publisher
await cleanup_messaging()
logger.info("Training Service shutdown complete")
# Include API routers
app.include_router(
training.router,
prefix="/training",
tags=["training"],
dependencies=[require_auth] if not settings.DEBUG else []
)
app.include_router(
models.router,
prefix="/models",
tags=["models"],
dependencies=[require_auth] if not settings.DEBUG else []
)
# Health check endpoints
@app.get("/health")
async def health_check():
"""Health check endpoint"""
"""Basic health check endpoint"""
return {
"status": "healthy",
"status": "healthy" if app.state.ready else "starting",
"service": "training-service",
"version": "1.0.0"
"version": "1.0.0",
"timestamp": structlog.get_logger().new().info("Health check")
}
@app.get("/health/ready")
async def readiness_check():
"""Kubernetes readiness probe"""
if not app.state.ready:
return JSONResponse(
status_code=503,
content={"status": "not_ready", "message": "Service is starting up"}
)
return {"status": "ready", "service": "training-service"}
@app.get("/health/live")
async def liveness_check():
"""Kubernetes liveness probe"""
# Check database connectivity
try:
db_healthy = await get_db_health()
if not db_healthy:
return JSONResponse(
status_code=503,
content={"status": "unhealthy", "reason": "database_unavailable"}
)
except Exception as e:
logger.error("Database health check failed", error=str(e))
return JSONResponse(
status_code=503,
content={"status": "unhealthy", "reason": "database_error"}
)
return {"status": "alive", "service": "training-service"}
@app.get("/metrics")
async def get_metrics():
"""Expose service metrics"""
return {
"training_jobs_active": metrics_collector.get_gauge("training_jobs_active", 0),
"training_jobs_completed": metrics_collector.get_counter("training_jobs_completed", 0),
"training_jobs_failed": metrics_collector.get_counter("training_jobs_failed", 0),
"models_trained_total": metrics_collector.get_counter("models_trained_total", 0),
"uptime_seconds": metrics_collector.get_gauge("uptime_seconds", 0)
}
@app.get("/")
async def root():
"""Root endpoint with service information"""
return {
"service": "training-service",
"version": "1.0.0",
"description": "ML model training service for bakery demand forecasting",
"docs": "/docs" if settings.DEBUG else "Documentation disabled in production",
"health": "/health"
}
# Development server configuration
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
uvicorn.run(
"app.main:app",
host="0.0.0.0",
port=8000,
reload=settings.DEBUG,
log_level=settings.LOG_LEVEL.lower(),
access_log=settings.DEBUG,
server_header=False,
date_header=False
)

View File

@@ -0,0 +1,493 @@
# services/training/app/ml/data_processor.py
"""
Data Processor for Training Service
Handles data preparation and feature engineering for ML training
"""
import pandas as pd
import numpy as np
from typing import Dict, List, Any, Optional, Tuple
from datetime import datetime, timedelta
import logging
from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer
logger = logging.getLogger(__name__)
class BakeryDataProcessor:
"""
Enhanced data processor for bakery forecasting training service.
Handles data cleaning, feature engineering, and preparation for ML models.
"""
def __init__(self):
self.scalers = {} # Store scalers for each feature
self.imputers = {} # Store imputers for missing value handling
async def prepare_training_data(self,
sales_data: pd.DataFrame,
weather_data: pd.DataFrame,
traffic_data: pd.DataFrame,
product_name: str) -> pd.DataFrame:
"""
Prepare comprehensive training data for a specific product.
Args:
sales_data: Historical sales data for the product
weather_data: Weather data
traffic_data: Traffic data
product_name: Product name for logging
Returns:
DataFrame ready for Prophet training with 'ds' and 'y' columns plus features
"""
try:
logger.info(f"Preparing training data for product: {product_name}")
# Convert and validate sales data
sales_clean = await self._process_sales_data(sales_data, product_name)
# Aggregate to daily level
daily_sales = await self._aggregate_daily_sales(sales_clean)
# Add temporal features
daily_sales = self._add_temporal_features(daily_sales)
# Merge external data sources
daily_sales = self._merge_weather_features(daily_sales, weather_data)
daily_sales = self._merge_traffic_features(daily_sales, traffic_data)
# Engineer additional features
daily_sales = self._engineer_features(daily_sales)
# Handle missing values
daily_sales = self._handle_missing_values(daily_sales)
# Prepare for Prophet (rename columns and validate)
prophet_data = self._prepare_prophet_format(daily_sales)
logger.info(f"Prepared {len(prophet_data)} data points for {product_name}")
return prophet_data
except Exception as e:
logger.error(f"Error preparing training data for {product_name}: {str(e)}")
raise
async def prepare_prediction_features(self,
future_dates: pd.DatetimeIndex,
weather_forecast: pd.DataFrame = None,
traffic_forecast: pd.DataFrame = None) -> pd.DataFrame:
"""
Create features for future predictions.
Args:
future_dates: Future dates to predict
weather_forecast: Weather forecast data
traffic_forecast: Traffic forecast data
Returns:
DataFrame with features for prediction
"""
try:
# Create base future dataframe
future_df = pd.DataFrame({'ds': future_dates})
# Add temporal features
future_df = self._add_temporal_features(
future_df.rename(columns={'ds': 'date'})
).rename(columns={'date': 'ds'})
# Add weather features
if weather_forecast is not None and not weather_forecast.empty:
weather_features = weather_forecast.copy()
if 'date' in weather_features.columns:
weather_features = weather_features.rename(columns={'date': 'ds'})
future_df = future_df.merge(weather_features, on='ds', how='left')
# Add traffic features
if traffic_forecast is not None and not traffic_forecast.empty:
traffic_features = traffic_forecast.copy()
if 'date' in traffic_features.columns:
traffic_features = traffic_features.rename(columns={'date': 'ds'})
future_df = future_df.merge(traffic_features, on='ds', how='left')
# Engineer additional features
future_df = self._engineer_features(future_df.rename(columns={'ds': 'date'}))
future_df = future_df.rename(columns={'date': 'ds'})
# Handle missing values in future data
numeric_columns = future_df.select_dtypes(include=[np.number]).columns
for col in numeric_columns:
if future_df[col].isna().any():
# Use reasonable defaults for Madrid
if col == 'temperature':
future_df[col] = future_df[col].fillna(15.0) # Default Madrid temp
elif col == 'precipitation':
future_df[col] = future_df[col].fillna(0.0) # Default no rain
elif col == 'humidity':
future_df[col] = future_df[col].fillna(60.0) # Default humidity
elif col == 'traffic_volume':
future_df[col] = future_df[col].fillna(100.0) # Default traffic
else:
future_df[col] = future_df[col].fillna(future_df[col].median())
return future_df
except Exception as e:
logger.error(f"Error creating prediction features: {e}")
# Return minimal features if error
return pd.DataFrame({'ds': future_dates})
async def _process_sales_data(self, sales_data: pd.DataFrame, product_name: str) -> pd.DataFrame:
"""Process and clean sales data"""
sales_clean = sales_data.copy()
# Ensure date column exists and is datetime
if 'date' not in sales_clean.columns:
raise ValueError("Sales data must have a 'date' column")
sales_clean['date'] = pd.to_datetime(sales_clean['date'])
# Ensure quantity column exists and is numeric
if 'quantity' not in sales_clean.columns:
raise ValueError("Sales data must have a 'quantity' column")
sales_clean['quantity'] = pd.to_numeric(sales_clean['quantity'], errors='coerce')
# Remove rows with invalid quantities
sales_clean = sales_clean.dropna(subset=['quantity'])
sales_clean = sales_clean[sales_clean['quantity'] >= 0] # No negative sales
# Filter for the specific product if product_name column exists
if 'product_name' in sales_clean.columns:
sales_clean = sales_clean[sales_clean['product_name'] == product_name]
return sales_clean
async def _aggregate_daily_sales(self, sales_data: pd.DataFrame) -> pd.DataFrame:
"""Aggregate sales to daily level"""
daily_sales = sales_data.groupby('date').agg({
'quantity': 'sum'
}).reset_index()
# Ensure we have data for all dates in the range
date_range = pd.date_range(
start=daily_sales['date'].min(),
end=daily_sales['date'].max(),
freq='D'
)
full_date_df = pd.DataFrame({'date': date_range})
daily_sales = full_date_df.merge(daily_sales, on='date', how='left')
daily_sales['quantity'] = daily_sales['quantity'].fillna(0) # Fill missing days with 0 sales
return daily_sales
def _add_temporal_features(self, df: pd.DataFrame) -> pd.DataFrame:
"""Add temporal features like day of week, month, etc."""
df = df.copy()
# Ensure we have a date column
if 'date' not in df.columns:
raise ValueError("DataFrame must have a 'date' column")
df['date'] = pd.to_datetime(df['date'])
# Day of week (0=Monday, 6=Sunday)
df['day_of_week'] = df['date'].dt.dayofweek
df['is_weekend'] = df['day_of_week'].isin([5, 6]).astype(int)
# Month and season
df['month'] = df['date'].dt.month
df['season'] = df['month'].apply(self._get_season)
# Week of year
df['week_of_year'] = df['date'].dt.isocalendar().week
# Quarter
df['quarter'] = df['date'].dt.quarter
# Holiday indicators (basic Spanish holidays)
df['is_holiday'] = df['date'].apply(self._is_spanish_holiday).astype(int)
# School calendar effects (approximate)
df['is_school_holiday'] = df['date'].apply(self._is_school_holiday).astype(int)
return df
def _merge_weather_features(self,
daily_sales: pd.DataFrame,
weather_data: pd.DataFrame) -> pd.DataFrame:
"""Merge weather features with sales data"""
if weather_data.empty:
# Add default weather columns with neutral values
daily_sales['temperature'] = 15.0 # Mild temperature
daily_sales['precipitation'] = 0.0 # No rain
daily_sales['humidity'] = 60.0 # Moderate humidity
daily_sales['wind_speed'] = 5.0 # Light wind
return daily_sales
try:
weather_clean = weather_data.copy()
# Ensure weather data has date column
if 'date' not in weather_clean.columns and 'ds' in weather_clean.columns:
weather_clean = weather_clean.rename(columns={'ds': 'date'})
weather_clean['date'] = pd.to_datetime(weather_clean['date'])
# Select relevant weather features
weather_features = ['date']
# Add available weather columns with default names
weather_mapping = {
'temperature': ['temperature', 'temp', 'temperatura'],
'precipitation': ['precipitation', 'rain', 'lluvia', 'precipitacion'],
'humidity': ['humidity', 'humedad'],
'wind_speed': ['wind_speed', 'viento', 'wind']
}
for standard_name, possible_names in weather_mapping.items():
for possible_name in possible_names:
if possible_name in weather_clean.columns:
weather_clean[standard_name] = weather_clean[possible_name]
weather_features.append(standard_name)
break
# Keep only the features we found
weather_clean = weather_clean[weather_features].copy()
# Merge with sales data
merged = daily_sales.merge(weather_clean, on='date', how='left')
# Fill missing weather values with reasonable defaults
if 'temperature' in merged.columns:
merged['temperature'] = merged['temperature'].fillna(15.0)
if 'precipitation' in merged.columns:
merged['precipitation'] = merged['precipitation'].fillna(0.0)
if 'humidity' in merged.columns:
merged['humidity'] = merged['humidity'].fillna(60.0)
if 'wind_speed' in merged.columns:
merged['wind_speed'] = merged['wind_speed'].fillna(5.0)
return merged
except Exception as e:
logger.warning(f"Error merging weather data: {e}")
# Add default weather columns if merge fails
daily_sales['temperature'] = 15.0
daily_sales['precipitation'] = 0.0
daily_sales['humidity'] = 60.0
daily_sales['wind_speed'] = 5.0
return daily_sales
def _merge_traffic_features(self,
daily_sales: pd.DataFrame,
traffic_data: pd.DataFrame) -> pd.DataFrame:
"""Merge traffic features with sales data"""
if traffic_data.empty:
# Add default traffic column
daily_sales['traffic_volume'] = 100.0 # Neutral traffic level
return daily_sales
try:
traffic_clean = traffic_data.copy()
# Ensure traffic data has date column
if 'date' not in traffic_clean.columns and 'ds' in traffic_clean.columns:
traffic_clean = traffic_clean.rename(columns={'ds': 'date'})
traffic_clean['date'] = pd.to_datetime(traffic_clean['date'])
# Select relevant traffic features
traffic_features = ['date']
# Map traffic column names
traffic_mapping = {
'traffic_volume': ['traffic_volume', 'traffic_intensity', 'trafico', 'intensidad'],
'pedestrian_count': ['pedestrian_count', 'peatones'],
'occupancy_rate': ['occupancy_rate', 'ocupacion']
}
for standard_name, possible_names in traffic_mapping.items():
for possible_name in possible_names:
if possible_name in traffic_clean.columns:
traffic_clean[standard_name] = traffic_clean[possible_name]
traffic_features.append(standard_name)
break
# Keep only the features we found
traffic_clean = traffic_clean[traffic_features].copy()
# Merge with sales data
merged = daily_sales.merge(traffic_clean, on='date', how='left')
# Fill missing traffic values
if 'traffic_volume' in merged.columns:
merged['traffic_volume'] = merged['traffic_volume'].fillna(100.0)
if 'pedestrian_count' in merged.columns:
merged['pedestrian_count'] = merged['pedestrian_count'].fillna(50.0)
if 'occupancy_rate' in merged.columns:
merged['occupancy_rate'] = merged['occupancy_rate'].fillna(0.5)
return merged
except Exception as e:
logger.warning(f"Error merging traffic data: {e}")
# Add default traffic column if merge fails
daily_sales['traffic_volume'] = 100.0
return daily_sales
def _engineer_features(self, df: pd.DataFrame) -> pd.DataFrame:
"""Engineer additional features from existing data"""
df = df.copy()
# Weather-based features
if 'temperature' in df.columns:
df['temp_squared'] = df['temperature'] ** 2
df['is_hot_day'] = (df['temperature'] > 25).astype(int)
df['is_cold_day'] = (df['temperature'] < 10).astype(int)
if 'precipitation' in df.columns:
df['is_rainy_day'] = (df['precipitation'] > 0).astype(int)
df['heavy_rain'] = (df['precipitation'] > 10).astype(int)
# Traffic-based features
if 'traffic_volume' in df.columns:
df['high_traffic'] = (df['traffic_volume'] > df['traffic_volume'].quantile(0.75)).astype(int)
df['low_traffic'] = (df['traffic_volume'] < df['traffic_volume'].quantile(0.25)).astype(int)
# Interaction features
if 'is_weekend' in df.columns and 'temperature' in df.columns:
df['weekend_temp_interaction'] = df['is_weekend'] * df['temperature']
if 'is_rainy_day' in df.columns and 'traffic_volume' in df.columns:
df['rain_traffic_interaction'] = df['is_rainy_day'] * df['traffic_volume']
return df
def _handle_missing_values(self, df: pd.DataFrame) -> pd.DataFrame:
"""Handle missing values in the dataset"""
df = df.copy()
# For numeric columns, use median imputation
numeric_columns = df.select_dtypes(include=[np.number]).columns
for col in numeric_columns:
if col != 'quantity' and df[col].isna().any():
median_value = df[col].median()
df[col] = df[col].fillna(median_value)
return df
def _prepare_prophet_format(self, df: pd.DataFrame) -> pd.DataFrame:
"""Prepare data in Prophet format with 'ds' and 'y' columns"""
prophet_df = df.copy()
# Rename columns for Prophet
if 'date' in prophet_df.columns:
prophet_df = prophet_df.rename(columns={'date': 'ds'})
if 'quantity' in prophet_df.columns:
prophet_df = prophet_df.rename(columns={'quantity': 'y'})
# Ensure ds is datetime
if 'ds' in prophet_df.columns:
prophet_df['ds'] = pd.to_datetime(prophet_df['ds'])
# Validate required columns
if 'ds' not in prophet_df.columns or 'y' not in prophet_df.columns:
raise ValueError("Prophet data must have 'ds' and 'y' columns")
# Remove any rows with missing target values
prophet_df = prophet_df.dropna(subset=['y'])
# Sort by date
prophet_df = prophet_df.sort_values('ds').reset_index(drop=True)
return prophet_df
def _get_season(self, month: int) -> int:
"""Get season from month (1-4 for Winter, Spring, Summer, Autumn)"""
if month in [12, 1, 2]:
return 1 # Winter
elif month in [3, 4, 5]:
return 2 # Spring
elif month in [6, 7, 8]:
return 3 # Summer
else:
return 4 # Autumn
def _is_spanish_holiday(self, date: datetime) -> bool:
"""Check if a date is a major Spanish holiday"""
month_day = (date.month, date.day)
# Major Spanish holidays that affect bakery sales
spanish_holidays = [
(1, 1), # New Year
(1, 6), # Epiphany
(5, 1), # Labour Day
(8, 15), # Assumption
(10, 12), # National Day
(11, 1), # All Saints
(12, 6), # Constitution
(12, 8), # Immaculate Conception
(12, 25), # Christmas
(5, 15), # San Isidro (Madrid)
(5, 2), # Madrid Community Day
]
return month_day in spanish_holidays
def _is_school_holiday(self, date: datetime) -> bool:
"""Check if a date is during school holidays (approximate)"""
month = date.month
# Approximate Spanish school holiday periods
# Summer holidays (July-August)
if month in [7, 8]:
return True
# Christmas holidays (mid December to early January)
if month == 12 and date.day >= 20:
return True
if month == 1 and date.day <= 10:
return True
# Easter holidays (approximate - first two weeks of April)
if month == 4 and date.day <= 14:
return True
return False
def calculate_feature_importance(self,
model_data: pd.DataFrame,
target_column: str = 'y') -> Dict[str, float]:
"""
Calculate feature importance for the model.
"""
try:
# Simple correlation-based importance
numeric_features = model_data.select_dtypes(include=[np.number]).columns
numeric_features = [col for col in numeric_features if col != target_column]
importance_scores = {}
for feature in numeric_features:
if feature in model_data.columns:
correlation = model_data[feature].corr(model_data[target_column])
importance_scores[feature] = abs(correlation) if not pd.isna(correlation) else 0.0
# Sort by importance
importance_scores = dict(sorted(importance_scores.items(),
key=lambda x: x[1], reverse=True))
return importance_scores
except Exception as e:
logger.error(f"Error calculating feature importance: {e}")
return {}

View File

@@ -0,0 +1,408 @@
# services/training/app/ml/prophet_manager.py
"""
Enhanced Prophet Manager for Training Service
Migrated from the monolithic backend to microservices architecture
"""
from typing import Dict, List, Any, Optional, Tuple
import pandas as pd
import numpy as np
from prophet import Prophet
import pickle
import logging
from datetime import datetime, timedelta
import uuid
import asyncio
import os
import joblib
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import json
from pathlib import Path
from app.core.config import settings
logger = logging.getLogger(__name__)
class BakeryProphetManager:
"""
Enhanced Prophet model manager for the training service.
Handles training, validation, and model persistence for bakery forecasting.
"""
def __init__(self):
self.models = {} # In-memory model storage
self.model_metadata = {} # Store model metadata
self.feature_scalers = {} # Store feature scalers per model
# Ensure model storage directory exists
os.makedirs(settings.MODEL_STORAGE_PATH, exist_ok=True)
async def train_bakery_model(self,
tenant_id: str,
product_name: str,
df: pd.DataFrame,
job_id: str) -> Dict[str, Any]:
"""
Train a Prophet model for bakery forecasting with enhanced features.
Args:
tenant_id: Tenant identifier
product_name: Product name
df: Training data with 'ds' and 'y' columns plus regressors
job_id: Training job identifier
Returns:
Dictionary with model information and metrics
"""
try:
logger.info(f"Training bakery model for tenant {tenant_id}, product {product_name}")
# Validate input data
await self._validate_training_data(df, product_name)
# Prepare data for Prophet
prophet_data = await self._prepare_prophet_data(df)
# Get regressor columns
regressor_columns = self._extract_regressor_columns(prophet_data)
# Initialize Prophet model with bakery-specific settings
model = self._create_prophet_model(regressor_columns)
# Add regressors to model
for regressor in regressor_columns:
if regressor in prophet_data.columns:
model.add_regressor(regressor)
# Fit the model
model.fit(prophet_data)
# Generate model ID and store model
model_id = f"{job_id}_{product_name}_{uuid.uuid4().hex[:8]}"
model_path = await self._store_model(
tenant_id, product_name, model, model_id, prophet_data, regressor_columns
)
# Calculate training metrics
training_metrics = await self._calculate_training_metrics(model, prophet_data)
# Prepare model information
model_info = {
"model_id": model_id,
"model_path": model_path,
"type": "prophet",
"training_samples": len(prophet_data),
"features": regressor_columns,
"hyperparameters": {
"seasonality_mode": settings.PROPHET_SEASONALITY_MODE,
"daily_seasonality": settings.PROPHET_DAILY_SEASONALITY,
"weekly_seasonality": settings.PROPHET_WEEKLY_SEASONALITY,
"yearly_seasonality": settings.PROPHET_YEARLY_SEASONALITY
},
"training_metrics": training_metrics,
"trained_at": datetime.now().isoformat(),
"data_period": {
"start_date": prophet_data['ds'].min().isoformat(),
"end_date": prophet_data['ds'].max().isoformat(),
"total_days": len(prophet_data)
}
}
logger.info(f"Model trained successfully for {product_name}")
return model_info
except Exception as e:
logger.error(f"Failed to train bakery model for {product_name}: {str(e)}")
raise
async def generate_forecast(self,
model_path: str,
future_dates: pd.DataFrame,
regressor_columns: List[str]) -> pd.DataFrame:
"""
Generate forecast using a stored Prophet model.
Args:
model_path: Path to the stored model
future_dates: DataFrame with future dates and regressors
regressor_columns: List of regressor column names
Returns:
DataFrame with forecast results
"""
try:
# Load the model
model = joblib.load(model_path)
# Validate future data has required regressors
for regressor in regressor_columns:
if regressor not in future_dates.columns:
logger.warning(f"Missing regressor {regressor}, filling with median")
future_dates[regressor] = 0 # Default value
# Generate forecast
forecast = model.predict(future_dates)
return forecast
except Exception as e:
logger.error(f"Failed to generate forecast: {str(e)}")
raise
async def _validate_training_data(self, df: pd.DataFrame, product_name: str):
"""Validate training data quality"""
if df.empty:
raise ValueError(f"No training data available for {product_name}")
if len(df) < settings.MIN_TRAINING_DATA_DAYS:
raise ValueError(
f"Insufficient training data for {product_name}: "
f"{len(df)} days, minimum required: {settings.MIN_TRAINING_DATA_DAYS}"
)
required_columns = ['ds', 'y']
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
raise ValueError(f"Missing required columns: {missing_columns}")
# Check for valid date range
if df['ds'].isna().any():
raise ValueError("Invalid dates found in training data")
# Check for valid target values
if df['y'].isna().all():
raise ValueError("No valid target values found")
async def _prepare_prophet_data(self, df: pd.DataFrame) -> pd.DataFrame:
"""Prepare data for Prophet training"""
prophet_data = df.copy()
# Ensure ds column is datetime
prophet_data['ds'] = pd.to_datetime(prophet_data['ds'])
# Handle missing values in target
if prophet_data['y'].isna().any():
logger.warning("Filling missing target values with interpolation")
prophet_data['y'] = prophet_data['y'].interpolate(method='linear')
# Remove extreme outliers (values > 3 standard deviations)
mean_val = prophet_data['y'].mean()
std_val = prophet_data['y'].std()
if std_val > 0: # Avoid division by zero
lower_bound = mean_val - 3 * std_val
upper_bound = mean_val + 3 * std_val
before_count = len(prophet_data)
prophet_data = prophet_data[
(prophet_data['y'] >= lower_bound) &
(prophet_data['y'] <= upper_bound)
]
after_count = len(prophet_data)
if before_count != after_count:
logger.info(f"Removed {before_count - after_count} outliers")
# Ensure chronological order
prophet_data = prophet_data.sort_values('ds').reset_index(drop=True)
# Fill missing values in regressors
numeric_columns = prophet_data.select_dtypes(include=[np.number]).columns
for col in numeric_columns:
if col != 'y' and prophet_data[col].isna().any():
prophet_data[col] = prophet_data[col].fillna(prophet_data[col].median())
return prophet_data
def _extract_regressor_columns(self, df: pd.DataFrame) -> List[str]:
"""Extract regressor columns from the dataframe"""
excluded_columns = ['ds', 'y']
regressor_columns = []
for col in df.columns:
if col not in excluded_columns and df[col].dtype in ['int64', 'float64']:
regressor_columns.append(col)
logger.info(f"Identified regressor columns: {regressor_columns}")
return regressor_columns
def _create_prophet_model(self, regressor_columns: List[str]) -> Prophet:
"""Create Prophet model with bakery-specific settings"""
# Get Spanish holidays
holidays = self._get_spanish_holidays()
# Bakery-specific Prophet configuration
model = Prophet(
holidays=holidays if not holidays.empty else None,
daily_seasonality=settings.PROPHET_DAILY_SEASONALITY,
weekly_seasonality=settings.PROPHET_WEEKLY_SEASONALITY,
yearly_seasonality=settings.PROPHET_YEARLY_SEASONALITY,
seasonality_mode=settings.PROPHET_SEASONALITY_MODE,
changepoint_prior_scale=0.05, # Conservative changepoint detection
seasonality_prior_scale=10, # Strong seasonality for bakeries
holidays_prior_scale=10, # Strong holiday effects
interval_width=0.8, # 80% confidence intervals
mcmc_samples=0, # Use MAP estimation (faster)
uncertainty_samples=1000 # For uncertainty estimation
)
return model
def _get_spanish_holidays(self) -> pd.DataFrame:
"""Get Spanish holidays for Prophet model"""
try:
# Define major Spanish holidays that affect bakery sales
holidays_list = []
years = range(2020, 2030) # Cover training and prediction period
for year in years:
holidays_list.extend([
{'holiday': 'new_year', 'ds': f'{year}-01-01'},
{'holiday': 'epiphany', 'ds': f'{year}-01-06'},
{'holiday': 'may_day', 'ds': f'{year}-05-01'},
{'holiday': 'assumption', 'ds': f'{year}-08-15'},
{'holiday': 'national_day', 'ds': f'{year}-10-12'},
{'holiday': 'all_saints', 'ds': f'{year}-11-01'},
{'holiday': 'constitution', 'ds': f'{year}-12-06'},
{'holiday': 'immaculate', 'ds': f'{year}-12-08'},
{'holiday': 'christmas', 'ds': f'{year}-12-25'},
# Madrid specific holidays
{'holiday': 'madrid_patron', 'ds': f'{year}-05-15'}, # San Isidro
{'holiday': 'madrid_community', 'ds': f'{year}-05-02'},
])
holidays_df = pd.DataFrame(holidays_list)
holidays_df['ds'] = pd.to_datetime(holidays_df['ds'])
return holidays_df
except Exception as e:
logger.warning(f"Error creating holidays dataframe: {e}")
return pd.DataFrame()
async def _store_model(self,
tenant_id: str,
product_name: str,
model: Prophet,
model_id: str,
training_data: pd.DataFrame,
regressor_columns: List[str]) -> str:
"""Store model and metadata to filesystem"""
# Create model filename
model_filename = f"{model_id}_prophet_model.pkl"
model_path = os.path.join(settings.MODEL_STORAGE_PATH, model_filename)
# Store the model
joblib.dump(model, model_path)
# Store metadata
metadata = {
"tenant_id": tenant_id,
"product_name": product_name,
"model_id": model_id,
"regressor_columns": regressor_columns,
"training_samples": len(training_data),
"training_period": {
"start": training_data['ds'].min().isoformat(),
"end": training_data['ds'].max().isoformat()
},
"created_at": datetime.now().isoformat(),
"model_type": "prophet",
"file_path": model_path
}
metadata_path = model_path.replace('.pkl', '_metadata.json')
with open(metadata_path, 'w') as f:
json.dump(metadata, f, indent=2)
# Store in memory for quick access
model_key = f"{tenant_id}:{product_name}"
self.models[model_key] = model
self.model_metadata[model_key] = metadata
logger.info(f"Model stored at: {model_path}")
return model_path
async def _calculate_training_metrics(self,
model: Prophet,
training_data: pd.DataFrame) -> Dict[str, float]:
"""Calculate training metrics for the model"""
try:
# Generate in-sample predictions
forecast = model.predict(training_data[['ds'] + [col for col in training_data.columns if col not in ['ds', 'y']]])
# Calculate metrics
y_true = training_data['y'].values
y_pred = forecast['yhat'].values
# Basic metrics
mae = mean_absolute_error(y_true, y_pred)
mse = mean_squared_error(y_true, y_pred)
rmse = np.sqrt(mse)
# MAPE (Mean Absolute Percentage Error)
mape = np.mean(np.abs((y_true - y_pred) / y_true)) * 100
# R-squared
r2 = r2_score(y_true, y_pred)
return {
"mae": round(mae, 2),
"mse": round(mse, 2),
"rmse": round(rmse, 2),
"mape": round(mape, 2),
"r2_score": round(r2, 4),
"mean_actual": round(np.mean(y_true), 2),
"mean_predicted": round(np.mean(y_pred), 2)
}
except Exception as e:
logger.error(f"Error calculating training metrics: {e}")
return {
"mae": 0.0,
"mse": 0.0,
"rmse": 0.0,
"mape": 0.0,
"r2_score": 0.0,
"mean_actual": 0.0,
"mean_predicted": 0.0
}
def get_model_info(self, tenant_id: str, product_name: str) -> Optional[Dict[str, Any]]:
"""Get model information for a specific tenant and product"""
model_key = f"{tenant_id}:{product_name}"
return self.model_metadata.get(model_key)
def list_models(self, tenant_id: str) -> List[Dict[str, Any]]:
"""List all models for a tenant"""
tenant_models = []
for model_key, metadata in self.model_metadata.items():
if metadata['tenant_id'] == tenant_id:
tenant_models.append(metadata)
return tenant_models
async def cleanup_old_models(self, days_old: int = 30):
"""Clean up old model files"""
try:
cutoff_date = datetime.now() - timedelta(days=days_old)
for model_path in Path(settings.MODEL_STORAGE_PATH).glob("*.pkl"):
# Check file modification time
if model_path.stat().st_mtime < cutoff_date.timestamp():
# Remove model and metadata files
model_path.unlink()
metadata_path = model_path.with_suffix('.json')
if metadata_path.exists():
metadata_path.unlink()
logger.info(f"Cleaned up old model: {model_path}")
except Exception as e:
logger.error(f"Error during model cleanup: {e}")

View File

@@ -1,174 +1,372 @@
# services/training/app/ml/trainer.py
"""
ML Training implementation
ML Trainer for Training Service
Orchestrates the complete training process
"""
import asyncio
import structlog
from typing import Dict, Any, List
from typing import Dict, List, Any, Optional, Tuple
import pandas as pd
from datetime import datetime
import joblib
import os
from prophet import Prophet
import numpy as np
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from datetime import datetime, timedelta
import logging
import asyncio
import uuid
from pathlib import Path
from app.ml.prophet_manager import BakeryProphetManager
from app.ml.data_processor import BakeryDataProcessor
from app.core.config import settings
logger = structlog.get_logger()
logger = logging.getLogger(__name__)
class MLTrainer:
"""ML training implementation"""
class BakeryMLTrainer:
"""
Main ML trainer that orchestrates the complete training process.
Replaces the old Celery-based training system with clean async implementation.
"""
def __init__(self):
self.model_storage_path = settings.MODEL_STORAGE_PATH
os.makedirs(self.model_storage_path, exist_ok=True)
self.prophet_manager = BakeryProphetManager()
self.data_processor = BakeryDataProcessor()
async def train_tenant_models(self,
tenant_id: str,
sales_data: List[Dict],
weather_data: List[Dict] = None,
traffic_data: List[Dict] = None,
job_id: str = None) -> Dict[str, Any]:
"""
Train models for all products of a tenant.
Args:
tenant_id: Tenant identifier
sales_data: Historical sales data
weather_data: Weather data (optional)
traffic_data: Traffic data (optional)
job_id: Training job identifier
Returns:
Dictionary with training results for each product
"""
if not job_id:
job_id = f"training_{tenant_id}_{uuid.uuid4().hex[:8]}"
logger.info(f"Starting training job {job_id} for tenant {tenant_id}")
try:
# Convert input data to DataFrames
sales_df = pd.DataFrame(sales_data) if sales_data else pd.DataFrame()
weather_df = pd.DataFrame(weather_data) if weather_data else pd.DataFrame()
traffic_df = pd.DataFrame(traffic_data) if traffic_data else pd.DataFrame()
# Validate input data
await self._validate_input_data(sales_df, tenant_id)
# Get unique products
products = sales_df['product_name'].unique().tolist()
logger.info(f"Training models for {len(products)} products: {products}")
# Process data for each product
processed_data = await self._process_all_products(
sales_df, weather_df, traffic_df, products
)
# Train models for each product
training_results = await self._train_all_models(
tenant_id, processed_data, job_id
)
# Calculate overall training summary
summary = self._calculate_training_summary(training_results)
result = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "completed",
"products_trained": len([r for r in training_results.values() if r.get('status') == 'success']),
"products_failed": len([r for r in training_results.values() if r.get('status') == 'error']),
"total_products": len(products),
"training_results": training_results,
"summary": summary,
"completed_at": datetime.now().isoformat()
}
logger.info(f"Training job {job_id} completed successfully")
return result
except Exception as e:
logger.error(f"Training job {job_id} failed: {str(e)}")
raise
async def train_models(self, training_data: Dict[str, Any], job_id: str, db) -> Dict[str, Any]:
"""Train models for all products"""
async def train_single_product(self,
tenant_id: str,
product_name: str,
sales_data: List[Dict],
weather_data: List[Dict] = None,
traffic_data: List[Dict] = None,
job_id: str = None) -> Dict[str, Any]:
"""
Train model for a single product.
models_result = {}
Args:
tenant_id: Tenant identifier
product_name: Product name
sales_data: Historical sales data
weather_data: Weather data (optional)
traffic_data: Traffic data (optional)
job_id: Training job identifier
Returns:
Training result for the product
"""
if not job_id:
job_id = f"training_{tenant_id}_{product_name}_{uuid.uuid4().hex[:8]}"
logger.info(f"Starting single product training {job_id} for {product_name}")
# Get sales data
sales_data = training_data.get("sales_data", [])
external_data = training_data.get("external_data", {})
try:
# Convert input data to DataFrames
sales_df = pd.DataFrame(sales_data) if sales_data else pd.DataFrame()
weather_df = pd.DataFrame(weather_data) if weather_data else pd.DataFrame()
traffic_df = pd.DataFrame(traffic_data) if traffic_data else pd.DataFrame()
# Filter sales data for the specific product
product_sales = sales_df[sales_df['product_name'] == product_name].copy()
# Validate product data
if product_sales.empty:
raise ValueError(f"No sales data found for product: {product_name}")
# Prepare training data
processed_data = await self.data_processor.prepare_training_data(
sales_data=product_sales,
weather_data=weather_df,
traffic_data=traffic_df,
product_name=product_name
)
# Train the model
model_info = await self.prophet_manager.train_bakery_model(
tenant_id=tenant_id,
product_name=product_name,
df=processed_data,
job_id=job_id
)
result = {
"job_id": job_id,
"tenant_id": tenant_id,
"product_name": product_name,
"status": "success",
"model_info": model_info,
"data_points": len(processed_data),
"completed_at": datetime.now().isoformat()
}
logger.info(f"Single product training {job_id} completed successfully")
return result
except Exception as e:
logger.error(f"Single product training {job_id} failed: {str(e)}")
raise
async def evaluate_model_performance(self,
tenant_id: str,
product_name: str,
model_path: str,
test_data: List[Dict]) -> Dict[str, Any]:
"""
Evaluate model performance on test data.
# Group by product
products_data = self._group_by_product(sales_data)
Args:
tenant_id: Tenant identifier
product_name: Product name
model_path: Path to the trained model
test_data: Test data for evaluation
Returns:
Performance metrics
"""
try:
logger.info(f"Evaluating model performance for {product_name}")
# Convert test data to DataFrame
test_df = pd.DataFrame(test_data)
# Prepare test data
test_prepared = await self.data_processor.prepare_prediction_features(
future_dates=test_df['ds'],
weather_forecast=test_df if 'temperature' in test_df.columns else pd.DataFrame(),
traffic_forecast=test_df if 'traffic_volume' in test_df.columns else pd.DataFrame()
)
# Get regressor columns
regressor_columns = [col for col in test_prepared.columns if col not in ['ds', 'y']]
# Generate predictions
forecast = await self.prophet_manager.generate_forecast(
model_path=model_path,
future_dates=test_prepared,
regressor_columns=regressor_columns
)
# Calculate performance metrics if we have actual values
metrics = {}
if 'y' in test_df.columns:
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
y_true = test_df['y'].values
y_pred = forecast['yhat'].values
metrics = {
"mae": float(mean_absolute_error(y_true, y_pred)),
"rmse": float(np.sqrt(mean_squared_error(y_true, y_pred))),
"mape": float(np.mean(np.abs((y_true - y_pred) / y_true)) * 100),
"r2_score": float(r2_score(y_true, y_pred))
}
result = {
"tenant_id": tenant_id,
"product_name": product_name,
"evaluation_metrics": metrics,
"forecast_samples": len(forecast),
"evaluated_at": datetime.now().isoformat()
}
return result
except Exception as e:
logger.error(f"Model evaluation failed: {str(e)}")
raise
async def _validate_input_data(self, sales_df: pd.DataFrame, tenant_id: str):
"""Validate input sales data"""
if sales_df.empty:
raise ValueError(f"No sales data provided for tenant {tenant_id}")
# Train model for each product
for product_name, product_sales in products_data.items():
required_columns = ['date', 'product_name', 'quantity']
missing_columns = [col for col in required_columns if col not in sales_df.columns]
if missing_columns:
raise ValueError(f"Missing required columns: {missing_columns}")
# Check for valid dates
try:
sales_df['date'] = pd.to_datetime(sales_df['date'])
except Exception:
raise ValueError("Invalid date format in sales data")
# Check for valid quantities
if not sales_df['quantity'].dtype in ['int64', 'float64']:
raise ValueError("Quantity column must be numeric")
async def _process_all_products(self,
sales_df: pd.DataFrame,
weather_df: pd.DataFrame,
traffic_df: pd.DataFrame,
products: List[str]) -> Dict[str, pd.DataFrame]:
"""Process data for all products"""
processed_data = {}
for product_name in products:
try:
model_result = await self._train_product_model(
product_name,
product_sales,
external_data,
job_id
logger.info(f"Processing data for product: {product_name}")
# Filter sales data for this product
product_sales = sales_df[sales_df['product_name'] == product_name].copy()
# Process the product data
processed_product_data = await self.data_processor.prepare_training_data(
sales_data=product_sales,
weather_data=weather_df,
traffic_data=traffic_df,
product_name=product_name
)
models_result[product_name] = model_result
processed_data[product_name] = processed_product_data
logger.info(f"Processed {len(processed_product_data)} data points for {product_name}")
except Exception as e:
logger.error(f"Failed to train model for {product_name}: {e}")
logger.error(f"Failed to process data for {product_name}: {str(e)}")
# Continue with other products
continue
return models_result
return processed_data
def _group_by_product(self, sales_data: List[Dict]) -> Dict[str, List[Dict]]:
"""Group sales data by product"""
async def _train_all_models(self,
tenant_id: str,
processed_data: Dict[str, pd.DataFrame],
job_id: str) -> Dict[str, Any]:
"""Train models for all processed products"""
training_results = {}
products = {}
for sale in sales_data:
product_name = sale.get("product_name")
if product_name not in products:
products[product_name] = []
products[product_name].append(sale)
return products
async def _train_product_model(self, product_name: str, sales_data: List[Dict], external_data: Dict, job_id: str) -> Dict[str, Any]:
"""Train Prophet model for a single product"""
# Convert to DataFrame
df = pd.DataFrame(sales_data)
df['date'] = pd.to_datetime(df['date'])
# Aggregate daily sales
daily_sales = df.groupby('date')['quantity_sold'].sum().reset_index()
daily_sales.columns = ['ds', 'y']
# Add external features
daily_sales = self._add_external_features(daily_sales, external_data)
# Train Prophet model
model = Prophet(
seasonality_mode=settings.PROPHET_SEASONALITY_MODE,
daily_seasonality=settings.PROPHET_DAILY_SEASONALITY,
weekly_seasonality=settings.PROPHET_WEEKLY_SEASONALITY,
yearly_seasonality=settings.PROPHET_YEARLY_SEASONALITY
)
# Add regressors
model.add_regressor('temperature')
model.add_regressor('humidity')
model.add_regressor('precipitation')
model.add_regressor('traffic_volume')
# Fit model
model.fit(daily_sales)
# Save model
model_path = os.path.join(
self.model_storage_path,
f"{job_id}_{product_name}_prophet_model.pkl"
)
joblib.dump(model, model_path)
return {
"type": "prophet",
"path": model_path,
"training_samples": len(daily_sales),
"features": ["temperature", "humidity", "precipitation", "traffic_volume"],
"hyperparameters": {
"seasonality_mode": settings.PROPHET_SEASONALITY_MODE,
"daily_seasonality": settings.PROPHET_DAILY_SEASONALITY,
"weekly_seasonality": settings.PROPHET_WEEKLY_SEASONALITY,
"yearly_seasonality": settings.PROPHET_YEARLY_SEASONALITY
}
}
def _add_external_features(self, daily_sales: pd.DataFrame, external_data: Dict) -> pd.DataFrame:
"""Add external features to sales data"""
# Add weather data
weather_data = external_data.get("weather", [])
if weather_data:
weather_df = pd.DataFrame(weather_data)
weather_df['ds'] = pd.to_datetime(weather_df['date'])
daily_sales = daily_sales.merge(weather_df[['ds', 'temperature', 'humidity', 'precipitation']], on='ds', how='left')
# Add traffic data
traffic_data = external_data.get("traffic", [])
if traffic_data:
traffic_df = pd.DataFrame(traffic_data)
traffic_df['ds'] = pd.to_datetime(traffic_df['date'])
daily_sales = daily_sales.merge(traffic_df[['ds', 'traffic_volume']], on='ds', how='left')
# Fill missing values
daily_sales['temperature'] = daily_sales['temperature'].fillna(daily_sales['temperature'].mean())
daily_sales['humidity'] = daily_sales['humidity'].fillna(daily_sales['humidity'].mean())
daily_sales['precipitation'] = daily_sales['precipitation'].fillna(0)
daily_sales['traffic_volume'] = daily_sales['traffic_volume'].fillna(daily_sales['traffic_volume'].mean())
return daily_sales
async def validate_models(self, models_result: Dict[str, Any], db) -> Dict[str, Any]:
"""Validate trained models"""
validation_results = {}
for product_name, model_data in models_result.items():
for product_name, product_data in processed_data.items():
try:
# Load model
model_path = model_data.get("path")
model = joblib.load(model_path)
logger.info(f"Training model for product: {product_name}")
# Mock validation for now (in production, you'd use actual validation data)
validation_results[product_name] = {
"mape": np.random.uniform(10, 25), # Mock MAPE between 10-25%
"rmse": np.random.uniform(8, 15), # Mock RMSE
"mae": np.random.uniform(5, 12), # Mock MAE
"r2_score": np.random.uniform(0.7, 0.9) # Mock R2 score
# Check if we have enough data
if len(product_data) < settings.MIN_TRAINING_DATA_DAYS:
training_results[product_name] = {
'status': 'skipped',
'reason': 'insufficient_data',
'data_points': len(product_data),
'min_required': settings.MIN_TRAINING_DATA_DAYS
}
continue
# Train the model
model_info = await self.prophet_manager.train_bakery_model(
tenant_id=tenant_id,
product_name=product_name,
df=product_data,
job_id=job_id
)
training_results[product_name] = {
'status': 'success',
'model_info': model_info,
'data_points': len(product_data),
'trained_at': datetime.now().isoformat()
}
logger.info(f"Successfully trained model for {product_name}")
except Exception as e:
logger.error(f"Validation failed for {product_name}: {e}")
validation_results[product_name] = {
"mape": None,
"rmse": None,
"mae": None,
"r2_score": None,
"error": str(e)
logger.error(f"Failed to train model for {product_name}: {str(e)}")
training_results[product_name] = {
'status': 'error',
'error_message': str(e),
'data_points': len(product_data) if product_data is not None else 0
}
return validation_results
return training_results
def _calculate_training_summary(self, training_results: Dict[str, Any]) -> Dict[str, Any]:
"""Calculate summary statistics from training results"""
total_products = len(training_results)
successful_products = len([r for r in training_results.values() if r.get('status') == 'success'])
failed_products = len([r for r in training_results.values() if r.get('status') == 'error'])
skipped_products = len([r for r in training_results.values() if r.get('status') == 'skipped'])
# Calculate average training metrics for successful models
successful_results = [r for r in training_results.values() if r.get('status') == 'success']
avg_metrics = {}
if successful_results:
metrics_list = [r['model_info'].get('training_metrics', {}) for r in successful_results]
if metrics_list and all(metrics_list):
avg_metrics = {
'avg_mae': np.mean([m.get('mae', 0) for m in metrics_list]),
'avg_rmse': np.mean([m.get('rmse', 0) for m in metrics_list]),
'avg_mape': np.mean([m.get('mape', 0) for m in metrics_list]),
'avg_r2': np.mean([m.get('r2_score', 0) for m in metrics_list])
}
return {
'total_products': total_products,
'successful_products': successful_products,
'failed_products': failed_products,
'skipped_products': skipped_products,
'success_rate': round(successful_products / total_products * 100, 2) if total_products > 0 else 0,
'average_metrics': avg_metrics
}

View File

@@ -1,101 +1,154 @@
# services/training/app/models/training.py
"""
Training models - Fixed version
Database models for training service
"""
from sqlalchemy import Column, String, Integer, DateTime, Text, JSON, Boolean, Float
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy import Column, Integer, String, DateTime, Text, Boolean, JSON, Float
from sqlalchemy.dialects.postgresql import UUID, ARRAY
from sqlalchemy.ext.declarative import declarative_base
from datetime import datetime
import uuid
from shared.database.base import Base
Base = declarative_base()
class TrainingJob(Base):
"""Training job model"""
__tablename__ = "training_jobs"
class ModelTrainingLog(Base):
"""
Table to track training job execution and status.
Replaces the old Celery task tracking.
"""
__tablename__ = "model_training_logs"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
status = Column(String(20), nullable=False, default="queued") # queued, running, completed, failed
progress = Column(Integer, default=0)
current_step = Column(String(200))
requested_by = Column(UUID(as_uuid=True), nullable=False)
id = Column(Integer, primary_key=True, index=True)
job_id = Column(String(255), unique=True, index=True, nullable=False)
tenant_id = Column(String(255), index=True, nullable=False)
status = Column(String(50), nullable=False, default="pending") # pending, running, completed, failed, cancelled
progress = Column(Integer, default=0) # 0-100 percentage
current_step = Column(String(500), default="")
# Timing
started_at = Column(DateTime, default=datetime.utcnow)
completed_at = Column(DateTime)
duration_seconds = Column(Integer)
# Timestamps
start_time = Column(DateTime, default=datetime.now)
end_time = Column(DateTime, nullable=True)
# Results
models_trained = Column(JSON)
metrics = Column(JSON)
error_message = Column(Text)
# Configuration and results
config = Column(JSON, nullable=True) # Training job configuration
results = Column(JSON, nullable=True) # Training results
error_message = Column(Text, nullable=True)
# Metadata
training_data_from = Column(DateTime)
training_data_to = Column(DateTime)
total_data_points = Column(Integer)
products_count = Column(Integer)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
def __repr__(self):
return f"<TrainingJob(id={self.id}, tenant_id={self.tenant_id}, status={self.status})>"
created_at = Column(DateTime, default=datetime.now)
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
class TrainedModel(Base):
"""Trained model information"""
"""
Table to store information about trained models.
"""
__tablename__ = "trained_models"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
training_job_id = Column(UUID(as_uuid=True), nullable=False)
id = Column(Integer, primary_key=True, index=True)
model_id = Column(String(255), unique=True, index=True, nullable=False)
tenant_id = Column(String(255), index=True, nullable=False)
product_name = Column(String(255), index=True, nullable=False)
# Model details
product_name = Column(String(100), nullable=False)
model_type = Column(String(50), nullable=False, default="prophet")
model_version = Column(String(20), nullable=False)
model_path = Column(String(500)) # Path to saved model file
# Model information
model_type = Column(String(50), nullable=False, default="prophet") # prophet, arima, etc.
model_path = Column(String(1000), nullable=False) # Path to stored model file
version = Column(Integer, nullable=False, default=1)
# Training information
training_samples = Column(Integer, nullable=False, default=0)
features = Column(ARRAY(String), nullable=True) # List of features used
hyperparameters = Column(JSON, nullable=True) # Model hyperparameters
training_metrics = Column(JSON, nullable=True) # Training performance metrics
# Data period information
data_period_start = Column(DateTime, nullable=True)
data_period_end = Column(DateTime, nullable=True)
# Status and metadata
is_active = Column(Boolean, default=True, index=True)
created_at = Column(DateTime, default=datetime.now)
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
class ModelPerformanceMetric(Base):
"""
Table to track model performance over time.
"""
__tablename__ = "model_performance_metrics"
id = Column(Integer, primary_key=True, index=True)
model_id = Column(String(255), index=True, nullable=False)
tenant_id = Column(String(255), index=True, nullable=False)
product_name = Column(String(255), index=True, nullable=False)
# Performance metrics
mape = Column(Float) # Mean Absolute Percentage Error
rmse = Column(Float) # Root Mean Square Error
mae = Column(Float) # Mean Absolute Error
r2_score = Column(Float) # R-squared score
mae = Column(Float, nullable=True) # Mean Absolute Error
mse = Column(Float, nullable=True) # Mean Squared Error
rmse = Column(Float, nullable=True) # Root Mean Squared Error
mape = Column(Float, nullable=True) # Mean Absolute Percentage Error
r2_score = Column(Float, nullable=True) # R-squared score
# Training details
training_samples = Column(Integer)
validation_samples = Column(Integer)
features_used = Column(JSON)
hyperparameters = Column(JSON)
# Additional metrics
accuracy_percentage = Column(Float, nullable=True)
prediction_confidence = Column(Float, nullable=True)
# Evaluation information
evaluation_period_start = Column(DateTime, nullable=True)
evaluation_period_end = Column(DateTime, nullable=True)
evaluation_samples = Column(Integer, nullable=True)
# Metadata
measured_at = Column(DateTime, default=datetime.now)
created_at = Column(DateTime, default=datetime.now)
class TrainingJobQueue(Base):
"""
Table to manage training job queue and scheduling.
"""
__tablename__ = "training_job_queue"
id = Column(Integer, primary_key=True, index=True)
job_id = Column(String(255), unique=True, index=True, nullable=False)
tenant_id = Column(String(255), index=True, nullable=False)
# Job configuration
job_type = Column(String(50), nullable=False) # full_training, single_product, evaluation
priority = Column(Integer, default=1) # Higher number = higher priority
config = Column(JSON, nullable=True)
# Scheduling information
scheduled_at = Column(DateTime, nullable=True)
started_at = Column(DateTime, nullable=True)
estimated_duration_minutes = Column(Integer, nullable=True)
# Status
is_active = Column(Boolean, default=True)
last_used_at = Column(DateTime)
status = Column(String(50), nullable=False, default="queued") # queued, running, completed, failed
retry_count = Column(Integer, default=0)
max_retries = Column(Integer, default=3)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
def __repr__(self):
return f"<TrainedModel(id={self.id}, product={self.product_name}, tenant={self.tenant_id})>"
# Metadata
created_at = Column(DateTime, default=datetime.now)
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
class TrainingLog(Base):
"""Training log entries - FIXED: renamed metadata to log_metadata"""
__tablename__ = "training_logs"
class ModelArtifact(Base):
"""
Table to track model files and artifacts.
"""
__tablename__ = "model_artifacts"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
training_job_id = Column(UUID(as_uuid=True), nullable=False, index=True)
id = Column(Integer, primary_key=True, index=True)
model_id = Column(String(255), index=True, nullable=False)
tenant_id = Column(String(255), index=True, nullable=False)
level = Column(String(10), nullable=False) # DEBUG, INFO, WARNING, ERROR
message = Column(Text, nullable=False)
step = Column(String(100))
progress = Column(Integer)
# Artifact information
artifact_type = Column(String(50), nullable=False) # model_file, metadata, training_data, etc.
file_path = Column(String(1000), nullable=False)
file_size_bytes = Column(Integer, nullable=True)
checksum = Column(String(255), nullable=True) # For file integrity
# Additional data
execution_time = Column(Float) # Time taken for this step
memory_usage = Column(Float) # Memory usage in MB
log_metadata = Column(JSON) # FIXED: renamed from 'metadata' to 'log_metadata'
# Storage information
storage_location = Column(String(100), nullable=False, default="local") # local, s3, gcs, etc.
compression = Column(String(50), nullable=True) # gzip, lz4, etc.
created_at = Column(DateTime, default=datetime.utcnow)
def __repr__(self):
return f"<TrainingLog(id={self.id}, level={self.level})>"
# Metadata
created_at = Column(DateTime, default=datetime.now)
expires_at = Column(DateTime, nullable=True) # For automatic cleanup

View File

@@ -1,91 +1,181 @@
# services/training/app/schemas/training.py
"""
Training schemas
Pydantic schemas for training service
"""
from pydantic import BaseModel, Field, validator
from typing import Optional, Dict, Any, List
from typing import Dict, List, Any, Optional
from datetime import datetime
from enum import Enum
class TrainingJobStatus(str, Enum):
"""Training job status enum"""
QUEUED = "queued"
class TrainingStatus(str, Enum):
"""Training job status enumeration"""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class TrainingRequest(BaseModel):
"""Training request schema"""
tenant_id: Optional[str] = None # Will be set from auth
force_retrain: bool = Field(default=False, description="Force retrain even if recent models exist")
products: Optional[List[str]] = Field(default=None, description="Specific products to train, or None for all")
training_days: Optional[int] = Field(default=730, ge=30, le=1095, description="Number of days of historical data to use")
class TrainingJobRequest(BaseModel):
"""Request schema for starting a training job"""
products: Optional[List[str]] = Field(None, description="Specific products to train (if None, train all)")
include_weather: bool = Field(True, description="Include weather data in training")
include_traffic: bool = Field(True, description="Include traffic data in training")
start_date: Optional[datetime] = Field(None, description="Start date for training data")
end_date: Optional[datetime] = Field(None, description="End date for training data")
min_data_points: int = Field(30, description="Minimum data points required per product")
estimated_duration: Optional[int] = Field(None, description="Estimated duration in minutes")
@validator('training_days')
def validate_training_days(cls, v):
if v < 30:
raise ValueError('Minimum training days is 30')
if v > 1095:
raise ValueError('Maximum training days is 1095 (3 years)')
# Prophet-specific parameters
seasonality_mode: str = Field("additive", description="Prophet seasonality mode")
daily_seasonality: bool = Field(True, description="Enable daily seasonality")
weekly_seasonality: bool = Field(True, description="Enable weekly seasonality")
yearly_seasonality: bool = Field(True, description="Enable yearly seasonality")
@validator('seasonality_mode')
def validate_seasonality_mode(cls, v):
if v not in ['additive', 'multiplicative']:
raise ValueError('seasonality_mode must be additive or multiplicative')
return v
@validator('min_data_points')
def validate_min_data_points(cls, v):
if v < 7:
raise ValueError('min_data_points must be at least 7')
return v
class SingleProductTrainingRequest(BaseModel):
"""Request schema for training a single product"""
include_weather: bool = Field(True, description="Include weather data in training")
include_traffic: bool = Field(True, description="Include traffic data in training")
start_date: Optional[datetime] = Field(None, description="Start date for training data")
end_date: Optional[datetime] = Field(None, description="End date for training data")
# Prophet-specific parameters
seasonality_mode: str = Field("additive", description="Prophet seasonality mode")
daily_seasonality: bool = Field(True, description="Enable daily seasonality")
weekly_seasonality: bool = Field(True, description="Enable weekly seasonality")
yearly_seasonality: bool = Field(True, description="Enable yearly seasonality")
class TrainingJobResponse(BaseModel):
"""Training job response schema"""
id: str
tenant_id: str
status: TrainingJobStatus
progress: int
current_step: Optional[str]
started_at: datetime
completed_at: Optional[datetime]
duration_seconds: Optional[int]
models_trained: Optional[Dict[str, Any]]
metrics: Optional[Dict[str, Any]]
error_message: Optional[str]
class Config:
from_attributes = True
"""Response schema for training job creation"""
job_id: str = Field(..., description="Unique training job identifier")
status: TrainingStatus = Field(..., description="Current job status")
message: str = Field(..., description="Status message")
tenant_id: str = Field(..., description="Tenant identifier")
created_at: datetime = Field(..., description="Job creation timestamp")
estimated_duration_minutes: int = Field(..., description="Estimated completion time in minutes")
class TrainedModelResponse(BaseModel):
"""Trained model response schema"""
id: str
product_name: str
model_type: str
model_version: str
mape: Optional[float]
rmse: Optional[float]
mae: Optional[float]
r2_score: Optional[float]
training_samples: Optional[int]
features_used: Optional[List[str]]
is_active: bool
created_at: datetime
last_used_at: Optional[datetime]
class Config:
from_attributes = True
class TrainingStatusResponse(BaseModel):
"""Response schema for training job status"""
job_id: str = Field(..., description="Training job identifier")
status: TrainingStatus = Field(..., description="Current job status")
progress: int = Field(0, description="Progress percentage (0-100)")
current_step: str = Field("", description="Current processing step")
started_at: datetime = Field(..., description="Job start timestamp")
completed_at: Optional[datetime] = Field(None, description="Job completion timestamp")
results: Optional[Dict[str, Any]] = Field(None, description="Training results")
error_message: Optional[str] = Field(None, description="Error message if failed")
class ModelInfo(BaseModel):
"""Schema for trained model information"""
model_id: str = Field(..., description="Unique model identifier")
model_path: str = Field(..., description="Path to stored model")
model_type: str = Field("prophet", description="Type of ML model")
training_samples: int = Field(..., description="Number of training samples")
features: List[str] = Field(..., description="List of features used")
hyperparameters: Dict[str, Any] = Field(..., description="Model hyperparameters")
training_metrics: Dict[str, float] = Field(..., description="Training performance metrics")
trained_at: datetime = Field(..., description="Training completion timestamp")
data_period: Dict[str, str] = Field(..., description="Training data period")
class ProductTrainingResult(BaseModel):
"""Schema for individual product training result"""
product_name: str = Field(..., description="Product name")
status: str = Field(..., description="Training status for this product")
model_info: Optional[ModelInfo] = Field(None, description="Model information if successful")
data_points: int = Field(..., description="Number of data points used")
error_message: Optional[str] = Field(None, description="Error message if failed")
trained_at: datetime = Field(..., description="Training completion timestamp")
class TrainingResultsResponse(BaseModel):
"""Response schema for complete training results"""
job_id: str = Field(..., description="Training job identifier")
tenant_id: str = Field(..., description="Tenant identifier")
status: TrainingStatus = Field(..., description="Overall job status")
products_trained: int = Field(..., description="Number of products successfully trained")
products_failed: int = Field(..., description="Number of products that failed training")
total_products: int = Field(..., description="Total number of products processed")
training_results: Dict[str, ProductTrainingResult] = Field(..., description="Per-product results")
summary: Dict[str, Any] = Field(..., description="Training summary statistics")
completed_at: datetime = Field(..., description="Job completion timestamp")
class TrainingValidationResult(BaseModel):
"""Schema for training data validation results"""
is_valid: bool = Field(..., description="Whether the data is valid for training")
issues: List[str] = Field(default_factory=list, description="List of data quality issues")
recommendations: List[str] = Field(default_factory=list, description="Recommendations for improvement")
estimated_time_minutes: int = Field(..., description="Estimated training time in minutes")
products_analyzed: int = Field(..., description="Number of products analyzed")
total_data_points: int = Field(..., description="Total data points available")
class TrainingProgress(BaseModel):
"""Training progress update schema"""
job_id: str
progress: int
current_step: str
estimated_completion: Optional[datetime]
class TrainingMetrics(BaseModel):
"""Training metrics schema"""
total_jobs: int
successful_jobs: int
failed_jobs: int
average_duration: float
models_trained: int
active_models: int
"""Schema for training performance metrics"""
mae: float = Field(..., description="Mean Absolute Error")
mse: float = Field(..., description="Mean Squared Error")
rmse: float = Field(..., description="Root Mean Squared Error")
mape: float = Field(..., description="Mean Absolute Percentage Error")
r2_score: float = Field(..., description="R-squared score")
mean_actual: float = Field(..., description="Mean of actual values")
mean_predicted: float = Field(..., description="Mean of predicted values")
class ModelValidationResult(BaseModel):
"""Model validation result schema"""
product_name: str
is_valid: bool
accuracy_score: float
validation_error: Optional[str]
recommendations: List[str]
class ExternalDataConfig(BaseModel):
"""Configuration for external data sources"""
weather_enabled: bool = Field(True, description="Enable weather data")
traffic_enabled: bool = Field(True, description="Enable traffic data")
weather_features: List[str] = Field(
default_factory=lambda: ["temperature", "precipitation", "humidity"],
description="Weather features to include"
)
traffic_features: List[str] = Field(
default_factory=lambda: ["traffic_volume"],
description="Traffic features to include"
)
class TrainingJobConfig(BaseModel):
"""Complete training job configuration"""
external_data: ExternalDataConfig = Field(default_factory=ExternalDataConfig)
prophet_params: Dict[str, Any] = Field(
default_factory=lambda: {
"seasonality_mode": "additive",
"daily_seasonality": True,
"weekly_seasonality": True,
"yearly_seasonality": True
},
description="Prophet model parameters"
)
data_filters: Dict[str, Any] = Field(
default_factory=dict,
description="Data filtering parameters"
)
validation_params: Dict[str, Any] = Field(
default_factory=lambda: {"min_data_points": 30},
description="Data validation parameters"
)
class TrainedModelResponse(BaseModel):
"""Response schema for trained model information"""
model_id: str = Field(..., description="Unique model identifier")
tenant_id: str = Field(..., description="Tenant identifier")
product_name: str = Field(..., description="Product name")
model_type: str = Field(..., description="Type of ML model")
model_path: str = Field(..., description="Path to stored model")
version: int = Field(..., description="Model version")
training_samples: int = Field(..., description="Number of training samples")
features: List[str] = Field(..., description="List of features used")
hyperparameters: Dict[str, Any] = Field(..., description="Model hyperparameters")
training_metrics: Dict[str, float] = Field(..., description="Training performance metrics")
is_active: bool = Field(..., description="Whether model is active")
created_at: datetime = Field(..., description="Model creation timestamp")
data_period_start: Optional[datetime] = Field(None, description="Training data start date")
data_period_end: Optional[datetime] = Field(None, description="Training data end date")

View File

@@ -1,12 +1,17 @@
# ================================================================
# services/training/app/services/messaging.py
# ================================================================
"""
Messaging service for training service
Training service messaging - Clean interface for training-specific events
Uses shared RabbitMQ infrastructure
"""
import structlog
from typing import Dict, Any, Optional
from shared.messaging.rabbitmq import RabbitMQClient
from shared.messaging.events import (
TrainingStartedEvent,
TrainingCompletedEvent,
TrainingFailedEvent
)
from app.core.config import settings
logger = structlog.get_logger()
@@ -27,23 +32,188 @@ async def cleanup_messaging():
await training_publisher.disconnect()
logger.info("Training service messaging cleaned up")
# Convenience functions for training-specific events
async def publish_training_started(job_data: dict) -> bool:
"""Publish training started event"""
return await training_publisher.publish_training_event("started", job_data)
# Training Job Events
async def publish_job_started(job_id: str, tenant_id: str, config: Dict[str, Any]) -> bool:
"""Publish training job started event"""
event = TrainingStartedEvent(
service_name="training-service",
data={
"job_id": job_id,
"tenant_id": tenant_id,
"config": config
}
)
return await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.started",
event_data=event.to_dict()
)
async def publish_training_completed(job_data: dict) -> bool:
"""Publish training completed event"""
return await training_publisher.publish_training_event("completed", job_data)
async def publish_job_progress(job_id: str, tenant_id: str, progress: int, step: str) -> bool:
"""Publish training job progress event"""
return await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.progress",
event_data={
"service_name": "training-service",
"event_type": "training.progress",
"data": {
"job_id": job_id,
"tenant_id": tenant_id,
"progress": progress,
"current_step": step
}
}
)
async def publish_training_failed(job_data: dict) -> bool:
"""Publish training failed event"""
return await training_publisher.publish_training_event("failed", job_data)
async def publish_job_completed(job_id: str, tenant_id: str, results: Dict[str, Any]) -> bool:
"""Publish training job completed event"""
event = TrainingCompletedEvent(
service_name="training-service",
data={
"job_id": job_id,
"tenant_id": tenant_id,
"results": results,
"models_trained": results.get("products_trained", 0),
"success_rate": results.get("summary", {}).get("success_rate", 0)
}
)
return await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.completed",
event_data=event.to_dict()
)
async def publish_model_validated(model_data: dict) -> bool:
async def publish_job_failed(job_id: str, tenant_id: str, error: str) -> bool:
"""Publish training job failed event"""
event = TrainingFailedEvent(
service_name="training-service",
data={
"job_id": job_id,
"tenant_id": tenant_id,
"error": error
}
)
return await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.failed",
event_data=event.to_dict()
)
async def publish_job_cancelled(job_id: str, tenant_id: str) -> bool:
"""Publish training job cancelled event"""
return await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.cancelled",
event_data={
"service_name": "training-service",
"event_type": "training.cancelled",
"data": {
"job_id": job_id,
"tenant_id": tenant_id
}
}
)
# Product Training Events
async def publish_product_training_started(job_id: str, tenant_id: str, product_name: str) -> bool:
"""Publish single product training started event"""
return await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.product.started",
event_data={
"service_name": "training-service",
"event_type": "training.product.started",
"data": {
"job_id": job_id,
"tenant_id": tenant_id,
"product_name": product_name
}
}
)
async def publish_product_training_completed(job_id: str, tenant_id: str, product_name: str, model_id: str) -> bool:
"""Publish single product training completed event"""
return await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.product.completed",
event_data={
"service_name": "training-service",
"event_type": "training.product.completed",
"data": {
"job_id": job_id,
"tenant_id": tenant_id,
"product_name": product_name,
"model_id": model_id
}
}
)
# Model Events
async def publish_model_trained(model_id: str, tenant_id: str, product_name: str, metrics: Dict[str, float]) -> bool:
"""Publish model trained event"""
return await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.model.trained",
event_data={
"service_name": "training-service",
"event_type": "training.model.trained",
"data": {
"model_id": model_id,
"tenant_id": tenant_id,
"product_name": product_name,
"training_metrics": metrics
}
}
)
async def publish_model_updated(model_id: str, tenant_id: str, product_name: str, version: int) -> bool:
"""Publish model updated event"""
return await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.model.updated",
event_data={
"service_name": "training-service",
"event_type": "training.model.updated",
"data": {
"model_id": model_id,
"tenant_id": tenant_id,
"product_name": product_name,
"version": version
}
}
)
async def publish_model_validated(model_id: str, tenant_id: str, product_name: str, validation_results: Dict[str, Any]) -> bool:
"""Publish model validation event"""
return await training_publisher.publish_training_event("model.validated", model_data)
return await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.model.validated",
event_data={
"service_name": "training-service",
"event_type": "training.model.validated",
"data": {
"model_id": model_id,
"tenant_id": tenant_id,
"product_name": product_name,
"validation_results": validation_results
}
}
)
async def publish_model_saved(model_data: dict) -> bool:
async def publish_model_saved(model_id: str, tenant_id: str, product_name: str, model_path: str) -> bool:
"""Publish model saved event"""
return await training_publisher.publish_training_event("model.saved", model_data)
return await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.model.saved",
event_data={
"service_name": "training-service",
"event_type": "training.model.saved",
"data": {
"model_id": model_id,
"tenant_id": tenant_id,
"product_name": product_name,
"model_path": model_path
}
}
)

File diff suppressed because it is too large Load Diff