Training job in the background

This commit is contained in:
Urtzi Alfaro
2025-08-01 16:26:36 +02:00
parent e67ce2a594
commit 2f6f13bfef
4 changed files with 169 additions and 25 deletions

View File

@@ -9,9 +9,10 @@ from fastapi import Query, Path
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, Optional, Dict, Any from typing import List, Optional, Dict, Any
import structlog import structlog
from datetime import datetime from datetime import datetime, timezone
import uuid
from app.core.database import get_db from app.core.database import get_db, get_background_db_session
from app.services.training_service import TrainingService from app.services.training_service import TrainingService
from app.schemas.training import ( from app.schemas.training import (
TrainingJobRequest, TrainingJobRequest,
@@ -21,15 +22,23 @@ from app.schemas.training import (
TrainingJobResponse TrainingJobResponse
) )
from app.services.messaging import (
publish_job_progress,
publish_data_validation_started,
publish_data_validation_completed,
publish_job_step_completed,
publish_job_completed,
publish_job_failed,
publish_job_started
)
# Import shared auth decorators (assuming they exist in your microservices) # Import shared auth decorators (assuming they exist in your microservices)
from shared.auth.decorators import get_current_tenant_id_dep from shared.auth.decorators import get_current_tenant_id_dep
logger = structlog.get_logger() logger = structlog.get_logger()
router = APIRouter() router = APIRouter()
# Initialize training service
training_service = TrainingService()
@router.post("/tenants/{tenant_id}/training/jobs", response_model=TrainingJobResponse) @router.post("/tenants/{tenant_id}/training/jobs", response_model=TrainingJobResponse)
async def start_training_job( async def start_training_job(
request: TrainingJobRequest, request: TrainingJobRequest,
@@ -41,32 +50,89 @@ async def start_training_job(
""" """
Start a new training job for all tenant products. Start a new training job for all tenant products.
This is the main entry point for the training pipeline: 🚀 IMMEDIATE RESPONSE PATTERN:
API → Training Service → Trainer → Data Processor → Prophet Manager 1. Validate request immediately
2. Create job record with 'pending' status
3. Return 200 with job details
4. Execute training in background with separate DB session
This ensures fast API response while maintaining data consistency.
""" """
try: try:
# Validate tenant access # Validate tenant access immediately
if tenant_id != current_tenant: if tenant_id != current_tenant:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources" detail="Access denied to tenant resources"
) )
# Generate job ID immediately
job_id = f"training_{tenant_id}_{uuid.uuid4().hex[:8]}"
logger.info(f"Starting training job for tenant {tenant_id}") logger.info(f"Creating training job {job_id} for tenant {tenant_id}")
training_service = TrainingService(db_session=db) # Add background task with isolated database session
background_tasks.add_task(
# Delegate to training service (Step 1 of the flow) execute_training_job_background,
result = await training_service.start_training_job(
tenant_id=tenant_id, tenant_id=tenant_id,
bakery_location=(40.4168, -3.7038), # Default Madrid coordinates job_id=job_id,
requested_start=request.start_date if request.start_date else None, bakery_location=(40.4168, -3.7038),
requested_end=request.end_date if request.end_date else None, requested_start=request.start_date,
job_id=None # Let the service generate it requested_end=request.end_date
) )
return TrainingJobResponse(**result) training_config = {
"job_id": job_id,
"tenant_id": tenant_id,
"bakery_location": {
"latitude": 40.4168,
"longitude": -3.7038
},
"requested_start": request.start_date.isoformat() if request.start_date else None,
"requested_end": request.end_date.isoformat() if request.end_date else None,
"estimated_duration_minutes": 15,
"estimated_products": 10,
"background_execution": True,
"api_version": "v1"
}
# Publish immediate event (training started)
await publish_job_started(
job_id=job_id,
tenant_id=tenant_id,
config=training_config
)
# Return immediate success response
response_data = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "pending", # Will change to 'running' in background
"message": "Training job started successfully",
"created_at": datetime.now(timezone.utc),
"estimated_duration_minutes": "15",
"training_results": {
"total_products": 10,
"successful_trainings": 0,
"failed_trainings": 0,
"products": [],
"overall_training_time_seconds": 0.0
},
"data_summary": None,
"completed_at": None,
"error_details": None,
"processing_metadata": {
"background_task": True,
"async_execution": True
}
}
logger.info(f"Training job {job_id} queued successfully, returning immediate response")
return TrainingJobResponse(**response_data)
except HTTPException:
# Re-raise HTTP exceptions as-is
raise
except ValueError as e: except ValueError as e:
logger.error(f"Training job validation error: {str(e)}") logger.error(f"Training job validation error: {str(e)}")
raise HTTPException( raise HTTPException(
@@ -74,12 +140,77 @@ async def start_training_job(
detail=str(e) detail=str(e)
) )
except Exception as e: except Exception as e:
logger.error(f"Training job failed: {str(e)}") logger.error(f"Failed to queue training job: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Training job failed" detail="Failed to start training job"
) )
async def execute_training_job_background(
tenant_id: str,
job_id: str,
bakery_location: tuple,
requested_start: Optional[datetime] = None,
requested_end: Optional[datetime] = None
):
"""
Background task that executes the actual training job.
🔧 KEY FEATURES:
- Uses its own database session (isolated from API request)
- Handles all errors gracefully
- Updates job status in real-time
- Publishes progress events via WebSocket/messaging
- Comprehensive logging and monitoring
"""
logger.info(f"🚀 Background training job {job_id} started for tenant {tenant_id}")
async with get_background_db_session() as db_session:
try:
# ✅ FIX: Create training service with isolated DB session
training_service = TrainingService(db_session=db_session)
# Publish progress event
await publish_job_progress(job_id, tenant_id, 5, "Initializing training pipeline")
try:
# Execute the actual training pipeline
result = await training_service.start_training_job(
tenant_id=tenant_id,
job_id=job_id,
bakery_location=bakery_location,
requested_start=requested_start,
requested_end=requested_end
)
# Publish completion event
await publish_job_completed(
job_id=job_id,
tenant_id=tenant_id,
results=result
)
logger.info(f"✅ Background training job {job_id} completed successfully")
except Exception as training_error:
logger.error(f"❌ Training pipeline failed for job {job_id}: {str(training_error)}")
# Publish failure event
await publish_job_failed(
job_id=job_id,
tenant_id=tenant_id,
error=str(training_error)
)
except Exception as background_error:
logger.error(f"💥 Critical error in background training job {job_id}: {str(background_error)}")
finally:
# Ensure database session is properly closed
logger.info(f"🧹 Background training job {job_id} cleanup completed")
@router.post("/tenants/{tenant_id}/training/products/{product_name}", response_model=TrainingJobResponse) @router.post("/tenants/{tenant_id}/training/products/{product_name}", response_model=TrainingJobResponse)
async def start_single_product_training( async def start_single_product_training(
request: SingleProductTrainingRequest, request: SingleProductTrainingRequest,

View File

@@ -238,10 +238,10 @@ async def get_current_job_status(job_id: str, tenant_id: str) -> Dict[str, Any]:
from app.core.database import get_db_session from app.core.database import get_db_session
from app.models.training import ModelTrainingLog # Assuming you have this model from app.models.training import ModelTrainingLog # Assuming you have this model
async with get_db_session() as db: # async with get_background_db_session() as db:
# Query your training job status # Query your training job status
# This is a placeholder - adjust based on your actual database models # This is a placeholder - adjust based on your actual database models
pass # pass
# Placeholder return - replace with actual database query # Placeholder return - replace with actual database query
return { return {

View File

@@ -7,6 +7,7 @@ Uses shared database infrastructure
import structlog import structlog
from typing import AsyncGenerator from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from contextlib import asynccontextmanager
from sqlalchemy import text from sqlalchemy import text
from shared.database.base import DatabaseManager, Base from shared.database.base import DatabaseManager, Base
@@ -20,6 +21,18 @@ database_manager = DatabaseManager(settings.DATABASE_URL)
# Alias for convenience - matches the existing interface # Alias for convenience - matches the existing interface
get_db = database_manager.get_db get_db = database_manager.get_db
@asynccontextmanager
async def get_background_db_session():
async with database_manager.async_session_local() as session:
try:
yield session
await session.commit()
except Exception as e:
await session.rollback()
raise
finally:
await session.close()
async def get_db_health() -> bool: async def get_db_health() -> bool:
""" """
Health check function for database connectivity Health check function for database connectivity

View File

@@ -13,7 +13,7 @@ TEST_PASSWORD="TestPassword123!"
TEST_NAME="Test Bakery Owner" TEST_NAME="Test Bakery Owner"
REAL_CSV_FILE="bakery_sales_2023_2024.csv" REAL_CSV_FILE="bakery_sales_2023_2024.csv"
WS_BASE="ws://localhost:8002/api/v1/ws" WS_BASE="ws://localhost:8002/api/v1/ws"
WS_TEST_DURATION=30 # seconds to listen for WebSocket messages WS_TEST_DURATION=200 # seconds to listen for WebSocket messages
WS_PID="" WS_PID=""