diff --git a/services/training/app/api/training.py b/services/training/app/api/training.py index c4e3040d..0af2954b 100644 --- a/services/training/app/api/training.py +++ b/services/training/app/api/training.py @@ -41,13 +41,17 @@ from shared.auth.decorators import ( logger = structlog.get_logger() router = APIRouter(prefix="/training", tags=["training"]) +def get_training_service() -> TrainingService: + """Factory function for TrainingService dependency""" + return TrainingService() + @router.post("/jobs", response_model=TrainingJobResponse) async def start_training_job( request: TrainingJobRequest, background_tasks: BackgroundTasks, tenant_id: str = Depends(get_current_tenant_id_dep), current_user: Dict[str, Any] = Depends(get_current_user_dep), - training_service: TrainingService = Depends(), + training_service: TrainingService = Depends(get_training_service), db: AsyncSession = Depends(get_db_session) # Ensure db is available ): """Start a new training job for all products""" @@ -106,7 +110,7 @@ async def get_training_jobs( offset: int = Query(0, ge=0), tenant_id: str = Depends(get_current_tenant_id_dep), current_user: Dict[str, Any] = Depends(get_current_user_dep), - training_service: TrainingService = Depends() + training_service: TrainingService = Depends(get_training_service) ): """Get training jobs for tenant""" try: @@ -140,7 +144,7 @@ async def get_training_job( job_id: str, tenant_id: str = Depends(get_current_tenant_id_dep), current_user: Dict[str, Any] = Depends(get_current_user_dep), - training_service: TrainingService = Depends() + training_service: TrainingService = Depends(get_training_service) ): """Get specific training job details""" try: @@ -173,7 +177,7 @@ async def get_training_progress( job_id: str, tenant_id: str = Depends(get_current_tenant_id_dep), current_user: Dict[str, Any] = Depends(get_current_user_dep), - training_service: TrainingService = Depends() + training_service: TrainingService = Depends(get_training_service) ): """Get real-time training progress""" try: @@ -203,7 +207,7 @@ async def cancel_training_job( job_id: str, tenant_id: str = Depends(get_current_tenant_id_dep), current_user: Dict[str, Any] = Depends(get_current_user_dep), - training_service: TrainingService = Depends() + training_service: TrainingService = Depends(get_training_service) ): """Cancel a running training job""" try: @@ -250,7 +254,8 @@ async def train_single_product( background_tasks: BackgroundTasks, tenant_id: str = Depends(get_current_tenant_id_dep), current_user: Dict[str, Any] = Depends(get_current_user_dep), - training_service: TrainingService = Depends() + training_service: TrainingService = Depends(get_training_service), + db: AsyncSession = Depends(get_db_session) ): """Train model for a single product""" try: @@ -261,8 +266,8 @@ async def train_single_product( # Create training job for single product job = await training_service.create_single_product_job( + db, tenant_id=tenant_id, - user_id=current_user["user_id"], product_name=product_name, config=request.dict() ) @@ -302,7 +307,7 @@ async def validate_training_data( request: DataValidationRequest, tenant_id: str = Depends(get_current_tenant_id_dep), current_user: Dict[str, Any] = Depends(get_current_user_dep), - training_service: TrainingService = Depends() + training_service: TrainingService = Depends(get_training_service) ): """Validate data before training""" try: @@ -333,7 +338,7 @@ async def get_trained_models( product_name: Optional[str] = Query(None), tenant_id: str = Depends(get_current_tenant_id_dep), current_user: Dict[str, Any] = Depends(get_current_user_dep), - training_service: TrainingService = Depends() + training_service: TrainingService = Depends(get_training_service) ): """Get list of trained models""" try: @@ -364,7 +369,7 @@ async def delete_model( model_id: str, tenant_id: str = Depends(get_current_tenant_id_dep), current_user: Dict[str, Any] = Depends(get_current_user_dep), - training_service: TrainingService = Depends() + training_service: TrainingService = Depends(get_training_service) ): """Delete a trained model (admin only)""" try: @@ -401,7 +406,7 @@ async def get_training_stats( end_date: Optional[datetime] = Query(None), tenant_id: str = Depends(get_current_tenant_id_dep), current_user: Dict[str, Any] = Depends(get_current_user_dep), - training_service: TrainingService = Depends() + training_service: TrainingService = Depends(get_training_service) ): """Get training statistics for tenant""" try: @@ -432,7 +437,7 @@ async def retrain_all_products( background_tasks: BackgroundTasks, tenant_id: str = Depends(get_current_tenant_id_dep), current_user: Dict[str, Any] = Depends(get_current_user_dep), - training_service: TrainingService = Depends() + training_service: TrainingService = Depends(get_training_service) ): """Retrain all products with existing models""" try: diff --git a/services/training/app/main.py b/services/training/app/main.py index 827469ab..54cd736c 100644 --- a/services/training/app/main.py +++ b/services/training/app/main.py @@ -41,7 +41,7 @@ async def lifespan(app: FastAPI): try: # Initialize database logger.info("Initializing database connection") - initialize_training_database() + await initialize_training_database() logger.info("Database initialized successfully") # Initialize messaging @@ -81,7 +81,7 @@ async def lifespan(app: FastAPI): logger.info("Messaging cleanup completed") # Close database connections - cleanup_training_database() + await cleanup_training_database() logger.info("Database connections closed") except Exception as e: