Training job in the background

This commit is contained in:
Urtzi Alfaro
2025-08-01 16:26:36 +02:00
parent e67ce2a594
commit 2f6f13bfef
4 changed files with 169 additions and 25 deletions

View File

@@ -9,9 +9,10 @@ from fastapi import Query, Path
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, Optional, Dict, Any
import structlog
from datetime import datetime
from datetime import datetime, timezone
import uuid
from app.core.database import get_db
from app.core.database import get_db, get_background_db_session
from app.services.training_service import TrainingService
from app.schemas.training import (
TrainingJobRequest,
@@ -21,15 +22,23 @@ from app.schemas.training import (
TrainingJobResponse
)
from app.services.messaging import (
publish_job_progress,
publish_data_validation_started,
publish_data_validation_completed,
publish_job_step_completed,
publish_job_completed,
publish_job_failed,
publish_job_started
)
# Import shared auth decorators (assuming they exist in your microservices)
from shared.auth.decorators import get_current_tenant_id_dep
logger = structlog.get_logger()
router = APIRouter()
# Initialize training service
training_service = TrainingService()
@router.post("/tenants/{tenant_id}/training/jobs", response_model=TrainingJobResponse)
async def start_training_job(
request: TrainingJobRequest,
@@ -41,32 +50,89 @@ async def start_training_job(
"""
Start a new training job for all tenant products.
This is the main entry point for the training pipeline:
API → Training Service → Trainer → Data Processor → Prophet Manager
🚀 IMMEDIATE RESPONSE PATTERN:
1. Validate request immediately
2. Create job record with 'pending' status
3. Return 200 with job details
4. Execute training in background with separate DB session
This ensures fast API response while maintaining data consistency.
"""
try:
# Validate tenant access
# Validate tenant access immediately
if tenant_id != current_tenant:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
logger.info(f"Starting training job for tenant {tenant_id}")
# Generate job ID immediately
job_id = f"training_{tenant_id}_{uuid.uuid4().hex[:8]}"
training_service = TrainingService(db_session=db)
logger.info(f"Creating training job {job_id} for tenant {tenant_id}")
# Delegate to training service (Step 1 of the flow)
result = await training_service.start_training_job(
# Add background task with isolated database session
background_tasks.add_task(
execute_training_job_background,
tenant_id=tenant_id,
bakery_location=(40.4168, -3.7038), # Default Madrid coordinates
requested_start=request.start_date if request.start_date else None,
requested_end=request.end_date if request.end_date else None,
job_id=None # Let the service generate it
job_id=job_id,
bakery_location=(40.4168, -3.7038),
requested_start=request.start_date,
requested_end=request.end_date
)
return TrainingJobResponse(**result)
training_config = {
"job_id": job_id,
"tenant_id": tenant_id,
"bakery_location": {
"latitude": 40.4168,
"longitude": -3.7038
},
"requested_start": request.start_date.isoformat() if request.start_date else None,
"requested_end": request.end_date.isoformat() if request.end_date else None,
"estimated_duration_minutes": 15,
"estimated_products": 10,
"background_execution": True,
"api_version": "v1"
}
# Publish immediate event (training started)
await publish_job_started(
job_id=job_id,
tenant_id=tenant_id,
config=training_config
)
# Return immediate success response
response_data = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "pending", # Will change to 'running' in background
"message": "Training job started successfully",
"created_at": datetime.now(timezone.utc),
"estimated_duration_minutes": "15",
"training_results": {
"total_products": 10,
"successful_trainings": 0,
"failed_trainings": 0,
"products": [],
"overall_training_time_seconds": 0.0
},
"data_summary": None,
"completed_at": None,
"error_details": None,
"processing_metadata": {
"background_task": True,
"async_execution": True
}
}
logger.info(f"Training job {job_id} queued successfully, returning immediate response")
return TrainingJobResponse(**response_data)
except HTTPException:
# Re-raise HTTP exceptions as-is
raise
except ValueError as e:
logger.error(f"Training job validation error: {str(e)}")
raise HTTPException(
@@ -74,12 +140,77 @@ async def start_training_job(
detail=str(e)
)
except Exception as e:
logger.error(f"Training job failed: {str(e)}")
logger.error(f"Failed to queue training job: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Training job failed"
detail="Failed to start training job"
)
async def execute_training_job_background(
tenant_id: str,
job_id: str,
bakery_location: tuple,
requested_start: Optional[datetime] = None,
requested_end: Optional[datetime] = None
):
"""
Background task that executes the actual training job.
🔧 KEY FEATURES:
- Uses its own database session (isolated from API request)
- Handles all errors gracefully
- Updates job status in real-time
- Publishes progress events via WebSocket/messaging
- Comprehensive logging and monitoring
"""
logger.info(f"🚀 Background training job {job_id} started for tenant {tenant_id}")
async with get_background_db_session() as db_session:
try:
# ✅ FIX: Create training service with isolated DB session
training_service = TrainingService(db_session=db_session)
# Publish progress event
await publish_job_progress(job_id, tenant_id, 5, "Initializing training pipeline")
try:
# Execute the actual training pipeline
result = await training_service.start_training_job(
tenant_id=tenant_id,
job_id=job_id,
bakery_location=bakery_location,
requested_start=requested_start,
requested_end=requested_end
)
# Publish completion event
await publish_job_completed(
job_id=job_id,
tenant_id=tenant_id,
results=result
)
logger.info(f"✅ Background training job {job_id} completed successfully")
except Exception as training_error:
logger.error(f"❌ Training pipeline failed for job {job_id}: {str(training_error)}")
# Publish failure event
await publish_job_failed(
job_id=job_id,
tenant_id=tenant_id,
error=str(training_error)
)
except Exception as background_error:
logger.error(f"💥 Critical error in background training job {job_id}: {str(background_error)}")
finally:
# Ensure database session is properly closed
logger.info(f"🧹 Background training job {job_id} cleanup completed")
@router.post("/tenants/{tenant_id}/training/products/{product_name}", response_model=TrainingJobResponse)
async def start_single_product_training(
request: SingleProductTrainingRequest,

View File

@@ -238,10 +238,10 @@ async def get_current_job_status(job_id: str, tenant_id: str) -> Dict[str, Any]:
from app.core.database import get_db_session
from app.models.training import ModelTrainingLog # Assuming you have this model
async with get_db_session() as db:
# async with get_background_db_session() as db:
# Query your training job status
# This is a placeholder - adjust based on your actual database models
pass
# pass
# Placeholder return - replace with actual database query
return {

View File

@@ -7,6 +7,7 @@ Uses shared database infrastructure
import structlog
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession
from contextlib import asynccontextmanager
from sqlalchemy import text
from shared.database.base import DatabaseManager, Base
@@ -20,6 +21,18 @@ database_manager = DatabaseManager(settings.DATABASE_URL)
# Alias for convenience - matches the existing interface
get_db = database_manager.get_db
@asynccontextmanager
async def get_background_db_session():
async with database_manager.async_session_local() as session:
try:
yield session
await session.commit()
except Exception as e:
await session.rollback()
raise
finally:
await session.close()
async def get_db_health() -> bool:
"""
Health check function for database connectivity

View File

@@ -13,7 +13,7 @@ TEST_PASSWORD="TestPassword123!"
TEST_NAME="Test Bakery Owner"
REAL_CSV_FILE="bakery_sales_2023_2024.csv"
WS_BASE="ws://localhost:8002/api/v1/ws"
WS_TEST_DURATION=30 # seconds to listen for WebSocket messages
WS_TEST_DURATION=200 # seconds to listen for WebSocket messages
WS_PID=""