240 lines
9.4 KiB
Python
240 lines
9.4 KiB
Python
# services/forecasting/app/services/model_client.py
|
|
"""
|
|
Forecast Service Model Client
|
|
Demonstrates calling training service to get models
|
|
"""
|
|
|
|
import structlog
|
|
from typing import Dict, Any, List, Optional
|
|
|
|
# Import shared clients - no more code duplication!
|
|
from shared.clients import get_service_clients, get_training_client, get_sales_client
|
|
from shared.database.base import create_database_manager
|
|
from app.core.config import settings
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
class ModelClient:
|
|
"""
|
|
Client for managing models in forecasting service with dependency injection
|
|
Shows how to call multiple services cleanly
|
|
"""
|
|
|
|
def __init__(self, database_manager=None):
|
|
self.database_manager = database_manager or create_database_manager(
|
|
settings.DATABASE_URL, "forecasting-service"
|
|
)
|
|
|
|
# Option 1: Get all clients at once
|
|
self.clients = get_service_clients(settings, "forecasting")
|
|
|
|
# Option 2: Get specific clients
|
|
# self.training_client = get_training_client(settings, "forecasting")
|
|
# self.sales_client = get_sales_client(settings, "forecasting")
|
|
|
|
async def get_available_models(
|
|
self,
|
|
tenant_id: str,
|
|
model_type: Optional[str] = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Get available trained models from training service
|
|
"""
|
|
try:
|
|
models = await self.clients.training.list_models(
|
|
tenant_id=tenant_id,
|
|
status="deployed", # Only get deployed models
|
|
model_type=model_type
|
|
)
|
|
|
|
if models:
|
|
logger.info(f"Found {len(models)} available models",
|
|
tenant_id=tenant_id, model_type=model_type)
|
|
return models
|
|
else:
|
|
logger.warning("No available models found", tenant_id=tenant_id)
|
|
return []
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error fetching available models: {e}", tenant_id=tenant_id)
|
|
return []
|
|
|
|
async def get_best_model_for_forecasting(
|
|
self,
|
|
tenant_id: str,
|
|
inventory_product_id: Optional[str] = None
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Get the best model for forecasting based on performance metrics
|
|
"""
|
|
try:
|
|
# Get latest model
|
|
latest_model = await self.clients.training.get_active_model_for_product(
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=inventory_product_id
|
|
)
|
|
|
|
if not latest_model:
|
|
logger.warning("No trained models found", tenant_id=tenant_id)
|
|
return None
|
|
|
|
# ✅ FIX 1: Use "model_id" instead of "id"
|
|
model_id = latest_model.get("model_id")
|
|
if not model_id:
|
|
logger.error("Model response missing model_id field", tenant_id=tenant_id)
|
|
return None
|
|
|
|
# ✅ FIX 2: Handle metrics endpoint failure gracefully
|
|
try:
|
|
# Get model metrics to validate quality
|
|
metrics = await self.clients.training.get_model_metrics(
|
|
tenant_id=tenant_id,
|
|
model_id=model_id
|
|
)
|
|
|
|
# If metrics call succeeded, check accuracy threshold
|
|
if metrics and metrics.get("accuracy", 0) > 0.7: # 70% accuracy threshold
|
|
logger.info(f"Selected model {model_id} with accuracy {metrics.get('accuracy')}",
|
|
tenant_id=tenant_id)
|
|
return latest_model
|
|
elif metrics:
|
|
logger.warning(f"Model accuracy too low: {metrics.get('accuracy', 'unknown')}",
|
|
tenant_id=tenant_id)
|
|
# Still return the model even if accuracy is low - better than no prediction
|
|
logger.info("Returning model despite low accuracy - no alternative available",
|
|
tenant_id=tenant_id)
|
|
return latest_model
|
|
else:
|
|
logger.warning("No metrics returned from training service", tenant_id=tenant_id)
|
|
# Return model anyway - metrics service might be temporarily down
|
|
return latest_model
|
|
|
|
except Exception as metrics_error:
|
|
# ✅ FIX 3: If metrics endpoint fails, still return the model
|
|
logger.warning(f"Failed to get model metrics: {metrics_error}", tenant_id=tenant_id)
|
|
logger.info("Proceeding with model despite metrics failure", tenant_id=tenant_id)
|
|
return latest_model
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error selecting best model: {e}", tenant_id=tenant_id)
|
|
return None
|
|
|
|
async def get_any_model_for_tenant(
|
|
self,
|
|
tenant_id: str
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Get any available model for a tenant, used as fallback when specific product models aren't found
|
|
"""
|
|
try:
|
|
# First try to get any active models for this tenant
|
|
models = await self.get_available_models(tenant_id)
|
|
|
|
if models:
|
|
# Return the most recently trained model
|
|
sorted_models = sorted(models, key=lambda x: x.get('created_at', ''), reverse=True)
|
|
best_model = sorted_models[0]
|
|
logger.info("Found fallback model for tenant",
|
|
tenant_id=tenant_id,
|
|
model_id=best_model.get('id', 'unknown'),
|
|
inventory_product_id=best_model.get('inventory_product_id', 'unknown'))
|
|
return best_model
|
|
|
|
logger.warning("No fallback models available for tenant", tenant_id=tenant_id)
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error("Error getting fallback model for tenant",
|
|
tenant_id=tenant_id,
|
|
error=str(e))
|
|
return None
|
|
|
|
async def validate_model_data_compatibility(
|
|
self,
|
|
tenant_id: str,
|
|
model_id: str,
|
|
forecast_start_date: str,
|
|
forecast_end_date: str
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Validate that we have sufficient data for the model to make forecasts
|
|
Demonstrates calling both training and data services
|
|
"""
|
|
try:
|
|
# Get model details from training service
|
|
model = await self.clients.training.get_model(
|
|
tenant_id=tenant_id,
|
|
model_id=model_id
|
|
)
|
|
|
|
if not model:
|
|
return {"is_valid": False, "error": "Model not found"}
|
|
|
|
# Get data statistics from data service
|
|
data_stats = await self.clients.data.get_data_statistics(
|
|
tenant_id=tenant_id,
|
|
start_date=forecast_start_date,
|
|
end_date=forecast_end_date
|
|
)
|
|
|
|
if not data_stats:
|
|
return {"is_valid": False, "error": "Could not retrieve data statistics"}
|
|
|
|
# Check if we have minimum required data points
|
|
min_required = model.get("metadata", {}).get("min_data_points", 30)
|
|
available_points = data_stats.get("total_records", 0)
|
|
|
|
is_valid = available_points >= min_required
|
|
|
|
result = {
|
|
"is_valid": is_valid,
|
|
"model_id": model_id,
|
|
"required_points": min_required,
|
|
"available_points": available_points,
|
|
"data_coverage": data_stats.get("coverage_percentage", 0)
|
|
}
|
|
|
|
if not is_valid:
|
|
result["error"] = f"Insufficient data: need {min_required}, have {available_points}"
|
|
|
|
logger.info("Model data compatibility check completed",
|
|
tenant_id=tenant_id, model_id=model_id, is_valid=is_valid)
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error validating model compatibility: {e}",
|
|
tenant_id=tenant_id, model_id=model_id)
|
|
return {"is_valid": False, "error": str(e)}
|
|
|
|
async def trigger_model_retraining(
|
|
self,
|
|
tenant_id: str,
|
|
include_weather: bool = True,
|
|
include_traffic: bool = False
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Trigger a new training job if current model is outdated
|
|
"""
|
|
try:
|
|
# Create training job through training service
|
|
job = await self.clients.training.create_training_job(
|
|
tenant_id=tenant_id,
|
|
include_weather=include_weather,
|
|
include_traffic=include_traffic,
|
|
min_data_points=50 # Higher threshold for forecasting
|
|
)
|
|
|
|
if job:
|
|
logger.info(f"Training job created: {job['job_id']}", tenant_id=tenant_id)
|
|
return job
|
|
else:
|
|
logger.error("Failed to create training job", tenant_id=tenant_id)
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error triggering model retraining: {e}", tenant_id=tenant_id)
|
|
return None
|
|
|
|
# Global instance
|
|
model_client = ModelClient() |