Training job in the background
This commit is contained in:
@@ -9,9 +9,10 @@ from fastapi import Query, Path
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from typing import List, Optional, Dict, Any
|
from typing import List, Optional, Dict, Any
|
||||||
import structlog
|
import structlog
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
|
import uuid
|
||||||
|
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db, get_background_db_session
|
||||||
from app.services.training_service import TrainingService
|
from app.services.training_service import TrainingService
|
||||||
from app.schemas.training import (
|
from app.schemas.training import (
|
||||||
TrainingJobRequest,
|
TrainingJobRequest,
|
||||||
@@ -21,15 +22,23 @@ from app.schemas.training import (
|
|||||||
TrainingJobResponse
|
TrainingJobResponse
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from app.services.messaging import (
|
||||||
|
publish_job_progress,
|
||||||
|
publish_data_validation_started,
|
||||||
|
publish_data_validation_completed,
|
||||||
|
publish_job_step_completed,
|
||||||
|
publish_job_completed,
|
||||||
|
publish_job_failed,
|
||||||
|
publish_job_started
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Import shared auth decorators (assuming they exist in your microservices)
|
# Import shared auth decorators (assuming they exist in your microservices)
|
||||||
from shared.auth.decorators import get_current_tenant_id_dep
|
from shared.auth.decorators import get_current_tenant_id_dep
|
||||||
|
|
||||||
logger = structlog.get_logger()
|
logger = structlog.get_logger()
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
# Initialize training service
|
|
||||||
training_service = TrainingService()
|
|
||||||
|
|
||||||
@router.post("/tenants/{tenant_id}/training/jobs", response_model=TrainingJobResponse)
|
@router.post("/tenants/{tenant_id}/training/jobs", response_model=TrainingJobResponse)
|
||||||
async def start_training_job(
|
async def start_training_job(
|
||||||
request: TrainingJobRequest,
|
request: TrainingJobRequest,
|
||||||
@@ -41,32 +50,89 @@ async def start_training_job(
|
|||||||
"""
|
"""
|
||||||
Start a new training job for all tenant products.
|
Start a new training job for all tenant products.
|
||||||
|
|
||||||
This is the main entry point for the training pipeline:
|
🚀 IMMEDIATE RESPONSE PATTERN:
|
||||||
API → Training Service → Trainer → Data Processor → Prophet Manager
|
1. Validate request immediately
|
||||||
|
2. Create job record with 'pending' status
|
||||||
|
3. Return 200 with job details
|
||||||
|
4. Execute training in background with separate DB session
|
||||||
|
|
||||||
|
This ensures fast API response while maintaining data consistency.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Validate tenant access
|
# Validate tenant access immediately
|
||||||
if tenant_id != current_tenant:
|
if tenant_id != current_tenant:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail="Access denied to tenant resources"
|
detail="Access denied to tenant resources"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Generate job ID immediately
|
||||||
|
job_id = f"training_{tenant_id}_{uuid.uuid4().hex[:8]}"
|
||||||
|
|
||||||
logger.info(f"Starting training job for tenant {tenant_id}")
|
logger.info(f"Creating training job {job_id} for tenant {tenant_id}")
|
||||||
|
|
||||||
training_service = TrainingService(db_session=db)
|
# Add background task with isolated database session
|
||||||
|
background_tasks.add_task(
|
||||||
# Delegate to training service (Step 1 of the flow)
|
execute_training_job_background,
|
||||||
result = await training_service.start_training_job(
|
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
bakery_location=(40.4168, -3.7038), # Default Madrid coordinates
|
job_id=job_id,
|
||||||
requested_start=request.start_date if request.start_date else None,
|
bakery_location=(40.4168, -3.7038),
|
||||||
requested_end=request.end_date if request.end_date else None,
|
requested_start=request.start_date,
|
||||||
job_id=None # Let the service generate it
|
requested_end=request.end_date
|
||||||
)
|
)
|
||||||
|
|
||||||
return TrainingJobResponse(**result)
|
training_config = {
|
||||||
|
"job_id": job_id,
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"bakery_location": {
|
||||||
|
"latitude": 40.4168,
|
||||||
|
"longitude": -3.7038
|
||||||
|
},
|
||||||
|
"requested_start": request.start_date.isoformat() if request.start_date else None,
|
||||||
|
"requested_end": request.end_date.isoformat() if request.end_date else None,
|
||||||
|
"estimated_duration_minutes": 15,
|
||||||
|
"estimated_products": 10,
|
||||||
|
"background_execution": True,
|
||||||
|
"api_version": "v1"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Publish immediate event (training started)
|
||||||
|
await publish_job_started(
|
||||||
|
job_id=job_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
config=training_config
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return immediate success response
|
||||||
|
response_data = {
|
||||||
|
"job_id": job_id,
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"status": "pending", # Will change to 'running' in background
|
||||||
|
"message": "Training job started successfully",
|
||||||
|
"created_at": datetime.now(timezone.utc),
|
||||||
|
"estimated_duration_minutes": "15",
|
||||||
|
"training_results": {
|
||||||
|
"total_products": 10,
|
||||||
|
"successful_trainings": 0,
|
||||||
|
"failed_trainings": 0,
|
||||||
|
"products": [],
|
||||||
|
"overall_training_time_seconds": 0.0
|
||||||
|
},
|
||||||
|
"data_summary": None,
|
||||||
|
"completed_at": None,
|
||||||
|
"error_details": None,
|
||||||
|
"processing_metadata": {
|
||||||
|
"background_task": True,
|
||||||
|
"async_execution": True
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Training job {job_id} queued successfully, returning immediate response")
|
||||||
|
return TrainingJobResponse(**response_data)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
# Re-raise HTTP exceptions as-is
|
||||||
|
raise
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(f"Training job validation error: {str(e)}")
|
logger.error(f"Training job validation error: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -74,12 +140,77 @@ async def start_training_job(
|
|||||||
detail=str(e)
|
detail=str(e)
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Training job failed: {str(e)}")
|
logger.error(f"Failed to queue training job: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Training job failed"
|
detail="Failed to start training job"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_training_job_background(
|
||||||
|
tenant_id: str,
|
||||||
|
job_id: str,
|
||||||
|
bakery_location: tuple,
|
||||||
|
requested_start: Optional[datetime] = None,
|
||||||
|
requested_end: Optional[datetime] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Background task that executes the actual training job.
|
||||||
|
|
||||||
|
🔧 KEY FEATURES:
|
||||||
|
- Uses its own database session (isolated from API request)
|
||||||
|
- Handles all errors gracefully
|
||||||
|
- Updates job status in real-time
|
||||||
|
- Publishes progress events via WebSocket/messaging
|
||||||
|
- Comprehensive logging and monitoring
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger.info(f"🚀 Background training job {job_id} started for tenant {tenant_id}")
|
||||||
|
|
||||||
|
async with get_background_db_session() as db_session:
|
||||||
|
try:
|
||||||
|
# ✅ FIX: Create training service with isolated DB session
|
||||||
|
training_service = TrainingService(db_session=db_session)
|
||||||
|
|
||||||
|
# Publish progress event
|
||||||
|
await publish_job_progress(job_id, tenant_id, 5, "Initializing training pipeline")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Execute the actual training pipeline
|
||||||
|
result = await training_service.start_training_job(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
job_id=job_id,
|
||||||
|
bakery_location=bakery_location,
|
||||||
|
requested_start=requested_start,
|
||||||
|
requested_end=requested_end
|
||||||
|
)
|
||||||
|
|
||||||
|
# Publish completion event
|
||||||
|
await publish_job_completed(
|
||||||
|
job_id=job_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
results=result
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"✅ Background training job {job_id} completed successfully")
|
||||||
|
|
||||||
|
except Exception as training_error:
|
||||||
|
logger.error(f"❌ Training pipeline failed for job {job_id}: {str(training_error)}")
|
||||||
|
|
||||||
|
# Publish failure event
|
||||||
|
await publish_job_failed(
|
||||||
|
job_id=job_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
error=str(training_error)
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as background_error:
|
||||||
|
logger.error(f"💥 Critical error in background training job {job_id}: {str(background_error)}")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Ensure database session is properly closed
|
||||||
|
logger.info(f"🧹 Background training job {job_id} cleanup completed")
|
||||||
|
|
||||||
@router.post("/tenants/{tenant_id}/training/products/{product_name}", response_model=TrainingJobResponse)
|
@router.post("/tenants/{tenant_id}/training/products/{product_name}", response_model=TrainingJobResponse)
|
||||||
async def start_single_product_training(
|
async def start_single_product_training(
|
||||||
request: SingleProductTrainingRequest,
|
request: SingleProductTrainingRequest,
|
||||||
|
|||||||
@@ -238,10 +238,10 @@ async def get_current_job_status(job_id: str, tenant_id: str) -> Dict[str, Any]:
|
|||||||
from app.core.database import get_db_session
|
from app.core.database import get_db_session
|
||||||
from app.models.training import ModelTrainingLog # Assuming you have this model
|
from app.models.training import ModelTrainingLog # Assuming you have this model
|
||||||
|
|
||||||
async with get_db_session() as db:
|
# async with get_background_db_session() as db:
|
||||||
# Query your training job status
|
# Query your training job status
|
||||||
# This is a placeholder - adjust based on your actual database models
|
# This is a placeholder - adjust based on your actual database models
|
||||||
pass
|
# pass
|
||||||
|
|
||||||
# Placeholder return - replace with actual database query
|
# Placeholder return - replace with actual database query
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ Uses shared database infrastructure
|
|||||||
import structlog
|
import structlog
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
|
|
||||||
from shared.database.base import DatabaseManager, Base
|
from shared.database.base import DatabaseManager, Base
|
||||||
@@ -20,6 +21,18 @@ database_manager = DatabaseManager(settings.DATABASE_URL)
|
|||||||
# Alias for convenience - matches the existing interface
|
# Alias for convenience - matches the existing interface
|
||||||
get_db = database_manager.get_db
|
get_db = database_manager.get_db
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def get_background_db_session():
|
||||||
|
async with database_manager.async_session_local() as session:
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
await session.commit()
|
||||||
|
except Exception as e:
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
|
||||||
async def get_db_health() -> bool:
|
async def get_db_health() -> bool:
|
||||||
"""
|
"""
|
||||||
Health check function for database connectivity
|
Health check function for database connectivity
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ TEST_PASSWORD="TestPassword123!"
|
|||||||
TEST_NAME="Test Bakery Owner"
|
TEST_NAME="Test Bakery Owner"
|
||||||
REAL_CSV_FILE="bakery_sales_2023_2024.csv"
|
REAL_CSV_FILE="bakery_sales_2023_2024.csv"
|
||||||
WS_BASE="ws://localhost:8002/api/v1/ws"
|
WS_BASE="ws://localhost:8002/api/v1/ws"
|
||||||
WS_TEST_DURATION=30 # seconds to listen for WebSocket messages
|
WS_TEST_DURATION=200 # seconds to listen for WebSocket messages
|
||||||
WS_PID=""
|
WS_PID=""
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user