385 lines
13 KiB
Python
385 lines
13 KiB
Python
|
|
# ================================================================
|
||
|
|
# services/forecasting/app/services/retraining_trigger_service.py
|
||
|
|
# ================================================================
|
||
|
|
"""
|
||
|
|
Retraining Trigger Service
|
||
|
|
|
||
|
|
Automatically triggers model retraining based on performance metrics,
|
||
|
|
accuracy degradation, or data availability.
|
||
|
|
"""
|
||
|
|
|
||
|
|
from typing import Dict, Any, List, Optional
|
||
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
|
from datetime import datetime, timezone
|
||
|
|
import structlog
|
||
|
|
import uuid
|
||
|
|
|
||
|
|
from app.services.performance_monitoring_service import PerformanceMonitoringService
|
||
|
|
from shared.clients.training_client import TrainingServiceClient
|
||
|
|
from shared.config.base import BaseServiceSettings
|
||
|
|
from shared.database.exceptions import DatabaseError
|
||
|
|
|
||
|
|
logger = structlog.get_logger()
|
||
|
|
|
||
|
|
|
||
|
|
class RetrainingTriggerService:
|
||
|
|
"""Service for triggering automatic model retraining"""
|
||
|
|
|
||
|
|
def __init__(self, db_session: AsyncSession):
|
||
|
|
self.db = db_session
|
||
|
|
self.performance_service = PerformanceMonitoringService(db_session)
|
||
|
|
|
||
|
|
# Initialize training client
|
||
|
|
config = BaseServiceSettings()
|
||
|
|
self.training_client = TrainingServiceClient(config, calling_service_name="forecasting")
|
||
|
|
|
||
|
|
async def evaluate_and_trigger_retraining(
|
||
|
|
self,
|
||
|
|
tenant_id: uuid.UUID,
|
||
|
|
auto_trigger: bool = True
|
||
|
|
) -> Dict[str, Any]:
|
||
|
|
"""
|
||
|
|
Evaluate performance and trigger retraining if needed
|
||
|
|
|
||
|
|
Args:
|
||
|
|
tenant_id: Tenant identifier
|
||
|
|
auto_trigger: Whether to automatically trigger retraining
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Evaluation results and retraining actions taken
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
logger.info(
|
||
|
|
"Evaluating retraining needs",
|
||
|
|
tenant_id=tenant_id,
|
||
|
|
auto_trigger=auto_trigger
|
||
|
|
)
|
||
|
|
|
||
|
|
# Generate performance report
|
||
|
|
report = await self.performance_service.generate_performance_report(
|
||
|
|
tenant_id=tenant_id,
|
||
|
|
days=30
|
||
|
|
)
|
||
|
|
|
||
|
|
if not report.get("requires_action"):
|
||
|
|
logger.info(
|
||
|
|
"No retraining required",
|
||
|
|
tenant_id=tenant_id,
|
||
|
|
health_status=report["summary"].get("health_status")
|
||
|
|
)
|
||
|
|
return {
|
||
|
|
"status": "no_action_needed",
|
||
|
|
"tenant_id": str(tenant_id),
|
||
|
|
"health_status": report["summary"].get("health_status"),
|
||
|
|
"report": report
|
||
|
|
}
|
||
|
|
|
||
|
|
# Extract products that need retraining
|
||
|
|
products_to_retrain = []
|
||
|
|
recommendations = report.get("recommendations", [])
|
||
|
|
|
||
|
|
for rec in recommendations:
|
||
|
|
if rec.get("action") == "retrain_poor_performers":
|
||
|
|
products_to_retrain.extend(rec.get("products", []))
|
||
|
|
|
||
|
|
if not products_to_retrain and auto_trigger:
|
||
|
|
# If degradation detected but no specific products, consider retraining all
|
||
|
|
degradation = report.get("degradation_analysis", {})
|
||
|
|
if degradation.get("is_degrading") and degradation.get("severity") in ["high", "medium"]:
|
||
|
|
logger.info(
|
||
|
|
"General degradation detected, considering full retraining",
|
||
|
|
tenant_id=tenant_id,
|
||
|
|
severity=degradation.get("severity")
|
||
|
|
)
|
||
|
|
|
||
|
|
retraining_results = []
|
||
|
|
|
||
|
|
if auto_trigger and products_to_retrain:
|
||
|
|
# Trigger retraining for poor performers
|
||
|
|
for product in products_to_retrain:
|
||
|
|
try:
|
||
|
|
result = await self._trigger_product_retraining(
|
||
|
|
tenant_id=tenant_id,
|
||
|
|
inventory_product_id=uuid.UUID(product["inventory_product_id"]),
|
||
|
|
reason=f"MAPE {product['avg_mape']}% exceeds threshold",
|
||
|
|
priority="high"
|
||
|
|
)
|
||
|
|
retraining_results.append(result)
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(
|
||
|
|
"Failed to trigger retraining for product",
|
||
|
|
product_id=product["inventory_product_id"],
|
||
|
|
error=str(e)
|
||
|
|
)
|
||
|
|
retraining_results.append({
|
||
|
|
"product_id": product["inventory_product_id"],
|
||
|
|
"status": "failed",
|
||
|
|
"error": str(e)
|
||
|
|
})
|
||
|
|
|
||
|
|
logger.info(
|
||
|
|
"Retraining evaluation complete",
|
||
|
|
tenant_id=tenant_id,
|
||
|
|
products_evaluated=len(products_to_retrain),
|
||
|
|
retraining_triggered=len(retraining_results)
|
||
|
|
)
|
||
|
|
|
||
|
|
return {
|
||
|
|
"status": "evaluated",
|
||
|
|
"tenant_id": str(tenant_id),
|
||
|
|
"requires_action": report.get("requires_action"),
|
||
|
|
"products_needing_retraining": len(products_to_retrain),
|
||
|
|
"retraining_triggered": len(retraining_results),
|
||
|
|
"auto_trigger_enabled": auto_trigger,
|
||
|
|
"retraining_results": retraining_results,
|
||
|
|
"performance_report": report
|
||
|
|
}
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(
|
||
|
|
"Failed to evaluate and trigger retraining",
|
||
|
|
tenant_id=tenant_id,
|
||
|
|
error=str(e)
|
||
|
|
)
|
||
|
|
raise DatabaseError(f"Failed to evaluate retraining: {str(e)}")
|
||
|
|
|
||
|
|
async def _trigger_product_retraining(
|
||
|
|
self,
|
||
|
|
tenant_id: uuid.UUID,
|
||
|
|
inventory_product_id: uuid.UUID,
|
||
|
|
reason: str,
|
||
|
|
priority: str = "normal"
|
||
|
|
) -> Dict[str, Any]:
|
||
|
|
"""
|
||
|
|
Trigger retraining for a specific product
|
||
|
|
|
||
|
|
Args:
|
||
|
|
tenant_id: Tenant identifier
|
||
|
|
inventory_product_id: Product to retrain
|
||
|
|
reason: Reason for retraining
|
||
|
|
priority: Priority level (low, normal, high)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Retraining trigger result
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
logger.info(
|
||
|
|
"Triggering product retraining",
|
||
|
|
tenant_id=tenant_id,
|
||
|
|
product_id=inventory_product_id,
|
||
|
|
reason=reason,
|
||
|
|
priority=priority
|
||
|
|
)
|
||
|
|
|
||
|
|
# Call training service to trigger retraining
|
||
|
|
result = await self.training_client.trigger_retrain(
|
||
|
|
tenant_id=str(tenant_id),
|
||
|
|
inventory_product_id=str(inventory_product_id),
|
||
|
|
reason=reason,
|
||
|
|
priority=priority
|
||
|
|
)
|
||
|
|
|
||
|
|
if result:
|
||
|
|
logger.info(
|
||
|
|
"Retraining triggered successfully",
|
||
|
|
tenant_id=tenant_id,
|
||
|
|
product_id=inventory_product_id,
|
||
|
|
training_job_id=result.get("training_job_id")
|
||
|
|
)
|
||
|
|
|
||
|
|
return {
|
||
|
|
"status": "triggered",
|
||
|
|
"product_id": str(inventory_product_id),
|
||
|
|
"training_job_id": result.get("training_job_id"),
|
||
|
|
"reason": reason,
|
||
|
|
"priority": priority,
|
||
|
|
"triggered_at": datetime.now(timezone.utc).isoformat()
|
||
|
|
}
|
||
|
|
else:
|
||
|
|
logger.warning(
|
||
|
|
"Retraining trigger returned no result",
|
||
|
|
tenant_id=tenant_id,
|
||
|
|
product_id=inventory_product_id
|
||
|
|
)
|
||
|
|
return {
|
||
|
|
"status": "no_response",
|
||
|
|
"product_id": str(inventory_product_id),
|
||
|
|
"reason": reason
|
||
|
|
}
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(
|
||
|
|
"Failed to trigger product retraining",
|
||
|
|
tenant_id=tenant_id,
|
||
|
|
product_id=inventory_product_id,
|
||
|
|
error=str(e)
|
||
|
|
)
|
||
|
|
return {
|
||
|
|
"status": "failed",
|
||
|
|
"product_id": str(inventory_product_id),
|
||
|
|
"error": str(e),
|
||
|
|
"reason": reason
|
||
|
|
}
|
||
|
|
|
||
|
|
async def trigger_bulk_retraining(
|
||
|
|
self,
|
||
|
|
tenant_id: uuid.UUID,
|
||
|
|
product_ids: List[uuid.UUID],
|
||
|
|
reason: str = "Bulk retraining requested"
|
||
|
|
) -> Dict[str, Any]:
|
||
|
|
"""
|
||
|
|
Trigger retraining for multiple products
|
||
|
|
|
||
|
|
Args:
|
||
|
|
tenant_id: Tenant identifier
|
||
|
|
product_ids: List of products to retrain
|
||
|
|
reason: Reason for bulk retraining
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Bulk retraining results
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
logger.info(
|
||
|
|
"Triggering bulk retraining",
|
||
|
|
tenant_id=tenant_id,
|
||
|
|
product_count=len(product_ids)
|
||
|
|
)
|
||
|
|
|
||
|
|
results = []
|
||
|
|
|
||
|
|
for product_id in product_ids:
|
||
|
|
result = await self._trigger_product_retraining(
|
||
|
|
tenant_id=tenant_id,
|
||
|
|
inventory_product_id=product_id,
|
||
|
|
reason=reason,
|
||
|
|
priority="normal"
|
||
|
|
)
|
||
|
|
results.append(result)
|
||
|
|
|
||
|
|
successful = sum(1 for r in results if r["status"] == "triggered")
|
||
|
|
|
||
|
|
logger.info(
|
||
|
|
"Bulk retraining completed",
|
||
|
|
tenant_id=tenant_id,
|
||
|
|
total=len(product_ids),
|
||
|
|
successful=successful,
|
||
|
|
failed=len(product_ids) - successful
|
||
|
|
)
|
||
|
|
|
||
|
|
return {
|
||
|
|
"status": "completed",
|
||
|
|
"tenant_id": str(tenant_id),
|
||
|
|
"total_products": len(product_ids),
|
||
|
|
"successful": successful,
|
||
|
|
"failed": len(product_ids) - successful,
|
||
|
|
"results": results
|
||
|
|
}
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(
|
||
|
|
"Bulk retraining failed",
|
||
|
|
tenant_id=tenant_id,
|
||
|
|
error=str(e)
|
||
|
|
)
|
||
|
|
raise DatabaseError(f"Bulk retraining failed: {str(e)}")
|
||
|
|
|
||
|
|
async def check_and_trigger_scheduled_retraining(
|
||
|
|
self,
|
||
|
|
tenant_id: uuid.UUID,
|
||
|
|
max_model_age_days: int = 30
|
||
|
|
) -> Dict[str, Any]:
|
||
|
|
"""
|
||
|
|
Check model ages and trigger retraining for outdated models
|
||
|
|
|
||
|
|
Args:
|
||
|
|
tenant_id: Tenant identifier
|
||
|
|
max_model_age_days: Maximum acceptable model age
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Scheduled retraining results
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
logger.info(
|
||
|
|
"Checking for scheduled retraining needs",
|
||
|
|
tenant_id=tenant_id,
|
||
|
|
max_model_age_days=max_model_age_days
|
||
|
|
)
|
||
|
|
|
||
|
|
# Get model age analysis
|
||
|
|
model_age_analysis = await self.performance_service.check_model_age(
|
||
|
|
tenant_id=tenant_id,
|
||
|
|
max_age_days=max_model_age_days
|
||
|
|
)
|
||
|
|
|
||
|
|
outdated_count = model_age_analysis.get("outdated_models", 0)
|
||
|
|
|
||
|
|
if outdated_count == 0:
|
||
|
|
logger.info(
|
||
|
|
"No outdated models found",
|
||
|
|
tenant_id=tenant_id
|
||
|
|
)
|
||
|
|
return {
|
||
|
|
"status": "no_action_needed",
|
||
|
|
"tenant_id": str(tenant_id),
|
||
|
|
"outdated_models": 0
|
||
|
|
}
|
||
|
|
|
||
|
|
# TODO: Trigger retraining for outdated models
|
||
|
|
# Would need to get list of outdated products from training service
|
||
|
|
|
||
|
|
return {
|
||
|
|
"status": "analyzed",
|
||
|
|
"tenant_id": str(tenant_id),
|
||
|
|
"outdated_models": outdated_count,
|
||
|
|
"message": "Scheduled retraining analysis complete"
|
||
|
|
}
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(
|
||
|
|
"Scheduled retraining check failed",
|
||
|
|
tenant_id=tenant_id,
|
||
|
|
error=str(e)
|
||
|
|
)
|
||
|
|
raise DatabaseError(f"Scheduled retraining check failed: {str(e)}")
|
||
|
|
|
||
|
|
async def get_retraining_recommendations(
|
||
|
|
self,
|
||
|
|
tenant_id: uuid.UUID
|
||
|
|
) -> Dict[str, Any]:
|
||
|
|
"""
|
||
|
|
Get retraining recommendations without triggering
|
||
|
|
|
||
|
|
Args:
|
||
|
|
tenant_id: Tenant identifier
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Recommendations for manual review
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
# Evaluate without auto-triggering
|
||
|
|
result = await self.evaluate_and_trigger_retraining(
|
||
|
|
tenant_id=tenant_id,
|
||
|
|
auto_trigger=False
|
||
|
|
)
|
||
|
|
|
||
|
|
# Extract just the recommendations
|
||
|
|
report = result.get("performance_report", {})
|
||
|
|
recommendations = report.get("recommendations", [])
|
||
|
|
|
||
|
|
return {
|
||
|
|
"tenant_id": str(tenant_id),
|
||
|
|
"generated_at": datetime.now(timezone.utc).isoformat(),
|
||
|
|
"requires_action": result.get("requires_action", False),
|
||
|
|
"recommendations": recommendations,
|
||
|
|
"summary": report.get("summary", {}),
|
||
|
|
"degradation_detected": report.get("degradation_analysis", {}).get("is_degrading", False)
|
||
|
|
}
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(
|
||
|
|
"Failed to get retraining recommendations",
|
||
|
|
tenant_id=tenant_id,
|
||
|
|
error=str(e)
|
||
|
|
)
|
||
|
|
raise DatabaseError(f"Failed to get recommendations: {str(e)}")
|