Training job in the background
This commit is contained in:
@@ -9,9 +9,10 @@ from fastapi import Query, Path
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import List, Optional, Dict, Any
|
||||
import structlog
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.database import get_db, get_background_db_session
|
||||
from app.services.training_service import TrainingService
|
||||
from app.schemas.training import (
|
||||
TrainingJobRequest,
|
||||
@@ -21,15 +22,23 @@ from app.schemas.training import (
|
||||
TrainingJobResponse
|
||||
)
|
||||
|
||||
from app.services.messaging import (
|
||||
publish_job_progress,
|
||||
publish_data_validation_started,
|
||||
publish_data_validation_completed,
|
||||
publish_job_step_completed,
|
||||
publish_job_completed,
|
||||
publish_job_failed,
|
||||
publish_job_started
|
||||
)
|
||||
|
||||
|
||||
# Import shared auth decorators (assuming they exist in your microservices)
|
||||
from shared.auth.decorators import get_current_tenant_id_dep
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter()
|
||||
|
||||
# Initialize training service
|
||||
training_service = TrainingService()
|
||||
|
||||
@router.post("/tenants/{tenant_id}/training/jobs", response_model=TrainingJobResponse)
|
||||
async def start_training_job(
|
||||
request: TrainingJobRequest,
|
||||
@@ -41,32 +50,89 @@ async def start_training_job(
|
||||
"""
|
||||
Start a new training job for all tenant products.
|
||||
|
||||
This is the main entry point for the training pipeline:
|
||||
API → Training Service → Trainer → Data Processor → Prophet Manager
|
||||
🚀 IMMEDIATE RESPONSE PATTERN:
|
||||
1. Validate request immediately
|
||||
2. Create job record with 'pending' status
|
||||
3. Return 200 with job details
|
||||
4. Execute training in background with separate DB session
|
||||
|
||||
This ensures fast API response while maintaining data consistency.
|
||||
"""
|
||||
try:
|
||||
# Validate tenant access
|
||||
# Validate tenant access immediately
|
||||
if tenant_id != current_tenant:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to tenant resources"
|
||||
)
|
||||
|
||||
# Generate job ID immediately
|
||||
job_id = f"training_{tenant_id}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
logger.info(f"Starting training job for tenant {tenant_id}")
|
||||
|
||||
training_service = TrainingService(db_session=db)
|
||||
|
||||
# Delegate to training service (Step 1 of the flow)
|
||||
result = await training_service.start_training_job(
|
||||
logger.info(f"Creating training job {job_id} for tenant {tenant_id}")
|
||||
|
||||
# Add background task with isolated database session
|
||||
background_tasks.add_task(
|
||||
execute_training_job_background,
|
||||
tenant_id=tenant_id,
|
||||
bakery_location=(40.4168, -3.7038), # Default Madrid coordinates
|
||||
requested_start=request.start_date if request.start_date else None,
|
||||
requested_end=request.end_date if request.end_date else None,
|
||||
job_id=None # Let the service generate it
|
||||
job_id=job_id,
|
||||
bakery_location=(40.4168, -3.7038),
|
||||
requested_start=request.start_date,
|
||||
requested_end=request.end_date
|
||||
)
|
||||
|
||||
return TrainingJobResponse(**result)
|
||||
|
||||
training_config = {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"bakery_location": {
|
||||
"latitude": 40.4168,
|
||||
"longitude": -3.7038
|
||||
},
|
||||
"requested_start": request.start_date.isoformat() if request.start_date else None,
|
||||
"requested_end": request.end_date.isoformat() if request.end_date else None,
|
||||
"estimated_duration_minutes": 15,
|
||||
"estimated_products": 10,
|
||||
"background_execution": True,
|
||||
"api_version": "v1"
|
||||
}
|
||||
|
||||
# Publish immediate event (training started)
|
||||
await publish_job_started(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
config=training_config
|
||||
)
|
||||
|
||||
# Return immediate success response
|
||||
response_data = {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"status": "pending", # Will change to 'running' in background
|
||||
"message": "Training job started successfully",
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"estimated_duration_minutes": "15",
|
||||
"training_results": {
|
||||
"total_products": 10,
|
||||
"successful_trainings": 0,
|
||||
"failed_trainings": 0,
|
||||
"products": [],
|
||||
"overall_training_time_seconds": 0.0
|
||||
},
|
||||
"data_summary": None,
|
||||
"completed_at": None,
|
||||
"error_details": None,
|
||||
"processing_metadata": {
|
||||
"background_task": True,
|
||||
"async_execution": True
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(f"Training job {job_id} queued successfully, returning immediate response")
|
||||
return TrainingJobResponse(**response_data)
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTP exceptions as-is
|
||||
raise
|
||||
except ValueError as e:
|
||||
logger.error(f"Training job validation error: {str(e)}")
|
||||
raise HTTPException(
|
||||
@@ -74,12 +140,77 @@ async def start_training_job(
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Training job failed: {str(e)}")
|
||||
logger.error(f"Failed to queue training job: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Training job failed"
|
||||
detail="Failed to start training job"
|
||||
)
|
||||
|
||||
|
||||
async def execute_training_job_background(
|
||||
tenant_id: str,
|
||||
job_id: str,
|
||||
bakery_location: tuple,
|
||||
requested_start: Optional[datetime] = None,
|
||||
requested_end: Optional[datetime] = None
|
||||
):
|
||||
"""
|
||||
Background task that executes the actual training job.
|
||||
|
||||
🔧 KEY FEATURES:
|
||||
- Uses its own database session (isolated from API request)
|
||||
- Handles all errors gracefully
|
||||
- Updates job status in real-time
|
||||
- Publishes progress events via WebSocket/messaging
|
||||
- Comprehensive logging and monitoring
|
||||
"""
|
||||
|
||||
logger.info(f"🚀 Background training job {job_id} started for tenant {tenant_id}")
|
||||
|
||||
async with get_background_db_session() as db_session:
|
||||
try:
|
||||
# ✅ FIX: Create training service with isolated DB session
|
||||
training_service = TrainingService(db_session=db_session)
|
||||
|
||||
# Publish progress event
|
||||
await publish_job_progress(job_id, tenant_id, 5, "Initializing training pipeline")
|
||||
|
||||
try:
|
||||
# Execute the actual training pipeline
|
||||
result = await training_service.start_training_job(
|
||||
tenant_id=tenant_id,
|
||||
job_id=job_id,
|
||||
bakery_location=bakery_location,
|
||||
requested_start=requested_start,
|
||||
requested_end=requested_end
|
||||
)
|
||||
|
||||
# Publish completion event
|
||||
await publish_job_completed(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
results=result
|
||||
)
|
||||
|
||||
logger.info(f"✅ Background training job {job_id} completed successfully")
|
||||
|
||||
except Exception as training_error:
|
||||
logger.error(f"❌ Training pipeline failed for job {job_id}: {str(training_error)}")
|
||||
|
||||
# Publish failure event
|
||||
await publish_job_failed(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
error=str(training_error)
|
||||
)
|
||||
|
||||
except Exception as background_error:
|
||||
logger.error(f"💥 Critical error in background training job {job_id}: {str(background_error)}")
|
||||
|
||||
finally:
|
||||
# Ensure database session is properly closed
|
||||
logger.info(f"🧹 Background training job {job_id} cleanup completed")
|
||||
|
||||
@router.post("/tenants/{tenant_id}/training/products/{product_name}", response_model=TrainingJobResponse)
|
||||
async def start_single_product_training(
|
||||
request: SingleProductTrainingRequest,
|
||||
|
||||
@@ -238,10 +238,10 @@ async def get_current_job_status(job_id: str, tenant_id: str) -> Dict[str, Any]:
|
||||
from app.core.database import get_db_session
|
||||
from app.models.training import ModelTrainingLog # Assuming you have this model
|
||||
|
||||
async with get_db_session() as db:
|
||||
# async with get_background_db_session() as db:
|
||||
# Query your training job status
|
||||
# This is a placeholder - adjust based on your actual database models
|
||||
pass
|
||||
# pass
|
||||
|
||||
# Placeholder return - replace with actual database query
|
||||
return {
|
||||
|
||||
@@ -7,6 +7,7 @@ Uses shared database infrastructure
|
||||
import structlog
|
||||
from typing import AsyncGenerator
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from contextlib import asynccontextmanager
|
||||
from sqlalchemy import text
|
||||
|
||||
from shared.database.base import DatabaseManager, Base
|
||||
@@ -20,6 +21,18 @@ database_manager = DatabaseManager(settings.DATABASE_URL)
|
||||
# Alias for convenience - matches the existing interface
|
||||
get_db = database_manager.get_db
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_background_db_session():
|
||||
async with database_manager.async_session_local() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
async def get_db_health() -> bool:
|
||||
"""
|
||||
Health check function for database connectivity
|
||||
|
||||
@@ -13,7 +13,7 @@ TEST_PASSWORD="TestPassword123!"
|
||||
TEST_NAME="Test Bakery Owner"
|
||||
REAL_CSV_FILE="bakery_sales_2023_2024.csv"
|
||||
WS_BASE="ws://localhost:8002/api/v1/ws"
|
||||
WS_TEST_DURATION=30 # seconds to listen for WebSocket messages
|
||||
WS_TEST_DURATION=200 # seconds to listen for WebSocket messages
|
||||
WS_PID=""
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user