REFACTOR ALL APIs
This commit is contained in:
@@ -3,12 +3,12 @@ Training API Layer
|
||||
HTTP endpoints for ML training operations
|
||||
"""
|
||||
|
||||
from .training import router as training_router
|
||||
|
||||
from .websocket import websocket_router
|
||||
from .training_jobs import router as training_jobs_router
|
||||
from .training_operations import router as training_operations_router
|
||||
from .models import router as models_router
|
||||
|
||||
__all__ = [
|
||||
"training_router",
|
||||
|
||||
"websocket_router"
|
||||
"training_jobs_router",
|
||||
"training_operations_router",
|
||||
"models_router"
|
||||
]
|
||||
@@ -22,13 +22,27 @@ from shared.auth.decorators import (
|
||||
get_current_user_dep,
|
||||
require_admin_role
|
||||
)
|
||||
from shared.routing import RouteBuilder
|
||||
from shared.auth.access_control import (
|
||||
require_user_role,
|
||||
admin_role_required,
|
||||
owner_role_required,
|
||||
require_subscription_tier,
|
||||
analytics_tier_required,
|
||||
enterprise_tier_required
|
||||
)
|
||||
|
||||
# Create route builder for consistent URL structure
|
||||
route_builder = RouteBuilder('training')
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter()
|
||||
|
||||
training_service = TrainingService()
|
||||
|
||||
@router.get("/tenants/{tenant_id}/models/{inventory_product_id}/active")
|
||||
@router.get(
|
||||
route_builder.build_base_route("models") + "/{inventory_product_id}/active"
|
||||
)
|
||||
async def get_active_model(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
inventory_product_id: str = Path(..., description="Inventory product UUID"),
|
||||
@@ -114,7 +128,10 @@ async def get_active_model(
|
||||
detail="Failed to retrieve model"
|
||||
)
|
||||
|
||||
@router.get("/tenants/{tenant_id}/models/{model_id}/metrics", response_model=ModelMetricsResponse)
|
||||
@router.get(
|
||||
route_builder.build_nested_resource_route("models", "model_id", "metrics"),
|
||||
response_model=ModelMetricsResponse
|
||||
)
|
||||
async def get_model_metrics(
|
||||
model_id: str = Path(..., description="Model ID"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
@@ -168,7 +185,10 @@ async def get_model_metrics(
|
||||
detail="Failed to retrieve model metrics"
|
||||
)
|
||||
|
||||
@router.get("/tenants/{tenant_id}/models", response_model=List[TrainedModelResponse])
|
||||
@router.get(
|
||||
route_builder.build_base_route("models"),
|
||||
response_model=List[TrainedModelResponse]
|
||||
)
|
||||
async def list_models(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
status: Optional[str] = Query(None, description="Filter by status (active/inactive)"),
|
||||
@@ -235,6 +255,7 @@ async def list_models(
|
||||
)
|
||||
|
||||
@router.delete("/models/tenant/{tenant_id}")
|
||||
@require_user_role(['admin', 'owner'])
|
||||
async def delete_tenant_models_complete(
|
||||
tenant_id: str,
|
||||
current_user = Depends(get_current_user_dep),
|
||||
|
||||
@@ -1,577 +0,0 @@
|
||||
"""
|
||||
Enhanced Training API Endpoints with Repository Pattern
|
||||
Updated to use repository pattern with dependency injection and improved error handling
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks, Request
|
||||
from fastapi import Query, Path
|
||||
from typing import List, Optional, Dict, Any
|
||||
import structlog
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
|
||||
from app.services.training_service import EnhancedTrainingService
|
||||
from app.schemas.training import (
|
||||
TrainingJobRequest,
|
||||
SingleProductTrainingRequest,
|
||||
TrainingJobResponse
|
||||
)
|
||||
|
||||
from app.services.messaging import (
|
||||
publish_job_progress,
|
||||
publish_data_validation_started,
|
||||
publish_data_validation_completed,
|
||||
publish_job_step_completed,
|
||||
publish_job_completed,
|
||||
publish_job_failed,
|
||||
publish_job_started
|
||||
)
|
||||
|
||||
from shared.auth.decorators import require_admin_role, get_current_user_dep
|
||||
from shared.database.base import create_database_manager
|
||||
from shared.monitoring.decorators import track_execution_time
|
||||
from shared.monitoring.metrics import get_metrics_collector
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter(tags=["enhanced-training"])
|
||||
|
||||
def get_enhanced_training_service():
|
||||
"""Dependency injection for EnhancedTrainingService"""
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
|
||||
return EnhancedTrainingService(database_manager)
|
||||
|
||||
@router.post("/tenants/{tenant_id}/training/jobs", response_model=TrainingJobResponse)
|
||||
@track_execution_time("enhanced_training_job_duration_seconds", "training-service")
|
||||
async def start_enhanced_training_job(
|
||||
request: TrainingJobRequest,
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
background_tasks: BackgroundTasks = BackgroundTasks(),
|
||||
request_obj: Request = None,
|
||||
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
|
||||
):
|
||||
"""
|
||||
Start a new enhanced training job for all tenant products using repository pattern.
|
||||
|
||||
🚀 ENHANCED IMMEDIATE RESPONSE PATTERN:
|
||||
1. Validate request with enhanced validation
|
||||
2. Create job record using repository pattern
|
||||
3. Return 200 with enhanced job details
|
||||
4. Execute enhanced training in background with repository tracking
|
||||
|
||||
Enhanced features:
|
||||
- Repository pattern for data access
|
||||
- Enhanced error handling and logging
|
||||
- Metrics tracking and monitoring
|
||||
- Transactional operations
|
||||
"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
# Generate enhanced job ID
|
||||
job_id = f"enhanced_training_{tenant_id}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
logger.info("Creating enhanced training job using repository pattern",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id)
|
||||
|
||||
# Record job creation metrics
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_training_jobs_created_total")
|
||||
|
||||
# Add enhanced background task
|
||||
background_tasks.add_task(
|
||||
execute_enhanced_training_job_background,
|
||||
tenant_id=tenant_id,
|
||||
job_id=job_id,
|
||||
bakery_location=(40.4168, -3.7038),
|
||||
requested_start=request.start_date,
|
||||
requested_end=request.end_date
|
||||
)
|
||||
|
||||
# Return enhanced immediate success response
|
||||
response_data = {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"status": "pending",
|
||||
"message": "Enhanced training job started successfully using repository pattern",
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"estimated_duration_minutes": 18,
|
||||
"training_results": {
|
||||
"total_products": 0, # Will be updated during processing
|
||||
"successful_trainings": 0,
|
||||
"failed_trainings": 0,
|
||||
"products": [],
|
||||
"overall_training_time_seconds": 0.0
|
||||
},
|
||||
"data_summary": None,
|
||||
"completed_at": None,
|
||||
"error_details": None,
|
||||
"processing_metadata": {
|
||||
"background_task": True,
|
||||
"async_execution": True,
|
||||
"enhanced_features": True,
|
||||
"repository_pattern": True,
|
||||
"dependency_injection": True
|
||||
}
|
||||
}
|
||||
|
||||
logger.info("Enhanced training job queued successfully",
|
||||
job_id=job_id,
|
||||
features=["repository-pattern", "dependency-injection", "enhanced-tracking"])
|
||||
|
||||
return TrainingJobResponse(**response_data)
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTP exceptions as-is
|
||||
raise
|
||||
except ValueError as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_training_validation_errors_total")
|
||||
logger.error("Enhanced training job validation error",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_training_job_errors_total")
|
||||
logger.error("Failed to queue enhanced training job",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to start enhanced training job"
|
||||
)
|
||||
|
||||
|
||||
async def execute_enhanced_training_job_background(
|
||||
tenant_id: str,
|
||||
job_id: str,
|
||||
bakery_location: tuple,
|
||||
requested_start: Optional[datetime] = None,
|
||||
requested_end: Optional[datetime] = None
|
||||
):
|
||||
"""
|
||||
Enhanced background task that executes the training job using repository pattern.
|
||||
|
||||
🔧 ENHANCED FEATURES:
|
||||
- Repository pattern for all data operations
|
||||
- Enhanced error handling with structured logging
|
||||
- Transactional operations for data consistency
|
||||
- Comprehensive metrics tracking
|
||||
- Database connection pooling
|
||||
- Enhanced progress reporting
|
||||
"""
|
||||
|
||||
logger.info("Enhanced background training job started",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
features=["repository-pattern", "enhanced-tracking"])
|
||||
|
||||
# Get enhanced training service with dependency injection
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
|
||||
enhanced_training_service = EnhancedTrainingService(database_manager)
|
||||
|
||||
try:
|
||||
# Create initial training log entry first
|
||||
await enhanced_training_service._update_job_status_repository(
|
||||
job_id=job_id,
|
||||
status="pending",
|
||||
progress=0,
|
||||
current_step="Starting enhanced training job",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# Publish job started event
|
||||
await publish_job_started(job_id, tenant_id, {
|
||||
"enhanced_features": True,
|
||||
"repository_pattern": True,
|
||||
"job_type": "enhanced_training"
|
||||
})
|
||||
|
||||
training_config = {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"bakery_location": {
|
||||
"latitude": bakery_location[0],
|
||||
"longitude": bakery_location[1]
|
||||
},
|
||||
"requested_start": requested_start.isoformat() if requested_start else None,
|
||||
"requested_end": requested_end.isoformat() if requested_end else None,
|
||||
"estimated_duration_minutes": 18,
|
||||
"background_execution": True,
|
||||
"enhanced_features": True,
|
||||
"repository_pattern": True,
|
||||
"api_version": "enhanced_v1"
|
||||
}
|
||||
|
||||
# Update job status using repository pattern
|
||||
await enhanced_training_service._update_job_status_repository(
|
||||
job_id=job_id,
|
||||
status="running",
|
||||
progress=0,
|
||||
current_step="Initializing enhanced training pipeline",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# Execute the enhanced training pipeline with repository pattern
|
||||
result = await enhanced_training_service.start_training_job(
|
||||
tenant_id=tenant_id,
|
||||
job_id=job_id,
|
||||
bakery_location=bakery_location,
|
||||
requested_start=requested_start,
|
||||
requested_end=requested_end
|
||||
)
|
||||
|
||||
# Update final status using repository pattern
|
||||
await enhanced_training_service._update_job_status_repository(
|
||||
job_id=job_id,
|
||||
status="completed",
|
||||
progress=100,
|
||||
current_step="Enhanced training completed successfully",
|
||||
results=result,
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# Publish enhanced completion event
|
||||
await publish_job_completed(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
results={
|
||||
**result,
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True
|
||||
}
|
||||
)
|
||||
|
||||
logger.info("Enhanced background training job completed successfully",
|
||||
job_id=job_id,
|
||||
models_created=result.get('products_trained', 0),
|
||||
features=["repository-pattern", "enhanced-tracking"])
|
||||
|
||||
except Exception as training_error:
|
||||
logger.error("Enhanced training pipeline failed",
|
||||
job_id=job_id,
|
||||
error=str(training_error))
|
||||
|
||||
try:
|
||||
await enhanced_training_service._update_job_status_repository(
|
||||
job_id=job_id,
|
||||
status="failed",
|
||||
progress=0,
|
||||
current_step="Enhanced training failed",
|
||||
error_message=str(training_error),
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
except Exception as status_error:
|
||||
logger.error("Failed to update job status after training error",
|
||||
job_id=job_id,
|
||||
status_error=str(status_error))
|
||||
|
||||
# Publish enhanced failure event
|
||||
await publish_job_failed(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
error=str(training_error),
|
||||
metadata={
|
||||
"enhanced_features": True,
|
||||
"repository_pattern": True,
|
||||
"error_type": type(training_error).__name__
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as background_error:
|
||||
logger.error("Critical error in enhanced background training job",
|
||||
job_id=job_id,
|
||||
error=str(background_error))
|
||||
|
||||
finally:
|
||||
logger.info("Enhanced background training job cleanup completed",
|
||||
job_id=job_id)
|
||||
|
||||
|
||||
@router.post("/tenants/{tenant_id}/training/products/{inventory_product_id}", response_model=TrainingJobResponse)
|
||||
@track_execution_time("enhanced_single_product_training_duration_seconds", "training-service")
|
||||
async def start_enhanced_single_product_training(
|
||||
request: SingleProductTrainingRequest,
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
inventory_product_id: str = Path(..., description="Inventory product UUID"),
|
||||
request_obj: Request = None,
|
||||
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
|
||||
):
|
||||
"""
|
||||
Start enhanced training for a single product using repository pattern.
|
||||
|
||||
Enhanced features:
|
||||
- Repository pattern for data access
|
||||
- Enhanced error handling and validation
|
||||
- Metrics tracking
|
||||
- Transactional operations
|
||||
"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
|
||||
logger.info("Starting enhanced single product training",
|
||||
inventory_product_id=inventory_product_id,
|
||||
tenant_id=tenant_id)
|
||||
|
||||
# Record metrics
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_single_product_training_total")
|
||||
|
||||
# Generate enhanced job ID
|
||||
job_id = f"enhanced_single_{tenant_id}_{inventory_product_id}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Delegate to enhanced training service (single product method to be implemented)
|
||||
result = await enhanced_training_service.start_single_product_training(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
job_id=job_id,
|
||||
bakery_location=request.bakery_location or (40.4168, -3.7038)
|
||||
)
|
||||
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_single_product_training_success_total")
|
||||
|
||||
logger.info("Enhanced single product training completed",
|
||||
inventory_product_id=inventory_product_id,
|
||||
job_id=job_id)
|
||||
|
||||
return TrainingJobResponse(**result)
|
||||
|
||||
except ValueError as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_single_product_validation_errors_total")
|
||||
logger.error("Enhanced single product training validation error",
|
||||
error=str(e),
|
||||
inventory_product_id=inventory_product_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_single_product_training_errors_total")
|
||||
logger.error("Enhanced single product training failed",
|
||||
error=str(e),
|
||||
inventory_product_id=inventory_product_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Enhanced single product training failed"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}/training/jobs/{job_id}/status")
|
||||
@track_execution_time("enhanced_job_status_duration_seconds", "training-service")
|
||||
async def get_enhanced_training_job_status(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
job_id: str = Path(..., description="Job ID"),
|
||||
request_obj: Request = None,
|
||||
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
|
||||
):
|
||||
"""
|
||||
Get enhanced training job status using repository pattern.
|
||||
"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
|
||||
# Get status using enhanced service
|
||||
status_info = await enhanced_training_service.get_training_status(job_id)
|
||||
|
||||
if not status_info or status_info.get("error"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Training job not found"
|
||||
)
|
||||
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_status_requests_total")
|
||||
|
||||
return {
|
||||
**status_info,
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_status_errors_total")
|
||||
logger.error("Failed to get enhanced training status",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get training status"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}/models")
|
||||
@track_execution_time("enhanced_models_list_duration_seconds", "training-service")
|
||||
async def get_enhanced_tenant_models(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
active_only: bool = Query(True, description="Return only active models"),
|
||||
skip: int = Query(0, description="Number of models to skip"),
|
||||
limit: int = Query(100, description="Number of models to return"),
|
||||
request_obj: Request = None,
|
||||
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
|
||||
):
|
||||
"""
|
||||
Get tenant models using enhanced repository pattern.
|
||||
"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
|
||||
# Get models using enhanced service
|
||||
models = await enhanced_training_service.get_tenant_models(
|
||||
tenant_id=tenant_id,
|
||||
active_only=active_only,
|
||||
skip=skip,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_models_requests_total")
|
||||
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"models": models,
|
||||
"total_returned": len(models),
|
||||
"active_only": active_only,
|
||||
"pagination": {
|
||||
"skip": skip,
|
||||
"limit": limit
|
||||
},
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_models_errors_total")
|
||||
logger.error("Failed to get enhanced tenant models",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get tenant models"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}/models/{model_id}/performance")
|
||||
@track_execution_time("enhanced_model_performance_duration_seconds", "training-service")
|
||||
async def get_enhanced_model_performance(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
model_id: str = Path(..., description="Model ID"),
|
||||
request_obj: Request = None,
|
||||
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
|
||||
):
|
||||
"""
|
||||
Get enhanced model performance metrics using repository pattern.
|
||||
"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
# Get performance using enhanced service
|
||||
performance = await enhanced_training_service.get_model_performance(model_id)
|
||||
|
||||
if not performance:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model performance not found"
|
||||
)
|
||||
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_performance_requests_total")
|
||||
|
||||
return {
|
||||
**performance,
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_performance_errors_total")
|
||||
logger.error("Failed to get enhanced model performance",
|
||||
model_id=model_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get model performance"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}/statistics")
|
||||
@track_execution_time("enhanced_tenant_statistics_duration_seconds", "training-service")
|
||||
async def get_enhanced_tenant_statistics(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
request_obj: Request = None,
|
||||
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
|
||||
):
|
||||
"""
|
||||
Get comprehensive enhanced tenant statistics using repository pattern.
|
||||
"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
|
||||
# Get statistics using enhanced service
|
||||
statistics = await enhanced_training_service.get_tenant_statistics(tenant_id)
|
||||
|
||||
if statistics.get("error"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=statistics["error"]
|
||||
)
|
||||
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_statistics_requests_total")
|
||||
|
||||
return {
|
||||
**statistics,
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_statistics_errors_total")
|
||||
logger.error("Failed to get enhanced tenant statistics",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get tenant statistics"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def enhanced_health_check():
|
||||
"""
|
||||
Enhanced health check endpoint for the training service.
|
||||
"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "enhanced-training-service",
|
||||
"version": "2.0.0",
|
||||
"features": [
|
||||
"repository-pattern",
|
||||
"dependency-injection",
|
||||
"enhanced-error-handling",
|
||||
"metrics-tracking",
|
||||
"transactional-operations"
|
||||
],
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
123
services/training/app/api/training_jobs.py
Normal file
123
services/training/app/api/training_jobs.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
Training Jobs API - ATOMIC CRUD operations
|
||||
Handles basic training job creation and retrieval
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Path, Query, Request
|
||||
from typing import List, Optional
|
||||
import structlog
|
||||
from shared.routing import RouteBuilder
|
||||
from shared.monitoring.decorators import track_execution_time
|
||||
from shared.monitoring.metrics import get_metrics_collector
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
from app.services.training_service import EnhancedTrainingService
|
||||
from app.schemas.training import TrainingJobResponse
|
||||
from shared.database.base import create_database_manager
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
route_builder = RouteBuilder('training')
|
||||
|
||||
router = APIRouter(tags=["training-jobs"])
|
||||
|
||||
def get_enhanced_training_service():
|
||||
"""Dependency injection for EnhancedTrainingService"""
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
|
||||
return EnhancedTrainingService(database_manager)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_nested_resource_route("jobs", "job_id", "status")
|
||||
)
|
||||
@track_execution_time("enhanced_job_status_duration_seconds", "training-service")
|
||||
async def get_training_job_status(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
job_id: str = Path(..., description="Job ID"),
|
||||
request_obj: Request = None,
|
||||
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
|
||||
):
|
||||
"""
|
||||
Get training job status using repository pattern.
|
||||
"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
# Get status using enhanced service
|
||||
status_info = await enhanced_training_service.get_training_status(job_id)
|
||||
|
||||
if not status_info or status_info.get("error"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Training job not found"
|
||||
)
|
||||
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_status_requests_total")
|
||||
|
||||
return {
|
||||
**status_info,
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_status_errors_total")
|
||||
logger.error("Failed to get training status",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get training status"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("statistics")
|
||||
)
|
||||
@track_execution_time("enhanced_tenant_statistics_duration_seconds", "training-service")
|
||||
async def get_tenant_statistics(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
request_obj: Request = None,
|
||||
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
|
||||
):
|
||||
"""
|
||||
Get comprehensive tenant statistics using repository pattern.
|
||||
"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
# Get statistics using enhanced service
|
||||
statistics = await enhanced_training_service.get_tenant_statistics(tenant_id)
|
||||
|
||||
if statistics.get("error"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=statistics["error"]
|
||||
)
|
||||
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_statistics_requests_total")
|
||||
|
||||
return {
|
||||
**statistics,
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_statistics_errors_total")
|
||||
logger.error("Failed to get tenant statistics",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get tenant statistics"
|
||||
)
|
||||
730
services/training/app/api/training_operations.py
Normal file
730
services/training/app/api/training_operations.py
Normal file
@@ -0,0 +1,730 @@
|
||||
"""
|
||||
Training Operations API - BUSINESS logic
|
||||
Handles training job execution, metrics, and WebSocket live feed
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks, Request, Path, WebSocket, WebSocketDisconnect
|
||||
from typing import List, Optional, Dict, Any
|
||||
import structlog
|
||||
import asyncio
|
||||
import json
|
||||
import datetime
|
||||
from shared.auth.access_control import require_user_role, admin_role_required, analytics_tier_required
|
||||
from shared.routing import RouteBuilder
|
||||
from shared.monitoring.decorators import track_execution_time
|
||||
from shared.monitoring.metrics import get_metrics_collector
|
||||
from shared.database.base import create_database_manager
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
|
||||
from app.services.training_service import EnhancedTrainingService
|
||||
from app.schemas.training import (
|
||||
TrainingJobRequest,
|
||||
SingleProductTrainingRequest,
|
||||
TrainingJobResponse
|
||||
)
|
||||
from app.services.messaging import (
|
||||
publish_job_progress,
|
||||
publish_data_validation_started,
|
||||
publish_data_validation_completed,
|
||||
publish_job_step_completed,
|
||||
publish_job_completed,
|
||||
publish_job_failed,
|
||||
publish_job_started,
|
||||
training_publisher
|
||||
)
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
route_builder = RouteBuilder('training')
|
||||
|
||||
router = APIRouter(tags=["training-operations"])
|
||||
|
||||
def get_enhanced_training_service():
|
||||
"""Dependency injection for EnhancedTrainingService"""
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
|
||||
return EnhancedTrainingService(database_manager)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("jobs"), response_model=TrainingJobResponse)
|
||||
@track_execution_time("enhanced_training_job_duration_seconds", "training-service")
|
||||
async def start_training_job(
|
||||
request: TrainingJobRequest,
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
background_tasks: BackgroundTasks = BackgroundTasks(),
|
||||
request_obj: Request = None,
|
||||
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
|
||||
):
|
||||
"""
|
||||
Start a new training job for all tenant products using repository pattern.
|
||||
|
||||
Enhanced immediate response pattern:
|
||||
1. Validate request with enhanced validation
|
||||
2. Create job record using repository pattern
|
||||
3. Return 200 with enhanced job details
|
||||
4. Execute enhanced training in background with repository tracking
|
||||
|
||||
Enhanced features:
|
||||
- Repository pattern for data access
|
||||
- Enhanced error handling and logging
|
||||
- Metrics tracking and monitoring
|
||||
- Transactional operations
|
||||
"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
# Generate enhanced job ID
|
||||
job_id = f"enhanced_training_{tenant_id}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
logger.info("Creating enhanced training job using repository pattern",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id)
|
||||
|
||||
# Record job creation metrics
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_training_jobs_created_total")
|
||||
|
||||
# Add enhanced background task
|
||||
background_tasks.add_task(
|
||||
execute_training_job_background,
|
||||
tenant_id=tenant_id,
|
||||
job_id=job_id,
|
||||
bakery_location=(40.4168, -3.7038),
|
||||
requested_start=request.start_date,
|
||||
requested_end=request.end_date
|
||||
)
|
||||
|
||||
# Return enhanced immediate success response
|
||||
response_data = {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"status": "pending",
|
||||
"message": "Enhanced training job started successfully using repository pattern",
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"estimated_duration_minutes": 18,
|
||||
"training_results": {
|
||||
"total_products": 0,
|
||||
"successful_trainings": 0,
|
||||
"failed_trainings": 0,
|
||||
"products": [],
|
||||
"overall_training_time_seconds": 0.0
|
||||
},
|
||||
"data_summary": None,
|
||||
"completed_at": None,
|
||||
"error_details": None,
|
||||
"processing_metadata": {
|
||||
"background_task": True,
|
||||
"async_execution": True,
|
||||
"enhanced_features": True,
|
||||
"repository_pattern": True,
|
||||
"dependency_injection": True
|
||||
}
|
||||
}
|
||||
|
||||
logger.info("Enhanced training job queued successfully",
|
||||
job_id=job_id,
|
||||
features=["repository-pattern", "dependency-injection", "enhanced-tracking"])
|
||||
|
||||
return TrainingJobResponse(**response_data)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_training_validation_errors_total")
|
||||
logger.error("Enhanced training job validation error",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_training_job_errors_total")
|
||||
logger.error("Failed to queue enhanced training job",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to start enhanced training job"
|
||||
)
|
||||
|
||||
|
||||
async def execute_training_job_background(
|
||||
tenant_id: str,
|
||||
job_id: str,
|
||||
bakery_location: tuple,
|
||||
requested_start: Optional[datetime] = None,
|
||||
requested_end: Optional[datetime] = None
|
||||
):
|
||||
"""
|
||||
Enhanced background task that executes the training job using repository pattern.
|
||||
|
||||
Enhanced features:
|
||||
- Repository pattern for all data operations
|
||||
- Enhanced error handling with structured logging
|
||||
- Transactional operations for data consistency
|
||||
- Comprehensive metrics tracking
|
||||
- Database connection pooling
|
||||
- Enhanced progress reporting
|
||||
"""
|
||||
|
||||
logger.info("Enhanced background training job started",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
features=["repository-pattern", "enhanced-tracking"])
|
||||
|
||||
# Get enhanced training service with dependency injection
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
|
||||
enhanced_training_service = EnhancedTrainingService(database_manager)
|
||||
|
||||
try:
|
||||
# Create initial training log entry first
|
||||
await enhanced_training_service._update_job_status_repository(
|
||||
job_id=job_id,
|
||||
status="pending",
|
||||
progress=0,
|
||||
current_step="Starting enhanced training job",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# Publish job started event
|
||||
await publish_job_started(job_id, tenant_id, {
|
||||
"enhanced_features": True,
|
||||
"repository_pattern": True,
|
||||
"job_type": "enhanced_training"
|
||||
})
|
||||
|
||||
training_config = {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"bakery_location": {
|
||||
"latitude": bakery_location[0],
|
||||
"longitude": bakery_location[1]
|
||||
},
|
||||
"requested_start": requested_start.isoformat() if requested_start else None,
|
||||
"requested_end": requested_end.isoformat() if requested_end else None,
|
||||
"estimated_duration_minutes": 18,
|
||||
"background_execution": True,
|
||||
"enhanced_features": True,
|
||||
"repository_pattern": True,
|
||||
"api_version": "enhanced_v1"
|
||||
}
|
||||
|
||||
# Update job status using repository pattern
|
||||
await enhanced_training_service._update_job_status_repository(
|
||||
job_id=job_id,
|
||||
status="running",
|
||||
progress=0,
|
||||
current_step="Initializing enhanced training pipeline",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# Execute the enhanced training pipeline with repository pattern
|
||||
result = await enhanced_training_service.start_training_job(
|
||||
tenant_id=tenant_id,
|
||||
job_id=job_id,
|
||||
bakery_location=bakery_location,
|
||||
requested_start=requested_start,
|
||||
requested_end=requested_end
|
||||
)
|
||||
|
||||
# Update final status using repository pattern
|
||||
await enhanced_training_service._update_job_status_repository(
|
||||
job_id=job_id,
|
||||
status="completed",
|
||||
progress=100,
|
||||
current_step="Enhanced training completed successfully",
|
||||
results=result,
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# Publish enhanced completion event
|
||||
await publish_job_completed(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
results={
|
||||
**result,
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True
|
||||
}
|
||||
)
|
||||
|
||||
logger.info("Enhanced background training job completed successfully",
|
||||
job_id=job_id,
|
||||
models_created=result.get('products_trained', 0),
|
||||
features=["repository-pattern", "enhanced-tracking"])
|
||||
|
||||
except Exception as training_error:
|
||||
logger.error("Enhanced training pipeline failed",
|
||||
job_id=job_id,
|
||||
error=str(training_error))
|
||||
|
||||
try:
|
||||
await enhanced_training_service._update_job_status_repository(
|
||||
job_id=job_id,
|
||||
status="failed",
|
||||
progress=0,
|
||||
current_step="Enhanced training failed",
|
||||
error_message=str(training_error),
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
except Exception as status_error:
|
||||
logger.error("Failed to update job status after training error",
|
||||
job_id=job_id,
|
||||
status_error=str(status_error))
|
||||
|
||||
# Publish enhanced failure event
|
||||
await publish_job_failed(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
error=str(training_error),
|
||||
metadata={
|
||||
"enhanced_features": True,
|
||||
"repository_pattern": True,
|
||||
"error_type": type(training_error).__name__
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as background_error:
|
||||
logger.error("Critical error in enhanced background training job",
|
||||
job_id=job_id,
|
||||
error=str(background_error))
|
||||
|
||||
finally:
|
||||
logger.info("Enhanced background training job cleanup completed",
|
||||
job_id=job_id)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_resource_detail_route("products", "inventory_product_id"), response_model=TrainingJobResponse)
|
||||
@track_execution_time("enhanced_single_product_training_duration_seconds", "training-service")
|
||||
async def start_single_product_training(
|
||||
request: SingleProductTrainingRequest,
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
inventory_product_id: str = Path(..., description="Inventory product UUID"),
|
||||
request_obj: Request = None,
|
||||
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
|
||||
):
|
||||
"""
|
||||
Start enhanced training for a single product using repository pattern.
|
||||
|
||||
Enhanced features:
|
||||
- Repository pattern for data access
|
||||
- Enhanced error handling and validation
|
||||
- Metrics tracking
|
||||
- Transactional operations
|
||||
"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
logger.info("Starting enhanced single product training",
|
||||
inventory_product_id=inventory_product_id,
|
||||
tenant_id=tenant_id)
|
||||
|
||||
# Record metrics
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_single_product_training_total")
|
||||
|
||||
# Generate enhanced job ID
|
||||
job_id = f"enhanced_single_{tenant_id}_{inventory_product_id}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Delegate to enhanced training service
|
||||
result = await enhanced_training_service.start_single_product_training(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
job_id=job_id,
|
||||
bakery_location=request.bakery_location or (40.4168, -3.7038)
|
||||
)
|
||||
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_single_product_training_success_total")
|
||||
|
||||
logger.info("Enhanced single product training completed",
|
||||
inventory_product_id=inventory_product_id,
|
||||
job_id=job_id)
|
||||
|
||||
return TrainingJobResponse(**result)
|
||||
|
||||
except ValueError as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_single_product_validation_errors_total")
|
||||
logger.error("Enhanced single product training validation error",
|
||||
error=str(e),
|
||||
inventory_product_id=inventory_product_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_single_product_training_errors_total")
|
||||
logger.error("Enhanced single product training failed",
|
||||
error=str(e),
|
||||
inventory_product_id=inventory_product_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Enhanced single product training failed"
|
||||
)
|
||||
|
||||
|
||||
# ============================================
|
||||
# WebSocket Live Feed
|
||||
# ============================================
|
||||
|
||||
class ConnectionManager:
|
||||
"""Manage WebSocket connections for training progress"""
|
||||
|
||||
def __init__(self):
|
||||
self.active_connections: Dict[str, Dict[str, WebSocket]] = {}
|
||||
# Structure: {job_id: {connection_id: websocket}}
|
||||
|
||||
async def connect(self, websocket: WebSocket, job_id: str, connection_id: str):
|
||||
"""Accept WebSocket connection and register it"""
|
||||
await websocket.accept()
|
||||
|
||||
if job_id not in self.active_connections:
|
||||
self.active_connections[job_id] = {}
|
||||
|
||||
self.active_connections[job_id][connection_id] = websocket
|
||||
logger.info(f"WebSocket connected for job {job_id}, connection {connection_id}")
|
||||
|
||||
def disconnect(self, job_id: str, connection_id: str):
|
||||
"""Remove WebSocket connection"""
|
||||
if job_id in self.active_connections:
|
||||
self.active_connections[job_id].pop(connection_id, None)
|
||||
if not self.active_connections[job_id]:
|
||||
del self.active_connections[job_id]
|
||||
|
||||
logger.info(f"WebSocket disconnected for job {job_id}, connection {connection_id}")
|
||||
|
||||
async def send_to_job(self, job_id: str, message: dict):
|
||||
"""Send message to all connections for a specific job with better error handling"""
|
||||
if job_id not in self.active_connections:
|
||||
logger.debug(f"No active connections for job {job_id}")
|
||||
return
|
||||
|
||||
# Send to all connections for this job
|
||||
disconnected_connections = []
|
||||
|
||||
for connection_id, websocket in self.active_connections[job_id].items():
|
||||
try:
|
||||
await websocket.send_json(message)
|
||||
logger.debug(f"Sent {message.get('type', 'unknown')} to connection {connection_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send message to connection {connection_id}: {e}")
|
||||
disconnected_connections.append(connection_id)
|
||||
|
||||
# Clean up disconnected connections
|
||||
for connection_id in disconnected_connections:
|
||||
self.disconnect(job_id, connection_id)
|
||||
|
||||
# Log successful sends
|
||||
active_count = len(self.active_connections.get(job_id, {}))
|
||||
if active_count > 0:
|
||||
logger.info(f"Sent {message.get('type', 'unknown')} message to {active_count} connection(s) for job {job_id}")
|
||||
|
||||
|
||||
# Global connection manager
|
||||
connection_manager = ConnectionManager()
|
||||
|
||||
|
||||
@router.websocket(route_builder.build_nested_resource_route('jobs', 'job_id', 'live'))
|
||||
async def training_progress_websocket(
|
||||
websocket: WebSocket,
|
||||
tenant_id: str,
|
||||
job_id: str
|
||||
):
|
||||
"""
|
||||
WebSocket endpoint for real-time training progress updates
|
||||
"""
|
||||
# Validate token from query parameters
|
||||
token = websocket.query_params.get("token")
|
||||
if not token:
|
||||
logger.warning(f"WebSocket connection rejected - missing token for job {job_id}")
|
||||
await websocket.close(code=1008, reason="Authentication token required")
|
||||
return
|
||||
|
||||
# Validate the token
|
||||
from shared.auth.jwt_handler import JWTHandler
|
||||
|
||||
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
|
||||
|
||||
try:
|
||||
payload = jwt_handler.verify_token(token)
|
||||
if not payload:
|
||||
logger.warning(f"WebSocket connection rejected - invalid token for job {job_id}")
|
||||
await websocket.close(code=1008, reason="Invalid authentication token")
|
||||
return
|
||||
|
||||
# Verify user has access to this tenant
|
||||
user_id = payload.get('user_id')
|
||||
if not user_id:
|
||||
logger.warning(f"WebSocket connection rejected - no user_id in token for job {job_id}")
|
||||
await websocket.close(code=1008, reason="Invalid token payload")
|
||||
return
|
||||
|
||||
logger.info(f"WebSocket authenticated for user {payload.get('email', 'unknown')} on job {job_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"WebSocket token validation failed for job {job_id}: {e}")
|
||||
await websocket.close(code=1008, reason="Token validation failed")
|
||||
return
|
||||
|
||||
connection_id = f"{tenant_id}_{user_id}_{id(websocket)}"
|
||||
|
||||
await connection_manager.connect(websocket, job_id, connection_id)
|
||||
logger.info(f"WebSocket connection established for job {job_id}, user {user_id}")
|
||||
|
||||
consumer_task = None
|
||||
training_completed = False
|
||||
|
||||
try:
|
||||
# Start RabbitMQ consumer
|
||||
consumer_task = asyncio.create_task(
|
||||
setup_rabbitmq_consumer_for_job(job_id, tenant_id)
|
||||
)
|
||||
|
||||
last_activity = asyncio.get_event_loop().time()
|
||||
|
||||
while not training_completed:
|
||||
try:
|
||||
try:
|
||||
data = await asyncio.wait_for(websocket.receive(), timeout=60.0)
|
||||
last_activity = asyncio.get_event_loop().time()
|
||||
|
||||
# Handle different message types
|
||||
if data["type"] == "websocket.receive":
|
||||
if "text" in data:
|
||||
message_text = data["text"]
|
||||
if message_text == "ping":
|
||||
await websocket.send_text("pong")
|
||||
logger.debug(f"Text ping received from job {job_id}")
|
||||
elif message_text == "get_status":
|
||||
current_status = await get_current_job_status(job_id, tenant_id)
|
||||
if current_status:
|
||||
await websocket.send_json({
|
||||
"type": "current_status",
|
||||
"job_id": job_id,
|
||||
"data": current_status
|
||||
})
|
||||
elif message_text == "close":
|
||||
logger.info(f"Client requested connection close for job {job_id}")
|
||||
break
|
||||
|
||||
elif "bytes" in data:
|
||||
await websocket.send_text("pong")
|
||||
logger.debug(f"Binary ping received for job {job_id}, responding with text pong")
|
||||
|
||||
elif data["type"] == "websocket.disconnect":
|
||||
logger.info(f"WebSocket disconnect message received for job {job_id}")
|
||||
break
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
|
||||
if current_time - last_activity > 90:
|
||||
logger.warning(f"No frontend activity for 90s on job {job_id}, sending training service heartbeat")
|
||||
|
||||
try:
|
||||
await websocket.send_json({
|
||||
"type": "heartbeat",
|
||||
"job_id": job_id,
|
||||
"timestamp": str(datetime.now()),
|
||||
"message": "Training service heartbeat - frontend inactive",
|
||||
"inactivity_seconds": int(current_time - last_activity)
|
||||
})
|
||||
last_activity = current_time
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send heartbeat for job {job_id}: {e}")
|
||||
break
|
||||
else:
|
||||
logger.debug(f"Normal 60s timeout for job {job_id}, continuing (last activity: {int(current_time - last_activity)}s ago)")
|
||||
continue
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket client disconnected for job {job_id}")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error for job {job_id}: {e}")
|
||||
if "Cannot call" in str(e) and "disconnect message" in str(e):
|
||||
logger.error(f"FastAPI WebSocket disconnect error - connection already closed")
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
|
||||
logger.info(f"WebSocket loop ended for job {job_id}, training_completed: {training_completed}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Critical WebSocket error for job {job_id}: {e}")
|
||||
|
||||
finally:
|
||||
logger.info(f"Cleaning up WebSocket connection for job {job_id}")
|
||||
connection_manager.disconnect(job_id, connection_id)
|
||||
|
||||
if consumer_task and not consumer_task.done():
|
||||
if training_completed:
|
||||
logger.info(f"Training completed, cancelling consumer for job {job_id}")
|
||||
consumer_task.cancel()
|
||||
else:
|
||||
logger.warning(f"WebSocket disconnected but training not completed for job {job_id}")
|
||||
|
||||
try:
|
||||
await consumer_task
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Consumer task cancelled for job {job_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Consumer task error for job {job_id}: {e}")
|
||||
|
||||
|
||||
async def setup_rabbitmq_consumer_for_job(job_id: str, tenant_id: str):
|
||||
"""Set up RabbitMQ consumer to listen for training events for a specific job"""
|
||||
|
||||
logger.info(f"Setting up RabbitMQ consumer for job {job_id}")
|
||||
|
||||
try:
|
||||
# Create a unique queue for this WebSocket connection
|
||||
queue_name = f"websocket_training_{job_id}_{tenant_id}"
|
||||
|
||||
async def handle_training_message(message):
|
||||
"""Handle incoming RabbitMQ messages and forward to WebSocket"""
|
||||
try:
|
||||
# Parse the message
|
||||
body = message.body.decode()
|
||||
data = json.loads(body)
|
||||
|
||||
logger.debug(f"Received message for job {job_id}: {data.get('event_type', 'unknown')}")
|
||||
|
||||
# Extract event data
|
||||
event_type = data.get("event_type", "unknown")
|
||||
event_data = data.get("data", {})
|
||||
|
||||
# Only process messages for this specific job
|
||||
message_job_id = event_data.get("job_id") if event_data else None
|
||||
if message_job_id != job_id:
|
||||
logger.debug(f"Ignoring message for different job: {message_job_id}")
|
||||
await message.ack()
|
||||
return
|
||||
|
||||
# Transform RabbitMQ message to WebSocket message format
|
||||
websocket_message = {
|
||||
"type": map_event_type_to_websocket_type(event_type),
|
||||
"job_id": job_id,
|
||||
"timestamp": data.get("timestamp"),
|
||||
"data": event_data
|
||||
}
|
||||
|
||||
logger.info(f"Forwarding {event_type} message to WebSocket clients for job {job_id}")
|
||||
|
||||
# Send to all WebSocket connections for this job
|
||||
await connection_manager.send_to_job(job_id, websocket_message)
|
||||
|
||||
# Check if this is a completion message
|
||||
if event_type in ["training.completed", "training.failed"]:
|
||||
logger.info(f"Training completion detected for job {job_id}: {event_type}")
|
||||
|
||||
# Acknowledge the message
|
||||
await message.ack()
|
||||
|
||||
logger.debug(f"Successfully processed {event_type} for job {job_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling training message for job {job_id}: {e}")
|
||||
import traceback
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
await message.nack(requeue=False)
|
||||
|
||||
# Check if training_publisher is connected
|
||||
if not training_publisher.connected:
|
||||
logger.warning(f"Training publisher not connected for job {job_id}, attempting to connect...")
|
||||
success = await training_publisher.connect()
|
||||
if not success:
|
||||
logger.error(f"Failed to connect training_publisher for job {job_id}")
|
||||
return
|
||||
|
||||
# Subscribe to training events
|
||||
logger.info(f"Subscribing to training events for job {job_id}")
|
||||
success = await training_publisher.consume_events(
|
||||
exchange_name="training.events",
|
||||
queue_name=queue_name,
|
||||
routing_key="training.*",
|
||||
callback=handle_training_message
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"Successfully set up RabbitMQ consumer for job {job_id} (queue: {queue_name})")
|
||||
|
||||
# Keep the consumer running indefinitely until cancelled
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(10)
|
||||
logger.debug(f"Consumer heartbeat for job {job_id}")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Consumer cancelled for job {job_id}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Consumer error for job {job_id}: {e}")
|
||||
raise
|
||||
else:
|
||||
logger.error(f"Failed to set up RabbitMQ consumer for job {job_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Exception in setup_rabbitmq_consumer_for_job for job {job_id}: {e}")
|
||||
import traceback
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
|
||||
def map_event_type_to_websocket_type(rabbitmq_event_type: str) -> str:
|
||||
"""Map RabbitMQ event types to WebSocket message types"""
|
||||
mapping = {
|
||||
"training.started": "started",
|
||||
"training.progress": "progress",
|
||||
"training.completed": "completed",
|
||||
"training.failed": "failed",
|
||||
"training.cancelled": "cancelled",
|
||||
"training.step.completed": "step_completed",
|
||||
"training.product.started": "product_started",
|
||||
"training.product.completed": "product_completed",
|
||||
"training.product.failed": "product_failed",
|
||||
"training.model.trained": "model_trained",
|
||||
"training.data.validation.started": "validation_started",
|
||||
"training.data.validation.completed": "validation_completed"
|
||||
}
|
||||
|
||||
return mapping.get(rabbitmq_event_type, "unknown")
|
||||
|
||||
|
||||
async def get_current_job_status(job_id: str, tenant_id: str) -> Dict[str, Any]:
|
||||
"""Get current job status from database"""
|
||||
try:
|
||||
return {
|
||||
"job_id": job_id,
|
||||
"status": "running",
|
||||
"progress": 0,
|
||||
"current_step": "Starting...",
|
||||
"started_at": "2025-07-30T19:00:00Z"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get current job status: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint for the training operations"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "training-operations",
|
||||
"version": "2.0.0",
|
||||
"features": [
|
||||
"repository-pattern",
|
||||
"dependency-injection",
|
||||
"enhanced-error-handling",
|
||||
"metrics-tracking",
|
||||
"transactional-operations",
|
||||
"websocket-support"
|
||||
],
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
@@ -1,377 +0,0 @@
|
||||
# services/training/app/api/websocket.py
|
||||
"""
|
||||
WebSocket endpoints for real-time training progress updates
|
||||
"""
|
||||
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Dict, Any
|
||||
from fastapi import WebSocket, WebSocketDisconnect, Depends, HTTPException
|
||||
from fastapi.routing import APIRouter
|
||||
import datetime
|
||||
|
||||
import structlog
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
from app.services.messaging import training_publisher
|
||||
from shared.auth.decorators import (
|
||||
get_current_user_dep
|
||||
)
|
||||
|
||||
# Create WebSocket router
|
||||
websocket_router = APIRouter()
|
||||
|
||||
class ConnectionManager:
|
||||
"""Manage WebSocket connections for training progress"""
|
||||
|
||||
def __init__(self):
|
||||
self.active_connections: Dict[str, Dict[str, WebSocket]] = {}
|
||||
# Structure: {job_id: {connection_id: websocket}}
|
||||
|
||||
async def connect(self, websocket: WebSocket, job_id: str, connection_id: str):
|
||||
"""Accept WebSocket connection and register it"""
|
||||
await websocket.accept()
|
||||
|
||||
if job_id not in self.active_connections:
|
||||
self.active_connections[job_id] = {}
|
||||
|
||||
self.active_connections[job_id][connection_id] = websocket
|
||||
logger.info(f"WebSocket connected for job {job_id}, connection {connection_id}")
|
||||
|
||||
def disconnect(self, job_id: str, connection_id: str):
|
||||
"""Remove WebSocket connection"""
|
||||
if job_id in self.active_connections:
|
||||
self.active_connections[job_id].pop(connection_id, None)
|
||||
if not self.active_connections[job_id]:
|
||||
del self.active_connections[job_id]
|
||||
|
||||
logger.info(f"WebSocket disconnected for job {job_id}, connection {connection_id}")
|
||||
|
||||
async def send_to_job(self, job_id: str, message: dict):
|
||||
"""Send message to all connections for a specific job with better error handling"""
|
||||
if job_id not in self.active_connections:
|
||||
logger.debug(f"No active connections for job {job_id}")
|
||||
return
|
||||
|
||||
# Send to all connections for this job
|
||||
disconnected_connections = []
|
||||
|
||||
for connection_id, websocket in self.active_connections[job_id].items():
|
||||
try:
|
||||
await websocket.send_json(message)
|
||||
logger.debug(f"📤 Sent {message.get('type', 'unknown')} to connection {connection_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send message to connection {connection_id}: {e}")
|
||||
disconnected_connections.append(connection_id)
|
||||
|
||||
# Clean up disconnected connections
|
||||
for connection_id in disconnected_connections:
|
||||
self.disconnect(job_id, connection_id)
|
||||
|
||||
# Log successful sends
|
||||
active_count = len(self.active_connections.get(job_id, {}))
|
||||
if active_count > 0:
|
||||
logger.info(f"📡 Sent {message.get('type', 'unknown')} message to {active_count} connection(s) for job {job_id}")
|
||||
|
||||
# Global connection manager
|
||||
connection_manager = ConnectionManager()
|
||||
|
||||
@websocket_router.websocket("/tenants/{tenant_id}/training/jobs/{job_id}/live")
|
||||
async def training_progress_websocket(
|
||||
websocket: WebSocket,
|
||||
tenant_id: str,
|
||||
job_id: str
|
||||
):
|
||||
# Validate token from query parameters
|
||||
token = websocket.query_params.get("token")
|
||||
if not token:
|
||||
logger.warning(f"WebSocket connection rejected - missing token for job {job_id}")
|
||||
await websocket.close(code=1008, reason="Authentication token required")
|
||||
return
|
||||
|
||||
# Validate the token (use the same JWT handler as gateway)
|
||||
from shared.auth.jwt_handler import JWTHandler
|
||||
from app.core.config import settings
|
||||
|
||||
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
|
||||
|
||||
try:
|
||||
payload = jwt_handler.verify_token(token)
|
||||
if not payload:
|
||||
logger.warning(f"WebSocket connection rejected - invalid token for job {job_id}")
|
||||
await websocket.close(code=1008, reason="Invalid authentication token")
|
||||
return
|
||||
|
||||
# Verify user has access to this tenant
|
||||
user_id = payload.get('user_id')
|
||||
if not user_id:
|
||||
logger.warning(f"WebSocket connection rejected - no user_id in token for job {job_id}")
|
||||
await websocket.close(code=1008, reason="Invalid token payload")
|
||||
return
|
||||
|
||||
logger.info(f"WebSocket authenticated for user {payload.get('email', 'unknown')} on job {job_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"WebSocket token validation failed for job {job_id}: {e}")
|
||||
await websocket.close(code=1008, reason="Token validation failed")
|
||||
return
|
||||
|
||||
connection_id = f"{tenant_id}_{user_id}_{id(websocket)}"
|
||||
|
||||
await connection_manager.connect(websocket, job_id, connection_id)
|
||||
logger.info(f"WebSocket connection established for job {job_id}, user {user_id}")
|
||||
|
||||
consumer_task = None
|
||||
training_completed = False
|
||||
|
||||
try:
|
||||
# Start RabbitMQ consumer
|
||||
consumer_task = asyncio.create_task(
|
||||
setup_rabbitmq_consumer_for_job(job_id, tenant_id)
|
||||
)
|
||||
|
||||
last_activity = asyncio.get_event_loop().time()
|
||||
|
||||
while not training_completed:
|
||||
try:
|
||||
# Coordinate with frontend 30s heartbeat + gateway 45s timeout
|
||||
# This should be longer than gateway timeout to avoid premature closure
|
||||
try:
|
||||
data = await asyncio.wait_for(websocket.receive(), timeout=60.0)
|
||||
last_activity = asyncio.get_event_loop().time()
|
||||
|
||||
# Handle different message types
|
||||
if data["type"] == "websocket.receive":
|
||||
if "text" in data:
|
||||
message_text = data["text"]
|
||||
if message_text == "ping":
|
||||
await websocket.send_text("pong")
|
||||
logger.debug(f"Text ping received from job {job_id}")
|
||||
elif message_text == "get_status":
|
||||
current_status = await get_current_job_status(job_id, tenant_id)
|
||||
if current_status:
|
||||
await websocket.send_json({
|
||||
"type": "current_status",
|
||||
"job_id": job_id,
|
||||
"data": current_status
|
||||
})
|
||||
elif message_text == "close":
|
||||
logger.info(f"Client requested connection close for job {job_id}")
|
||||
break
|
||||
|
||||
elif "bytes" in data:
|
||||
# Handle binary messages (WebSocket ping frames) - respond with text pong for compatibility
|
||||
await websocket.send_text("pong")
|
||||
logger.debug(f"Binary ping received for job {job_id}, responding with text pong")
|
||||
|
||||
elif data["type"] == "websocket.disconnect":
|
||||
logger.info(f"WebSocket disconnect message received for job {job_id}")
|
||||
break
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# No message received in 60 seconds - this is now coordinated with gateway timeouts
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
|
||||
# Send heartbeat only if we haven't received frontend ping for too long
|
||||
# Frontend sends ping every 30s, so 60s timeout + 30s grace = 90s before heartbeat
|
||||
if current_time - last_activity > 90: # 90 seconds of total inactivity
|
||||
logger.warning(f"No frontend activity for 90s on job {job_id}, sending training service heartbeat")
|
||||
|
||||
try:
|
||||
await websocket.send_json({
|
||||
"type": "heartbeat",
|
||||
"job_id": job_id,
|
||||
"timestamp": str(datetime.datetime.now()),
|
||||
"message": "Training service heartbeat - frontend inactive",
|
||||
"inactivity_seconds": int(current_time - last_activity)
|
||||
})
|
||||
last_activity = current_time
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send heartbeat for job {job_id}: {e}")
|
||||
break
|
||||
else:
|
||||
# Normal timeout, frontend should be sending ping every 30s
|
||||
logger.debug(f"Normal 60s timeout for job {job_id}, continuing (last activity: {int(current_time - last_activity)}s ago)")
|
||||
continue
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket client disconnected for job {job_id}")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error for job {job_id}: {e}")
|
||||
# Check if it's the specific "cannot call receive" error
|
||||
if "Cannot call" in str(e) and "disconnect message" in str(e):
|
||||
logger.error(f"FastAPI WebSocket disconnect error - connection already closed")
|
||||
break
|
||||
# Don't break immediately for other errors - try to recover
|
||||
await asyncio.sleep(1)
|
||||
|
||||
logger.info(f"WebSocket loop ended for job {job_id}, training_completed: {training_completed}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Critical WebSocket error for job {job_id}: {e}")
|
||||
|
||||
finally:
|
||||
logger.info(f"Cleaning up WebSocket connection for job {job_id}")
|
||||
connection_manager.disconnect(job_id, connection_id)
|
||||
|
||||
if consumer_task and not consumer_task.done():
|
||||
if training_completed:
|
||||
logger.info(f"Training completed, cancelling consumer for job {job_id}")
|
||||
consumer_task.cancel()
|
||||
else:
|
||||
logger.warning(f"WebSocket disconnected but training not completed for job {job_id}")
|
||||
|
||||
try:
|
||||
await consumer_task
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Consumer task cancelled for job {job_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Consumer task error for job {job_id}: {e}")
|
||||
|
||||
|
||||
async def setup_rabbitmq_consumer_for_job(job_id: str, tenant_id: str):
|
||||
"""Set up RabbitMQ consumer to listen for training events for a specific job"""
|
||||
|
||||
logger.info(f"🚀 Setting up RabbitMQ consumer for job {job_id}")
|
||||
|
||||
try:
|
||||
# Create a unique queue for this WebSocket connection
|
||||
queue_name = f"websocket_training_{job_id}_{tenant_id}"
|
||||
|
||||
async def handle_training_message(message):
|
||||
"""Handle incoming RabbitMQ messages and forward to WebSocket"""
|
||||
try:
|
||||
# Parse the message
|
||||
body = message.body.decode()
|
||||
data = json.loads(body)
|
||||
|
||||
logger.debug(f"🔍 Received message for job {job_id}: {data.get('event_type', 'unknown')}")
|
||||
|
||||
# Extract event data
|
||||
event_type = data.get("event_type", "unknown")
|
||||
event_data = data.get("data", {})
|
||||
|
||||
# Only process messages for this specific job
|
||||
message_job_id = event_data.get("job_id") if event_data else None
|
||||
if message_job_id != job_id:
|
||||
logger.debug(f"⏭️ Ignoring message for different job: {message_job_id}")
|
||||
await message.ack()
|
||||
return
|
||||
|
||||
# Transform RabbitMQ message to WebSocket message format
|
||||
websocket_message = {
|
||||
"type": map_event_type_to_websocket_type(event_type),
|
||||
"job_id": job_id,
|
||||
"timestamp": data.get("timestamp"),
|
||||
"data": event_data
|
||||
}
|
||||
|
||||
logger.info(f"📤 Forwarding {event_type} message to WebSocket clients for job {job_id}")
|
||||
|
||||
# Send to all WebSocket connections for this job
|
||||
await connection_manager.send_to_job(job_id, websocket_message)
|
||||
|
||||
# Check if this is a completion message
|
||||
if event_type in ["training.completed", "training.failed"]:
|
||||
logger.info(f"🎯 Training completion detected for job {job_id}: {event_type}")
|
||||
# Mark training as completed (you might want to store this in a global state)
|
||||
# For now, we'll let the WebSocket handle this through the message
|
||||
|
||||
# Acknowledge the message
|
||||
await message.ack()
|
||||
|
||||
logger.debug(f"✅ Successfully processed {event_type} for job {job_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error handling training message for job {job_id}: {e}")
|
||||
import traceback
|
||||
logger.error(f"💥 Traceback: {traceback.format_exc()}")
|
||||
await message.nack(requeue=False)
|
||||
|
||||
# Check if training_publisher is connected
|
||||
if not training_publisher.connected:
|
||||
logger.warning(f"⚠️ Training publisher not connected for job {job_id}, attempting to connect...")
|
||||
success = await training_publisher.connect()
|
||||
if not success:
|
||||
logger.error(f"❌ Failed to connect training_publisher for job {job_id}")
|
||||
return
|
||||
|
||||
# Subscribe to training events
|
||||
logger.info(f"🔗 Subscribing to training events for job {job_id}")
|
||||
success = await training_publisher.consume_events(
|
||||
exchange_name="training.events",
|
||||
queue_name=queue_name,
|
||||
routing_key="training.*", # Listen to all training events
|
||||
callback=handle_training_message
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"✅ Successfully set up RabbitMQ consumer for job {job_id} (queue: {queue_name})")
|
||||
|
||||
# Keep the consumer running indefinitely until cancelled
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(10) # Keep consumer alive
|
||||
logger.debug(f"🔄 Consumer heartbeat for job {job_id}")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"🛑 Consumer cancelled for job {job_id}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"💥 Consumer error for job {job_id}: {e}")
|
||||
raise
|
||||
else:
|
||||
logger.error(f"❌ Failed to set up RabbitMQ consumer for job {job_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"💥 Exception in setup_rabbitmq_consumer_for_job for job {job_id}: {e}")
|
||||
import traceback
|
||||
logger.error(f"🔥 Traceback: {traceback.format_exc()}")
|
||||
|
||||
|
||||
def map_event_type_to_websocket_type(rabbitmq_event_type: str) -> str:
|
||||
"""Map RabbitMQ event types to WebSocket message types"""
|
||||
mapping = {
|
||||
"training.started": "started",
|
||||
"training.progress": "progress",
|
||||
"training.completed": "completed", # This is the key completion event
|
||||
"training.failed": "failed", # This is also a completion event
|
||||
"training.cancelled": "cancelled",
|
||||
"training.step.completed": "step_completed",
|
||||
"training.product.started": "product_started",
|
||||
"training.product.completed": "product_completed",
|
||||
"training.product.failed": "product_failed",
|
||||
"training.model.trained": "model_trained",
|
||||
"training.data.validation.started": "validation_started",
|
||||
"training.data.validation.completed": "validation_completed"
|
||||
}
|
||||
|
||||
return mapping.get(rabbitmq_event_type, "unknown")
|
||||
|
||||
async def get_current_job_status(job_id: str, tenant_id: str) -> Dict[str, Any]:
|
||||
"""Get current job status from database or cache"""
|
||||
try:
|
||||
# This should query your database for current job status
|
||||
# For now, return a placeholder - implement based on your database schema
|
||||
|
||||
from app.core.database import get_db_session
|
||||
from app.models.training import ModelTrainingLog # Assuming you have this model
|
||||
|
||||
# async with get_background_db_session() as db:
|
||||
# Query your training job status
|
||||
# This is a placeholder - adjust based on your actual database models
|
||||
# pass
|
||||
|
||||
# Placeholder return - replace with actual database query
|
||||
return {
|
||||
"job_id": job_id,
|
||||
"status": "running", # or "completed", "failed", etc.
|
||||
"progress": 0,
|
||||
"current_step": "Starting...",
|
||||
"started_at": "2025-07-30T19:00:00Z"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get current job status: {e}")
|
||||
return None
|
||||
@@ -11,8 +11,7 @@ from fastapi import FastAPI, Request
|
||||
from sqlalchemy import text
|
||||
from app.core.config import settings
|
||||
from app.core.database import initialize_training_database, cleanup_training_database, database_manager
|
||||
from app.api import training, models
|
||||
from app.api.websocket import websocket_router
|
||||
from app.api import training_jobs, training_operations, models
|
||||
from app.services.messaging import setup_messaging, cleanup_messaging
|
||||
from shared.service_base import StandardFastAPIService
|
||||
|
||||
@@ -55,7 +54,7 @@ class TrainingService(StandardFastAPIService):
|
||||
version="1.0.0",
|
||||
log_level=settings.LOG_LEVEL,
|
||||
cors_origins=settings.CORS_ORIGINS_LIST,
|
||||
api_prefix="/api/v1",
|
||||
api_prefix="", # Empty because RouteBuilder already includes /api/v1
|
||||
database_manager=database_manager,
|
||||
expected_tables=training_expected_tables,
|
||||
enable_messaging=True
|
||||
@@ -160,9 +159,9 @@ service.setup_custom_middleware()
|
||||
service.setup_custom_endpoints()
|
||||
|
||||
# Include API routers
|
||||
service.add_router(training.router, tags=["training"])
|
||||
service.add_router(training_jobs.router, tags=["training-jobs"])
|
||||
service.add_router(training_operations.router, tags=["training-operations"])
|
||||
service.add_router(models.router, tags=["models"])
|
||||
app.include_router(websocket_router, prefix="/api/v1/ws", tags=["websocket"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
|
||||
Reference in New Issue
Block a user