Improve backend
This commit is contained in:
384
services/forecasting/app/services/retraining_trigger_service.py
Normal file
384
services/forecasting/app/services/retraining_trigger_service.py
Normal file
@@ -0,0 +1,384 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/services/retraining_trigger_service.py
|
||||
# ================================================================
|
||||
"""
|
||||
Retraining Trigger Service
|
||||
|
||||
Automatically triggers model retraining based on performance metrics,
|
||||
accuracy degradation, or data availability.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from datetime import datetime, timezone
|
||||
import structlog
|
||||
import uuid
|
||||
|
||||
from app.services.performance_monitoring_service import PerformanceMonitoringService
|
||||
from shared.clients.training_client import TrainingServiceClient
|
||||
from shared.config.base import BaseServiceSettings
|
||||
from shared.database.exceptions import DatabaseError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class RetrainingTriggerService:
|
||||
"""Service for triggering automatic model retraining"""
|
||||
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self.db = db_session
|
||||
self.performance_service = PerformanceMonitoringService(db_session)
|
||||
|
||||
# Initialize training client
|
||||
config = BaseServiceSettings()
|
||||
self.training_client = TrainingServiceClient(config, calling_service_name="forecasting")
|
||||
|
||||
async def evaluate_and_trigger_retraining(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
auto_trigger: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Evaluate performance and trigger retraining if needed
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
auto_trigger: Whether to automatically trigger retraining
|
||||
|
||||
Returns:
|
||||
Evaluation results and retraining actions taken
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Evaluating retraining needs",
|
||||
tenant_id=tenant_id,
|
||||
auto_trigger=auto_trigger
|
||||
)
|
||||
|
||||
# Generate performance report
|
||||
report = await self.performance_service.generate_performance_report(
|
||||
tenant_id=tenant_id,
|
||||
days=30
|
||||
)
|
||||
|
||||
if not report.get("requires_action"):
|
||||
logger.info(
|
||||
"No retraining required",
|
||||
tenant_id=tenant_id,
|
||||
health_status=report["summary"].get("health_status")
|
||||
)
|
||||
return {
|
||||
"status": "no_action_needed",
|
||||
"tenant_id": str(tenant_id),
|
||||
"health_status": report["summary"].get("health_status"),
|
||||
"report": report
|
||||
}
|
||||
|
||||
# Extract products that need retraining
|
||||
products_to_retrain = []
|
||||
recommendations = report.get("recommendations", [])
|
||||
|
||||
for rec in recommendations:
|
||||
if rec.get("action") == "retrain_poor_performers":
|
||||
products_to_retrain.extend(rec.get("products", []))
|
||||
|
||||
if not products_to_retrain and auto_trigger:
|
||||
# If degradation detected but no specific products, consider retraining all
|
||||
degradation = report.get("degradation_analysis", {})
|
||||
if degradation.get("is_degrading") and degradation.get("severity") in ["high", "medium"]:
|
||||
logger.info(
|
||||
"General degradation detected, considering full retraining",
|
||||
tenant_id=tenant_id,
|
||||
severity=degradation.get("severity")
|
||||
)
|
||||
|
||||
retraining_results = []
|
||||
|
||||
if auto_trigger and products_to_retrain:
|
||||
# Trigger retraining for poor performers
|
||||
for product in products_to_retrain:
|
||||
try:
|
||||
result = await self._trigger_product_retraining(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=uuid.UUID(product["inventory_product_id"]),
|
||||
reason=f"MAPE {product['avg_mape']}% exceeds threshold",
|
||||
priority="high"
|
||||
)
|
||||
retraining_results.append(result)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to trigger retraining for product",
|
||||
product_id=product["inventory_product_id"],
|
||||
error=str(e)
|
||||
)
|
||||
retraining_results.append({
|
||||
"product_id": product["inventory_product_id"],
|
||||
"status": "failed",
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
logger.info(
|
||||
"Retraining evaluation complete",
|
||||
tenant_id=tenant_id,
|
||||
products_evaluated=len(products_to_retrain),
|
||||
retraining_triggered=len(retraining_results)
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "evaluated",
|
||||
"tenant_id": str(tenant_id),
|
||||
"requires_action": report.get("requires_action"),
|
||||
"products_needing_retraining": len(products_to_retrain),
|
||||
"retraining_triggered": len(retraining_results),
|
||||
"auto_trigger_enabled": auto_trigger,
|
||||
"retraining_results": retraining_results,
|
||||
"performance_report": report
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to evaluate and trigger retraining",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Failed to evaluate retraining: {str(e)}")
|
||||
|
||||
async def _trigger_product_retraining(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
inventory_product_id: uuid.UUID,
|
||||
reason: str,
|
||||
priority: str = "normal"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Trigger retraining for a specific product
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
inventory_product_id: Product to retrain
|
||||
reason: Reason for retraining
|
||||
priority: Priority level (low, normal, high)
|
||||
|
||||
Returns:
|
||||
Retraining trigger result
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Triggering product retraining",
|
||||
tenant_id=tenant_id,
|
||||
product_id=inventory_product_id,
|
||||
reason=reason,
|
||||
priority=priority
|
||||
)
|
||||
|
||||
# Call training service to trigger retraining
|
||||
result = await self.training_client.trigger_retrain(
|
||||
tenant_id=str(tenant_id),
|
||||
inventory_product_id=str(inventory_product_id),
|
||||
reason=reason,
|
||||
priority=priority
|
||||
)
|
||||
|
||||
if result:
|
||||
logger.info(
|
||||
"Retraining triggered successfully",
|
||||
tenant_id=tenant_id,
|
||||
product_id=inventory_product_id,
|
||||
training_job_id=result.get("training_job_id")
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "triggered",
|
||||
"product_id": str(inventory_product_id),
|
||||
"training_job_id": result.get("training_job_id"),
|
||||
"reason": reason,
|
||||
"priority": priority,
|
||||
"triggered_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
else:
|
||||
logger.warning(
|
||||
"Retraining trigger returned no result",
|
||||
tenant_id=tenant_id,
|
||||
product_id=inventory_product_id
|
||||
)
|
||||
return {
|
||||
"status": "no_response",
|
||||
"product_id": str(inventory_product_id),
|
||||
"reason": reason
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to trigger product retraining",
|
||||
tenant_id=tenant_id,
|
||||
product_id=inventory_product_id,
|
||||
error=str(e)
|
||||
)
|
||||
return {
|
||||
"status": "failed",
|
||||
"product_id": str(inventory_product_id),
|
||||
"error": str(e),
|
||||
"reason": reason
|
||||
}
|
||||
|
||||
async def trigger_bulk_retraining(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
product_ids: List[uuid.UUID],
|
||||
reason: str = "Bulk retraining requested"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Trigger retraining for multiple products
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
product_ids: List of products to retrain
|
||||
reason: Reason for bulk retraining
|
||||
|
||||
Returns:
|
||||
Bulk retraining results
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Triggering bulk retraining",
|
||||
tenant_id=tenant_id,
|
||||
product_count=len(product_ids)
|
||||
)
|
||||
|
||||
results = []
|
||||
|
||||
for product_id in product_ids:
|
||||
result = await self._trigger_product_retraining(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=product_id,
|
||||
reason=reason,
|
||||
priority="normal"
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
successful = sum(1 for r in results if r["status"] == "triggered")
|
||||
|
||||
logger.info(
|
||||
"Bulk retraining completed",
|
||||
tenant_id=tenant_id,
|
||||
total=len(product_ids),
|
||||
successful=successful,
|
||||
failed=len(product_ids) - successful
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"tenant_id": str(tenant_id),
|
||||
"total_products": len(product_ids),
|
||||
"successful": successful,
|
||||
"failed": len(product_ids) - successful,
|
||||
"results": results
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Bulk retraining failed",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Bulk retraining failed: {str(e)}")
|
||||
|
||||
async def check_and_trigger_scheduled_retraining(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
max_model_age_days: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Check model ages and trigger retraining for outdated models
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
max_model_age_days: Maximum acceptable model age
|
||||
|
||||
Returns:
|
||||
Scheduled retraining results
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Checking for scheduled retraining needs",
|
||||
tenant_id=tenant_id,
|
||||
max_model_age_days=max_model_age_days
|
||||
)
|
||||
|
||||
# Get model age analysis
|
||||
model_age_analysis = await self.performance_service.check_model_age(
|
||||
tenant_id=tenant_id,
|
||||
max_age_days=max_model_age_days
|
||||
)
|
||||
|
||||
outdated_count = model_age_analysis.get("outdated_models", 0)
|
||||
|
||||
if outdated_count == 0:
|
||||
logger.info(
|
||||
"No outdated models found",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
return {
|
||||
"status": "no_action_needed",
|
||||
"tenant_id": str(tenant_id),
|
||||
"outdated_models": 0
|
||||
}
|
||||
|
||||
# TODO: Trigger retraining for outdated models
|
||||
# Would need to get list of outdated products from training service
|
||||
|
||||
return {
|
||||
"status": "analyzed",
|
||||
"tenant_id": str(tenant_id),
|
||||
"outdated_models": outdated_count,
|
||||
"message": "Scheduled retraining analysis complete"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Scheduled retraining check failed",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Scheduled retraining check failed: {str(e)}")
|
||||
|
||||
async def get_retraining_recommendations(
|
||||
self,
|
||||
tenant_id: uuid.UUID
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get retraining recommendations without triggering
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
|
||||
Returns:
|
||||
Recommendations for manual review
|
||||
"""
|
||||
try:
|
||||
# Evaluate without auto-triggering
|
||||
result = await self.evaluate_and_trigger_retraining(
|
||||
tenant_id=tenant_id,
|
||||
auto_trigger=False
|
||||
)
|
||||
|
||||
# Extract just the recommendations
|
||||
report = result.get("performance_report", {})
|
||||
recommendations = report.get("recommendations", [])
|
||||
|
||||
return {
|
||||
"tenant_id": str(tenant_id),
|
||||
"generated_at": datetime.now(timezone.utc).isoformat(),
|
||||
"requires_action": result.get("requires_action", False),
|
||||
"recommendations": recommendations,
|
||||
"summary": report.get("summary", {}),
|
||||
"degradation_detected": report.get("degradation_analysis", {}).get("is_degrading", False)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get retraining recommendations",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Failed to get recommendations: {str(e)}")
|
||||
Reference in New Issue
Block a user