184 lines
6.6 KiB
Python
184 lines
6.6 KiB
Python
# services/forecasting/app/services/model_client.py
|
|
"""
|
|
Forecast Service Model Client
|
|
Demonstrates calling training service to get models
|
|
"""
|
|
|
|
import structlog
|
|
from typing import Dict, Any, List, Optional
|
|
|
|
# Import shared clients - no more code duplication!
|
|
from shared.clients import get_service_clients, get_training_client, get_data_client
|
|
from app.core.config import settings
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
class ModelClient:
|
|
"""
|
|
Client for managing models in forecasting service
|
|
Shows how to call multiple services cleanly
|
|
"""
|
|
|
|
def __init__(self):
|
|
# Option 1: Get all clients at once
|
|
self.clients = get_service_clients(settings, "forecasting")
|
|
|
|
# Option 2: Get specific clients
|
|
# self.training_client = get_training_client(settings, "forecasting")
|
|
# self.data_client = get_data_client(settings, "forecasting")
|
|
|
|
async def get_available_models(
|
|
self,
|
|
tenant_id: str,
|
|
model_type: Optional[str] = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Get available trained models from training service
|
|
"""
|
|
try:
|
|
models = await self.clients.training.list_models(
|
|
tenant_id=tenant_id,
|
|
status="deployed", # Only get deployed models
|
|
model_type=model_type
|
|
)
|
|
|
|
if models:
|
|
logger.info(f"Found {len(models)} available models",
|
|
tenant_id=tenant_id, model_type=model_type)
|
|
return models
|
|
else:
|
|
logger.warning("No available models found", tenant_id=tenant_id)
|
|
return []
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error fetching available models: {e}", tenant_id=tenant_id)
|
|
return []
|
|
|
|
async def get_best_model_for_forecasting(
|
|
self,
|
|
tenant_id: str,
|
|
product_id: Optional[str] = None
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Get the best model for forecasting based on performance metrics
|
|
"""
|
|
try:
|
|
# Get latest model
|
|
latest_model = await self.clients.training.get_active_model_for_product(
|
|
tenant_id=tenant_id,
|
|
model_type="forecasting",
|
|
product_name=product_id
|
|
)
|
|
|
|
if not latest_model:
|
|
logger.warning("No trained models found", tenant_id=tenant_id)
|
|
return None
|
|
|
|
# Get model metrics to validate quality
|
|
metrics = await self.clients.training.get_model_metrics(
|
|
tenant_id=tenant_id,
|
|
model_id=latest_model["id"]
|
|
)
|
|
|
|
if metrics and metrics.get("accuracy", 0) > 0.7: # 70% accuracy threshold
|
|
logger.info(f"Selected model {latest_model['id']} with accuracy {metrics.get('accuracy')}",
|
|
tenant_id=tenant_id)
|
|
return latest_model
|
|
else:
|
|
logger.warning(f"Model accuracy too low: {metrics.get('accuracy', 'unknown')}",
|
|
tenant_id=tenant_id)
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error selecting best model: {e}", tenant_id=tenant_id)
|
|
return None
|
|
|
|
async def validate_model_data_compatibility(
|
|
self,
|
|
tenant_id: str,
|
|
model_id: str,
|
|
forecast_start_date: str,
|
|
forecast_end_date: str
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Validate that we have sufficient data for the model to make forecasts
|
|
Demonstrates calling both training and data services
|
|
"""
|
|
try:
|
|
# Get model details from training service
|
|
model = await self.clients.training.get_model(
|
|
tenant_id=tenant_id,
|
|
model_id=model_id
|
|
)
|
|
|
|
if not model:
|
|
return {"is_valid": False, "error": "Model not found"}
|
|
|
|
# Get data statistics from data service
|
|
data_stats = await self.clients.data.get_data_statistics(
|
|
tenant_id=tenant_id,
|
|
start_date=forecast_start_date,
|
|
end_date=forecast_end_date
|
|
)
|
|
|
|
if not data_stats:
|
|
return {"is_valid": False, "error": "Could not retrieve data statistics"}
|
|
|
|
# Check if we have minimum required data points
|
|
min_required = model.get("metadata", {}).get("min_data_points", 30)
|
|
available_points = data_stats.get("total_records", 0)
|
|
|
|
is_valid = available_points >= min_required
|
|
|
|
result = {
|
|
"is_valid": is_valid,
|
|
"model_id": model_id,
|
|
"required_points": min_required,
|
|
"available_points": available_points,
|
|
"data_coverage": data_stats.get("coverage_percentage", 0)
|
|
}
|
|
|
|
if not is_valid:
|
|
result["error"] = f"Insufficient data: need {min_required}, have {available_points}"
|
|
|
|
logger.info("Model data compatibility check completed",
|
|
tenant_id=tenant_id, model_id=model_id, is_valid=is_valid)
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error validating model compatibility: {e}",
|
|
tenant_id=tenant_id, model_id=model_id)
|
|
return {"is_valid": False, "error": str(e)}
|
|
|
|
async def trigger_model_retraining(
|
|
self,
|
|
tenant_id: str,
|
|
include_weather: bool = True,
|
|
include_traffic: bool = False
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Trigger a new training job if current model is outdated
|
|
"""
|
|
try:
|
|
# Create training job through training service
|
|
job = await self.clients.training.create_training_job(
|
|
tenant_id=tenant_id,
|
|
include_weather=include_weather,
|
|
include_traffic=include_traffic,
|
|
min_data_points=50 # Higher threshold for forecasting
|
|
)
|
|
|
|
if job:
|
|
logger.info(f"Training job created: {job['job_id']}", tenant_id=tenant_id)
|
|
return job
|
|
else:
|
|
logger.error("Failed to create training job", tenant_id=tenant_id)
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error triggering model retraining: {e}", tenant_id=tenant_id)
|
|
return None
|
|
|
|
# Global instance
|
|
model_client = ModelClient() |