# ================================================================ # services/forecasting/app/api/retraining.py # ================================================================ """ Retraining API - Trigger and manage model retraining based on performance """ from fastapi import APIRouter, Depends, HTTPException, Path, Query, status from typing import Dict, Any, List from uuid import UUID import structlog from pydantic import BaseModel, Field from app.services.retraining_trigger_service import RetrainingTriggerService from shared.auth.decorators import get_current_user_dep from shared.auth.access_control import require_user_role from shared.routing import RouteBuilder from app.core.database import get_db from sqlalchemy.ext.asyncio import AsyncSession route_builder = RouteBuilder('forecasting') router = APIRouter(tags=["retraining"]) logger = structlog.get_logger() # ================================================================ # Request/Response Schemas # ================================================================ class EvaluateRetrainingRequest(BaseModel): """Request model for retraining evaluation""" auto_trigger: bool = Field( default=False, description="Automatically trigger retraining for poor performers" ) class TriggerProductRetrainingRequest(BaseModel): """Request model for single product retraining""" inventory_product_id: UUID = Field(..., description="Product to retrain") reason: str = Field(..., description="Reason for retraining") priority: str = Field( default="normal", description="Priority level: low, normal, high" ) class TriggerBulkRetrainingRequest(BaseModel): """Request model for bulk retraining""" product_ids: List[UUID] = Field(..., description="List of products to retrain") reason: str = Field( default="Bulk retraining requested", description="Reason for bulk retraining" ) class ScheduledRetrainingCheckRequest(BaseModel): """Request model for scheduled retraining check""" max_model_age_days: int = Field( default=30, ge=1, le=90, description="Maximum acceptable model age" ) # ================================================================ # Endpoints # ================================================================ @router.post( route_builder.build_base_route("retraining/evaluate"), status_code=status.HTTP_200_OK ) @require_user_role(['admin', 'owner']) async def evaluate_retraining_needs( request: EvaluateRetrainingRequest, tenant_id: UUID = Path(..., description="Tenant ID"), current_user: Dict[str, Any] = Depends(get_current_user_dep), db: AsyncSession = Depends(get_db) ): """ Evaluate performance and optionally trigger retraining Analyzes 30-day performance and identifies products needing retraining. If auto_trigger=true, automatically triggers retraining for poor performers. """ try: logger.info( "Evaluating retraining needs", tenant_id=tenant_id, auto_trigger=request.auto_trigger, user_id=current_user.get("user_id") ) service = RetrainingTriggerService(db) result = await service.evaluate_and_trigger_retraining( tenant_id=tenant_id, auto_trigger=request.auto_trigger ) return result except Exception as e: logger.error( "Failed to evaluate retraining needs", tenant_id=tenant_id, error=str(e) ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to evaluate retraining: {str(e)}" ) @router.post( route_builder.build_base_route("retraining/trigger-product"), status_code=status.HTTP_200_OK ) @require_user_role(['admin', 'owner']) async def trigger_product_retraining( request: TriggerProductRetrainingRequest, tenant_id: UUID = Path(..., description="Tenant ID"), current_user: Dict[str, Any] = Depends(get_current_user_dep), db: AsyncSession = Depends(get_db) ): """ Trigger retraining for a specific product Manually trigger model retraining for a single product. """ try: logger.info( "Triggering product retraining", tenant_id=tenant_id, product_id=request.inventory_product_id, reason=request.reason, user_id=current_user.get("user_id") ) service = RetrainingTriggerService(db) result = await service._trigger_product_retraining( tenant_id=tenant_id, inventory_product_id=request.inventory_product_id, reason=request.reason, priority=request.priority ) return result except Exception as e: logger.error( "Failed to trigger product retraining", tenant_id=tenant_id, product_id=request.inventory_product_id, error=str(e) ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to trigger retraining: {str(e)}" ) @router.post( route_builder.build_base_route("retraining/trigger-bulk"), status_code=status.HTTP_200_OK ) @require_user_role(['admin', 'owner']) async def trigger_bulk_retraining( request: TriggerBulkRetrainingRequest, tenant_id: UUID = Path(..., description="Tenant ID"), current_user: Dict[str, Any] = Depends(get_current_user_dep), db: AsyncSession = Depends(get_db) ): """ Trigger retraining for multiple products Bulk retraining operation for multiple products at once. """ try: logger.info( "Triggering bulk retraining", tenant_id=tenant_id, product_count=len(request.product_ids), reason=request.reason, user_id=current_user.get("user_id") ) service = RetrainingTriggerService(db) result = await service.trigger_bulk_retraining( tenant_id=tenant_id, product_ids=request.product_ids, reason=request.reason ) return result except Exception as e: logger.error( "Failed to trigger bulk retraining", tenant_id=tenant_id, error=str(e) ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to trigger bulk retraining: {str(e)}" ) @router.get( route_builder.build_base_route("retraining/recommendations"), status_code=status.HTTP_200_OK ) @require_user_role(['admin', 'owner', 'member']) async def get_retraining_recommendations( tenant_id: UUID = Path(..., description="Tenant ID"), current_user: Dict[str, Any] = Depends(get_current_user_dep), db: AsyncSession = Depends(get_db) ): """ Get retraining recommendations without triggering Returns recommendations for manual review and decision-making. """ try: logger.info( "Getting retraining recommendations", tenant_id=tenant_id, user_id=current_user.get("user_id") ) service = RetrainingTriggerService(db) recommendations = await service.get_retraining_recommendations( tenant_id=tenant_id ) return recommendations except Exception as e: logger.error( "Failed to get recommendations", tenant_id=tenant_id, error=str(e) ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to get recommendations: {str(e)}" ) @router.post( route_builder.build_base_route("retraining/check-scheduled"), status_code=status.HTTP_200_OK ) @require_user_role(['admin', 'owner']) async def check_scheduled_retraining( request: ScheduledRetrainingCheckRequest, tenant_id: UUID = Path(..., description="Tenant ID"), current_user: Dict[str, Any] = Depends(get_current_user_dep), db: AsyncSession = Depends(get_db) ): """ Check for models needing scheduled retraining based on age Identifies models that haven't been updated in max_model_age_days. """ try: logger.info( "Checking scheduled retraining needs", tenant_id=tenant_id, max_model_age_days=request.max_model_age_days, user_id=current_user.get("user_id") ) service = RetrainingTriggerService(db) result = await service.check_and_trigger_scheduled_retraining( tenant_id=tenant_id, max_model_age_days=request.max_model_age_days ) return result except Exception as e: logger.error( "Failed to check scheduled retraining", tenant_id=tenant_id, error=str(e) ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to check scheduled retraining: {str(e)}" )