REFACTOR - Database logic
This commit is contained in:
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
Training API Layer
|
||||
HTTP endpoints for ML training operations
|
||||
"""
|
||||
|
||||
from .training import router as training_router
|
||||
|
||||
from .websocket import websocket_router
|
||||
|
||||
__all__ = [
|
||||
"training_router",
|
||||
|
||||
"websocket_router"
|
||||
]
|
||||
@@ -38,11 +38,12 @@ async def get_active_model(
|
||||
Get the active model for a product - used by forecasting service
|
||||
"""
|
||||
try:
|
||||
# ✅ FIX: Wrap SQL with text() for SQLAlchemy 2.0
|
||||
logger.debug("Getting active model", tenant_id=tenant_id, product_name=product_name)
|
||||
# ✅ FIX: Wrap SQL with text() for SQLAlchemy 2.0 and add case-insensitive product name matching
|
||||
query = text("""
|
||||
SELECT * FROM trained_models
|
||||
WHERE tenant_id = :tenant_id
|
||||
AND product_name = :product_name
|
||||
AND LOWER(product_name) = LOWER(:product_name)
|
||||
AND is_active = true
|
||||
AND is_production = true
|
||||
ORDER BY created_at DESC
|
||||
@@ -57,6 +58,7 @@ async def get_active_model(
|
||||
model_record = result.fetchone()
|
||||
|
||||
if not model_record:
|
||||
logger.info("No active model found", tenant_id=tenant_id, product_name=product_name)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No active model found for product {product_name}"
|
||||
@@ -76,7 +78,7 @@ async def get_active_model(
|
||||
await db.commit()
|
||||
|
||||
return {
|
||||
"model_id": model_record.id, # ✅ This is the correct field name
|
||||
"model_id": str(model_record.id), # ✅ This is the correct field name
|
||||
"model_path": model_record.model_path,
|
||||
"features_used": model_record.features_used,
|
||||
"hyperparameters": model_record.hyperparameters,
|
||||
@@ -93,12 +95,24 @@ async def get_active_model(
|
||||
}
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get active model: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve model"
|
||||
)
|
||||
error_msg = str(e) if str(e) else f"{type(e).__name__}: {repr(e)}"
|
||||
logger.error(f"Failed to get active model: {error_msg}", tenant_id=tenant_id, product_name=product_name)
|
||||
|
||||
# Handle client disconnection gracefully
|
||||
if "EndOfStream" in str(type(e)) or "WouldBlock" in str(type(e)):
|
||||
logger.info("Client disconnected during model retrieval", tenant_id=tenant_id, product_name=product_name)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_408_REQUEST_TIMEOUT,
|
||||
detail="Request connection closed"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve model"
|
||||
)
|
||||
|
||||
@router.get("/tenants/{tenant_id}/models/{model_id}/metrics", response_model=ModelMetricsResponse)
|
||||
async def get_model_metrics(
|
||||
@@ -126,7 +140,7 @@ async def get_model_metrics(
|
||||
|
||||
# Return metrics in the format expected by forecasting service
|
||||
metrics = {
|
||||
"model_id": model_record.id,
|
||||
"model_id": str(model_record.id),
|
||||
"accuracy": model_record.r2_score or 0.0, # Use R2 as accuracy measure
|
||||
"mape": model_record.mape or 0.0,
|
||||
"mae": model_record.mae or 0.0,
|
||||
@@ -189,8 +203,8 @@ async def list_models(
|
||||
models = []
|
||||
for record in model_records:
|
||||
models.append({
|
||||
"model_id": record.id,
|
||||
"tenant_id": record.tenant_id,
|
||||
"model_id": str(record.id),
|
||||
"tenant_id": str(record.tenant_id),
|
||||
"product_name": record.product_name,
|
||||
"model_type": record.model_type,
|
||||
"model_path": record.model_path,
|
||||
|
||||
@@ -1,25 +1,19 @@
|
||||
# services/training/app/api/training.py
|
||||
"""
|
||||
Training API Endpoints - Entry point for training requests
|
||||
Handles HTTP requests and delegates to Training Service
|
||||
Enhanced Training API Endpoints with Repository Pattern
|
||||
Updated to use repository pattern with dependency injection and improved error handling
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks, Request
|
||||
from fastapi import Query, Path
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import List, Optional, Dict, Any
|
||||
import structlog
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
|
||||
from app.core.database import get_db, get_background_db_session
|
||||
from app.services.training_service import TrainingService, TrainingStatusManager
|
||||
from sqlalchemy import select, delete, func
|
||||
from app.services.training_service import EnhancedTrainingService
|
||||
from app.schemas.training import (
|
||||
TrainingJobRequest,
|
||||
SingleProductTrainingRequest
|
||||
)
|
||||
from app.schemas.training import (
|
||||
SingleProductTrainingRequest,
|
||||
TrainingJobResponse
|
||||
)
|
||||
|
||||
@@ -33,47 +27,71 @@ from app.services.messaging import (
|
||||
publish_job_started
|
||||
)
|
||||
|
||||
|
||||
from shared.auth.decorators import require_admin_role, get_current_user_dep, get_current_tenant_id_dep
|
||||
from shared.database.base import create_database_manager
|
||||
from shared.monitoring.decorators import track_execution_time
|
||||
from shared.monitoring.metrics import get_metrics_collector
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter()
|
||||
router = APIRouter(tags=["enhanced-training"])
|
||||
|
||||
def get_enhanced_training_service():
|
||||
"""Dependency injection for EnhancedTrainingService"""
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
|
||||
return EnhancedTrainingService(database_manager)
|
||||
|
||||
@router.post("/tenants/{tenant_id}/training/jobs", response_model=TrainingJobResponse)
|
||||
async def start_training_job(
|
||||
@track_execution_time("enhanced_training_job_duration_seconds", "training-service")
|
||||
async def start_enhanced_training_job(
|
||||
request: TrainingJobRequest,
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
background_tasks: BackgroundTasks = BackgroundTasks(),
|
||||
request_obj: Request = None,
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
|
||||
):
|
||||
"""
|
||||
Start a new training job for all tenant products.
|
||||
Start a new enhanced training job for all tenant products using repository pattern.
|
||||
|
||||
🚀 IMMEDIATE RESPONSE PATTERN:
|
||||
1. Validate request immediately
|
||||
2. Create job record with 'pending' status
|
||||
3. Return 200 with job details
|
||||
4. Execute training in background with separate DB session
|
||||
🚀 ENHANCED IMMEDIATE RESPONSE PATTERN:
|
||||
1. Validate request with enhanced validation
|
||||
2. Create job record using repository pattern
|
||||
3. Return 200 with enhanced job details
|
||||
4. Execute enhanced training in background with repository tracking
|
||||
|
||||
This ensures fast API response while maintaining data consistency.
|
||||
Enhanced features:
|
||||
- Repository pattern for data access
|
||||
- Enhanced error handling and logging
|
||||
- Metrics tracking and monitoring
|
||||
- Transactional operations
|
||||
"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
# Validate tenant access immediately
|
||||
# Enhanced tenant validation
|
||||
if tenant_id != current_tenant:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_training_access_denied_total")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to tenant resources"
|
||||
)
|
||||
|
||||
# Generate job ID immediately
|
||||
job_id = f"training_{tenant_id}_{uuid.uuid4().hex[:8]}"
|
||||
# Generate enhanced job ID
|
||||
job_id = f"enhanced_training_{tenant_id}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
logger.info(f"Creating training job {job_id} for tenant {tenant_id}")
|
||||
logger.info("Creating enhanced training job using repository pattern",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id)
|
||||
|
||||
# Add background task with isolated database session
|
||||
# Record job creation metrics
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_training_jobs_created_total")
|
||||
|
||||
# Add enhanced background task
|
||||
background_tasks.add_task(
|
||||
execute_training_job_background,
|
||||
execute_enhanced_training_job_background,
|
||||
tenant_id=tenant_id,
|
||||
job_id=job_id,
|
||||
bakery_location=(40.4168, -3.7038),
|
||||
@@ -81,16 +99,16 @@ async def start_training_job(
|
||||
requested_end=request.end_date
|
||||
)
|
||||
|
||||
# Return immediate success response
|
||||
# Return enhanced immediate success response
|
||||
response_data = {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"status": "pending", # Will change to 'running' in background
|
||||
"message": "Training job started successfully",
|
||||
"status": "pending",
|
||||
"message": "Enhanced training job started successfully using repository pattern",
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"estimated_duration_minutes": "15",
|
||||
"estimated_duration_minutes": 18,
|
||||
"training_results": {
|
||||
"total_products": 10,
|
||||
"total_products": 0, # Will be updated during processing
|
||||
"successful_trainings": 0,
|
||||
"failed_trainings": 0,
|
||||
"products": [],
|
||||
@@ -101,31 +119,45 @@ async def start_training_job(
|
||||
"error_details": None,
|
||||
"processing_metadata": {
|
||||
"background_task": True,
|
||||
"async_execution": True
|
||||
"async_execution": True,
|
||||
"enhanced_features": True,
|
||||
"repository_pattern": True,
|
||||
"dependency_injection": True
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(f"Training job {job_id} queued successfully, returning immediate response")
|
||||
logger.info("Enhanced training job queued successfully",
|
||||
job_id=job_id,
|
||||
features=["repository-pattern", "dependency-injection", "enhanced-tracking"])
|
||||
|
||||
return TrainingJobResponse(**response_data)
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTP exceptions as-is
|
||||
raise
|
||||
except ValueError as e:
|
||||
logger.error(f"Training job validation error: {str(e)}")
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_training_validation_errors_total")
|
||||
logger.error("Enhanced training job validation error",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to queue training job: {str(e)}")
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_training_job_errors_total")
|
||||
logger.error("Failed to queue enhanced training job",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to start training job"
|
||||
detail="Failed to start enhanced training job"
|
||||
)
|
||||
|
||||
|
||||
async def execute_training_job_background(
|
||||
async def execute_enhanced_training_job_background(
|
||||
tenant_id: str,
|
||||
job_id: str,
|
||||
bakery_location: tuple,
|
||||
@@ -133,382 +165,457 @@ async def execute_training_job_background(
|
||||
requested_end: Optional[datetime] = None
|
||||
):
|
||||
"""
|
||||
Background task that executes the actual training job.
|
||||
Enhanced background task that executes the training job using repository pattern.
|
||||
|
||||
🔧 KEY FEATURES:
|
||||
- Uses its own database session (isolated from API request)
|
||||
- Handles all errors gracefully
|
||||
- Updates job status in real-time
|
||||
- Publishes progress events via WebSocket/messaging
|
||||
- Comprehensive logging and monitoring
|
||||
🔧 ENHANCED FEATURES:
|
||||
- Repository pattern for all data operations
|
||||
- Enhanced error handling with structured logging
|
||||
- Transactional operations for data consistency
|
||||
- Comprehensive metrics tracking
|
||||
- Database connection pooling
|
||||
- Enhanced progress reporting
|
||||
"""
|
||||
|
||||
logger.info(f"🚀 Background training job {job_id} started for tenant {tenant_id}")
|
||||
logger.info("Enhanced background training job started",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
features=["repository-pattern", "enhanced-tracking"])
|
||||
|
||||
async with get_background_db_session() as db_session:
|
||||
try:
|
||||
# ✅ FIX: Create training service with isolated DB session
|
||||
training_service = TrainingService(db_session=db_session)
|
||||
|
||||
status_manager = TrainingStatusManager(db_session=db_session)
|
||||
|
||||
try:
|
||||
|
||||
training_config = {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"bakery_location": {
|
||||
"latitude": 40.4168,
|
||||
"longitude": -3.7038
|
||||
},
|
||||
"requested_start": requested_start if requested_start else None,
|
||||
"requested_end": requested_end if requested_end else None,
|
||||
"estimated_duration_minutes": 15,
|
||||
"estimated_products": None,
|
||||
"background_execution": True,
|
||||
"api_version": "v1"
|
||||
}
|
||||
|
||||
await status_manager.update_job_status(
|
||||
job_id=job_id,
|
||||
status="running",
|
||||
progress=0,
|
||||
current_step="Initializing training pipeline"
|
||||
)
|
||||
|
||||
# Execute the actual training pipeline
|
||||
result = await training_service.start_training_job(
|
||||
tenant_id=tenant_id,
|
||||
job_id=job_id,
|
||||
bakery_location=bakery_location,
|
||||
requested_start=requested_start,
|
||||
requested_end=requested_end
|
||||
)
|
||||
|
||||
await status_manager.update_job_status(
|
||||
job_id=job_id,
|
||||
status="completed",
|
||||
progress=100,
|
||||
current_step="Training completed successfully",
|
||||
results=result
|
||||
)
|
||||
|
||||
# Publish completion event
|
||||
await publish_job_completed(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
results=result
|
||||
)
|
||||
|
||||
logger.info(f"✅ Background training job {job_id} completed successfully")
|
||||
|
||||
except Exception as training_error:
|
||||
logger.error(f"❌ Training pipeline failed for job {job_id}: {str(training_error)}")
|
||||
|
||||
await status_manager.update_job_status(
|
||||
job_id=job_id,
|
||||
status="failed",
|
||||
progress=0,
|
||||
current_step="Training failed",
|
||||
error_message=str(training_error)
|
||||
)
|
||||
|
||||
# Publish failure event
|
||||
await publish_job_failed(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
error=str(training_error)
|
||||
)
|
||||
|
||||
except Exception as background_error:
|
||||
logger.error(f"💥 Critical error in background training job {job_id}: {str(background_error)}")
|
||||
# Get enhanced training service with dependency injection
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
|
||||
enhanced_training_service = EnhancedTrainingService(database_manager)
|
||||
|
||||
try:
|
||||
# Publish job started event
|
||||
await publish_job_started(job_id, tenant_id, {
|
||||
"enhanced_features": True,
|
||||
"repository_pattern": True,
|
||||
"job_type": "enhanced_training"
|
||||
})
|
||||
|
||||
finally:
|
||||
# Ensure database session is properly closed
|
||||
logger.info(f"🧹 Background training job {job_id} cleanup completed")
|
||||
training_config = {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"bakery_location": {
|
||||
"latitude": bakery_location[0],
|
||||
"longitude": bakery_location[1]
|
||||
},
|
||||
"requested_start": requested_start.isoformat() if requested_start else None,
|
||||
"requested_end": requested_end.isoformat() if requested_end else None,
|
||||
"estimated_duration_minutes": 18,
|
||||
"background_execution": True,
|
||||
"enhanced_features": True,
|
||||
"repository_pattern": True,
|
||||
"api_version": "enhanced_v1"
|
||||
}
|
||||
|
||||
# Update job status using repository pattern
|
||||
await enhanced_training_service._update_job_status_repository(
|
||||
job_id=job_id,
|
||||
status="running",
|
||||
progress=0,
|
||||
current_step="Initializing enhanced training pipeline"
|
||||
)
|
||||
|
||||
# Execute the enhanced training pipeline with repository pattern
|
||||
result = await enhanced_training_service.start_training_job(
|
||||
tenant_id=tenant_id,
|
||||
job_id=job_id,
|
||||
bakery_location=bakery_location,
|
||||
requested_start=requested_start,
|
||||
requested_end=requested_end
|
||||
)
|
||||
|
||||
# Update final status using repository pattern
|
||||
await enhanced_training_service._update_job_status_repository(
|
||||
job_id=job_id,
|
||||
status="completed",
|
||||
progress=100,
|
||||
current_step="Enhanced training completed successfully",
|
||||
results=result
|
||||
)
|
||||
|
||||
# Publish enhanced completion event
|
||||
await publish_job_completed(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
results={
|
||||
**result,
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True
|
||||
}
|
||||
)
|
||||
|
||||
logger.info("Enhanced background training job completed successfully",
|
||||
job_id=job_id,
|
||||
models_created=result.get('products_trained', 0),
|
||||
features=["repository-pattern", "enhanced-tracking"])
|
||||
|
||||
except Exception as training_error:
|
||||
logger.error("Enhanced training pipeline failed",
|
||||
job_id=job_id,
|
||||
error=str(training_error))
|
||||
|
||||
try:
|
||||
await enhanced_training_service._update_job_status_repository(
|
||||
job_id=job_id,
|
||||
status="failed",
|
||||
progress=0,
|
||||
current_step="Enhanced training failed",
|
||||
error_message=str(training_error)
|
||||
)
|
||||
except Exception as status_error:
|
||||
logger.error("Failed to update job status after training error",
|
||||
job_id=job_id,
|
||||
status_error=str(status_error))
|
||||
|
||||
# Publish enhanced failure event
|
||||
await publish_job_failed(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
error=str(training_error),
|
||||
metadata={
|
||||
"enhanced_features": True,
|
||||
"repository_pattern": True,
|
||||
"error_type": type(training_error).__name__
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as background_error:
|
||||
logger.error("Critical error in enhanced background training job",
|
||||
job_id=job_id,
|
||||
error=str(background_error))
|
||||
|
||||
finally:
|
||||
logger.info("Enhanced background training job cleanup completed",
|
||||
job_id=job_id)
|
||||
|
||||
|
||||
@router.post("/tenants/{tenant_id}/training/products/{product_name}", response_model=TrainingJobResponse)
|
||||
async def start_single_product_training(
|
||||
@track_execution_time("enhanced_single_product_training_duration_seconds", "training-service")
|
||||
async def start_enhanced_single_product_training(
|
||||
request: SingleProductTrainingRequest,
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
product_name: str = Path(..., description="Product name"),
|
||||
request_obj: Request = None,
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
|
||||
):
|
||||
"""
|
||||
Start training for a single product.
|
||||
Start enhanced training for a single product using repository pattern.
|
||||
|
||||
Uses the same pipeline but filters for specific product.
|
||||
Enhanced features:
|
||||
- Repository pattern for data access
|
||||
- Enhanced error handling and validation
|
||||
- Metrics tracking
|
||||
- Transactional operations
|
||||
"""
|
||||
|
||||
training_service = TrainingService(db_session=db)
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
# Validate tenant access
|
||||
# Enhanced tenant validation
|
||||
if tenant_id != current_tenant:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_single_product_access_denied_total")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to tenant resources"
|
||||
)
|
||||
|
||||
logger.info(f"Starting single product training for {product_name} (tenant {tenant_id})")
|
||||
logger.info("Starting enhanced single product training",
|
||||
product_name=product_name,
|
||||
tenant_id=tenant_id)
|
||||
|
||||
# Delegate to training service
|
||||
result = await training_service.start_single_product_training(
|
||||
# Record metrics
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_single_product_training_total")
|
||||
|
||||
# Generate enhanced job ID
|
||||
job_id = f"enhanced_single_{tenant_id}_{product_name}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Delegate to enhanced training service (single product method to be implemented)
|
||||
result = await enhanced_training_service.start_single_product_training(
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
sales_data=request.sales_data,
|
||||
bakery_location=request.bakery_location or (40.4168, -3.7038),
|
||||
weather_data=request.weather_data,
|
||||
traffic_data=request.traffic_data,
|
||||
job_id=request.job_id
|
||||
job_id=job_id,
|
||||
bakery_location=request.bakery_location or (40.4168, -3.7038)
|
||||
)
|
||||
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_single_product_training_success_total")
|
||||
|
||||
logger.info("Enhanced single product training completed",
|
||||
product_name=product_name,
|
||||
job_id=job_id)
|
||||
|
||||
return TrainingJobResponse(**result)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Single product training validation error: {str(e)}")
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_single_product_validation_errors_total")
|
||||
logger.error("Enhanced single product training validation error",
|
||||
error=str(e),
|
||||
product_name=product_name)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Single product training failed: {str(e)}")
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_single_product_training_errors_total")
|
||||
logger.error("Enhanced single product training failed",
|
||||
error=str(e),
|
||||
product_name=product_name)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Single product training failed"
|
||||
detail="Enhanced single product training failed"
|
||||
)
|
||||
|
||||
@router.get("/tenants/{tenant_id}/training/jobs/{job_id}/logs")
|
||||
async def get_training_logs(
|
||||
|
||||
@router.get("/tenants/{tenant_id}/training/jobs/{job_id}/status")
|
||||
@track_execution_time("enhanced_job_status_duration_seconds", "training-service")
|
||||
async def get_enhanced_training_job_status(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
job_id: str = Path(..., description="Job ID"),
|
||||
limit: int = Query(100, description="Number of log entries to return"),
|
||||
request_obj: Request = None,
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
|
||||
):
|
||||
"""
|
||||
Get training job logs.
|
||||
Get enhanced training job status using repository pattern.
|
||||
"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
# Validate tenant access
|
||||
if tenant_id != current_tenant:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_status_access_denied_total")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to tenant resources"
|
||||
)
|
||||
|
||||
# TODO: Implement log retrieval
|
||||
# Get status using enhanced service
|
||||
status_info = await enhanced_training_service.get_training_status(job_id)
|
||||
|
||||
if not status_info or status_info.get("error"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Training job not found"
|
||||
)
|
||||
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_status_requests_total")
|
||||
|
||||
return {
|
||||
"job_id": job_id,
|
||||
"logs": [
|
||||
f"Training job {job_id} started",
|
||||
"Data preprocessing completed",
|
||||
"Model training completed",
|
||||
"Training job finished successfully"
|
||||
]
|
||||
**status_info,
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_status_errors_total")
|
||||
logger.error("Failed to get enhanced training status",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get training status"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}/models")
|
||||
@track_execution_time("enhanced_models_list_duration_seconds", "training-service")
|
||||
async def get_enhanced_tenant_models(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
active_only: bool = Query(True, description="Return only active models"),
|
||||
skip: int = Query(0, description="Number of models to skip"),
|
||||
limit: int = Query(100, description="Number of models to return"),
|
||||
request_obj: Request = None,
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
|
||||
):
|
||||
"""
|
||||
Get tenant models using enhanced repository pattern.
|
||||
"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
# Validate tenant access
|
||||
if tenant_id != current_tenant:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_models_access_denied_total")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to tenant resources"
|
||||
)
|
||||
|
||||
# Get models using enhanced service
|
||||
models = await enhanced_training_service.get_tenant_models(
|
||||
tenant_id=tenant_id,
|
||||
active_only=active_only,
|
||||
skip=skip,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_models_requests_total")
|
||||
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"models": models,
|
||||
"total_returned": len(models),
|
||||
"active_only": active_only,
|
||||
"pagination": {
|
||||
"skip": skip,
|
||||
"limit": limit
|
||||
},
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get training logs: {str(e)}")
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_models_errors_total")
|
||||
logger.error("Failed to get enhanced tenant models",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get training logs"
|
||||
detail="Failed to get tenant models"
|
||||
)
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check():
|
||||
|
||||
@router.get("/tenants/{tenant_id}/models/{model_id}/performance")
|
||||
@track_execution_time("enhanced_model_performance_duration_seconds", "training-service")
|
||||
async def get_enhanced_model_performance(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
model_id: str = Path(..., description="Model ID"),
|
||||
request_obj: Request = None,
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
|
||||
):
|
||||
"""
|
||||
Health check endpoint for the training service.
|
||||
Get enhanced model performance metrics using repository pattern.
|
||||
"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
# Validate tenant access
|
||||
if tenant_id != current_tenant:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_performance_access_denied_total")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to tenant resources"
|
||||
)
|
||||
|
||||
# Get performance using enhanced service
|
||||
performance = await enhanced_training_service.get_model_performance(model_id)
|
||||
|
||||
if not performance:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model performance not found"
|
||||
)
|
||||
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_performance_requests_total")
|
||||
|
||||
return {
|
||||
**performance,
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_performance_errors_total")
|
||||
logger.error("Failed to get enhanced model performance",
|
||||
model_id=model_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get model performance"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}/statistics")
|
||||
@track_execution_time("enhanced_tenant_statistics_duration_seconds", "training-service")
|
||||
async def get_enhanced_tenant_statistics(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
request_obj: Request = None,
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
|
||||
):
|
||||
"""
|
||||
Get comprehensive enhanced tenant statistics using repository pattern.
|
||||
"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
# Validate tenant access
|
||||
if tenant_id != current_tenant:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_statistics_access_denied_total")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to tenant resources"
|
||||
)
|
||||
|
||||
# Get statistics using enhanced service
|
||||
statistics = await enhanced_training_service.get_tenant_statistics(tenant_id)
|
||||
|
||||
if statistics.get("error"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=statistics["error"]
|
||||
)
|
||||
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_statistics_requests_total")
|
||||
|
||||
return {
|
||||
**statistics,
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_statistics_errors_total")
|
||||
logger.error("Failed to get enhanced tenant statistics",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get tenant statistics"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def enhanced_health_check():
|
||||
"""
|
||||
Enhanced health check endpoint for the training service.
|
||||
"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "training",
|
||||
"version": "1.0.0",
|
||||
"service": "enhanced-training-service",
|
||||
"version": "2.0.0",
|
||||
"features": [
|
||||
"repository-pattern",
|
||||
"dependency-injection",
|
||||
"enhanced-error-handling",
|
||||
"metrics-tracking",
|
||||
"transactional-operations"
|
||||
],
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
@router.post("/tenants/{tenant_id}/training/jobs/cancel")
|
||||
async def cancel_tenant_training_jobs(
|
||||
cancel_data: dict, # {"tenant_id": str}
|
||||
current_user = Depends(get_current_user_dep),
|
||||
_admin_check = Depends(require_admin_role),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Cancel all active training jobs for a tenant (admin only)"""
|
||||
try:
|
||||
tenant_id = cancel_data.get("tenant_id")
|
||||
if not tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="tenant_id is required"
|
||||
)
|
||||
|
||||
tenant_uuid = uuid.UUID(tenant_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid tenant ID format"
|
||||
)
|
||||
|
||||
try:
|
||||
from app.models.training import TrainingJobQueue
|
||||
|
||||
# Find all active jobs for the tenant
|
||||
active_jobs_query = select(TrainingJobQueue).where(
|
||||
TrainingJobQueue.tenant_id == tenant_uuid,
|
||||
TrainingJobQueue.status.in_(["queued", "running", "pending"])
|
||||
)
|
||||
active_jobs_result = await db.execute(active_jobs_query)
|
||||
active_jobs = active_jobs_result.scalars().all()
|
||||
|
||||
jobs_cancelled = 0
|
||||
cancelled_job_ids = []
|
||||
errors = []
|
||||
|
||||
for job in active_jobs:
|
||||
try:
|
||||
job.status = "cancelled"
|
||||
job.updated_at = datetime.utcnow()
|
||||
job.cancelled_by = current_user.get("user_id")
|
||||
jobs_cancelled += 1
|
||||
cancelled_job_ids.append(str(job.id))
|
||||
|
||||
logger.info("Cancelled training job",
|
||||
job_id=str(job.id),
|
||||
tenant_id=tenant_id)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to cancel job {job.id}: {str(e)}"
|
||||
errors.append(error_msg)
|
||||
logger.error(error_msg)
|
||||
|
||||
if jobs_cancelled > 0:
|
||||
await db.commit()
|
||||
|
||||
result = {
|
||||
"success": True,
|
||||
"tenant_id": tenant_id,
|
||||
"jobs_cancelled": jobs_cancelled,
|
||||
"cancelled_job_ids": cancelled_job_ids,
|
||||
"errors": errors,
|
||||
"cancelled_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
if errors:
|
||||
result["success"] = len(errors) < len(active_jobs)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error("Failed to cancel tenant training jobs",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to cancel training jobs"
|
||||
)
|
||||
|
||||
@router.get("/tenants/{tenant_id}/training/jobs/active")
|
||||
async def get_tenant_active_jobs(
|
||||
tenant_id: str,
|
||||
current_user = Depends(get_current_user_dep),
|
||||
_admin_check = Depends(require_admin_role),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get all active training jobs for a tenant (admin only)"""
|
||||
try:
|
||||
tenant_uuid = uuid.UUID(tenant_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid tenant ID format"
|
||||
)
|
||||
|
||||
try:
|
||||
from app.models.training import TrainingJobQueue
|
||||
|
||||
# Get active jobs
|
||||
active_jobs_query = select(TrainingJobQueue).where(
|
||||
TrainingJobQueue.tenant_id == tenant_uuid,
|
||||
TrainingJobQueue.status.in_(["queued", "running", "pending"])
|
||||
)
|
||||
active_jobs_result = await db.execute(active_jobs_query)
|
||||
active_jobs = active_jobs_result.scalars().all()
|
||||
|
||||
jobs = []
|
||||
for job in active_jobs:
|
||||
jobs.append({
|
||||
"id": str(job.id),
|
||||
"tenant_id": str(job.tenant_id),
|
||||
"status": job.status,
|
||||
"created_at": job.created_at.isoformat() if job.created_at else None,
|
||||
"updated_at": job.updated_at.isoformat() if job.updated_at else None,
|
||||
"started_at": job.started_at.isoformat() if job.started_at else None,
|
||||
"progress": getattr(job, 'progress', 0)
|
||||
})
|
||||
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"active_jobs_count": len(jobs),
|
||||
"jobs": jobs
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get tenant active jobs",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get active jobs"
|
||||
)
|
||||
|
||||
@router.get("/tenants/{tenant_id}/training/jobs/count")
|
||||
async def get_tenant_models_count(
|
||||
tenant_id: str,
|
||||
current_user = Depends(get_current_user_dep),
|
||||
_admin_check = Depends(require_admin_role),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get count of trained models for a tenant (admin only)"""
|
||||
try:
|
||||
tenant_uuid = uuid.UUID(tenant_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid tenant ID format"
|
||||
)
|
||||
|
||||
try:
|
||||
from app.models.training import TrainedModel, ModelArtifact
|
||||
|
||||
# Count models
|
||||
models_count_query = select(func.count(TrainedModel.id)).where(
|
||||
TrainedModel.tenant_id == tenant_uuid
|
||||
)
|
||||
models_count_result = await db.execute(models_count_query)
|
||||
models_count = models_count_result.scalar()
|
||||
|
||||
# Count artifacts
|
||||
artifacts_count_query = select(func.count(ModelArtifact.id)).where(
|
||||
ModelArtifact.tenant_id == tenant_uuid
|
||||
)
|
||||
artifacts_count_result = await db.execute(artifacts_count_query)
|
||||
artifacts_count = artifacts_count_result.scalar()
|
||||
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"models_count": models_count,
|
||||
"artifacts_count": artifacts_count,
|
||||
"total_training_assets": models_count + artifacts_count
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get tenant models count",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get models count"
|
||||
)
|
||||
}
|
||||
Reference in New Issue
Block a user