Add forecasting service

This commit is contained in:
Urtzi Alfaro
2025-07-21 19:48:56 +02:00
parent 2d85dd3e9e
commit 0e7ca10a29
24 changed files with 2937 additions and 179 deletions

View File

@@ -0,0 +1,438 @@
# ================================================================
# services/forecasting/app/services/forecasting_service.py
# ================================================================
"""
Main forecasting service business logic
Orchestrates demand prediction operations
"""
import structlog
from typing import Dict, List, Any, Optional
from datetime import datetime, date, timedelta
import asyncio
import uuid
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, desc
import httpx
from app.models.forecasts import Forecast, PredictionBatch, ForecastAlert
from app.schemas.forecasts import ForecastRequest, BatchForecastRequest, BusinessType
from app.services.prediction_service import PredictionService
from app.services.messaging import publish_forecast_completed, publish_alert_created
from app.core.config import settings
from shared.monitoring.metrics import MetricsCollector
logger = structlog.get_logger()
metrics = MetricsCollector("forecasting-service")
class ForecastingService:
"""
Main service class for managing forecasting operations.
Handles demand prediction, batch processing, and alert generation.
"""
def __init__(self):
self.prediction_service = PredictionService()
async def generate_forecast(self, request: ForecastRequest, db: AsyncSession) -> Forecast:
"""Generate a single forecast for a product"""
start_time = datetime.now()
try:
logger.info("Generating forecast",
tenant_id=request.tenant_id,
product=request.product_name,
date=request.forecast_date)
# Get the latest trained model for this tenant/product
model_info = await self._get_latest_model(
request.tenant_id,
request.product_name,
request.location
)
if not model_info:
raise ValueError(f"No trained model found for {request.product_name}")
# Prepare features for prediction
features = await self._prepare_forecast_features(request)
# Generate prediction using ML service
prediction_result = await self.prediction_service.predict(
model_id=model_info["model_id"],
features=features,
confidence_level=request.confidence_level
)
# Create forecast record
forecast = Forecast(
tenant_id=uuid.UUID(request.tenant_id),
product_name=request.product_name,
location=request.location,
forecast_date=datetime.combine(request.forecast_date, datetime.min.time()),
# Prediction results
predicted_demand=prediction_result["demand"],
confidence_lower=prediction_result["lower_bound"],
confidence_upper=prediction_result["upper_bound"],
confidence_level=request.confidence_level,
# Model information
model_id=uuid.UUID(model_info["model_id"]),
model_version=model_info["version"],
algorithm=model_info.get("algorithm", "prophet"),
# Context
business_type=request.business_type.value,
day_of_week=request.forecast_date.weekday(),
is_holiday=features.get("is_holiday", False),
is_weekend=request.forecast_date.weekday() >= 5,
# External factors
weather_temperature=features.get("temperature"),
weather_precipitation=features.get("precipitation"),
weather_description=features.get("weather_description"),
traffic_volume=features.get("traffic_volume"),
# Metadata
processing_time_ms=int((datetime.now() - start_time).total_seconds() * 1000),
features_used=features
)
db.add(forecast)
await db.commit()
await db.refresh(forecast)
# Check for alerts
await self._check_and_create_alerts(forecast, db)
# Update metrics
metrics.increment_counter("forecasts_generated_total",
{"product": request.product_name, "location": request.location})
# Publish event
await publish_forecast_completed({
"forecast_id": str(forecast.id),
"tenant_id": request.tenant_id,
"product_name": request.product_name,
"predicted_demand": forecast.predicted_demand
})
logger.info("Forecast generated successfully",
forecast_id=str(forecast.id),
predicted_demand=forecast.predicted_demand)
return forecast
except Exception as e:
logger.error("Error generating forecast",
error=str(e),
tenant_id=request.tenant_id,
product=request.product_name)
raise
async def generate_batch_forecast(self, request: BatchForecastRequest, db: AsyncSession) -> PredictionBatch:
"""Generate forecasts for multiple products over multiple days"""
try:
logger.info("Starting batch forecast generation",
tenant_id=request.tenant_id,
batch_name=request.batch_name,
products_count=len(request.products),
forecast_days=request.forecast_days)
# Create batch record
batch = PredictionBatch(
tenant_id=uuid.UUID(request.tenant_id),
batch_name=request.batch_name,
status="processing",
total_products=len(request.products) * request.forecast_days,
business_type=request.business_type.value,
forecast_days=request.forecast_days
)
db.add(batch)
await db.commit()
await db.refresh(batch)
# Generate forecasts for each product and day
completed_count = 0
failed_count = 0
for product in request.products:
for day_offset in range(request.forecast_days):
forecast_date = date.today() + timedelta(days=day_offset + 1)
try:
forecast_request = ForecastRequest(
tenant_id=request.tenant_id,
product_name=product,
location=request.location,
forecast_date=forecast_date,
business_type=request.business_type,
include_weather=request.include_weather,
include_traffic=request.include_traffic,
confidence_level=request.confidence_level
)
await self.generate_forecast(forecast_request, db)
completed_count += 1
except Exception as e:
logger.warning("Failed to generate forecast for product",
product=product,
date=forecast_date,
error=str(e))
failed_count += 1
# Update batch status
batch.status = "completed" if failed_count == 0 else "partial"
batch.completed_products = completed_count
batch.failed_products = failed_count
batch.completed_at = datetime.now()
await db.commit()
logger.info("Batch forecast generation completed",
batch_id=str(batch.id),
completed=completed_count,
failed=failed_count)
return batch
except Exception as e:
logger.error("Error in batch forecast generation", error=str(e))
raise
async def get_forecasts(self, tenant_id: str, location: str,
start_date: Optional[date] = None,
end_date: Optional[date] = None,
product_name: Optional[str] = None,
db: AsyncSession = None) -> List[Forecast]:
"""Retrieve forecasts with filtering"""
try:
query = select(Forecast).where(
and_(
Forecast.tenant_id == uuid.UUID(tenant_id),
Forecast.location == location
)
)
if start_date:
query = query.where(Forecast.forecast_date >= datetime.combine(start_date, datetime.min.time()))
if end_date:
query = query.where(Forecast.forecast_date <= datetime.combine(end_date, datetime.max.time()))
if product_name:
query = query.where(Forecast.product_name == product_name)
query = query.order_by(desc(Forecast.forecast_date))
result = await db.execute(query)
forecasts = result.scalars().all()
logger.info("Retrieved forecasts",
tenant_id=tenant_id,
count=len(forecasts))
return list(forecasts)
except Exception as e:
logger.error("Error retrieving forecasts", error=str(e))
raise
async def _get_latest_model(self, tenant_id: str, product_name: str, location: str) -> Optional[Dict[str, Any]]:
"""Get the latest trained model for a tenant/product combination"""
try:
# Call training service to get model information
async with httpx.AsyncClient() as client:
response = await client.get(
f"{settings.TRAINING_SERVICE_URL}/api/v1/models/latest",
params={
"tenant_id": tenant_id,
"product_name": product_name,
"location": location
},
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
)
if response.status_code == 200:
return response.json()
elif response.status_code == 404:
logger.warning("No model found",
tenant_id=tenant_id,
product=product_name)
return None
else:
response.raise_for_status()
except Exception as e:
logger.error("Error getting latest model", error=str(e))
raise
async def _prepare_forecast_features(self, request: ForecastRequest) -> Dict[str, Any]:
"""Prepare features for forecasting model"""
features = {
"date": request.forecast_date.isoformat(),
"day_of_week": request.forecast_date.weekday(),
"is_weekend": request.forecast_date.weekday() >= 5,
"business_type": request.business_type.value
}
# Add Spanish holidays
features["is_holiday"] = await self._is_spanish_holiday(request.forecast_date)
# Add weather data if requested
if request.include_weather:
weather_data = await self._get_weather_forecast(request.forecast_date)
features.update(weather_data)
# Add traffic data if requested
if request.include_traffic:
traffic_data = await self._get_traffic_forecast(request.forecast_date, request.location)
features.update(traffic_data)
return features
async def _is_spanish_holiday(self, forecast_date: date) -> bool:
"""Check if date is a Spanish holiday"""
try:
# Call data service for holiday information
async with httpx.AsyncClient() as client:
response = await client.get(
f"{settings.DATA_SERVICE_URL}/api/v1/holidays/check",
params={"date": forecast_date.isoformat()},
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
)
if response.status_code == 200:
return response.json().get("is_holiday", False)
else:
return False
except Exception as e:
logger.warning("Error checking holiday status", error=str(e))
return False
async def _get_weather_forecast(self, forecast_date: date) -> Dict[str, Any]:
"""Get weather forecast for the date"""
try:
# Call data service for weather forecast
async with httpx.AsyncClient() as client:
response = await client.get(
f"{settings.DATA_SERVICE_URL}/api/v1/weather/forecast",
params={"date": forecast_date.isoformat()},
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
)
if response.status_code == 200:
weather = response.json()
return {
"temperature": weather.get("temperature"),
"precipitation": weather.get("precipitation"),
"humidity": weather.get("humidity"),
"weather_description": weather.get("description")
}
else:
return {}
except Exception as e:
logger.warning("Error getting weather forecast", error=str(e))
return {}
async def _get_traffic_forecast(self, forecast_date: date, location: str) -> Dict[str, Any]:
"""Get traffic forecast for the date and location"""
try:
# Call data service for traffic forecast
async with httpx.AsyncClient() as client:
response = await client.get(
f"{settings.DATA_SERVICE_URL}/api/v1/traffic/forecast",
params={
"date": forecast_date.isoformat(),
"location": location
},
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
)
if response.status_code == 200:
traffic = response.json()
return {
"traffic_volume": traffic.get("volume"),
"pedestrian_count": traffic.get("pedestrian_count")
}
else:
return {}
except Exception as e:
logger.warning("Error getting traffic forecast", error=str(e))
return {}
async def _check_and_create_alerts(self, forecast: Forecast, db: AsyncSession):
"""Check forecast and create alerts if needed"""
try:
alerts_to_create = []
# High demand alert
if forecast.predicted_demand > settings.HIGH_DEMAND_THRESHOLD * 100: # Assuming base of 100 units
alerts_to_create.append({
"type": "high_demand",
"severity": "medium",
"message": f"High demand predicted for {forecast.product_name}: {forecast.predicted_demand:.0f} units"
})
# Low demand alert
if forecast.predicted_demand < settings.LOW_DEMAND_THRESHOLD * 100:
alerts_to_create.append({
"type": "low_demand",
"severity": "low",
"message": f"Low demand predicted for {forecast.product_name}: {forecast.predicted_demand:.0f} units"
})
# Stockout risk alert
if forecast.confidence_upper > settings.STOCKOUT_RISK_THRESHOLD * forecast.predicted_demand:
alerts_to_create.append({
"type": "stockout_risk",
"severity": "high",
"message": f"Stockout risk for {forecast.product_name}. Upper confidence: {forecast.confidence_upper:.0f}"
})
# Create alerts
for alert_data in alerts_to_create:
alert = ForecastAlert(
tenant_id=forecast.tenant_id,
forecast_id=forecast.id,
alert_type=alert_data["type"],
severity=alert_data["severity"],
message=alert_data["message"]
)
db.add(alert)
# Publish alert event
await publish_alert_created({
"alert_id": str(alert.id),
"tenant_id": str(forecast.tenant_id),
"product_name": forecast.product_name,
"alert_type": alert_data["type"],
"severity": alert_data["severity"],
"message": alert_data["message"]
})
await db.commit()
if alerts_to_create:
logger.info("Created forecast alerts",
forecast_id=str(forecast.id),
alerts_count=len(alerts_to_create))
except Exception as e:
logger.error("Error creating alerts", error=str(e))
# Don't raise - alerts are not critical for forecast generation