298 lines
8.9 KiB
Python
298 lines
8.9 KiB
Python
# ================================================================
|
|
# services/forecasting/app/api/retraining.py
|
|
# ================================================================
|
|
"""
|
|
Retraining API - Trigger and manage model retraining based on performance
|
|
"""
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Path, Query, status
|
|
from typing import Dict, Any, List
|
|
from uuid import UUID
|
|
import structlog
|
|
|
|
from pydantic import BaseModel, Field
|
|
from app.services.retraining_trigger_service import RetrainingTriggerService
|
|
from shared.auth.decorators import get_current_user_dep
|
|
from shared.auth.access_control import require_user_role
|
|
from shared.routing import RouteBuilder
|
|
from app.core.database import get_db
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
route_builder = RouteBuilder('forecasting')
|
|
router = APIRouter(tags=["retraining"])
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
# ================================================================
|
|
# Request/Response Schemas
|
|
# ================================================================
|
|
|
|
class EvaluateRetrainingRequest(BaseModel):
|
|
"""Request model for retraining evaluation"""
|
|
auto_trigger: bool = Field(
|
|
default=False,
|
|
description="Automatically trigger retraining for poor performers"
|
|
)
|
|
|
|
|
|
class TriggerProductRetrainingRequest(BaseModel):
|
|
"""Request model for single product retraining"""
|
|
inventory_product_id: UUID = Field(..., description="Product to retrain")
|
|
reason: str = Field(..., description="Reason for retraining")
|
|
priority: str = Field(
|
|
default="normal",
|
|
description="Priority level: low, normal, high"
|
|
)
|
|
|
|
|
|
class TriggerBulkRetrainingRequest(BaseModel):
|
|
"""Request model for bulk retraining"""
|
|
product_ids: List[UUID] = Field(..., description="List of products to retrain")
|
|
reason: str = Field(
|
|
default="Bulk retraining requested",
|
|
description="Reason for bulk retraining"
|
|
)
|
|
|
|
|
|
class ScheduledRetrainingCheckRequest(BaseModel):
|
|
"""Request model for scheduled retraining check"""
|
|
max_model_age_days: int = Field(
|
|
default=30,
|
|
ge=1,
|
|
le=90,
|
|
description="Maximum acceptable model age"
|
|
)
|
|
|
|
|
|
# ================================================================
|
|
# Endpoints
|
|
# ================================================================
|
|
|
|
@router.post(
|
|
route_builder.build_base_route("retraining/evaluate"),
|
|
status_code=status.HTTP_200_OK
|
|
)
|
|
@require_user_role(['admin', 'owner'])
|
|
async def evaluate_retraining_needs(
|
|
request: EvaluateRetrainingRequest,
|
|
tenant_id: UUID = Path(..., description="Tenant ID"),
|
|
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""
|
|
Evaluate performance and optionally trigger retraining
|
|
|
|
Analyzes 30-day performance and identifies products needing retraining.
|
|
If auto_trigger=true, automatically triggers retraining for poor performers.
|
|
"""
|
|
try:
|
|
logger.info(
|
|
"Evaluating retraining needs",
|
|
tenant_id=tenant_id,
|
|
auto_trigger=request.auto_trigger,
|
|
user_id=current_user.get("user_id")
|
|
)
|
|
|
|
service = RetrainingTriggerService(db)
|
|
|
|
result = await service.evaluate_and_trigger_retraining(
|
|
tenant_id=tenant_id,
|
|
auto_trigger=request.auto_trigger
|
|
)
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
"Failed to evaluate retraining needs",
|
|
tenant_id=tenant_id,
|
|
error=str(e)
|
|
)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Failed to evaluate retraining: {str(e)}"
|
|
)
|
|
|
|
|
|
@router.post(
|
|
route_builder.build_base_route("retraining/trigger-product"),
|
|
status_code=status.HTTP_200_OK
|
|
)
|
|
@require_user_role(['admin', 'owner'])
|
|
async def trigger_product_retraining(
|
|
request: TriggerProductRetrainingRequest,
|
|
tenant_id: UUID = Path(..., description="Tenant ID"),
|
|
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""
|
|
Trigger retraining for a specific product
|
|
|
|
Manually trigger model retraining for a single product.
|
|
"""
|
|
try:
|
|
logger.info(
|
|
"Triggering product retraining",
|
|
tenant_id=tenant_id,
|
|
product_id=request.inventory_product_id,
|
|
reason=request.reason,
|
|
user_id=current_user.get("user_id")
|
|
)
|
|
|
|
service = RetrainingTriggerService(db)
|
|
|
|
result = await service._trigger_product_retraining(
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=request.inventory_product_id,
|
|
reason=request.reason,
|
|
priority=request.priority
|
|
)
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
"Failed to trigger product retraining",
|
|
tenant_id=tenant_id,
|
|
product_id=request.inventory_product_id,
|
|
error=str(e)
|
|
)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Failed to trigger retraining: {str(e)}"
|
|
)
|
|
|
|
|
|
@router.post(
|
|
route_builder.build_base_route("retraining/trigger-bulk"),
|
|
status_code=status.HTTP_200_OK
|
|
)
|
|
@require_user_role(['admin', 'owner'])
|
|
async def trigger_bulk_retraining(
|
|
request: TriggerBulkRetrainingRequest,
|
|
tenant_id: UUID = Path(..., description="Tenant ID"),
|
|
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""
|
|
Trigger retraining for multiple products
|
|
|
|
Bulk retraining operation for multiple products at once.
|
|
"""
|
|
try:
|
|
logger.info(
|
|
"Triggering bulk retraining",
|
|
tenant_id=tenant_id,
|
|
product_count=len(request.product_ids),
|
|
reason=request.reason,
|
|
user_id=current_user.get("user_id")
|
|
)
|
|
|
|
service = RetrainingTriggerService(db)
|
|
|
|
result = await service.trigger_bulk_retraining(
|
|
tenant_id=tenant_id,
|
|
product_ids=request.product_ids,
|
|
reason=request.reason
|
|
)
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
"Failed to trigger bulk retraining",
|
|
tenant_id=tenant_id,
|
|
error=str(e)
|
|
)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Failed to trigger bulk retraining: {str(e)}"
|
|
)
|
|
|
|
|
|
@router.get(
|
|
route_builder.build_base_route("retraining/recommendations"),
|
|
status_code=status.HTTP_200_OK
|
|
)
|
|
@require_user_role(['admin', 'owner', 'member'])
|
|
async def get_retraining_recommendations(
|
|
tenant_id: UUID = Path(..., description="Tenant ID"),
|
|
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""
|
|
Get retraining recommendations without triggering
|
|
|
|
Returns recommendations for manual review and decision-making.
|
|
"""
|
|
try:
|
|
logger.info(
|
|
"Getting retraining recommendations",
|
|
tenant_id=tenant_id,
|
|
user_id=current_user.get("user_id")
|
|
)
|
|
|
|
service = RetrainingTriggerService(db)
|
|
|
|
recommendations = await service.get_retraining_recommendations(
|
|
tenant_id=tenant_id
|
|
)
|
|
|
|
return recommendations
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
"Failed to get recommendations",
|
|
tenant_id=tenant_id,
|
|
error=str(e)
|
|
)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Failed to get recommendations: {str(e)}"
|
|
)
|
|
|
|
|
|
@router.post(
|
|
route_builder.build_base_route("retraining/check-scheduled"),
|
|
status_code=status.HTTP_200_OK
|
|
)
|
|
@require_user_role(['admin', 'owner'])
|
|
async def check_scheduled_retraining(
|
|
request: ScheduledRetrainingCheckRequest,
|
|
tenant_id: UUID = Path(..., description="Tenant ID"),
|
|
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""
|
|
Check for models needing scheduled retraining based on age
|
|
|
|
Identifies models that haven't been updated in max_model_age_days.
|
|
"""
|
|
try:
|
|
logger.info(
|
|
"Checking scheduled retraining needs",
|
|
tenant_id=tenant_id,
|
|
max_model_age_days=request.max_model_age_days,
|
|
user_id=current_user.get("user_id")
|
|
)
|
|
|
|
service = RetrainingTriggerService(db)
|
|
|
|
result = await service.check_and_trigger_scheduled_retraining(
|
|
tenant_id=tenant_id,
|
|
max_model_age_days=request.max_model_age_days
|
|
)
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
"Failed to check scheduled retraining",
|
|
tenant_id=tenant_id,
|
|
error=str(e)
|
|
)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Failed to check scheduled retraining: {str(e)}"
|
|
)
|