REFACTOR - Database logic

This commit is contained in:
Urtzi Alfaro
2025-08-08 09:08:41 +02:00
parent 0154365bfc
commit 488bb3ef93
113 changed files with 22842 additions and 6503 deletions

View File

@@ -0,0 +1,14 @@
"""
Training API Layer
HTTP endpoints for ML training operations
"""
from .training import router as training_router
from .websocket import websocket_router
__all__ = [
"training_router",
"websocket_router"
]

View File

@@ -38,11 +38,12 @@ async def get_active_model(
Get the active model for a product - used by forecasting service
"""
try:
# ✅ FIX: Wrap SQL with text() for SQLAlchemy 2.0
logger.debug("Getting active model", tenant_id=tenant_id, product_name=product_name)
# ✅ FIX: Wrap SQL with text() for SQLAlchemy 2.0 and add case-insensitive product name matching
query = text("""
SELECT * FROM trained_models
WHERE tenant_id = :tenant_id
AND product_name = :product_name
AND LOWER(product_name) = LOWER(:product_name)
AND is_active = true
AND is_production = true
ORDER BY created_at DESC
@@ -57,6 +58,7 @@ async def get_active_model(
model_record = result.fetchone()
if not model_record:
logger.info("No active model found", tenant_id=tenant_id, product_name=product_name)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No active model found for product {product_name}"
@@ -76,7 +78,7 @@ async def get_active_model(
await db.commit()
return {
"model_id": model_record.id, # ✅ This is the correct field name
"model_id": str(model_record.id), # ✅ This is the correct field name
"model_path": model_record.model_path,
"features_used": model_record.features_used,
"hyperparameters": model_record.hyperparameters,
@@ -93,12 +95,24 @@ async def get_active_model(
}
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to get active model: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve model"
)
error_msg = str(e) if str(e) else f"{type(e).__name__}: {repr(e)}"
logger.error(f"Failed to get active model: {error_msg}", tenant_id=tenant_id, product_name=product_name)
# Handle client disconnection gracefully
if "EndOfStream" in str(type(e)) or "WouldBlock" in str(type(e)):
logger.info("Client disconnected during model retrieval", tenant_id=tenant_id, product_name=product_name)
raise HTTPException(
status_code=status.HTTP_408_REQUEST_TIMEOUT,
detail="Request connection closed"
)
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve model"
)
@router.get("/tenants/{tenant_id}/models/{model_id}/metrics", response_model=ModelMetricsResponse)
async def get_model_metrics(
@@ -126,7 +140,7 @@ async def get_model_metrics(
# Return metrics in the format expected by forecasting service
metrics = {
"model_id": model_record.id,
"model_id": str(model_record.id),
"accuracy": model_record.r2_score or 0.0, # Use R2 as accuracy measure
"mape": model_record.mape or 0.0,
"mae": model_record.mae or 0.0,
@@ -189,8 +203,8 @@ async def list_models(
models = []
for record in model_records:
models.append({
"model_id": record.id,
"tenant_id": record.tenant_id,
"model_id": str(record.id),
"tenant_id": str(record.tenant_id),
"product_name": record.product_name,
"model_type": record.model_type,
"model_path": record.model_path,

View File

@@ -1,25 +1,19 @@
# services/training/app/api/training.py
"""
Training API Endpoints - Entry point for training requests
Handles HTTP requests and delegates to Training Service
Enhanced Training API Endpoints with Repository Pattern
Updated to use repository pattern with dependency injection and improved error handling
"""
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks, Request
from fastapi import Query, Path
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, Optional, Dict, Any
import structlog
from datetime import datetime, timezone
import uuid
from app.core.database import get_db, get_background_db_session
from app.services.training_service import TrainingService, TrainingStatusManager
from sqlalchemy import select, delete, func
from app.services.training_service import EnhancedTrainingService
from app.schemas.training import (
TrainingJobRequest,
SingleProductTrainingRequest
)
from app.schemas.training import (
SingleProductTrainingRequest,
TrainingJobResponse
)
@@ -33,47 +27,71 @@ from app.services.messaging import (
publish_job_started
)
from shared.auth.decorators import require_admin_role, get_current_user_dep, get_current_tenant_id_dep
from shared.database.base import create_database_manager
from shared.monitoring.decorators import track_execution_time
from shared.monitoring.metrics import get_metrics_collector
from app.core.config import settings
logger = structlog.get_logger()
router = APIRouter()
router = APIRouter(tags=["enhanced-training"])
def get_enhanced_training_service():
"""Dependency injection for EnhancedTrainingService"""
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
return EnhancedTrainingService(database_manager)
@router.post("/tenants/{tenant_id}/training/jobs", response_model=TrainingJobResponse)
async def start_training_job(
@track_execution_time("enhanced_training_job_duration_seconds", "training-service")
async def start_enhanced_training_job(
request: TrainingJobRequest,
tenant_id: str = Path(..., description="Tenant ID"),
background_tasks: BackgroundTasks = BackgroundTasks(),
request_obj: Request = None,
current_tenant: str = Depends(get_current_tenant_id_dep),
db: AsyncSession = Depends(get_db)
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
):
"""
Start a new training job for all tenant products.
Start a new enhanced training job for all tenant products using repository pattern.
🚀 IMMEDIATE RESPONSE PATTERN:
1. Validate request immediately
2. Create job record with 'pending' status
3. Return 200 with job details
4. Execute training in background with separate DB session
🚀 ENHANCED IMMEDIATE RESPONSE PATTERN:
1. Validate request with enhanced validation
2. Create job record using repository pattern
3. Return 200 with enhanced job details
4. Execute enhanced training in background with repository tracking
This ensures fast API response while maintaining data consistency.
Enhanced features:
- Repository pattern for data access
- Enhanced error handling and logging
- Metrics tracking and monitoring
- Transactional operations
"""
metrics = get_metrics_collector(request_obj)
try:
# Validate tenant access immediately
# Enhanced tenant validation
if tenant_id != current_tenant:
if metrics:
metrics.increment_counter("enhanced_training_access_denied_total")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
# Generate job ID immediately
job_id = f"training_{tenant_id}_{uuid.uuid4().hex[:8]}"
# Generate enhanced job ID
job_id = f"enhanced_training_{tenant_id}_{uuid.uuid4().hex[:8]}"
logger.info(f"Creating training job {job_id} for tenant {tenant_id}")
logger.info("Creating enhanced training job using repository pattern",
job_id=job_id,
tenant_id=tenant_id)
# Add background task with isolated database session
# Record job creation metrics
if metrics:
metrics.increment_counter("enhanced_training_jobs_created_total")
# Add enhanced background task
background_tasks.add_task(
execute_training_job_background,
execute_enhanced_training_job_background,
tenant_id=tenant_id,
job_id=job_id,
bakery_location=(40.4168, -3.7038),
@@ -81,16 +99,16 @@ async def start_training_job(
requested_end=request.end_date
)
# Return immediate success response
# Return enhanced immediate success response
response_data = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "pending", # Will change to 'running' in background
"message": "Training job started successfully",
"status": "pending",
"message": "Enhanced training job started successfully using repository pattern",
"created_at": datetime.now(timezone.utc),
"estimated_duration_minutes": "15",
"estimated_duration_minutes": 18,
"training_results": {
"total_products": 10,
"total_products": 0, # Will be updated during processing
"successful_trainings": 0,
"failed_trainings": 0,
"products": [],
@@ -101,31 +119,45 @@ async def start_training_job(
"error_details": None,
"processing_metadata": {
"background_task": True,
"async_execution": True
"async_execution": True,
"enhanced_features": True,
"repository_pattern": True,
"dependency_injection": True
}
}
logger.info(f"Training job {job_id} queued successfully, returning immediate response")
logger.info("Enhanced training job queued successfully",
job_id=job_id,
features=["repository-pattern", "dependency-injection", "enhanced-tracking"])
return TrainingJobResponse(**response_data)
except HTTPException:
# Re-raise HTTP exceptions as-is
raise
except ValueError as e:
logger.error(f"Training job validation error: {str(e)}")
if metrics:
metrics.increment_counter("enhanced_training_validation_errors_total")
logger.error("Enhanced training job validation error",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e:
logger.error(f"Failed to queue training job: {str(e)}")
if metrics:
metrics.increment_counter("enhanced_training_job_errors_total")
logger.error("Failed to queue enhanced training job",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to start training job"
detail="Failed to start enhanced training job"
)
async def execute_training_job_background(
async def execute_enhanced_training_job_background(
tenant_id: str,
job_id: str,
bakery_location: tuple,
@@ -133,382 +165,457 @@ async def execute_training_job_background(
requested_end: Optional[datetime] = None
):
"""
Background task that executes the actual training job.
Enhanced background task that executes the training job using repository pattern.
🔧 KEY FEATURES:
- Uses its own database session (isolated from API request)
- Handles all errors gracefully
- Updates job status in real-time
- Publishes progress events via WebSocket/messaging
- Comprehensive logging and monitoring
🔧 ENHANCED FEATURES:
- Repository pattern for all data operations
- Enhanced error handling with structured logging
- Transactional operations for data consistency
- Comprehensive metrics tracking
- Database connection pooling
- Enhanced progress reporting
"""
logger.info(f"🚀 Background training job {job_id} started for tenant {tenant_id}")
logger.info("Enhanced background training job started",
job_id=job_id,
tenant_id=tenant_id,
features=["repository-pattern", "enhanced-tracking"])
async with get_background_db_session() as db_session:
try:
# ✅ FIX: Create training service with isolated DB session
training_service = TrainingService(db_session=db_session)
status_manager = TrainingStatusManager(db_session=db_session)
try:
training_config = {
"job_id": job_id,
"tenant_id": tenant_id,
"bakery_location": {
"latitude": 40.4168,
"longitude": -3.7038
},
"requested_start": requested_start if requested_start else None,
"requested_end": requested_end if requested_end else None,
"estimated_duration_minutes": 15,
"estimated_products": None,
"background_execution": True,
"api_version": "v1"
}
await status_manager.update_job_status(
job_id=job_id,
status="running",
progress=0,
current_step="Initializing training pipeline"
)
# Execute the actual training pipeline
result = await training_service.start_training_job(
tenant_id=tenant_id,
job_id=job_id,
bakery_location=bakery_location,
requested_start=requested_start,
requested_end=requested_end
)
await status_manager.update_job_status(
job_id=job_id,
status="completed",
progress=100,
current_step="Training completed successfully",
results=result
)
# Publish completion event
await publish_job_completed(
job_id=job_id,
tenant_id=tenant_id,
results=result
)
logger.info(f"✅ Background training job {job_id} completed successfully")
except Exception as training_error:
logger.error(f"❌ Training pipeline failed for job {job_id}: {str(training_error)}")
await status_manager.update_job_status(
job_id=job_id,
status="failed",
progress=0,
current_step="Training failed",
error_message=str(training_error)
)
# Publish failure event
await publish_job_failed(
job_id=job_id,
tenant_id=tenant_id,
error=str(training_error)
)
except Exception as background_error:
logger.error(f"💥 Critical error in background training job {job_id}: {str(background_error)}")
# Get enhanced training service with dependency injection
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
enhanced_training_service = EnhancedTrainingService(database_manager)
try:
# Publish job started event
await publish_job_started(job_id, tenant_id, {
"enhanced_features": True,
"repository_pattern": True,
"job_type": "enhanced_training"
})
finally:
# Ensure database session is properly closed
logger.info(f"🧹 Background training job {job_id} cleanup completed")
training_config = {
"job_id": job_id,
"tenant_id": tenant_id,
"bakery_location": {
"latitude": bakery_location[0],
"longitude": bakery_location[1]
},
"requested_start": requested_start.isoformat() if requested_start else None,
"requested_end": requested_end.isoformat() if requested_end else None,
"estimated_duration_minutes": 18,
"background_execution": True,
"enhanced_features": True,
"repository_pattern": True,
"api_version": "enhanced_v1"
}
# Update job status using repository pattern
await enhanced_training_service._update_job_status_repository(
job_id=job_id,
status="running",
progress=0,
current_step="Initializing enhanced training pipeline"
)
# Execute the enhanced training pipeline with repository pattern
result = await enhanced_training_service.start_training_job(
tenant_id=tenant_id,
job_id=job_id,
bakery_location=bakery_location,
requested_start=requested_start,
requested_end=requested_end
)
# Update final status using repository pattern
await enhanced_training_service._update_job_status_repository(
job_id=job_id,
status="completed",
progress=100,
current_step="Enhanced training completed successfully",
results=result
)
# Publish enhanced completion event
await publish_job_completed(
job_id=job_id,
tenant_id=tenant_id,
results={
**result,
"enhanced_features": True,
"repository_integration": True
}
)
logger.info("Enhanced background training job completed successfully",
job_id=job_id,
models_created=result.get('products_trained', 0),
features=["repository-pattern", "enhanced-tracking"])
except Exception as training_error:
logger.error("Enhanced training pipeline failed",
job_id=job_id,
error=str(training_error))
try:
await enhanced_training_service._update_job_status_repository(
job_id=job_id,
status="failed",
progress=0,
current_step="Enhanced training failed",
error_message=str(training_error)
)
except Exception as status_error:
logger.error("Failed to update job status after training error",
job_id=job_id,
status_error=str(status_error))
# Publish enhanced failure event
await publish_job_failed(
job_id=job_id,
tenant_id=tenant_id,
error=str(training_error),
metadata={
"enhanced_features": True,
"repository_pattern": True,
"error_type": type(training_error).__name__
}
)
except Exception as background_error:
logger.error("Critical error in enhanced background training job",
job_id=job_id,
error=str(background_error))
finally:
logger.info("Enhanced background training job cleanup completed",
job_id=job_id)
@router.post("/tenants/{tenant_id}/training/products/{product_name}", response_model=TrainingJobResponse)
async def start_single_product_training(
@track_execution_time("enhanced_single_product_training_duration_seconds", "training-service")
async def start_enhanced_single_product_training(
request: SingleProductTrainingRequest,
tenant_id: str = Path(..., description="Tenant ID"),
product_name: str = Path(..., description="Product name"),
request_obj: Request = None,
current_tenant: str = Depends(get_current_tenant_id_dep),
db: AsyncSession = Depends(get_db)
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
):
"""
Start training for a single product.
Start enhanced training for a single product using repository pattern.
Uses the same pipeline but filters for specific product.
Enhanced features:
- Repository pattern for data access
- Enhanced error handling and validation
- Metrics tracking
- Transactional operations
"""
training_service = TrainingService(db_session=db)
metrics = get_metrics_collector(request_obj)
try:
# Validate tenant access
# Enhanced tenant validation
if tenant_id != current_tenant:
if metrics:
metrics.increment_counter("enhanced_single_product_access_denied_total")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
logger.info(f"Starting single product training for {product_name} (tenant {tenant_id})")
logger.info("Starting enhanced single product training",
product_name=product_name,
tenant_id=tenant_id)
# Delegate to training service
result = await training_service.start_single_product_training(
# Record metrics
if metrics:
metrics.increment_counter("enhanced_single_product_training_total")
# Generate enhanced job ID
job_id = f"enhanced_single_{tenant_id}_{product_name}_{uuid.uuid4().hex[:8]}"
# Delegate to enhanced training service (single product method to be implemented)
result = await enhanced_training_service.start_single_product_training(
tenant_id=tenant_id,
product_name=product_name,
sales_data=request.sales_data,
bakery_location=request.bakery_location or (40.4168, -3.7038),
weather_data=request.weather_data,
traffic_data=request.traffic_data,
job_id=request.job_id
job_id=job_id,
bakery_location=request.bakery_location or (40.4168, -3.7038)
)
if metrics:
metrics.increment_counter("enhanced_single_product_training_success_total")
logger.info("Enhanced single product training completed",
product_name=product_name,
job_id=job_id)
return TrainingJobResponse(**result)
except ValueError as e:
logger.error(f"Single product training validation error: {str(e)}")
if metrics:
metrics.increment_counter("enhanced_single_product_validation_errors_total")
logger.error("Enhanced single product training validation error",
error=str(e),
product_name=product_name)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e:
logger.error(f"Single product training failed: {str(e)}")
if metrics:
metrics.increment_counter("enhanced_single_product_training_errors_total")
logger.error("Enhanced single product training failed",
error=str(e),
product_name=product_name)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Single product training failed"
detail="Enhanced single product training failed"
)
@router.get("/tenants/{tenant_id}/training/jobs/{job_id}/logs")
async def get_training_logs(
@router.get("/tenants/{tenant_id}/training/jobs/{job_id}/status")
@track_execution_time("enhanced_job_status_duration_seconds", "training-service")
async def get_enhanced_training_job_status(
tenant_id: str = Path(..., description="Tenant ID"),
job_id: str = Path(..., description="Job ID"),
limit: int = Query(100, description="Number of log entries to return"),
request_obj: Request = None,
current_tenant: str = Depends(get_current_tenant_id_dep),
db: AsyncSession = Depends(get_db)
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
):
"""
Get training job logs.
Get enhanced training job status using repository pattern.
"""
metrics = get_metrics_collector(request_obj)
try:
# Validate tenant access
if tenant_id != current_tenant:
if metrics:
metrics.increment_counter("enhanced_status_access_denied_total")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
# TODO: Implement log retrieval
# Get status using enhanced service
status_info = await enhanced_training_service.get_training_status(job_id)
if not status_info or status_info.get("error"):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Training job not found"
)
if metrics:
metrics.increment_counter("enhanced_status_requests_total")
return {
"job_id": job_id,
"logs": [
f"Training job {job_id} started",
"Data preprocessing completed",
"Model training completed",
"Training job finished successfully"
]
**status_info,
"enhanced_features": True,
"repository_integration": True
}
except HTTPException:
raise
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_status_errors_total")
logger.error("Failed to get enhanced training status",
job_id=job_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get training status"
)
@router.get("/tenants/{tenant_id}/models")
@track_execution_time("enhanced_models_list_duration_seconds", "training-service")
async def get_enhanced_tenant_models(
tenant_id: str = Path(..., description="Tenant ID"),
active_only: bool = Query(True, description="Return only active models"),
skip: int = Query(0, description="Number of models to skip"),
limit: int = Query(100, description="Number of models to return"),
request_obj: Request = None,
current_tenant: str = Depends(get_current_tenant_id_dep),
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
):
"""
Get tenant models using enhanced repository pattern.
"""
metrics = get_metrics_collector(request_obj)
try:
# Validate tenant access
if tenant_id != current_tenant:
if metrics:
metrics.increment_counter("enhanced_models_access_denied_total")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
# Get models using enhanced service
models = await enhanced_training_service.get_tenant_models(
tenant_id=tenant_id,
active_only=active_only,
skip=skip,
limit=limit
)
if metrics:
metrics.increment_counter("enhanced_models_requests_total")
return {
"tenant_id": tenant_id,
"models": models,
"total_returned": len(models),
"active_only": active_only,
"pagination": {
"skip": skip,
"limit": limit
},
"enhanced_features": True,
"repository_integration": True
}
except Exception as e:
logger.error(f"Failed to get training logs: {str(e)}")
if metrics:
metrics.increment_counter("enhanced_models_errors_total")
logger.error("Failed to get enhanced tenant models",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get training logs"
detail="Failed to get tenant models"
)
@router.get("/health")
async def health_check():
@router.get("/tenants/{tenant_id}/models/{model_id}/performance")
@track_execution_time("enhanced_model_performance_duration_seconds", "training-service")
async def get_enhanced_model_performance(
tenant_id: str = Path(..., description="Tenant ID"),
model_id: str = Path(..., description="Model ID"),
request_obj: Request = None,
current_tenant: str = Depends(get_current_tenant_id_dep),
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
):
"""
Health check endpoint for the training service.
Get enhanced model performance metrics using repository pattern.
"""
metrics = get_metrics_collector(request_obj)
try:
# Validate tenant access
if tenant_id != current_tenant:
if metrics:
metrics.increment_counter("enhanced_performance_access_denied_total")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
# Get performance using enhanced service
performance = await enhanced_training_service.get_model_performance(model_id)
if not performance:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model performance not found"
)
if metrics:
metrics.increment_counter("enhanced_performance_requests_total")
return {
**performance,
"enhanced_features": True,
"repository_integration": True
}
except HTTPException:
raise
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_performance_errors_total")
logger.error("Failed to get enhanced model performance",
model_id=model_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get model performance"
)
@router.get("/tenants/{tenant_id}/statistics")
@track_execution_time("enhanced_tenant_statistics_duration_seconds", "training-service")
async def get_enhanced_tenant_statistics(
tenant_id: str = Path(..., description="Tenant ID"),
request_obj: Request = None,
current_tenant: str = Depends(get_current_tenant_id_dep),
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
):
"""
Get comprehensive enhanced tenant statistics using repository pattern.
"""
metrics = get_metrics_collector(request_obj)
try:
# Validate tenant access
if tenant_id != current_tenant:
if metrics:
metrics.increment_counter("enhanced_statistics_access_denied_total")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
# Get statistics using enhanced service
statistics = await enhanced_training_service.get_tenant_statistics(tenant_id)
if statistics.get("error"):
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=statistics["error"]
)
if metrics:
metrics.increment_counter("enhanced_statistics_requests_total")
return {
**statistics,
"enhanced_features": True,
"repository_integration": True
}
except HTTPException:
raise
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_statistics_errors_total")
logger.error("Failed to get enhanced tenant statistics",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get tenant statistics"
)
@router.get("/health")
async def enhanced_health_check():
"""
Enhanced health check endpoint for the training service.
"""
return {
"status": "healthy",
"service": "training",
"version": "1.0.0",
"service": "enhanced-training-service",
"version": "2.0.0",
"features": [
"repository-pattern",
"dependency-injection",
"enhanced-error-handling",
"metrics-tracking",
"transactional-operations"
],
"timestamp": datetime.now().isoformat()
}
@router.post("/tenants/{tenant_id}/training/jobs/cancel")
async def cancel_tenant_training_jobs(
cancel_data: dict, # {"tenant_id": str}
current_user = Depends(get_current_user_dep),
_admin_check = Depends(require_admin_role),
db: AsyncSession = Depends(get_db)
):
"""Cancel all active training jobs for a tenant (admin only)"""
try:
tenant_id = cancel_data.get("tenant_id")
if not tenant_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="tenant_id is required"
)
tenant_uuid = uuid.UUID(tenant_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid tenant ID format"
)
try:
from app.models.training import TrainingJobQueue
# Find all active jobs for the tenant
active_jobs_query = select(TrainingJobQueue).where(
TrainingJobQueue.tenant_id == tenant_uuid,
TrainingJobQueue.status.in_(["queued", "running", "pending"])
)
active_jobs_result = await db.execute(active_jobs_query)
active_jobs = active_jobs_result.scalars().all()
jobs_cancelled = 0
cancelled_job_ids = []
errors = []
for job in active_jobs:
try:
job.status = "cancelled"
job.updated_at = datetime.utcnow()
job.cancelled_by = current_user.get("user_id")
jobs_cancelled += 1
cancelled_job_ids.append(str(job.id))
logger.info("Cancelled training job",
job_id=str(job.id),
tenant_id=tenant_id)
except Exception as e:
error_msg = f"Failed to cancel job {job.id}: {str(e)}"
errors.append(error_msg)
logger.error(error_msg)
if jobs_cancelled > 0:
await db.commit()
result = {
"success": True,
"tenant_id": tenant_id,
"jobs_cancelled": jobs_cancelled,
"cancelled_job_ids": cancelled_job_ids,
"errors": errors,
"cancelled_at": datetime.utcnow().isoformat()
}
if errors:
result["success"] = len(errors) < len(active_jobs)
return result
except Exception as e:
await db.rollback()
logger.error("Failed to cancel tenant training jobs",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to cancel training jobs"
)
@router.get("/tenants/{tenant_id}/training/jobs/active")
async def get_tenant_active_jobs(
tenant_id: str,
current_user = Depends(get_current_user_dep),
_admin_check = Depends(require_admin_role),
db: AsyncSession = Depends(get_db)
):
"""Get all active training jobs for a tenant (admin only)"""
try:
tenant_uuid = uuid.UUID(tenant_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid tenant ID format"
)
try:
from app.models.training import TrainingJobQueue
# Get active jobs
active_jobs_query = select(TrainingJobQueue).where(
TrainingJobQueue.tenant_id == tenant_uuid,
TrainingJobQueue.status.in_(["queued", "running", "pending"])
)
active_jobs_result = await db.execute(active_jobs_query)
active_jobs = active_jobs_result.scalars().all()
jobs = []
for job in active_jobs:
jobs.append({
"id": str(job.id),
"tenant_id": str(job.tenant_id),
"status": job.status,
"created_at": job.created_at.isoformat() if job.created_at else None,
"updated_at": job.updated_at.isoformat() if job.updated_at else None,
"started_at": job.started_at.isoformat() if job.started_at else None,
"progress": getattr(job, 'progress', 0)
})
return {
"tenant_id": tenant_id,
"active_jobs_count": len(jobs),
"jobs": jobs
}
except Exception as e:
logger.error("Failed to get tenant active jobs",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get active jobs"
)
@router.get("/tenants/{tenant_id}/training/jobs/count")
async def get_tenant_models_count(
tenant_id: str,
current_user = Depends(get_current_user_dep),
_admin_check = Depends(require_admin_role),
db: AsyncSession = Depends(get_db)
):
"""Get count of trained models for a tenant (admin only)"""
try:
tenant_uuid = uuid.UUID(tenant_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid tenant ID format"
)
try:
from app.models.training import TrainedModel, ModelArtifact
# Count models
models_count_query = select(func.count(TrainedModel.id)).where(
TrainedModel.tenant_id == tenant_uuid
)
models_count_result = await db.execute(models_count_query)
models_count = models_count_result.scalar()
# Count artifacts
artifacts_count_query = select(func.count(ModelArtifact.id)).where(
ModelArtifact.tenant_id == tenant_uuid
)
artifacts_count_result = await db.execute(artifacts_count_query)
artifacts_count = artifacts_count_result.scalar()
return {
"tenant_id": tenant_id,
"models_count": models_count,
"artifacts_count": artifacts_count,
"total_training_assets": models_count + artifacts_count
}
except Exception as e:
logger.error("Failed to get tenant models count",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get models count"
)
}