Add all the code for training service
This commit is contained in:
@@ -8,10 +8,11 @@ from typing import List
|
||||
import structlog
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.auth import verify_token
|
||||
from app.core.auth import get_current_tenant_id
|
||||
from app.schemas.training import TrainedModelResponse
|
||||
from app.services.training_service import TrainingService
|
||||
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter()
|
||||
|
||||
@@ -19,12 +20,12 @@ training_service = TrainingService()
|
||||
|
||||
@router.get("/", response_model=List[TrainedModelResponse])
|
||||
async def get_trained_models(
|
||||
user_data: dict = Depends(verify_token),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get trained models"""
|
||||
try:
|
||||
return await training_service.get_trained_models(user_data, db)
|
||||
return await training_service.get_trained_models(tenant_id, db)
|
||||
except Exception as e:
|
||||
logger.error(f"Get trained models error: {e}")
|
||||
raise HTTPException(
|
||||
|
||||
@@ -1,77 +1,299 @@
|
||||
# services/training/app/api/training.py
|
||||
"""
|
||||
Training API endpoints
|
||||
Training API endpoints for the training service
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import List, Optional
|
||||
import structlog
|
||||
from typing import Dict, List, Any, Optional
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.auth import verify_token
|
||||
from app.schemas.training import TrainingRequest, TrainingJobResponse, TrainedModelResponse
|
||||
from app.core.auth import get_current_tenant_id
|
||||
from app.schemas.training import (
|
||||
TrainingJobRequest,
|
||||
TrainingJobResponse,
|
||||
TrainingStatusResponse,
|
||||
SingleProductTrainingRequest
|
||||
)
|
||||
from app.services.training_service import TrainingService
|
||||
from app.services.messaging import publish_job_started, publish_job_cancelled, publish_product_training_started
|
||||
from shared.monitoring.metrics import MetricsCollector
|
||||
|
||||
logger = structlog.get_logger()
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
metrics = MetricsCollector("training-service")
|
||||
|
||||
# Initialize training service
|
||||
training_service = TrainingService()
|
||||
|
||||
@router.post("/train", response_model=TrainingJobResponse)
|
||||
async def start_training(
|
||||
request: TrainingRequest,
|
||||
user_data: dict = Depends(verify_token),
|
||||
@router.post("/jobs", response_model=TrainingJobResponse)
|
||||
async def start_training_job(
|
||||
request: TrainingJobRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Start training job"""
|
||||
"""
|
||||
Start a new training job for all products of a tenant.
|
||||
Replaces the old Celery-based training system.
|
||||
"""
|
||||
try:
|
||||
return await training_service.start_training(request, user_data, db)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
logger.info(f"Starting training job for tenant {tenant_id}")
|
||||
metrics.increment_counter("training_jobs_started")
|
||||
|
||||
# Generate job ID
|
||||
job_id = f"training_{tenant_id}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Create training job record
|
||||
training_job = await training_service.create_training_job(
|
||||
db=db,
|
||||
tenant_id=tenant_id,
|
||||
job_id=job_id,
|
||||
config=request.dict()
|
||||
)
|
||||
|
||||
# Start training in background
|
||||
background_tasks.add_task(
|
||||
training_service.execute_training_job,
|
||||
db,
|
||||
job_id,
|
||||
tenant_id,
|
||||
request
|
||||
)
|
||||
|
||||
# Publish training started event
|
||||
await publish_job_started(job_id, tenant_id, request.dict())
|
||||
|
||||
return TrainingJobResponse(
|
||||
job_id=job_id,
|
||||
status="started",
|
||||
message="Training job started successfully",
|
||||
tenant_id=tenant_id,
|
||||
created_at=training_job.start_time,
|
||||
estimated_duration_minutes=request.estimated_duration or 15
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training start error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to start training"
|
||||
)
|
||||
logger.error(f"Failed to start training job: {str(e)}")
|
||||
metrics.increment_counter("training_jobs_failed")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to start training job: {str(e)}")
|
||||
|
||||
@router.get("/status/{job_id}", response_model=TrainingJobResponse)
|
||||
@router.get("/jobs/{job_id}/status", response_model=TrainingStatusResponse)
|
||||
async def get_training_status(
|
||||
job_id: str,
|
||||
user_data: dict = Depends(verify_token),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get training job status"""
|
||||
"""
|
||||
Get the status of a training job.
|
||||
Provides real-time progress updates.
|
||||
"""
|
||||
try:
|
||||
return await training_service.get_training_status(job_id, user_data, db)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e)
|
||||
# Get job status from database
|
||||
job_status = await training_service.get_job_status(db, job_id, tenant_id)
|
||||
|
||||
if not job_status:
|
||||
raise HTTPException(status_code=404, detail="Training job not found")
|
||||
|
||||
return TrainingStatusResponse(
|
||||
job_id=job_id,
|
||||
status=job_status.status,
|
||||
progress=job_status.progress,
|
||||
current_step=job_status.current_step,
|
||||
started_at=job_status.start_time,
|
||||
completed_at=job_status.end_time,
|
||||
results=job_status.results,
|
||||
error_message=job_status.error_message
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Get training status error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get training status"
|
||||
)
|
||||
logger.error(f"Failed to get training status: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get training status: {str(e)}")
|
||||
|
||||
@router.get("/jobs", response_model=List[TrainingJobResponse])
|
||||
async def get_training_jobs(
|
||||
limit: int = Query(10, ge=1, le=100),
|
||||
offset: int = Query(0, ge=0),
|
||||
user_data: dict = Depends(verify_token),
|
||||
@router.post("/products/{product_name}", response_model=TrainingJobResponse)
|
||||
async def train_single_product(
|
||||
product_name: str,
|
||||
request: SingleProductTrainingRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get training jobs"""
|
||||
"""
|
||||
Train a model for a single product.
|
||||
Useful for quick model updates or new products.
|
||||
"""
|
||||
try:
|
||||
return await training_service.get_training_jobs(user_data, limit, offset, db)
|
||||
logger.info(f"Starting single product training for {product_name}, tenant {tenant_id}")
|
||||
metrics.increment_counter("single_product_training_started")
|
||||
|
||||
# Generate job ID
|
||||
job_id = f"single_{tenant_id}_{product_name}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Create training job record
|
||||
training_job = await training_service.create_single_product_job(
|
||||
db=db,
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
job_id=job_id,
|
||||
config=request.dict()
|
||||
)
|
||||
|
||||
# Start training in background
|
||||
background_tasks.add_task(
|
||||
training_service.execute_single_product_training,
|
||||
db,
|
||||
job_id,
|
||||
tenant_id,
|
||||
product_name,
|
||||
request
|
||||
)
|
||||
|
||||
# Publish event
|
||||
await publish_product_training_started(job_id, tenant_id, product_name)
|
||||
|
||||
return TrainingJobResponse(
|
||||
job_id=job_id,
|
||||
status="started",
|
||||
message=f"Single product training started for {product_name}",
|
||||
tenant_id=tenant_id,
|
||||
created_at=training_job.start_time,
|
||||
estimated_duration_minutes=5
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Get training jobs error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get training jobs"
|
||||
)
|
||||
logger.error(f"Failed to start single product training: {str(e)}")
|
||||
metrics.increment_counter("single_product_training_failed")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to start training: {str(e)}")
|
||||
|
||||
@router.get("/jobs", response_model=List[TrainingStatusResponse])
|
||||
async def list_training_jobs(
|
||||
limit: int = 10,
|
||||
status: Optional[str] = None,
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
List training jobs for a tenant.
|
||||
"""
|
||||
try:
|
||||
jobs = await training_service.list_training_jobs(
|
||||
db=db,
|
||||
tenant_id=tenant_id,
|
||||
limit=limit,
|
||||
status_filter=status
|
||||
)
|
||||
|
||||
return [
|
||||
TrainingStatusResponse(
|
||||
job_id=job.job_id,
|
||||
status=job.status,
|
||||
progress=job.progress,
|
||||
current_step=job.current_step,
|
||||
started_at=job.start_time,
|
||||
completed_at=job.end_time,
|
||||
results=job.results,
|
||||
error_message=job.error_message
|
||||
)
|
||||
for job in jobs
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list training jobs: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to list training jobs: {str(e)}")
|
||||
|
||||
@router.post("/jobs/{job_id}/cancel")
|
||||
async def cancel_training_job(
|
||||
job_id: str,
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Cancel a running training job.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Cancelling training job {job_id} for tenant {tenant_id}")
|
||||
|
||||
# Update job status to cancelled
|
||||
success = await training_service.cancel_training_job(db, job_id, tenant_id)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Training job not found or cannot be cancelled")
|
||||
|
||||
# Publish cancellation event
|
||||
await publish_job_cancelled(job_id, tenant_id)
|
||||
|
||||
return {"message": "Training job cancelled successfully"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cancel training job: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to cancel training job: {str(e)}")
|
||||
|
||||
@router.get("/jobs/{job_id}/logs")
|
||||
async def get_training_logs(
|
||||
job_id: str,
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get detailed logs for a training job.
|
||||
"""
|
||||
try:
|
||||
logs = await training_service.get_training_logs(db, job_id, tenant_id)
|
||||
|
||||
if not logs:
|
||||
raise HTTPException(status_code=404, detail="Training job not found")
|
||||
|
||||
return {"job_id": job_id, "logs": logs}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get training logs: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get training logs: {str(e)}")
|
||||
|
||||
@router.post("/validate")
|
||||
async def validate_training_data(
|
||||
request: TrainingJobRequest,
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Validate training data before starting a job.
|
||||
Provides early feedback on data quality issues.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Validating training data for tenant {tenant_id}")
|
||||
|
||||
# Perform data validation
|
||||
validation_result = await training_service.validate_training_data(
|
||||
db=db,
|
||||
tenant_id=tenant_id,
|
||||
config=request.dict()
|
||||
)
|
||||
|
||||
return {
|
||||
"is_valid": validation_result["is_valid"],
|
||||
"issues": validation_result.get("issues", []),
|
||||
"recommendations": validation_result.get("recommendations", []),
|
||||
"estimated_training_time": validation_result.get("estimated_time_minutes", 15)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to validate training data: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to validate training data: {str(e)}")
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check():
|
||||
"""Health check for the training service"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "training-service",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
@@ -1,38 +1,303 @@
|
||||
# services/training/app/core/auth.py
|
||||
"""
|
||||
Authentication utilities for training service
|
||||
Authentication and authorization for training service
|
||||
"""
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException, status, Depends
|
||||
from fastapi.security import HTTPBearer
|
||||
import structlog
|
||||
from typing import Optional
|
||||
from fastapi import HTTPException, Depends, status
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
import httpx
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
security = HTTPBearer()
|
||||
# HTTP Bearer token scheme
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
async def verify_token(token: str = Depends(security)):
|
||||
"""Verify token with auth service"""
|
||||
class AuthenticationError(Exception):
|
||||
"""Custom exception for authentication errors"""
|
||||
pass
|
||||
|
||||
class AuthorizationError(Exception):
|
||||
"""Custom exception for authorization errors"""
|
||||
pass
|
||||
|
||||
async def verify_token(token: str) -> dict:
|
||||
"""
|
||||
Verify JWT token with auth service
|
||||
|
||||
Args:
|
||||
token: JWT token to verify
|
||||
|
||||
Returns:
|
||||
dict: Token payload with user and tenant information
|
||||
|
||||
Raises:
|
||||
AuthenticationError: If token is invalid
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{settings.AUTH_SERVICE_URL}/auth/verify",
|
||||
headers={"Authorization": f"Bearer {token.credentials}"}
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
token_data = response.json()
|
||||
logger.debug("Token verified successfully", user_id=token_data.get("user_id"))
|
||||
return token_data
|
||||
elif response.status_code == 401:
|
||||
logger.warning("Invalid token provided")
|
||||
raise AuthenticationError("Invalid or expired token")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authentication credentials"
|
||||
)
|
||||
logger.error("Auth service error", status_code=response.status_code)
|
||||
raise AuthenticationError("Authentication service unavailable")
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.error("Auth service timeout")
|
||||
raise AuthenticationError("Authentication service timeout")
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Auth service unavailable: {e}")
|
||||
logger.error("Auth service request error", error=str(e))
|
||||
raise AuthenticationError("Authentication service unavailable")
|
||||
except AuthenticationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Unexpected auth error", error=str(e))
|
||||
raise AuthenticationError("Authentication failed")
|
||||
|
||||
async def get_current_user(
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
|
||||
) -> dict:
|
||||
"""
|
||||
Get current authenticated user
|
||||
|
||||
Args:
|
||||
credentials: HTTP Bearer credentials
|
||||
|
||||
Returns:
|
||||
dict: User information
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails
|
||||
"""
|
||||
if not credentials:
|
||||
logger.warning("No credentials provided")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Authentication service unavailable"
|
||||
)
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication credentials required",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
try:
|
||||
token_data = await verify_token(credentials.credentials)
|
||||
return token_data
|
||||
|
||||
except AuthenticationError as e:
|
||||
logger.warning("Authentication failed", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
async def get_current_tenant_id(
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> str:
|
||||
"""
|
||||
Get current tenant ID from authenticated user
|
||||
|
||||
Args:
|
||||
current_user: Current authenticated user data
|
||||
|
||||
Returns:
|
||||
str: Tenant ID
|
||||
|
||||
Raises:
|
||||
HTTPException: If tenant ID is missing
|
||||
"""
|
||||
tenant_id = current_user.get("tenant_id")
|
||||
if not tenant_id:
|
||||
logger.error("Missing tenant_id in token", user_data=current_user)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid token: missing tenant information"
|
||||
)
|
||||
|
||||
return tenant_id
|
||||
|
||||
async def require_admin_role(
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> dict:
|
||||
"""
|
||||
Require admin role for endpoint access
|
||||
|
||||
Args:
|
||||
current_user: Current authenticated user data
|
||||
|
||||
Returns:
|
||||
dict: User information
|
||||
|
||||
Raises:
|
||||
HTTPException: If user is not admin
|
||||
"""
|
||||
user_role = current_user.get("role", "").lower()
|
||||
if user_role != "admin":
|
||||
logger.warning("Access denied - admin role required",
|
||||
user_id=current_user.get("user_id"),
|
||||
role=user_role)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Admin role required"
|
||||
)
|
||||
|
||||
return current_user
|
||||
|
||||
async def require_training_permission(
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> dict:
|
||||
"""
|
||||
Require training permission for endpoint access
|
||||
|
||||
Args:
|
||||
current_user: Current authenticated user data
|
||||
|
||||
Returns:
|
||||
dict: User information
|
||||
|
||||
Raises:
|
||||
HTTPException: If user doesn't have training permission
|
||||
"""
|
||||
permissions = current_user.get("permissions", [])
|
||||
if "training" not in permissions and current_user.get("role", "").lower() != "admin":
|
||||
logger.warning("Access denied - training permission required",
|
||||
user_id=current_user.get("user_id"),
|
||||
permissions=permissions)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Training permission required"
|
||||
)
|
||||
|
||||
return current_user
|
||||
|
||||
# Optional authentication for development/testing
|
||||
async def get_current_user_optional(
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get current user but don't require authentication (for development)
|
||||
|
||||
Args:
|
||||
credentials: HTTP Bearer credentials
|
||||
|
||||
Returns:
|
||||
dict or None: User information if authenticated, None otherwise
|
||||
"""
|
||||
if not credentials:
|
||||
return None
|
||||
|
||||
try:
|
||||
token_data = await verify_token(credentials.credentials)
|
||||
return token_data
|
||||
except AuthenticationError:
|
||||
return None
|
||||
|
||||
async def get_tenant_id_optional(
|
||||
current_user: Optional[dict] = Depends(get_current_user_optional)
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get tenant ID but don't require authentication (for development)
|
||||
|
||||
Args:
|
||||
current_user: Current user data (optional)
|
||||
|
||||
Returns:
|
||||
str or None: Tenant ID if available, None otherwise
|
||||
"""
|
||||
if not current_user:
|
||||
return None
|
||||
|
||||
return current_user.get("tenant_id")
|
||||
|
||||
# Development/testing auth bypass
|
||||
async def get_test_tenant_id() -> str:
|
||||
"""
|
||||
Get test tenant ID for development/testing
|
||||
Only works when DEBUG is enabled
|
||||
|
||||
Returns:
|
||||
str: Test tenant ID
|
||||
"""
|
||||
if settings.DEBUG:
|
||||
return "test-tenant-development"
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Test authentication only available in debug mode"
|
||||
)
|
||||
|
||||
# Token validation utility
|
||||
def validate_token_structure(token_data: dict) -> bool:
|
||||
"""
|
||||
Validate that token data has required structure
|
||||
|
||||
Args:
|
||||
token_data: Token payload data
|
||||
|
||||
Returns:
|
||||
bool: True if valid structure, False otherwise
|
||||
"""
|
||||
required_fields = ["user_id", "tenant_id"]
|
||||
|
||||
for field in required_fields:
|
||||
if field not in token_data:
|
||||
logger.warning("Invalid token structure - missing field", field=field)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
# Role checking utilities
|
||||
def has_role(user_data: dict, required_role: str) -> bool:
|
||||
"""
|
||||
Check if user has required role
|
||||
|
||||
Args:
|
||||
user_data: User data from token
|
||||
required_role: Required role name
|
||||
|
||||
Returns:
|
||||
bool: True if user has role, False otherwise
|
||||
"""
|
||||
user_role = user_data.get("role", "").lower()
|
||||
return user_role == required_role.lower()
|
||||
|
||||
def has_permission(user_data: dict, required_permission: str) -> bool:
|
||||
"""
|
||||
Check if user has required permission
|
||||
|
||||
Args:
|
||||
user_data: User data from token
|
||||
required_permission: Required permission name
|
||||
|
||||
Returns:
|
||||
bool: True if user has permission, False otherwise
|
||||
"""
|
||||
permissions = user_data.get("permissions", [])
|
||||
return required_permission in permissions or has_role(user_data, "admin")
|
||||
|
||||
# Export commonly used items
|
||||
__all__ = [
|
||||
'get_current_user',
|
||||
'get_current_tenant_id',
|
||||
'require_admin_role',
|
||||
'require_training_permission',
|
||||
'get_current_user_optional',
|
||||
'get_tenant_id_optional',
|
||||
'get_test_tenant_id',
|
||||
'has_role',
|
||||
'has_permission',
|
||||
'AuthenticationError',
|
||||
'AuthorizationError'
|
||||
]
|
||||
@@ -1,12 +1,260 @@
|
||||
# services/training/app/core/database.py
|
||||
"""
|
||||
Database configuration for training service
|
||||
Uses shared database infrastructure
|
||||
"""
|
||||
|
||||
from shared.database.base import DatabaseManager
|
||||
import structlog
|
||||
from typing import AsyncGenerator
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import text
|
||||
|
||||
from shared.database.base import DatabaseManager, Base
|
||||
from app.core.config import settings
|
||||
|
||||
# Initialize database manager
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Initialize database manager using shared infrastructure
|
||||
database_manager = DatabaseManager(settings.DATABASE_URL)
|
||||
|
||||
# Alias for convenience
|
||||
get_db = database_manager.get_db
|
||||
# Alias for convenience - matches the existing interface
|
||||
get_db = database_manager.get_db
|
||||
|
||||
async def get_db_health() -> bool:
|
||||
"""
|
||||
Health check function for database connectivity
|
||||
Enhanced version of the shared functionality
|
||||
"""
|
||||
try:
|
||||
async with database_manager.async_engine.begin() as conn:
|
||||
await conn.execute(text("SELECT 1"))
|
||||
logger.debug("Database health check passed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Database health check failed", error=str(e))
|
||||
return False
|
||||
|
||||
# Training service specific database utilities
|
||||
class TrainingDatabaseUtils:
|
||||
"""Training service specific database utilities"""
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_old_training_logs(days_old: int = 90):
|
||||
"""Clean up old training logs"""
|
||||
try:
|
||||
async with database_manager.async_session_local() as session:
|
||||
if settings.DATABASE_URL.startswith("sqlite"):
|
||||
query = text(
|
||||
"DELETE FROM model_training_logs "
|
||||
"WHERE start_time < datetime('now', :days_param)"
|
||||
)
|
||||
params = {"days_param": f"-{days_old} days"}
|
||||
else:
|
||||
query = text(
|
||||
"DELETE FROM model_training_logs "
|
||||
"WHERE start_time < NOW() - INTERVAL :days_param"
|
||||
)
|
||||
params = {"days_param": f"{days_old} days"}
|
||||
|
||||
result = await session.execute(query, params)
|
||||
await session.commit()
|
||||
|
||||
deleted_count = result.rowcount
|
||||
logger.info("Cleaned up old training logs",
|
||||
deleted_count=deleted_count,
|
||||
days_old=days_old)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Training logs cleanup failed", error=str(e))
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_old_models(days_old: int = 365):
|
||||
"""Clean up old inactive models"""
|
||||
try:
|
||||
async with database_manager.async_session_local() as session:
|
||||
if settings.DATABASE_URL.startswith("sqlite"):
|
||||
query = text(
|
||||
"DELETE FROM trained_models "
|
||||
"WHERE is_active = 0 AND created_at < datetime('now', :days_param)"
|
||||
)
|
||||
params = {"days_param": f"-{days_old} days"}
|
||||
else:
|
||||
query = text(
|
||||
"DELETE FROM trained_models "
|
||||
"WHERE is_active = false AND created_at < NOW() - INTERVAL :days_param"
|
||||
)
|
||||
params = {"days_param": f"{days_old} days"}
|
||||
|
||||
result = await session.execute(query, params)
|
||||
await session.commit()
|
||||
|
||||
deleted_count = result.rowcount
|
||||
logger.info("Cleaned up old models",
|
||||
deleted_count=deleted_count,
|
||||
days_old=days_old)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Model cleanup failed", error=str(e))
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
async def get_training_statistics(tenant_id: str = None) -> dict:
|
||||
"""Get training statistics"""
|
||||
try:
|
||||
async with database_manager.async_session_local() as session:
|
||||
# Base query for training logs
|
||||
if tenant_id:
|
||||
logs_query = text(
|
||||
"SELECT status, COUNT(*) as count "
|
||||
"FROM model_training_logs "
|
||||
"WHERE tenant_id = :tenant_id "
|
||||
"GROUP BY status"
|
||||
)
|
||||
models_query = text(
|
||||
"SELECT COUNT(*) as count "
|
||||
"FROM trained_models "
|
||||
"WHERE tenant_id = :tenant_id AND is_active = :is_active"
|
||||
)
|
||||
params = {"tenant_id": tenant_id}
|
||||
else:
|
||||
logs_query = text(
|
||||
"SELECT status, COUNT(*) as count "
|
||||
"FROM model_training_logs "
|
||||
"GROUP BY status"
|
||||
)
|
||||
models_query = text(
|
||||
"SELECT COUNT(*) as count "
|
||||
"FROM trained_models "
|
||||
"WHERE is_active = :is_active"
|
||||
)
|
||||
params = {}
|
||||
|
||||
# Get training job statistics
|
||||
logs_result = await session.execute(logs_query, params)
|
||||
job_stats = {row.status: row.count for row in logs_result.fetchall()}
|
||||
|
||||
# Get active models count
|
||||
active_models_result = await session.execute(
|
||||
models_query,
|
||||
{**params, "is_active": True}
|
||||
)
|
||||
active_models = active_models_result.scalar() or 0
|
||||
|
||||
# Get inactive models count
|
||||
inactive_models_result = await session.execute(
|
||||
models_query,
|
||||
{**params, "is_active": False}
|
||||
)
|
||||
inactive_models = inactive_models_result.scalar() or 0
|
||||
|
||||
return {
|
||||
"training_jobs": job_stats,
|
||||
"active_models": active_models,
|
||||
"inactive_models": inactive_models,
|
||||
"total_models": active_models + inactive_models
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get training statistics", error=str(e))
|
||||
return {
|
||||
"training_jobs": {},
|
||||
"active_models": 0,
|
||||
"inactive_models": 0,
|
||||
"total_models": 0
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def check_tenant_data_exists(tenant_id: str) -> bool:
|
||||
"""Check if tenant has any training data"""
|
||||
try:
|
||||
async with database_manager.async_session_local() as session:
|
||||
query = text(
|
||||
"SELECT COUNT(*) as count "
|
||||
"FROM model_training_logs "
|
||||
"WHERE tenant_id = :tenant_id "
|
||||
"LIMIT 1"
|
||||
)
|
||||
|
||||
result = await session.execute(query, {"tenant_id": tenant_id})
|
||||
count = result.scalar() or 0
|
||||
|
||||
return count > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to check tenant data existence",
|
||||
tenant_id=tenant_id, error=str(e))
|
||||
return False
|
||||
|
||||
# Enhanced database session dependency with better error handling
|
||||
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
Enhanced database session dependency with better logging and error handling
|
||||
"""
|
||||
async with database_manager.async_session_local() as session:
|
||||
try:
|
||||
logger.debug("Database session created")
|
||||
yield session
|
||||
except Exception as e:
|
||||
logger.error("Database session error", error=str(e), exc_info=True)
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
logger.debug("Database session closed")
|
||||
|
||||
# Database initialization for training service
|
||||
async def initialize_training_database():
|
||||
"""Initialize database tables for training service"""
|
||||
try:
|
||||
logger.info("Initializing training service database")
|
||||
|
||||
# Import models to ensure they're registered
|
||||
from app.models.training import (
|
||||
ModelTrainingLog,
|
||||
TrainedModel,
|
||||
ModelPerformanceMetric,
|
||||
TrainingJobQueue,
|
||||
ModelArtifact
|
||||
)
|
||||
|
||||
# Create tables using shared infrastructure
|
||||
await database_manager.create_tables()
|
||||
|
||||
logger.info("Training service database initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize training service database", error=str(e))
|
||||
raise
|
||||
|
||||
# Database cleanup for training service
|
||||
async def cleanup_training_database():
|
||||
"""Cleanup database connections for training service"""
|
||||
try:
|
||||
logger.info("Cleaning up training service database connections")
|
||||
|
||||
# Close engine connections
|
||||
if hasattr(database_manager, 'async_engine') and database_manager.async_engine:
|
||||
await database_manager.async_engine.dispose()
|
||||
|
||||
logger.info("Training service database cleanup completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup training service database", error=str(e))
|
||||
|
||||
# Export the commonly used items to maintain compatibility
|
||||
__all__ = [
|
||||
'Base',
|
||||
'database_manager',
|
||||
'get_db',
|
||||
'get_db_session',
|
||||
'get_db_health',
|
||||
'TrainingDatabaseUtils',
|
||||
'initialize_training_database',
|
||||
'cleanup_training_database'
|
||||
]
|
||||
@@ -1,81 +1,282 @@
|
||||
# services/training/app/main.py
|
||||
"""
|
||||
Training Service
|
||||
Handles ML model training for bakery demand forecasting
|
||||
Training Service Main Application
|
||||
Enhanced with proper error handling, monitoring, and lifecycle management
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from fastapi import FastAPI, BackgroundTasks
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
import uvicorn
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import database_manager
|
||||
from app.core.database import database_manager, get_db_health
|
||||
from app.api import training, models
|
||||
from app.services.messaging import setup_messaging, cleanup_messaging
|
||||
from shared.monitoring.logging import setup_logging
|
||||
from shared.monitoring.metrics import MetricsCollector
|
||||
from shared.auth.decorators import require_auth
|
||||
|
||||
# Setup logging
|
||||
# Setup structured logging
|
||||
setup_logging("training-service", settings.LOG_LEVEL)
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title="Training Service",
|
||||
description="ML model training service for bakery demand forecasting",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# Initialize metrics collector
|
||||
metrics_collector = MetricsCollector("training-service")
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""
|
||||
Application lifespan manager for startup and shutdown events
|
||||
"""
|
||||
# Startup
|
||||
logger.info("Starting Training Service", version="1.0.0")
|
||||
|
||||
try:
|
||||
# Initialize database
|
||||
logger.info("Initializing database connection")
|
||||
await database_manager.create_tables()
|
||||
logger.info("Database initialized successfully")
|
||||
|
||||
# Initialize messaging
|
||||
logger.info("Setting up messaging")
|
||||
await setup_messaging()
|
||||
logger.info("Messaging setup completed")
|
||||
|
||||
# Start metrics server
|
||||
logger.info("Starting metrics server")
|
||||
metrics_collector.start_metrics_server(8080)
|
||||
logger.info("Metrics server started on port 8080")
|
||||
|
||||
# Mark service as ready
|
||||
app.state.ready = True
|
||||
logger.info("Training Service startup completed successfully")
|
||||
|
||||
yield
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to start Training Service", error=str(e))
|
||||
app.state.ready = False
|
||||
raise
|
||||
|
||||
# Shutdown
|
||||
logger.info("Shutting down Training Service")
|
||||
|
||||
try:
|
||||
# Cleanup messaging
|
||||
logger.info("Cleaning up messaging")
|
||||
await cleanup_messaging()
|
||||
|
||||
# Close database connections
|
||||
logger.info("Closing database connections")
|
||||
await database_manager.close_connections()
|
||||
|
||||
logger.info("Training Service shutdown completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error during shutdown", error=str(e))
|
||||
|
||||
# Create FastAPI app with lifespan
|
||||
app = FastAPI(
|
||||
title="Training Service",
|
||||
description="ML model training service for bakery demand forecasting",
|
||||
version="1.0.0",
|
||||
docs_url="/docs" if settings.DEBUG else None,
|
||||
redoc_url="/redoc" if settings.DEBUG else None,
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# Initialize app state
|
||||
app.state.ready = False
|
||||
|
||||
# Security middleware
|
||||
if not settings.DEBUG:
|
||||
app.add_middleware(
|
||||
TrustedHostMiddleware,
|
||||
allowed_hosts=["localhost", "127.0.0.1", "training-service", "*.bakery-forecast.local"]
|
||||
)
|
||||
|
||||
# CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_origins=["*"] if settings.DEBUG else [
|
||||
"http://localhost:3000",
|
||||
"http://localhost:8000",
|
||||
"https://dashboard.bakery-forecast.es"
|
||||
],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include routers
|
||||
app.include_router(training.router, prefix="/training", tags=["training"])
|
||||
app.include_router(models.router, prefix="/models", tags=["models"])
|
||||
# Request logging middleware
|
||||
@app.middleware("http")
|
||||
async def log_requests(request: Request, call_next):
|
||||
"""Log all incoming requests with timing"""
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
# Log request
|
||||
logger.info(
|
||||
"Request started",
|
||||
method=request.method,
|
||||
path=request.url.path,
|
||||
client_ip=request.client.host if request.client else "unknown"
|
||||
)
|
||||
|
||||
# Process request
|
||||
try:
|
||||
response = await call_next(request)
|
||||
|
||||
# Calculate duration
|
||||
duration = asyncio.get_event_loop().time() - start_time
|
||||
|
||||
# Log response
|
||||
logger.info(
|
||||
"Request completed",
|
||||
method=request.method,
|
||||
path=request.url.path,
|
||||
status_code=response.status_code,
|
||||
duration_ms=round(duration * 1000, 2)
|
||||
)
|
||||
|
||||
# Update metrics
|
||||
metrics_collector.record_request(
|
||||
method=request.method,
|
||||
endpoint=request.url.path,
|
||||
status_code=response.status_code,
|
||||
duration=duration
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
duration = asyncio.get_event_loop().time() - start_time
|
||||
|
||||
logger.error(
|
||||
"Request failed",
|
||||
method=request.method,
|
||||
path=request.url.path,
|
||||
error=str(e),
|
||||
duration_ms=round(duration * 1000, 2)
|
||||
)
|
||||
|
||||
metrics_collector.increment_counter("http_requests_failed_total")
|
||||
raise
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Application startup"""
|
||||
logger.info("Starting Training Service")
|
||||
# Exception handlers
|
||||
@app.exception_handler(Exception)
|
||||
async def global_exception_handler(request: Request, exc: Exception):
|
||||
"""Global exception handler for unhandled errors"""
|
||||
logger.error(
|
||||
"Unhandled exception",
|
||||
path=request.url.path,
|
||||
method=request.method,
|
||||
error=str(exc),
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
# Create database tables
|
||||
await database_manager.create_tables()
|
||||
metrics_collector.increment_counter("unhandled_exceptions_total")
|
||||
|
||||
# Initialize message publisher
|
||||
await setup_messaging()
|
||||
|
||||
# Start metrics server
|
||||
metrics_collector.start_metrics_server(8080)
|
||||
|
||||
logger.info("Training Service started successfully")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"detail": "Internal server error",
|
||||
"error_id": structlog.get_logger().new().info("Error logged", error=str(exc))
|
||||
}
|
||||
)
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""Application shutdown"""
|
||||
logger.info("Shutting down Training Service")
|
||||
|
||||
# Cleanup message publisher
|
||||
await cleanup_messaging()
|
||||
|
||||
logger.info("Training Service shutdown complete")
|
||||
# Include API routers
|
||||
app.include_router(
|
||||
training.router,
|
||||
prefix="/training",
|
||||
tags=["training"],
|
||||
dependencies=[require_auth] if not settings.DEBUG else []
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
models.router,
|
||||
prefix="/models",
|
||||
tags=["models"],
|
||||
dependencies=[require_auth] if not settings.DEBUG else []
|
||||
)
|
||||
|
||||
# Health check endpoints
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
"""Basic health check endpoint"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"status": "healthy" if app.state.ready else "starting",
|
||||
"service": "training-service",
|
||||
"version": "1.0.0"
|
||||
"version": "1.0.0",
|
||||
"timestamp": structlog.get_logger().new().info("Health check")
|
||||
}
|
||||
|
||||
@app.get("/health/ready")
|
||||
async def readiness_check():
|
||||
"""Kubernetes readiness probe"""
|
||||
if not app.state.ready:
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={"status": "not_ready", "message": "Service is starting up"}
|
||||
)
|
||||
|
||||
return {"status": "ready", "service": "training-service"}
|
||||
|
||||
@app.get("/health/live")
|
||||
async def liveness_check():
|
||||
"""Kubernetes liveness probe"""
|
||||
# Check database connectivity
|
||||
try:
|
||||
db_healthy = await get_db_health()
|
||||
if not db_healthy:
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={"status": "unhealthy", "reason": "database_unavailable"}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Database health check failed", error=str(e))
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={"status": "unhealthy", "reason": "database_error"}
|
||||
)
|
||||
|
||||
return {"status": "alive", "service": "training-service"}
|
||||
|
||||
@app.get("/metrics")
|
||||
async def get_metrics():
|
||||
"""Expose service metrics"""
|
||||
return {
|
||||
"training_jobs_active": metrics_collector.get_gauge("training_jobs_active", 0),
|
||||
"training_jobs_completed": metrics_collector.get_counter("training_jobs_completed", 0),
|
||||
"training_jobs_failed": metrics_collector.get_counter("training_jobs_failed", 0),
|
||||
"models_trained_total": metrics_collector.get_counter("models_trained_total", 0),
|
||||
"uptime_seconds": metrics_collector.get_gauge("uptime_seconds", 0)
|
||||
}
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint with service information"""
|
||||
return {
|
||||
"service": "training-service",
|
||||
"version": "1.0.0",
|
||||
"description": "ML model training service for bakery demand forecasting",
|
||||
"docs": "/docs" if settings.DEBUG else "Documentation disabled in production",
|
||||
"health": "/health"
|
||||
}
|
||||
|
||||
# Development server configuration
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
reload=settings.DEBUG,
|
||||
log_level=settings.LOG_LEVEL.lower(),
|
||||
access_log=settings.DEBUG,
|
||||
server_header=False,
|
||||
date_header=False
|
||||
)
|
||||
493
services/training/app/ml/data_processor.py
Normal file
493
services/training/app/ml/data_processor.py
Normal file
@@ -0,0 +1,493 @@
|
||||
# services/training/app/ml/data_processor.py
|
||||
"""
|
||||
Data Processor for Training Service
|
||||
Handles data preparation and feature engineering for ML training
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.impute import SimpleImputer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BakeryDataProcessor:
|
||||
"""
|
||||
Enhanced data processor for bakery forecasting training service.
|
||||
Handles data cleaning, feature engineering, and preparation for ML models.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.scalers = {} # Store scalers for each feature
|
||||
self.imputers = {} # Store imputers for missing value handling
|
||||
|
||||
async def prepare_training_data(self,
|
||||
sales_data: pd.DataFrame,
|
||||
weather_data: pd.DataFrame,
|
||||
traffic_data: pd.DataFrame,
|
||||
product_name: str) -> pd.DataFrame:
|
||||
"""
|
||||
Prepare comprehensive training data for a specific product.
|
||||
|
||||
Args:
|
||||
sales_data: Historical sales data for the product
|
||||
weather_data: Weather data
|
||||
traffic_data: Traffic data
|
||||
product_name: Product name for logging
|
||||
|
||||
Returns:
|
||||
DataFrame ready for Prophet training with 'ds' and 'y' columns plus features
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Preparing training data for product: {product_name}")
|
||||
|
||||
# Convert and validate sales data
|
||||
sales_clean = await self._process_sales_data(sales_data, product_name)
|
||||
|
||||
# Aggregate to daily level
|
||||
daily_sales = await self._aggregate_daily_sales(sales_clean)
|
||||
|
||||
# Add temporal features
|
||||
daily_sales = self._add_temporal_features(daily_sales)
|
||||
|
||||
# Merge external data sources
|
||||
daily_sales = self._merge_weather_features(daily_sales, weather_data)
|
||||
daily_sales = self._merge_traffic_features(daily_sales, traffic_data)
|
||||
|
||||
# Engineer additional features
|
||||
daily_sales = self._engineer_features(daily_sales)
|
||||
|
||||
# Handle missing values
|
||||
daily_sales = self._handle_missing_values(daily_sales)
|
||||
|
||||
# Prepare for Prophet (rename columns and validate)
|
||||
prophet_data = self._prepare_prophet_format(daily_sales)
|
||||
|
||||
logger.info(f"Prepared {len(prophet_data)} data points for {product_name}")
|
||||
return prophet_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing training data for {product_name}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def prepare_prediction_features(self,
|
||||
future_dates: pd.DatetimeIndex,
|
||||
weather_forecast: pd.DataFrame = None,
|
||||
traffic_forecast: pd.DataFrame = None) -> pd.DataFrame:
|
||||
"""
|
||||
Create features for future predictions.
|
||||
|
||||
Args:
|
||||
future_dates: Future dates to predict
|
||||
weather_forecast: Weather forecast data
|
||||
traffic_forecast: Traffic forecast data
|
||||
|
||||
Returns:
|
||||
DataFrame with features for prediction
|
||||
"""
|
||||
try:
|
||||
# Create base future dataframe
|
||||
future_df = pd.DataFrame({'ds': future_dates})
|
||||
|
||||
# Add temporal features
|
||||
future_df = self._add_temporal_features(
|
||||
future_df.rename(columns={'ds': 'date'})
|
||||
).rename(columns={'date': 'ds'})
|
||||
|
||||
# Add weather features
|
||||
if weather_forecast is not None and not weather_forecast.empty:
|
||||
weather_features = weather_forecast.copy()
|
||||
if 'date' in weather_features.columns:
|
||||
weather_features = weather_features.rename(columns={'date': 'ds'})
|
||||
|
||||
future_df = future_df.merge(weather_features, on='ds', how='left')
|
||||
|
||||
# Add traffic features
|
||||
if traffic_forecast is not None and not traffic_forecast.empty:
|
||||
traffic_features = traffic_forecast.copy()
|
||||
if 'date' in traffic_features.columns:
|
||||
traffic_features = traffic_features.rename(columns={'date': 'ds'})
|
||||
|
||||
future_df = future_df.merge(traffic_features, on='ds', how='left')
|
||||
|
||||
# Engineer additional features
|
||||
future_df = self._engineer_features(future_df.rename(columns={'ds': 'date'}))
|
||||
future_df = future_df.rename(columns={'date': 'ds'})
|
||||
|
||||
# Handle missing values in future data
|
||||
numeric_columns = future_df.select_dtypes(include=[np.number]).columns
|
||||
for col in numeric_columns:
|
||||
if future_df[col].isna().any():
|
||||
# Use reasonable defaults for Madrid
|
||||
if col == 'temperature':
|
||||
future_df[col] = future_df[col].fillna(15.0) # Default Madrid temp
|
||||
elif col == 'precipitation':
|
||||
future_df[col] = future_df[col].fillna(0.0) # Default no rain
|
||||
elif col == 'humidity':
|
||||
future_df[col] = future_df[col].fillna(60.0) # Default humidity
|
||||
elif col == 'traffic_volume':
|
||||
future_df[col] = future_df[col].fillna(100.0) # Default traffic
|
||||
else:
|
||||
future_df[col] = future_df[col].fillna(future_df[col].median())
|
||||
|
||||
return future_df
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating prediction features: {e}")
|
||||
# Return minimal features if error
|
||||
return pd.DataFrame({'ds': future_dates})
|
||||
|
||||
async def _process_sales_data(self, sales_data: pd.DataFrame, product_name: str) -> pd.DataFrame:
|
||||
"""Process and clean sales data"""
|
||||
sales_clean = sales_data.copy()
|
||||
|
||||
# Ensure date column exists and is datetime
|
||||
if 'date' not in sales_clean.columns:
|
||||
raise ValueError("Sales data must have a 'date' column")
|
||||
|
||||
sales_clean['date'] = pd.to_datetime(sales_clean['date'])
|
||||
|
||||
# Ensure quantity column exists and is numeric
|
||||
if 'quantity' not in sales_clean.columns:
|
||||
raise ValueError("Sales data must have a 'quantity' column")
|
||||
|
||||
sales_clean['quantity'] = pd.to_numeric(sales_clean['quantity'], errors='coerce')
|
||||
|
||||
# Remove rows with invalid quantities
|
||||
sales_clean = sales_clean.dropna(subset=['quantity'])
|
||||
sales_clean = sales_clean[sales_clean['quantity'] >= 0] # No negative sales
|
||||
|
||||
# Filter for the specific product if product_name column exists
|
||||
if 'product_name' in sales_clean.columns:
|
||||
sales_clean = sales_clean[sales_clean['product_name'] == product_name]
|
||||
|
||||
return sales_clean
|
||||
|
||||
async def _aggregate_daily_sales(self, sales_data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Aggregate sales to daily level"""
|
||||
daily_sales = sales_data.groupby('date').agg({
|
||||
'quantity': 'sum'
|
||||
}).reset_index()
|
||||
|
||||
# Ensure we have data for all dates in the range
|
||||
date_range = pd.date_range(
|
||||
start=daily_sales['date'].min(),
|
||||
end=daily_sales['date'].max(),
|
||||
freq='D'
|
||||
)
|
||||
|
||||
full_date_df = pd.DataFrame({'date': date_range})
|
||||
daily_sales = full_date_df.merge(daily_sales, on='date', how='left')
|
||||
daily_sales['quantity'] = daily_sales['quantity'].fillna(0) # Fill missing days with 0 sales
|
||||
|
||||
return daily_sales
|
||||
|
||||
def _add_temporal_features(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Add temporal features like day of week, month, etc."""
|
||||
df = df.copy()
|
||||
|
||||
# Ensure we have a date column
|
||||
if 'date' not in df.columns:
|
||||
raise ValueError("DataFrame must have a 'date' column")
|
||||
|
||||
df['date'] = pd.to_datetime(df['date'])
|
||||
|
||||
# Day of week (0=Monday, 6=Sunday)
|
||||
df['day_of_week'] = df['date'].dt.dayofweek
|
||||
df['is_weekend'] = df['day_of_week'].isin([5, 6]).astype(int)
|
||||
|
||||
# Month and season
|
||||
df['month'] = df['date'].dt.month
|
||||
df['season'] = df['month'].apply(self._get_season)
|
||||
|
||||
# Week of year
|
||||
df['week_of_year'] = df['date'].dt.isocalendar().week
|
||||
|
||||
# Quarter
|
||||
df['quarter'] = df['date'].dt.quarter
|
||||
|
||||
# Holiday indicators (basic Spanish holidays)
|
||||
df['is_holiday'] = df['date'].apply(self._is_spanish_holiday).astype(int)
|
||||
|
||||
# School calendar effects (approximate)
|
||||
df['is_school_holiday'] = df['date'].apply(self._is_school_holiday).astype(int)
|
||||
|
||||
return df
|
||||
|
||||
def _merge_weather_features(self,
|
||||
daily_sales: pd.DataFrame,
|
||||
weather_data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Merge weather features with sales data"""
|
||||
|
||||
if weather_data.empty:
|
||||
# Add default weather columns with neutral values
|
||||
daily_sales['temperature'] = 15.0 # Mild temperature
|
||||
daily_sales['precipitation'] = 0.0 # No rain
|
||||
daily_sales['humidity'] = 60.0 # Moderate humidity
|
||||
daily_sales['wind_speed'] = 5.0 # Light wind
|
||||
return daily_sales
|
||||
|
||||
try:
|
||||
weather_clean = weather_data.copy()
|
||||
|
||||
# Ensure weather data has date column
|
||||
if 'date' not in weather_clean.columns and 'ds' in weather_clean.columns:
|
||||
weather_clean = weather_clean.rename(columns={'ds': 'date'})
|
||||
|
||||
weather_clean['date'] = pd.to_datetime(weather_clean['date'])
|
||||
|
||||
# Select relevant weather features
|
||||
weather_features = ['date']
|
||||
|
||||
# Add available weather columns with default names
|
||||
weather_mapping = {
|
||||
'temperature': ['temperature', 'temp', 'temperatura'],
|
||||
'precipitation': ['precipitation', 'rain', 'lluvia', 'precipitacion'],
|
||||
'humidity': ['humidity', 'humedad'],
|
||||
'wind_speed': ['wind_speed', 'viento', 'wind']
|
||||
}
|
||||
|
||||
for standard_name, possible_names in weather_mapping.items():
|
||||
for possible_name in possible_names:
|
||||
if possible_name in weather_clean.columns:
|
||||
weather_clean[standard_name] = weather_clean[possible_name]
|
||||
weather_features.append(standard_name)
|
||||
break
|
||||
|
||||
# Keep only the features we found
|
||||
weather_clean = weather_clean[weather_features].copy()
|
||||
|
||||
# Merge with sales data
|
||||
merged = daily_sales.merge(weather_clean, on='date', how='left')
|
||||
|
||||
# Fill missing weather values with reasonable defaults
|
||||
if 'temperature' in merged.columns:
|
||||
merged['temperature'] = merged['temperature'].fillna(15.0)
|
||||
if 'precipitation' in merged.columns:
|
||||
merged['precipitation'] = merged['precipitation'].fillna(0.0)
|
||||
if 'humidity' in merged.columns:
|
||||
merged['humidity'] = merged['humidity'].fillna(60.0)
|
||||
if 'wind_speed' in merged.columns:
|
||||
merged['wind_speed'] = merged['wind_speed'].fillna(5.0)
|
||||
|
||||
return merged
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error merging weather data: {e}")
|
||||
# Add default weather columns if merge fails
|
||||
daily_sales['temperature'] = 15.0
|
||||
daily_sales['precipitation'] = 0.0
|
||||
daily_sales['humidity'] = 60.0
|
||||
daily_sales['wind_speed'] = 5.0
|
||||
return daily_sales
|
||||
|
||||
def _merge_traffic_features(self,
|
||||
daily_sales: pd.DataFrame,
|
||||
traffic_data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Merge traffic features with sales data"""
|
||||
|
||||
if traffic_data.empty:
|
||||
# Add default traffic column
|
||||
daily_sales['traffic_volume'] = 100.0 # Neutral traffic level
|
||||
return daily_sales
|
||||
|
||||
try:
|
||||
traffic_clean = traffic_data.copy()
|
||||
|
||||
# Ensure traffic data has date column
|
||||
if 'date' not in traffic_clean.columns and 'ds' in traffic_clean.columns:
|
||||
traffic_clean = traffic_clean.rename(columns={'ds': 'date'})
|
||||
|
||||
traffic_clean['date'] = pd.to_datetime(traffic_clean['date'])
|
||||
|
||||
# Select relevant traffic features
|
||||
traffic_features = ['date']
|
||||
|
||||
# Map traffic column names
|
||||
traffic_mapping = {
|
||||
'traffic_volume': ['traffic_volume', 'traffic_intensity', 'trafico', 'intensidad'],
|
||||
'pedestrian_count': ['pedestrian_count', 'peatones'],
|
||||
'occupancy_rate': ['occupancy_rate', 'ocupacion']
|
||||
}
|
||||
|
||||
for standard_name, possible_names in traffic_mapping.items():
|
||||
for possible_name in possible_names:
|
||||
if possible_name in traffic_clean.columns:
|
||||
traffic_clean[standard_name] = traffic_clean[possible_name]
|
||||
traffic_features.append(standard_name)
|
||||
break
|
||||
|
||||
# Keep only the features we found
|
||||
traffic_clean = traffic_clean[traffic_features].copy()
|
||||
|
||||
# Merge with sales data
|
||||
merged = daily_sales.merge(traffic_clean, on='date', how='left')
|
||||
|
||||
# Fill missing traffic values
|
||||
if 'traffic_volume' in merged.columns:
|
||||
merged['traffic_volume'] = merged['traffic_volume'].fillna(100.0)
|
||||
if 'pedestrian_count' in merged.columns:
|
||||
merged['pedestrian_count'] = merged['pedestrian_count'].fillna(50.0)
|
||||
if 'occupancy_rate' in merged.columns:
|
||||
merged['occupancy_rate'] = merged['occupancy_rate'].fillna(0.5)
|
||||
|
||||
return merged
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error merging traffic data: {e}")
|
||||
# Add default traffic column if merge fails
|
||||
daily_sales['traffic_volume'] = 100.0
|
||||
return daily_sales
|
||||
|
||||
def _engineer_features(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Engineer additional features from existing data"""
|
||||
df = df.copy()
|
||||
|
||||
# Weather-based features
|
||||
if 'temperature' in df.columns:
|
||||
df['temp_squared'] = df['temperature'] ** 2
|
||||
df['is_hot_day'] = (df['temperature'] > 25).astype(int)
|
||||
df['is_cold_day'] = (df['temperature'] < 10).astype(int)
|
||||
|
||||
if 'precipitation' in df.columns:
|
||||
df['is_rainy_day'] = (df['precipitation'] > 0).astype(int)
|
||||
df['heavy_rain'] = (df['precipitation'] > 10).astype(int)
|
||||
|
||||
# Traffic-based features
|
||||
if 'traffic_volume' in df.columns:
|
||||
df['high_traffic'] = (df['traffic_volume'] > df['traffic_volume'].quantile(0.75)).astype(int)
|
||||
df['low_traffic'] = (df['traffic_volume'] < df['traffic_volume'].quantile(0.25)).astype(int)
|
||||
|
||||
# Interaction features
|
||||
if 'is_weekend' in df.columns and 'temperature' in df.columns:
|
||||
df['weekend_temp_interaction'] = df['is_weekend'] * df['temperature']
|
||||
|
||||
if 'is_rainy_day' in df.columns and 'traffic_volume' in df.columns:
|
||||
df['rain_traffic_interaction'] = df['is_rainy_day'] * df['traffic_volume']
|
||||
|
||||
return df
|
||||
|
||||
def _handle_missing_values(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Handle missing values in the dataset"""
|
||||
df = df.copy()
|
||||
|
||||
# For numeric columns, use median imputation
|
||||
numeric_columns = df.select_dtypes(include=[np.number]).columns
|
||||
|
||||
for col in numeric_columns:
|
||||
if col != 'quantity' and df[col].isna().any():
|
||||
median_value = df[col].median()
|
||||
df[col] = df[col].fillna(median_value)
|
||||
|
||||
return df
|
||||
|
||||
def _prepare_prophet_format(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Prepare data in Prophet format with 'ds' and 'y' columns"""
|
||||
prophet_df = df.copy()
|
||||
|
||||
# Rename columns for Prophet
|
||||
if 'date' in prophet_df.columns:
|
||||
prophet_df = prophet_df.rename(columns={'date': 'ds'})
|
||||
|
||||
if 'quantity' in prophet_df.columns:
|
||||
prophet_df = prophet_df.rename(columns={'quantity': 'y'})
|
||||
|
||||
# Ensure ds is datetime
|
||||
if 'ds' in prophet_df.columns:
|
||||
prophet_df['ds'] = pd.to_datetime(prophet_df['ds'])
|
||||
|
||||
# Validate required columns
|
||||
if 'ds' not in prophet_df.columns or 'y' not in prophet_df.columns:
|
||||
raise ValueError("Prophet data must have 'ds' and 'y' columns")
|
||||
|
||||
# Remove any rows with missing target values
|
||||
prophet_df = prophet_df.dropna(subset=['y'])
|
||||
|
||||
# Sort by date
|
||||
prophet_df = prophet_df.sort_values('ds').reset_index(drop=True)
|
||||
|
||||
return prophet_df
|
||||
|
||||
def _get_season(self, month: int) -> int:
|
||||
"""Get season from month (1-4 for Winter, Spring, Summer, Autumn)"""
|
||||
if month in [12, 1, 2]:
|
||||
return 1 # Winter
|
||||
elif month in [3, 4, 5]:
|
||||
return 2 # Spring
|
||||
elif month in [6, 7, 8]:
|
||||
return 3 # Summer
|
||||
else:
|
||||
return 4 # Autumn
|
||||
|
||||
def _is_spanish_holiday(self, date: datetime) -> bool:
|
||||
"""Check if a date is a major Spanish holiday"""
|
||||
month_day = (date.month, date.day)
|
||||
|
||||
# Major Spanish holidays that affect bakery sales
|
||||
spanish_holidays = [
|
||||
(1, 1), # New Year
|
||||
(1, 6), # Epiphany
|
||||
(5, 1), # Labour Day
|
||||
(8, 15), # Assumption
|
||||
(10, 12), # National Day
|
||||
(11, 1), # All Saints
|
||||
(12, 6), # Constitution
|
||||
(12, 8), # Immaculate Conception
|
||||
(12, 25), # Christmas
|
||||
(5, 15), # San Isidro (Madrid)
|
||||
(5, 2), # Madrid Community Day
|
||||
]
|
||||
|
||||
return month_day in spanish_holidays
|
||||
|
||||
def _is_school_holiday(self, date: datetime) -> bool:
|
||||
"""Check if a date is during school holidays (approximate)"""
|
||||
month = date.month
|
||||
|
||||
# Approximate Spanish school holiday periods
|
||||
# Summer holidays (July-August)
|
||||
if month in [7, 8]:
|
||||
return True
|
||||
|
||||
# Christmas holidays (mid December to early January)
|
||||
if month == 12 and date.day >= 20:
|
||||
return True
|
||||
if month == 1 and date.day <= 10:
|
||||
return True
|
||||
|
||||
# Easter holidays (approximate - first two weeks of April)
|
||||
if month == 4 and date.day <= 14:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def calculate_feature_importance(self,
|
||||
model_data: pd.DataFrame,
|
||||
target_column: str = 'y') -> Dict[str, float]:
|
||||
"""
|
||||
Calculate feature importance for the model.
|
||||
"""
|
||||
try:
|
||||
# Simple correlation-based importance
|
||||
numeric_features = model_data.select_dtypes(include=[np.number]).columns
|
||||
numeric_features = [col for col in numeric_features if col != target_column]
|
||||
|
||||
importance_scores = {}
|
||||
|
||||
for feature in numeric_features:
|
||||
if feature in model_data.columns:
|
||||
correlation = model_data[feature].corr(model_data[target_column])
|
||||
importance_scores[feature] = abs(correlation) if not pd.isna(correlation) else 0.0
|
||||
|
||||
# Sort by importance
|
||||
importance_scores = dict(sorted(importance_scores.items(),
|
||||
key=lambda x: x[1], reverse=True))
|
||||
|
||||
return importance_scores
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating feature importance: {e}")
|
||||
return {}
|
||||
408
services/training/app/ml/prophet_manager.py
Normal file
408
services/training/app/ml/prophet_manager.py
Normal file
@@ -0,0 +1,408 @@
|
||||
# services/training/app/ml/prophet_manager.py
|
||||
"""
|
||||
Enhanced Prophet Manager for Training Service
|
||||
Migrated from the monolithic backend to microservices architecture
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from prophet import Prophet
|
||||
import pickle
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
import uuid
|
||||
import asyncio
|
||||
import os
|
||||
import joblib
|
||||
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BakeryProphetManager:
|
||||
"""
|
||||
Enhanced Prophet model manager for the training service.
|
||||
Handles training, validation, and model persistence for bakery forecasting.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.models = {} # In-memory model storage
|
||||
self.model_metadata = {} # Store model metadata
|
||||
self.feature_scalers = {} # Store feature scalers per model
|
||||
|
||||
# Ensure model storage directory exists
|
||||
os.makedirs(settings.MODEL_STORAGE_PATH, exist_ok=True)
|
||||
|
||||
async def train_bakery_model(self,
|
||||
tenant_id: str,
|
||||
product_name: str,
|
||||
df: pd.DataFrame,
|
||||
job_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Train a Prophet model for bakery forecasting with enhanced features.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
product_name: Product name
|
||||
df: Training data with 'ds' and 'y' columns plus regressors
|
||||
job_id: Training job identifier
|
||||
|
||||
Returns:
|
||||
Dictionary with model information and metrics
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Training bakery model for tenant {tenant_id}, product {product_name}")
|
||||
|
||||
# Validate input data
|
||||
await self._validate_training_data(df, product_name)
|
||||
|
||||
# Prepare data for Prophet
|
||||
prophet_data = await self._prepare_prophet_data(df)
|
||||
|
||||
# Get regressor columns
|
||||
regressor_columns = self._extract_regressor_columns(prophet_data)
|
||||
|
||||
# Initialize Prophet model with bakery-specific settings
|
||||
model = self._create_prophet_model(regressor_columns)
|
||||
|
||||
# Add regressors to model
|
||||
for regressor in regressor_columns:
|
||||
if regressor in prophet_data.columns:
|
||||
model.add_regressor(regressor)
|
||||
|
||||
# Fit the model
|
||||
model.fit(prophet_data)
|
||||
|
||||
# Generate model ID and store model
|
||||
model_id = f"{job_id}_{product_name}_{uuid.uuid4().hex[:8]}"
|
||||
model_path = await self._store_model(
|
||||
tenant_id, product_name, model, model_id, prophet_data, regressor_columns
|
||||
)
|
||||
|
||||
# Calculate training metrics
|
||||
training_metrics = await self._calculate_training_metrics(model, prophet_data)
|
||||
|
||||
# Prepare model information
|
||||
model_info = {
|
||||
"model_id": model_id,
|
||||
"model_path": model_path,
|
||||
"type": "prophet",
|
||||
"training_samples": len(prophet_data),
|
||||
"features": regressor_columns,
|
||||
"hyperparameters": {
|
||||
"seasonality_mode": settings.PROPHET_SEASONALITY_MODE,
|
||||
"daily_seasonality": settings.PROPHET_DAILY_SEASONALITY,
|
||||
"weekly_seasonality": settings.PROPHET_WEEKLY_SEASONALITY,
|
||||
"yearly_seasonality": settings.PROPHET_YEARLY_SEASONALITY
|
||||
},
|
||||
"training_metrics": training_metrics,
|
||||
"trained_at": datetime.now().isoformat(),
|
||||
"data_period": {
|
||||
"start_date": prophet_data['ds'].min().isoformat(),
|
||||
"end_date": prophet_data['ds'].max().isoformat(),
|
||||
"total_days": len(prophet_data)
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(f"Model trained successfully for {product_name}")
|
||||
return model_info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to train bakery model for {product_name}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def generate_forecast(self,
|
||||
model_path: str,
|
||||
future_dates: pd.DataFrame,
|
||||
regressor_columns: List[str]) -> pd.DataFrame:
|
||||
"""
|
||||
Generate forecast using a stored Prophet model.
|
||||
|
||||
Args:
|
||||
model_path: Path to the stored model
|
||||
future_dates: DataFrame with future dates and regressors
|
||||
regressor_columns: List of regressor column names
|
||||
|
||||
Returns:
|
||||
DataFrame with forecast results
|
||||
"""
|
||||
try:
|
||||
# Load the model
|
||||
model = joblib.load(model_path)
|
||||
|
||||
# Validate future data has required regressors
|
||||
for regressor in regressor_columns:
|
||||
if regressor not in future_dates.columns:
|
||||
logger.warning(f"Missing regressor {regressor}, filling with median")
|
||||
future_dates[regressor] = 0 # Default value
|
||||
|
||||
# Generate forecast
|
||||
forecast = model.predict(future_dates)
|
||||
|
||||
return forecast
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate forecast: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _validate_training_data(self, df: pd.DataFrame, product_name: str):
|
||||
"""Validate training data quality"""
|
||||
if df.empty:
|
||||
raise ValueError(f"No training data available for {product_name}")
|
||||
|
||||
if len(df) < settings.MIN_TRAINING_DATA_DAYS:
|
||||
raise ValueError(
|
||||
f"Insufficient training data for {product_name}: "
|
||||
f"{len(df)} days, minimum required: {settings.MIN_TRAINING_DATA_DAYS}"
|
||||
)
|
||||
|
||||
required_columns = ['ds', 'y']
|
||||
missing_columns = [col for col in required_columns if col not in df.columns]
|
||||
if missing_columns:
|
||||
raise ValueError(f"Missing required columns: {missing_columns}")
|
||||
|
||||
# Check for valid date range
|
||||
if df['ds'].isna().any():
|
||||
raise ValueError("Invalid dates found in training data")
|
||||
|
||||
# Check for valid target values
|
||||
if df['y'].isna().all():
|
||||
raise ValueError("No valid target values found")
|
||||
|
||||
async def _prepare_prophet_data(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Prepare data for Prophet training"""
|
||||
prophet_data = df.copy()
|
||||
|
||||
# Ensure ds column is datetime
|
||||
prophet_data['ds'] = pd.to_datetime(prophet_data['ds'])
|
||||
|
||||
# Handle missing values in target
|
||||
if prophet_data['y'].isna().any():
|
||||
logger.warning("Filling missing target values with interpolation")
|
||||
prophet_data['y'] = prophet_data['y'].interpolate(method='linear')
|
||||
|
||||
# Remove extreme outliers (values > 3 standard deviations)
|
||||
mean_val = prophet_data['y'].mean()
|
||||
std_val = prophet_data['y'].std()
|
||||
|
||||
if std_val > 0: # Avoid division by zero
|
||||
lower_bound = mean_val - 3 * std_val
|
||||
upper_bound = mean_val + 3 * std_val
|
||||
|
||||
before_count = len(prophet_data)
|
||||
prophet_data = prophet_data[
|
||||
(prophet_data['y'] >= lower_bound) &
|
||||
(prophet_data['y'] <= upper_bound)
|
||||
]
|
||||
after_count = len(prophet_data)
|
||||
|
||||
if before_count != after_count:
|
||||
logger.info(f"Removed {before_count - after_count} outliers")
|
||||
|
||||
# Ensure chronological order
|
||||
prophet_data = prophet_data.sort_values('ds').reset_index(drop=True)
|
||||
|
||||
# Fill missing values in regressors
|
||||
numeric_columns = prophet_data.select_dtypes(include=[np.number]).columns
|
||||
for col in numeric_columns:
|
||||
if col != 'y' and prophet_data[col].isna().any():
|
||||
prophet_data[col] = prophet_data[col].fillna(prophet_data[col].median())
|
||||
|
||||
return prophet_data
|
||||
|
||||
def _extract_regressor_columns(self, df: pd.DataFrame) -> List[str]:
|
||||
"""Extract regressor columns from the dataframe"""
|
||||
excluded_columns = ['ds', 'y']
|
||||
regressor_columns = []
|
||||
|
||||
for col in df.columns:
|
||||
if col not in excluded_columns and df[col].dtype in ['int64', 'float64']:
|
||||
regressor_columns.append(col)
|
||||
|
||||
logger.info(f"Identified regressor columns: {regressor_columns}")
|
||||
return regressor_columns
|
||||
|
||||
def _create_prophet_model(self, regressor_columns: List[str]) -> Prophet:
|
||||
"""Create Prophet model with bakery-specific settings"""
|
||||
|
||||
# Get Spanish holidays
|
||||
holidays = self._get_spanish_holidays()
|
||||
|
||||
# Bakery-specific Prophet configuration
|
||||
model = Prophet(
|
||||
holidays=holidays if not holidays.empty else None,
|
||||
daily_seasonality=settings.PROPHET_DAILY_SEASONALITY,
|
||||
weekly_seasonality=settings.PROPHET_WEEKLY_SEASONALITY,
|
||||
yearly_seasonality=settings.PROPHET_YEARLY_SEASONALITY,
|
||||
seasonality_mode=settings.PROPHET_SEASONALITY_MODE,
|
||||
changepoint_prior_scale=0.05, # Conservative changepoint detection
|
||||
seasonality_prior_scale=10, # Strong seasonality for bakeries
|
||||
holidays_prior_scale=10, # Strong holiday effects
|
||||
interval_width=0.8, # 80% confidence intervals
|
||||
mcmc_samples=0, # Use MAP estimation (faster)
|
||||
uncertainty_samples=1000 # For uncertainty estimation
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
def _get_spanish_holidays(self) -> pd.DataFrame:
|
||||
"""Get Spanish holidays for Prophet model"""
|
||||
try:
|
||||
# Define major Spanish holidays that affect bakery sales
|
||||
holidays_list = []
|
||||
|
||||
years = range(2020, 2030) # Cover training and prediction period
|
||||
|
||||
for year in years:
|
||||
holidays_list.extend([
|
||||
{'holiday': 'new_year', 'ds': f'{year}-01-01'},
|
||||
{'holiday': 'epiphany', 'ds': f'{year}-01-06'},
|
||||
{'holiday': 'may_day', 'ds': f'{year}-05-01'},
|
||||
{'holiday': 'assumption', 'ds': f'{year}-08-15'},
|
||||
{'holiday': 'national_day', 'ds': f'{year}-10-12'},
|
||||
{'holiday': 'all_saints', 'ds': f'{year}-11-01'},
|
||||
{'holiday': 'constitution', 'ds': f'{year}-12-06'},
|
||||
{'holiday': 'immaculate', 'ds': f'{year}-12-08'},
|
||||
{'holiday': 'christmas', 'ds': f'{year}-12-25'},
|
||||
|
||||
# Madrid specific holidays
|
||||
{'holiday': 'madrid_patron', 'ds': f'{year}-05-15'}, # San Isidro
|
||||
{'holiday': 'madrid_community', 'ds': f'{year}-05-02'},
|
||||
])
|
||||
|
||||
holidays_df = pd.DataFrame(holidays_list)
|
||||
holidays_df['ds'] = pd.to_datetime(holidays_df['ds'])
|
||||
|
||||
return holidays_df
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error creating holidays dataframe: {e}")
|
||||
return pd.DataFrame()
|
||||
|
||||
async def _store_model(self,
|
||||
tenant_id: str,
|
||||
product_name: str,
|
||||
model: Prophet,
|
||||
model_id: str,
|
||||
training_data: pd.DataFrame,
|
||||
regressor_columns: List[str]) -> str:
|
||||
"""Store model and metadata to filesystem"""
|
||||
|
||||
# Create model filename
|
||||
model_filename = f"{model_id}_prophet_model.pkl"
|
||||
model_path = os.path.join(settings.MODEL_STORAGE_PATH, model_filename)
|
||||
|
||||
# Store the model
|
||||
joblib.dump(model, model_path)
|
||||
|
||||
# Store metadata
|
||||
metadata = {
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name,
|
||||
"model_id": model_id,
|
||||
"regressor_columns": regressor_columns,
|
||||
"training_samples": len(training_data),
|
||||
"training_period": {
|
||||
"start": training_data['ds'].min().isoformat(),
|
||||
"end": training_data['ds'].max().isoformat()
|
||||
},
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"model_type": "prophet",
|
||||
"file_path": model_path
|
||||
}
|
||||
|
||||
metadata_path = model_path.replace('.pkl', '_metadata.json')
|
||||
with open(metadata_path, 'w') as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
# Store in memory for quick access
|
||||
model_key = f"{tenant_id}:{product_name}"
|
||||
self.models[model_key] = model
|
||||
self.model_metadata[model_key] = metadata
|
||||
|
||||
logger.info(f"Model stored at: {model_path}")
|
||||
return model_path
|
||||
|
||||
async def _calculate_training_metrics(self,
|
||||
model: Prophet,
|
||||
training_data: pd.DataFrame) -> Dict[str, float]:
|
||||
"""Calculate training metrics for the model"""
|
||||
try:
|
||||
# Generate in-sample predictions
|
||||
forecast = model.predict(training_data[['ds'] + [col for col in training_data.columns if col not in ['ds', 'y']]])
|
||||
|
||||
# Calculate metrics
|
||||
y_true = training_data['y'].values
|
||||
y_pred = forecast['yhat'].values
|
||||
|
||||
# Basic metrics
|
||||
mae = mean_absolute_error(y_true, y_pred)
|
||||
mse = mean_squared_error(y_true, y_pred)
|
||||
rmse = np.sqrt(mse)
|
||||
|
||||
# MAPE (Mean Absolute Percentage Error)
|
||||
mape = np.mean(np.abs((y_true - y_pred) / y_true)) * 100
|
||||
|
||||
# R-squared
|
||||
r2 = r2_score(y_true, y_pred)
|
||||
|
||||
return {
|
||||
"mae": round(mae, 2),
|
||||
"mse": round(mse, 2),
|
||||
"rmse": round(rmse, 2),
|
||||
"mape": round(mape, 2),
|
||||
"r2_score": round(r2, 4),
|
||||
"mean_actual": round(np.mean(y_true), 2),
|
||||
"mean_predicted": round(np.mean(y_pred), 2)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating training metrics: {e}")
|
||||
return {
|
||||
"mae": 0.0,
|
||||
"mse": 0.0,
|
||||
"rmse": 0.0,
|
||||
"mape": 0.0,
|
||||
"r2_score": 0.0,
|
||||
"mean_actual": 0.0,
|
||||
"mean_predicted": 0.0
|
||||
}
|
||||
|
||||
def get_model_info(self, tenant_id: str, product_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get model information for a specific tenant and product"""
|
||||
model_key = f"{tenant_id}:{product_name}"
|
||||
return self.model_metadata.get(model_key)
|
||||
|
||||
def list_models(self, tenant_id: str) -> List[Dict[str, Any]]:
|
||||
"""List all models for a tenant"""
|
||||
tenant_models = []
|
||||
|
||||
for model_key, metadata in self.model_metadata.items():
|
||||
if metadata['tenant_id'] == tenant_id:
|
||||
tenant_models.append(metadata)
|
||||
|
||||
return tenant_models
|
||||
|
||||
async def cleanup_old_models(self, days_old: int = 30):
|
||||
"""Clean up old model files"""
|
||||
try:
|
||||
cutoff_date = datetime.now() - timedelta(days=days_old)
|
||||
|
||||
for model_path in Path(settings.MODEL_STORAGE_PATH).glob("*.pkl"):
|
||||
# Check file modification time
|
||||
if model_path.stat().st_mtime < cutoff_date.timestamp():
|
||||
# Remove model and metadata files
|
||||
model_path.unlink()
|
||||
|
||||
metadata_path = model_path.with_suffix('.json')
|
||||
if metadata_path.exists():
|
||||
metadata_path.unlink()
|
||||
|
||||
logger.info(f"Cleaned up old model: {model_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during model cleanup: {e}")
|
||||
@@ -1,174 +1,372 @@
|
||||
# services/training/app/ml/trainer.py
|
||||
"""
|
||||
ML Training implementation
|
||||
ML Trainer for Training Service
|
||||
Orchestrates the complete training process
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import structlog
|
||||
from typing import Dict, Any, List
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
import pandas as pd
|
||||
from datetime import datetime
|
||||
import joblib
|
||||
import os
|
||||
from prophet import Prophet
|
||||
import numpy as np
|
||||
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
import asyncio
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
from app.ml.prophet_manager import BakeryProphetManager
|
||||
from app.ml.data_processor import BakeryDataProcessor
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MLTrainer:
|
||||
"""ML training implementation"""
|
||||
class BakeryMLTrainer:
|
||||
"""
|
||||
Main ML trainer that orchestrates the complete training process.
|
||||
Replaces the old Celery-based training system with clean async implementation.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.model_storage_path = settings.MODEL_STORAGE_PATH
|
||||
os.makedirs(self.model_storage_path, exist_ok=True)
|
||||
self.prophet_manager = BakeryProphetManager()
|
||||
self.data_processor = BakeryDataProcessor()
|
||||
|
||||
async def train_tenant_models(self,
|
||||
tenant_id: str,
|
||||
sales_data: List[Dict],
|
||||
weather_data: List[Dict] = None,
|
||||
traffic_data: List[Dict] = None,
|
||||
job_id: str = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Train models for all products of a tenant.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
sales_data: Historical sales data
|
||||
weather_data: Weather data (optional)
|
||||
traffic_data: Traffic data (optional)
|
||||
job_id: Training job identifier
|
||||
|
||||
Returns:
|
||||
Dictionary with training results for each product
|
||||
"""
|
||||
if not job_id:
|
||||
job_id = f"training_{tenant_id}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
logger.info(f"Starting training job {job_id} for tenant {tenant_id}")
|
||||
|
||||
try:
|
||||
# Convert input data to DataFrames
|
||||
sales_df = pd.DataFrame(sales_data) if sales_data else pd.DataFrame()
|
||||
weather_df = pd.DataFrame(weather_data) if weather_data else pd.DataFrame()
|
||||
traffic_df = pd.DataFrame(traffic_data) if traffic_data else pd.DataFrame()
|
||||
|
||||
# Validate input data
|
||||
await self._validate_input_data(sales_df, tenant_id)
|
||||
|
||||
# Get unique products
|
||||
products = sales_df['product_name'].unique().tolist()
|
||||
logger.info(f"Training models for {len(products)} products: {products}")
|
||||
|
||||
# Process data for each product
|
||||
processed_data = await self._process_all_products(
|
||||
sales_df, weather_df, traffic_df, products
|
||||
)
|
||||
|
||||
# Train models for each product
|
||||
training_results = await self._train_all_models(
|
||||
tenant_id, processed_data, job_id
|
||||
)
|
||||
|
||||
# Calculate overall training summary
|
||||
summary = self._calculate_training_summary(training_results)
|
||||
|
||||
result = {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"status": "completed",
|
||||
"products_trained": len([r for r in training_results.values() if r.get('status') == 'success']),
|
||||
"products_failed": len([r for r in training_results.values() if r.get('status') == 'error']),
|
||||
"total_products": len(products),
|
||||
"training_results": training_results,
|
||||
"summary": summary,
|
||||
"completed_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
logger.info(f"Training job {job_id} completed successfully")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training job {job_id} failed: {str(e)}")
|
||||
raise
|
||||
|
||||
async def train_models(self, training_data: Dict[str, Any], job_id: str, db) -> Dict[str, Any]:
|
||||
"""Train models for all products"""
|
||||
async def train_single_product(self,
|
||||
tenant_id: str,
|
||||
product_name: str,
|
||||
sales_data: List[Dict],
|
||||
weather_data: List[Dict] = None,
|
||||
traffic_data: List[Dict] = None,
|
||||
job_id: str = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Train model for a single product.
|
||||
|
||||
models_result = {}
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
product_name: Product name
|
||||
sales_data: Historical sales data
|
||||
weather_data: Weather data (optional)
|
||||
traffic_data: Traffic data (optional)
|
||||
job_id: Training job identifier
|
||||
|
||||
Returns:
|
||||
Training result for the product
|
||||
"""
|
||||
if not job_id:
|
||||
job_id = f"training_{tenant_id}_{product_name}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
logger.info(f"Starting single product training {job_id} for {product_name}")
|
||||
|
||||
# Get sales data
|
||||
sales_data = training_data.get("sales_data", [])
|
||||
external_data = training_data.get("external_data", {})
|
||||
try:
|
||||
# Convert input data to DataFrames
|
||||
sales_df = pd.DataFrame(sales_data) if sales_data else pd.DataFrame()
|
||||
weather_df = pd.DataFrame(weather_data) if weather_data else pd.DataFrame()
|
||||
traffic_df = pd.DataFrame(traffic_data) if traffic_data else pd.DataFrame()
|
||||
|
||||
# Filter sales data for the specific product
|
||||
product_sales = sales_df[sales_df['product_name'] == product_name].copy()
|
||||
|
||||
# Validate product data
|
||||
if product_sales.empty:
|
||||
raise ValueError(f"No sales data found for product: {product_name}")
|
||||
|
||||
# Prepare training data
|
||||
processed_data = await self.data_processor.prepare_training_data(
|
||||
sales_data=product_sales,
|
||||
weather_data=weather_df,
|
||||
traffic_data=traffic_df,
|
||||
product_name=product_name
|
||||
)
|
||||
|
||||
# Train the model
|
||||
model_info = await self.prophet_manager.train_bakery_model(
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
df=processed_data,
|
||||
job_id=job_id
|
||||
)
|
||||
|
||||
result = {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name,
|
||||
"status": "success",
|
||||
"model_info": model_info,
|
||||
"data_points": len(processed_data),
|
||||
"completed_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
logger.info(f"Single product training {job_id} completed successfully")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Single product training {job_id} failed: {str(e)}")
|
||||
raise
|
||||
|
||||
async def evaluate_model_performance(self,
|
||||
tenant_id: str,
|
||||
product_name: str,
|
||||
model_path: str,
|
||||
test_data: List[Dict]) -> Dict[str, Any]:
|
||||
"""
|
||||
Evaluate model performance on test data.
|
||||
|
||||
# Group by product
|
||||
products_data = self._group_by_product(sales_data)
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
product_name: Product name
|
||||
model_path: Path to the trained model
|
||||
test_data: Test data for evaluation
|
||||
|
||||
Returns:
|
||||
Performance metrics
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Evaluating model performance for {product_name}")
|
||||
|
||||
# Convert test data to DataFrame
|
||||
test_df = pd.DataFrame(test_data)
|
||||
|
||||
# Prepare test data
|
||||
test_prepared = await self.data_processor.prepare_prediction_features(
|
||||
future_dates=test_df['ds'],
|
||||
weather_forecast=test_df if 'temperature' in test_df.columns else pd.DataFrame(),
|
||||
traffic_forecast=test_df if 'traffic_volume' in test_df.columns else pd.DataFrame()
|
||||
)
|
||||
|
||||
# Get regressor columns
|
||||
regressor_columns = [col for col in test_prepared.columns if col not in ['ds', 'y']]
|
||||
|
||||
# Generate predictions
|
||||
forecast = await self.prophet_manager.generate_forecast(
|
||||
model_path=model_path,
|
||||
future_dates=test_prepared,
|
||||
regressor_columns=regressor_columns
|
||||
)
|
||||
|
||||
# Calculate performance metrics if we have actual values
|
||||
metrics = {}
|
||||
if 'y' in test_df.columns:
|
||||
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
|
||||
|
||||
y_true = test_df['y'].values
|
||||
y_pred = forecast['yhat'].values
|
||||
|
||||
metrics = {
|
||||
"mae": float(mean_absolute_error(y_true, y_pred)),
|
||||
"rmse": float(np.sqrt(mean_squared_error(y_true, y_pred))),
|
||||
"mape": float(np.mean(np.abs((y_true - y_pred) / y_true)) * 100),
|
||||
"r2_score": float(r2_score(y_true, y_pred))
|
||||
}
|
||||
|
||||
result = {
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name,
|
||||
"evaluation_metrics": metrics,
|
||||
"forecast_samples": len(forecast),
|
||||
"evaluated_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Model evaluation failed: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _validate_input_data(self, sales_df: pd.DataFrame, tenant_id: str):
|
||||
"""Validate input sales data"""
|
||||
if sales_df.empty:
|
||||
raise ValueError(f"No sales data provided for tenant {tenant_id}")
|
||||
|
||||
# Train model for each product
|
||||
for product_name, product_sales in products_data.items():
|
||||
required_columns = ['date', 'product_name', 'quantity']
|
||||
missing_columns = [col for col in required_columns if col not in sales_df.columns]
|
||||
if missing_columns:
|
||||
raise ValueError(f"Missing required columns: {missing_columns}")
|
||||
|
||||
# Check for valid dates
|
||||
try:
|
||||
sales_df['date'] = pd.to_datetime(sales_df['date'])
|
||||
except Exception:
|
||||
raise ValueError("Invalid date format in sales data")
|
||||
|
||||
# Check for valid quantities
|
||||
if not sales_df['quantity'].dtype in ['int64', 'float64']:
|
||||
raise ValueError("Quantity column must be numeric")
|
||||
|
||||
async def _process_all_products(self,
|
||||
sales_df: pd.DataFrame,
|
||||
weather_df: pd.DataFrame,
|
||||
traffic_df: pd.DataFrame,
|
||||
products: List[str]) -> Dict[str, pd.DataFrame]:
|
||||
"""Process data for all products"""
|
||||
processed_data = {}
|
||||
|
||||
for product_name in products:
|
||||
try:
|
||||
model_result = await self._train_product_model(
|
||||
product_name,
|
||||
product_sales,
|
||||
external_data,
|
||||
job_id
|
||||
logger.info(f"Processing data for product: {product_name}")
|
||||
|
||||
# Filter sales data for this product
|
||||
product_sales = sales_df[sales_df['product_name'] == product_name].copy()
|
||||
|
||||
# Process the product data
|
||||
processed_product_data = await self.data_processor.prepare_training_data(
|
||||
sales_data=product_sales,
|
||||
weather_data=weather_df,
|
||||
traffic_data=traffic_df,
|
||||
product_name=product_name
|
||||
)
|
||||
models_result[product_name] = model_result
|
||||
|
||||
processed_data[product_name] = processed_product_data
|
||||
logger.info(f"Processed {len(processed_product_data)} data points for {product_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to train model for {product_name}: {e}")
|
||||
logger.error(f"Failed to process data for {product_name}: {str(e)}")
|
||||
# Continue with other products
|
||||
continue
|
||||
|
||||
return models_result
|
||||
return processed_data
|
||||
|
||||
def _group_by_product(self, sales_data: List[Dict]) -> Dict[str, List[Dict]]:
|
||||
"""Group sales data by product"""
|
||||
async def _train_all_models(self,
|
||||
tenant_id: str,
|
||||
processed_data: Dict[str, pd.DataFrame],
|
||||
job_id: str) -> Dict[str, Any]:
|
||||
"""Train models for all processed products"""
|
||||
training_results = {}
|
||||
|
||||
products = {}
|
||||
for sale in sales_data:
|
||||
product_name = sale.get("product_name")
|
||||
if product_name not in products:
|
||||
products[product_name] = []
|
||||
products[product_name].append(sale)
|
||||
|
||||
return products
|
||||
|
||||
async def _train_product_model(self, product_name: str, sales_data: List[Dict], external_data: Dict, job_id: str) -> Dict[str, Any]:
|
||||
"""Train Prophet model for a single product"""
|
||||
|
||||
# Convert to DataFrame
|
||||
df = pd.DataFrame(sales_data)
|
||||
df['date'] = pd.to_datetime(df['date'])
|
||||
|
||||
# Aggregate daily sales
|
||||
daily_sales = df.groupby('date')['quantity_sold'].sum().reset_index()
|
||||
daily_sales.columns = ['ds', 'y']
|
||||
|
||||
# Add external features
|
||||
daily_sales = self._add_external_features(daily_sales, external_data)
|
||||
|
||||
# Train Prophet model
|
||||
model = Prophet(
|
||||
seasonality_mode=settings.PROPHET_SEASONALITY_MODE,
|
||||
daily_seasonality=settings.PROPHET_DAILY_SEASONALITY,
|
||||
weekly_seasonality=settings.PROPHET_WEEKLY_SEASONALITY,
|
||||
yearly_seasonality=settings.PROPHET_YEARLY_SEASONALITY
|
||||
)
|
||||
|
||||
# Add regressors
|
||||
model.add_regressor('temperature')
|
||||
model.add_regressor('humidity')
|
||||
model.add_regressor('precipitation')
|
||||
model.add_regressor('traffic_volume')
|
||||
|
||||
# Fit model
|
||||
model.fit(daily_sales)
|
||||
|
||||
# Save model
|
||||
model_path = os.path.join(
|
||||
self.model_storage_path,
|
||||
f"{job_id}_{product_name}_prophet_model.pkl"
|
||||
)
|
||||
|
||||
joblib.dump(model, model_path)
|
||||
|
||||
return {
|
||||
"type": "prophet",
|
||||
"path": model_path,
|
||||
"training_samples": len(daily_sales),
|
||||
"features": ["temperature", "humidity", "precipitation", "traffic_volume"],
|
||||
"hyperparameters": {
|
||||
"seasonality_mode": settings.PROPHET_SEASONALITY_MODE,
|
||||
"daily_seasonality": settings.PROPHET_DAILY_SEASONALITY,
|
||||
"weekly_seasonality": settings.PROPHET_WEEKLY_SEASONALITY,
|
||||
"yearly_seasonality": settings.PROPHET_YEARLY_SEASONALITY
|
||||
}
|
||||
}
|
||||
|
||||
def _add_external_features(self, daily_sales: pd.DataFrame, external_data: Dict) -> pd.DataFrame:
|
||||
"""Add external features to sales data"""
|
||||
|
||||
# Add weather data
|
||||
weather_data = external_data.get("weather", [])
|
||||
if weather_data:
|
||||
weather_df = pd.DataFrame(weather_data)
|
||||
weather_df['ds'] = pd.to_datetime(weather_df['date'])
|
||||
daily_sales = daily_sales.merge(weather_df[['ds', 'temperature', 'humidity', 'precipitation']], on='ds', how='left')
|
||||
|
||||
# Add traffic data
|
||||
traffic_data = external_data.get("traffic", [])
|
||||
if traffic_data:
|
||||
traffic_df = pd.DataFrame(traffic_data)
|
||||
traffic_df['ds'] = pd.to_datetime(traffic_df['date'])
|
||||
daily_sales = daily_sales.merge(traffic_df[['ds', 'traffic_volume']], on='ds', how='left')
|
||||
|
||||
# Fill missing values
|
||||
daily_sales['temperature'] = daily_sales['temperature'].fillna(daily_sales['temperature'].mean())
|
||||
daily_sales['humidity'] = daily_sales['humidity'].fillna(daily_sales['humidity'].mean())
|
||||
daily_sales['precipitation'] = daily_sales['precipitation'].fillna(0)
|
||||
daily_sales['traffic_volume'] = daily_sales['traffic_volume'].fillna(daily_sales['traffic_volume'].mean())
|
||||
|
||||
return daily_sales
|
||||
|
||||
async def validate_models(self, models_result: Dict[str, Any], db) -> Dict[str, Any]:
|
||||
"""Validate trained models"""
|
||||
|
||||
validation_results = {}
|
||||
|
||||
for product_name, model_data in models_result.items():
|
||||
for product_name, product_data in processed_data.items():
|
||||
try:
|
||||
# Load model
|
||||
model_path = model_data.get("path")
|
||||
model = joblib.load(model_path)
|
||||
logger.info(f"Training model for product: {product_name}")
|
||||
|
||||
# Mock validation for now (in production, you'd use actual validation data)
|
||||
validation_results[product_name] = {
|
||||
"mape": np.random.uniform(10, 25), # Mock MAPE between 10-25%
|
||||
"rmse": np.random.uniform(8, 15), # Mock RMSE
|
||||
"mae": np.random.uniform(5, 12), # Mock MAE
|
||||
"r2_score": np.random.uniform(0.7, 0.9) # Mock R2 score
|
||||
# Check if we have enough data
|
||||
if len(product_data) < settings.MIN_TRAINING_DATA_DAYS:
|
||||
training_results[product_name] = {
|
||||
'status': 'skipped',
|
||||
'reason': 'insufficient_data',
|
||||
'data_points': len(product_data),
|
||||
'min_required': settings.MIN_TRAINING_DATA_DAYS
|
||||
}
|
||||
continue
|
||||
|
||||
# Train the model
|
||||
model_info = await self.prophet_manager.train_bakery_model(
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
df=product_data,
|
||||
job_id=job_id
|
||||
)
|
||||
|
||||
training_results[product_name] = {
|
||||
'status': 'success',
|
||||
'model_info': model_info,
|
||||
'data_points': len(product_data),
|
||||
'trained_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
logger.info(f"Successfully trained model for {product_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Validation failed for {product_name}: {e}")
|
||||
validation_results[product_name] = {
|
||||
"mape": None,
|
||||
"rmse": None,
|
||||
"mae": None,
|
||||
"r2_score": None,
|
||||
"error": str(e)
|
||||
logger.error(f"Failed to train model for {product_name}: {str(e)}")
|
||||
training_results[product_name] = {
|
||||
'status': 'error',
|
||||
'error_message': str(e),
|
||||
'data_points': len(product_data) if product_data is not None else 0
|
||||
}
|
||||
|
||||
return validation_results
|
||||
return training_results
|
||||
|
||||
def _calculate_training_summary(self, training_results: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Calculate summary statistics from training results"""
|
||||
total_products = len(training_results)
|
||||
successful_products = len([r for r in training_results.values() if r.get('status') == 'success'])
|
||||
failed_products = len([r for r in training_results.values() if r.get('status') == 'error'])
|
||||
skipped_products = len([r for r in training_results.values() if r.get('status') == 'skipped'])
|
||||
|
||||
# Calculate average training metrics for successful models
|
||||
successful_results = [r for r in training_results.values() if r.get('status') == 'success']
|
||||
|
||||
avg_metrics = {}
|
||||
if successful_results:
|
||||
metrics_list = [r['model_info'].get('training_metrics', {}) for r in successful_results]
|
||||
|
||||
if metrics_list and all(metrics_list):
|
||||
avg_metrics = {
|
||||
'avg_mae': np.mean([m.get('mae', 0) for m in metrics_list]),
|
||||
'avg_rmse': np.mean([m.get('rmse', 0) for m in metrics_list]),
|
||||
'avg_mape': np.mean([m.get('mape', 0) for m in metrics_list]),
|
||||
'avg_r2': np.mean([m.get('r2_score', 0) for m in metrics_list])
|
||||
}
|
||||
|
||||
return {
|
||||
'total_products': total_products,
|
||||
'successful_products': successful_products,
|
||||
'failed_products': failed_products,
|
||||
'skipped_products': skipped_products,
|
||||
'success_rate': round(successful_products / total_products * 100, 2) if total_products > 0 else 0,
|
||||
'average_metrics': avg_metrics
|
||||
}
|
||||
@@ -1,101 +1,154 @@
|
||||
# services/training/app/models/training.py
|
||||
"""
|
||||
Training models - Fixed version
|
||||
Database models for training service
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, Integer, DateTime, Text, JSON, Boolean, Float
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Text, Boolean, JSON, Float
|
||||
from sqlalchemy.dialects.postgresql import UUID, ARRAY
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
from shared.database.base import Base
|
||||
Base = declarative_base()
|
||||
|
||||
class TrainingJob(Base):
|
||||
"""Training job model"""
|
||||
__tablename__ = "training_jobs"
|
||||
class ModelTrainingLog(Base):
|
||||
"""
|
||||
Table to track training job execution and status.
|
||||
Replaces the old Celery task tracking.
|
||||
"""
|
||||
__tablename__ = "model_training_logs"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
status = Column(String(20), nullable=False, default="queued") # queued, running, completed, failed
|
||||
progress = Column(Integer, default=0)
|
||||
current_step = Column(String(200))
|
||||
requested_by = Column(UUID(as_uuid=True), nullable=False)
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
job_id = Column(String(255), unique=True, index=True, nullable=False)
|
||||
tenant_id = Column(String(255), index=True, nullable=False)
|
||||
status = Column(String(50), nullable=False, default="pending") # pending, running, completed, failed, cancelled
|
||||
progress = Column(Integer, default=0) # 0-100 percentage
|
||||
current_step = Column(String(500), default="")
|
||||
|
||||
# Timing
|
||||
started_at = Column(DateTime, default=datetime.utcnow)
|
||||
completed_at = Column(DateTime)
|
||||
duration_seconds = Column(Integer)
|
||||
# Timestamps
|
||||
start_time = Column(DateTime, default=datetime.now)
|
||||
end_time = Column(DateTime, nullable=True)
|
||||
|
||||
# Results
|
||||
models_trained = Column(JSON)
|
||||
metrics = Column(JSON)
|
||||
error_message = Column(Text)
|
||||
# Configuration and results
|
||||
config = Column(JSON, nullable=True) # Training job configuration
|
||||
results = Column(JSON, nullable=True) # Training results
|
||||
error_message = Column(Text, nullable=True)
|
||||
|
||||
# Metadata
|
||||
training_data_from = Column(DateTime)
|
||||
training_data_to = Column(DateTime)
|
||||
total_data_points = Column(Integer)
|
||||
products_count = Column(Integer)
|
||||
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<TrainingJob(id={self.id}, tenant_id={self.tenant_id}, status={self.status})>"
|
||||
created_at = Column(DateTime, default=datetime.now)
|
||||
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
|
||||
|
||||
class TrainedModel(Base):
|
||||
"""Trained model information"""
|
||||
"""
|
||||
Table to store information about trained models.
|
||||
"""
|
||||
__tablename__ = "trained_models"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
training_job_id = Column(UUID(as_uuid=True), nullable=False)
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
model_id = Column(String(255), unique=True, index=True, nullable=False)
|
||||
tenant_id = Column(String(255), index=True, nullable=False)
|
||||
product_name = Column(String(255), index=True, nullable=False)
|
||||
|
||||
# Model details
|
||||
product_name = Column(String(100), nullable=False)
|
||||
model_type = Column(String(50), nullable=False, default="prophet")
|
||||
model_version = Column(String(20), nullable=False)
|
||||
model_path = Column(String(500)) # Path to saved model file
|
||||
# Model information
|
||||
model_type = Column(String(50), nullable=False, default="prophet") # prophet, arima, etc.
|
||||
model_path = Column(String(1000), nullable=False) # Path to stored model file
|
||||
version = Column(Integer, nullable=False, default=1)
|
||||
|
||||
# Training information
|
||||
training_samples = Column(Integer, nullable=False, default=0)
|
||||
features = Column(ARRAY(String), nullable=True) # List of features used
|
||||
hyperparameters = Column(JSON, nullable=True) # Model hyperparameters
|
||||
training_metrics = Column(JSON, nullable=True) # Training performance metrics
|
||||
|
||||
# Data period information
|
||||
data_period_start = Column(DateTime, nullable=True)
|
||||
data_period_end = Column(DateTime, nullable=True)
|
||||
|
||||
# Status and metadata
|
||||
is_active = Column(Boolean, default=True, index=True)
|
||||
created_at = Column(DateTime, default=datetime.now)
|
||||
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
|
||||
|
||||
class ModelPerformanceMetric(Base):
|
||||
"""
|
||||
Table to track model performance over time.
|
||||
"""
|
||||
__tablename__ = "model_performance_metrics"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
model_id = Column(String(255), index=True, nullable=False)
|
||||
tenant_id = Column(String(255), index=True, nullable=False)
|
||||
product_name = Column(String(255), index=True, nullable=False)
|
||||
|
||||
# Performance metrics
|
||||
mape = Column(Float) # Mean Absolute Percentage Error
|
||||
rmse = Column(Float) # Root Mean Square Error
|
||||
mae = Column(Float) # Mean Absolute Error
|
||||
r2_score = Column(Float) # R-squared score
|
||||
mae = Column(Float, nullable=True) # Mean Absolute Error
|
||||
mse = Column(Float, nullable=True) # Mean Squared Error
|
||||
rmse = Column(Float, nullable=True) # Root Mean Squared Error
|
||||
mape = Column(Float, nullable=True) # Mean Absolute Percentage Error
|
||||
r2_score = Column(Float, nullable=True) # R-squared score
|
||||
|
||||
# Training details
|
||||
training_samples = Column(Integer)
|
||||
validation_samples = Column(Integer)
|
||||
features_used = Column(JSON)
|
||||
hyperparameters = Column(JSON)
|
||||
# Additional metrics
|
||||
accuracy_percentage = Column(Float, nullable=True)
|
||||
prediction_confidence = Column(Float, nullable=True)
|
||||
|
||||
# Evaluation information
|
||||
evaluation_period_start = Column(DateTime, nullable=True)
|
||||
evaluation_period_end = Column(DateTime, nullable=True)
|
||||
evaluation_samples = Column(Integer, nullable=True)
|
||||
|
||||
# Metadata
|
||||
measured_at = Column(DateTime, default=datetime.now)
|
||||
created_at = Column(DateTime, default=datetime.now)
|
||||
|
||||
class TrainingJobQueue(Base):
|
||||
"""
|
||||
Table to manage training job queue and scheduling.
|
||||
"""
|
||||
__tablename__ = "training_job_queue"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
job_id = Column(String(255), unique=True, index=True, nullable=False)
|
||||
tenant_id = Column(String(255), index=True, nullable=False)
|
||||
|
||||
# Job configuration
|
||||
job_type = Column(String(50), nullable=False) # full_training, single_product, evaluation
|
||||
priority = Column(Integer, default=1) # Higher number = higher priority
|
||||
config = Column(JSON, nullable=True)
|
||||
|
||||
# Scheduling information
|
||||
scheduled_at = Column(DateTime, nullable=True)
|
||||
started_at = Column(DateTime, nullable=True)
|
||||
estimated_duration_minutes = Column(Integer, nullable=True)
|
||||
|
||||
# Status
|
||||
is_active = Column(Boolean, default=True)
|
||||
last_used_at = Column(DateTime)
|
||||
status = Column(String(50), nullable=False, default="queued") # queued, running, completed, failed
|
||||
retry_count = Column(Integer, default=0)
|
||||
max_retries = Column(Integer, default=3)
|
||||
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<TrainedModel(id={self.id}, product={self.product_name}, tenant={self.tenant_id})>"
|
||||
# Metadata
|
||||
created_at = Column(DateTime, default=datetime.now)
|
||||
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
|
||||
|
||||
class TrainingLog(Base):
|
||||
"""Training log entries - FIXED: renamed metadata to log_metadata"""
|
||||
__tablename__ = "training_logs"
|
||||
class ModelArtifact(Base):
|
||||
"""
|
||||
Table to track model files and artifacts.
|
||||
"""
|
||||
__tablename__ = "model_artifacts"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
training_job_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
model_id = Column(String(255), index=True, nullable=False)
|
||||
tenant_id = Column(String(255), index=True, nullable=False)
|
||||
|
||||
level = Column(String(10), nullable=False) # DEBUG, INFO, WARNING, ERROR
|
||||
message = Column(Text, nullable=False)
|
||||
step = Column(String(100))
|
||||
progress = Column(Integer)
|
||||
# Artifact information
|
||||
artifact_type = Column(String(50), nullable=False) # model_file, metadata, training_data, etc.
|
||||
file_path = Column(String(1000), nullable=False)
|
||||
file_size_bytes = Column(Integer, nullable=True)
|
||||
checksum = Column(String(255), nullable=True) # For file integrity
|
||||
|
||||
# Additional data
|
||||
execution_time = Column(Float) # Time taken for this step
|
||||
memory_usage = Column(Float) # Memory usage in MB
|
||||
log_metadata = Column(JSON) # FIXED: renamed from 'metadata' to 'log_metadata'
|
||||
# Storage information
|
||||
storage_location = Column(String(100), nullable=False, default="local") # local, s3, gcs, etc.
|
||||
compression = Column(String(50), nullable=True) # gzip, lz4, etc.
|
||||
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<TrainingLog(id={self.id}, level={self.level})>"
|
||||
# Metadata
|
||||
created_at = Column(DateTime, default=datetime.now)
|
||||
expires_at = Column(DateTime, nullable=True) # For automatic cleanup
|
||||
@@ -1,91 +1,181 @@
|
||||
# services/training/app/schemas/training.py
|
||||
"""
|
||||
Training schemas
|
||||
Pydantic schemas for training service
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from typing import Optional, Dict, Any, List
|
||||
from typing import Dict, List, Any, Optional
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
class TrainingJobStatus(str, Enum):
|
||||
"""Training job status enum"""
|
||||
QUEUED = "queued"
|
||||
class TrainingStatus(str, Enum):
|
||||
"""Training job status enumeration"""
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
class TrainingRequest(BaseModel):
|
||||
"""Training request schema"""
|
||||
tenant_id: Optional[str] = None # Will be set from auth
|
||||
force_retrain: bool = Field(default=False, description="Force retrain even if recent models exist")
|
||||
products: Optional[List[str]] = Field(default=None, description="Specific products to train, or None for all")
|
||||
training_days: Optional[int] = Field(default=730, ge=30, le=1095, description="Number of days of historical data to use")
|
||||
class TrainingJobRequest(BaseModel):
|
||||
"""Request schema for starting a training job"""
|
||||
products: Optional[List[str]] = Field(None, description="Specific products to train (if None, train all)")
|
||||
include_weather: bool = Field(True, description="Include weather data in training")
|
||||
include_traffic: bool = Field(True, description="Include traffic data in training")
|
||||
start_date: Optional[datetime] = Field(None, description="Start date for training data")
|
||||
end_date: Optional[datetime] = Field(None, description="End date for training data")
|
||||
min_data_points: int = Field(30, description="Minimum data points required per product")
|
||||
estimated_duration: Optional[int] = Field(None, description="Estimated duration in minutes")
|
||||
|
||||
@validator('training_days')
|
||||
def validate_training_days(cls, v):
|
||||
if v < 30:
|
||||
raise ValueError('Minimum training days is 30')
|
||||
if v > 1095:
|
||||
raise ValueError('Maximum training days is 1095 (3 years)')
|
||||
# Prophet-specific parameters
|
||||
seasonality_mode: str = Field("additive", description="Prophet seasonality mode")
|
||||
daily_seasonality: bool = Field(True, description="Enable daily seasonality")
|
||||
weekly_seasonality: bool = Field(True, description="Enable weekly seasonality")
|
||||
yearly_seasonality: bool = Field(True, description="Enable yearly seasonality")
|
||||
|
||||
@validator('seasonality_mode')
|
||||
def validate_seasonality_mode(cls, v):
|
||||
if v not in ['additive', 'multiplicative']:
|
||||
raise ValueError('seasonality_mode must be additive or multiplicative')
|
||||
return v
|
||||
|
||||
@validator('min_data_points')
|
||||
def validate_min_data_points(cls, v):
|
||||
if v < 7:
|
||||
raise ValueError('min_data_points must be at least 7')
|
||||
return v
|
||||
|
||||
class SingleProductTrainingRequest(BaseModel):
|
||||
"""Request schema for training a single product"""
|
||||
include_weather: bool = Field(True, description="Include weather data in training")
|
||||
include_traffic: bool = Field(True, description="Include traffic data in training")
|
||||
start_date: Optional[datetime] = Field(None, description="Start date for training data")
|
||||
end_date: Optional[datetime] = Field(None, description="End date for training data")
|
||||
|
||||
# Prophet-specific parameters
|
||||
seasonality_mode: str = Field("additive", description="Prophet seasonality mode")
|
||||
daily_seasonality: bool = Field(True, description="Enable daily seasonality")
|
||||
weekly_seasonality: bool = Field(True, description="Enable weekly seasonality")
|
||||
yearly_seasonality: bool = Field(True, description="Enable yearly seasonality")
|
||||
|
||||
class TrainingJobResponse(BaseModel):
|
||||
"""Training job response schema"""
|
||||
id: str
|
||||
tenant_id: str
|
||||
status: TrainingJobStatus
|
||||
progress: int
|
||||
current_step: Optional[str]
|
||||
started_at: datetime
|
||||
completed_at: Optional[datetime]
|
||||
duration_seconds: Optional[int]
|
||||
models_trained: Optional[Dict[str, Any]]
|
||||
metrics: Optional[Dict[str, Any]]
|
||||
error_message: Optional[str]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
"""Response schema for training job creation"""
|
||||
job_id: str = Field(..., description="Unique training job identifier")
|
||||
status: TrainingStatus = Field(..., description="Current job status")
|
||||
message: str = Field(..., description="Status message")
|
||||
tenant_id: str = Field(..., description="Tenant identifier")
|
||||
created_at: datetime = Field(..., description="Job creation timestamp")
|
||||
estimated_duration_minutes: int = Field(..., description="Estimated completion time in minutes")
|
||||
|
||||
class TrainedModelResponse(BaseModel):
|
||||
"""Trained model response schema"""
|
||||
id: str
|
||||
product_name: str
|
||||
model_type: str
|
||||
model_version: str
|
||||
mape: Optional[float]
|
||||
rmse: Optional[float]
|
||||
mae: Optional[float]
|
||||
r2_score: Optional[float]
|
||||
training_samples: Optional[int]
|
||||
features_used: Optional[List[str]]
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
last_used_at: Optional[datetime]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
class TrainingStatusResponse(BaseModel):
|
||||
"""Response schema for training job status"""
|
||||
job_id: str = Field(..., description="Training job identifier")
|
||||
status: TrainingStatus = Field(..., description="Current job status")
|
||||
progress: int = Field(0, description="Progress percentage (0-100)")
|
||||
current_step: str = Field("", description="Current processing step")
|
||||
started_at: datetime = Field(..., description="Job start timestamp")
|
||||
completed_at: Optional[datetime] = Field(None, description="Job completion timestamp")
|
||||
results: Optional[Dict[str, Any]] = Field(None, description="Training results")
|
||||
error_message: Optional[str] = Field(None, description="Error message if failed")
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
"""Schema for trained model information"""
|
||||
model_id: str = Field(..., description="Unique model identifier")
|
||||
model_path: str = Field(..., description="Path to stored model")
|
||||
model_type: str = Field("prophet", description="Type of ML model")
|
||||
training_samples: int = Field(..., description="Number of training samples")
|
||||
features: List[str] = Field(..., description="List of features used")
|
||||
hyperparameters: Dict[str, Any] = Field(..., description="Model hyperparameters")
|
||||
training_metrics: Dict[str, float] = Field(..., description="Training performance metrics")
|
||||
trained_at: datetime = Field(..., description="Training completion timestamp")
|
||||
data_period: Dict[str, str] = Field(..., description="Training data period")
|
||||
|
||||
class ProductTrainingResult(BaseModel):
|
||||
"""Schema for individual product training result"""
|
||||
product_name: str = Field(..., description="Product name")
|
||||
status: str = Field(..., description="Training status for this product")
|
||||
model_info: Optional[ModelInfo] = Field(None, description="Model information if successful")
|
||||
data_points: int = Field(..., description="Number of data points used")
|
||||
error_message: Optional[str] = Field(None, description="Error message if failed")
|
||||
trained_at: datetime = Field(..., description="Training completion timestamp")
|
||||
|
||||
class TrainingResultsResponse(BaseModel):
|
||||
"""Response schema for complete training results"""
|
||||
job_id: str = Field(..., description="Training job identifier")
|
||||
tenant_id: str = Field(..., description="Tenant identifier")
|
||||
status: TrainingStatus = Field(..., description="Overall job status")
|
||||
products_trained: int = Field(..., description="Number of products successfully trained")
|
||||
products_failed: int = Field(..., description="Number of products that failed training")
|
||||
total_products: int = Field(..., description="Total number of products processed")
|
||||
training_results: Dict[str, ProductTrainingResult] = Field(..., description="Per-product results")
|
||||
summary: Dict[str, Any] = Field(..., description="Training summary statistics")
|
||||
completed_at: datetime = Field(..., description="Job completion timestamp")
|
||||
|
||||
class TrainingValidationResult(BaseModel):
|
||||
"""Schema for training data validation results"""
|
||||
is_valid: bool = Field(..., description="Whether the data is valid for training")
|
||||
issues: List[str] = Field(default_factory=list, description="List of data quality issues")
|
||||
recommendations: List[str] = Field(default_factory=list, description="Recommendations for improvement")
|
||||
estimated_time_minutes: int = Field(..., description="Estimated training time in minutes")
|
||||
products_analyzed: int = Field(..., description="Number of products analyzed")
|
||||
total_data_points: int = Field(..., description="Total data points available")
|
||||
|
||||
class TrainingProgress(BaseModel):
|
||||
"""Training progress update schema"""
|
||||
job_id: str
|
||||
progress: int
|
||||
current_step: str
|
||||
estimated_completion: Optional[datetime]
|
||||
|
||||
class TrainingMetrics(BaseModel):
|
||||
"""Training metrics schema"""
|
||||
total_jobs: int
|
||||
successful_jobs: int
|
||||
failed_jobs: int
|
||||
average_duration: float
|
||||
models_trained: int
|
||||
active_models: int
|
||||
"""Schema for training performance metrics"""
|
||||
mae: float = Field(..., description="Mean Absolute Error")
|
||||
mse: float = Field(..., description="Mean Squared Error")
|
||||
rmse: float = Field(..., description="Root Mean Squared Error")
|
||||
mape: float = Field(..., description="Mean Absolute Percentage Error")
|
||||
r2_score: float = Field(..., description="R-squared score")
|
||||
mean_actual: float = Field(..., description="Mean of actual values")
|
||||
mean_predicted: float = Field(..., description="Mean of predicted values")
|
||||
|
||||
class ModelValidationResult(BaseModel):
|
||||
"""Model validation result schema"""
|
||||
product_name: str
|
||||
is_valid: bool
|
||||
accuracy_score: float
|
||||
validation_error: Optional[str]
|
||||
recommendations: List[str]
|
||||
class ExternalDataConfig(BaseModel):
|
||||
"""Configuration for external data sources"""
|
||||
weather_enabled: bool = Field(True, description="Enable weather data")
|
||||
traffic_enabled: bool = Field(True, description="Enable traffic data")
|
||||
weather_features: List[str] = Field(
|
||||
default_factory=lambda: ["temperature", "precipitation", "humidity"],
|
||||
description="Weather features to include"
|
||||
)
|
||||
traffic_features: List[str] = Field(
|
||||
default_factory=lambda: ["traffic_volume"],
|
||||
description="Traffic features to include"
|
||||
)
|
||||
|
||||
class TrainingJobConfig(BaseModel):
|
||||
"""Complete training job configuration"""
|
||||
external_data: ExternalDataConfig = Field(default_factory=ExternalDataConfig)
|
||||
prophet_params: Dict[str, Any] = Field(
|
||||
default_factory=lambda: {
|
||||
"seasonality_mode": "additive",
|
||||
"daily_seasonality": True,
|
||||
"weekly_seasonality": True,
|
||||
"yearly_seasonality": True
|
||||
},
|
||||
description="Prophet model parameters"
|
||||
)
|
||||
data_filters: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Data filtering parameters"
|
||||
)
|
||||
validation_params: Dict[str, Any] = Field(
|
||||
default_factory=lambda: {"min_data_points": 30},
|
||||
description="Data validation parameters"
|
||||
)
|
||||
|
||||
class TrainedModelResponse(BaseModel):
|
||||
"""Response schema for trained model information"""
|
||||
model_id: str = Field(..., description="Unique model identifier")
|
||||
tenant_id: str = Field(..., description="Tenant identifier")
|
||||
product_name: str = Field(..., description="Product name")
|
||||
model_type: str = Field(..., description="Type of ML model")
|
||||
model_path: str = Field(..., description="Path to stored model")
|
||||
version: int = Field(..., description="Model version")
|
||||
training_samples: int = Field(..., description="Number of training samples")
|
||||
features: List[str] = Field(..., description="List of features used")
|
||||
hyperparameters: Dict[str, Any] = Field(..., description="Model hyperparameters")
|
||||
training_metrics: Dict[str, float] = Field(..., description="Training performance metrics")
|
||||
is_active: bool = Field(..., description="Whether model is active")
|
||||
created_at: datetime = Field(..., description="Model creation timestamp")
|
||||
data_period_start: Optional[datetime] = Field(None, description="Training data start date")
|
||||
data_period_end: Optional[datetime] = Field(None, description="Training data end date")
|
||||
@@ -1,12 +1,17 @@
|
||||
# ================================================================
|
||||
# services/training/app/services/messaging.py
|
||||
# ================================================================
|
||||
"""
|
||||
Messaging service for training service
|
||||
Training service messaging - Clean interface for training-specific events
|
||||
Uses shared RabbitMQ infrastructure
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from typing import Dict, Any, Optional
|
||||
from shared.messaging.rabbitmq import RabbitMQClient
|
||||
from shared.messaging.events import (
|
||||
TrainingStartedEvent,
|
||||
TrainingCompletedEvent,
|
||||
TrainingFailedEvent
|
||||
)
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
@@ -27,23 +32,188 @@ async def cleanup_messaging():
|
||||
await training_publisher.disconnect()
|
||||
logger.info("Training service messaging cleaned up")
|
||||
|
||||
# Convenience functions for training-specific events
|
||||
async def publish_training_started(job_data: dict) -> bool:
|
||||
"""Publish training started event"""
|
||||
return await training_publisher.publish_training_event("started", job_data)
|
||||
# Training Job Events
|
||||
async def publish_job_started(job_id: str, tenant_id: str, config: Dict[str, Any]) -> bool:
|
||||
"""Publish training job started event"""
|
||||
event = TrainingStartedEvent(
|
||||
service_name="training-service",
|
||||
data={
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"config": config
|
||||
}
|
||||
)
|
||||
return await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.started",
|
||||
event_data=event.to_dict()
|
||||
)
|
||||
|
||||
async def publish_training_completed(job_data: dict) -> bool:
|
||||
"""Publish training completed event"""
|
||||
return await training_publisher.publish_training_event("completed", job_data)
|
||||
async def publish_job_progress(job_id: str, tenant_id: str, progress: int, step: str) -> bool:
|
||||
"""Publish training job progress event"""
|
||||
return await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.progress",
|
||||
event_data={
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.progress",
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"progress": progress,
|
||||
"current_step": step
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
async def publish_training_failed(job_data: dict) -> bool:
|
||||
"""Publish training failed event"""
|
||||
return await training_publisher.publish_training_event("failed", job_data)
|
||||
async def publish_job_completed(job_id: str, tenant_id: str, results: Dict[str, Any]) -> bool:
|
||||
"""Publish training job completed event"""
|
||||
event = TrainingCompletedEvent(
|
||||
service_name="training-service",
|
||||
data={
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"results": results,
|
||||
"models_trained": results.get("products_trained", 0),
|
||||
"success_rate": results.get("summary", {}).get("success_rate", 0)
|
||||
}
|
||||
)
|
||||
return await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.completed",
|
||||
event_data=event.to_dict()
|
||||
)
|
||||
|
||||
async def publish_model_validated(model_data: dict) -> bool:
|
||||
async def publish_job_failed(job_id: str, tenant_id: str, error: str) -> bool:
|
||||
"""Publish training job failed event"""
|
||||
event = TrainingFailedEvent(
|
||||
service_name="training-service",
|
||||
data={
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"error": error
|
||||
}
|
||||
)
|
||||
return await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.failed",
|
||||
event_data=event.to_dict()
|
||||
)
|
||||
|
||||
async def publish_job_cancelled(job_id: str, tenant_id: str) -> bool:
|
||||
"""Publish training job cancelled event"""
|
||||
return await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.cancelled",
|
||||
event_data={
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.cancelled",
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Product Training Events
|
||||
async def publish_product_training_started(job_id: str, tenant_id: str, product_name: str) -> bool:
|
||||
"""Publish single product training started event"""
|
||||
return await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.product.started",
|
||||
event_data={
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.product.started",
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
async def publish_product_training_completed(job_id: str, tenant_id: str, product_name: str, model_id: str) -> bool:
|
||||
"""Publish single product training completed event"""
|
||||
return await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.product.completed",
|
||||
event_data={
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.product.completed",
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name,
|
||||
"model_id": model_id
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Model Events
|
||||
async def publish_model_trained(model_id: str, tenant_id: str, product_name: str, metrics: Dict[str, float]) -> bool:
|
||||
"""Publish model trained event"""
|
||||
return await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.model.trained",
|
||||
event_data={
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.model.trained",
|
||||
"data": {
|
||||
"model_id": model_id,
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name,
|
||||
"training_metrics": metrics
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
async def publish_model_updated(model_id: str, tenant_id: str, product_name: str, version: int) -> bool:
|
||||
"""Publish model updated event"""
|
||||
return await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.model.updated",
|
||||
event_data={
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.model.updated",
|
||||
"data": {
|
||||
"model_id": model_id,
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name,
|
||||
"version": version
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
async def publish_model_validated(model_id: str, tenant_id: str, product_name: str, validation_results: Dict[str, Any]) -> bool:
|
||||
"""Publish model validation event"""
|
||||
return await training_publisher.publish_training_event("model.validated", model_data)
|
||||
return await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.model.validated",
|
||||
event_data={
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.model.validated",
|
||||
"data": {
|
||||
"model_id": model_id,
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name,
|
||||
"validation_results": validation_results
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
async def publish_model_saved(model_data: dict) -> bool:
|
||||
async def publish_model_saved(model_id: str, tenant_id: str, product_name: str, model_path: str) -> bool:
|
||||
"""Publish model saved event"""
|
||||
return await training_publisher.publish_training_event("model.saved", model_data)
|
||||
return await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.model.saved",
|
||||
event_data={
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.model.saved",
|
||||
"data": {
|
||||
"model_id": model_id,
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name,
|
||||
"model_path": model_path
|
||||
}
|
||||
}
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user