Add forecasting service
This commit is contained in:
438
services/forecasting/app/services/forecasting_service.py
Normal file
438
services/forecasting/app/services/forecasting_service.py
Normal file
@@ -0,0 +1,438 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/services/forecasting_service.py
|
||||
# ================================================================
|
||||
"""
|
||||
Main forecasting service business logic
|
||||
Orchestrates demand prediction operations
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from typing import Dict, List, Any, Optional
|
||||
from datetime import datetime, date, timedelta
|
||||
import asyncio
|
||||
import uuid
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, desc
|
||||
import httpx
|
||||
|
||||
from app.models.forecasts import Forecast, PredictionBatch, ForecastAlert
|
||||
from app.schemas.forecasts import ForecastRequest, BatchForecastRequest, BusinessType
|
||||
from app.services.prediction_service import PredictionService
|
||||
from app.services.messaging import publish_forecast_completed, publish_alert_created
|
||||
from app.core.config import settings
|
||||
from shared.monitoring.metrics import MetricsCollector
|
||||
|
||||
logger = structlog.get_logger()
|
||||
metrics = MetricsCollector("forecasting-service")
|
||||
|
||||
class ForecastingService:
|
||||
"""
|
||||
Main service class for managing forecasting operations.
|
||||
Handles demand prediction, batch processing, and alert generation.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.prediction_service = PredictionService()
|
||||
|
||||
async def generate_forecast(self, request: ForecastRequest, db: AsyncSession) -> Forecast:
|
||||
"""Generate a single forecast for a product"""
|
||||
start_time = datetime.now()
|
||||
|
||||
try:
|
||||
logger.info("Generating forecast",
|
||||
tenant_id=request.tenant_id,
|
||||
product=request.product_name,
|
||||
date=request.forecast_date)
|
||||
|
||||
# Get the latest trained model for this tenant/product
|
||||
model_info = await self._get_latest_model(
|
||||
request.tenant_id,
|
||||
request.product_name,
|
||||
request.location
|
||||
)
|
||||
|
||||
if not model_info:
|
||||
raise ValueError(f"No trained model found for {request.product_name}")
|
||||
|
||||
# Prepare features for prediction
|
||||
features = await self._prepare_forecast_features(request)
|
||||
|
||||
# Generate prediction using ML service
|
||||
prediction_result = await self.prediction_service.predict(
|
||||
model_id=model_info["model_id"],
|
||||
features=features,
|
||||
confidence_level=request.confidence_level
|
||||
)
|
||||
|
||||
# Create forecast record
|
||||
forecast = Forecast(
|
||||
tenant_id=uuid.UUID(request.tenant_id),
|
||||
product_name=request.product_name,
|
||||
location=request.location,
|
||||
forecast_date=datetime.combine(request.forecast_date, datetime.min.time()),
|
||||
|
||||
# Prediction results
|
||||
predicted_demand=prediction_result["demand"],
|
||||
confidence_lower=prediction_result["lower_bound"],
|
||||
confidence_upper=prediction_result["upper_bound"],
|
||||
confidence_level=request.confidence_level,
|
||||
|
||||
# Model information
|
||||
model_id=uuid.UUID(model_info["model_id"]),
|
||||
model_version=model_info["version"],
|
||||
algorithm=model_info.get("algorithm", "prophet"),
|
||||
|
||||
# Context
|
||||
business_type=request.business_type.value,
|
||||
day_of_week=request.forecast_date.weekday(),
|
||||
is_holiday=features.get("is_holiday", False),
|
||||
is_weekend=request.forecast_date.weekday() >= 5,
|
||||
|
||||
# External factors
|
||||
weather_temperature=features.get("temperature"),
|
||||
weather_precipitation=features.get("precipitation"),
|
||||
weather_description=features.get("weather_description"),
|
||||
traffic_volume=features.get("traffic_volume"),
|
||||
|
||||
# Metadata
|
||||
processing_time_ms=int((datetime.now() - start_time).total_seconds() * 1000),
|
||||
features_used=features
|
||||
)
|
||||
|
||||
db.add(forecast)
|
||||
await db.commit()
|
||||
await db.refresh(forecast)
|
||||
|
||||
# Check for alerts
|
||||
await self._check_and_create_alerts(forecast, db)
|
||||
|
||||
# Update metrics
|
||||
metrics.increment_counter("forecasts_generated_total",
|
||||
{"product": request.product_name, "location": request.location})
|
||||
|
||||
# Publish event
|
||||
await publish_forecast_completed({
|
||||
"forecast_id": str(forecast.id),
|
||||
"tenant_id": request.tenant_id,
|
||||
"product_name": request.product_name,
|
||||
"predicted_demand": forecast.predicted_demand
|
||||
})
|
||||
|
||||
logger.info("Forecast generated successfully",
|
||||
forecast_id=str(forecast.id),
|
||||
predicted_demand=forecast.predicted_demand)
|
||||
|
||||
return forecast
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error generating forecast",
|
||||
error=str(e),
|
||||
tenant_id=request.tenant_id,
|
||||
product=request.product_name)
|
||||
raise
|
||||
|
||||
async def generate_batch_forecast(self, request: BatchForecastRequest, db: AsyncSession) -> PredictionBatch:
|
||||
"""Generate forecasts for multiple products over multiple days"""
|
||||
|
||||
try:
|
||||
logger.info("Starting batch forecast generation",
|
||||
tenant_id=request.tenant_id,
|
||||
batch_name=request.batch_name,
|
||||
products_count=len(request.products),
|
||||
forecast_days=request.forecast_days)
|
||||
|
||||
# Create batch record
|
||||
batch = PredictionBatch(
|
||||
tenant_id=uuid.UUID(request.tenant_id),
|
||||
batch_name=request.batch_name,
|
||||
status="processing",
|
||||
total_products=len(request.products) * request.forecast_days,
|
||||
business_type=request.business_type.value,
|
||||
forecast_days=request.forecast_days
|
||||
)
|
||||
|
||||
db.add(batch)
|
||||
await db.commit()
|
||||
await db.refresh(batch)
|
||||
|
||||
# Generate forecasts for each product and day
|
||||
completed_count = 0
|
||||
failed_count = 0
|
||||
|
||||
for product in request.products:
|
||||
for day_offset in range(request.forecast_days):
|
||||
forecast_date = date.today() + timedelta(days=day_offset + 1)
|
||||
|
||||
try:
|
||||
forecast_request = ForecastRequest(
|
||||
tenant_id=request.tenant_id,
|
||||
product_name=product,
|
||||
location=request.location,
|
||||
forecast_date=forecast_date,
|
||||
business_type=request.business_type,
|
||||
include_weather=request.include_weather,
|
||||
include_traffic=request.include_traffic,
|
||||
confidence_level=request.confidence_level
|
||||
)
|
||||
|
||||
await self.generate_forecast(forecast_request, db)
|
||||
completed_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to generate forecast for product",
|
||||
product=product,
|
||||
date=forecast_date,
|
||||
error=str(e))
|
||||
failed_count += 1
|
||||
|
||||
# Update batch status
|
||||
batch.status = "completed" if failed_count == 0 else "partial"
|
||||
batch.completed_products = completed_count
|
||||
batch.failed_products = failed_count
|
||||
batch.completed_at = datetime.now()
|
||||
|
||||
await db.commit()
|
||||
|
||||
logger.info("Batch forecast generation completed",
|
||||
batch_id=str(batch.id),
|
||||
completed=completed_count,
|
||||
failed=failed_count)
|
||||
|
||||
return batch
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in batch forecast generation", error=str(e))
|
||||
raise
|
||||
|
||||
async def get_forecasts(self, tenant_id: str, location: str,
|
||||
start_date: Optional[date] = None,
|
||||
end_date: Optional[date] = None,
|
||||
product_name: Optional[str] = None,
|
||||
db: AsyncSession = None) -> List[Forecast]:
|
||||
"""Retrieve forecasts with filtering"""
|
||||
|
||||
try:
|
||||
query = select(Forecast).where(
|
||||
and_(
|
||||
Forecast.tenant_id == uuid.UUID(tenant_id),
|
||||
Forecast.location == location
|
||||
)
|
||||
)
|
||||
|
||||
if start_date:
|
||||
query = query.where(Forecast.forecast_date >= datetime.combine(start_date, datetime.min.time()))
|
||||
|
||||
if end_date:
|
||||
query = query.where(Forecast.forecast_date <= datetime.combine(end_date, datetime.max.time()))
|
||||
|
||||
if product_name:
|
||||
query = query.where(Forecast.product_name == product_name)
|
||||
|
||||
query = query.order_by(desc(Forecast.forecast_date))
|
||||
|
||||
result = await db.execute(query)
|
||||
forecasts = result.scalars().all()
|
||||
|
||||
logger.info("Retrieved forecasts",
|
||||
tenant_id=tenant_id,
|
||||
count=len(forecasts))
|
||||
|
||||
return list(forecasts)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error retrieving forecasts", error=str(e))
|
||||
raise
|
||||
|
||||
async def _get_latest_model(self, tenant_id: str, product_name: str, location: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get the latest trained model for a tenant/product combination"""
|
||||
|
||||
try:
|
||||
# Call training service to get model information
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.TRAINING_SERVICE_URL}/api/v1/models/latest",
|
||||
params={
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name,
|
||||
"location": location
|
||||
},
|
||||
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
elif response.status_code == 404:
|
||||
logger.warning("No model found",
|
||||
tenant_id=tenant_id,
|
||||
product=product_name)
|
||||
return None
|
||||
else:
|
||||
response.raise_for_status()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting latest model", error=str(e))
|
||||
raise
|
||||
|
||||
async def _prepare_forecast_features(self, request: ForecastRequest) -> Dict[str, Any]:
|
||||
"""Prepare features for forecasting model"""
|
||||
|
||||
features = {
|
||||
"date": request.forecast_date.isoformat(),
|
||||
"day_of_week": request.forecast_date.weekday(),
|
||||
"is_weekend": request.forecast_date.weekday() >= 5,
|
||||
"business_type": request.business_type.value
|
||||
}
|
||||
|
||||
# Add Spanish holidays
|
||||
features["is_holiday"] = await self._is_spanish_holiday(request.forecast_date)
|
||||
|
||||
# Add weather data if requested
|
||||
if request.include_weather:
|
||||
weather_data = await self._get_weather_forecast(request.forecast_date)
|
||||
features.update(weather_data)
|
||||
|
||||
# Add traffic data if requested
|
||||
if request.include_traffic:
|
||||
traffic_data = await self._get_traffic_forecast(request.forecast_date, request.location)
|
||||
features.update(traffic_data)
|
||||
|
||||
return features
|
||||
|
||||
async def _is_spanish_holiday(self, forecast_date: date) -> bool:
|
||||
"""Check if date is a Spanish holiday"""
|
||||
|
||||
try:
|
||||
# Call data service for holiday information
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.DATA_SERVICE_URL}/api/v1/holidays/check",
|
||||
params={"date": forecast_date.isoformat()},
|
||||
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json().get("is_holiday", False)
|
||||
else:
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Error checking holiday status", error=str(e))
|
||||
return False
|
||||
|
||||
async def _get_weather_forecast(self, forecast_date: date) -> Dict[str, Any]:
|
||||
"""Get weather forecast for the date"""
|
||||
|
||||
try:
|
||||
# Call data service for weather forecast
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.DATA_SERVICE_URL}/api/v1/weather/forecast",
|
||||
params={"date": forecast_date.isoformat()},
|
||||
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
weather = response.json()
|
||||
return {
|
||||
"temperature": weather.get("temperature"),
|
||||
"precipitation": weather.get("precipitation"),
|
||||
"humidity": weather.get("humidity"),
|
||||
"weather_description": weather.get("description")
|
||||
}
|
||||
else:
|
||||
return {}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Error getting weather forecast", error=str(e))
|
||||
return {}
|
||||
|
||||
async def _get_traffic_forecast(self, forecast_date: date, location: str) -> Dict[str, Any]:
|
||||
"""Get traffic forecast for the date and location"""
|
||||
|
||||
try:
|
||||
# Call data service for traffic forecast
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.DATA_SERVICE_URL}/api/v1/traffic/forecast",
|
||||
params={
|
||||
"date": forecast_date.isoformat(),
|
||||
"location": location
|
||||
},
|
||||
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
traffic = response.json()
|
||||
return {
|
||||
"traffic_volume": traffic.get("volume"),
|
||||
"pedestrian_count": traffic.get("pedestrian_count")
|
||||
}
|
||||
else:
|
||||
return {}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Error getting traffic forecast", error=str(e))
|
||||
return {}
|
||||
|
||||
async def _check_and_create_alerts(self, forecast: Forecast, db: AsyncSession):
|
||||
"""Check forecast and create alerts if needed"""
|
||||
|
||||
try:
|
||||
alerts_to_create = []
|
||||
|
||||
# High demand alert
|
||||
if forecast.predicted_demand > settings.HIGH_DEMAND_THRESHOLD * 100: # Assuming base of 100 units
|
||||
alerts_to_create.append({
|
||||
"type": "high_demand",
|
||||
"severity": "medium",
|
||||
"message": f"High demand predicted for {forecast.product_name}: {forecast.predicted_demand:.0f} units"
|
||||
})
|
||||
|
||||
# Low demand alert
|
||||
if forecast.predicted_demand < settings.LOW_DEMAND_THRESHOLD * 100:
|
||||
alerts_to_create.append({
|
||||
"type": "low_demand",
|
||||
"severity": "low",
|
||||
"message": f"Low demand predicted for {forecast.product_name}: {forecast.predicted_demand:.0f} units"
|
||||
})
|
||||
|
||||
# Stockout risk alert
|
||||
if forecast.confidence_upper > settings.STOCKOUT_RISK_THRESHOLD * forecast.predicted_demand:
|
||||
alerts_to_create.append({
|
||||
"type": "stockout_risk",
|
||||
"severity": "high",
|
||||
"message": f"Stockout risk for {forecast.product_name}. Upper confidence: {forecast.confidence_upper:.0f}"
|
||||
})
|
||||
|
||||
# Create alerts
|
||||
for alert_data in alerts_to_create:
|
||||
alert = ForecastAlert(
|
||||
tenant_id=forecast.tenant_id,
|
||||
forecast_id=forecast.id,
|
||||
alert_type=alert_data["type"],
|
||||
severity=alert_data["severity"],
|
||||
message=alert_data["message"]
|
||||
)
|
||||
|
||||
db.add(alert)
|
||||
|
||||
# Publish alert event
|
||||
await publish_alert_created({
|
||||
"alert_id": str(alert.id),
|
||||
"tenant_id": str(forecast.tenant_id),
|
||||
"product_name": forecast.product_name,
|
||||
"alert_type": alert_data["type"],
|
||||
"severity": alert_data["severity"],
|
||||
"message": alert_data["message"]
|
||||
})
|
||||
|
||||
await db.commit()
|
||||
|
||||
if alerts_to_create:
|
||||
logger.info("Created forecast alerts",
|
||||
forecast_id=str(forecast.id),
|
||||
alerts_count=len(alerts_to_create))
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error creating alerts", error=str(e))
|
||||
# Don't raise - alerts are not critical for forecast generation
|
||||
98
services/forecasting/app/services/messaging.py
Normal file
98
services/forecasting/app/services/messaging.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/services/messaging.py
|
||||
# ================================================================
|
||||
"""
|
||||
Messaging service for event publishing and consuming
|
||||
"""
|
||||
|
||||
import structlog
|
||||
import json
|
||||
from typing import Dict, Any
|
||||
import asyncio
|
||||
|
||||
from shared.messaging.rabbitmq import RabbitMQPublisher, RabbitMQConsumer
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Global messaging instances
|
||||
publisher = None
|
||||
consumer = None
|
||||
|
||||
async def setup_messaging():
|
||||
"""Initialize messaging services"""
|
||||
global publisher, consumer
|
||||
|
||||
try:
|
||||
# Initialize publisher
|
||||
publisher = RabbitMQPublisher(settings.RABBITMQ_URL)
|
||||
await publisher.connect()
|
||||
|
||||
# Initialize consumer
|
||||
consumer = RabbitMQConsumer(settings.RABBITMQ_URL)
|
||||
await consumer.connect()
|
||||
|
||||
# Set up event handlers
|
||||
await consumer.subscribe("training.model.updated", handle_model_updated)
|
||||
await consumer.subscribe("data.weather.updated", handle_weather_updated)
|
||||
|
||||
logger.info("Messaging setup completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to setup messaging", error=str(e))
|
||||
raise
|
||||
|
||||
async def cleanup_messaging():
|
||||
"""Cleanup messaging connections"""
|
||||
global publisher, consumer
|
||||
|
||||
try:
|
||||
if consumer:
|
||||
await consumer.close()
|
||||
if publisher:
|
||||
await publisher.close()
|
||||
|
||||
logger.info("Messaging cleanup completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error during messaging cleanup", error=str(e))
|
||||
|
||||
async def publish_forecast_completed(data: Dict[str, Any]):
|
||||
"""Publish forecast completed event"""
|
||||
if publisher:
|
||||
await publisher.publish("forecasting.forecast.completed", data)
|
||||
|
||||
async def publish_alert_created(data: Dict[str, Any]):
|
||||
"""Publish alert created event"""
|
||||
if publisher:
|
||||
await publisher.publish("forecasting.alert.created", data)
|
||||
|
||||
async def publish_batch_completed(data: Dict[str, Any]):
|
||||
"""Publish batch forecast completed event"""
|
||||
if publisher:
|
||||
await publisher.publish("forecasting.batch.completed", data)
|
||||
|
||||
# Event handlers
|
||||
async def handle_model_updated(data: Dict[str, Any]):
|
||||
"""Handle model updated event from training service"""
|
||||
try:
|
||||
logger.info("Received model updated event",
|
||||
model_id=data.get("model_id"),
|
||||
tenant_id=data.get("tenant_id"))
|
||||
|
||||
# Clear model cache for this model
|
||||
# This will be handled by PredictionService
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error handling model updated event", error=str(e))
|
||||
|
||||
async def handle_weather_updated(data: Dict[str, Any]):
|
||||
"""Handle weather data updated event"""
|
||||
try:
|
||||
logger.info("Received weather updated event",
|
||||
date=data.get("date"))
|
||||
|
||||
# Could trigger re-forecasting if needed
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error handling weather updated event", error=str(e))
|
||||
166
services/forecasting/app/services/prediction_service.py
Normal file
166
services/forecasting/app/services/prediction_service.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/services/prediction_service.py
|
||||
# ================================================================
|
||||
"""
|
||||
Prediction service for loading models and generating predictions
|
||||
Handles the actual ML prediction logic
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from typing import Dict, List, Any, Optional
|
||||
import asyncio
|
||||
import pickle
|
||||
import json
|
||||
from datetime import datetime, date
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import httpx
|
||||
from pathlib import Path
|
||||
|
||||
from app.core.config import settings
|
||||
from shared.monitoring.metrics import MetricsCollector
|
||||
|
||||
logger = structlog.get_logger()
|
||||
metrics = MetricsCollector("forecasting-service")
|
||||
|
||||
class PredictionService:
|
||||
"""
|
||||
Service for loading ML models and generating predictions
|
||||
Interfaces with trained Prophet models from the training service
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.model_cache = {}
|
||||
self.cache_ttl = 3600 # 1 hour cache
|
||||
|
||||
async def predict(self, model_id: str, features: Dict[str, Any],
|
||||
confidence_level: float = 0.8) -> Dict[str, float]:
|
||||
"""Generate prediction using trained model"""
|
||||
|
||||
start_time = datetime.now()
|
||||
|
||||
try:
|
||||
logger.info("Generating prediction",
|
||||
model_id=model_id,
|
||||
features_count=len(features))
|
||||
|
||||
# Load model
|
||||
model = await self._load_model(model_id)
|
||||
|
||||
if not model:
|
||||
raise ValueError(f"Model {model_id} not found or failed to load")
|
||||
|
||||
# Prepare features for Prophet
|
||||
df = self._prepare_prophet_features(features)
|
||||
|
||||
# Generate prediction
|
||||
forecast = model.predict(df)
|
||||
|
||||
# Extract prediction results
|
||||
if len(forecast) > 0:
|
||||
row = forecast.iloc[0]
|
||||
result = {
|
||||
"demand": float(row['yhat']),
|
||||
"lower_bound": float(row[f'yhat_lower']),
|
||||
"upper_bound": float(row[f'yhat_upper']),
|
||||
"trend": float(row.get('trend', 0)),
|
||||
"seasonal": float(row.get('seasonal', 0)),
|
||||
"holiday": float(row.get('holidays', 0))
|
||||
}
|
||||
else:
|
||||
raise ValueError("No prediction generated from model")
|
||||
|
||||
# Update metrics
|
||||
processing_time = (datetime.now() - start_time).total_seconds()
|
||||
metrics.histogram_observe("forecast_processing_time_seconds", processing_time)
|
||||
|
||||
logger.info("Prediction generated successfully",
|
||||
model_id=model_id,
|
||||
predicted_demand=result["demand"],
|
||||
processing_time_ms=int(processing_time * 1000))
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error generating prediction",
|
||||
model_id=model_id,
|
||||
error=str(e))
|
||||
raise
|
||||
|
||||
async def _load_model(self, model_id: str):
|
||||
"""Load model from cache or training service"""
|
||||
|
||||
# Check cache first
|
||||
if model_id in self.model_cache:
|
||||
cached_model, cached_time = self.model_cache[model_id]
|
||||
if (datetime.now() - cached_time).seconds < self.cache_ttl:
|
||||
logger.debug("Using cached model", model_id=model_id)
|
||||
return cached_model
|
||||
|
||||
try:
|
||||
# Download model from training service
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(
|
||||
f"{settings.TRAINING_SERVICE_URL}/api/v1/models/{model_id}/download",
|
||||
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
# Load model from bytes
|
||||
model_data = response.content
|
||||
model = pickle.loads(model_data)
|
||||
|
||||
# Cache the model
|
||||
self.model_cache[model_id] = (model, datetime.now())
|
||||
|
||||
logger.info("Model loaded successfully", model_id=model_id)
|
||||
return model
|
||||
else:
|
||||
logger.error("Failed to download model",
|
||||
model_id=model_id,
|
||||
status_code=response.status_code)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error loading model", model_id=model_id, error=str(e))
|
||||
return None
|
||||
|
||||
def _prepare_prophet_features(self, features: Dict[str, Any]) -> pd.DataFrame:
|
||||
"""Convert features to Prophet-compatible DataFrame"""
|
||||
|
||||
try:
|
||||
# Create base DataFrame with required 'ds' column
|
||||
df = pd.DataFrame({
|
||||
'ds': [pd.to_datetime(features['date'])]
|
||||
})
|
||||
|
||||
# Add numeric features
|
||||
numeric_features = [
|
||||
'temperature', 'precipitation', 'humidity', 'wind_speed',
|
||||
'traffic_volume', 'pedestrian_count'
|
||||
]
|
||||
|
||||
for feature in numeric_features:
|
||||
if feature in features and features[feature] is not None:
|
||||
df[feature] = float(features[feature])
|
||||
else:
|
||||
df[feature] = 0.0
|
||||
|
||||
# Add categorical features
|
||||
df['day_of_week'] = int(features.get('day_of_week', 0))
|
||||
df['is_weekend'] = int(features.get('is_weekend', False))
|
||||
df['is_holiday'] = int(features.get('is_holiday', False))
|
||||
|
||||
# Business type encoding
|
||||
business_type = features.get('business_type', 'individual')
|
||||
df['is_central_workshop'] = int(business_type == 'central_workshop')
|
||||
|
||||
logger.debug("Prepared Prophet features",
|
||||
features_count=len(df.columns),
|
||||
date=features['date'])
|
||||
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error preparing Prophet features", error=str(e))
|
||||
raise
|
||||
Reference in New Issue
Block a user