Files
bakery-ia/services/forecasting/app/services/retraining_trigger_service.py
2025-12-05 20:07:01 +01:00

487 lines
18 KiB
Python

# ================================================================
# services/forecasting/app/services/retraining_trigger_service.py
# ================================================================
"""
Retraining Trigger Service
Automatically triggers model retraining based on performance metrics,
accuracy degradation, or data availability.
"""
from typing import Dict, Any, List, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from datetime import datetime, timezone
import structlog
import uuid
from app.services.performance_monitoring_service import PerformanceMonitoringService
from shared.clients.training_client import TrainingServiceClient
from shared.config.base import BaseServiceSettings
from shared.database.exceptions import DatabaseError
logger = structlog.get_logger()
class RetrainingTriggerService:
"""Service for triggering automatic model retraining"""
def __init__(self, db_session: AsyncSession):
self.db = db_session
self.performance_service = PerformanceMonitoringService(db_session)
# Initialize training client
config = BaseServiceSettings()
self.training_client = TrainingServiceClient(config, calling_service_name="forecasting")
async def evaluate_and_trigger_retraining(
self,
tenant_id: uuid.UUID,
auto_trigger: bool = True
) -> Dict[str, Any]:
"""
Evaluate performance and trigger retraining if needed
Args:
tenant_id: Tenant identifier
auto_trigger: Whether to automatically trigger retraining
Returns:
Evaluation results and retraining actions taken
"""
try:
logger.info(
"Evaluating retraining needs",
tenant_id=tenant_id,
auto_trigger=auto_trigger
)
# Generate performance report
report = await self.performance_service.generate_performance_report(
tenant_id=tenant_id,
days=30
)
if not report.get("requires_action"):
logger.info(
"No retraining required",
tenant_id=tenant_id,
health_status=report["summary"].get("health_status")
)
return {
"status": "no_action_needed",
"tenant_id": str(tenant_id),
"health_status": report["summary"].get("health_status"),
"report": report
}
# Extract products that need retraining
products_to_retrain = []
recommendations = report.get("recommendations", [])
for rec in recommendations:
if rec.get("action") == "retrain_poor_performers":
products_to_retrain.extend(rec.get("products", []))
if not products_to_retrain and auto_trigger:
# If degradation detected but no specific products, consider retraining all
degradation = report.get("degradation_analysis", {})
if degradation.get("is_degrading") and degradation.get("severity") in ["high", "medium"]:
logger.info(
"General degradation detected, considering full retraining",
tenant_id=tenant_id,
severity=degradation.get("severity")
)
retraining_results = []
if auto_trigger and products_to_retrain:
# Trigger retraining for poor performers
for product in products_to_retrain:
try:
result = await self._trigger_product_retraining(
tenant_id=tenant_id,
inventory_product_id=uuid.UUID(product["inventory_product_id"]),
reason=f"MAPE {product['avg_mape']}% exceeds threshold",
priority="high"
)
retraining_results.append(result)
except Exception as e:
logger.error(
"Failed to trigger retraining for product",
product_id=product["inventory_product_id"],
error=str(e)
)
retraining_results.append({
"product_id": product["inventory_product_id"],
"status": "failed",
"error": str(e)
})
logger.info(
"Retraining evaluation complete",
tenant_id=tenant_id,
products_evaluated=len(products_to_retrain),
retraining_triggered=len(retraining_results)
)
return {
"status": "evaluated",
"tenant_id": str(tenant_id),
"requires_action": report.get("requires_action"),
"products_needing_retraining": len(products_to_retrain),
"retraining_triggered": len(retraining_results),
"auto_trigger_enabled": auto_trigger,
"retraining_results": retraining_results,
"performance_report": report
}
except Exception as e:
logger.error(
"Failed to evaluate and trigger retraining",
tenant_id=tenant_id,
error=str(e)
)
raise DatabaseError(f"Failed to evaluate retraining: {str(e)}")
async def _trigger_product_retraining(
self,
tenant_id: uuid.UUID,
inventory_product_id: uuid.UUID,
reason: str,
priority: str = "normal"
) -> Dict[str, Any]:
"""
Trigger retraining for a specific product
Args:
tenant_id: Tenant identifier
inventory_product_id: Product to retrain
reason: Reason for retraining
priority: Priority level (low, normal, high)
Returns:
Retraining trigger result
"""
try:
logger.info(
"Triggering product retraining",
tenant_id=tenant_id,
product_id=inventory_product_id,
reason=reason,
priority=priority
)
# Call training service to trigger retraining
result = await self.training_client.trigger_retrain(
tenant_id=str(tenant_id),
inventory_product_id=str(inventory_product_id),
reason=reason,
priority=priority
)
if result:
logger.info(
"Retraining triggered successfully",
tenant_id=tenant_id,
product_id=inventory_product_id,
training_job_id=result.get("training_job_id")
)
return {
"status": "triggered",
"product_id": str(inventory_product_id),
"training_job_id": result.get("training_job_id"),
"reason": reason,
"priority": priority,
"triggered_at": datetime.now(timezone.utc).isoformat()
}
else:
logger.warning(
"Retraining trigger returned no result",
tenant_id=tenant_id,
product_id=inventory_product_id
)
return {
"status": "no_response",
"product_id": str(inventory_product_id),
"reason": reason
}
except Exception as e:
logger.error(
"Failed to trigger product retraining",
tenant_id=tenant_id,
product_id=inventory_product_id,
error=str(e)
)
return {
"status": "failed",
"product_id": str(inventory_product_id),
"error": str(e),
"reason": reason
}
async def trigger_bulk_retraining(
self,
tenant_id: uuid.UUID,
product_ids: List[uuid.UUID],
reason: str = "Bulk retraining requested"
) -> Dict[str, Any]:
"""
Trigger retraining for multiple products
Args:
tenant_id: Tenant identifier
product_ids: List of products to retrain
reason: Reason for bulk retraining
Returns:
Bulk retraining results
"""
try:
logger.info(
"Triggering bulk retraining",
tenant_id=tenant_id,
product_count=len(product_ids)
)
results = []
for product_id in product_ids:
result = await self._trigger_product_retraining(
tenant_id=tenant_id,
inventory_product_id=product_id,
reason=reason,
priority="normal"
)
results.append(result)
successful = sum(1 for r in results if r["status"] == "triggered")
logger.info(
"Bulk retraining completed",
tenant_id=tenant_id,
total=len(product_ids),
successful=successful,
failed=len(product_ids) - successful
)
return {
"status": "completed",
"tenant_id": str(tenant_id),
"total_products": len(product_ids),
"successful": successful,
"failed": len(product_ids) - successful,
"results": results
}
except Exception as e:
logger.error(
"Bulk retraining failed",
tenant_id=tenant_id,
error=str(e)
)
raise DatabaseError(f"Bulk retraining failed: {str(e)}")
async def check_and_trigger_scheduled_retraining(
self,
tenant_id: uuid.UUID,
max_model_age_days: int = 30
) -> Dict[str, Any]:
"""
Check model ages and trigger retraining for outdated models
Args:
tenant_id: Tenant identifier
max_model_age_days: Maximum acceptable model age
Returns:
Scheduled retraining results
"""
try:
logger.info(
"Checking for scheduled retraining needs",
tenant_id=tenant_id,
max_model_age_days=max_model_age_days
)
# Get model age analysis
model_age_analysis = await self.performance_service.check_model_age(
tenant_id=tenant_id,
max_age_days=max_model_age_days
)
outdated_count = model_age_analysis.get("outdated_models", 0)
if outdated_count == 0:
logger.info(
"No outdated models found",
tenant_id=tenant_id
)
return {
"status": "no_action_needed",
"tenant_id": str(tenant_id),
"outdated_models": 0
}
# Trigger retraining for outdated models
try:
from shared.clients.training_client import TrainingServiceClient
from shared.config.base import get_settings
from shared.messaging import get_rabbitmq_client
config = get_settings()
training_client = TrainingServiceClient(config, "forecasting")
# Get list of models that need retraining
outdated_models = await training_client.get_outdated_models(
tenant_id=str(tenant_id),
max_age_days=max_model_age_days,
min_accuracy=0.85, # Configurable threshold
min_new_data_points=1000 # Configurable threshold
)
if not outdated_models:
logger.info("No specific models returned for retraining", tenant_id=tenant_id)
return {
"status": "no_models_found",
"tenant_id": str(tenant_id),
"outdated_models": outdated_count
}
# Publish retraining events to RabbitMQ for each model
rabbitmq_client = get_rabbitmq_client()
triggered_models = []
if rabbitmq_client:
for model in outdated_models:
try:
import uuid as uuid_module
from datetime import datetime
retraining_event = {
"event_id": str(uuid_module.uuid4()),
"event_type": "training.retrain.requested",
"timestamp": datetime.utcnow().isoformat(),
"tenant_id": str(tenant_id),
"data": {
"model_id": model.get('id'),
"product_id": model.get('product_id'),
"model_type": model.get('model_type'),
"current_accuracy": model.get('accuracy'),
"model_age_days": model.get('age_days'),
"new_data_points": model.get('new_data_points', 0),
"trigger_reason": model.get('trigger_reason', 'scheduled_check'),
"priority": model.get('priority', 'normal'),
"requested_by": "system_scheduled_check"
}
}
await rabbitmq_client.publish_event(
exchange_name="training.events",
routing_key="training.retrain.requested",
event_data=retraining_event
)
triggered_models.append({
'model_id': model.get('id'),
'product_id': model.get('product_id'),
'event_id': retraining_event['event_id']
})
logger.info(
"Published retraining request",
model_id=model.get('id'),
product_id=model.get('product_id'),
event_id=retraining_event['event_id'],
trigger_reason=model.get('trigger_reason')
)
except Exception as publish_error:
logger.error(
"Failed to publish retraining event",
model_id=model.get('id'),
error=str(publish_error)
)
# Continue with other models even if one fails
else:
logger.warning(
"RabbitMQ client not available, cannot trigger retraining",
tenant_id=tenant_id
)
return {
"status": "retraining_triggered",
"tenant_id": str(tenant_id),
"outdated_models": outdated_count,
"triggered_count": len(triggered_models),
"triggered_models": triggered_models,
"message": f"Triggered retraining for {len(triggered_models)} models"
}
except Exception as trigger_error:
logger.error(
"Failed to trigger retraining",
tenant_id=tenant_id,
error=str(trigger_error),
exc_info=True
)
# Return analysis result even if triggering failed
return {
"status": "trigger_failed",
"tenant_id": str(tenant_id),
"outdated_models": outdated_count,
"error": str(trigger_error),
"message": "Analysis complete but failed to trigger retraining"
}
except Exception as e:
logger.error(
"Scheduled retraining check failed",
tenant_id=tenant_id,
error=str(e)
)
raise DatabaseError(f"Scheduled retraining check failed: {str(e)}")
async def get_retraining_recommendations(
self,
tenant_id: uuid.UUID
) -> Dict[str, Any]:
"""
Get retraining recommendations without triggering
Args:
tenant_id: Tenant identifier
Returns:
Recommendations for manual review
"""
try:
# Evaluate without auto-triggering
result = await self.evaluate_and_trigger_retraining(
tenant_id=tenant_id,
auto_trigger=False
)
# Extract just the recommendations
report = result.get("performance_report", {})
recommendations = report.get("recommendations", [])
return {
"tenant_id": str(tenant_id),
"generated_at": datetime.now(timezone.utc).isoformat(),
"requires_action": result.get("requires_action", False),
"recommendations": recommendations,
"summary": report.get("summary", {}),
"degradation_detected": report.get("degradation_analysis", {}).get("is_degrading", False)
}
except Exception as e:
logger.error(
"Failed to get retraining recommendations",
tenant_id=tenant_id,
error=str(e)
)
raise DatabaseError(f"Failed to get recommendations: {str(e)}")