# ================================================================ # services/training/app/api/training.py - UPDATED WITH UNIFIED AUTH # ================================================================ """Training API endpoints with unified authentication""" from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, Query from typing import List, Optional, Dict, Any from datetime import datetime import structlog import uuid from app.schemas.training import ( TrainingJobRequest, TrainingJobResponse, TrainingStatus, SingleProductTrainingRequest, TrainingJobProgress, DataValidationRequest, DataValidationResponse ) from app.services.training_service import TrainingService from app.services.messaging import ( publish_job_started, publish_job_completed, publish_job_failed, publish_job_progress, publish_product_training_started, publish_product_training_completed ) from sqlalchemy.ext.asyncio import AsyncSession from app.core.database import get_db_session # Import unified authentication from shared library from shared.auth.decorators import ( get_current_user_dep, get_current_tenant_id_dep, require_role ) logger = structlog.get_logger() router = APIRouter(prefix="/training", tags=["training"]) def get_training_service() -> TrainingService: """Factory function for TrainingService dependency""" return TrainingService() @router.post("/jobs", response_model=TrainingJobResponse) async def start_training_job( request: TrainingJobRequest, background_tasks: BackgroundTasks, tenant_id: str = Depends(get_current_tenant_id_dep), current_user: Dict[str, Any] = Depends(get_current_user_dep), training_service: TrainingService = Depends(get_training_service), db: AsyncSession = Depends(get_db_session) # Ensure db is available ): """Start a new training job for all products""" try: new_job_id = str(uuid.uuid4()) logger.info("Starting training job", tenant_id=tenant_id, job_id=uuid.uuid4(), config=request.dict()) # Create training job job = await training_service.create_training_job( db, # Pass db here tenant_id=tenant_id, job_id=new_job_id, config=request.dict() ) # Publish job started event try: await publish_job_started( job_id=new_job_id, tenant_id=tenant_id, config=request.dict() ) except Exception as e: logger.warning("Failed to publish job started event", error=str(e)) # Start training in background background_tasks.add_task( training_service.execute_training_job, db, # Pass the database session job.job_id, job.tenant_id, request # Pass the request object ) logger.info("Training job created", job_id=job.job_id, tenant_id=tenant_id) return TrainingJobResponse( job_id=job.job_id, status=TrainingStatus.PENDING, message="Training job created successfully", tenant_id=tenant_id, created_at=job.created_at, estimated_duration_minutes=30 ) except Exception as e: logger.error("Failed to start training job", error=str(e), tenant_id=tenant_id) raise HTTPException(status_code=500, detail=f"Failed to start training job: {str(e)}") @router.get("/jobs", response_model=List[TrainingJobResponse]) async def get_training_jobs( status: Optional[TrainingStatus] = Query(None, description="Filter jobs by status"), limit: int = Query(100, ge=1, le=1000), offset: int = Query(0, ge=0), tenant_id: str = Depends(get_current_tenant_id_dep), current_user: Dict[str, Any] = Depends(get_current_user_dep), training_service: TrainingService = Depends(get_training_service) ): """Get training jobs for tenant""" try: logger.debug("Getting training jobs", tenant_id=tenant_id, status=status, limit=limit, offset=offset) jobs = await training_service.get_training_jobs( tenant_id=tenant_id, status=status, limit=limit, offset=offset ) logger.debug("Retrieved training jobs", count=len(jobs), tenant_id=tenant_id) return jobs except Exception as e: logger.error("Failed to get training jobs", error=str(e), tenant_id=tenant_id) raise HTTPException(status_code=500, detail=f"Failed to get training jobs: {str(e)}") @router.get("/jobs/{job_id}", response_model=TrainingJobResponse) async def get_training_job( job_id: str, tenant_id: str = Depends(get_current_tenant_id_dep), current_user: Dict[str, Any] = Depends(get_current_user_dep), training_service: TrainingService = Depends(get_training_service) ): """Get specific training job details""" try: logger.debug("Getting training job", job_id=job_id, tenant_id=tenant_id) job = await training_service.get_training_job(job_id) # Verify tenant access if job.tenant_id != tenant_id: logger.warning("Unauthorized job access attempt", job_id=job_id, tenant_id=tenant_id, job_tenant_id=job.tenant_id) raise HTTPException(status_code=404, detail="Job not found") return job except HTTPException: raise except Exception as e: logger.error("Failed to get training job", error=str(e), job_id=job_id) raise HTTPException(status_code=500, detail=f"Failed to get training job: {str(e)}") @router.get("/jobs/{job_id}/progress", response_model=TrainingJobProgress) async def get_training_progress( job_id: str, tenant_id: str = Depends(get_current_tenant_id_dep), current_user: Dict[str, Any] = Depends(get_current_user_dep), training_service: TrainingService = Depends(get_training_service) ): """Get real-time training progress""" try: logger.debug("Getting training progress", job_id=job_id, tenant_id=tenant_id) # Verify job belongs to tenant job = await training_service.get_training_job(job_id) if job.tenant_id != tenant_id: raise HTTPException(status_code=404, detail="Job not found") progress = await training_service.get_job_progress(job_id) return progress except HTTPException: raise except Exception as e: logger.error("Failed to get training progress", error=str(e), job_id=job_id) raise HTTPException(status_code=500, detail=f"Failed to get training progress: {str(e)}") @router.post("/jobs/{job_id}/cancel") async def cancel_training_job( job_id: str, tenant_id: str = Depends(get_current_tenant_id_dep), current_user: Dict[str, Any] = Depends(get_current_user_dep), training_service: TrainingService = Depends(get_training_service) ): """Cancel a running training job""" try: logger.info("Cancelling training job", job_id=job_id, tenant_id=tenant_id, user_id=current_user["user_id"]) job = await training_service.get_training_job(job_id) # Verify tenant access if job.tenant_id != tenant_id: raise HTTPException(status_code=404, detail="Job not found") await training_service.cancel_training_job(job_id) # Publish cancellation event try: await publish_job_failed( job_id=job_id, tenant_id=tenant_id, error="Job cancelled by user", failed_at="cancellation" ) except Exception as e: logger.warning("Failed to publish cancellation event", error=str(e)) logger.info("Training job cancelled", job_id=job_id) return {"message": "Job cancelled successfully", "job_id": job_id} except HTTPException: raise except Exception as e: logger.error("Failed to cancel training job", error=str(e), job_id=job_id) raise HTTPException(status_code=500, detail=f"Failed to cancel training job: {str(e)}") @router.post("/products/{product_name}", response_model=TrainingJobResponse) async def train_single_product( product_name: str, request: SingleProductTrainingRequest, background_tasks: BackgroundTasks, tenant_id: str = Depends(get_current_tenant_id_dep), current_user: Dict[str, Any] = Depends(get_current_user_dep), training_service: TrainingService = Depends(get_training_service), db: AsyncSession = Depends(get_db_session) ): """Train model for a single product""" try: logger.info("Training single product", product_name=product_name, tenant_id=tenant_id, user_id=current_user["user_id"]) # Create training job for single product job = await training_service.create_single_product_job( db, tenant_id=tenant_id, product_name=product_name, config=request.dict() ) # Publish event try: await publish_product_training_started( job_id=job.job_id, tenant_id=tenant_id, product_name=product_name ) except Exception as e: logger.warning("Failed to publish product training event", error=str(e)) # Start training in background background_tasks.add_task( training_service.execute_single_product_training, job.job_id, product_name ) logger.info("Single product training started", job_id=job.job_id, product_name=product_name) return job except Exception as e: logger.error("Failed to train single product", error=str(e), product_name=product_name, tenant_id=tenant_id) raise HTTPException(status_code=500, detail=f"Failed to train product: {str(e)}") @router.post("/validate", response_model=DataValidationResponse) async def validate_training_data( request: DataValidationRequest, tenant_id: str = Depends(get_current_tenant_id_dep), current_user: Dict[str, Any] = Depends(get_current_user_dep), training_service: TrainingService = Depends(get_training_service) ): """Validate data before training""" try: logger.debug("Validating training data", tenant_id=tenant_id, products=request.products) validation_result = await training_service.validate_training_data( tenant_id=tenant_id, products=request.products, min_data_points=request.min_data_points ) logger.debug("Data validation completed", is_valid=validation_result.is_valid, tenant_id=tenant_id) return validation_result except Exception as e: logger.error("Failed to validate training data", error=str(e), tenant_id=tenant_id) raise HTTPException(status_code=500, detail=f"Failed to validate data: {str(e)}") @router.get("/models") async def get_trained_models( product_name: Optional[str] = Query(None), tenant_id: str = Depends(get_current_tenant_id_dep), current_user: Dict[str, Any] = Depends(get_current_user_dep), training_service: TrainingService = Depends(get_training_service) ): """Get list of trained models""" try: logger.debug("Getting trained models", tenant_id=tenant_id, product_name=product_name) models = await training_service.get_trained_models( tenant_id=tenant_id, product_name=product_name ) logger.debug("Retrieved trained models", count=len(models), tenant_id=tenant_id) return models except Exception as e: logger.error("Failed to get trained models", error=str(e), tenant_id=tenant_id) raise HTTPException(status_code=500, detail=f"Failed to get models: {str(e)}") @router.delete("/models/{model_id}") @require_role("admin") # Only admins can delete models async def delete_model( model_id: str, tenant_id: str = Depends(get_current_tenant_id_dep), current_user: Dict[str, Any] = Depends(get_current_user_dep), training_service: TrainingService = Depends(get_training_service) ): """Delete a trained model (admin only)""" try: logger.info("Deleting model", model_id=model_id, tenant_id=tenant_id, admin_id=current_user["user_id"]) # Verify model belongs to tenant model = await training_service.get_model(model_id) if model.tenant_id != tenant_id: raise HTTPException(status_code=404, detail="Model not found") success = await training_service.delete_model(model_id) if not success: raise HTTPException(status_code=404, detail="Model not found") logger.info("Model deleted successfully", model_id=model_id) return {"message": "Model deleted successfully", "model_id": model_id} except HTTPException: raise except Exception as e: logger.error("Failed to delete model", error=str(e), model_id=model_id) raise HTTPException(status_code=500, detail=f"Failed to delete model: {str(e)}") @router.get("/stats") async def get_training_stats( start_date: Optional[datetime] = Query(None), end_date: Optional[datetime] = Query(None), tenant_id: str = Depends(get_current_tenant_id_dep), current_user: Dict[str, Any] = Depends(get_current_user_dep), training_service: TrainingService = Depends(get_training_service) ): """Get training statistics for tenant""" try: logger.debug("Getting training stats", tenant_id=tenant_id, start_date=start_date, end_date=end_date) stats = await training_service.get_training_stats( tenant_id=tenant_id, start_date=start_date, end_date=end_date ) logger.debug("Training stats retrieved", tenant_id=tenant_id) return stats except Exception as e: logger.error("Failed to get training stats", error=str(e), tenant_id=tenant_id) raise HTTPException(status_code=500, detail=f"Failed to get stats: {str(e)}") @router.post("/retrain/all") async def retrain_all_products( request: TrainingJobRequest, background_tasks: BackgroundTasks, tenant_id: str = Depends(get_current_tenant_id_dep), current_user: Dict[str, Any] = Depends(get_current_user_dep), training_service: TrainingService = Depends(get_training_service) ): """Retrain all products with existing models""" try: logger.info("Retraining all products", tenant_id=tenant_id, user_id=current_user["user_id"]) # Check if models exist existing_models = await training_service.get_trained_models(tenant_id) if not existing_models: raise HTTPException( status_code=400, detail="No existing models found. Please run initial training first." ) # Create retraining job job = await training_service.create_training_job( tenant_id=tenant_id, user_id=current_user["user_id"], config={**request.dict(), "is_retrain": True} ) # Publish event try: await publish_job_started( job_id=job.job_id, tenant_id=tenant_id, config={**request.dict(), "is_retrain": True} ) except Exception as e: logger.warning("Failed to publish retrain event", error=str(e)) # Start retraining in background background_tasks.add_task( training_service.execute_training_job, job.job_id ) logger.info("Retraining job created", job_id=job.job_id) return job except HTTPException: raise except Exception as e: logger.error("Failed to start retraining", error=str(e), tenant_id=tenant_id) raise HTTPException(status_code=500, detail=f"Failed to start retraining: {str(e)}")