REFACTOR - Database logic

This commit is contained in:
Urtzi Alfaro
2025-08-08 09:08:41 +02:00
parent 0154365bfc
commit 488bb3ef93
113 changed files with 22842 additions and 6503 deletions

View File

@@ -0,0 +1,14 @@
"""
Training API Layer
HTTP endpoints for ML training operations
"""
from .training import router as training_router
from .websocket import websocket_router
__all__ = [
"training_router",
"websocket_router"
]

View File

@@ -38,11 +38,12 @@ async def get_active_model(
Get the active model for a product - used by forecasting service
"""
try:
# ✅ FIX: Wrap SQL with text() for SQLAlchemy 2.0
logger.debug("Getting active model", tenant_id=tenant_id, product_name=product_name)
# ✅ FIX: Wrap SQL with text() for SQLAlchemy 2.0 and add case-insensitive product name matching
query = text("""
SELECT * FROM trained_models
WHERE tenant_id = :tenant_id
AND product_name = :product_name
AND LOWER(product_name) = LOWER(:product_name)
AND is_active = true
AND is_production = true
ORDER BY created_at DESC
@@ -57,6 +58,7 @@ async def get_active_model(
model_record = result.fetchone()
if not model_record:
logger.info("No active model found", tenant_id=tenant_id, product_name=product_name)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No active model found for product {product_name}"
@@ -76,7 +78,7 @@ async def get_active_model(
await db.commit()
return {
"model_id": model_record.id, # ✅ This is the correct field name
"model_id": str(model_record.id), # ✅ This is the correct field name
"model_path": model_record.model_path,
"features_used": model_record.features_used,
"hyperparameters": model_record.hyperparameters,
@@ -93,12 +95,24 @@ async def get_active_model(
}
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to get active model: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve model"
)
error_msg = str(e) if str(e) else f"{type(e).__name__}: {repr(e)}"
logger.error(f"Failed to get active model: {error_msg}", tenant_id=tenant_id, product_name=product_name)
# Handle client disconnection gracefully
if "EndOfStream" in str(type(e)) or "WouldBlock" in str(type(e)):
logger.info("Client disconnected during model retrieval", tenant_id=tenant_id, product_name=product_name)
raise HTTPException(
status_code=status.HTTP_408_REQUEST_TIMEOUT,
detail="Request connection closed"
)
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve model"
)
@router.get("/tenants/{tenant_id}/models/{model_id}/metrics", response_model=ModelMetricsResponse)
async def get_model_metrics(
@@ -126,7 +140,7 @@ async def get_model_metrics(
# Return metrics in the format expected by forecasting service
metrics = {
"model_id": model_record.id,
"model_id": str(model_record.id),
"accuracy": model_record.r2_score or 0.0, # Use R2 as accuracy measure
"mape": model_record.mape or 0.0,
"mae": model_record.mae or 0.0,
@@ -189,8 +203,8 @@ async def list_models(
models = []
for record in model_records:
models.append({
"model_id": record.id,
"tenant_id": record.tenant_id,
"model_id": str(record.id),
"tenant_id": str(record.tenant_id),
"product_name": record.product_name,
"model_type": record.model_type,
"model_path": record.model_path,

View File

@@ -1,25 +1,19 @@
# services/training/app/api/training.py
"""
Training API Endpoints - Entry point for training requests
Handles HTTP requests and delegates to Training Service
Enhanced Training API Endpoints with Repository Pattern
Updated to use repository pattern with dependency injection and improved error handling
"""
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks, Request
from fastapi import Query, Path
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, Optional, Dict, Any
import structlog
from datetime import datetime, timezone
import uuid
from app.core.database import get_db, get_background_db_session
from app.services.training_service import TrainingService, TrainingStatusManager
from sqlalchemy import select, delete, func
from app.services.training_service import EnhancedTrainingService
from app.schemas.training import (
TrainingJobRequest,
SingleProductTrainingRequest
)
from app.schemas.training import (
SingleProductTrainingRequest,
TrainingJobResponse
)
@@ -33,47 +27,71 @@ from app.services.messaging import (
publish_job_started
)
from shared.auth.decorators import require_admin_role, get_current_user_dep, get_current_tenant_id_dep
from shared.database.base import create_database_manager
from shared.monitoring.decorators import track_execution_time
from shared.monitoring.metrics import get_metrics_collector
from app.core.config import settings
logger = structlog.get_logger()
router = APIRouter()
router = APIRouter(tags=["enhanced-training"])
def get_enhanced_training_service():
"""Dependency injection for EnhancedTrainingService"""
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
return EnhancedTrainingService(database_manager)
@router.post("/tenants/{tenant_id}/training/jobs", response_model=TrainingJobResponse)
async def start_training_job(
@track_execution_time("enhanced_training_job_duration_seconds", "training-service")
async def start_enhanced_training_job(
request: TrainingJobRequest,
tenant_id: str = Path(..., description="Tenant ID"),
background_tasks: BackgroundTasks = BackgroundTasks(),
request_obj: Request = None,
current_tenant: str = Depends(get_current_tenant_id_dep),
db: AsyncSession = Depends(get_db)
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
):
"""
Start a new training job for all tenant products.
Start a new enhanced training job for all tenant products using repository pattern.
🚀 IMMEDIATE RESPONSE PATTERN:
1. Validate request immediately
2. Create job record with 'pending' status
3. Return 200 with job details
4. Execute training in background with separate DB session
🚀 ENHANCED IMMEDIATE RESPONSE PATTERN:
1. Validate request with enhanced validation
2. Create job record using repository pattern
3. Return 200 with enhanced job details
4. Execute enhanced training in background with repository tracking
This ensures fast API response while maintaining data consistency.
Enhanced features:
- Repository pattern for data access
- Enhanced error handling and logging
- Metrics tracking and monitoring
- Transactional operations
"""
metrics = get_metrics_collector(request_obj)
try:
# Validate tenant access immediately
# Enhanced tenant validation
if tenant_id != current_tenant:
if metrics:
metrics.increment_counter("enhanced_training_access_denied_total")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
# Generate job ID immediately
job_id = f"training_{tenant_id}_{uuid.uuid4().hex[:8]}"
# Generate enhanced job ID
job_id = f"enhanced_training_{tenant_id}_{uuid.uuid4().hex[:8]}"
logger.info(f"Creating training job {job_id} for tenant {tenant_id}")
logger.info("Creating enhanced training job using repository pattern",
job_id=job_id,
tenant_id=tenant_id)
# Add background task with isolated database session
# Record job creation metrics
if metrics:
metrics.increment_counter("enhanced_training_jobs_created_total")
# Add enhanced background task
background_tasks.add_task(
execute_training_job_background,
execute_enhanced_training_job_background,
tenant_id=tenant_id,
job_id=job_id,
bakery_location=(40.4168, -3.7038),
@@ -81,16 +99,16 @@ async def start_training_job(
requested_end=request.end_date
)
# Return immediate success response
# Return enhanced immediate success response
response_data = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "pending", # Will change to 'running' in background
"message": "Training job started successfully",
"status": "pending",
"message": "Enhanced training job started successfully using repository pattern",
"created_at": datetime.now(timezone.utc),
"estimated_duration_minutes": "15",
"estimated_duration_minutes": 18,
"training_results": {
"total_products": 10,
"total_products": 0, # Will be updated during processing
"successful_trainings": 0,
"failed_trainings": 0,
"products": [],
@@ -101,31 +119,45 @@ async def start_training_job(
"error_details": None,
"processing_metadata": {
"background_task": True,
"async_execution": True
"async_execution": True,
"enhanced_features": True,
"repository_pattern": True,
"dependency_injection": True
}
}
logger.info(f"Training job {job_id} queued successfully, returning immediate response")
logger.info("Enhanced training job queued successfully",
job_id=job_id,
features=["repository-pattern", "dependency-injection", "enhanced-tracking"])
return TrainingJobResponse(**response_data)
except HTTPException:
# Re-raise HTTP exceptions as-is
raise
except ValueError as e:
logger.error(f"Training job validation error: {str(e)}")
if metrics:
metrics.increment_counter("enhanced_training_validation_errors_total")
logger.error("Enhanced training job validation error",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e:
logger.error(f"Failed to queue training job: {str(e)}")
if metrics:
metrics.increment_counter("enhanced_training_job_errors_total")
logger.error("Failed to queue enhanced training job",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to start training job"
detail="Failed to start enhanced training job"
)
async def execute_training_job_background(
async def execute_enhanced_training_job_background(
tenant_id: str,
job_id: str,
bakery_location: tuple,
@@ -133,382 +165,457 @@ async def execute_training_job_background(
requested_end: Optional[datetime] = None
):
"""
Background task that executes the actual training job.
Enhanced background task that executes the training job using repository pattern.
🔧 KEY FEATURES:
- Uses its own database session (isolated from API request)
- Handles all errors gracefully
- Updates job status in real-time
- Publishes progress events via WebSocket/messaging
- Comprehensive logging and monitoring
🔧 ENHANCED FEATURES:
- Repository pattern for all data operations
- Enhanced error handling with structured logging
- Transactional operations for data consistency
- Comprehensive metrics tracking
- Database connection pooling
- Enhanced progress reporting
"""
logger.info(f"🚀 Background training job {job_id} started for tenant {tenant_id}")
logger.info("Enhanced background training job started",
job_id=job_id,
tenant_id=tenant_id,
features=["repository-pattern", "enhanced-tracking"])
async with get_background_db_session() as db_session:
try:
# ✅ FIX: Create training service with isolated DB session
training_service = TrainingService(db_session=db_session)
status_manager = TrainingStatusManager(db_session=db_session)
try:
training_config = {
"job_id": job_id,
"tenant_id": tenant_id,
"bakery_location": {
"latitude": 40.4168,
"longitude": -3.7038
},
"requested_start": requested_start if requested_start else None,
"requested_end": requested_end if requested_end else None,
"estimated_duration_minutes": 15,
"estimated_products": None,
"background_execution": True,
"api_version": "v1"
}
await status_manager.update_job_status(
job_id=job_id,
status="running",
progress=0,
current_step="Initializing training pipeline"
)
# Execute the actual training pipeline
result = await training_service.start_training_job(
tenant_id=tenant_id,
job_id=job_id,
bakery_location=bakery_location,
requested_start=requested_start,
requested_end=requested_end
)
await status_manager.update_job_status(
job_id=job_id,
status="completed",
progress=100,
current_step="Training completed successfully",
results=result
)
# Publish completion event
await publish_job_completed(
job_id=job_id,
tenant_id=tenant_id,
results=result
)
logger.info(f"✅ Background training job {job_id} completed successfully")
except Exception as training_error:
logger.error(f"❌ Training pipeline failed for job {job_id}: {str(training_error)}")
await status_manager.update_job_status(
job_id=job_id,
status="failed",
progress=0,
current_step="Training failed",
error_message=str(training_error)
)
# Publish failure event
await publish_job_failed(
job_id=job_id,
tenant_id=tenant_id,
error=str(training_error)
)
except Exception as background_error:
logger.error(f"💥 Critical error in background training job {job_id}: {str(background_error)}")
# Get enhanced training service with dependency injection
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
enhanced_training_service = EnhancedTrainingService(database_manager)
try:
# Publish job started event
await publish_job_started(job_id, tenant_id, {
"enhanced_features": True,
"repository_pattern": True,
"job_type": "enhanced_training"
})
finally:
# Ensure database session is properly closed
logger.info(f"🧹 Background training job {job_id} cleanup completed")
training_config = {
"job_id": job_id,
"tenant_id": tenant_id,
"bakery_location": {
"latitude": bakery_location[0],
"longitude": bakery_location[1]
},
"requested_start": requested_start.isoformat() if requested_start else None,
"requested_end": requested_end.isoformat() if requested_end else None,
"estimated_duration_minutes": 18,
"background_execution": True,
"enhanced_features": True,
"repository_pattern": True,
"api_version": "enhanced_v1"
}
# Update job status using repository pattern
await enhanced_training_service._update_job_status_repository(
job_id=job_id,
status="running",
progress=0,
current_step="Initializing enhanced training pipeline"
)
# Execute the enhanced training pipeline with repository pattern
result = await enhanced_training_service.start_training_job(
tenant_id=tenant_id,
job_id=job_id,
bakery_location=bakery_location,
requested_start=requested_start,
requested_end=requested_end
)
# Update final status using repository pattern
await enhanced_training_service._update_job_status_repository(
job_id=job_id,
status="completed",
progress=100,
current_step="Enhanced training completed successfully",
results=result
)
# Publish enhanced completion event
await publish_job_completed(
job_id=job_id,
tenant_id=tenant_id,
results={
**result,
"enhanced_features": True,
"repository_integration": True
}
)
logger.info("Enhanced background training job completed successfully",
job_id=job_id,
models_created=result.get('products_trained', 0),
features=["repository-pattern", "enhanced-tracking"])
except Exception as training_error:
logger.error("Enhanced training pipeline failed",
job_id=job_id,
error=str(training_error))
try:
await enhanced_training_service._update_job_status_repository(
job_id=job_id,
status="failed",
progress=0,
current_step="Enhanced training failed",
error_message=str(training_error)
)
except Exception as status_error:
logger.error("Failed to update job status after training error",
job_id=job_id,
status_error=str(status_error))
# Publish enhanced failure event
await publish_job_failed(
job_id=job_id,
tenant_id=tenant_id,
error=str(training_error),
metadata={
"enhanced_features": True,
"repository_pattern": True,
"error_type": type(training_error).__name__
}
)
except Exception as background_error:
logger.error("Critical error in enhanced background training job",
job_id=job_id,
error=str(background_error))
finally:
logger.info("Enhanced background training job cleanup completed",
job_id=job_id)
@router.post("/tenants/{tenant_id}/training/products/{product_name}", response_model=TrainingJobResponse)
async def start_single_product_training(
@track_execution_time("enhanced_single_product_training_duration_seconds", "training-service")
async def start_enhanced_single_product_training(
request: SingleProductTrainingRequest,
tenant_id: str = Path(..., description="Tenant ID"),
product_name: str = Path(..., description="Product name"),
request_obj: Request = None,
current_tenant: str = Depends(get_current_tenant_id_dep),
db: AsyncSession = Depends(get_db)
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
):
"""
Start training for a single product.
Start enhanced training for a single product using repository pattern.
Uses the same pipeline but filters for specific product.
Enhanced features:
- Repository pattern for data access
- Enhanced error handling and validation
- Metrics tracking
- Transactional operations
"""
training_service = TrainingService(db_session=db)
metrics = get_metrics_collector(request_obj)
try:
# Validate tenant access
# Enhanced tenant validation
if tenant_id != current_tenant:
if metrics:
metrics.increment_counter("enhanced_single_product_access_denied_total")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
logger.info(f"Starting single product training for {product_name} (tenant {tenant_id})")
logger.info("Starting enhanced single product training",
product_name=product_name,
tenant_id=tenant_id)
# Delegate to training service
result = await training_service.start_single_product_training(
# Record metrics
if metrics:
metrics.increment_counter("enhanced_single_product_training_total")
# Generate enhanced job ID
job_id = f"enhanced_single_{tenant_id}_{product_name}_{uuid.uuid4().hex[:8]}"
# Delegate to enhanced training service (single product method to be implemented)
result = await enhanced_training_service.start_single_product_training(
tenant_id=tenant_id,
product_name=product_name,
sales_data=request.sales_data,
bakery_location=request.bakery_location or (40.4168, -3.7038),
weather_data=request.weather_data,
traffic_data=request.traffic_data,
job_id=request.job_id
job_id=job_id,
bakery_location=request.bakery_location or (40.4168, -3.7038)
)
if metrics:
metrics.increment_counter("enhanced_single_product_training_success_total")
logger.info("Enhanced single product training completed",
product_name=product_name,
job_id=job_id)
return TrainingJobResponse(**result)
except ValueError as e:
logger.error(f"Single product training validation error: {str(e)}")
if metrics:
metrics.increment_counter("enhanced_single_product_validation_errors_total")
logger.error("Enhanced single product training validation error",
error=str(e),
product_name=product_name)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e:
logger.error(f"Single product training failed: {str(e)}")
if metrics:
metrics.increment_counter("enhanced_single_product_training_errors_total")
logger.error("Enhanced single product training failed",
error=str(e),
product_name=product_name)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Single product training failed"
detail="Enhanced single product training failed"
)
@router.get("/tenants/{tenant_id}/training/jobs/{job_id}/logs")
async def get_training_logs(
@router.get("/tenants/{tenant_id}/training/jobs/{job_id}/status")
@track_execution_time("enhanced_job_status_duration_seconds", "training-service")
async def get_enhanced_training_job_status(
tenant_id: str = Path(..., description="Tenant ID"),
job_id: str = Path(..., description="Job ID"),
limit: int = Query(100, description="Number of log entries to return"),
request_obj: Request = None,
current_tenant: str = Depends(get_current_tenant_id_dep),
db: AsyncSession = Depends(get_db)
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
):
"""
Get training job logs.
Get enhanced training job status using repository pattern.
"""
metrics = get_metrics_collector(request_obj)
try:
# Validate tenant access
if tenant_id != current_tenant:
if metrics:
metrics.increment_counter("enhanced_status_access_denied_total")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
# TODO: Implement log retrieval
# Get status using enhanced service
status_info = await enhanced_training_service.get_training_status(job_id)
if not status_info or status_info.get("error"):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Training job not found"
)
if metrics:
metrics.increment_counter("enhanced_status_requests_total")
return {
"job_id": job_id,
"logs": [
f"Training job {job_id} started",
"Data preprocessing completed",
"Model training completed",
"Training job finished successfully"
]
**status_info,
"enhanced_features": True,
"repository_integration": True
}
except HTTPException:
raise
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_status_errors_total")
logger.error("Failed to get enhanced training status",
job_id=job_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get training status"
)
@router.get("/tenants/{tenant_id}/models")
@track_execution_time("enhanced_models_list_duration_seconds", "training-service")
async def get_enhanced_tenant_models(
tenant_id: str = Path(..., description="Tenant ID"),
active_only: bool = Query(True, description="Return only active models"),
skip: int = Query(0, description="Number of models to skip"),
limit: int = Query(100, description="Number of models to return"),
request_obj: Request = None,
current_tenant: str = Depends(get_current_tenant_id_dep),
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
):
"""
Get tenant models using enhanced repository pattern.
"""
metrics = get_metrics_collector(request_obj)
try:
# Validate tenant access
if tenant_id != current_tenant:
if metrics:
metrics.increment_counter("enhanced_models_access_denied_total")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
# Get models using enhanced service
models = await enhanced_training_service.get_tenant_models(
tenant_id=tenant_id,
active_only=active_only,
skip=skip,
limit=limit
)
if metrics:
metrics.increment_counter("enhanced_models_requests_total")
return {
"tenant_id": tenant_id,
"models": models,
"total_returned": len(models),
"active_only": active_only,
"pagination": {
"skip": skip,
"limit": limit
},
"enhanced_features": True,
"repository_integration": True
}
except Exception as e:
logger.error(f"Failed to get training logs: {str(e)}")
if metrics:
metrics.increment_counter("enhanced_models_errors_total")
logger.error("Failed to get enhanced tenant models",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get training logs"
detail="Failed to get tenant models"
)
@router.get("/health")
async def health_check():
@router.get("/tenants/{tenant_id}/models/{model_id}/performance")
@track_execution_time("enhanced_model_performance_duration_seconds", "training-service")
async def get_enhanced_model_performance(
tenant_id: str = Path(..., description="Tenant ID"),
model_id: str = Path(..., description="Model ID"),
request_obj: Request = None,
current_tenant: str = Depends(get_current_tenant_id_dep),
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
):
"""
Health check endpoint for the training service.
Get enhanced model performance metrics using repository pattern.
"""
metrics = get_metrics_collector(request_obj)
try:
# Validate tenant access
if tenant_id != current_tenant:
if metrics:
metrics.increment_counter("enhanced_performance_access_denied_total")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
# Get performance using enhanced service
performance = await enhanced_training_service.get_model_performance(model_id)
if not performance:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model performance not found"
)
if metrics:
metrics.increment_counter("enhanced_performance_requests_total")
return {
**performance,
"enhanced_features": True,
"repository_integration": True
}
except HTTPException:
raise
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_performance_errors_total")
logger.error("Failed to get enhanced model performance",
model_id=model_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get model performance"
)
@router.get("/tenants/{tenant_id}/statistics")
@track_execution_time("enhanced_tenant_statistics_duration_seconds", "training-service")
async def get_enhanced_tenant_statistics(
tenant_id: str = Path(..., description="Tenant ID"),
request_obj: Request = None,
current_tenant: str = Depends(get_current_tenant_id_dep),
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
):
"""
Get comprehensive enhanced tenant statistics using repository pattern.
"""
metrics = get_metrics_collector(request_obj)
try:
# Validate tenant access
if tenant_id != current_tenant:
if metrics:
metrics.increment_counter("enhanced_statistics_access_denied_total")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
# Get statistics using enhanced service
statistics = await enhanced_training_service.get_tenant_statistics(tenant_id)
if statistics.get("error"):
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=statistics["error"]
)
if metrics:
metrics.increment_counter("enhanced_statistics_requests_total")
return {
**statistics,
"enhanced_features": True,
"repository_integration": True
}
except HTTPException:
raise
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_statistics_errors_total")
logger.error("Failed to get enhanced tenant statistics",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get tenant statistics"
)
@router.get("/health")
async def enhanced_health_check():
"""
Enhanced health check endpoint for the training service.
"""
return {
"status": "healthy",
"service": "training",
"version": "1.0.0",
"service": "enhanced-training-service",
"version": "2.0.0",
"features": [
"repository-pattern",
"dependency-injection",
"enhanced-error-handling",
"metrics-tracking",
"transactional-operations"
],
"timestamp": datetime.now().isoformat()
}
@router.post("/tenants/{tenant_id}/training/jobs/cancel")
async def cancel_tenant_training_jobs(
cancel_data: dict, # {"tenant_id": str}
current_user = Depends(get_current_user_dep),
_admin_check = Depends(require_admin_role),
db: AsyncSession = Depends(get_db)
):
"""Cancel all active training jobs for a tenant (admin only)"""
try:
tenant_id = cancel_data.get("tenant_id")
if not tenant_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="tenant_id is required"
)
tenant_uuid = uuid.UUID(tenant_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid tenant ID format"
)
try:
from app.models.training import TrainingJobQueue
# Find all active jobs for the tenant
active_jobs_query = select(TrainingJobQueue).where(
TrainingJobQueue.tenant_id == tenant_uuid,
TrainingJobQueue.status.in_(["queued", "running", "pending"])
)
active_jobs_result = await db.execute(active_jobs_query)
active_jobs = active_jobs_result.scalars().all()
jobs_cancelled = 0
cancelled_job_ids = []
errors = []
for job in active_jobs:
try:
job.status = "cancelled"
job.updated_at = datetime.utcnow()
job.cancelled_by = current_user.get("user_id")
jobs_cancelled += 1
cancelled_job_ids.append(str(job.id))
logger.info("Cancelled training job",
job_id=str(job.id),
tenant_id=tenant_id)
except Exception as e:
error_msg = f"Failed to cancel job {job.id}: {str(e)}"
errors.append(error_msg)
logger.error(error_msg)
if jobs_cancelled > 0:
await db.commit()
result = {
"success": True,
"tenant_id": tenant_id,
"jobs_cancelled": jobs_cancelled,
"cancelled_job_ids": cancelled_job_ids,
"errors": errors,
"cancelled_at": datetime.utcnow().isoformat()
}
if errors:
result["success"] = len(errors) < len(active_jobs)
return result
except Exception as e:
await db.rollback()
logger.error("Failed to cancel tenant training jobs",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to cancel training jobs"
)
@router.get("/tenants/{tenant_id}/training/jobs/active")
async def get_tenant_active_jobs(
tenant_id: str,
current_user = Depends(get_current_user_dep),
_admin_check = Depends(require_admin_role),
db: AsyncSession = Depends(get_db)
):
"""Get all active training jobs for a tenant (admin only)"""
try:
tenant_uuid = uuid.UUID(tenant_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid tenant ID format"
)
try:
from app.models.training import TrainingJobQueue
# Get active jobs
active_jobs_query = select(TrainingJobQueue).where(
TrainingJobQueue.tenant_id == tenant_uuid,
TrainingJobQueue.status.in_(["queued", "running", "pending"])
)
active_jobs_result = await db.execute(active_jobs_query)
active_jobs = active_jobs_result.scalars().all()
jobs = []
for job in active_jobs:
jobs.append({
"id": str(job.id),
"tenant_id": str(job.tenant_id),
"status": job.status,
"created_at": job.created_at.isoformat() if job.created_at else None,
"updated_at": job.updated_at.isoformat() if job.updated_at else None,
"started_at": job.started_at.isoformat() if job.started_at else None,
"progress": getattr(job, 'progress', 0)
})
return {
"tenant_id": tenant_id,
"active_jobs_count": len(jobs),
"jobs": jobs
}
except Exception as e:
logger.error("Failed to get tenant active jobs",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get active jobs"
)
@router.get("/tenants/{tenant_id}/training/jobs/count")
async def get_tenant_models_count(
tenant_id: str,
current_user = Depends(get_current_user_dep),
_admin_check = Depends(require_admin_role),
db: AsyncSession = Depends(get_db)
):
"""Get count of trained models for a tenant (admin only)"""
try:
tenant_uuid = uuid.UUID(tenant_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid tenant ID format"
)
try:
from app.models.training import TrainedModel, ModelArtifact
# Count models
models_count_query = select(func.count(TrainedModel.id)).where(
TrainedModel.tenant_id == tenant_uuid
)
models_count_result = await db.execute(models_count_query)
models_count = models_count_result.scalar()
# Count artifacts
artifacts_count_query = select(func.count(ModelArtifact.id)).where(
ModelArtifact.tenant_id == tenant_uuid
)
artifacts_count_result = await db.execute(artifacts_count_query)
artifacts_count = artifacts_count_result.scalar()
return {
"tenant_id": tenant_id,
"models_count": models_count,
"artifacts_count": artifacts_count,
"total_training_assets": models_count + artifacts_count
}
except Exception as e:
logger.error("Failed to get tenant models count",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get models count"
)
}

View File

@@ -16,8 +16,9 @@ from fastapi.responses import JSONResponse
import uvicorn
from app.core.config import settings
from app.core.database import initialize_training_database, cleanup_training_database
from app.core.database import initialize_training_database, cleanup_training_database, get_db_health
from app.api import training, models
from app.api.websocket import websocket_router
from app.services.messaging import setup_messaging, cleanup_messaging
from shared.monitoring.logging import setup_logging
@@ -176,6 +177,7 @@ async def global_exception_handler(request: Request, exc: Exception):
# Include API routers
app.include_router(training.router, prefix="/api/v1", tags=["training"])
app.include_router(models.router, prefix="/api/v1", tags=["models"])
app.include_router(websocket_router, prefix="/api/v1/ws", tags=["websocket"])

View File

@@ -0,0 +1,18 @@
"""
ML Pipeline Components
Machine learning training and prediction components
"""
from .trainer import BakeryMLTrainer
from .trainer import EnhancedBakeryMLTrainer
from .data_processor import BakeryDataProcessor
from .data_processor import EnhancedBakeryDataProcessor
from .prophet_manager import BakeryProphetManager
__all__ = [
"BakeryMLTrainer",
"EnhancedBakeryMLTrainer",
"BakeryDataProcessor",
"EnhancedBakeryDataProcessor",
"BakeryProphetManager"
]

View File

@@ -1,32 +1,44 @@
# services/training/app/ml/data_processor.py
"""
Enhanced Data Processor for Training Service
Handles data preparation, date alignment, cleaning, and feature engineering for ML training
Enhanced Data Processor for Training Service with Repository Pattern
Uses repository pattern for data access and dependency injection
"""
import pandas as pd
import numpy as np
from typing import Dict, List, Any, Optional, Tuple
from datetime import datetime, timedelta, timezone
import logging
import structlog
from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType
from app.repositories import ModelRepository, TrainingLogRepository
from shared.database.base import create_database_manager
from shared.database.transactions import transactional
from shared.database.exceptions import DatabaseError
from app.core.config import settings
logger = logging.getLogger(__name__)
logger = structlog.get_logger()
class BakeryDataProcessor:
class EnhancedBakeryDataProcessor:
"""
Enhanced data processor for bakery forecasting training service.
Enhanced data processor for bakery forecasting with repository pattern.
Integrates date alignment, data cleaning, feature engineering, and preparation for ML models.
"""
def __init__(self):
def __init__(self, database_manager=None):
self.database_manager = database_manager or create_database_manager(settings.DATABASE_URL, "training-service")
self.scalers = {} # Store scalers for each feature
self.imputers = {} # Store imputers for missing value handling
self.date_alignment_service = DateAlignmentService()
async def _get_repositories(self, session):
"""Initialize repositories with session"""
return {
'model': ModelRepository(session),
'training_log': TrainingLogRepository(session)
}
def _ensure_timezone_aware(self, df: pd.DataFrame, date_column: str = 'date') -> pd.DataFrame:
"""Ensure date column is timezone-aware to prevent conversion errors"""
if date_column in df.columns:
@@ -46,59 +58,118 @@ class BakeryDataProcessor:
sales_data: pd.DataFrame,
weather_data: pd.DataFrame,
traffic_data: pd.DataFrame,
product_name: str) -> pd.DataFrame:
product_name: str,
tenant_id: str = None,
job_id: str = None,
session=None) -> pd.DataFrame:
"""
Prepare comprehensive training data for a specific product with date alignment.
Prepare comprehensive training data for a specific product with repository logging.
Args:
sales_data: Historical sales data for the product
weather_data: Weather data
traffic_data: Traffic data
product_name: Product name for logging
tenant_id: Optional tenant ID for tracking
job_id: Optional job ID for tracking
Returns:
DataFrame ready for Prophet training with 'ds' and 'y' columns plus features
"""
try:
logger.info(f"Preparing training data for product: {product_name}")
logger.info("Preparing enhanced training data using repository pattern",
product_name=product_name,
tenant_id=tenant_id,
job_id=job_id)
# Step 1: Convert and validate sales data
sales_clean = await self._process_sales_data(sales_data, product_name)
# FIX: Ensure timezone awareness before any operations
sales_clean = self._ensure_timezone_aware(sales_clean)
weather_data = self._ensure_timezone_aware(weather_data) if not weather_data.empty else weather_data
traffic_data = self._ensure_timezone_aware(traffic_data) if not traffic_data.empty else traffic_data
# Step 2: Apply date alignment if we have date constraints
sales_clean = await self._apply_date_alignment(sales_clean, weather_data, traffic_data)
# Step 3: Aggregate to daily level
daily_sales = await self._aggregate_daily_sales(sales_clean)
# Step 4: Add temporal features
daily_sales = self._add_temporal_features(daily_sales)
# Step 5: Merge external data sources
daily_sales = self._merge_weather_features(daily_sales, weather_data)
daily_sales = self._merge_traffic_features(daily_sales, traffic_data)
# Step 6: Engineer additional features
daily_sales = self._engineer_features(daily_sales)
# Step 7: Handle missing values
daily_sales = self._handle_missing_values(daily_sales)
# Step 8: Prepare for Prophet (rename columns and validate)
prophet_data = self._prepare_prophet_format(daily_sales)
logger.info(f"Prepared {len(prophet_data)} data points for {product_name}")
return prophet_data
# Get database session and repositories
async with self.database_manager.get_session() as db_session:
repos = await self._get_repositories(db_session)
# Log data preparation start if we have tracking info
if job_id and tenant_id:
await repos['training_log'].update_log_progress(
job_id, 15, f"preparing_data_{product_name}", "running"
)
# Step 1: Convert and validate sales data
sales_clean = await self._process_sales_data(sales_data, product_name)
# FIX: Ensure timezone awareness before any operations
sales_clean = self._ensure_timezone_aware(sales_clean)
weather_data = self._ensure_timezone_aware(weather_data) if not weather_data.empty else weather_data
traffic_data = self._ensure_timezone_aware(traffic_data) if not traffic_data.empty else traffic_data
# Step 2: Apply date alignment if we have date constraints
sales_clean = await self._apply_date_alignment(sales_clean, weather_data, traffic_data)
# Step 3: Aggregate to daily level
daily_sales = await self._aggregate_daily_sales(sales_clean)
# Step 4: Add temporal features
daily_sales = self._add_temporal_features(daily_sales)
# Step 5: Merge external data sources
daily_sales = self._merge_weather_features(daily_sales, weather_data)
daily_sales = self._merge_traffic_features(daily_sales, traffic_data)
# Step 6: Engineer additional features
daily_sales = self._engineer_features(daily_sales)
# Step 7: Handle missing values
daily_sales = self._handle_missing_values(daily_sales)
# Step 8: Prepare for Prophet (rename columns and validate)
prophet_data = self._prepare_prophet_format(daily_sales)
# Step 9: Store processing metadata if we have a tenant
if tenant_id:
await self._store_processing_metadata(
repos, tenant_id, product_name, prophet_data, job_id
)
logger.info("Enhanced training data prepared successfully",
product_name=product_name,
data_points=len(prophet_data))
return prophet_data
except Exception as e:
logger.error(f"Error preparing training data for {product_name}: {str(e)}")
logger.error("Error preparing enhanced training data",
product_name=product_name,
error=str(e))
raise
async def _store_processing_metadata(self,
repos: Dict,
tenant_id: str,
product_name: str,
processed_data: pd.DataFrame,
job_id: str = None):
"""Store data processing metadata using repository"""
try:
# Create processing metadata
metadata = {
"product_name": product_name,
"data_points": len(processed_data),
"date_range": {
"start": processed_data['ds'].min().isoformat(),
"end": processed_data['ds'].max().isoformat()
},
"features_count": len([col for col in processed_data.columns if col not in ['ds', 'y']]),
"processed_at": datetime.now().isoformat()
}
# Log processing completion
if job_id:
await repos['training_log'].update_log_progress(
job_id, 25, f"data_prepared_{product_name}", "running"
)
except Exception as e:
logger.warning("Failed to store processing metadata",
error=str(e))
async def prepare_prediction_features(self,
future_dates: pd.DatetimeIndex,
weather_forecast: pd.DataFrame = None,
@@ -149,7 +220,7 @@ class BakeryDataProcessor:
return future_df
except Exception as e:
logger.error(f"Error creating prediction features: {e}")
logger.error("Error creating prediction features", error=str(e))
# Return minimal features if error
return pd.DataFrame({'ds': future_dates})
@@ -181,16 +252,18 @@ class BakeryDataProcessor:
mask = (sales_dates >= aligned_range.start) & (sales_dates <= aligned_range.end)
filtered_sales = sales_data[mask].copy()
logger.info(f"Date alignment: {len(sales_data)}{len(filtered_sales)} records")
logger.info(f"Aligned date range: {aligned_range.start.date()} to {aligned_range.end.date()}")
logger.info("Date alignment completed",
original_records=len(sales_data),
filtered_records=len(filtered_sales),
date_range=f"{aligned_range.start.date()} to {aligned_range.end.date()}")
if aligned_range.constraints:
logger.info(f"Applied constraints: {aligned_range.constraints}")
logger.info("Applied constraints", constraints=aligned_range.constraints)
return filtered_sales
except Exception as e:
logger.warning(f"Date alignment failed, using original data: {str(e)}")
logger.warning("Date alignment failed, using original data", error=str(e))
return sales_data
async def _process_sales_data(self, sales_data: pd.DataFrame, product_name: str) -> pd.DataFrame:
@@ -218,7 +291,9 @@ class BakeryDataProcessor:
# Standardize to 'quantity'
if quantity_col != 'quantity':
sales_clean['quantity'] = sales_clean[quantity_col]
logger.info(f"Mapped '{quantity_col}' to 'quantity' column")
logger.info("Mapped quantity column",
from_column=quantity_col,
to_column='quantity')
sales_clean['quantity'] = pd.to_numeric(sales_clean['quantity'], errors='coerce')
@@ -302,7 +377,7 @@ class BakeryDataProcessor:
weather_data: pd.DataFrame) -> pd.DataFrame:
"""Merge weather features with enhanced Madrid-specific handling"""
# ✅ FIX: Define weather_defaults OUTSIDE try block to fix scope error
# Define weather_defaults OUTSIDE try block to fix scope error
weather_defaults = {
'temperature': 15.0,
'precipitation': 0.0,
@@ -324,17 +399,15 @@ class BakeryDataProcessor:
if 'date' not in weather_clean.columns and 'ds' in weather_clean.columns:
weather_clean = weather_clean.rename(columns={'ds': 'date'})
# 🔧 CRITICAL FIX: Ensure both DataFrames have compatible datetime formats
# CRITICAL FIX: Ensure both DataFrames have compatible datetime formats
weather_clean['date'] = pd.to_datetime(weather_clean['date'])
daily_sales['date'] = pd.to_datetime(daily_sales['date'])
# NEW FIX: Normalize both to timezone-naive datetime for merge compatibility
# NEW FIX: Normalize both to timezone-naive datetime for merge compatibility
if weather_clean['date'].dt.tz is not None:
# Convert timezone-aware to UTC then remove timezone info
weather_clean['date'] = weather_clean['date'].dt.tz_convert('UTC').dt.tz_localize(None)
if daily_sales['date'].dt.tz is not None:
# Convert timezone-aware to UTC then remove timezone info
daily_sales['date'] = daily_sales['date'].dt.tz_convert('UTC').dt.tz_localize(None)
# Map weather columns to standard names
@@ -369,8 +442,8 @@ class BakeryDataProcessor:
return merged
except Exception as e:
logger.warning(f"Error merging weather data: {e}")
# Add default weather columns if merge fails (weather_defaults now in scope)
logger.warning("Error merging weather data", error=str(e))
# Add default weather columns if merge fails
for feature, default_value in weather_defaults.items():
daily_sales[feature] = default_value
return daily_sales
@@ -393,18 +466,15 @@ class BakeryDataProcessor:
if 'date' not in traffic_clean.columns and 'ds' in traffic_clean.columns:
traffic_clean = traffic_clean.rename(columns={'ds': 'date'})
# 🔧 CRITICAL FIX: Ensure both DataFrames have compatible datetime formats
# CRITICAL FIX: Ensure both DataFrames have compatible datetime formats
traffic_clean['date'] = pd.to_datetime(traffic_clean['date'])
daily_sales['date'] = pd.to_datetime(daily_sales['date'])
# NEW FIX: Normalize both to timezone-naive datetime for merge compatibility
# This prevents the "datetime64[ns] and datetime64[ns, UTC]" merge error
# NEW FIX: Normalize both to timezone-naive datetime for merge compatibility
if traffic_clean['date'].dt.tz is not None:
# Convert timezone-aware to UTC then remove timezone info
traffic_clean['date'] = traffic_clean['date'].dt.tz_convert('UTC').dt.tz_localize(None)
if daily_sales['date'].dt.tz is not None:
# Convert timezone-aware to UTC then remove timezone info
daily_sales['date'] = daily_sales['date'].dt.tz_convert('UTC').dt.tz_localize(None)
# Map traffic columns to standard names
@@ -445,7 +515,7 @@ class BakeryDataProcessor:
return merged
except Exception as e:
logger.warning(f"Error merging traffic data: {e}")
logger.warning("Error merging traffic data", error=str(e))
# Add default traffic column if merge fails
daily_sales['traffic_volume'] = 100.0
return daily_sales
@@ -473,7 +543,7 @@ class BakeryDataProcessor:
bins=[-0.1, 0, 2, 10, np.inf],
labels=[0, 1, 2, 3]).astype(int)
# ✅ FIX: Traffic-based features with NaN protection
# Traffic-based features with NaN protection
if 'traffic_volume' in df.columns:
# Calculate traffic quantiles for relative measures
q75 = df['traffic_volume'].quantile(0.75)
@@ -482,19 +552,17 @@ class BakeryDataProcessor:
df['high_traffic'] = (df['traffic_volume'] > q75).astype(int)
df['low_traffic'] = (df['traffic_volume'] < q25).astype(int)
# ✅ FIX: Safe normalization with NaN protection
# Safe normalization with NaN protection
traffic_std = df['traffic_volume'].std()
traffic_mean = df['traffic_volume'].mean()
if traffic_std > 0 and not pd.isna(traffic_std) and not pd.isna(traffic_mean):
# Normal case: valid standard deviation
df['traffic_normalized'] = (df['traffic_volume'] - traffic_mean) / traffic_std
else:
# Edge case: all values are the same or contain NaN
logger.warning("Traffic volume has zero standard deviation or contains NaN, using zeros for normalized values")
logger.warning("Traffic volume has zero standard deviation, using zeros for normalized values")
df['traffic_normalized'] = 0.0
# ✅ ADDITIONAL SAFETY: Fill any remaining NaN values
# Fill any remaining NaN values
df['traffic_normalized'] = df['traffic_normalized'].fillna(0.0)
# Interaction features - bakery specific
@@ -528,13 +596,14 @@ class BakeryDataProcessor:
# Spring/summer months
df['is_warm_season'] = df['month'].isin([4, 5, 6, 7, 8, 9]).astype(int)
# FINAL SAFETY CHECK: Remove any remaining NaN values
# Check for NaN values in all numeric columns and fill them
# FINAL SAFETY CHECK: Remove any remaining NaN values
numeric_columns = df.select_dtypes(include=[np.number]).columns
for col in numeric_columns:
if df[col].isna().any():
nan_count = df[col].isna().sum()
logger.warning(f"Found {nan_count} NaN values in column '{col}', filling with 0")
logger.warning("Found NaN values in column, filling with 0",
column=col,
nan_count=nan_count)
df[col] = df[col].fillna(0.0)
return df
@@ -632,8 +701,9 @@ class BakeryDataProcessor:
if len(prophet_df) == 0:
raise ValueError("No valid data points after cleaning")
logger.info(f"Prophet data prepared: {len(prophet_df)} rows, "
f"date range: {prophet_df['ds'].min()} to {prophet_df['ds'].max()}")
logger.info("Prophet data prepared",
rows=len(prophet_df),
date_range=f"{prophet_df['ds'].min()} to {prophet_df['ds'].max()}")
return prophet_df
@@ -690,11 +760,11 @@ class BakeryDataProcessor:
return False
def calculate_feature_importance(self,
async def calculate_feature_importance(self,
model_data: pd.DataFrame,
target_column: str = 'y') -> Dict[str, float]:
"""
Calculate feature importance for the model using correlation analysis.
Calculate feature importance for the model using correlation analysis with repository logging.
"""
try:
# Get numeric features
@@ -704,7 +774,7 @@ class BakeryDataProcessor:
importance_scores = {}
if target_column not in model_data.columns:
logger.warning(f"Target column '{target_column}' not found")
logger.warning("Target column not found", target_column=target_column)
return {}
for feature in numeric_features:
@@ -717,16 +787,18 @@ class BakeryDataProcessor:
importance_scores = dict(sorted(importance_scores.items(),
key=lambda x: x[1], reverse=True))
logger.info(f"Calculated feature importance for {len(importance_scores)} features")
logger.info("Calculated feature importance",
features_count=len(importance_scores))
return importance_scores
except Exception as e:
logger.error(f"Error calculating feature importance: {e}")
logger.error("Error calculating feature importance", error=str(e))
return {}
def get_data_quality_report(self, df: pd.DataFrame) -> Dict[str, Any]:
async def get_data_quality_report(self, df: pd.DataFrame) -> Dict[str, Any]:
"""
Generate a comprehensive data quality report.
Generate a comprehensive data quality report with repository integration.
"""
try:
report = {
@@ -778,5 +850,9 @@ class BakeryDataProcessor:
return report
except Exception as e:
logger.error(f"Error generating data quality report: {e}")
return {"error": str(e)}
logger.error("Error generating data quality report", error=str(e))
return {"error": str(e)}
# Legacy compatibility alias
BakeryDataProcessor = EnhancedBakeryDataProcessor

View File

@@ -24,7 +24,8 @@ warnings.filterwarnings('ignore')
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
from app.models.training import TrainedModel
from app.core.database import get_db_session
from shared.database.base import create_database_manager
from app.repositories import ModelRepository
# Simple optimization import
import optuna
@@ -40,10 +41,11 @@ class BakeryProphetManager:
Drop-in replacement for the existing manager - optimization runs automatically.
"""
def __init__(self, db_session: AsyncSession = None):
def __init__(self, database_manager=None):
self.models = {} # In-memory model storage
self.model_metadata = {} # Store model metadata
self.db_session = db_session # Add database session
self.database_manager = database_manager or create_database_manager(settings.DATABASE_URL, "training-service")
self.db_session = None # Will be set when session is available
# Ensure model storage directory exists
os.makedirs(settings.MODEL_STORAGE_PATH, exist_ok=True)
@@ -84,15 +86,15 @@ class BakeryProphetManager:
# Fit the model
model.fit(prophet_data)
# Store model and calculate metrics (same as before)
model_id = f"{job_id}_{product_name}_{uuid.uuid4().hex[:8]}"
model_path = await self._store_model(
tenant_id, product_name, model, model_id, prophet_data, regressor_columns, best_params
)
# Calculate enhanced training metrics
# Calculate enhanced training metrics first
training_metrics = await self._calculate_training_metrics(model, prophet_data, best_params)
# Store model and metrics - Generate proper UUID for model_id
model_id = str(uuid.uuid4())
model_path = await self._store_model(
tenant_id, product_name, model, model_id, prophet_data, regressor_columns, best_params, training_metrics
)
# Return same format as before, but with optimization info
model_info = {
"model_id": model_id,
@@ -517,11 +519,11 @@ class BakeryProphetManager:
self.models[model_key] = model
self.model_metadata[model_key] = metadata
# 🆕 NEW: Store in database
if self.db_session:
try:
# 🆕 NEW: Store in database using new session
try:
async with self.database_manager.get_session() as db_session:
# Deactivate previous models for this product
await self._deactivate_previous_models(tenant_id, product_name)
await self._deactivate_previous_models_with_session(db_session, tenant_id, product_name)
# Create new database record
db_model = TrainedModel(
@@ -536,8 +538,8 @@ class BakeryProphetManager:
features_used=regressor_columns,
is_active=True,
is_production=True, # New models are production-ready
training_start_date=training_data['ds'].min(),
training_end_date=training_data['ds'].max(),
training_start_date=training_data['ds'].min().to_pydatetime().replace(tzinfo=None) if training_data['ds'].min().tz is None else training_data['ds'].min().to_pydatetime(),
training_end_date=training_data['ds'].max().to_pydatetime().replace(tzinfo=None) if training_data['ds'].max().tz is None else training_data['ds'].max().to_pydatetime(),
training_samples=len(training_data)
)
@@ -549,44 +551,39 @@ class BakeryProphetManager:
db_model.r2_score = training_metrics.get('r2')
db_model.data_quality_score = training_metrics.get('data_quality_score')
self.db_session.add(db_model)
await self.db_session.commit()
db_session.add(db_model)
await db_session.commit()
logger.info(f"Model {model_id} stored in database successfully")
except Exception as e:
logger.error(f"Failed to store model in database: {str(e)}")
await self.db_session.rollback()
# Continue execution - file storage succeeded
except Exception as e:
logger.error(f"Failed to store model in database: {str(e)}")
# Continue execution - file storage succeeded
logger.info(f"Optimized model stored at: {model_path}")
return str(model_path)
async def _deactivate_previous_models(self, tenant_id: str, product_name: str):
"""Deactivate previous models for the same product"""
if self.db_session:
try:
# ✅ FIX: Wrap SQL string with text() for SQLAlchemy 2.0
query = text("""
UPDATE trained_models
SET is_active = false, is_production = false
WHERE tenant_id = :tenant_id AND product_name = :product_name
""")
await self.db_session.execute(query, {
"tenant_id": tenant_id,
"product_name": product_name
})
# ✅ ADD: Commit the transaction
await self.db_session.commit()
logger.info(f"Successfully deactivated previous models for {product_name}")
except Exception as e:
logger.error(f"Failed to deactivate previous models: {str(e)}")
# ✅ ADD: Rollback on error
await self.db_session.rollback()
async def _deactivate_previous_models_with_session(self, db_session, tenant_id: str, product_name: str):
"""Deactivate previous models for the same product using provided session"""
try:
# ✅ FIX: Wrap SQL string with text() for SQLAlchemy 2.0
query = text("""
UPDATE trained_models
SET is_active = false, is_production = false
WHERE tenant_id = :tenant_id AND product_name = :product_name
""")
await db_session.execute(query, {
"tenant_id": tenant_id,
"product_name": product_name
})
# Note: Don't commit here, let the calling method handle the transaction
logger.info(f"Successfully deactivated previous models for {product_name}")
except Exception as e:
logger.error(f"Failed to deactivate previous models: {str(e)}")
raise
# Keep all existing methods unchanged
async def generate_forecast(self,

File diff suppressed because it is too large Load Diff

View File

@@ -6,7 +6,7 @@ Database models for training service
from sqlalchemy import Column, Integer, String, DateTime, Text, Boolean, JSON, Float
from sqlalchemy.dialects.postgresql import UUID, ARRAY
from shared.database.base import Base
from datetime import datetime
from datetime import datetime, timezone
import uuid
@@ -25,8 +25,8 @@ class ModelTrainingLog(Base):
current_step = Column(String(500), default="")
# Timestamps
start_time = Column(DateTime, default=datetime.now)
end_time = Column(DateTime, nullable=True)
start_time = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
end_time = Column(DateTime(timezone=True), nullable=True)
# Configuration and results
config = Column(JSON, nullable=True) # Training job configuration
@@ -34,8 +34,8 @@ class ModelTrainingLog(Base):
error_message = Column(Text, nullable=True)
# Metadata
created_at = Column(DateTime, default=datetime.now)
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
class ModelPerformanceMetric(Base):
"""
@@ -65,8 +65,8 @@ class ModelPerformanceMetric(Base):
evaluation_samples = Column(Integer, nullable=True)
# Metadata
measured_at = Column(DateTime, default=datetime.now)
created_at = Column(DateTime, default=datetime.now)
measured_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
class TrainingJobQueue(Base):
"""
@@ -94,8 +94,8 @@ class TrainingJobQueue(Base):
max_retries = Column(Integer, default=3)
# Metadata
created_at = Column(DateTime, default=datetime.now)
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
cancelled_by = Column(String, nullable=True)
class ModelArtifact(Base):
@@ -119,15 +119,15 @@ class ModelArtifact(Base):
compression = Column(String(50), nullable=True) # gzip, lz4, etc.
# Metadata
created_at = Column(DateTime, default=datetime.now)
expires_at = Column(DateTime, nullable=True) # For automatic cleanup
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
expires_at = Column(DateTime(timezone=True), nullable=True) # For automatic cleanup
class TrainedModel(Base):
__tablename__ = "trained_models"
# Primary identification
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
tenant_id = Column(String, nullable=False, index=True)
# Primary identification - Updated to use UUID properly
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
product_name = Column(String, nullable=False, index=True)
# Model information
@@ -154,13 +154,14 @@ class TrainedModel(Base):
is_active = Column(Boolean, default=True)
is_production = Column(Boolean, default=False)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
last_used_at = Column(DateTime)
# Timestamps - Updated to be timezone-aware with proper defaults
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
last_used_at = Column(DateTime(timezone=True))
# Training data info
training_start_date = Column(DateTime)
training_end_date = Column(DateTime)
training_start_date = Column(DateTime(timezone=True))
training_end_date = Column(DateTime(timezone=True))
data_quality_score = Column(Float)
# Additional metadata
@@ -169,9 +170,9 @@ class TrainedModel(Base):
def to_dict(self):
return {
"id": self.id,
"model_id": self.id,
"tenant_id": self.tenant_id,
"id": str(self.id),
"model_id": str(self.id),
"tenant_id": str(self.tenant_id),
"product_name": self.product_name,
"model_type": self.model_type,
"model_version": self.model_version,
@@ -186,6 +187,7 @@ class TrainedModel(Base):
"is_active": self.is_active,
"is_production": self.is_production,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
"last_used_at": self.last_used_at.isoformat() if self.last_used_at else None,
"training_start_date": self.training_start_date.isoformat() if self.training_start_date else None,
"training_end_date": self.training_end_date.isoformat() if self.training_end_date else None,

View File

@@ -1,80 +1,11 @@
# services/training/app/models/training_models.py
"""
Database models for trained ML models
Legacy file - TrainedModel has been moved to training.py
This file is deprecated and should be removed after migration.
"""
from sqlalchemy import Column, String, DateTime, Float, Integer, Text, Boolean, JSON
from sqlalchemy.ext.declarative import declarative_base
from datetime import datetime
import uuid
# Import the actual model from the correct location
from .training import TrainedModel
Base = declarative_base()
class TrainedModel(Base):
__tablename__ = "trained_models"
# Primary identification
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
tenant_id = Column(String, nullable=False, index=True)
product_name = Column(String, nullable=False, index=True)
# Model information
model_type = Column(String, default="prophet_optimized")
model_version = Column(String, default="1.0")
job_id = Column(String, nullable=False)
# File storage
model_path = Column(String, nullable=False) # Path to the .pkl file
metadata_path = Column(String) # Path to metadata JSON
# Training metrics
mape = Column(Float)
mae = Column(Float)
rmse = Column(Float)
r2_score = Column(Float)
training_samples = Column(Integer)
# Hyperparameters and features
hyperparameters = Column(JSON) # Store optimized parameters
features_used = Column(JSON) # List of regressor columns
# Model status
is_active = Column(Boolean, default=True)
is_production = Column(Boolean, default=False)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
last_used_at = Column(DateTime)
# Training data info
training_start_date = Column(DateTime)
training_end_date = Column(DateTime)
data_quality_score = Column(Float)
# Additional metadata
notes = Column(Text)
created_by = Column(String) # User who triggered training
def to_dict(self):
return {
"id": self.id,
"tenant_id": self.tenant_id,
"product_name": self.product_name,
"model_type": self.model_type,
"model_version": self.model_version,
"model_path": self.model_path,
"mape": self.mape,
"mae": self.mae,
"rmse": self.rmse,
"r2_score": self.r2_score,
"training_samples": self.training_samples,
"hyperparameters": self.hyperparameters,
"features_used": self.features_used,
"is_active": self.is_active,
"is_production": self.is_production,
"created_at": self.created_at.isoformat() if self.created_at else None,
"last_used_at": self.last_used_at.isoformat() if self.last_used_at else None,
"training_start_date": self.training_start_date.isoformat() if self.training_start_date else None,
"training_end_date": self.training_end_date.isoformat() if self.training_end_date else None,
"data_quality_score": self.data_quality_score
}
# For backward compatibility, re-export the model
__all__ = ["TrainedModel"]

View File

@@ -0,0 +1,20 @@
"""
Training Service Repositories
Repository implementations for training service
"""
from .base import TrainingBaseRepository
from .model_repository import ModelRepository
from .training_log_repository import TrainingLogRepository
from .performance_repository import PerformanceRepository
from .job_queue_repository import JobQueueRepository
from .artifact_repository import ArtifactRepository
__all__ = [
"TrainingBaseRepository",
"ModelRepository",
"TrainingLogRepository",
"PerformanceRepository",
"JobQueueRepository",
"ArtifactRepository"
]

View File

@@ -0,0 +1,433 @@
"""
Artifact Repository
Repository for model artifact operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, text, desc
from datetime import datetime, timedelta
import structlog
from .base import TrainingBaseRepository
from app.models.training import ModelArtifact
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class ArtifactRepository(TrainingBaseRepository):
"""Repository for model artifact operations"""
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 1800):
# Artifacts are stable, longer cache time (30 minutes)
super().__init__(ModelArtifact, session, cache_ttl)
async def create_artifact(self, artifact_data: Dict[str, Any]) -> ModelArtifact:
"""Create a new model artifact record"""
try:
# Validate artifact data
validation_result = self._validate_training_data(
artifact_data,
["model_id", "tenant_id", "artifact_type", "file_path"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid artifact data: {validation_result['errors']}")
# Set default values
if "storage_location" not in artifact_data:
artifact_data["storage_location"] = "local"
# Create artifact record
artifact = await self.create(artifact_data)
logger.info("Model artifact created",
model_id=artifact.model_id,
tenant_id=artifact.tenant_id,
artifact_type=artifact.artifact_type,
file_path=artifact.file_path)
return artifact
except ValidationError:
raise
except Exception as e:
logger.error("Failed to create model artifact",
model_id=artifact_data.get("model_id"),
error=str(e))
raise DatabaseError(f"Failed to create artifact: {str(e)}")
async def get_artifacts_by_model(
self,
model_id: str,
artifact_type: str = None
) -> List[ModelArtifact]:
"""Get all artifacts for a model"""
try:
filters = {"model_id": model_id}
if artifact_type:
filters["artifact_type"] = artifact_type
return await self.get_multi(
filters=filters,
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get artifacts by model",
model_id=model_id,
artifact_type=artifact_type,
error=str(e))
raise DatabaseError(f"Failed to get artifacts: {str(e)}")
async def get_artifacts_by_tenant(
self,
tenant_id: str,
artifact_type: str = None,
skip: int = 0,
limit: int = 100
) -> List[ModelArtifact]:
"""Get artifacts for a tenant"""
try:
filters = {"tenant_id": tenant_id}
if artifact_type:
filters["artifact_type"] = artifact_type
return await self.get_multi(
filters=filters,
skip=skip,
limit=limit,
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get artifacts by tenant",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get tenant artifacts: {str(e)}")
async def get_artifact_by_path(self, file_path: str) -> Optional[ModelArtifact]:
"""Get artifact by file path"""
try:
return await self.get_by_field("file_path", file_path)
except Exception as e:
logger.error("Failed to get artifact by path",
file_path=file_path,
error=str(e))
raise DatabaseError(f"Failed to get artifact: {str(e)}")
async def update_artifact_size(self, artifact_id: int, file_size_bytes: int) -> Optional[ModelArtifact]:
"""Update artifact file size"""
try:
return await self.update(artifact_id, {"file_size_bytes": file_size_bytes})
except Exception as e:
logger.error("Failed to update artifact size",
artifact_id=artifact_id,
error=str(e))
return None
async def update_artifact_checksum(self, artifact_id: int, checksum: str) -> Optional[ModelArtifact]:
"""Update artifact checksum for integrity verification"""
try:
return await self.update(artifact_id, {"checksum": checksum})
except Exception as e:
logger.error("Failed to update artifact checksum",
artifact_id=artifact_id,
error=str(e))
return None
async def mark_artifact_expired(self, artifact_id: int, expires_at: datetime = None) -> Optional[ModelArtifact]:
"""Mark artifact for expiration/cleanup"""
try:
if not expires_at:
expires_at = datetime.now()
return await self.update(artifact_id, {"expires_at": expires_at})
except Exception as e:
logger.error("Failed to mark artifact as expired",
artifact_id=artifact_id,
error=str(e))
return None
async def get_expired_artifacts(self, days_expired: int = 0) -> List[ModelArtifact]:
"""Get artifacts that have expired"""
try:
cutoff_date = datetime.now() - timedelta(days=days_expired)
query_text = """
SELECT * FROM model_artifacts
WHERE expires_at IS NOT NULL
AND expires_at <= :cutoff_date
ORDER BY expires_at ASC
"""
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
expired_artifacts = []
for row in result.fetchall():
record_dict = dict(row._mapping)
artifact = self.model(**record_dict)
expired_artifacts.append(artifact)
return expired_artifacts
except Exception as e:
logger.error("Failed to get expired artifacts",
days_expired=days_expired,
error=str(e))
return []
async def cleanup_expired_artifacts(self, days_expired: int = 0) -> int:
"""Clean up expired artifacts"""
try:
cutoff_date = datetime.now() - timedelta(days=days_expired)
query_text = """
DELETE FROM model_artifacts
WHERE expires_at IS NOT NULL
AND expires_at <= :cutoff_date
"""
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
deleted_count = result.rowcount
logger.info("Cleaned up expired artifacts",
deleted_count=deleted_count,
days_expired=days_expired)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup expired artifacts",
days_expired=days_expired,
error=str(e))
raise DatabaseError(f"Artifact cleanup failed: {str(e)}")
async def get_large_artifacts(self, min_size_mb: int = 100) -> List[ModelArtifact]:
"""Get artifacts larger than specified size"""
try:
min_size_bytes = min_size_mb * 1024 * 1024 # Convert MB to bytes
query_text = """
SELECT * FROM model_artifacts
WHERE file_size_bytes >= :min_size_bytes
ORDER BY file_size_bytes DESC
"""
result = await self.session.execute(text(query_text), {"min_size_bytes": min_size_bytes})
large_artifacts = []
for row in result.fetchall():
record_dict = dict(row._mapping)
artifact = self.model(**record_dict)
large_artifacts.append(artifact)
return large_artifacts
except Exception as e:
logger.error("Failed to get large artifacts",
min_size_mb=min_size_mb,
error=str(e))
return []
async def get_artifacts_by_storage_location(
self,
storage_location: str,
tenant_id: str = None
) -> List[ModelArtifact]:
"""Get artifacts by storage location"""
try:
filters = {"storage_location": storage_location}
if tenant_id:
filters["tenant_id"] = tenant_id
return await self.get_multi(
filters=filters,
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get artifacts by storage location",
storage_location=storage_location,
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get artifacts: {str(e)}")
async def get_artifact_statistics(self, tenant_id: str = None) -> Dict[str, Any]:
"""Get artifact statistics"""
try:
base_filters = {}
if tenant_id:
base_filters["tenant_id"] = tenant_id
# Get basic counts
total_artifacts = await self.count(filters=base_filters)
# Get artifacts by type
type_query_params = {}
type_query_filter = ""
if tenant_id:
type_query_filter = "WHERE tenant_id = :tenant_id"
type_query_params["tenant_id"] = tenant_id
type_query = text(f"""
SELECT artifact_type, COUNT(*) as count
FROM model_artifacts
{type_query_filter}
GROUP BY artifact_type
ORDER BY count DESC
""")
result = await self.session.execute(type_query, type_query_params)
artifacts_by_type = {row.artifact_type: row.count for row in result.fetchall()}
# Get storage location stats
location_query = text(f"""
SELECT
storage_location,
COUNT(*) as count,
SUM(COALESCE(file_size_bytes, 0)) as total_size_bytes
FROM model_artifacts
{type_query_filter}
GROUP BY storage_location
ORDER BY count DESC
""")
location_result = await self.session.execute(location_query, type_query_params)
storage_stats = {}
total_size_bytes = 0
for row in location_result.fetchall():
storage_stats[row.storage_location] = {
"artifact_count": row.count,
"total_size_bytes": int(row.total_size_bytes or 0),
"total_size_mb": round((row.total_size_bytes or 0) / (1024 * 1024), 2)
}
total_size_bytes += row.total_size_bytes or 0
# Get expired artifacts count
expired_artifacts = len(await self.get_expired_artifacts())
return {
"total_artifacts": total_artifacts,
"expired_artifacts": expired_artifacts,
"active_artifacts": total_artifacts - expired_artifacts,
"artifacts_by_type": artifacts_by_type,
"storage_statistics": storage_stats,
"total_storage": {
"total_size_bytes": total_size_bytes,
"total_size_mb": round(total_size_bytes / (1024 * 1024), 2),
"total_size_gb": round(total_size_bytes / (1024 * 1024 * 1024), 2)
}
}
except Exception as e:
logger.error("Failed to get artifact statistics",
tenant_id=tenant_id,
error=str(e))
return {
"total_artifacts": 0,
"expired_artifacts": 0,
"active_artifacts": 0,
"artifacts_by_type": {},
"storage_statistics": {},
"total_storage": {
"total_size_bytes": 0,
"total_size_mb": 0.0,
"total_size_gb": 0.0
}
}
async def verify_artifact_integrity(self, artifact_id: int) -> Dict[str, Any]:
"""Verify artifact file integrity (placeholder for file system checks)"""
try:
artifact = await self.get_by_id(artifact_id)
if not artifact:
return {"exists": False, "error": "Artifact not found"}
# This is a placeholder - in a real implementation, you would:
# 1. Check if the file exists at artifact.file_path
# 2. Calculate current checksum and compare with stored checksum
# 3. Verify file size matches stored file_size_bytes
return {
"artifact_id": artifact_id,
"file_path": artifact.file_path,
"exists": True, # Would check actual file existence
"checksum_valid": True, # Would verify actual checksum
"size_valid": True, # Would verify actual file size
"storage_location": artifact.storage_location,
"last_verified": datetime.now().isoformat()
}
except Exception as e:
logger.error("Failed to verify artifact integrity",
artifact_id=artifact_id,
error=str(e))
return {
"exists": False,
"error": f"Verification failed: {str(e)}"
}
async def migrate_artifacts_to_storage(
self,
from_location: str,
to_location: str,
tenant_id: str = None
) -> Dict[str, Any]:
"""Migrate artifacts from one storage location to another (placeholder)"""
try:
# Get artifacts to migrate
artifacts = await self.get_artifacts_by_storage_location(from_location, tenant_id)
migrated_count = 0
failed_count = 0
# This is a placeholder - in a real implementation, you would:
# 1. Copy files from old location to new location
# 2. Update file paths in database
# 3. Verify successful migration
# 4. Clean up old files
for artifact in artifacts:
try:
# Placeholder migration logic
new_file_path = artifact.file_path.replace(from_location, to_location)
await self.update(artifact.id, {
"storage_location": to_location,
"file_path": new_file_path
})
migrated_count += 1
except Exception as migration_error:
logger.error("Failed to migrate artifact",
artifact_id=artifact.id,
error=str(migration_error))
failed_count += 1
logger.info("Artifact migration completed",
from_location=from_location,
to_location=to_location,
migrated_count=migrated_count,
failed_count=failed_count)
return {
"from_location": from_location,
"to_location": to_location,
"total_artifacts": len(artifacts),
"migrated_count": migrated_count,
"failed_count": failed_count,
"success_rate": round((migrated_count / len(artifacts)) * 100, 2) if artifacts else 100
}
except Exception as e:
logger.error("Failed to migrate artifacts",
from_location=from_location,
to_location=to_location,
error=str(e))
return {
"error": f"Migration failed: {str(e)}"
}

View File

@@ -0,0 +1,179 @@
"""
Base Repository for Training Service
Service-specific repository base class with training service utilities
"""
from typing import Optional, List, Dict, Any, Type
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
from datetime import datetime, timedelta
import structlog
from shared.database.repository import BaseRepository
from shared.database.exceptions import DatabaseError
logger = structlog.get_logger()
class TrainingBaseRepository(BaseRepository):
"""Base repository for training service with common training operations"""
def __init__(self, model: Type, session: AsyncSession, cache_ttl: Optional[int] = 300):
# Training data changes frequently, shorter cache time (5 minutes)
super().__init__(model, session, cache_ttl)
async def get_by_tenant_id(self, tenant_id: str, skip: int = 0, limit: int = 100) -> List:
"""Get records by tenant ID"""
if hasattr(self.model, 'tenant_id'):
return await self.get_multi(
skip=skip,
limit=limit,
filters={"tenant_id": tenant_id},
order_by="created_at",
order_desc=True
)
return await self.get_multi(skip=skip, limit=limit)
async def get_active_records(self, skip: int = 0, limit: int = 100) -> List:
"""Get active records (if model has is_active field)"""
if hasattr(self.model, 'is_active'):
return await self.get_multi(
skip=skip,
limit=limit,
filters={"is_active": True},
order_by="created_at",
order_desc=True
)
return await self.get_multi(skip=skip, limit=limit)
async def get_by_job_id(self, job_id: str) -> Optional:
"""Get record by job ID (if model has job_id field)"""
if hasattr(self.model, 'job_id'):
return await self.get_by_field("job_id", job_id)
return None
async def get_by_model_id(self, model_id: str) -> Optional:
"""Get record by model ID (if model has model_id field)"""
if hasattr(self.model, 'model_id'):
return await self.get_by_field("model_id", model_id)
return None
async def deactivate_record(self, record_id: Any) -> Optional:
"""Deactivate a record instead of deleting it"""
if hasattr(self.model, 'is_active'):
return await self.update(record_id, {"is_active": False})
return await self.delete(record_id)
async def activate_record(self, record_id: Any) -> Optional:
"""Activate a record"""
if hasattr(self.model, 'is_active'):
return await self.update(record_id, {"is_active": True})
return await self.get_by_id(record_id)
async def cleanup_old_records(self, days_old: int = 90, status_filter: str = None) -> int:
"""Clean up old training records"""
try:
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
table_name = self.model.__tablename__
# Build query based on available fields
conditions = [f"created_at < :cutoff_date"]
params = {"cutoff_date": cutoff_date}
if status_filter and hasattr(self.model, 'status'):
conditions.append(f"status = :status")
params["status"] = status_filter
query_text = f"""
DELETE FROM {table_name}
WHERE {' AND '.join(conditions)}
"""
result = await self.session.execute(text(query_text), params)
deleted_count = result.rowcount
logger.info(f"Cleaned up old {self.model.__name__} records",
deleted_count=deleted_count,
days_old=days_old,
status_filter=status_filter)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup old records",
model=self.model.__name__,
error=str(e))
raise DatabaseError(f"Cleanup failed: {str(e)}")
async def get_records_by_date_range(
self,
start_date: datetime,
end_date: datetime,
skip: int = 0,
limit: int = 100
) -> List:
"""Get records within date range"""
if not hasattr(self.model, 'created_at'):
logger.warning(f"Model {self.model.__name__} has no created_at field")
return []
try:
table_name = self.model.__tablename__
query_text = f"""
SELECT * FROM {table_name}
WHERE created_at >= :start_date
AND created_at <= :end_date
ORDER BY created_at DESC
LIMIT :limit OFFSET :skip
"""
result = await self.session.execute(text(query_text), {
"start_date": start_date,
"end_date": end_date,
"limit": limit,
"skip": skip
})
# Convert rows to model objects
records = []
for row in result.fetchall():
# Create model instance from row data
record_dict = dict(row._mapping)
record = self.model(**record_dict)
records.append(record)
return records
except Exception as e:
logger.error("Failed to get records by date range",
model=self.model.__name__,
start_date=start_date,
end_date=end_date,
error=str(e))
raise DatabaseError(f"Date range query failed: {str(e)}")
def _validate_training_data(self, data: Dict[str, Any], required_fields: List[str]) -> Dict[str, Any]:
"""Validate training-related data"""
errors = []
for field in required_fields:
if field not in data or not data[field]:
errors.append(f"Missing required field: {field}")
# Validate tenant_id format if present
if "tenant_id" in data and data["tenant_id"]:
tenant_id = data["tenant_id"]
if not isinstance(tenant_id, str) or len(tenant_id) < 1:
errors.append("Invalid tenant_id format")
# Validate job_id format if present
if "job_id" in data and data["job_id"]:
job_id = data["job_id"]
if not isinstance(job_id, str) or len(job_id) < 1:
errors.append("Invalid job_id format")
return {
"is_valid": len(errors) == 0,
"errors": errors
}

View File

@@ -0,0 +1,445 @@
"""
Job Queue Repository
Repository for training job queue operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, text, desc
from datetime import datetime, timedelta
import structlog
from .base import TrainingBaseRepository
from app.models.training import TrainingJobQueue
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class JobQueueRepository(TrainingBaseRepository):
"""Repository for training job queue operations"""
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 60):
# Job queue changes frequently, very short cache time (1 minute)
super().__init__(TrainingJobQueue, session, cache_ttl)
async def enqueue_job(self, job_data: Dict[str, Any]) -> TrainingJobQueue:
"""Add a job to the training queue"""
try:
# Validate job data
validation_result = self._validate_training_data(
job_data,
["job_id", "tenant_id", "job_type"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid job data: {validation_result['errors']}")
# Set default values
if "priority" not in job_data:
job_data["priority"] = 1
if "status" not in job_data:
job_data["status"] = "queued"
if "max_retries" not in job_data:
job_data["max_retries"] = 3
# Create queue entry
queued_job = await self.create(job_data)
logger.info("Job enqueued",
job_id=queued_job.job_id,
tenant_id=queued_job.tenant_id,
job_type=queued_job.job_type,
priority=queued_job.priority)
return queued_job
except ValidationError:
raise
except Exception as e:
logger.error("Failed to enqueue job",
job_id=job_data.get("job_id"),
error=str(e))
raise DatabaseError(f"Failed to enqueue job: {str(e)}")
async def get_next_job(self, job_types: List[str] = None) -> Optional[TrainingJobQueue]:
"""Get the next job to process from the queue"""
try:
# Build filters for job types if specified
filters = {"status": "queued"}
if job_types:
# For multiple job types, we need to use raw SQL
job_types_str = "', '".join(job_types)
query_text = f"""
SELECT * FROM training_job_queue
WHERE status = 'queued'
AND job_type IN ('{job_types_str}')
AND (scheduled_at IS NULL OR scheduled_at <= :now)
ORDER BY priority DESC, created_at ASC
LIMIT 1
"""
result = await self.session.execute(text(query_text), {"now": datetime.now()})
row = result.fetchone()
if row:
record_dict = dict(row._mapping)
return self.model(**record_dict)
return None
else:
# Simple case - get any queued job
jobs = await self.get_multi(
filters=filters,
limit=1,
order_by="priority",
order_desc=True
)
return jobs[0] if jobs else None
except Exception as e:
logger.error("Failed to get next job from queue",
job_types=job_types,
error=str(e))
raise DatabaseError(f"Failed to get next job: {str(e)}")
async def start_job(self, job_id: str) -> Optional[TrainingJobQueue]:
"""Mark a job as started"""
try:
job = await self.get_by_job_id(job_id)
if not job:
logger.error(f"Job not found in queue: {job_id}")
return None
if job.status != "queued":
logger.warning(f"Job {job_id} is not queued (status: {job.status})")
return job
updated_job = await self.update(job.id, {
"status": "running",
"started_at": datetime.now(),
"updated_at": datetime.now()
})
logger.info("Job started",
job_id=job_id,
job_type=job.job_type)
return updated_job
except Exception as e:
logger.error("Failed to start job",
job_id=job_id,
error=str(e))
raise DatabaseError(f"Failed to start job: {str(e)}")
async def complete_job(self, job_id: str) -> Optional[TrainingJobQueue]:
"""Mark a job as completed"""
try:
job = await self.get_by_job_id(job_id)
if not job:
logger.error(f"Job not found in queue: {job_id}")
return None
updated_job = await self.update(job.id, {
"status": "completed",
"updated_at": datetime.now()
})
logger.info("Job completed",
job_id=job_id,
job_type=job.job_type if job else "unknown")
return updated_job
except Exception as e:
logger.error("Failed to complete job",
job_id=job_id,
error=str(e))
raise DatabaseError(f"Failed to complete job: {str(e)}")
async def fail_job(self, job_id: str, error_message: str = None) -> Optional[TrainingJobQueue]:
"""Mark a job as failed and handle retries"""
try:
job = await self.get_by_job_id(job_id)
if not job:
logger.error(f"Job not found in queue: {job_id}")
return None
# Increment retry count
new_retry_count = job.retry_count + 1
# Check if we should retry
if new_retry_count < job.max_retries:
# Reset to queued for retry
updated_job = await self.update(job.id, {
"status": "queued",
"retry_count": new_retry_count,
"updated_at": datetime.now(),
"started_at": None # Reset started_at for retry
})
logger.info("Job failed, queued for retry",
job_id=job_id,
retry_count=new_retry_count,
max_retries=job.max_retries)
else:
# Mark as permanently failed
updated_job = await self.update(job.id, {
"status": "failed",
"retry_count": new_retry_count,
"updated_at": datetime.now()
})
logger.error("Job permanently failed",
job_id=job_id,
retry_count=new_retry_count,
error_message=error_message)
return updated_job
except Exception as e:
logger.error("Failed to handle job failure",
job_id=job_id,
error=str(e))
raise DatabaseError(f"Failed to handle job failure: {str(e)}")
async def cancel_job(self, job_id: str, cancelled_by: str = None) -> Optional[TrainingJobQueue]:
"""Cancel a job"""
try:
job = await self.get_by_job_id(job_id)
if not job:
logger.error(f"Job not found in queue: {job_id}")
return None
if job.status in ["completed", "failed"]:
logger.warning(f"Cannot cancel job {job_id} with status {job.status}")
return job
updated_job = await self.update(job.id, {
"status": "cancelled",
"cancelled_by": cancelled_by,
"updated_at": datetime.now()
})
logger.info("Job cancelled",
job_id=job_id,
cancelled_by=cancelled_by)
return updated_job
except Exception as e:
logger.error("Failed to cancel job",
job_id=job_id,
error=str(e))
raise DatabaseError(f"Failed to cancel job: {str(e)}")
async def get_queue_status(self, tenant_id: str = None) -> Dict[str, Any]:
"""Get queue status and statistics"""
try:
base_filters = {}
if tenant_id:
base_filters["tenant_id"] = tenant_id
# Get counts by status
queued_jobs = await self.count(filters={**base_filters, "status": "queued"})
running_jobs = await self.count(filters={**base_filters, "status": "running"})
completed_jobs = await self.count(filters={**base_filters, "status": "completed"})
failed_jobs = await self.count(filters={**base_filters, "status": "failed"})
cancelled_jobs = await self.count(filters={**base_filters, "status": "cancelled"})
# Get jobs by type
type_query = text(f"""
SELECT job_type, COUNT(*) as count
FROM training_job_queue
WHERE 1=1
{' AND tenant_id = :tenant_id' if tenant_id else ''}
GROUP BY job_type
ORDER BY count DESC
""")
params = {"tenant_id": tenant_id} if tenant_id else {}
result = await self.session.execute(type_query, params)
jobs_by_type = {row.job_type: row.count for row in result.fetchall()}
# Get average wait time for completed jobs
wait_time_query = text(f"""
SELECT
AVG(EXTRACT(EPOCH FROM (started_at - created_at))/60) as avg_wait_minutes
FROM training_job_queue
WHERE status = 'completed'
AND started_at IS NOT NULL
AND created_at IS NOT NULL
{' AND tenant_id = :tenant_id' if tenant_id else ''}
""")
wait_result = await self.session.execute(wait_time_query, params)
wait_row = wait_result.fetchone()
avg_wait_time = float(wait_row.avg_wait_minutes) if wait_row and wait_row.avg_wait_minutes else 0.0
return {
"tenant_id": tenant_id,
"queue_counts": {
"queued": queued_jobs,
"running": running_jobs,
"completed": completed_jobs,
"failed": failed_jobs,
"cancelled": cancelled_jobs,
"total": queued_jobs + running_jobs + completed_jobs + failed_jobs + cancelled_jobs
},
"jobs_by_type": jobs_by_type,
"avg_wait_time_minutes": round(avg_wait_time, 2),
"queue_health": {
"has_queued_jobs": queued_jobs > 0,
"has_running_jobs": running_jobs > 0,
"failure_rate": round((failed_jobs / max(completed_jobs + failed_jobs, 1)) * 100, 2)
}
}
except Exception as e:
logger.error("Failed to get queue status",
tenant_id=tenant_id,
error=str(e))
return {
"tenant_id": tenant_id,
"queue_counts": {
"queued": 0, "running": 0, "completed": 0,
"failed": 0, "cancelled": 0, "total": 0
},
"jobs_by_type": {},
"avg_wait_time_minutes": 0.0,
"queue_health": {
"has_queued_jobs": False,
"has_running_jobs": False,
"failure_rate": 0.0
}
}
async def get_jobs_by_tenant(
self,
tenant_id: str,
status: str = None,
job_type: str = None,
skip: int = 0,
limit: int = 100
) -> List[TrainingJobQueue]:
"""Get jobs for a tenant with optional filtering"""
try:
filters = {"tenant_id": tenant_id}
if status:
filters["status"] = status
if job_type:
filters["job_type"] = job_type
return await self.get_multi(
filters=filters,
skip=skip,
limit=limit,
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get jobs by tenant",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get tenant jobs: {str(e)}")
async def cleanup_old_jobs(self, days_old: int = 30, status_filter: str = None) -> int:
"""Clean up old completed/failed/cancelled jobs"""
try:
cutoff_date = datetime.now() - timedelta(days=days_old)
# Only clean up finished jobs by default
default_statuses = ["completed", "failed", "cancelled"]
if status_filter:
status_condition = "status = :status"
params = {"cutoff_date": cutoff_date, "status": status_filter}
else:
status_list = "', '".join(default_statuses)
status_condition = f"status IN ('{status_list}')"
params = {"cutoff_date": cutoff_date}
query_text = f"""
DELETE FROM training_job_queue
WHERE created_at < :cutoff_date
AND {status_condition}
"""
result = await self.session.execute(text(query_text), params)
deleted_count = result.rowcount
logger.info("Cleaned up old queue jobs",
deleted_count=deleted_count,
days_old=days_old,
status_filter=status_filter)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup old queue jobs",
error=str(e))
raise DatabaseError(f"Queue cleanup failed: {str(e)}")
async def get_stuck_jobs(self, hours_stuck: int = 2) -> List[TrainingJobQueue]:
"""Get jobs that have been running for too long"""
try:
cutoff_time = datetime.now() - timedelta(hours=hours_stuck)
query_text = """
SELECT * FROM training_job_queue
WHERE status = 'running'
AND started_at IS NOT NULL
AND started_at < :cutoff_time
ORDER BY started_at ASC
"""
result = await self.session.execute(text(query_text), {"cutoff_time": cutoff_time})
stuck_jobs = []
for row in result.fetchall():
record_dict = dict(row._mapping)
job = self.model(**record_dict)
stuck_jobs.append(job)
if stuck_jobs:
logger.warning("Found stuck jobs",
count=len(stuck_jobs),
hours_stuck=hours_stuck)
return stuck_jobs
except Exception as e:
logger.error("Failed to get stuck jobs",
hours_stuck=hours_stuck,
error=str(e))
return []
async def reset_stuck_jobs(self, hours_stuck: int = 2) -> int:
"""Reset stuck jobs back to queued status"""
try:
stuck_jobs = await self.get_stuck_jobs(hours_stuck)
reset_count = 0
for job in stuck_jobs:
# Reset job to queued status
await self.update(job.id, {
"status": "queued",
"started_at": None,
"updated_at": datetime.now()
})
reset_count += 1
if reset_count > 0:
logger.info("Reset stuck jobs",
reset_count=reset_count,
hours_stuck=hours_stuck)
return reset_count
except Exception as e:
logger.error("Failed to reset stuck jobs",
hours_stuck=hours_stuck,
error=str(e))
raise DatabaseError(f"Failed to reset stuck jobs: {str(e)}")

View File

@@ -0,0 +1,346 @@
"""
Model Repository
Repository for trained model operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, text, desc
from datetime import datetime, timedelta
import structlog
from .base import TrainingBaseRepository
from app.models.training import TrainedModel
from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError
logger = structlog.get_logger()
class ModelRepository(TrainingBaseRepository):
"""Repository for trained model operations"""
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 600):
# Models are relatively stable, longer cache time (10 minutes)
super().__init__(TrainedModel, session, cache_ttl)
async def create_model(self, model_data: Dict[str, Any]) -> TrainedModel:
"""Create a new trained model with validation"""
try:
# Validate model data
validation_result = self._validate_training_data(
model_data,
["tenant_id", "product_name", "model_path", "job_id"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid model data: {validation_result['errors']}")
# Check for duplicate active models for same tenant+product
existing_model = await self.get_active_model_for_product(
model_data["tenant_id"],
model_data["product_name"]
)
# If there's an existing active model, we may want to deactivate it
if existing_model and model_data.get("is_production", False):
logger.info("Deactivating previous production model",
previous_model_id=existing_model.id,
tenant_id=model_data["tenant_id"],
product_name=model_data["product_name"])
await self.update(existing_model.id, {"is_production": False})
# Create new model
model = await self.create(model_data)
logger.info("Trained model created successfully",
model_id=model.id,
tenant_id=model.tenant_id,
product_name=model.product_name,
model_type=model.model_type)
return model
except (ValidationError, DuplicateRecordError):
raise
except Exception as e:
logger.error("Failed to create trained model",
tenant_id=model_data.get("tenant_id"),
product_name=model_data.get("product_name"),
error=str(e))
raise DatabaseError(f"Failed to create model: {str(e)}")
async def get_model_by_tenant_and_product(
self,
tenant_id: str,
product_name: str
) -> List[TrainedModel]:
"""Get all models for a tenant and product"""
try:
return await self.get_multi(
filters={
"tenant_id": tenant_id,
"product_name": product_name
},
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get models by tenant and product",
tenant_id=tenant_id,
product_name=product_name,
error=str(e))
raise DatabaseError(f"Failed to get models: {str(e)}")
async def get_active_model_for_product(
self,
tenant_id: str,
product_name: str
) -> Optional[TrainedModel]:
"""Get the active production model for a product"""
try:
models = await self.get_multi(
filters={
"tenant_id": tenant_id,
"product_name": product_name,
"is_active": True,
"is_production": True
},
order_by="created_at",
order_desc=True,
limit=1
)
return models[0] if models else None
except Exception as e:
logger.error("Failed to get active model for product",
tenant_id=tenant_id,
product_name=product_name,
error=str(e))
raise DatabaseError(f"Failed to get active model: {str(e)}")
async def get_models_by_tenant(
self,
tenant_id: str,
skip: int = 0,
limit: int = 100
) -> List[TrainedModel]:
"""Get all models for a tenant"""
return await self.get_by_tenant_id(tenant_id, skip=skip, limit=limit)
async def promote_to_production(self, model_id: str) -> Optional[TrainedModel]:
"""Promote a model to production"""
try:
# Get the model first
model = await self.get_by_id(model_id)
if not model:
raise ValueError(f"Model {model_id} not found")
# Deactivate other production models for the same tenant+product
await self._deactivate_other_production_models(
model.tenant_id,
model.product_name,
model_id
)
# Promote this model
updated_model = await self.update(model_id, {
"is_production": True,
"last_used_at": datetime.utcnow()
})
logger.info("Model promoted to production",
model_id=model_id,
tenant_id=model.tenant_id,
product_name=model.product_name)
return updated_model
except Exception as e:
logger.error("Failed to promote model to production",
model_id=model_id,
error=str(e))
raise DatabaseError(f"Failed to promote model: {str(e)}")
async def update_model_usage(self, model_id: str) -> Optional[TrainedModel]:
"""Update model last used timestamp"""
try:
return await self.update(model_id, {
"last_used_at": datetime.utcnow()
})
except Exception as e:
logger.error("Failed to update model usage",
model_id=model_id,
error=str(e))
# Don't raise here - usage update is not critical
return None
async def archive_old_models(self, tenant_id: str, days_old: int = 90) -> int:
"""Archive old non-production models"""
try:
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
query = text("""
UPDATE trained_models
SET is_active = false
WHERE tenant_id = :tenant_id
AND is_production = false
AND created_at < :cutoff_date
AND is_active = true
""")
result = await self.session.execute(query, {
"tenant_id": tenant_id,
"cutoff_date": cutoff_date
})
archived_count = result.rowcount
logger.info("Archived old models",
tenant_id=tenant_id,
archived_count=archived_count,
days_old=days_old)
return archived_count
except Exception as e:
logger.error("Failed to archive old models",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Model archival failed: {str(e)}")
async def get_model_statistics(self, tenant_id: str) -> Dict[str, Any]:
"""Get model statistics for a tenant"""
try:
# Get basic counts
total_models = await self.count(filters={"tenant_id": tenant_id})
active_models = await self.count(filters={
"tenant_id": tenant_id,
"is_active": True
})
production_models = await self.count(filters={
"tenant_id": tenant_id,
"is_production": True
})
# Get models by product using raw query
product_query = text("""
SELECT product_name, COUNT(*) as count
FROM trained_models
WHERE tenant_id = :tenant_id
AND is_active = true
GROUP BY product_name
ORDER BY count DESC
""")
result = await self.session.execute(product_query, {"tenant_id": tenant_id})
product_stats = {row.product_name: row.count for row in result.fetchall()}
# Recent activity (models created in last 30 days)
thirty_days_ago = datetime.utcnow() - timedelta(days=30)
recent_models_query = text("""
SELECT COUNT(*) as count
FROM trained_models
WHERE tenant_id = :tenant_id
AND created_at >= :thirty_days_ago
""")
recent_result = await self.session.execute(
recent_models_query,
{"tenant_id": tenant_id, "thirty_days_ago": thirty_days_ago}
)
recent_models = recent_result.scalar() or 0
return {
"total_models": total_models,
"active_models": active_models,
"inactive_models": total_models - active_models,
"production_models": production_models,
"models_by_product": product_stats,
"recent_models_30d": recent_models
}
except Exception as e:
logger.error("Failed to get model statistics",
tenant_id=tenant_id,
error=str(e))
return {
"total_models": 0,
"active_models": 0,
"inactive_models": 0,
"production_models": 0,
"models_by_product": {},
"recent_models_30d": 0
}
async def _deactivate_other_production_models(
self,
tenant_id: str,
product_name: str,
exclude_model_id: str
) -> int:
"""Deactivate other production models for the same tenant+product"""
try:
query = text("""
UPDATE trained_models
SET is_production = false
WHERE tenant_id = :tenant_id
AND product_name = :product_name
AND id != :exclude_model_id
AND is_production = true
""")
result = await self.session.execute(query, {
"tenant_id": tenant_id,
"product_name": product_name,
"exclude_model_id": exclude_model_id
})
return result.rowcount
except Exception as e:
logger.error("Failed to deactivate other production models",
tenant_id=tenant_id,
product_name=product_name,
error=str(e))
raise DatabaseError(f"Failed to deactivate models: {str(e)}")
async def get_model_performance_summary(self, model_id: str) -> Dict[str, Any]:
"""Get performance summary for a model"""
try:
model = await self.get_by_id(model_id)
if not model:
return {}
return {
"model_id": model.id,
"tenant_id": model.tenant_id,
"product_name": model.product_name,
"model_type": model.model_type,
"metrics": {
"mape": model.mape,
"mae": model.mae,
"rmse": model.rmse,
"r2_score": model.r2_score
},
"training_info": {
"training_samples": model.training_samples,
"training_start_date": model.training_start_date.isoformat() if model.training_start_date else None,
"training_end_date": model.training_end_date.isoformat() if model.training_end_date else None,
"data_quality_score": model.data_quality_score
},
"status": {
"is_active": model.is_active,
"is_production": model.is_production,
"created_at": model.created_at.isoformat() if model.created_at else None,
"last_used_at": model.last_used_at.isoformat() if model.last_used_at else None
},
"features": {
"hyperparameters": model.hyperparameters,
"features_used": model.features_used
}
}
except Exception as e:
logger.error("Failed to get model performance summary",
model_id=model_id,
error=str(e))
return {}

View File

@@ -0,0 +1,433 @@
"""
Performance Repository
Repository for model performance metrics operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, text, desc
from datetime import datetime, timedelta
import structlog
from .base import TrainingBaseRepository
from app.models.training import ModelPerformanceMetric
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class PerformanceRepository(TrainingBaseRepository):
"""Repository for model performance metrics operations"""
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 900):
# Performance metrics are relatively stable, longer cache time (15 minutes)
super().__init__(ModelPerformanceMetric, session, cache_ttl)
async def create_performance_metric(self, metric_data: Dict[str, Any]) -> ModelPerformanceMetric:
"""Create a new performance metric record"""
try:
# Validate metric data
validation_result = self._validate_training_data(
metric_data,
["model_id", "tenant_id", "product_name"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid metric data: {validation_result['errors']}")
# Set measurement timestamp if not provided
if "measured_at" not in metric_data:
metric_data["measured_at"] = datetime.now()
# Create metric record
metric = await self.create(metric_data)
logger.info("Performance metric created",
model_id=metric.model_id,
tenant_id=metric.tenant_id,
product_name=metric.product_name)
return metric
except ValidationError:
raise
except Exception as e:
logger.error("Failed to create performance metric",
model_id=metric_data.get("model_id"),
error=str(e))
raise DatabaseError(f"Failed to create metric: {str(e)}")
async def get_metrics_by_model(
self,
model_id: str,
skip: int = 0,
limit: int = 100
) -> List[ModelPerformanceMetric]:
"""Get all performance metrics for a model"""
try:
return await self.get_multi(
filters={"model_id": model_id},
skip=skip,
limit=limit,
order_by="measured_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get metrics by model",
model_id=model_id,
error=str(e))
raise DatabaseError(f"Failed to get metrics: {str(e)}")
async def get_latest_metric_for_model(self, model_id: str) -> Optional[ModelPerformanceMetric]:
"""Get the latest performance metric for a model"""
try:
metrics = await self.get_multi(
filters={"model_id": model_id},
limit=1,
order_by="measured_at",
order_desc=True
)
return metrics[0] if metrics else None
except Exception as e:
logger.error("Failed to get latest metric for model",
model_id=model_id,
error=str(e))
raise DatabaseError(f"Failed to get latest metric: {str(e)}")
async def get_metrics_by_tenant_and_product(
self,
tenant_id: str,
product_name: str,
skip: int = 0,
limit: int = 100
) -> List[ModelPerformanceMetric]:
"""Get performance metrics for a tenant's product"""
try:
return await self.get_multi(
filters={
"tenant_id": tenant_id,
"product_name": product_name
},
skip=skip,
limit=limit,
order_by="measured_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get metrics by tenant and product",
tenant_id=tenant_id,
product_name=product_name,
error=str(e))
raise DatabaseError(f"Failed to get metrics: {str(e)}")
async def get_metrics_in_date_range(
self,
start_date: datetime,
end_date: datetime,
tenant_id: str = None,
model_id: str = None,
skip: int = 0,
limit: int = 100
) -> List[ModelPerformanceMetric]:
"""Get performance metrics within a date range"""
try:
# Build filters
table_name = self.model.__tablename__
conditions = ["measured_at >= :start_date", "measured_at <= :end_date"]
params = {"start_date": start_date, "end_date": end_date, "limit": limit, "skip": skip}
if tenant_id:
conditions.append("tenant_id = :tenant_id")
params["tenant_id"] = tenant_id
if model_id:
conditions.append("model_id = :model_id")
params["model_id"] = model_id
query_text = f"""
SELECT * FROM {table_name}
WHERE {' AND '.join(conditions)}
ORDER BY measured_at DESC
LIMIT :limit OFFSET :skip
"""
result = await self.session.execute(text(query_text), params)
# Convert rows to model objects
metrics = []
for row in result.fetchall():
record_dict = dict(row._mapping)
metric = self.model(**record_dict)
metrics.append(metric)
return metrics
except Exception as e:
logger.error("Failed to get metrics in date range",
start_date=start_date,
end_date=end_date,
error=str(e))
raise DatabaseError(f"Date range query failed: {str(e)}")
async def get_performance_trends(
self,
tenant_id: str,
product_name: str = None,
days: int = 30
) -> Dict[str, Any]:
"""Get performance trends for analysis"""
try:
start_date = datetime.now() - timedelta(days=days)
end_date = datetime.now()
# Build query for performance trends
conditions = ["tenant_id = :tenant_id", "measured_at >= :start_date"]
params = {"tenant_id": tenant_id, "start_date": start_date}
if product_name:
conditions.append("product_name = :product_name")
params["product_name"] = product_name
query_text = f"""
SELECT
product_name,
AVG(mae) as avg_mae,
AVG(mse) as avg_mse,
AVG(rmse) as avg_rmse,
AVG(mape) as avg_mape,
AVG(r2_score) as avg_r2_score,
AVG(accuracy_percentage) as avg_accuracy,
COUNT(*) as measurement_count,
MIN(measured_at) as first_measurement,
MAX(measured_at) as last_measurement
FROM model_performance_metrics
WHERE {' AND '.join(conditions)}
GROUP BY product_name
ORDER BY avg_accuracy DESC
"""
result = await self.session.execute(text(query_text), params)
trends = []
for row in result.fetchall():
trends.append({
"product_name": row.product_name,
"metrics": {
"avg_mae": float(row.avg_mae) if row.avg_mae else None,
"avg_mse": float(row.avg_mse) if row.avg_mse else None,
"avg_rmse": float(row.avg_rmse) if row.avg_rmse else None,
"avg_mape": float(row.avg_mape) if row.avg_mape else None,
"avg_r2_score": float(row.avg_r2_score) if row.avg_r2_score else None,
"avg_accuracy": float(row.avg_accuracy) if row.avg_accuracy else None
},
"measurement_count": int(row.measurement_count),
"period": {
"start": row.first_measurement.isoformat() if row.first_measurement else None,
"end": row.last_measurement.isoformat() if row.last_measurement else None,
"days": days
}
})
return {
"tenant_id": tenant_id,
"product_name": product_name,
"trends": trends,
"period_days": days,
"total_products": len(trends)
}
except Exception as e:
logger.error("Failed to get performance trends",
tenant_id=tenant_id,
product_name=product_name,
error=str(e))
return {
"tenant_id": tenant_id,
"product_name": product_name,
"trends": [],
"period_days": days,
"total_products": 0
}
async def get_best_performing_models(
self,
tenant_id: str,
metric_type: str = "accuracy_percentage",
limit: int = 10
) -> List[Dict[str, Any]]:
"""Get best performing models based on a specific metric"""
try:
# Validate metric type
valid_metrics = ["mae", "mse", "rmse", "mape", "r2_score", "accuracy_percentage"]
if metric_type not in valid_metrics:
metric_type = "accuracy_percentage"
# For error metrics (mae, mse, rmse, mape), lower is better
# For performance metrics (r2_score, accuracy_percentage), higher is better
order_desc = metric_type in ["r2_score", "accuracy_percentage"]
order_direction = "DESC" if order_desc else "ASC"
query_text = f"""
SELECT DISTINCT ON (product_name, model_id)
model_id,
product_name,
{metric_type},
measured_at,
evaluation_samples
FROM model_performance_metrics
WHERE tenant_id = :tenant_id
AND {metric_type} IS NOT NULL
ORDER BY product_name, model_id, measured_at DESC, {metric_type} {order_direction}
LIMIT :limit
"""
result = await self.session.execute(text(query_text), {
"tenant_id": tenant_id,
"limit": limit
})
best_models = []
for row in result.fetchall():
best_models.append({
"model_id": row.model_id,
"product_name": row.product_name,
"metric_value": float(getattr(row, metric_type)),
"metric_type": metric_type,
"measured_at": row.measured_at.isoformat() if row.measured_at else None,
"evaluation_samples": int(row.evaluation_samples) if row.evaluation_samples else None
})
return best_models
except Exception as e:
logger.error("Failed to get best performing models",
tenant_id=tenant_id,
metric_type=metric_type,
error=str(e))
return []
async def cleanup_old_metrics(self, days_old: int = 180) -> int:
"""Clean up old performance metrics"""
return await self.cleanup_old_records(days_old=days_old)
async def get_metric_statistics(self, tenant_id: str) -> Dict[str, Any]:
"""Get performance metric statistics for a tenant"""
try:
# Get basic counts
total_metrics = await self.count(filters={"tenant_id": tenant_id})
# Get metrics by product using raw query
product_query = text("""
SELECT
product_name,
COUNT(*) as metric_count,
AVG(accuracy_percentage) as avg_accuracy
FROM model_performance_metrics
WHERE tenant_id = :tenant_id
GROUP BY product_name
ORDER BY avg_accuracy DESC
""")
result = await self.session.execute(product_query, {"tenant_id": tenant_id})
product_stats = {}
for row in result.fetchall():
product_stats[row.product_name] = {
"metric_count": row.metric_count,
"avg_accuracy": float(row.avg_accuracy) if row.avg_accuracy else None
}
# Recent activity (metrics in last 7 days)
seven_days_ago = datetime.now() - timedelta(days=7)
recent_metrics = len(await self.get_records_by_date_range(
seven_days_ago,
datetime.now(),
limit=1000 # High limit to get accurate count
))
return {
"total_metrics": total_metrics,
"products_tracked": len(product_stats),
"metrics_by_product": product_stats,
"recent_metrics_7d": recent_metrics
}
except Exception as e:
logger.error("Failed to get metric statistics",
tenant_id=tenant_id,
error=str(e))
return {
"total_metrics": 0,
"products_tracked": 0,
"metrics_by_product": {},
"recent_metrics_7d": 0
}
async def compare_model_performance(
self,
model_ids: List[str],
metric_type: str = "accuracy_percentage"
) -> Dict[str, Any]:
"""Compare performance between multiple models"""
try:
if not model_ids or len(model_ids) < 2:
return {"error": "At least 2 model IDs required for comparison"}
# Validate metric type
valid_metrics = ["mae", "mse", "rmse", "mape", "r2_score", "accuracy_percentage"]
if metric_type not in valid_metrics:
metric_type = "accuracy_percentage"
model_ids_str = "', '".join(model_ids)
query_text = f"""
SELECT
model_id,
product_name,
AVG({metric_type}) as avg_metric,
MIN({metric_type}) as min_metric,
MAX({metric_type}) as max_metric,
COUNT(*) as measurement_count,
MAX(measured_at) as latest_measurement
FROM model_performance_metrics
WHERE model_id IN ('{model_ids_str}')
AND {metric_type} IS NOT NULL
GROUP BY model_id, product_name
ORDER BY avg_metric DESC
"""
result = await self.session.execute(text(query_text))
comparisons = []
for row in result.fetchall():
comparisons.append({
"model_id": row.model_id,
"product_name": row.product_name,
"avg_metric": float(row.avg_metric),
"min_metric": float(row.min_metric),
"max_metric": float(row.max_metric),
"measurement_count": int(row.measurement_count),
"latest_measurement": row.latest_measurement.isoformat() if row.latest_measurement else None
})
# Find best and worst performing models
if comparisons:
best_model = max(comparisons, key=lambda x: x["avg_metric"])
worst_model = min(comparisons, key=lambda x: x["avg_metric"])
else:
best_model = worst_model = None
return {
"metric_type": metric_type,
"models_compared": len(set(comp["model_id"] for comp in comparisons)),
"comparisons": comparisons,
"best_performing": best_model,
"worst_performing": worst_model
}
except Exception as e:
logger.error("Failed to compare model performance",
model_ids=model_ids,
metric_type=metric_type,
error=str(e))
return {"error": f"Comparison failed: {str(e)}"}

View File

@@ -0,0 +1,332 @@
"""
Training Log Repository
Repository for model training log operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, text, desc
from datetime import datetime, timedelta
import structlog
from .base import TrainingBaseRepository
from app.models.training import ModelTrainingLog
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class TrainingLogRepository(TrainingBaseRepository):
"""Repository for training log operations"""
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 300):
# Training logs change frequently, shorter cache time (5 minutes)
super().__init__(ModelTrainingLog, session, cache_ttl)
async def create_training_log(self, log_data: Dict[str, Any]) -> ModelTrainingLog:
"""Create a new training log entry"""
try:
# Validate log data
validation_result = self._validate_training_data(
log_data,
["job_id", "tenant_id", "status"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid training log data: {validation_result['errors']}")
# Set default values
if "progress" not in log_data:
log_data["progress"] = 0
if "current_step" not in log_data:
log_data["current_step"] = "initializing"
# Create log entry
log_entry = await self.create(log_data)
logger.info("Training log created",
job_id=log_entry.job_id,
tenant_id=log_entry.tenant_id,
status=log_entry.status)
return log_entry
except ValidationError:
raise
except Exception as e:
logger.error("Failed to create training log",
job_id=log_data.get("job_id"),
error=str(e))
raise DatabaseError(f"Failed to create training log: {str(e)}")
async def get_log_by_job_id(self, job_id: str) -> Optional[ModelTrainingLog]:
"""Get training log by job ID"""
return await self.get_by_job_id(job_id)
async def update_log_progress(
self,
job_id: str,
progress: int,
current_step: str = None,
status: str = None
) -> Optional[ModelTrainingLog]:
"""Update training log progress"""
try:
update_data = {"progress": progress, "updated_at": datetime.now()}
if current_step:
update_data["current_step"] = current_step
if status:
update_data["status"] = status
log_entry = await self.get_by_job_id(job_id)
if not log_entry:
logger.error(f"Training log not found for job {job_id}")
return None
updated_log = await self.update(log_entry.id, update_data)
logger.debug("Training log progress updated",
job_id=job_id,
progress=progress,
step=current_step)
return updated_log
except Exception as e:
logger.error("Failed to update training log progress",
job_id=job_id,
error=str(e))
raise DatabaseError(f"Failed to update progress: {str(e)}")
async def complete_training_log(
self,
job_id: str,
results: Dict[str, Any] = None,
error_message: str = None
) -> Optional[ModelTrainingLog]:
"""Mark training log as completed or failed"""
try:
status = "failed" if error_message else "completed"
update_data = {
"status": status,
"progress": 100 if status == "completed" else None,
"end_time": datetime.now(),
"updated_at": datetime.now()
}
if results:
update_data["results"] = results
if error_message:
update_data["error_message"] = error_message
log_entry = await self.get_by_job_id(job_id)
if not log_entry:
logger.error(f"Training log not found for job {job_id}")
return None
updated_log = await self.update(log_entry.id, update_data)
logger.info("Training log completed",
job_id=job_id,
status=status,
has_results=bool(results))
return updated_log
except Exception as e:
logger.error("Failed to complete training log",
job_id=job_id,
error=str(e))
raise DatabaseError(f"Failed to complete training log: {str(e)}")
async def get_logs_by_tenant(
self,
tenant_id: str,
status: str = None,
skip: int = 0,
limit: int = 100
) -> List[ModelTrainingLog]:
"""Get training logs for a tenant"""
try:
filters = {"tenant_id": tenant_id}
if status:
filters["status"] = status
return await self.get_multi(
filters=filters,
skip=skip,
limit=limit,
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get logs by tenant",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get training logs: {str(e)}")
async def get_active_jobs(self, tenant_id: str = None) -> List[ModelTrainingLog]:
"""Get currently running training jobs"""
try:
filters = {"status": "running"}
if tenant_id:
filters["tenant_id"] = tenant_id
return await self.get_multi(
filters=filters,
order_by="start_time",
order_desc=True
)
except Exception as e:
logger.error("Failed to get active jobs",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get active jobs: {str(e)}")
async def cancel_job(self, job_id: str, cancelled_by: str = None) -> Optional[ModelTrainingLog]:
"""Cancel a training job"""
try:
update_data = {
"status": "cancelled",
"end_time": datetime.now(),
"updated_at": datetime.now()
}
if cancelled_by:
update_data["error_message"] = f"Cancelled by {cancelled_by}"
log_entry = await self.get_by_job_id(job_id)
if not log_entry:
logger.error(f"Training log not found for job {job_id}")
return None
# Only cancel if job is still running
if log_entry.status not in ["pending", "running"]:
logger.warning(f"Cannot cancel job {job_id} with status {log_entry.status}")
return log_entry
updated_log = await self.update(log_entry.id, update_data)
logger.info("Training job cancelled",
job_id=job_id,
cancelled_by=cancelled_by)
return updated_log
except Exception as e:
logger.error("Failed to cancel training job",
job_id=job_id,
error=str(e))
raise DatabaseError(f"Failed to cancel job: {str(e)}")
async def get_job_statistics(self, tenant_id: str = None) -> Dict[str, Any]:
"""Get training job statistics"""
try:
base_filters = {}
if tenant_id:
base_filters["tenant_id"] = tenant_id
# Get counts by status
total_jobs = await self.count(filters=base_filters)
completed_jobs = await self.count(filters={**base_filters, "status": "completed"})
failed_jobs = await self.count(filters={**base_filters, "status": "failed"})
running_jobs = await self.count(filters={**base_filters, "status": "running"})
pending_jobs = await self.count(filters={**base_filters, "status": "pending"})
# Get recent activity (jobs in last 7 days)
seven_days_ago = datetime.now() - timedelta(days=7)
recent_jobs = len(await self.get_records_by_date_range(
seven_days_ago,
datetime.now(),
limit=1000 # High limit to get accurate count
))
# Calculate success rate
finished_jobs = completed_jobs + failed_jobs
success_rate = (completed_jobs / finished_jobs * 100) if finished_jobs > 0 else 0
return {
"total_jobs": total_jobs,
"completed_jobs": completed_jobs,
"failed_jobs": failed_jobs,
"running_jobs": running_jobs,
"pending_jobs": pending_jobs,
"cancelled_jobs": total_jobs - completed_jobs - failed_jobs - running_jobs - pending_jobs,
"success_rate": round(success_rate, 2),
"recent_jobs_7d": recent_jobs
}
except Exception as e:
logger.error("Failed to get job statistics",
tenant_id=tenant_id,
error=str(e))
return {
"total_jobs": 0,
"completed_jobs": 0,
"failed_jobs": 0,
"running_jobs": 0,
"pending_jobs": 0,
"cancelled_jobs": 0,
"success_rate": 0.0,
"recent_jobs_7d": 0
}
async def cleanup_old_logs(self, days_old: int = 90) -> int:
"""Clean up old completed/failed training logs"""
return await self.cleanup_old_records(
days_old=days_old,
status_filter=None # Clean up all old records regardless of status
)
async def get_job_duration_stats(self, tenant_id: str = None) -> Dict[str, Any]:
"""Get job duration statistics"""
try:
# Use raw SQL for complex duration calculations
tenant_filter = "AND tenant_id = :tenant_id" if tenant_id else ""
params = {"tenant_id": tenant_id} if tenant_id else {}
query = text(f"""
SELECT
AVG(EXTRACT(EPOCH FROM (end_time - start_time))/60) as avg_duration_minutes,
MIN(EXTRACT(EPOCH FROM (end_time - start_time))/60) as min_duration_minutes,
MAX(EXTRACT(EPOCH FROM (end_time - start_time))/60) as max_duration_minutes,
COUNT(*) as completed_jobs_with_duration
FROM model_training_logs
WHERE status = 'completed'
AND start_time IS NOT NULL
AND end_time IS NOT NULL
{tenant_filter}
""")
result = await self.session.execute(query, params)
row = result.fetchone()
if row and row.completed_jobs_with_duration > 0:
return {
"avg_duration_minutes": round(float(row.avg_duration_minutes or 0), 2),
"min_duration_minutes": round(float(row.min_duration_minutes or 0), 2),
"max_duration_minutes": round(float(row.max_duration_minutes or 0), 2),
"completed_jobs_with_duration": int(row.completed_jobs_with_duration)
}
return {
"avg_duration_minutes": 0.0,
"min_duration_minutes": 0.0,
"max_duration_minutes": 0.0,
"completed_jobs_with_duration": 0
}
except Exception as e:
logger.error("Failed to get job duration statistics",
tenant_id=tenant_id,
error=str(e))
return {
"avg_duration_minutes": 0.0,
"min_duration_minutes": 0.0,
"max_duration_minutes": 0.0,
"completed_jobs_with_duration": 0
}

View File

@@ -357,7 +357,7 @@ class TrainingErrorUpdate(BaseModel):
class ModelMetricsResponse(BaseModel):
"""Response schema for model performance metrics"""
model_id: str = Field(..., description="Unique model identifier")
accuracy: float = Field(..., description="Model accuracy (R2 score)", ge=0.0, le=1.0)
accuracy: float = Field(..., description="Model accuracy (R2 score)")
mape: float = Field(..., description="Mean Absolute Percentage Error")
mae: float = Field(..., description="Mean Absolute Error")
rmse: float = Field(..., description="Root Mean Square Error")

View File

@@ -0,0 +1,34 @@
"""
Training Service Layer
Business logic services for ML training and model management
"""
from .training_service import TrainingService
from .training_service import EnhancedTrainingService
from .training_orchestrator import TrainingDataOrchestrator
from .date_alignment_service import DateAlignmentService
from .data_client import DataClient
from .messaging import (
publish_job_progress,
publish_data_validation_started,
publish_data_validation_completed,
publish_job_step_completed,
publish_job_completed,
publish_job_failed,
TrainingStatusPublisher
)
__all__ = [
"TrainingService",
"EnhancedTrainingService",
"TrainingDataOrchestrator",
"DateAlignmentService",
"DataClient",
"publish_job_progress",
"publish_data_validation_started",
"publish_data_validation_completed",
"publish_job_step_completed",
"publish_job_completed",
"publish_job_failed",
"TrainingStatusPublisher"
]

File diff suppressed because it is too large Load Diff