# ================================================================ # services/training/app/api/training.py - UPDATED WITH UNIFIED AUTH # ================================================================ """Training API endpoints with unified authentication""" from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, Query from typing import List, Optional, Dict, Any from datetime import datetime import structlog from app.schemas.training import ( TrainingJobRequest, TrainingJobResponse, TrainingJobStatus, SingleProductTrainingRequest, TrainingJobProgress, DataValidationRequest, DataValidationResponse ) from app.services.training_service import TrainingService from app.services.messaging import ( publish_job_started, publish_job_completed, publish_job_failed, publish_job_progress, publish_product_training_started, publish_product_training_completed ) # Import unified authentication from shared library from shared.auth.decorators import ( get_current_user_dep, get_current_tenant_id_dep, require_role ) logger = structlog.get_logger() router = APIRouter(prefix="/training", tags=["training"]) @router.post("/jobs", response_model=TrainingJobResponse) async def start_training_job( request: TrainingJobRequest, background_tasks: BackgroundTasks, tenant_id: str = Depends(get_current_tenant_id_dep), current_user: Dict[str, Any] = Depends(get_current_user_dep), training_service: TrainingService = Depends() ): """Start a new training job for all products""" try: logger.info("Starting training job", tenant_id=tenant_id, user_id=current_user["user_id"], config=request.dict()) # Create training job job = await training_service.create_training_job( tenant_id=tenant_id, user_id=current_user["user_id"], config=request.dict() ) # Publish job started event try: await publish_job_started( job_id=job.job_id, tenant_id=tenant_id, config=request.dict() ) except Exception as e: logger.warning("Failed to publish job started event", error=str(e)) # Start training in background background_tasks.add_task( training_service.execute_training_job, job.job_id ) logger.info("Training job created", job_id=job.job_id, tenant_id=tenant_id) return job except Exception as e: logger.error("Failed to start training job", error=str(e), tenant_id=tenant_id) raise HTTPException(status_code=500, detail=f"Failed to start training job: {str(e)}") @router.get("/jobs", response_model=List[TrainingJobResponse]) async def get_training_jobs( status: Optional[TrainingJobStatus] = Query(None), limit: int = Query(100, ge=1, le=1000), offset: int = Query(0, ge=0), tenant_id: str = Depends(get_current_tenant_id_dep), current_user: Dict[str, Any] = Depends(get_current_user_dep), training_service: TrainingService = Depends() ): """Get training jobs for tenant""" try: logger.debug("Getting training jobs", tenant_id=tenant_id, status=status, limit=limit, offset=offset) jobs = await training_service.get_training_jobs( tenant_id=tenant_id, status=status, limit=limit, offset=offset ) logger.debug("Retrieved training jobs", count=len(jobs), tenant_id=tenant_id) return jobs except Exception as e: logger.error("Failed to get training jobs", error=str(e), tenant_id=tenant_id) raise HTTPException(status_code=500, detail=f"Failed to get training jobs: {str(e)}") @router.get("/jobs/{job_id}", response_model=TrainingJobResponse) async def get_training_job( job_id: str, tenant_id: str = Depends(get_current_tenant_id_dep), current_user: Dict[str, Any] = Depends(get_current_user_dep), training_service: TrainingService = Depends() ): """Get specific training job details""" try: logger.debug("Getting training job", job_id=job_id, tenant_id=tenant_id) job = await training_service.get_training_job(job_id) # Verify tenant access if job.tenant_id != tenant_id: logger.warning("Unauthorized job access attempt", job_id=job_id, tenant_id=tenant_id, job_tenant_id=job.tenant_id) raise HTTPException(status_code=404, detail="Job not found") return job except HTTPException: raise except Exception as e: logger.error("Failed to get training job", error=str(e), job_id=job_id) raise HTTPException(status_code=500, detail=f"Failed to get training job: {str(e)}") @router.get("/jobs/{job_id}/progress", response_model=TrainingJobProgress) async def get_training_progress( job_id: str, tenant_id: str = Depends(get_current_tenant_id_dep), current_user: Dict[str, Any] = Depends(get_current_user_dep), training_service: TrainingService = Depends() ): """Get real-time training progress""" try: logger.debug("Getting training progress", job_id=job_id, tenant_id=tenant_id) # Verify job belongs to tenant job = await training_service.get_training_job(job_id) if job.tenant_id != tenant_id: raise HTTPException(status_code=404, detail="Job not found") progress = await training_service.get_job_progress(job_id) return progress except HTTPException: raise except Exception as e: logger.error("Failed to get training progress", error=str(e), job_id=job_id) raise HTTPException(status_code=500, detail=f"Failed to get training progress: {str(e)}") @router.post("/jobs/{job_id}/cancel") async def cancel_training_job( job_id: str, tenant_id: str = Depends(get_current_tenant_id_dep), current_user: Dict[str, Any] = Depends(get_current_user_dep), training_service: TrainingService = Depends() ): """Cancel a running training job""" try: logger.info("Cancelling training job", job_id=job_id, tenant_id=tenant_id, user_id=current_user["user_id"]) job = await training_service.get_training_job(job_id) # Verify tenant access if job.tenant_id != tenant_id: raise HTTPException(status_code=404, detail="Job not found") await training_service.cancel_training_job(job_id) # Publish cancellation event try: await publish_job_failed( job_id=job_id, tenant_id=tenant_id, error="Job cancelled by user", failed_at="cancellation" ) except Exception as e: logger.warning("Failed to publish cancellation event", error=str(e)) logger.info("Training job cancelled", job_id=job_id) return {"message": "Job cancelled successfully", "job_id": job_id} except HTTPException: raise except Exception as e: logger.error("Failed to cancel training job", error=str(e), job_id=job_id) raise HTTPException(status_code=500, detail=f"Failed to cancel training job: {str(e)}") @router.post("/products/{product_name}", response_model=TrainingJobResponse) async def train_single_product( product_name: str, request: SingleProductTrainingRequest, background_tasks: BackgroundTasks, tenant_id: str = Depends(get_current_tenant_id_dep), current_user: Dict[str, Any] = Depends(get_current_user_dep), training_service: TrainingService = Depends() ): """Train model for a single product""" try: logger.info("Training single product", product_name=product_name, tenant_id=tenant_id, user_id=current_user["user_id"]) # Create training job for single product job = await training_service.create_single_product_job( tenant_id=tenant_id, user_id=current_user["user_id"], product_name=product_name, config=request.dict() ) # Publish event try: await publish_product_training_started( job_id=job.job_id, tenant_id=tenant_id, product_name=product_name ) except Exception as e: logger.warning("Failed to publish product training event", error=str(e)) # Start training in background background_tasks.add_task( training_service.execute_single_product_training, job.job_id, product_name ) logger.info("Single product training started", job_id=job.job_id, product_name=product_name) return job except Exception as e: logger.error("Failed to train single product", error=str(e), product_name=product_name, tenant_id=tenant_id) raise HTTPException(status_code=500, detail=f"Failed to train product: {str(e)}") @router.post("/validate", response_model=DataValidationResponse) async def validate_training_data( request: DataValidationRequest, tenant_id: str = Depends(get_current_tenant_id_dep), current_user: Dict[str, Any] = Depends(get_current_user_dep), training_service: TrainingService = Depends() ): """Validate data before training""" try: logger.debug("Validating training data", tenant_id=tenant_id, products=request.products) validation_result = await training_service.validate_training_data( tenant_id=tenant_id, products=request.products, min_data_points=request.min_data_points ) logger.debug("Data validation completed", is_valid=validation_result.is_valid, tenant_id=tenant_id) return validation_result except Exception as e: logger.error("Failed to validate training data", error=str(e), tenant_id=tenant_id) raise HTTPException(status_code=500, detail=f"Failed to validate data: {str(e)}") @router.get("/models") async def get_trained_models( product_name: Optional[str] = Query(None), tenant_id: str = Depends(get_current_tenant_id_dep), current_user: Dict[str, Any] = Depends(get_current_user_dep), training_service: TrainingService = Depends() ): """Get list of trained models""" try: logger.debug("Getting trained models", tenant_id=tenant_id, product_name=product_name) models = await training_service.get_trained_models( tenant_id=tenant_id, product_name=product_name ) logger.debug("Retrieved trained models", count=len(models), tenant_id=tenant_id) return models except Exception as e: logger.error("Failed to get trained models", error=str(e), tenant_id=tenant_id) raise HTTPException(status_code=500, detail=f"Failed to get models: {str(e)}") @router.delete("/models/{model_id}") @require_role("admin") # Only admins can delete models async def delete_model( model_id: str, tenant_id: str = Depends(get_current_tenant_id_dep), current_user: Dict[str, Any] = Depends(get_current_user_dep), training_service: TrainingService = Depends() ): """Delete a trained model (admin only)""" try: logger.info("Deleting model", model_id=model_id, tenant_id=tenant_id, admin_id=current_user["user_id"]) # Verify model belongs to tenant model = await training_service.get_model(model_id) if model.tenant_id != tenant_id: raise HTTPException(status_code=404, detail="Model not found") success = await training_service.delete_model(model_id) if not success: raise HTTPException(status_code=404, detail="Model not found") logger.info("Model deleted successfully", model_id=model_id) return {"message": "Model deleted successfully", "model_id": model_id} except HTTPException: raise except Exception as e: logger.error("Failed to delete model", error=str(e), model_id=model_id) raise HTTPException(status_code=500, detail=f"Failed to delete model: {str(e)}") @router.get("/stats") async def get_training_stats( start_date: Optional[datetime] = Query(None), end_date: Optional[datetime] = Query(None), tenant_id: str = Depends(get_current_tenant_id_dep), current_user: Dict[str, Any] = Depends(get_current_user_dep), training_service: TrainingService = Depends() ): """Get training statistics for tenant""" try: logger.debug("Getting training stats", tenant_id=tenant_id, start_date=start_date, end_date=end_date) stats = await training_service.get_training_stats( tenant_id=tenant_id, start_date=start_date, end_date=end_date ) logger.debug("Training stats retrieved", tenant_id=tenant_id) return stats except Exception as e: logger.error("Failed to get training stats", error=str(e), tenant_id=tenant_id) raise HTTPException(status_code=500, detail=f"Failed to get stats: {str(e)}") @router.post("/retrain/all") async def retrain_all_products( request: TrainingJobRequest, background_tasks: BackgroundTasks, tenant_id: str = Depends(get_current_tenant_id_dep), current_user: Dict[str, Any] = Depends(get_current_user_dep), training_service: TrainingService = Depends() ): """Retrain all products with existing models""" try: logger.info("Retraining all products", tenant_id=tenant_id, user_id=current_user["user_id"]) # Check if models exist existing_models = await training_service.get_trained_models(tenant_id) if not existing_models: raise HTTPException( status_code=400, detail="No existing models found. Please run initial training first." ) # Create retraining job job = await training_service.create_training_job( tenant_id=tenant_id, user_id=current_user["user_id"], config={**request.dict(), "is_retrain": True} ) # Publish event try: await publish_job_started( job_id=job.job_id, tenant_id=tenant_id, config={**request.dict(), "is_retrain": True} ) except Exception as e: logger.warning("Failed to publish retrain event", error=str(e)) # Start retraining in background background_tasks.add_task( training_service.execute_training_job, job.job_id ) logger.info("Retraining job created", job_id=job.job_id) return job except HTTPException: raise except Exception as e: logger.error("Failed to start retraining", error=str(e), tenant_id=tenant_id) raise HTTPException(status_code=500, detail=f"Failed to start retraining: {str(e)}")