102 lines
3.3 KiB
Python
102 lines
3.3 KiB
Python
# ================================================================
|
|
# services/forecasting/app/ml/model_loader.py
|
|
# ================================================================
|
|
"""
|
|
Model loading and management utilities
|
|
"""
|
|
|
|
import structlog
|
|
from typing import Dict, Any, Optional
|
|
import pickle
|
|
import json
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
import httpx
|
|
|
|
from app.core.config import settings
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
class ModelLoader:
|
|
"""
|
|
Utility class for loading and managing ML models
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.model_cache = {}
|
|
self.metadata_cache = {}
|
|
|
|
async def load_model_with_metadata(self, model_id: str) -> Dict[str, Any]:
|
|
"""Load model along with its metadata"""
|
|
|
|
try:
|
|
# Get model metadata first
|
|
metadata = await self._get_model_metadata(model_id)
|
|
|
|
if not metadata:
|
|
return None
|
|
|
|
# Load the actual model
|
|
model = await self._load_model_binary(model_id)
|
|
|
|
if not model:
|
|
return None
|
|
|
|
return {
|
|
"model": model,
|
|
"metadata": metadata,
|
|
"loaded_at": datetime.now()
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error("Error loading model with metadata",
|
|
model_id=model_id,
|
|
error=str(e))
|
|
return None
|
|
|
|
async def _get_model_metadata(self, model_id: str) -> Optional[Dict[str, Any]]:
|
|
"""Get model metadata from training service"""
|
|
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(
|
|
f"{settings.TRAINING_SERVICE_URL}/api/v1/models/{model_id}/metadata",
|
|
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
return response.json()
|
|
else:
|
|
logger.warning("Model metadata not found",
|
|
model_id=model_id,
|
|
status_code=response.status_code)
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error("Error getting model metadata", error=str(e))
|
|
return None
|
|
|
|
async def _load_model_binary(self, model_id: str):
|
|
"""Load model binary from training service"""
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=60.0) as client:
|
|
response = await client.get(
|
|
f"{settings.TRAINING_SERVICE_URL}/api/v1/models/{model_id}/download",
|
|
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
model = pickle.loads(response.content)
|
|
return model
|
|
else:
|
|
logger.error("Failed to download model binary",
|
|
model_id=model_id,
|
|
status_code=response.status_code)
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error("Error loading model binary", error=str(e))
|
|
return None
|
|
|