Add forecasting service
This commit is contained in:
438
services/forecasting/app/services/forecasting_service.py
Normal file
438
services/forecasting/app/services/forecasting_service.py
Normal file
@@ -0,0 +1,438 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/services/forecasting_service.py
|
||||
# ================================================================
|
||||
"""
|
||||
Main forecasting service business logic
|
||||
Orchestrates demand prediction operations
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from typing import Dict, List, Any, Optional
|
||||
from datetime import datetime, date, timedelta
|
||||
import asyncio
|
||||
import uuid
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, desc
|
||||
import httpx
|
||||
|
||||
from app.models.forecasts import Forecast, PredictionBatch, ForecastAlert
|
||||
from app.schemas.forecasts import ForecastRequest, BatchForecastRequest, BusinessType
|
||||
from app.services.prediction_service import PredictionService
|
||||
from app.services.messaging import publish_forecast_completed, publish_alert_created
|
||||
from app.core.config import settings
|
||||
from shared.monitoring.metrics import MetricsCollector
|
||||
|
||||
logger = structlog.get_logger()
|
||||
metrics = MetricsCollector("forecasting-service")
|
||||
|
||||
class ForecastingService:
|
||||
"""
|
||||
Main service class for managing forecasting operations.
|
||||
Handles demand prediction, batch processing, and alert generation.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.prediction_service = PredictionService()
|
||||
|
||||
async def generate_forecast(self, request: ForecastRequest, db: AsyncSession) -> Forecast:
|
||||
"""Generate a single forecast for a product"""
|
||||
start_time = datetime.now()
|
||||
|
||||
try:
|
||||
logger.info("Generating forecast",
|
||||
tenant_id=request.tenant_id,
|
||||
product=request.product_name,
|
||||
date=request.forecast_date)
|
||||
|
||||
# Get the latest trained model for this tenant/product
|
||||
model_info = await self._get_latest_model(
|
||||
request.tenant_id,
|
||||
request.product_name,
|
||||
request.location
|
||||
)
|
||||
|
||||
if not model_info:
|
||||
raise ValueError(f"No trained model found for {request.product_name}")
|
||||
|
||||
# Prepare features for prediction
|
||||
features = await self._prepare_forecast_features(request)
|
||||
|
||||
# Generate prediction using ML service
|
||||
prediction_result = await self.prediction_service.predict(
|
||||
model_id=model_info["model_id"],
|
||||
features=features,
|
||||
confidence_level=request.confidence_level
|
||||
)
|
||||
|
||||
# Create forecast record
|
||||
forecast = Forecast(
|
||||
tenant_id=uuid.UUID(request.tenant_id),
|
||||
product_name=request.product_name,
|
||||
location=request.location,
|
||||
forecast_date=datetime.combine(request.forecast_date, datetime.min.time()),
|
||||
|
||||
# Prediction results
|
||||
predicted_demand=prediction_result["demand"],
|
||||
confidence_lower=prediction_result["lower_bound"],
|
||||
confidence_upper=prediction_result["upper_bound"],
|
||||
confidence_level=request.confidence_level,
|
||||
|
||||
# Model information
|
||||
model_id=uuid.UUID(model_info["model_id"]),
|
||||
model_version=model_info["version"],
|
||||
algorithm=model_info.get("algorithm", "prophet"),
|
||||
|
||||
# Context
|
||||
business_type=request.business_type.value,
|
||||
day_of_week=request.forecast_date.weekday(),
|
||||
is_holiday=features.get("is_holiday", False),
|
||||
is_weekend=request.forecast_date.weekday() >= 5,
|
||||
|
||||
# External factors
|
||||
weather_temperature=features.get("temperature"),
|
||||
weather_precipitation=features.get("precipitation"),
|
||||
weather_description=features.get("weather_description"),
|
||||
traffic_volume=features.get("traffic_volume"),
|
||||
|
||||
# Metadata
|
||||
processing_time_ms=int((datetime.now() - start_time).total_seconds() * 1000),
|
||||
features_used=features
|
||||
)
|
||||
|
||||
db.add(forecast)
|
||||
await db.commit()
|
||||
await db.refresh(forecast)
|
||||
|
||||
# Check for alerts
|
||||
await self._check_and_create_alerts(forecast, db)
|
||||
|
||||
# Update metrics
|
||||
metrics.increment_counter("forecasts_generated_total",
|
||||
{"product": request.product_name, "location": request.location})
|
||||
|
||||
# Publish event
|
||||
await publish_forecast_completed({
|
||||
"forecast_id": str(forecast.id),
|
||||
"tenant_id": request.tenant_id,
|
||||
"product_name": request.product_name,
|
||||
"predicted_demand": forecast.predicted_demand
|
||||
})
|
||||
|
||||
logger.info("Forecast generated successfully",
|
||||
forecast_id=str(forecast.id),
|
||||
predicted_demand=forecast.predicted_demand)
|
||||
|
||||
return forecast
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error generating forecast",
|
||||
error=str(e),
|
||||
tenant_id=request.tenant_id,
|
||||
product=request.product_name)
|
||||
raise
|
||||
|
||||
async def generate_batch_forecast(self, request: BatchForecastRequest, db: AsyncSession) -> PredictionBatch:
|
||||
"""Generate forecasts for multiple products over multiple days"""
|
||||
|
||||
try:
|
||||
logger.info("Starting batch forecast generation",
|
||||
tenant_id=request.tenant_id,
|
||||
batch_name=request.batch_name,
|
||||
products_count=len(request.products),
|
||||
forecast_days=request.forecast_days)
|
||||
|
||||
# Create batch record
|
||||
batch = PredictionBatch(
|
||||
tenant_id=uuid.UUID(request.tenant_id),
|
||||
batch_name=request.batch_name,
|
||||
status="processing",
|
||||
total_products=len(request.products) * request.forecast_days,
|
||||
business_type=request.business_type.value,
|
||||
forecast_days=request.forecast_days
|
||||
)
|
||||
|
||||
db.add(batch)
|
||||
await db.commit()
|
||||
await db.refresh(batch)
|
||||
|
||||
# Generate forecasts for each product and day
|
||||
completed_count = 0
|
||||
failed_count = 0
|
||||
|
||||
for product in request.products:
|
||||
for day_offset in range(request.forecast_days):
|
||||
forecast_date = date.today() + timedelta(days=day_offset + 1)
|
||||
|
||||
try:
|
||||
forecast_request = ForecastRequest(
|
||||
tenant_id=request.tenant_id,
|
||||
product_name=product,
|
||||
location=request.location,
|
||||
forecast_date=forecast_date,
|
||||
business_type=request.business_type,
|
||||
include_weather=request.include_weather,
|
||||
include_traffic=request.include_traffic,
|
||||
confidence_level=request.confidence_level
|
||||
)
|
||||
|
||||
await self.generate_forecast(forecast_request, db)
|
||||
completed_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to generate forecast for product",
|
||||
product=product,
|
||||
date=forecast_date,
|
||||
error=str(e))
|
||||
failed_count += 1
|
||||
|
||||
# Update batch status
|
||||
batch.status = "completed" if failed_count == 0 else "partial"
|
||||
batch.completed_products = completed_count
|
||||
batch.failed_products = failed_count
|
||||
batch.completed_at = datetime.now()
|
||||
|
||||
await db.commit()
|
||||
|
||||
logger.info("Batch forecast generation completed",
|
||||
batch_id=str(batch.id),
|
||||
completed=completed_count,
|
||||
failed=failed_count)
|
||||
|
||||
return batch
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in batch forecast generation", error=str(e))
|
||||
raise
|
||||
|
||||
async def get_forecasts(self, tenant_id: str, location: str,
|
||||
start_date: Optional[date] = None,
|
||||
end_date: Optional[date] = None,
|
||||
product_name: Optional[str] = None,
|
||||
db: AsyncSession = None) -> List[Forecast]:
|
||||
"""Retrieve forecasts with filtering"""
|
||||
|
||||
try:
|
||||
query = select(Forecast).where(
|
||||
and_(
|
||||
Forecast.tenant_id == uuid.UUID(tenant_id),
|
||||
Forecast.location == location
|
||||
)
|
||||
)
|
||||
|
||||
if start_date:
|
||||
query = query.where(Forecast.forecast_date >= datetime.combine(start_date, datetime.min.time()))
|
||||
|
||||
if end_date:
|
||||
query = query.where(Forecast.forecast_date <= datetime.combine(end_date, datetime.max.time()))
|
||||
|
||||
if product_name:
|
||||
query = query.where(Forecast.product_name == product_name)
|
||||
|
||||
query = query.order_by(desc(Forecast.forecast_date))
|
||||
|
||||
result = await db.execute(query)
|
||||
forecasts = result.scalars().all()
|
||||
|
||||
logger.info("Retrieved forecasts",
|
||||
tenant_id=tenant_id,
|
||||
count=len(forecasts))
|
||||
|
||||
return list(forecasts)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error retrieving forecasts", error=str(e))
|
||||
raise
|
||||
|
||||
async def _get_latest_model(self, tenant_id: str, product_name: str, location: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get the latest trained model for a tenant/product combination"""
|
||||
|
||||
try:
|
||||
# Call training service to get model information
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.TRAINING_SERVICE_URL}/api/v1/models/latest",
|
||||
params={
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name,
|
||||
"location": location
|
||||
},
|
||||
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
elif response.status_code == 404:
|
||||
logger.warning("No model found",
|
||||
tenant_id=tenant_id,
|
||||
product=product_name)
|
||||
return None
|
||||
else:
|
||||
response.raise_for_status()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting latest model", error=str(e))
|
||||
raise
|
||||
|
||||
async def _prepare_forecast_features(self, request: ForecastRequest) -> Dict[str, Any]:
|
||||
"""Prepare features for forecasting model"""
|
||||
|
||||
features = {
|
||||
"date": request.forecast_date.isoformat(),
|
||||
"day_of_week": request.forecast_date.weekday(),
|
||||
"is_weekend": request.forecast_date.weekday() >= 5,
|
||||
"business_type": request.business_type.value
|
||||
}
|
||||
|
||||
# Add Spanish holidays
|
||||
features["is_holiday"] = await self._is_spanish_holiday(request.forecast_date)
|
||||
|
||||
# Add weather data if requested
|
||||
if request.include_weather:
|
||||
weather_data = await self._get_weather_forecast(request.forecast_date)
|
||||
features.update(weather_data)
|
||||
|
||||
# Add traffic data if requested
|
||||
if request.include_traffic:
|
||||
traffic_data = await self._get_traffic_forecast(request.forecast_date, request.location)
|
||||
features.update(traffic_data)
|
||||
|
||||
return features
|
||||
|
||||
async def _is_spanish_holiday(self, forecast_date: date) -> bool:
|
||||
"""Check if date is a Spanish holiday"""
|
||||
|
||||
try:
|
||||
# Call data service for holiday information
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.DATA_SERVICE_URL}/api/v1/holidays/check",
|
||||
params={"date": forecast_date.isoformat()},
|
||||
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json().get("is_holiday", False)
|
||||
else:
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Error checking holiday status", error=str(e))
|
||||
return False
|
||||
|
||||
async def _get_weather_forecast(self, forecast_date: date) -> Dict[str, Any]:
|
||||
"""Get weather forecast for the date"""
|
||||
|
||||
try:
|
||||
# Call data service for weather forecast
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.DATA_SERVICE_URL}/api/v1/weather/forecast",
|
||||
params={"date": forecast_date.isoformat()},
|
||||
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
weather = response.json()
|
||||
return {
|
||||
"temperature": weather.get("temperature"),
|
||||
"precipitation": weather.get("precipitation"),
|
||||
"humidity": weather.get("humidity"),
|
||||
"weather_description": weather.get("description")
|
||||
}
|
||||
else:
|
||||
return {}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Error getting weather forecast", error=str(e))
|
||||
return {}
|
||||
|
||||
async def _get_traffic_forecast(self, forecast_date: date, location: str) -> Dict[str, Any]:
|
||||
"""Get traffic forecast for the date and location"""
|
||||
|
||||
try:
|
||||
# Call data service for traffic forecast
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.DATA_SERVICE_URL}/api/v1/traffic/forecast",
|
||||
params={
|
||||
"date": forecast_date.isoformat(),
|
||||
"location": location
|
||||
},
|
||||
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
traffic = response.json()
|
||||
return {
|
||||
"traffic_volume": traffic.get("volume"),
|
||||
"pedestrian_count": traffic.get("pedestrian_count")
|
||||
}
|
||||
else:
|
||||
return {}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Error getting traffic forecast", error=str(e))
|
||||
return {}
|
||||
|
||||
async def _check_and_create_alerts(self, forecast: Forecast, db: AsyncSession):
|
||||
"""Check forecast and create alerts if needed"""
|
||||
|
||||
try:
|
||||
alerts_to_create = []
|
||||
|
||||
# High demand alert
|
||||
if forecast.predicted_demand > settings.HIGH_DEMAND_THRESHOLD * 100: # Assuming base of 100 units
|
||||
alerts_to_create.append({
|
||||
"type": "high_demand",
|
||||
"severity": "medium",
|
||||
"message": f"High demand predicted for {forecast.product_name}: {forecast.predicted_demand:.0f} units"
|
||||
})
|
||||
|
||||
# Low demand alert
|
||||
if forecast.predicted_demand < settings.LOW_DEMAND_THRESHOLD * 100:
|
||||
alerts_to_create.append({
|
||||
"type": "low_demand",
|
||||
"severity": "low",
|
||||
"message": f"Low demand predicted for {forecast.product_name}: {forecast.predicted_demand:.0f} units"
|
||||
})
|
||||
|
||||
# Stockout risk alert
|
||||
if forecast.confidence_upper > settings.STOCKOUT_RISK_THRESHOLD * forecast.predicted_demand:
|
||||
alerts_to_create.append({
|
||||
"type": "stockout_risk",
|
||||
"severity": "high",
|
||||
"message": f"Stockout risk for {forecast.product_name}. Upper confidence: {forecast.confidence_upper:.0f}"
|
||||
})
|
||||
|
||||
# Create alerts
|
||||
for alert_data in alerts_to_create:
|
||||
alert = ForecastAlert(
|
||||
tenant_id=forecast.tenant_id,
|
||||
forecast_id=forecast.id,
|
||||
alert_type=alert_data["type"],
|
||||
severity=alert_data["severity"],
|
||||
message=alert_data["message"]
|
||||
)
|
||||
|
||||
db.add(alert)
|
||||
|
||||
# Publish alert event
|
||||
await publish_alert_created({
|
||||
"alert_id": str(alert.id),
|
||||
"tenant_id": str(forecast.tenant_id),
|
||||
"product_name": forecast.product_name,
|
||||
"alert_type": alert_data["type"],
|
||||
"severity": alert_data["severity"],
|
||||
"message": alert_data["message"]
|
||||
})
|
||||
|
||||
await db.commit()
|
||||
|
||||
if alerts_to_create:
|
||||
logger.info("Created forecast alerts",
|
||||
forecast_id=str(forecast.id),
|
||||
alerts_count=len(alerts_to_create))
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error creating alerts", error=str(e))
|
||||
# Don't raise - alerts are not critical for forecast generation
|
||||
Reference in New Issue
Block a user