Start fixing forecast service API 13
This commit is contained in:
@@ -1,101 +0,0 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/ml/model_loader.py
|
||||
# ================================================================
|
||||
"""
|
||||
Model loading and management utilities
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from typing import Dict, Any, Optional
|
||||
import pickle
|
||||
import json
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
import httpx
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
class ModelLoader:
|
||||
"""
|
||||
Utility class for loading and managing ML models
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.model_cache = {}
|
||||
self.metadata_cache = {}
|
||||
|
||||
async def load_model_with_metadata(self, model_id: str) -> Dict[str, Any]:
|
||||
"""Load model along with its metadata"""
|
||||
|
||||
try:
|
||||
# Get model metadata first
|
||||
metadata = await self._get_model_metadata(model_id)
|
||||
|
||||
if not metadata:
|
||||
return None
|
||||
|
||||
# Load the actual model
|
||||
model = await self._load_model_binary(model_id)
|
||||
|
||||
if not model:
|
||||
return None
|
||||
|
||||
return {
|
||||
"model": model,
|
||||
"metadata": metadata,
|
||||
"loaded_at": datetime.now()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error loading model with metadata",
|
||||
model_id=model_id,
|
||||
error=str(e))
|
||||
return None
|
||||
|
||||
async def _get_model_metadata(self, model_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get model metadata from training service"""
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.TRAINING_SERVICE_URL}/api/v1/models/{model_id}/metadata",
|
||||
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
logger.warning("Model metadata not found",
|
||||
model_id=model_id,
|
||||
status_code=response.status_code)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting model metadata", error=str(e))
|
||||
return None
|
||||
|
||||
async def _load_model_binary(self, model_id: str):
|
||||
"""Load model binary from training service"""
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.get(
|
||||
f"{settings.TRAINING_SERVICE_URL}/api/v1/models/{model_id}/download",
|
||||
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
model = pickle.loads(response.content)
|
||||
return model
|
||||
else:
|
||||
logger.error("Failed to download model binary",
|
||||
model_id=model_id,
|
||||
status_code=response.status_code)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error loading model binary", error=str(e))
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user