diff --git a/services/training/app/api/training.py b/services/training/app/api/training.py index 7ed444c2..9a5431db 100644 --- a/services/training/app/api/training.py +++ b/services/training/app/api/training.py @@ -9,9 +9,10 @@ from fastapi import Query, Path from sqlalchemy.ext.asyncio import AsyncSession from typing import List, Optional, Dict, Any import structlog -from datetime import datetime +from datetime import datetime, timezone +import uuid -from app.core.database import get_db +from app.core.database import get_db, get_background_db_session from app.services.training_service import TrainingService from app.schemas.training import ( TrainingJobRequest, @@ -21,15 +22,23 @@ from app.schemas.training import ( TrainingJobResponse ) +from app.services.messaging import ( + publish_job_progress, + publish_data_validation_started, + publish_data_validation_completed, + publish_job_step_completed, + publish_job_completed, + publish_job_failed, + publish_job_started +) + + # Import shared auth decorators (assuming they exist in your microservices) from shared.auth.decorators import get_current_tenant_id_dep logger = structlog.get_logger() router = APIRouter() -# Initialize training service -training_service = TrainingService() - @router.post("/tenants/{tenant_id}/training/jobs", response_model=TrainingJobResponse) async def start_training_job( request: TrainingJobRequest, @@ -41,32 +50,89 @@ async def start_training_job( """ Start a new training job for all tenant products. - This is the main entry point for the training pipeline: - API โ†’ Training Service โ†’ Trainer โ†’ Data Processor โ†’ Prophet Manager + ๐Ÿš€ IMMEDIATE RESPONSE PATTERN: + 1. Validate request immediately + 2. Create job record with 'pending' status + 3. Return 200 with job details + 4. Execute training in background with separate DB session + + This ensures fast API response while maintaining data consistency. """ try: - # Validate tenant access + # Validate tenant access immediately if tenant_id != current_tenant: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Access denied to tenant resources" ) + + # Generate job ID immediately + job_id = f"training_{tenant_id}_{uuid.uuid4().hex[:8]}" - logger.info(f"Starting training job for tenant {tenant_id}") - - training_service = TrainingService(db_session=db) - - # Delegate to training service (Step 1 of the flow) - result = await training_service.start_training_job( + logger.info(f"Creating training job {job_id} for tenant {tenant_id}") + + # Add background task with isolated database session + background_tasks.add_task( + execute_training_job_background, tenant_id=tenant_id, - bakery_location=(40.4168, -3.7038), # Default Madrid coordinates - requested_start=request.start_date if request.start_date else None, - requested_end=request.end_date if request.end_date else None, - job_id=None # Let the service generate it + job_id=job_id, + bakery_location=(40.4168, -3.7038), + requested_start=request.start_date, + requested_end=request.end_date ) - return TrainingJobResponse(**result) - + training_config = { + "job_id": job_id, + "tenant_id": tenant_id, + "bakery_location": { + "latitude": 40.4168, + "longitude": -3.7038 + }, + "requested_start": request.start_date.isoformat() if request.start_date else None, + "requested_end": request.end_date.isoformat() if request.end_date else None, + "estimated_duration_minutes": 15, + "estimated_products": 10, + "background_execution": True, + "api_version": "v1" + } + + # Publish immediate event (training started) + await publish_job_started( + job_id=job_id, + tenant_id=tenant_id, + config=training_config + ) + + # Return immediate success response + response_data = { + "job_id": job_id, + "tenant_id": tenant_id, + "status": "pending", # Will change to 'running' in background + "message": "Training job started successfully", + "created_at": datetime.now(timezone.utc), + "estimated_duration_minutes": "15", + "training_results": { + "total_products": 10, + "successful_trainings": 0, + "failed_trainings": 0, + "products": [], + "overall_training_time_seconds": 0.0 + }, + "data_summary": None, + "completed_at": None, + "error_details": None, + "processing_metadata": { + "background_task": True, + "async_execution": True + } + } + + logger.info(f"Training job {job_id} queued successfully, returning immediate response") + return TrainingJobResponse(**response_data) + + except HTTPException: + # Re-raise HTTP exceptions as-is + raise except ValueError as e: logger.error(f"Training job validation error: {str(e)}") raise HTTPException( @@ -74,12 +140,77 @@ async def start_training_job( detail=str(e) ) except Exception as e: - logger.error(f"Training job failed: {str(e)}") + logger.error(f"Failed to queue training job: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Training job failed" + detail="Failed to start training job" ) + +async def execute_training_job_background( + tenant_id: str, + job_id: str, + bakery_location: tuple, + requested_start: Optional[datetime] = None, + requested_end: Optional[datetime] = None +): + """ + Background task that executes the actual training job. + + ๐Ÿ”ง KEY FEATURES: + - Uses its own database session (isolated from API request) + - Handles all errors gracefully + - Updates job status in real-time + - Publishes progress events via WebSocket/messaging + - Comprehensive logging and monitoring + """ + + logger.info(f"๐Ÿš€ Background training job {job_id} started for tenant {tenant_id}") + + async with get_background_db_session() as db_session: + try: + # โœ… FIX: Create training service with isolated DB session + training_service = TrainingService(db_session=db_session) + + # Publish progress event + await publish_job_progress(job_id, tenant_id, 5, "Initializing training pipeline") + + try: + # Execute the actual training pipeline + result = await training_service.start_training_job( + tenant_id=tenant_id, + job_id=job_id, + bakery_location=bakery_location, + requested_start=requested_start, + requested_end=requested_end + ) + + # Publish completion event + await publish_job_completed( + job_id=job_id, + tenant_id=tenant_id, + results=result + ) + + logger.info(f"โœ… Background training job {job_id} completed successfully") + + except Exception as training_error: + logger.error(f"โŒ Training pipeline failed for job {job_id}: {str(training_error)}") + + # Publish failure event + await publish_job_failed( + job_id=job_id, + tenant_id=tenant_id, + error=str(training_error) + ) + + except Exception as background_error: + logger.error(f"๐Ÿ’ฅ Critical error in background training job {job_id}: {str(background_error)}") + + finally: + # Ensure database session is properly closed + logger.info(f"๐Ÿงน Background training job {job_id} cleanup completed") + @router.post("/tenants/{tenant_id}/training/products/{product_name}", response_model=TrainingJobResponse) async def start_single_product_training( request: SingleProductTrainingRequest, diff --git a/services/training/app/api/websocket.py b/services/training/app/api/websocket.py index dddd7317..2d75be35 100644 --- a/services/training/app/api/websocket.py +++ b/services/training/app/api/websocket.py @@ -238,10 +238,10 @@ async def get_current_job_status(job_id: str, tenant_id: str) -> Dict[str, Any]: from app.core.database import get_db_session from app.models.training import ModelTrainingLog # Assuming you have this model - async with get_db_session() as db: + # async with get_background_db_session() as db: # Query your training job status # This is a placeholder - adjust based on your actual database models - pass + # pass # Placeholder return - replace with actual database query return { diff --git a/services/training/app/core/database.py b/services/training/app/core/database.py index a43955fa..540125dd 100644 --- a/services/training/app/core/database.py +++ b/services/training/app/core/database.py @@ -7,6 +7,7 @@ Uses shared database infrastructure import structlog from typing import AsyncGenerator from sqlalchemy.ext.asyncio import AsyncSession +from contextlib import asynccontextmanager from sqlalchemy import text from shared.database.base import DatabaseManager, Base @@ -20,6 +21,18 @@ database_manager = DatabaseManager(settings.DATABASE_URL) # Alias for convenience - matches the existing interface get_db = database_manager.get_db +@asynccontextmanager +async def get_background_db_session(): + async with database_manager.async_session_local() as session: + try: + yield session + await session.commit() + except Exception as e: + await session.rollback() + raise + finally: + await session.close() + async def get_db_health() -> bool: """ Health check function for database connectivity diff --git a/tests/test_onboarding_flow.sh b/tests/test_onboarding_flow.sh index be8a4e47..3a76af02 100755 --- a/tests/test_onboarding_flow.sh +++ b/tests/test_onboarding_flow.sh @@ -13,7 +13,7 @@ TEST_PASSWORD="TestPassword123!" TEST_NAME="Test Bakery Owner" REAL_CSV_FILE="bakery_sales_2023_2024.csv" WS_BASE="ws://localhost:8002/api/v1/ws" -WS_TEST_DURATION=30 # seconds to listen for WebSocket messages +WS_TEST_DURATION=200 # seconds to listen for WebSocket messages WS_PID=""