Start fixing forecast service API 13
This commit is contained in:
@@ -1,101 +0,0 @@
|
|||||||
# ================================================================
|
|
||||||
# services/forecasting/app/ml/model_loader.py
|
|
||||||
# ================================================================
|
|
||||||
"""
|
|
||||||
Model loading and management utilities
|
|
||||||
"""
|
|
||||||
|
|
||||||
import structlog
|
|
||||||
from typing import Dict, Any, Optional
|
|
||||||
import pickle
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
from datetime import datetime
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from app.core.config import settings
|
|
||||||
|
|
||||||
logger = structlog.get_logger()
|
|
||||||
|
|
||||||
class ModelLoader:
|
|
||||||
"""
|
|
||||||
Utility class for loading and managing ML models
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.model_cache = {}
|
|
||||||
self.metadata_cache = {}
|
|
||||||
|
|
||||||
async def load_model_with_metadata(self, model_id: str) -> Dict[str, Any]:
|
|
||||||
"""Load model along with its metadata"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Get model metadata first
|
|
||||||
metadata = await self._get_model_metadata(model_id)
|
|
||||||
|
|
||||||
if not metadata:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Load the actual model
|
|
||||||
model = await self._load_model_binary(model_id)
|
|
||||||
|
|
||||||
if not model:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return {
|
|
||||||
"model": model,
|
|
||||||
"metadata": metadata,
|
|
||||||
"loaded_at": datetime.now()
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Error loading model with metadata",
|
|
||||||
model_id=model_id,
|
|
||||||
error=str(e))
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _get_model_metadata(self, model_id: str) -> Optional[Dict[str, Any]]:
|
|
||||||
"""Get model metadata from training service"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
response = await client.get(
|
|
||||||
f"{settings.TRAINING_SERVICE_URL}/api/v1/models/{model_id}/metadata",
|
|
||||||
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
return response.json()
|
|
||||||
else:
|
|
||||||
logger.warning("Model metadata not found",
|
|
||||||
model_id=model_id,
|
|
||||||
status_code=response.status_code)
|
|
||||||
return None
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Error getting model metadata", error=str(e))
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _load_model_binary(self, model_id: str):
|
|
||||||
"""Load model binary from training service"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
|
||||||
response = await client.get(
|
|
||||||
f"{settings.TRAINING_SERVICE_URL}/api/v1/models/{model_id}/download",
|
|
||||||
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
model = pickle.loads(response.content)
|
|
||||||
return model
|
|
||||||
else:
|
|
||||||
logger.error("Failed to download model binary",
|
|
||||||
model_id=model_id,
|
|
||||||
status_code=response.status_code)
|
|
||||||
return None
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Error loading model binary", error=str(e))
|
|
||||||
return None
|
|
||||||
|
|
||||||
@@ -63,6 +63,7 @@ class ForecastingService:
|
|||||||
# Generate prediction using ML service
|
# Generate prediction using ML service
|
||||||
prediction_result = await self.prediction_service.predict(
|
prediction_result = await self.prediction_service.predict(
|
||||||
model_id=model_info["model_id"],
|
model_id=model_info["model_id"],
|
||||||
|
model_path=model_info["model_path"],
|
||||||
features=features,
|
features=features,
|
||||||
confidence_level=request.confidence_level
|
confidence_level=request.confidence_level
|
||||||
)
|
)
|
||||||
@@ -268,7 +269,7 @@ class ForecastingService:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Add Spanish holidays
|
# Add Spanish holidays
|
||||||
features["is_holiday"] = await self._is_spanish_holiday(request.forecast_date)
|
features["is_holiday"] = self._is_spanish_holiday(request.forecast_date)
|
||||||
|
|
||||||
|
|
||||||
weather_data = await self._get_weather_forecast(tenant_id, 1)
|
weather_data = await self._get_weather_forecast(tenant_id, 1)
|
||||||
@@ -276,26 +277,26 @@ class ForecastingService:
|
|||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
async def _is_spanish_holiday(self, forecast_date: date) -> bool:
|
def _is_spanish_holiday(self, date: datetime) -> bool:
|
||||||
"""Check if date is a Spanish holiday"""
|
"""Check if a date is a major Spanish holiday"""
|
||||||
|
month_day = (date.month, date.day)
|
||||||
|
|
||||||
try:
|
# Major Spanish holidays that affect bakery sales
|
||||||
# Call data service for holiday information
|
spanish_holidays = [
|
||||||
async with httpx.AsyncClient() as client:
|
(1, 1), # New Year
|
||||||
response = await client.get(
|
(1, 6), # Epiphany (Reyes)
|
||||||
f"{settings.DATA_SERVICE_URL}/api/v1/holidays/check",
|
(5, 1), # Labour Day
|
||||||
params={"date": forecast_date.isoformat()},
|
(8, 15), # Assumption
|
||||||
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
|
(10, 12), # National Day
|
||||||
)
|
(11, 1), # All Saints
|
||||||
|
(12, 6), # Constitution
|
||||||
if response.status_code == 200:
|
(12, 8), # Immaculate Conception
|
||||||
return response.json().get("is_holiday", False)
|
(12, 25), # Christmas
|
||||||
else:
|
(5, 15), # San Isidro (Madrid patron saint)
|
||||||
return False
|
(5, 2), # Madrid Community Day
|
||||||
|
]
|
||||||
except Exception as e:
|
|
||||||
logger.warning("Error checking holiday status", error=str(e))
|
return month_day in spanish_holidays
|
||||||
return False
|
|
||||||
|
|
||||||
async def _get_weather_forecast(self, tenant_id: str, days: str) -> Dict[str, Any]:
|
async def _get_weather_forecast(self, tenant_id: str, days: str) -> Dict[str, Any]:
|
||||||
"""Get weather forecast for the date"""
|
"""Get weather forecast for the date"""
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import numpy as np
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import httpx
|
import httpx
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import os
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from shared.monitoring.metrics import MetricsCollector
|
from shared.monitoring.metrics import MetricsCollector
|
||||||
@@ -33,7 +34,7 @@ class PredictionService:
|
|||||||
self.model_cache = {}
|
self.model_cache = {}
|
||||||
self.cache_ttl = 3600 # 1 hour cache
|
self.cache_ttl = 3600 # 1 hour cache
|
||||||
|
|
||||||
async def predict(self, model_id: str, features: Dict[str, Any],
|
async def predict(self, model_id: str, model_path: str, features: Dict[str, Any],
|
||||||
confidence_level: float = 0.8) -> Dict[str, float]:
|
confidence_level: float = 0.8) -> Dict[str, float]:
|
||||||
"""Generate prediction using trained model"""
|
"""Generate prediction using trained model"""
|
||||||
|
|
||||||
@@ -45,7 +46,7 @@ class PredictionService:
|
|||||||
features_count=len(features))
|
features_count=len(features))
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
model = await self._load_model(model_id)
|
model = await self._load_model(model_id, model_path)
|
||||||
|
|
||||||
if not model:
|
if not model:
|
||||||
raise ValueError(f"Model {model_id} not found or failed to load")
|
raise ValueError(f"Model {model_id} not found or failed to load")
|
||||||
@@ -87,42 +88,31 @@ class PredictionService:
|
|||||||
error=str(e))
|
error=str(e))
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _load_model(self, model_id: str):
|
async def _load_model(self, model_id: str, model_path: str):
|
||||||
"""Load model from cache or training service"""
|
"""Load model from shared volume using API metadata"""
|
||||||
|
|
||||||
# Check cache first
|
# Check cache first
|
||||||
if model_id in self.model_cache:
|
if model_id in self.model_cache:
|
||||||
cached_model, cached_time = self.model_cache[model_id]
|
cached_model, cached_time = self.model_cache[model_id]
|
||||||
if (datetime.now() - cached_time).seconds < self.cache_ttl:
|
if (datetime.now() - cached_time).seconds < self.cache_ttl:
|
||||||
logger.debug("Using cached model", model_id=model_id)
|
|
||||||
return cached_model
|
return cached_model
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Download model from training service
|
# Load model directly from shared volume (fast!)
|
||||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
if os.path.exists(model_path):
|
||||||
response = await client.get(
|
with open(model_path, 'rb') as f:
|
||||||
f"{settings.TRAINING_SERVICE_URL}/api/v1/models/{model_id}/download",
|
model = pickle.load(f)
|
||||||
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
|
|
||||||
)
|
# Cache the model
|
||||||
|
self.model_cache[model_id] = (model, datetime.now())
|
||||||
|
logger.info(f"Model loaded from shared volume: {model_path}")
|
||||||
|
return model
|
||||||
|
else:
|
||||||
|
logger.error(f"Model file not found: {model_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
# Load model from bytes
|
|
||||||
model_data = response.content
|
|
||||||
model = pickle.loads(model_data)
|
|
||||||
|
|
||||||
# Cache the model
|
|
||||||
self.model_cache[model_id] = (model, datetime.now())
|
|
||||||
|
|
||||||
logger.info("Model loaded successfully", model_id=model_id)
|
|
||||||
return model
|
|
||||||
else:
|
|
||||||
logger.error("Failed to download model",
|
|
||||||
model_id=model_id,
|
|
||||||
status_code=response.status_code)
|
|
||||||
return None
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error loading model", model_id=model_id, error=str(e))
|
logger.error(f"Error loading model: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _prepare_prophet_features(self, features: Dict[str, Any]) -> pd.DataFrame:
|
def _prepare_prophet_features(self, features: Dict[str, Any]) -> pd.DataFrame:
|
||||||
|
|||||||
@@ -251,7 +251,7 @@ class DataServiceClient(BaseServiceClient):
|
|||||||
# Use POST request with extended timeout
|
# Use POST request with extended timeout
|
||||||
result = await self._make_request(
|
result = await self._make_request(
|
||||||
"POST",
|
"POST",
|
||||||
"weather/historical",
|
"weather/forecast",
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
data=payload,
|
data=payload,
|
||||||
timeout=2000.0 # Match original timeout
|
timeout=2000.0 # Match original timeout
|
||||||
|
|||||||
Reference in New Issue
Block a user