Start fixing forecast service API 13
This commit is contained in:
@@ -16,6 +16,7 @@ import numpy as np
|
||||
import pandas as pd
|
||||
import httpx
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
from app.core.config import settings
|
||||
from shared.monitoring.metrics import MetricsCollector
|
||||
@@ -33,7 +34,7 @@ class PredictionService:
|
||||
self.model_cache = {}
|
||||
self.cache_ttl = 3600 # 1 hour cache
|
||||
|
||||
async def predict(self, model_id: str, features: Dict[str, Any],
|
||||
async def predict(self, model_id: str, model_path: str, features: Dict[str, Any],
|
||||
confidence_level: float = 0.8) -> Dict[str, float]:
|
||||
"""Generate prediction using trained model"""
|
||||
|
||||
@@ -45,7 +46,7 @@ class PredictionService:
|
||||
features_count=len(features))
|
||||
|
||||
# Load model
|
||||
model = await self._load_model(model_id)
|
||||
model = await self._load_model(model_id, model_path)
|
||||
|
||||
if not model:
|
||||
raise ValueError(f"Model {model_id} not found or failed to load")
|
||||
@@ -87,42 +88,31 @@ class PredictionService:
|
||||
error=str(e))
|
||||
raise
|
||||
|
||||
async def _load_model(self, model_id: str):
|
||||
"""Load model from cache or training service"""
|
||||
async def _load_model(self, model_id: str, model_path: str):
|
||||
"""Load model from shared volume using API metadata"""
|
||||
|
||||
# Check cache first
|
||||
if model_id in self.model_cache:
|
||||
cached_model, cached_time = self.model_cache[model_id]
|
||||
if (datetime.now() - cached_time).seconds < self.cache_ttl:
|
||||
logger.debug("Using cached model", model_id=model_id)
|
||||
return cached_model
|
||||
|
||||
try:
|
||||
# Download model from training service
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(
|
||||
f"{settings.TRAINING_SERVICE_URL}/api/v1/models/{model_id}/download",
|
||||
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
|
||||
)
|
||||
# Load model directly from shared volume (fast!)
|
||||
if os.path.exists(model_path):
|
||||
with open(model_path, 'rb') as f:
|
||||
model = pickle.load(f)
|
||||
|
||||
# Cache the model
|
||||
self.model_cache[model_id] = (model, datetime.now())
|
||||
logger.info(f"Model loaded from shared volume: {model_path}")
|
||||
return model
|
||||
else:
|
||||
logger.error(f"Model file not found: {model_path}")
|
||||
return None
|
||||
|
||||
if response.status_code == 200:
|
||||
# Load model from bytes
|
||||
model_data = response.content
|
||||
model = pickle.loads(model_data)
|
||||
|
||||
# Cache the model
|
||||
self.model_cache[model_id] = (model, datetime.now())
|
||||
|
||||
logger.info("Model loaded successfully", model_id=model_id)
|
||||
return model
|
||||
else:
|
||||
logger.error("Failed to download model",
|
||||
model_id=model_id,
|
||||
status_code=response.status_code)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error loading model", model_id=model_id, error=str(e))
|
||||
logger.error(f"Error loading model: {e}")
|
||||
return None
|
||||
|
||||
def _prepare_prophet_features(self, features: Dict[str, Any]) -> pd.DataFrame:
|
||||
|
||||
Reference in New Issue
Block a user