Files
bakery-ia/services/forecasting/app/ml/model_loader.py

102 lines
3.3 KiB
Python
Raw Normal View History

2025-07-21 19:48:56 +02:00
# ================================================================
# services/forecasting/app/ml/model_loader.py
# ================================================================
"""
Model loading and management utilities
"""
import structlog
from typing import Dict, Any, Optional
import pickle
import json
from pathlib import Path
from datetime import datetime
import httpx
from app.core.config import settings
logger = structlog.get_logger()
class ModelLoader:
"""
Utility class for loading and managing ML models
"""
def __init__(self):
self.model_cache = {}
self.metadata_cache = {}
async def load_model_with_metadata(self, model_id: str) -> Dict[str, Any]:
"""Load model along with its metadata"""
try:
# Get model metadata first
metadata = await self._get_model_metadata(model_id)
if not metadata:
return None
# Load the actual model
model = await self._load_model_binary(model_id)
if not model:
return None
return {
"model": model,
"metadata": metadata,
"loaded_at": datetime.now()
}
except Exception as e:
logger.error("Error loading model with metadata",
model_id=model_id,
error=str(e))
return None
async def _get_model_metadata(self, model_id: str) -> Optional[Dict[str, Any]]:
"""Get model metadata from training service"""
try:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{settings.TRAINING_SERVICE_URL}/api/v1/models/{model_id}/metadata",
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
)
if response.status_code == 200:
return response.json()
else:
logger.warning("Model metadata not found",
model_id=model_id,
status_code=response.status_code)
return None
except Exception as e:
logger.error("Error getting model metadata", error=str(e))
return None
async def _load_model_binary(self, model_id: str):
"""Load model binary from training service"""
try:
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.get(
f"{settings.TRAINING_SERVICE_URL}/api/v1/models/{model_id}/download",
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
)
if response.status_code == 200:
model = pickle.loads(response.content)
return model
else:
logger.error("Failed to download model binary",
model_id=model_id,
status_code=response.status_code)
return None
except Exception as e:
logger.error("Error loading model binary", error=str(e))
return None