487 lines
18 KiB
Python
487 lines
18 KiB
Python
# ================================================================
|
|
# services/forecasting/app/services/retraining_trigger_service.py
|
|
# ================================================================
|
|
"""
|
|
Retraining Trigger Service
|
|
|
|
Automatically triggers model retraining based on performance metrics,
|
|
accuracy degradation, or data availability.
|
|
"""
|
|
|
|
from typing import Dict, Any, List, Optional
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from datetime import datetime, timezone
|
|
import structlog
|
|
import uuid
|
|
|
|
from app.services.performance_monitoring_service import PerformanceMonitoringService
|
|
from shared.clients.training_client import TrainingServiceClient
|
|
from shared.config.base import BaseServiceSettings
|
|
from shared.database.exceptions import DatabaseError
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class RetrainingTriggerService:
|
|
"""Service for triggering automatic model retraining"""
|
|
|
|
def __init__(self, db_session: AsyncSession):
|
|
self.db = db_session
|
|
self.performance_service = PerformanceMonitoringService(db_session)
|
|
|
|
# Initialize training client
|
|
config = BaseServiceSettings()
|
|
self.training_client = TrainingServiceClient(config, calling_service_name="forecasting")
|
|
|
|
async def evaluate_and_trigger_retraining(
|
|
self,
|
|
tenant_id: uuid.UUID,
|
|
auto_trigger: bool = True
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Evaluate performance and trigger retraining if needed
|
|
|
|
Args:
|
|
tenant_id: Tenant identifier
|
|
auto_trigger: Whether to automatically trigger retraining
|
|
|
|
Returns:
|
|
Evaluation results and retraining actions taken
|
|
"""
|
|
try:
|
|
logger.info(
|
|
"Evaluating retraining needs",
|
|
tenant_id=tenant_id,
|
|
auto_trigger=auto_trigger
|
|
)
|
|
|
|
# Generate performance report
|
|
report = await self.performance_service.generate_performance_report(
|
|
tenant_id=tenant_id,
|
|
days=30
|
|
)
|
|
|
|
if not report.get("requires_action"):
|
|
logger.info(
|
|
"No retraining required",
|
|
tenant_id=tenant_id,
|
|
health_status=report["summary"].get("health_status")
|
|
)
|
|
return {
|
|
"status": "no_action_needed",
|
|
"tenant_id": str(tenant_id),
|
|
"health_status": report["summary"].get("health_status"),
|
|
"report": report
|
|
}
|
|
|
|
# Extract products that need retraining
|
|
products_to_retrain = []
|
|
recommendations = report.get("recommendations", [])
|
|
|
|
for rec in recommendations:
|
|
if rec.get("action") == "retrain_poor_performers":
|
|
products_to_retrain.extend(rec.get("products", []))
|
|
|
|
if not products_to_retrain and auto_trigger:
|
|
# If degradation detected but no specific products, consider retraining all
|
|
degradation = report.get("degradation_analysis", {})
|
|
if degradation.get("is_degrading") and degradation.get("severity") in ["high", "medium"]:
|
|
logger.info(
|
|
"General degradation detected, considering full retraining",
|
|
tenant_id=tenant_id,
|
|
severity=degradation.get("severity")
|
|
)
|
|
|
|
retraining_results = []
|
|
|
|
if auto_trigger and products_to_retrain:
|
|
# Trigger retraining for poor performers
|
|
for product in products_to_retrain:
|
|
try:
|
|
result = await self._trigger_product_retraining(
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=uuid.UUID(product["inventory_product_id"]),
|
|
reason=f"MAPE {product['avg_mape']}% exceeds threshold",
|
|
priority="high"
|
|
)
|
|
retraining_results.append(result)
|
|
except Exception as e:
|
|
logger.error(
|
|
"Failed to trigger retraining for product",
|
|
product_id=product["inventory_product_id"],
|
|
error=str(e)
|
|
)
|
|
retraining_results.append({
|
|
"product_id": product["inventory_product_id"],
|
|
"status": "failed",
|
|
"error": str(e)
|
|
})
|
|
|
|
logger.info(
|
|
"Retraining evaluation complete",
|
|
tenant_id=tenant_id,
|
|
products_evaluated=len(products_to_retrain),
|
|
retraining_triggered=len(retraining_results)
|
|
)
|
|
|
|
return {
|
|
"status": "evaluated",
|
|
"tenant_id": str(tenant_id),
|
|
"requires_action": report.get("requires_action"),
|
|
"products_needing_retraining": len(products_to_retrain),
|
|
"retraining_triggered": len(retraining_results),
|
|
"auto_trigger_enabled": auto_trigger,
|
|
"retraining_results": retraining_results,
|
|
"performance_report": report
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
"Failed to evaluate and trigger retraining",
|
|
tenant_id=tenant_id,
|
|
error=str(e)
|
|
)
|
|
raise DatabaseError(f"Failed to evaluate retraining: {str(e)}")
|
|
|
|
async def _trigger_product_retraining(
|
|
self,
|
|
tenant_id: uuid.UUID,
|
|
inventory_product_id: uuid.UUID,
|
|
reason: str,
|
|
priority: str = "normal"
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Trigger retraining for a specific product
|
|
|
|
Args:
|
|
tenant_id: Tenant identifier
|
|
inventory_product_id: Product to retrain
|
|
reason: Reason for retraining
|
|
priority: Priority level (low, normal, high)
|
|
|
|
Returns:
|
|
Retraining trigger result
|
|
"""
|
|
try:
|
|
logger.info(
|
|
"Triggering product retraining",
|
|
tenant_id=tenant_id,
|
|
product_id=inventory_product_id,
|
|
reason=reason,
|
|
priority=priority
|
|
)
|
|
|
|
# Call training service to trigger retraining
|
|
result = await self.training_client.trigger_retrain(
|
|
tenant_id=str(tenant_id),
|
|
inventory_product_id=str(inventory_product_id),
|
|
reason=reason,
|
|
priority=priority
|
|
)
|
|
|
|
if result:
|
|
logger.info(
|
|
"Retraining triggered successfully",
|
|
tenant_id=tenant_id,
|
|
product_id=inventory_product_id,
|
|
training_job_id=result.get("training_job_id")
|
|
)
|
|
|
|
return {
|
|
"status": "triggered",
|
|
"product_id": str(inventory_product_id),
|
|
"training_job_id": result.get("training_job_id"),
|
|
"reason": reason,
|
|
"priority": priority,
|
|
"triggered_at": datetime.now(timezone.utc).isoformat()
|
|
}
|
|
else:
|
|
logger.warning(
|
|
"Retraining trigger returned no result",
|
|
tenant_id=tenant_id,
|
|
product_id=inventory_product_id
|
|
)
|
|
return {
|
|
"status": "no_response",
|
|
"product_id": str(inventory_product_id),
|
|
"reason": reason
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
"Failed to trigger product retraining",
|
|
tenant_id=tenant_id,
|
|
product_id=inventory_product_id,
|
|
error=str(e)
|
|
)
|
|
return {
|
|
"status": "failed",
|
|
"product_id": str(inventory_product_id),
|
|
"error": str(e),
|
|
"reason": reason
|
|
}
|
|
|
|
async def trigger_bulk_retraining(
|
|
self,
|
|
tenant_id: uuid.UUID,
|
|
product_ids: List[uuid.UUID],
|
|
reason: str = "Bulk retraining requested"
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Trigger retraining for multiple products
|
|
|
|
Args:
|
|
tenant_id: Tenant identifier
|
|
product_ids: List of products to retrain
|
|
reason: Reason for bulk retraining
|
|
|
|
Returns:
|
|
Bulk retraining results
|
|
"""
|
|
try:
|
|
logger.info(
|
|
"Triggering bulk retraining",
|
|
tenant_id=tenant_id,
|
|
product_count=len(product_ids)
|
|
)
|
|
|
|
results = []
|
|
|
|
for product_id in product_ids:
|
|
result = await self._trigger_product_retraining(
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=product_id,
|
|
reason=reason,
|
|
priority="normal"
|
|
)
|
|
results.append(result)
|
|
|
|
successful = sum(1 for r in results if r["status"] == "triggered")
|
|
|
|
logger.info(
|
|
"Bulk retraining completed",
|
|
tenant_id=tenant_id,
|
|
total=len(product_ids),
|
|
successful=successful,
|
|
failed=len(product_ids) - successful
|
|
)
|
|
|
|
return {
|
|
"status": "completed",
|
|
"tenant_id": str(tenant_id),
|
|
"total_products": len(product_ids),
|
|
"successful": successful,
|
|
"failed": len(product_ids) - successful,
|
|
"results": results
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
"Bulk retraining failed",
|
|
tenant_id=tenant_id,
|
|
error=str(e)
|
|
)
|
|
raise DatabaseError(f"Bulk retraining failed: {str(e)}")
|
|
|
|
async def check_and_trigger_scheduled_retraining(
|
|
self,
|
|
tenant_id: uuid.UUID,
|
|
max_model_age_days: int = 30
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Check model ages and trigger retraining for outdated models
|
|
|
|
Args:
|
|
tenant_id: Tenant identifier
|
|
max_model_age_days: Maximum acceptable model age
|
|
|
|
Returns:
|
|
Scheduled retraining results
|
|
"""
|
|
try:
|
|
logger.info(
|
|
"Checking for scheduled retraining needs",
|
|
tenant_id=tenant_id,
|
|
max_model_age_days=max_model_age_days
|
|
)
|
|
|
|
# Get model age analysis
|
|
model_age_analysis = await self.performance_service.check_model_age(
|
|
tenant_id=tenant_id,
|
|
max_age_days=max_model_age_days
|
|
)
|
|
|
|
outdated_count = model_age_analysis.get("outdated_models", 0)
|
|
|
|
if outdated_count == 0:
|
|
logger.info(
|
|
"No outdated models found",
|
|
tenant_id=tenant_id
|
|
)
|
|
return {
|
|
"status": "no_action_needed",
|
|
"tenant_id": str(tenant_id),
|
|
"outdated_models": 0
|
|
}
|
|
|
|
# Trigger retraining for outdated models
|
|
try:
|
|
from shared.clients.training_client import TrainingServiceClient
|
|
from shared.config.base import get_settings
|
|
from shared.messaging import get_rabbitmq_client
|
|
|
|
config = get_settings()
|
|
training_client = TrainingServiceClient(config, "forecasting")
|
|
|
|
# Get list of models that need retraining
|
|
outdated_models = await training_client.get_outdated_models(
|
|
tenant_id=str(tenant_id),
|
|
max_age_days=max_model_age_days,
|
|
min_accuracy=0.85, # Configurable threshold
|
|
min_new_data_points=1000 # Configurable threshold
|
|
)
|
|
|
|
if not outdated_models:
|
|
logger.info("No specific models returned for retraining", tenant_id=tenant_id)
|
|
return {
|
|
"status": "no_models_found",
|
|
"tenant_id": str(tenant_id),
|
|
"outdated_models": outdated_count
|
|
}
|
|
|
|
# Publish retraining events to RabbitMQ for each model
|
|
rabbitmq_client = get_rabbitmq_client()
|
|
triggered_models = []
|
|
|
|
if rabbitmq_client:
|
|
for model in outdated_models:
|
|
try:
|
|
import uuid as uuid_module
|
|
from datetime import datetime
|
|
|
|
retraining_event = {
|
|
"event_id": str(uuid_module.uuid4()),
|
|
"event_type": "training.retrain.requested",
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"tenant_id": str(tenant_id),
|
|
"data": {
|
|
"model_id": model.get('id'),
|
|
"product_id": model.get('product_id'),
|
|
"model_type": model.get('model_type'),
|
|
"current_accuracy": model.get('accuracy'),
|
|
"model_age_days": model.get('age_days'),
|
|
"new_data_points": model.get('new_data_points', 0),
|
|
"trigger_reason": model.get('trigger_reason', 'scheduled_check'),
|
|
"priority": model.get('priority', 'normal'),
|
|
"requested_by": "system_scheduled_check"
|
|
}
|
|
}
|
|
|
|
await rabbitmq_client.publish_event(
|
|
exchange_name="training.events",
|
|
routing_key="training.retrain.requested",
|
|
event_data=retraining_event
|
|
)
|
|
|
|
triggered_models.append({
|
|
'model_id': model.get('id'),
|
|
'product_id': model.get('product_id'),
|
|
'event_id': retraining_event['event_id']
|
|
})
|
|
|
|
logger.info(
|
|
"Published retraining request",
|
|
model_id=model.get('id'),
|
|
product_id=model.get('product_id'),
|
|
event_id=retraining_event['event_id'],
|
|
trigger_reason=model.get('trigger_reason')
|
|
)
|
|
|
|
except Exception as publish_error:
|
|
logger.error(
|
|
"Failed to publish retraining event",
|
|
model_id=model.get('id'),
|
|
error=str(publish_error)
|
|
)
|
|
# Continue with other models even if one fails
|
|
|
|
else:
|
|
logger.warning(
|
|
"RabbitMQ client not available, cannot trigger retraining",
|
|
tenant_id=tenant_id
|
|
)
|
|
|
|
return {
|
|
"status": "retraining_triggered",
|
|
"tenant_id": str(tenant_id),
|
|
"outdated_models": outdated_count,
|
|
"triggered_count": len(triggered_models),
|
|
"triggered_models": triggered_models,
|
|
"message": f"Triggered retraining for {len(triggered_models)} models"
|
|
}
|
|
|
|
except Exception as trigger_error:
|
|
logger.error(
|
|
"Failed to trigger retraining",
|
|
tenant_id=tenant_id,
|
|
error=str(trigger_error),
|
|
exc_info=True
|
|
)
|
|
# Return analysis result even if triggering failed
|
|
return {
|
|
"status": "trigger_failed",
|
|
"tenant_id": str(tenant_id),
|
|
"outdated_models": outdated_count,
|
|
"error": str(trigger_error),
|
|
"message": "Analysis complete but failed to trigger retraining"
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
"Scheduled retraining check failed",
|
|
tenant_id=tenant_id,
|
|
error=str(e)
|
|
)
|
|
raise DatabaseError(f"Scheduled retraining check failed: {str(e)}")
|
|
|
|
async def get_retraining_recommendations(
|
|
self,
|
|
tenant_id: uuid.UUID
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Get retraining recommendations without triggering
|
|
|
|
Args:
|
|
tenant_id: Tenant identifier
|
|
|
|
Returns:
|
|
Recommendations for manual review
|
|
"""
|
|
try:
|
|
# Evaluate without auto-triggering
|
|
result = await self.evaluate_and_trigger_retraining(
|
|
tenant_id=tenant_id,
|
|
auto_trigger=False
|
|
)
|
|
|
|
# Extract just the recommendations
|
|
report = result.get("performance_report", {})
|
|
recommendations = report.get("recommendations", [])
|
|
|
|
return {
|
|
"tenant_id": str(tenant_id),
|
|
"generated_at": datetime.now(timezone.utc).isoformat(),
|
|
"requires_action": result.get("requires_action", False),
|
|
"recommendations": recommendations,
|
|
"summary": report.get("summary", {}),
|
|
"degradation_detected": report.get("degradation_analysis", {}).get("is_degrading", False)
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
"Failed to get retraining recommendations",
|
|
tenant_id=tenant_id,
|
|
error=str(e)
|
|
)
|
|
raise DatabaseError(f"Failed to get recommendations: {str(e)}")
|