Improve backend
This commit is contained in:
297
services/forecasting/app/api/retraining.py
Normal file
297
services/forecasting/app/api/retraining.py
Normal file
@@ -0,0 +1,297 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/api/retraining.py
|
||||
# ================================================================
|
||||
"""
|
||||
Retraining API - Trigger and manage model retraining based on performance
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Query, status
|
||||
from typing import Dict, Any, List
|
||||
from uuid import UUID
|
||||
import structlog
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from app.services.retraining_trigger_service import RetrainingTriggerService
|
||||
from shared.auth.decorators import get_current_user_dep
|
||||
from shared.auth.access_control import require_user_role
|
||||
from shared.routing import RouteBuilder
|
||||
from app.core.database import get_db
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
route_builder = RouteBuilder('forecasting')
|
||||
router = APIRouter(tags=["retraining"])
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Request/Response Schemas
|
||||
# ================================================================
|
||||
|
||||
class EvaluateRetrainingRequest(BaseModel):
|
||||
"""Request model for retraining evaluation"""
|
||||
auto_trigger: bool = Field(
|
||||
default=False,
|
||||
description="Automatically trigger retraining for poor performers"
|
||||
)
|
||||
|
||||
|
||||
class TriggerProductRetrainingRequest(BaseModel):
|
||||
"""Request model for single product retraining"""
|
||||
inventory_product_id: UUID = Field(..., description="Product to retrain")
|
||||
reason: str = Field(..., description="Reason for retraining")
|
||||
priority: str = Field(
|
||||
default="normal",
|
||||
description="Priority level: low, normal, high"
|
||||
)
|
||||
|
||||
|
||||
class TriggerBulkRetrainingRequest(BaseModel):
|
||||
"""Request model for bulk retraining"""
|
||||
product_ids: List[UUID] = Field(..., description="List of products to retrain")
|
||||
reason: str = Field(
|
||||
default="Bulk retraining requested",
|
||||
description="Reason for bulk retraining"
|
||||
)
|
||||
|
||||
|
||||
class ScheduledRetrainingCheckRequest(BaseModel):
|
||||
"""Request model for scheduled retraining check"""
|
||||
max_model_age_days: int = Field(
|
||||
default=30,
|
||||
ge=1,
|
||||
le=90,
|
||||
description="Maximum acceptable model age"
|
||||
)
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Endpoints
|
||||
# ================================================================
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("retraining/evaluate"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner'])
|
||||
async def evaluate_retraining_needs(
|
||||
request: EvaluateRetrainingRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Evaluate performance and optionally trigger retraining
|
||||
|
||||
Analyzes 30-day performance and identifies products needing retraining.
|
||||
If auto_trigger=true, automatically triggers retraining for poor performers.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Evaluating retraining needs",
|
||||
tenant_id=tenant_id,
|
||||
auto_trigger=request.auto_trigger,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = RetrainingTriggerService(db)
|
||||
|
||||
result = await service.evaluate_and_trigger_retraining(
|
||||
tenant_id=tenant_id,
|
||||
auto_trigger=request.auto_trigger
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to evaluate retraining needs",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to evaluate retraining: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("retraining/trigger-product"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner'])
|
||||
async def trigger_product_retraining(
|
||||
request: TriggerProductRetrainingRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Trigger retraining for a specific product
|
||||
|
||||
Manually trigger model retraining for a single product.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Triggering product retraining",
|
||||
tenant_id=tenant_id,
|
||||
product_id=request.inventory_product_id,
|
||||
reason=request.reason,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = RetrainingTriggerService(db)
|
||||
|
||||
result = await service._trigger_product_retraining(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=request.inventory_product_id,
|
||||
reason=request.reason,
|
||||
priority=request.priority
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to trigger product retraining",
|
||||
tenant_id=tenant_id,
|
||||
product_id=request.inventory_product_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to trigger retraining: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("retraining/trigger-bulk"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner'])
|
||||
async def trigger_bulk_retraining(
|
||||
request: TriggerBulkRetrainingRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Trigger retraining for multiple products
|
||||
|
||||
Bulk retraining operation for multiple products at once.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Triggering bulk retraining",
|
||||
tenant_id=tenant_id,
|
||||
product_count=len(request.product_ids),
|
||||
reason=request.reason,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = RetrainingTriggerService(db)
|
||||
|
||||
result = await service.trigger_bulk_retraining(
|
||||
tenant_id=tenant_id,
|
||||
product_ids=request.product_ids,
|
||||
reason=request.reason
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to trigger bulk retraining",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to trigger bulk retraining: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("retraining/recommendations"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def get_retraining_recommendations(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get retraining recommendations without triggering
|
||||
|
||||
Returns recommendations for manual review and decision-making.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Getting retraining recommendations",
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = RetrainingTriggerService(db)
|
||||
|
||||
recommendations = await service.get_retraining_recommendations(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
return recommendations
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get recommendations",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get recommendations: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("retraining/check-scheduled"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner'])
|
||||
async def check_scheduled_retraining(
|
||||
request: ScheduledRetrainingCheckRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Check for models needing scheduled retraining based on age
|
||||
|
||||
Identifies models that haven't been updated in max_model_age_days.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Checking scheduled retraining needs",
|
||||
tenant_id=tenant_id,
|
||||
max_model_age_days=request.max_model_age_days,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = RetrainingTriggerService(db)
|
||||
|
||||
result = await service.check_and_trigger_scheduled_retraining(
|
||||
tenant_id=tenant_id,
|
||||
max_model_age_days=request.max_model_age_days
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to check scheduled retraining",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to check scheduled retraining: {str(e)}"
|
||||
)
|
||||
Reference in New Issue
Block a user