REFACTOR ALL APIs

This commit is contained in:
Urtzi Alfaro
2025-10-06 15:27:01 +02:00
parent dc8221bd2f
commit 38fb98bc27
166 changed files with 18454 additions and 13605 deletions

View File

@@ -3,12 +3,12 @@ Training API Layer
HTTP endpoints for ML training operations
"""
from .training import router as training_router
from .websocket import websocket_router
from .training_jobs import router as training_jobs_router
from .training_operations import router as training_operations_router
from .models import router as models_router
__all__ = [
"training_router",
"websocket_router"
"training_jobs_router",
"training_operations_router",
"models_router"
]

View File

@@ -22,13 +22,27 @@ from shared.auth.decorators import (
get_current_user_dep,
require_admin_role
)
from shared.routing import RouteBuilder
from shared.auth.access_control import (
require_user_role,
admin_role_required,
owner_role_required,
require_subscription_tier,
analytics_tier_required,
enterprise_tier_required
)
# Create route builder for consistent URL structure
route_builder = RouteBuilder('training')
logger = structlog.get_logger()
router = APIRouter()
training_service = TrainingService()
@router.get("/tenants/{tenant_id}/models/{inventory_product_id}/active")
@router.get(
route_builder.build_base_route("models") + "/{inventory_product_id}/active"
)
async def get_active_model(
tenant_id: str = Path(..., description="Tenant ID"),
inventory_product_id: str = Path(..., description="Inventory product UUID"),
@@ -114,7 +128,10 @@ async def get_active_model(
detail="Failed to retrieve model"
)
@router.get("/tenants/{tenant_id}/models/{model_id}/metrics", response_model=ModelMetricsResponse)
@router.get(
route_builder.build_nested_resource_route("models", "model_id", "metrics"),
response_model=ModelMetricsResponse
)
async def get_model_metrics(
model_id: str = Path(..., description="Model ID"),
db: AsyncSession = Depends(get_db)
@@ -168,7 +185,10 @@ async def get_model_metrics(
detail="Failed to retrieve model metrics"
)
@router.get("/tenants/{tenant_id}/models", response_model=List[TrainedModelResponse])
@router.get(
route_builder.build_base_route("models"),
response_model=List[TrainedModelResponse]
)
async def list_models(
tenant_id: str = Path(..., description="Tenant ID"),
status: Optional[str] = Query(None, description="Filter by status (active/inactive)"),
@@ -235,6 +255,7 @@ async def list_models(
)
@router.delete("/models/tenant/{tenant_id}")
@require_user_role(['admin', 'owner'])
async def delete_tenant_models_complete(
tenant_id: str,
current_user = Depends(get_current_user_dep),

View File

@@ -1,577 +0,0 @@
"""
Enhanced Training API Endpoints with Repository Pattern
Updated to use repository pattern with dependency injection and improved error handling
"""
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks, Request
from fastapi import Query, Path
from typing import List, Optional, Dict, Any
import structlog
from datetime import datetime, timezone
import uuid
from app.services.training_service import EnhancedTrainingService
from app.schemas.training import (
TrainingJobRequest,
SingleProductTrainingRequest,
TrainingJobResponse
)
from app.services.messaging import (
publish_job_progress,
publish_data_validation_started,
publish_data_validation_completed,
publish_job_step_completed,
publish_job_completed,
publish_job_failed,
publish_job_started
)
from shared.auth.decorators import require_admin_role, get_current_user_dep
from shared.database.base import create_database_manager
from shared.monitoring.decorators import track_execution_time
from shared.monitoring.metrics import get_metrics_collector
from app.core.config import settings
logger = structlog.get_logger()
router = APIRouter(tags=["enhanced-training"])
def get_enhanced_training_service():
"""Dependency injection for EnhancedTrainingService"""
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
return EnhancedTrainingService(database_manager)
@router.post("/tenants/{tenant_id}/training/jobs", response_model=TrainingJobResponse)
@track_execution_time("enhanced_training_job_duration_seconds", "training-service")
async def start_enhanced_training_job(
request: TrainingJobRequest,
tenant_id: str = Path(..., description="Tenant ID"),
background_tasks: BackgroundTasks = BackgroundTasks(),
request_obj: Request = None,
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
):
"""
Start a new enhanced training job for all tenant products using repository pattern.
🚀 ENHANCED IMMEDIATE RESPONSE PATTERN:
1. Validate request with enhanced validation
2. Create job record using repository pattern
3. Return 200 with enhanced job details
4. Execute enhanced training in background with repository tracking
Enhanced features:
- Repository pattern for data access
- Enhanced error handling and logging
- Metrics tracking and monitoring
- Transactional operations
"""
metrics = get_metrics_collector(request_obj)
try:
# Generate enhanced job ID
job_id = f"enhanced_training_{tenant_id}_{uuid.uuid4().hex[:8]}"
logger.info("Creating enhanced training job using repository pattern",
job_id=job_id,
tenant_id=tenant_id)
# Record job creation metrics
if metrics:
metrics.increment_counter("enhanced_training_jobs_created_total")
# Add enhanced background task
background_tasks.add_task(
execute_enhanced_training_job_background,
tenant_id=tenant_id,
job_id=job_id,
bakery_location=(40.4168, -3.7038),
requested_start=request.start_date,
requested_end=request.end_date
)
# Return enhanced immediate success response
response_data = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "pending",
"message": "Enhanced training job started successfully using repository pattern",
"created_at": datetime.now(timezone.utc),
"estimated_duration_minutes": 18,
"training_results": {
"total_products": 0, # Will be updated during processing
"successful_trainings": 0,
"failed_trainings": 0,
"products": [],
"overall_training_time_seconds": 0.0
},
"data_summary": None,
"completed_at": None,
"error_details": None,
"processing_metadata": {
"background_task": True,
"async_execution": True,
"enhanced_features": True,
"repository_pattern": True,
"dependency_injection": True
}
}
logger.info("Enhanced training job queued successfully",
job_id=job_id,
features=["repository-pattern", "dependency-injection", "enhanced-tracking"])
return TrainingJobResponse(**response_data)
except HTTPException:
# Re-raise HTTP exceptions as-is
raise
except ValueError as e:
if metrics:
metrics.increment_counter("enhanced_training_validation_errors_total")
logger.error("Enhanced training job validation error",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_training_job_errors_total")
logger.error("Failed to queue enhanced training job",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to start enhanced training job"
)
async def execute_enhanced_training_job_background(
tenant_id: str,
job_id: str,
bakery_location: tuple,
requested_start: Optional[datetime] = None,
requested_end: Optional[datetime] = None
):
"""
Enhanced background task that executes the training job using repository pattern.
🔧 ENHANCED FEATURES:
- Repository pattern for all data operations
- Enhanced error handling with structured logging
- Transactional operations for data consistency
- Comprehensive metrics tracking
- Database connection pooling
- Enhanced progress reporting
"""
logger.info("Enhanced background training job started",
job_id=job_id,
tenant_id=tenant_id,
features=["repository-pattern", "enhanced-tracking"])
# Get enhanced training service with dependency injection
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
enhanced_training_service = EnhancedTrainingService(database_manager)
try:
# Create initial training log entry first
await enhanced_training_service._update_job_status_repository(
job_id=job_id,
status="pending",
progress=0,
current_step="Starting enhanced training job",
tenant_id=tenant_id
)
# Publish job started event
await publish_job_started(job_id, tenant_id, {
"enhanced_features": True,
"repository_pattern": True,
"job_type": "enhanced_training"
})
training_config = {
"job_id": job_id,
"tenant_id": tenant_id,
"bakery_location": {
"latitude": bakery_location[0],
"longitude": bakery_location[1]
},
"requested_start": requested_start.isoformat() if requested_start else None,
"requested_end": requested_end.isoformat() if requested_end else None,
"estimated_duration_minutes": 18,
"background_execution": True,
"enhanced_features": True,
"repository_pattern": True,
"api_version": "enhanced_v1"
}
# Update job status using repository pattern
await enhanced_training_service._update_job_status_repository(
job_id=job_id,
status="running",
progress=0,
current_step="Initializing enhanced training pipeline",
tenant_id=tenant_id
)
# Execute the enhanced training pipeline with repository pattern
result = await enhanced_training_service.start_training_job(
tenant_id=tenant_id,
job_id=job_id,
bakery_location=bakery_location,
requested_start=requested_start,
requested_end=requested_end
)
# Update final status using repository pattern
await enhanced_training_service._update_job_status_repository(
job_id=job_id,
status="completed",
progress=100,
current_step="Enhanced training completed successfully",
results=result,
tenant_id=tenant_id
)
# Publish enhanced completion event
await publish_job_completed(
job_id=job_id,
tenant_id=tenant_id,
results={
**result,
"enhanced_features": True,
"repository_integration": True
}
)
logger.info("Enhanced background training job completed successfully",
job_id=job_id,
models_created=result.get('products_trained', 0),
features=["repository-pattern", "enhanced-tracking"])
except Exception as training_error:
logger.error("Enhanced training pipeline failed",
job_id=job_id,
error=str(training_error))
try:
await enhanced_training_service._update_job_status_repository(
job_id=job_id,
status="failed",
progress=0,
current_step="Enhanced training failed",
error_message=str(training_error),
tenant_id=tenant_id
)
except Exception as status_error:
logger.error("Failed to update job status after training error",
job_id=job_id,
status_error=str(status_error))
# Publish enhanced failure event
await publish_job_failed(
job_id=job_id,
tenant_id=tenant_id,
error=str(training_error),
metadata={
"enhanced_features": True,
"repository_pattern": True,
"error_type": type(training_error).__name__
}
)
except Exception as background_error:
logger.error("Critical error in enhanced background training job",
job_id=job_id,
error=str(background_error))
finally:
logger.info("Enhanced background training job cleanup completed",
job_id=job_id)
@router.post("/tenants/{tenant_id}/training/products/{inventory_product_id}", response_model=TrainingJobResponse)
@track_execution_time("enhanced_single_product_training_duration_seconds", "training-service")
async def start_enhanced_single_product_training(
request: SingleProductTrainingRequest,
tenant_id: str = Path(..., description="Tenant ID"),
inventory_product_id: str = Path(..., description="Inventory product UUID"),
request_obj: Request = None,
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
):
"""
Start enhanced training for a single product using repository pattern.
Enhanced features:
- Repository pattern for data access
- Enhanced error handling and validation
- Metrics tracking
- Transactional operations
"""
metrics = get_metrics_collector(request_obj)
try:
logger.info("Starting enhanced single product training",
inventory_product_id=inventory_product_id,
tenant_id=tenant_id)
# Record metrics
if metrics:
metrics.increment_counter("enhanced_single_product_training_total")
# Generate enhanced job ID
job_id = f"enhanced_single_{tenant_id}_{inventory_product_id}_{uuid.uuid4().hex[:8]}"
# Delegate to enhanced training service (single product method to be implemented)
result = await enhanced_training_service.start_single_product_training(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
job_id=job_id,
bakery_location=request.bakery_location or (40.4168, -3.7038)
)
if metrics:
metrics.increment_counter("enhanced_single_product_training_success_total")
logger.info("Enhanced single product training completed",
inventory_product_id=inventory_product_id,
job_id=job_id)
return TrainingJobResponse(**result)
except ValueError as e:
if metrics:
metrics.increment_counter("enhanced_single_product_validation_errors_total")
logger.error("Enhanced single product training validation error",
error=str(e),
inventory_product_id=inventory_product_id)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_single_product_training_errors_total")
logger.error("Enhanced single product training failed",
error=str(e),
inventory_product_id=inventory_product_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Enhanced single product training failed"
)
@router.get("/tenants/{tenant_id}/training/jobs/{job_id}/status")
@track_execution_time("enhanced_job_status_duration_seconds", "training-service")
async def get_enhanced_training_job_status(
tenant_id: str = Path(..., description="Tenant ID"),
job_id: str = Path(..., description="Job ID"),
request_obj: Request = None,
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
):
"""
Get enhanced training job status using repository pattern.
"""
metrics = get_metrics_collector(request_obj)
try:
# Get status using enhanced service
status_info = await enhanced_training_service.get_training_status(job_id)
if not status_info or status_info.get("error"):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Training job not found"
)
if metrics:
metrics.increment_counter("enhanced_status_requests_total")
return {
**status_info,
"enhanced_features": True,
"repository_integration": True
}
except HTTPException:
raise
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_status_errors_total")
logger.error("Failed to get enhanced training status",
job_id=job_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get training status"
)
@router.get("/tenants/{tenant_id}/models")
@track_execution_time("enhanced_models_list_duration_seconds", "training-service")
async def get_enhanced_tenant_models(
tenant_id: str = Path(..., description="Tenant ID"),
active_only: bool = Query(True, description="Return only active models"),
skip: int = Query(0, description="Number of models to skip"),
limit: int = Query(100, description="Number of models to return"),
request_obj: Request = None,
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
):
"""
Get tenant models using enhanced repository pattern.
"""
metrics = get_metrics_collector(request_obj)
try:
# Get models using enhanced service
models = await enhanced_training_service.get_tenant_models(
tenant_id=tenant_id,
active_only=active_only,
skip=skip,
limit=limit
)
if metrics:
metrics.increment_counter("enhanced_models_requests_total")
return {
"tenant_id": tenant_id,
"models": models,
"total_returned": len(models),
"active_only": active_only,
"pagination": {
"skip": skip,
"limit": limit
},
"enhanced_features": True,
"repository_integration": True
}
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_models_errors_total")
logger.error("Failed to get enhanced tenant models",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get tenant models"
)
@router.get("/tenants/{tenant_id}/models/{model_id}/performance")
@track_execution_time("enhanced_model_performance_duration_seconds", "training-service")
async def get_enhanced_model_performance(
tenant_id: str = Path(..., description="Tenant ID"),
model_id: str = Path(..., description="Model ID"),
request_obj: Request = None,
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
):
"""
Get enhanced model performance metrics using repository pattern.
"""
metrics = get_metrics_collector(request_obj)
try:
# Get performance using enhanced service
performance = await enhanced_training_service.get_model_performance(model_id)
if not performance:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model performance not found"
)
if metrics:
metrics.increment_counter("enhanced_performance_requests_total")
return {
**performance,
"enhanced_features": True,
"repository_integration": True
}
except HTTPException:
raise
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_performance_errors_total")
logger.error("Failed to get enhanced model performance",
model_id=model_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get model performance"
)
@router.get("/tenants/{tenant_id}/statistics")
@track_execution_time("enhanced_tenant_statistics_duration_seconds", "training-service")
async def get_enhanced_tenant_statistics(
tenant_id: str = Path(..., description="Tenant ID"),
request_obj: Request = None,
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
):
"""
Get comprehensive enhanced tenant statistics using repository pattern.
"""
metrics = get_metrics_collector(request_obj)
try:
# Get statistics using enhanced service
statistics = await enhanced_training_service.get_tenant_statistics(tenant_id)
if statistics.get("error"):
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=statistics["error"]
)
if metrics:
metrics.increment_counter("enhanced_statistics_requests_total")
return {
**statistics,
"enhanced_features": True,
"repository_integration": True
}
except HTTPException:
raise
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_statistics_errors_total")
logger.error("Failed to get enhanced tenant statistics",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get tenant statistics"
)
@router.get("/health")
async def enhanced_health_check():
"""
Enhanced health check endpoint for the training service.
"""
return {
"status": "healthy",
"service": "enhanced-training-service",
"version": "2.0.0",
"features": [
"repository-pattern",
"dependency-injection",
"enhanced-error-handling",
"metrics-tracking",
"transactional-operations"
],
"timestamp": datetime.now().isoformat()
}

View File

@@ -0,0 +1,123 @@
"""
Training Jobs API - ATOMIC CRUD operations
Handles basic training job creation and retrieval
"""
from fastapi import APIRouter, Depends, HTTPException, status, Path, Query, Request
from typing import List, Optional
import structlog
from shared.routing import RouteBuilder
from shared.monitoring.decorators import track_execution_time
from shared.monitoring.metrics import get_metrics_collector
from datetime import datetime
import uuid
from app.services.training_service import EnhancedTrainingService
from app.schemas.training import TrainingJobResponse
from shared.database.base import create_database_manager
from app.core.config import settings
logger = structlog.get_logger()
route_builder = RouteBuilder('training')
router = APIRouter(tags=["training-jobs"])
def get_enhanced_training_service():
"""Dependency injection for EnhancedTrainingService"""
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
return EnhancedTrainingService(database_manager)
@router.get(
route_builder.build_nested_resource_route("jobs", "job_id", "status")
)
@track_execution_time("enhanced_job_status_duration_seconds", "training-service")
async def get_training_job_status(
tenant_id: str = Path(..., description="Tenant ID"),
job_id: str = Path(..., description="Job ID"),
request_obj: Request = None,
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
):
"""
Get training job status using repository pattern.
"""
metrics = get_metrics_collector(request_obj)
try:
# Get status using enhanced service
status_info = await enhanced_training_service.get_training_status(job_id)
if not status_info or status_info.get("error"):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Training job not found"
)
if metrics:
metrics.increment_counter("enhanced_status_requests_total")
return {
**status_info,
"enhanced_features": True,
"repository_integration": True
}
except HTTPException:
raise
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_status_errors_total")
logger.error("Failed to get training status",
job_id=job_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get training status"
)
@router.get(
route_builder.build_base_route("statistics")
)
@track_execution_time("enhanced_tenant_statistics_duration_seconds", "training-service")
async def get_tenant_statistics(
tenant_id: str = Path(..., description="Tenant ID"),
request_obj: Request = None,
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
):
"""
Get comprehensive tenant statistics using repository pattern.
"""
metrics = get_metrics_collector(request_obj)
try:
# Get statistics using enhanced service
statistics = await enhanced_training_service.get_tenant_statistics(tenant_id)
if statistics.get("error"):
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=statistics["error"]
)
if metrics:
metrics.increment_counter("enhanced_statistics_requests_total")
return {
**statistics,
"enhanced_features": True,
"repository_integration": True
}
except HTTPException:
raise
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_statistics_errors_total")
logger.error("Failed to get tenant statistics",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get tenant statistics"
)

View File

@@ -0,0 +1,730 @@
"""
Training Operations API - BUSINESS logic
Handles training job execution, metrics, and WebSocket live feed
"""
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks, Request, Path, WebSocket, WebSocketDisconnect
from typing import List, Optional, Dict, Any
import structlog
import asyncio
import json
import datetime
from shared.auth.access_control import require_user_role, admin_role_required, analytics_tier_required
from shared.routing import RouteBuilder
from shared.monitoring.decorators import track_execution_time
from shared.monitoring.metrics import get_metrics_collector
from shared.database.base import create_database_manager
from datetime import datetime, timezone
import uuid
from app.services.training_service import EnhancedTrainingService
from app.schemas.training import (
TrainingJobRequest,
SingleProductTrainingRequest,
TrainingJobResponse
)
from app.services.messaging import (
publish_job_progress,
publish_data_validation_started,
publish_data_validation_completed,
publish_job_step_completed,
publish_job_completed,
publish_job_failed,
publish_job_started,
training_publisher
)
from app.core.config import settings
logger = structlog.get_logger()
route_builder = RouteBuilder('training')
router = APIRouter(tags=["training-operations"])
def get_enhanced_training_service():
"""Dependency injection for EnhancedTrainingService"""
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
return EnhancedTrainingService(database_manager)
@router.post(
route_builder.build_base_route("jobs"), response_model=TrainingJobResponse)
@track_execution_time("enhanced_training_job_duration_seconds", "training-service")
async def start_training_job(
request: TrainingJobRequest,
tenant_id: str = Path(..., description="Tenant ID"),
background_tasks: BackgroundTasks = BackgroundTasks(),
request_obj: Request = None,
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
):
"""
Start a new training job for all tenant products using repository pattern.
Enhanced immediate response pattern:
1. Validate request with enhanced validation
2. Create job record using repository pattern
3. Return 200 with enhanced job details
4. Execute enhanced training in background with repository tracking
Enhanced features:
- Repository pattern for data access
- Enhanced error handling and logging
- Metrics tracking and monitoring
- Transactional operations
"""
metrics = get_metrics_collector(request_obj)
try:
# Generate enhanced job ID
job_id = f"enhanced_training_{tenant_id}_{uuid.uuid4().hex[:8]}"
logger.info("Creating enhanced training job using repository pattern",
job_id=job_id,
tenant_id=tenant_id)
# Record job creation metrics
if metrics:
metrics.increment_counter("enhanced_training_jobs_created_total")
# Add enhanced background task
background_tasks.add_task(
execute_training_job_background,
tenant_id=tenant_id,
job_id=job_id,
bakery_location=(40.4168, -3.7038),
requested_start=request.start_date,
requested_end=request.end_date
)
# Return enhanced immediate success response
response_data = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "pending",
"message": "Enhanced training job started successfully using repository pattern",
"created_at": datetime.now(timezone.utc),
"estimated_duration_minutes": 18,
"training_results": {
"total_products": 0,
"successful_trainings": 0,
"failed_trainings": 0,
"products": [],
"overall_training_time_seconds": 0.0
},
"data_summary": None,
"completed_at": None,
"error_details": None,
"processing_metadata": {
"background_task": True,
"async_execution": True,
"enhanced_features": True,
"repository_pattern": True,
"dependency_injection": True
}
}
logger.info("Enhanced training job queued successfully",
job_id=job_id,
features=["repository-pattern", "dependency-injection", "enhanced-tracking"])
return TrainingJobResponse(**response_data)
except HTTPException:
raise
except ValueError as e:
if metrics:
metrics.increment_counter("enhanced_training_validation_errors_total")
logger.error("Enhanced training job validation error",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_training_job_errors_total")
logger.error("Failed to queue enhanced training job",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to start enhanced training job"
)
async def execute_training_job_background(
tenant_id: str,
job_id: str,
bakery_location: tuple,
requested_start: Optional[datetime] = None,
requested_end: Optional[datetime] = None
):
"""
Enhanced background task that executes the training job using repository pattern.
Enhanced features:
- Repository pattern for all data operations
- Enhanced error handling with structured logging
- Transactional operations for data consistency
- Comprehensive metrics tracking
- Database connection pooling
- Enhanced progress reporting
"""
logger.info("Enhanced background training job started",
job_id=job_id,
tenant_id=tenant_id,
features=["repository-pattern", "enhanced-tracking"])
# Get enhanced training service with dependency injection
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
enhanced_training_service = EnhancedTrainingService(database_manager)
try:
# Create initial training log entry first
await enhanced_training_service._update_job_status_repository(
job_id=job_id,
status="pending",
progress=0,
current_step="Starting enhanced training job",
tenant_id=tenant_id
)
# Publish job started event
await publish_job_started(job_id, tenant_id, {
"enhanced_features": True,
"repository_pattern": True,
"job_type": "enhanced_training"
})
training_config = {
"job_id": job_id,
"tenant_id": tenant_id,
"bakery_location": {
"latitude": bakery_location[0],
"longitude": bakery_location[1]
},
"requested_start": requested_start.isoformat() if requested_start else None,
"requested_end": requested_end.isoformat() if requested_end else None,
"estimated_duration_minutes": 18,
"background_execution": True,
"enhanced_features": True,
"repository_pattern": True,
"api_version": "enhanced_v1"
}
# Update job status using repository pattern
await enhanced_training_service._update_job_status_repository(
job_id=job_id,
status="running",
progress=0,
current_step="Initializing enhanced training pipeline",
tenant_id=tenant_id
)
# Execute the enhanced training pipeline with repository pattern
result = await enhanced_training_service.start_training_job(
tenant_id=tenant_id,
job_id=job_id,
bakery_location=bakery_location,
requested_start=requested_start,
requested_end=requested_end
)
# Update final status using repository pattern
await enhanced_training_service._update_job_status_repository(
job_id=job_id,
status="completed",
progress=100,
current_step="Enhanced training completed successfully",
results=result,
tenant_id=tenant_id
)
# Publish enhanced completion event
await publish_job_completed(
job_id=job_id,
tenant_id=tenant_id,
results={
**result,
"enhanced_features": True,
"repository_integration": True
}
)
logger.info("Enhanced background training job completed successfully",
job_id=job_id,
models_created=result.get('products_trained', 0),
features=["repository-pattern", "enhanced-tracking"])
except Exception as training_error:
logger.error("Enhanced training pipeline failed",
job_id=job_id,
error=str(training_error))
try:
await enhanced_training_service._update_job_status_repository(
job_id=job_id,
status="failed",
progress=0,
current_step="Enhanced training failed",
error_message=str(training_error),
tenant_id=tenant_id
)
except Exception as status_error:
logger.error("Failed to update job status after training error",
job_id=job_id,
status_error=str(status_error))
# Publish enhanced failure event
await publish_job_failed(
job_id=job_id,
tenant_id=tenant_id,
error=str(training_error),
metadata={
"enhanced_features": True,
"repository_pattern": True,
"error_type": type(training_error).__name__
}
)
except Exception as background_error:
logger.error("Critical error in enhanced background training job",
job_id=job_id,
error=str(background_error))
finally:
logger.info("Enhanced background training job cleanup completed",
job_id=job_id)
@router.post(
route_builder.build_resource_detail_route("products", "inventory_product_id"), response_model=TrainingJobResponse)
@track_execution_time("enhanced_single_product_training_duration_seconds", "training-service")
async def start_single_product_training(
request: SingleProductTrainingRequest,
tenant_id: str = Path(..., description="Tenant ID"),
inventory_product_id: str = Path(..., description="Inventory product UUID"),
request_obj: Request = None,
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
):
"""
Start enhanced training for a single product using repository pattern.
Enhanced features:
- Repository pattern for data access
- Enhanced error handling and validation
- Metrics tracking
- Transactional operations
"""
metrics = get_metrics_collector(request_obj)
try:
logger.info("Starting enhanced single product training",
inventory_product_id=inventory_product_id,
tenant_id=tenant_id)
# Record metrics
if metrics:
metrics.increment_counter("enhanced_single_product_training_total")
# Generate enhanced job ID
job_id = f"enhanced_single_{tenant_id}_{inventory_product_id}_{uuid.uuid4().hex[:8]}"
# Delegate to enhanced training service
result = await enhanced_training_service.start_single_product_training(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
job_id=job_id,
bakery_location=request.bakery_location or (40.4168, -3.7038)
)
if metrics:
metrics.increment_counter("enhanced_single_product_training_success_total")
logger.info("Enhanced single product training completed",
inventory_product_id=inventory_product_id,
job_id=job_id)
return TrainingJobResponse(**result)
except ValueError as e:
if metrics:
metrics.increment_counter("enhanced_single_product_validation_errors_total")
logger.error("Enhanced single product training validation error",
error=str(e),
inventory_product_id=inventory_product_id)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_single_product_training_errors_total")
logger.error("Enhanced single product training failed",
error=str(e),
inventory_product_id=inventory_product_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Enhanced single product training failed"
)
# ============================================
# WebSocket Live Feed
# ============================================
class ConnectionManager:
"""Manage WebSocket connections for training progress"""
def __init__(self):
self.active_connections: Dict[str, Dict[str, WebSocket]] = {}
# Structure: {job_id: {connection_id: websocket}}
async def connect(self, websocket: WebSocket, job_id: str, connection_id: str):
"""Accept WebSocket connection and register it"""
await websocket.accept()
if job_id not in self.active_connections:
self.active_connections[job_id] = {}
self.active_connections[job_id][connection_id] = websocket
logger.info(f"WebSocket connected for job {job_id}, connection {connection_id}")
def disconnect(self, job_id: str, connection_id: str):
"""Remove WebSocket connection"""
if job_id in self.active_connections:
self.active_connections[job_id].pop(connection_id, None)
if not self.active_connections[job_id]:
del self.active_connections[job_id]
logger.info(f"WebSocket disconnected for job {job_id}, connection {connection_id}")
async def send_to_job(self, job_id: str, message: dict):
"""Send message to all connections for a specific job with better error handling"""
if job_id not in self.active_connections:
logger.debug(f"No active connections for job {job_id}")
return
# Send to all connections for this job
disconnected_connections = []
for connection_id, websocket in self.active_connections[job_id].items():
try:
await websocket.send_json(message)
logger.debug(f"Sent {message.get('type', 'unknown')} to connection {connection_id}")
except Exception as e:
logger.warning(f"Failed to send message to connection {connection_id}: {e}")
disconnected_connections.append(connection_id)
# Clean up disconnected connections
for connection_id in disconnected_connections:
self.disconnect(job_id, connection_id)
# Log successful sends
active_count = len(self.active_connections.get(job_id, {}))
if active_count > 0:
logger.info(f"Sent {message.get('type', 'unknown')} message to {active_count} connection(s) for job {job_id}")
# Global connection manager
connection_manager = ConnectionManager()
@router.websocket(route_builder.build_nested_resource_route('jobs', 'job_id', 'live'))
async def training_progress_websocket(
websocket: WebSocket,
tenant_id: str,
job_id: str
):
"""
WebSocket endpoint for real-time training progress updates
"""
# Validate token from query parameters
token = websocket.query_params.get("token")
if not token:
logger.warning(f"WebSocket connection rejected - missing token for job {job_id}")
await websocket.close(code=1008, reason="Authentication token required")
return
# Validate the token
from shared.auth.jwt_handler import JWTHandler
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
try:
payload = jwt_handler.verify_token(token)
if not payload:
logger.warning(f"WebSocket connection rejected - invalid token for job {job_id}")
await websocket.close(code=1008, reason="Invalid authentication token")
return
# Verify user has access to this tenant
user_id = payload.get('user_id')
if not user_id:
logger.warning(f"WebSocket connection rejected - no user_id in token for job {job_id}")
await websocket.close(code=1008, reason="Invalid token payload")
return
logger.info(f"WebSocket authenticated for user {payload.get('email', 'unknown')} on job {job_id}")
except Exception as e:
logger.warning(f"WebSocket token validation failed for job {job_id}: {e}")
await websocket.close(code=1008, reason="Token validation failed")
return
connection_id = f"{tenant_id}_{user_id}_{id(websocket)}"
await connection_manager.connect(websocket, job_id, connection_id)
logger.info(f"WebSocket connection established for job {job_id}, user {user_id}")
consumer_task = None
training_completed = False
try:
# Start RabbitMQ consumer
consumer_task = asyncio.create_task(
setup_rabbitmq_consumer_for_job(job_id, tenant_id)
)
last_activity = asyncio.get_event_loop().time()
while not training_completed:
try:
try:
data = await asyncio.wait_for(websocket.receive(), timeout=60.0)
last_activity = asyncio.get_event_loop().time()
# Handle different message types
if data["type"] == "websocket.receive":
if "text" in data:
message_text = data["text"]
if message_text == "ping":
await websocket.send_text("pong")
logger.debug(f"Text ping received from job {job_id}")
elif message_text == "get_status":
current_status = await get_current_job_status(job_id, tenant_id)
if current_status:
await websocket.send_json({
"type": "current_status",
"job_id": job_id,
"data": current_status
})
elif message_text == "close":
logger.info(f"Client requested connection close for job {job_id}")
break
elif "bytes" in data:
await websocket.send_text("pong")
logger.debug(f"Binary ping received for job {job_id}, responding with text pong")
elif data["type"] == "websocket.disconnect":
logger.info(f"WebSocket disconnect message received for job {job_id}")
break
except asyncio.TimeoutError:
current_time = asyncio.get_event_loop().time()
if current_time - last_activity > 90:
logger.warning(f"No frontend activity for 90s on job {job_id}, sending training service heartbeat")
try:
await websocket.send_json({
"type": "heartbeat",
"job_id": job_id,
"timestamp": str(datetime.now()),
"message": "Training service heartbeat - frontend inactive",
"inactivity_seconds": int(current_time - last_activity)
})
last_activity = current_time
except Exception as e:
logger.error(f"Failed to send heartbeat for job {job_id}: {e}")
break
else:
logger.debug(f"Normal 60s timeout for job {job_id}, continuing (last activity: {int(current_time - last_activity)}s ago)")
continue
except WebSocketDisconnect:
logger.info(f"WebSocket client disconnected for job {job_id}")
break
except Exception as e:
logger.error(f"WebSocket error for job {job_id}: {e}")
if "Cannot call" in str(e) and "disconnect message" in str(e):
logger.error(f"FastAPI WebSocket disconnect error - connection already closed")
break
await asyncio.sleep(1)
logger.info(f"WebSocket loop ended for job {job_id}, training_completed: {training_completed}")
except Exception as e:
logger.error(f"Critical WebSocket error for job {job_id}: {e}")
finally:
logger.info(f"Cleaning up WebSocket connection for job {job_id}")
connection_manager.disconnect(job_id, connection_id)
if consumer_task and not consumer_task.done():
if training_completed:
logger.info(f"Training completed, cancelling consumer for job {job_id}")
consumer_task.cancel()
else:
logger.warning(f"WebSocket disconnected but training not completed for job {job_id}")
try:
await consumer_task
except asyncio.CancelledError:
logger.info(f"Consumer task cancelled for job {job_id}")
except Exception as e:
logger.error(f"Consumer task error for job {job_id}: {e}")
async def setup_rabbitmq_consumer_for_job(job_id: str, tenant_id: str):
"""Set up RabbitMQ consumer to listen for training events for a specific job"""
logger.info(f"Setting up RabbitMQ consumer for job {job_id}")
try:
# Create a unique queue for this WebSocket connection
queue_name = f"websocket_training_{job_id}_{tenant_id}"
async def handle_training_message(message):
"""Handle incoming RabbitMQ messages and forward to WebSocket"""
try:
# Parse the message
body = message.body.decode()
data = json.loads(body)
logger.debug(f"Received message for job {job_id}: {data.get('event_type', 'unknown')}")
# Extract event data
event_type = data.get("event_type", "unknown")
event_data = data.get("data", {})
# Only process messages for this specific job
message_job_id = event_data.get("job_id") if event_data else None
if message_job_id != job_id:
logger.debug(f"Ignoring message for different job: {message_job_id}")
await message.ack()
return
# Transform RabbitMQ message to WebSocket message format
websocket_message = {
"type": map_event_type_to_websocket_type(event_type),
"job_id": job_id,
"timestamp": data.get("timestamp"),
"data": event_data
}
logger.info(f"Forwarding {event_type} message to WebSocket clients for job {job_id}")
# Send to all WebSocket connections for this job
await connection_manager.send_to_job(job_id, websocket_message)
# Check if this is a completion message
if event_type in ["training.completed", "training.failed"]:
logger.info(f"Training completion detected for job {job_id}: {event_type}")
# Acknowledge the message
await message.ack()
logger.debug(f"Successfully processed {event_type} for job {job_id}")
except Exception as e:
logger.error(f"Error handling training message for job {job_id}: {e}")
import traceback
logger.error(f"Traceback: {traceback.format_exc()}")
await message.nack(requeue=False)
# Check if training_publisher is connected
if not training_publisher.connected:
logger.warning(f"Training publisher not connected for job {job_id}, attempting to connect...")
success = await training_publisher.connect()
if not success:
logger.error(f"Failed to connect training_publisher for job {job_id}")
return
# Subscribe to training events
logger.info(f"Subscribing to training events for job {job_id}")
success = await training_publisher.consume_events(
exchange_name="training.events",
queue_name=queue_name,
routing_key="training.*",
callback=handle_training_message
)
if success:
logger.info(f"Successfully set up RabbitMQ consumer for job {job_id} (queue: {queue_name})")
# Keep the consumer running indefinitely until cancelled
try:
while True:
await asyncio.sleep(10)
logger.debug(f"Consumer heartbeat for job {job_id}")
except asyncio.CancelledError:
logger.info(f"Consumer cancelled for job {job_id}")
raise
except Exception as e:
logger.error(f"Consumer error for job {job_id}: {e}")
raise
else:
logger.error(f"Failed to set up RabbitMQ consumer for job {job_id}")
except Exception as e:
logger.error(f"Exception in setup_rabbitmq_consumer_for_job for job {job_id}: {e}")
import traceback
logger.error(f"Traceback: {traceback.format_exc()}")
def map_event_type_to_websocket_type(rabbitmq_event_type: str) -> str:
"""Map RabbitMQ event types to WebSocket message types"""
mapping = {
"training.started": "started",
"training.progress": "progress",
"training.completed": "completed",
"training.failed": "failed",
"training.cancelled": "cancelled",
"training.step.completed": "step_completed",
"training.product.started": "product_started",
"training.product.completed": "product_completed",
"training.product.failed": "product_failed",
"training.model.trained": "model_trained",
"training.data.validation.started": "validation_started",
"training.data.validation.completed": "validation_completed"
}
return mapping.get(rabbitmq_event_type, "unknown")
async def get_current_job_status(job_id: str, tenant_id: str) -> Dict[str, Any]:
"""Get current job status from database"""
try:
return {
"job_id": job_id,
"status": "running",
"progress": 0,
"current_step": "Starting...",
"started_at": "2025-07-30T19:00:00Z"
}
except Exception as e:
logger.error(f"Failed to get current job status: {e}")
return None
@router.get("/health")
async def health_check():
"""Health check endpoint for the training operations"""
return {
"status": "healthy",
"service": "training-operations",
"version": "2.0.0",
"features": [
"repository-pattern",
"dependency-injection",
"enhanced-error-handling",
"metrics-tracking",
"transactional-operations",
"websocket-support"
],
"timestamp": datetime.now().isoformat()
}

View File

@@ -1,377 +0,0 @@
# services/training/app/api/websocket.py
"""
WebSocket endpoints for real-time training progress updates
"""
import json
import asyncio
from typing import Dict, Any
from fastapi import WebSocket, WebSocketDisconnect, Depends, HTTPException
from fastapi.routing import APIRouter
import datetime
import structlog
logger = structlog.get_logger(__name__)
from app.services.messaging import training_publisher
from shared.auth.decorators import (
get_current_user_dep
)
# Create WebSocket router
websocket_router = APIRouter()
class ConnectionManager:
"""Manage WebSocket connections for training progress"""
def __init__(self):
self.active_connections: Dict[str, Dict[str, WebSocket]] = {}
# Structure: {job_id: {connection_id: websocket}}
async def connect(self, websocket: WebSocket, job_id: str, connection_id: str):
"""Accept WebSocket connection and register it"""
await websocket.accept()
if job_id not in self.active_connections:
self.active_connections[job_id] = {}
self.active_connections[job_id][connection_id] = websocket
logger.info(f"WebSocket connected for job {job_id}, connection {connection_id}")
def disconnect(self, job_id: str, connection_id: str):
"""Remove WebSocket connection"""
if job_id in self.active_connections:
self.active_connections[job_id].pop(connection_id, None)
if not self.active_connections[job_id]:
del self.active_connections[job_id]
logger.info(f"WebSocket disconnected for job {job_id}, connection {connection_id}")
async def send_to_job(self, job_id: str, message: dict):
"""Send message to all connections for a specific job with better error handling"""
if job_id not in self.active_connections:
logger.debug(f"No active connections for job {job_id}")
return
# Send to all connections for this job
disconnected_connections = []
for connection_id, websocket in self.active_connections[job_id].items():
try:
await websocket.send_json(message)
logger.debug(f"📤 Sent {message.get('type', 'unknown')} to connection {connection_id}")
except Exception as e:
logger.warning(f"Failed to send message to connection {connection_id}: {e}")
disconnected_connections.append(connection_id)
# Clean up disconnected connections
for connection_id in disconnected_connections:
self.disconnect(job_id, connection_id)
# Log successful sends
active_count = len(self.active_connections.get(job_id, {}))
if active_count > 0:
logger.info(f"📡 Sent {message.get('type', 'unknown')} message to {active_count} connection(s) for job {job_id}")
# Global connection manager
connection_manager = ConnectionManager()
@websocket_router.websocket("/tenants/{tenant_id}/training/jobs/{job_id}/live")
async def training_progress_websocket(
websocket: WebSocket,
tenant_id: str,
job_id: str
):
# Validate token from query parameters
token = websocket.query_params.get("token")
if not token:
logger.warning(f"WebSocket connection rejected - missing token for job {job_id}")
await websocket.close(code=1008, reason="Authentication token required")
return
# Validate the token (use the same JWT handler as gateway)
from shared.auth.jwt_handler import JWTHandler
from app.core.config import settings
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
try:
payload = jwt_handler.verify_token(token)
if not payload:
logger.warning(f"WebSocket connection rejected - invalid token for job {job_id}")
await websocket.close(code=1008, reason="Invalid authentication token")
return
# Verify user has access to this tenant
user_id = payload.get('user_id')
if not user_id:
logger.warning(f"WebSocket connection rejected - no user_id in token for job {job_id}")
await websocket.close(code=1008, reason="Invalid token payload")
return
logger.info(f"WebSocket authenticated for user {payload.get('email', 'unknown')} on job {job_id}")
except Exception as e:
logger.warning(f"WebSocket token validation failed for job {job_id}: {e}")
await websocket.close(code=1008, reason="Token validation failed")
return
connection_id = f"{tenant_id}_{user_id}_{id(websocket)}"
await connection_manager.connect(websocket, job_id, connection_id)
logger.info(f"WebSocket connection established for job {job_id}, user {user_id}")
consumer_task = None
training_completed = False
try:
# Start RabbitMQ consumer
consumer_task = asyncio.create_task(
setup_rabbitmq_consumer_for_job(job_id, tenant_id)
)
last_activity = asyncio.get_event_loop().time()
while not training_completed:
try:
# Coordinate with frontend 30s heartbeat + gateway 45s timeout
# This should be longer than gateway timeout to avoid premature closure
try:
data = await asyncio.wait_for(websocket.receive(), timeout=60.0)
last_activity = asyncio.get_event_loop().time()
# Handle different message types
if data["type"] == "websocket.receive":
if "text" in data:
message_text = data["text"]
if message_text == "ping":
await websocket.send_text("pong")
logger.debug(f"Text ping received from job {job_id}")
elif message_text == "get_status":
current_status = await get_current_job_status(job_id, tenant_id)
if current_status:
await websocket.send_json({
"type": "current_status",
"job_id": job_id,
"data": current_status
})
elif message_text == "close":
logger.info(f"Client requested connection close for job {job_id}")
break
elif "bytes" in data:
# Handle binary messages (WebSocket ping frames) - respond with text pong for compatibility
await websocket.send_text("pong")
logger.debug(f"Binary ping received for job {job_id}, responding with text pong")
elif data["type"] == "websocket.disconnect":
logger.info(f"WebSocket disconnect message received for job {job_id}")
break
except asyncio.TimeoutError:
# No message received in 60 seconds - this is now coordinated with gateway timeouts
current_time = asyncio.get_event_loop().time()
# Send heartbeat only if we haven't received frontend ping for too long
# Frontend sends ping every 30s, so 60s timeout + 30s grace = 90s before heartbeat
if current_time - last_activity > 90: # 90 seconds of total inactivity
logger.warning(f"No frontend activity for 90s on job {job_id}, sending training service heartbeat")
try:
await websocket.send_json({
"type": "heartbeat",
"job_id": job_id,
"timestamp": str(datetime.datetime.now()),
"message": "Training service heartbeat - frontend inactive",
"inactivity_seconds": int(current_time - last_activity)
})
last_activity = current_time
except Exception as e:
logger.error(f"Failed to send heartbeat for job {job_id}: {e}")
break
else:
# Normal timeout, frontend should be sending ping every 30s
logger.debug(f"Normal 60s timeout for job {job_id}, continuing (last activity: {int(current_time - last_activity)}s ago)")
continue
except WebSocketDisconnect:
logger.info(f"WebSocket client disconnected for job {job_id}")
break
except Exception as e:
logger.error(f"WebSocket error for job {job_id}: {e}")
# Check if it's the specific "cannot call receive" error
if "Cannot call" in str(e) and "disconnect message" in str(e):
logger.error(f"FastAPI WebSocket disconnect error - connection already closed")
break
# Don't break immediately for other errors - try to recover
await asyncio.sleep(1)
logger.info(f"WebSocket loop ended for job {job_id}, training_completed: {training_completed}")
except Exception as e:
logger.error(f"Critical WebSocket error for job {job_id}: {e}")
finally:
logger.info(f"Cleaning up WebSocket connection for job {job_id}")
connection_manager.disconnect(job_id, connection_id)
if consumer_task and not consumer_task.done():
if training_completed:
logger.info(f"Training completed, cancelling consumer for job {job_id}")
consumer_task.cancel()
else:
logger.warning(f"WebSocket disconnected but training not completed for job {job_id}")
try:
await consumer_task
except asyncio.CancelledError:
logger.info(f"Consumer task cancelled for job {job_id}")
except Exception as e:
logger.error(f"Consumer task error for job {job_id}: {e}")
async def setup_rabbitmq_consumer_for_job(job_id: str, tenant_id: str):
"""Set up RabbitMQ consumer to listen for training events for a specific job"""
logger.info(f"🚀 Setting up RabbitMQ consumer for job {job_id}")
try:
# Create a unique queue for this WebSocket connection
queue_name = f"websocket_training_{job_id}_{tenant_id}"
async def handle_training_message(message):
"""Handle incoming RabbitMQ messages and forward to WebSocket"""
try:
# Parse the message
body = message.body.decode()
data = json.loads(body)
logger.debug(f"🔍 Received message for job {job_id}: {data.get('event_type', 'unknown')}")
# Extract event data
event_type = data.get("event_type", "unknown")
event_data = data.get("data", {})
# Only process messages for this specific job
message_job_id = event_data.get("job_id") if event_data else None
if message_job_id != job_id:
logger.debug(f"⏭️ Ignoring message for different job: {message_job_id}")
await message.ack()
return
# Transform RabbitMQ message to WebSocket message format
websocket_message = {
"type": map_event_type_to_websocket_type(event_type),
"job_id": job_id,
"timestamp": data.get("timestamp"),
"data": event_data
}
logger.info(f"📤 Forwarding {event_type} message to WebSocket clients for job {job_id}")
# Send to all WebSocket connections for this job
await connection_manager.send_to_job(job_id, websocket_message)
# Check if this is a completion message
if event_type in ["training.completed", "training.failed"]:
logger.info(f"🎯 Training completion detected for job {job_id}: {event_type}")
# Mark training as completed (you might want to store this in a global state)
# For now, we'll let the WebSocket handle this through the message
# Acknowledge the message
await message.ack()
logger.debug(f"✅ Successfully processed {event_type} for job {job_id}")
except Exception as e:
logger.error(f"❌ Error handling training message for job {job_id}: {e}")
import traceback
logger.error(f"💥 Traceback: {traceback.format_exc()}")
await message.nack(requeue=False)
# Check if training_publisher is connected
if not training_publisher.connected:
logger.warning(f"⚠️ Training publisher not connected for job {job_id}, attempting to connect...")
success = await training_publisher.connect()
if not success:
logger.error(f"❌ Failed to connect training_publisher for job {job_id}")
return
# Subscribe to training events
logger.info(f"🔗 Subscribing to training events for job {job_id}")
success = await training_publisher.consume_events(
exchange_name="training.events",
queue_name=queue_name,
routing_key="training.*", # Listen to all training events
callback=handle_training_message
)
if success:
logger.info(f"✅ Successfully set up RabbitMQ consumer for job {job_id} (queue: {queue_name})")
# Keep the consumer running indefinitely until cancelled
try:
while True:
await asyncio.sleep(10) # Keep consumer alive
logger.debug(f"🔄 Consumer heartbeat for job {job_id}")
except asyncio.CancelledError:
logger.info(f"🛑 Consumer cancelled for job {job_id}")
raise
except Exception as e:
logger.error(f"💥 Consumer error for job {job_id}: {e}")
raise
else:
logger.error(f"❌ Failed to set up RabbitMQ consumer for job {job_id}")
except Exception as e:
logger.error(f"💥 Exception in setup_rabbitmq_consumer_for_job for job {job_id}: {e}")
import traceback
logger.error(f"🔥 Traceback: {traceback.format_exc()}")
def map_event_type_to_websocket_type(rabbitmq_event_type: str) -> str:
"""Map RabbitMQ event types to WebSocket message types"""
mapping = {
"training.started": "started",
"training.progress": "progress",
"training.completed": "completed", # This is the key completion event
"training.failed": "failed", # This is also a completion event
"training.cancelled": "cancelled",
"training.step.completed": "step_completed",
"training.product.started": "product_started",
"training.product.completed": "product_completed",
"training.product.failed": "product_failed",
"training.model.trained": "model_trained",
"training.data.validation.started": "validation_started",
"training.data.validation.completed": "validation_completed"
}
return mapping.get(rabbitmq_event_type, "unknown")
async def get_current_job_status(job_id: str, tenant_id: str) -> Dict[str, Any]:
"""Get current job status from database or cache"""
try:
# This should query your database for current job status
# For now, return a placeholder - implement based on your database schema
from app.core.database import get_db_session
from app.models.training import ModelTrainingLog # Assuming you have this model
# async with get_background_db_session() as db:
# Query your training job status
# This is a placeholder - adjust based on your actual database models
# pass
# Placeholder return - replace with actual database query
return {
"job_id": job_id,
"status": "running", # or "completed", "failed", etc.
"progress": 0,
"current_step": "Starting...",
"started_at": "2025-07-30T19:00:00Z"
}
except Exception as e:
logger.error(f"Failed to get current job status: {e}")
return None

View File

@@ -11,8 +11,7 @@ from fastapi import FastAPI, Request
from sqlalchemy import text
from app.core.config import settings
from app.core.database import initialize_training_database, cleanup_training_database, database_manager
from app.api import training, models
from app.api.websocket import websocket_router
from app.api import training_jobs, training_operations, models
from app.services.messaging import setup_messaging, cleanup_messaging
from shared.service_base import StandardFastAPIService
@@ -55,7 +54,7 @@ class TrainingService(StandardFastAPIService):
version="1.0.0",
log_level=settings.LOG_LEVEL,
cors_origins=settings.CORS_ORIGINS_LIST,
api_prefix="/api/v1",
api_prefix="", # Empty because RouteBuilder already includes /api/v1
database_manager=database_manager,
expected_tables=training_expected_tables,
enable_messaging=True
@@ -160,9 +159,9 @@ service.setup_custom_middleware()
service.setup_custom_endpoints()
# Include API routers
service.add_router(training.router, tags=["training"])
service.add_router(training_jobs.router, tags=["training-jobs"])
service.add_router(training_operations.router, tags=["training-operations"])
service.add_router(models.router, tags=["models"])
app.include_router(websocket_router, prefix="/api/v1/ws", tags=["websocket"])
if __name__ == "__main__":
uvicorn.run(

View File

@@ -1,8 +1,8 @@
"""initial_schema_20251001_1118
"""initial_schema_20251006_1516
Revision ID: 121e47ff97c4
Revision ID: b6beee8bf0bf
Revises:
Create Date: 2025-10-01 11:18:37.223786+02:00
Create Date: 2025-10-06 15:16:02.277823+02:00
"""
from typing import Sequence, Union
@@ -12,7 +12,7 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '121e47ff97c4'
revision: str = 'b6beee8bf0bf'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None