Add forecasting service
This commit is contained in:
@@ -1,72 +0,0 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime, date
|
||||
import structlog
|
||||
|
||||
from app.schemas.forecast import (
|
||||
ForecastRequest,
|
||||
ForecastResponse,
|
||||
BatchForecastRequest,
|
||||
ForecastPerformanceResponse
|
||||
)
|
||||
from app.services.forecast_service import ForecastService
|
||||
from app.services.messaging import publish_forecast_generated
|
||||
|
||||
# Import unified authentication
|
||||
from shared.auth.decorators import (
|
||||
get_current_user_dep,
|
||||
get_current_tenant_id_dep
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/forecasts", tags=["forecasting"])
|
||||
logger = structlog.get_logger()
|
||||
|
||||
@router.post("/generate", response_model=ForecastResponse)
|
||||
async def generate_forecast(
|
||||
request: ForecastRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
tenant_id: str = Depends(get_current_tenant_id_dep),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
):
|
||||
"""Generate forecast for products"""
|
||||
try:
|
||||
logger.info("Generating forecast",
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user["user_id"],
|
||||
products=len(request.products) if request.products else "all")
|
||||
|
||||
forecast_service = ForecastService()
|
||||
|
||||
# Ensure products belong to tenant
|
||||
if request.products:
|
||||
valid_products = await forecast_service.validate_products(
|
||||
tenant_id, request.products
|
||||
)
|
||||
if len(valid_products) != len(request.products):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Some products not found or not accessible"
|
||||
)
|
||||
|
||||
# Generate forecast
|
||||
forecast = await forecast_service.generate_forecast(
|
||||
tenant_id=tenant_id,
|
||||
request=request,
|
||||
user_id=current_user["user_id"]
|
||||
)
|
||||
|
||||
# Publish event
|
||||
background_tasks.add_task(
|
||||
publish_forecast_generated,
|
||||
forecast_id=forecast.id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user["user_id"]
|
||||
)
|
||||
|
||||
return forecast
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to generate forecast", error=str(e))
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
326
services/forecasting/app/api/forecasts.py
Normal file
326
services/forecasting/app/api/forecasts.py
Normal file
@@ -0,0 +1,326 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/api/forecasts.py
|
||||
# ================================================================
|
||||
"""
|
||||
Forecast API endpoints
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import List, Optional
|
||||
from datetime import date
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.auth import get_current_user_from_headers
|
||||
from app.services.forecasting_service import ForecastingService
|
||||
from app.schemas.forecasts import (
|
||||
ForecastRequest, ForecastResponse, BatchForecastRequest,
|
||||
BatchForecastResponse, AlertResponse
|
||||
)
|
||||
from app.models.forecasts import Forecast, PredictionBatch, ForecastAlert
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter()
|
||||
|
||||
# Initialize service
|
||||
forecasting_service = ForecastingService()
|
||||
|
||||
@router.post("/single", response_model=ForecastResponse)
|
||||
async def create_single_forecast(
|
||||
request: ForecastRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user_from_headers)
|
||||
):
|
||||
"""Generate a single product forecast"""
|
||||
|
||||
try:
|
||||
# Verify tenant access
|
||||
if str(request.tenant_id) != str(current_user.get("tenant_id")):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to this tenant"
|
||||
)
|
||||
|
||||
# Generate forecast
|
||||
forecast = await forecasting_service.generate_forecast(request, db)
|
||||
|
||||
# Convert to response model
|
||||
return ForecastResponse(
|
||||
id=str(forecast.id),
|
||||
tenant_id=str(forecast.tenant_id),
|
||||
product_name=forecast.product_name,
|
||||
location=forecast.location,
|
||||
forecast_date=forecast.forecast_date,
|
||||
predicted_demand=forecast.predicted_demand,
|
||||
confidence_lower=forecast.confidence_lower,
|
||||
confidence_upper=forecast.confidence_upper,
|
||||
confidence_level=forecast.confidence_level,
|
||||
model_id=str(forecast.model_id),
|
||||
model_version=forecast.model_version,
|
||||
algorithm=forecast.algorithm,
|
||||
business_type=forecast.business_type,
|
||||
is_holiday=forecast.is_holiday,
|
||||
is_weekend=forecast.is_weekend,
|
||||
day_of_week=forecast.day_of_week,
|
||||
weather_temperature=forecast.weather_temperature,
|
||||
weather_precipitation=forecast.weather_precipitation,
|
||||
weather_description=forecast.weather_description,
|
||||
traffic_volume=forecast.traffic_volume,
|
||||
created_at=forecast.created_at,
|
||||
processing_time_ms=forecast.processing_time_ms,
|
||||
features_used=forecast.features_used
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error creating single forecast", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal server error"
|
||||
)
|
||||
|
||||
@router.post("/batch", response_model=BatchForecastResponse)
|
||||
async def create_batch_forecast(
|
||||
request: BatchForecastRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user_from_headers)
|
||||
):
|
||||
"""Generate batch forecasts for multiple products"""
|
||||
|
||||
try:
|
||||
# Verify tenant access
|
||||
if str(request.tenant_id) != str(current_user.get("tenant_id")):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to this tenant"
|
||||
)
|
||||
|
||||
# Generate batch forecast
|
||||
batch = await forecasting_service.generate_batch_forecast(request, db)
|
||||
|
||||
# Get associated forecasts
|
||||
forecasts = await forecasting_service.get_forecasts(
|
||||
tenant_id=request.tenant_id,
|
||||
location=request.location,
|
||||
db=db
|
||||
)
|
||||
|
||||
# Convert forecasts to response models
|
||||
forecast_responses = []
|
||||
for forecast in forecasts[:batch.total_products]: # Limit to batch size
|
||||
forecast_responses.append(ForecastResponse(
|
||||
id=str(forecast.id),
|
||||
tenant_id=str(forecast.tenant_id),
|
||||
product_name=forecast.product_name,
|
||||
location=forecast.location,
|
||||
forecast_date=forecast.forecast_date,
|
||||
predicted_demand=forecast.predicted_demand,
|
||||
confidence_lower=forecast.confidence_lower,
|
||||
confidence_upper=forecast.confidence_upper,
|
||||
confidence_level=forecast.confidence_level,
|
||||
model_id=str(forecast.model_id),
|
||||
model_version=forecast.model_version,
|
||||
algorithm=forecast.algorithm,
|
||||
business_type=forecast.business_type,
|
||||
is_holiday=forecast.is_holiday,
|
||||
is_weekend=forecast.is_weekend,
|
||||
day_of_week=forecast.day_of_week,
|
||||
weather_temperature=forecast.weather_temperature,
|
||||
weather_precipitation=forecast.weather_precipitation,
|
||||
weather_description=forecast.weather_description,
|
||||
traffic_volume=forecast.traffic_volume,
|
||||
created_at=forecast.created_at,
|
||||
processing_time_ms=forecast.processing_time_ms,
|
||||
features_used=forecast.features_used
|
||||
))
|
||||
|
||||
return BatchForecastResponse(
|
||||
id=str(batch.id),
|
||||
tenant_id=str(batch.tenant_id),
|
||||
batch_name=batch.batch_name,
|
||||
status=batch.status,
|
||||
total_products=batch.total_products,
|
||||
completed_products=batch.completed_products,
|
||||
failed_products=batch.failed_products,
|
||||
requested_at=batch.requested_at,
|
||||
completed_at=batch.completed_at,
|
||||
processing_time_ms=batch.processing_time_ms,
|
||||
forecasts=forecast_responses
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error creating batch forecast", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal server error"
|
||||
)
|
||||
|
||||
@router.get("/list", response_model=List[ForecastResponse])
|
||||
async def list_forecasts(
|
||||
location: str,
|
||||
start_date: Optional[date] = Query(None),
|
||||
end_date: Optional[date] = Query(None),
|
||||
product_name: Optional[str] = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user_from_headers)
|
||||
):
|
||||
"""List forecasts with filtering"""
|
||||
|
||||
try:
|
||||
tenant_id = str(current_user.get("tenant_id"))
|
||||
|
||||
# Get forecasts
|
||||
forecasts = await forecasting_service.get_forecasts(
|
||||
tenant_id=tenant_id,
|
||||
location=location,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
product_name=product_name,
|
||||
db=db
|
||||
)
|
||||
|
||||
# Convert to response models
|
||||
return [
|
||||
ForecastResponse(
|
||||
id=str(forecast.id),
|
||||
tenant_id=str(forecast.tenant_id),
|
||||
product_name=forecast.product_name,
|
||||
location=forecast.location,
|
||||
forecast_date=forecast.forecast_date,
|
||||
predicted_demand=forecast.predicted_demand,
|
||||
confidence_lower=forecast.confidence_lower,
|
||||
confidence_upper=forecast.confidence_upper,
|
||||
confidence_level=forecast.confidence_level,
|
||||
model_id=str(forecast.model_id),
|
||||
model_version=forecast.model_version,
|
||||
algorithm=forecast.algorithm,
|
||||
business_type=forecast.business_type,
|
||||
is_holiday=forecast.is_holiday,
|
||||
is_weekend=forecast.is_weekend,
|
||||
day_of_week=forecast.day_of_week,
|
||||
weather_temperature=forecast.weather_temperature,
|
||||
weather_precipitation=forecast.weather_precipitation,
|
||||
weather_description=forecast.weather_description,
|
||||
traffic_volume=forecast.traffic_volume,
|
||||
created_at=forecast.created_at,
|
||||
processing_time_ms=forecast.processing_time_ms,
|
||||
features_used=forecast.features_used
|
||||
)
|
||||
for forecast in forecasts
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error listing forecasts", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal server error"
|
||||
)
|
||||
|
||||
@router.get("/alerts", response_model=List[AlertResponse])
|
||||
async def get_forecast_alerts(
|
||||
active_only: bool = Query(True),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user_from_headers)
|
||||
):
|
||||
"""Get forecast alerts for tenant"""
|
||||
|
||||
try:
|
||||
from sqlalchemy import select, and_
|
||||
|
||||
tenant_id = current_user.get("tenant_id")
|
||||
|
||||
# Build query
|
||||
query = select(ForecastAlert).where(
|
||||
ForecastAlert.tenant_id == tenant_id
|
||||
)
|
||||
|
||||
if active_only:
|
||||
query = query.where(ForecastAlert.is_active == True)
|
||||
|
||||
query = query.order_by(ForecastAlert.created_at.desc())
|
||||
|
||||
# Execute query
|
||||
result = await db.execute(query)
|
||||
alerts = result.scalars().all()
|
||||
|
||||
# Convert to response models
|
||||
return [
|
||||
AlertResponse(
|
||||
id=str(alert.id),
|
||||
tenant_id=str(alert.tenant_id),
|
||||
forecast_id=str(alert.forecast_id),
|
||||
alert_type=alert.alert_type,
|
||||
severity=alert.severity,
|
||||
message=alert.message,
|
||||
is_active=alert.is_active,
|
||||
created_at=alert.created_at,
|
||||
acknowledged_at=alert.acknowledged_at,
|
||||
notification_sent=alert.notification_sent
|
||||
)
|
||||
for alert in alerts
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting forecast alerts", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal server error"
|
||||
)
|
||||
|
||||
@router.put("/alerts/{alert_id}/acknowledge")
|
||||
async def acknowledge_alert(
|
||||
alert_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user_from_headers)
|
||||
):
|
||||
"""Acknowledge a forecast alert"""
|
||||
|
||||
try:
|
||||
from sqlalchemy import select, update
|
||||
from datetime import datetime
|
||||
|
||||
tenant_id = current_user.get("tenant_id")
|
||||
|
||||
# Get alert
|
||||
result = await db.execute(
|
||||
select(ForecastAlert).where(
|
||||
and_(
|
||||
ForecastAlert.id == alert_id,
|
||||
ForecastAlert.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
)
|
||||
alert = result.scalar_one_or_none()
|
||||
|
||||
if not alert:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Alert not found"
|
||||
)
|
||||
|
||||
# Update alert
|
||||
alert.acknowledged_at = datetime.now()
|
||||
alert.is_active = False
|
||||
|
||||
await db.commit()
|
||||
|
||||
return {"message": "Alert acknowledged successfully"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Error acknowledging alert", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal server error"
|
||||
)
|
||||
141
services/forecasting/app/api/predictions.py
Normal file
141
services/forecasting/app/api/predictions.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/api/predictions.py
|
||||
# ================================================================
|
||||
"""
|
||||
Prediction API endpoints - Real-time prediction capabilities
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import List, Dict, Any
|
||||
from datetime import date, datetime, timedelta
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.auth import get_current_user_from_headers
|
||||
from app.services.prediction_service import PredictionService
|
||||
from app.schemas.forecasts import ForecastRequest
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter()
|
||||
|
||||
# Initialize service
|
||||
prediction_service = PredictionService()
|
||||
|
||||
@router.post("/realtime")
|
||||
async def get_realtime_prediction(
|
||||
product_name: str,
|
||||
location: str,
|
||||
forecast_date: date,
|
||||
features: Dict[str, Any],
|
||||
current_user: dict = Depends(get_current_user_from_headers)
|
||||
):
|
||||
"""Get real-time prediction without storing in database"""
|
||||
|
||||
try:
|
||||
tenant_id = str(current_user.get("tenant_id"))
|
||||
|
||||
# Get latest model
|
||||
from app.services.forecasting_service import ForecastingService
|
||||
forecasting_service = ForecastingService()
|
||||
|
||||
model_info = await forecasting_service._get_latest_model(
|
||||
tenant_id, product_name, location
|
||||
)
|
||||
|
||||
if not model_info:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No trained model found for {product_name}"
|
||||
)
|
||||
|
||||
# Generate prediction
|
||||
prediction = await prediction_service.predict(
|
||||
model_id=model_info["model_id"],
|
||||
features=features,
|
||||
confidence_level=0.8
|
||||
)
|
||||
|
||||
return {
|
||||
"product_name": product_name,
|
||||
"location": location,
|
||||
"forecast_date": forecast_date,
|
||||
"predicted_demand": prediction["demand"],
|
||||
"confidence_lower": prediction["lower_bound"],
|
||||
"confidence_upper": prediction["upper_bound"],
|
||||
"model_id": model_info["model_id"],
|
||||
"model_version": model_info["version"],
|
||||
"generated_at": datetime.now(),
|
||||
"features_used": features
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Error getting realtime prediction", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal server error"
|
||||
)
|
||||
|
||||
@router.get("/quick/{product_name}")
|
||||
async def get_quick_prediction(
|
||||
product_name: str,
|
||||
location: str = Query(...),
|
||||
days_ahead: int = Query(1, ge=1, le=7),
|
||||
current_user: dict = Depends(get_current_user_from_headers)
|
||||
):
|
||||
"""Get quick prediction for next few days"""
|
||||
|
||||
try:
|
||||
tenant_id = str(current_user.get("tenant_id"))
|
||||
|
||||
# Generate predictions for the next N days
|
||||
predictions = []
|
||||
|
||||
for day in range(1, days_ahead + 1):
|
||||
forecast_date = date.today() + timedelta(days=day)
|
||||
|
||||
# Prepare basic features
|
||||
features = {
|
||||
"date": forecast_date.isoformat(),
|
||||
"day_of_week": forecast_date.weekday(),
|
||||
"is_weekend": forecast_date.weekday() >= 5,
|
||||
"business_type": "individual"
|
||||
}
|
||||
|
||||
# Get model and predict
|
||||
from app.services.forecasting_service import ForecastingService
|
||||
forecasting_service = ForecastingService()
|
||||
|
||||
model_info = await forecasting_service._get_latest_model(
|
||||
tenant_id, product_name, location
|
||||
)
|
||||
|
||||
if model_info:
|
||||
prediction = await prediction_service.predict(
|
||||
model_id=model_info["model_id"],
|
||||
features=features
|
||||
)
|
||||
|
||||
predictions.append({
|
||||
"date": forecast_date,
|
||||
"predicted_demand": prediction["demand"],
|
||||
"confidence_lower": prediction["lower_bound"],
|
||||
"confidence_upper": prediction["upper_bound"]
|
||||
})
|
||||
|
||||
return {
|
||||
"product_name": product_name,
|
||||
"location": location,
|
||||
"predictions": predictions,
|
||||
"generated_at": datetime.now()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting quick prediction", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal server error"
|
||||
)
|
||||
|
||||
@@ -1,22 +1,48 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/core/auth.py
|
||||
# ================================================================
|
||||
"""
|
||||
Authentication configuration for forecasting service
|
||||
Authentication utilities for forecasting service
|
||||
"""
|
||||
|
||||
from shared.auth.jwt_handler import JWTHandler
|
||||
from shared.auth.decorators import require_auth, require_role
|
||||
from app.core.config import settings
|
||||
import structlog
|
||||
from fastapi import HTTPException, status, Request
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
# Initialize JWT handler
|
||||
jwt_handler = JWTHandler(
|
||||
secret_key=settings.JWT_SECRET_KEY,
|
||||
algorithm=settings.JWT_ALGORITHM,
|
||||
access_token_expire_minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Export commonly used functions
|
||||
verify_token = jwt_handler.verify_token
|
||||
create_access_token = jwt_handler.create_access_token
|
||||
get_current_user = jwt_handler.get_current_user
|
||||
async def get_current_user_from_headers(request: Request) -> Dict[str, Any]:
|
||||
"""
|
||||
Get current user from gateway headers
|
||||
Gateway middleware adds user context to headers after JWT verification
|
||||
"""
|
||||
|
||||
try:
|
||||
# Extract user information from headers set by API Gateway
|
||||
user_id = request.headers.get("X-User-ID")
|
||||
user_email = request.headers.get("X-User-Email")
|
||||
tenant_id = request.headers.get("X-Tenant-ID")
|
||||
user_roles = request.headers.get("X-User-Roles", "").split(",")
|
||||
|
||||
if not user_id or not tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required"
|
||||
)
|
||||
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"email": user_email,
|
||||
"tenant_id": tenant_id,
|
||||
"roles": [role.strip() for role in user_roles if role.strip()]
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Error extracting user from headers", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authentication"
|
||||
)
|
||||
|
||||
# Export decorators
|
||||
__all__ = ['verify_token', 'create_access_token', 'get_current_user', 'require_auth', 'require_role']
|
||||
|
||||
@@ -1,12 +1,73 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/core/database.py
|
||||
# ================================================================
|
||||
"""
|
||||
Database configuration for forecasting service
|
||||
"""
|
||||
|
||||
from shared.database.base import DatabaseManager
|
||||
import structlog
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
from sqlalchemy.pool import NullPool
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.core.config import settings
|
||||
from shared.database.base import Base
|
||||
|
||||
# Initialize database manager
|
||||
database_manager = DatabaseManager(settings.DATABASE_URL)
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Alias for convenience
|
||||
get_db = database_manager.get_db
|
||||
# Create async engine
|
||||
engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
poolclass=NullPool,
|
||||
echo=settings.DEBUG,
|
||||
future=True
|
||||
)
|
||||
|
||||
# Create session factory
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
autocommit=False
|
||||
)
|
||||
|
||||
class DatabaseManager:
|
||||
"""Database management operations"""
|
||||
|
||||
async def create_tables(self):
|
||||
"""Create database tables"""
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
logger.info("Forecasting database tables created successfully")
|
||||
|
||||
async def get_session(self) -> AsyncSession:
|
||||
"""Get database session"""
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Database session error: {e}")
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
# Global database manager instance
|
||||
database_manager = DatabaseManager()
|
||||
|
||||
async def get_db() -> AsyncSession:
|
||||
"""Database dependency"""
|
||||
async for session in database_manager.get_session():
|
||||
yield session
|
||||
|
||||
async def get_db_health() -> bool:
|
||||
"""Check database health"""
|
||||
try:
|
||||
async with AsyncSessionLocal() as session:
|
||||
await session.execute(text("SELECT 1"))
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Database health check failed: {e}")
|
||||
return False
|
||||
|
||||
@@ -1,61 +1,116 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/main.py
|
||||
# ================================================================
|
||||
"""
|
||||
uLuforecasting Service
|
||||
Forecasting Service Main Application
|
||||
Demand prediction and forecasting service for bakery operations
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from fastapi import FastAPI
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import database_manager
|
||||
from app.core.database import database_manager, get_db_health
|
||||
from app.api import forecasts, predictions
|
||||
from app.services.messaging import setup_messaging, cleanup_messaging
|
||||
from shared.monitoring.logging import setup_logging
|
||||
from shared.monitoring.metrics import MetricsCollector
|
||||
|
||||
# Setup logging
|
||||
setup_logging("forecasting-service", "INFO")
|
||||
# Setup structured logging
|
||||
setup_logging("forecasting-service", settings.LOG_LEVEL)
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title="uLuforecasting Service",
|
||||
description="uLuforecasting service for bakery forecasting",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# Initialize metrics collector
|
||||
metrics_collector = MetricsCollector("forecasting-service")
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan manager for startup and shutdown events"""
|
||||
# Startup
|
||||
logger.info("Starting Forecasting Service", version="1.0.0")
|
||||
|
||||
try:
|
||||
# Initialize database
|
||||
logger.info("Initializing database connection")
|
||||
await database_manager.create_tables()
|
||||
logger.info("Database initialized successfully")
|
||||
|
||||
# Initialize messaging
|
||||
logger.info("Setting up messaging")
|
||||
await setup_messaging()
|
||||
logger.info("Messaging initialized")
|
||||
|
||||
# Register custom metrics
|
||||
metrics_collector.register_counter("forecasts_generated_total", "Total forecasts generated")
|
||||
metrics_collector.register_counter("predictions_served_total", "Total predictions served")
|
||||
metrics_collector.register_histogram("forecast_processing_time_seconds", "Time to process forecast request")
|
||||
metrics_collector.register_gauge("active_models_count", "Number of active models")
|
||||
|
||||
# Start metrics server
|
||||
metrics_collector.start_metrics_server(8080)
|
||||
|
||||
logger.info("Forecasting Service started successfully")
|
||||
|
||||
yield
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to start Forecasting Service", error=str(e))
|
||||
raise
|
||||
finally:
|
||||
# Shutdown
|
||||
logger.info("Shutting down Forecasting Service")
|
||||
|
||||
try:
|
||||
await cleanup_messaging()
|
||||
logger.info("Messaging cleanup completed")
|
||||
except Exception as e:
|
||||
logger.error("Error during messaging cleanup", error=str(e))
|
||||
|
||||
# Create FastAPI app with lifespan
|
||||
app = FastAPI(
|
||||
title="Bakery Forecasting Service",
|
||||
description="AI-powered demand prediction and forecasting service for bakery operations",
|
||||
version="1.0.0",
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_origins=settings.CORS_ORIGINS_LIST,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Application startup"""
|
||||
logger.info("Starting uLuforecasting Service")
|
||||
|
||||
# Create database tables
|
||||
await database_manager.create_tables()
|
||||
|
||||
# Start metrics server
|
||||
metrics_collector.start_metrics_server(8080)
|
||||
|
||||
logger.info("uLuforecasting Service started successfully")
|
||||
# Include API routers
|
||||
app.include_router(forecasts.router, prefix="/api/v1/forecasts", tags=["forecasts"])
|
||||
app.include_router(predictions.router, prefix="/api/v1/predictions", tags=["predictions"])
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
db_health = await get_db_health()
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"status": "healthy" if db_health else "unhealthy",
|
||||
"service": "forecasting-service",
|
||||
"version": "1.0.0"
|
||||
"version": "1.0.0",
|
||||
"database": "connected" if db_health else "disconnected",
|
||||
"timestamp": structlog.get_logger().info("Health check requested")
|
||||
}
|
||||
|
||||
@app.get("/metrics")
|
||||
async def get_metrics():
|
||||
"""Metrics endpoint for Prometheus"""
|
||||
return metrics_collector.generate_latest()
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
|
||||
|
||||
101
services/forecasting/app/ml/model_loader.py
Normal file
101
services/forecasting/app/ml/model_loader.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/ml/model_loader.py
|
||||
# ================================================================
|
||||
"""
|
||||
Model loading and management utilities
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from typing import Dict, Any, Optional
|
||||
import pickle
|
||||
import json
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
import httpx
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
class ModelLoader:
|
||||
"""
|
||||
Utility class for loading and managing ML models
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.model_cache = {}
|
||||
self.metadata_cache = {}
|
||||
|
||||
async def load_model_with_metadata(self, model_id: str) -> Dict[str, Any]:
|
||||
"""Load model along with its metadata"""
|
||||
|
||||
try:
|
||||
# Get model metadata first
|
||||
metadata = await self._get_model_metadata(model_id)
|
||||
|
||||
if not metadata:
|
||||
return None
|
||||
|
||||
# Load the actual model
|
||||
model = await self._load_model_binary(model_id)
|
||||
|
||||
if not model:
|
||||
return None
|
||||
|
||||
return {
|
||||
"model": model,
|
||||
"metadata": metadata,
|
||||
"loaded_at": datetime.now()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error loading model with metadata",
|
||||
model_id=model_id,
|
||||
error=str(e))
|
||||
return None
|
||||
|
||||
async def _get_model_metadata(self, model_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get model metadata from training service"""
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.TRAINING_SERVICE_URL}/api/v1/models/{model_id}/metadata",
|
||||
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
logger.warning("Model metadata not found",
|
||||
model_id=model_id,
|
||||
status_code=response.status_code)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting model metadata", error=str(e))
|
||||
return None
|
||||
|
||||
async def _load_model_binary(self, model_id: str):
|
||||
"""Load model binary from training service"""
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.get(
|
||||
f"{settings.TRAINING_SERVICE_URL}/api/v1/models/{model_id}/download",
|
||||
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
model = pickle.loads(response.content)
|
||||
return model
|
||||
else:
|
||||
logger.error("Failed to download model binary",
|
||||
model_id=model_id,
|
||||
status_code=response.status_code)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error loading model binary", error=str(e))
|
||||
return None
|
||||
|
||||
305
services/forecasting/app/ml/predictor.py
Normal file
305
services/forecasting/app/ml/predictor.py
Normal file
@@ -0,0 +1,305 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/ml/predictor.py
|
||||
# ================================================================
|
||||
"""
|
||||
Enhanced predictor module with advanced forecasting capabilities
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime, date, timedelta
|
||||
import pickle
|
||||
import json
|
||||
|
||||
from app.core.config import settings
|
||||
from shared.monitoring.metrics import MetricsCollector
|
||||
|
||||
logger = structlog.get_logger()
|
||||
metrics = MetricsCollector("forecasting-service")
|
||||
|
||||
class BakeryPredictor:
|
||||
"""
|
||||
Advanced predictor for bakery demand forecasting
|
||||
Handles Prophet models and business-specific logic
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.model_cache = {}
|
||||
self.business_rules = BakeryBusinessRules()
|
||||
|
||||
async def predict_demand(self, model, features: Dict[str, Any],
|
||||
business_type: str = "individual") -> Dict[str, float]:
|
||||
"""Generate demand prediction with business rules applied"""
|
||||
|
||||
try:
|
||||
# Generate base prediction
|
||||
base_prediction = await self._generate_base_prediction(model, features)
|
||||
|
||||
# Apply business rules
|
||||
adjusted_prediction = self.business_rules.apply_rules(
|
||||
base_prediction, features, business_type
|
||||
)
|
||||
|
||||
# Add uncertainty estimation
|
||||
final_prediction = self._add_uncertainty_bands(adjusted_prediction, features)
|
||||
|
||||
return final_prediction
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in demand prediction", error=str(e))
|
||||
raise
|
||||
|
||||
async def _generate_base_prediction(self, model, features: Dict[str, Any]) -> Dict[str, float]:
|
||||
"""Generate base prediction from Prophet model"""
|
||||
|
||||
try:
|
||||
# Convert features to Prophet DataFrame
|
||||
df = self._prepare_prophet_dataframe(features)
|
||||
|
||||
# Generate forecast
|
||||
forecast = model.predict(df)
|
||||
|
||||
if len(forecast) > 0:
|
||||
row = forecast.iloc[0]
|
||||
return {
|
||||
"yhat": float(row['yhat']),
|
||||
"yhat_lower": float(row['yhat_lower']),
|
||||
"yhat_upper": float(row['yhat_upper']),
|
||||
"trend": float(row.get('trend', 0)),
|
||||
"seasonal": float(row.get('seasonal', 0)),
|
||||
"weekly": float(row.get('weekly', 0)),
|
||||
"yearly": float(row.get('yearly', 0)),
|
||||
"holidays": float(row.get('holidays', 0))
|
||||
}
|
||||
else:
|
||||
raise ValueError("No prediction generated from model")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error generating base prediction", error=str(e))
|
||||
raise
|
||||
|
||||
def _prepare_prophet_dataframe(self, features: Dict[str, Any]) -> pd.DataFrame:
|
||||
"""Convert features to Prophet-compatible DataFrame"""
|
||||
|
||||
try:
|
||||
# Create base DataFrame
|
||||
df = pd.DataFrame({
|
||||
'ds': [pd.to_datetime(features['date'])]
|
||||
})
|
||||
|
||||
# Add regressor features
|
||||
feature_mapping = {
|
||||
'temperature': 'temperature',
|
||||
'precipitation': 'precipitation',
|
||||
'humidity': 'humidity',
|
||||
'wind_speed': 'wind_speed',
|
||||
'traffic_volume': 'traffic_volume',
|
||||
'pedestrian_count': 'pedestrian_count'
|
||||
}
|
||||
|
||||
for feature_key, df_column in feature_mapping.items():
|
||||
if feature_key in features and features[feature_key] is not None:
|
||||
df[df_column] = float(features[feature_key])
|
||||
else:
|
||||
df[df_column] = 0.0
|
||||
|
||||
# Add categorical features
|
||||
df['day_of_week'] = int(features.get('day_of_week', 0))
|
||||
df['is_weekend'] = int(features.get('is_weekend', False))
|
||||
df['is_holiday'] = int(features.get('is_holiday', False))
|
||||
|
||||
# Business type
|
||||
business_type = features.get('business_type', 'individual')
|
||||
df['is_central_workshop'] = int(business_type == 'central_workshop')
|
||||
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error preparing Prophet dataframe", error=str(e))
|
||||
raise
|
||||
|
||||
def _add_uncertainty_bands(self, prediction: Dict[str, float],
|
||||
features: Dict[str, Any]) -> Dict[str, float]:
|
||||
"""Add uncertainty estimation based on external factors"""
|
||||
|
||||
try:
|
||||
base_demand = prediction["yhat"]
|
||||
base_lower = prediction["yhat_lower"]
|
||||
base_upper = prediction["yhat_upper"]
|
||||
|
||||
# Weather uncertainty
|
||||
weather_uncertainty = self._calculate_weather_uncertainty(features)
|
||||
|
||||
# Holiday uncertainty
|
||||
holiday_uncertainty = self._calculate_holiday_uncertainty(features)
|
||||
|
||||
# Weekend uncertainty
|
||||
weekend_uncertainty = self._calculate_weekend_uncertainty(features)
|
||||
|
||||
# Total uncertainty factor
|
||||
total_uncertainty = 1.0 + weather_uncertainty + holiday_uncertainty + weekend_uncertainty
|
||||
|
||||
# Adjust bounds
|
||||
uncertainty_range = (base_upper - base_lower) * total_uncertainty
|
||||
center_point = base_demand
|
||||
|
||||
adjusted_lower = center_point - (uncertainty_range / 2)
|
||||
adjusted_upper = center_point + (uncertainty_range / 2)
|
||||
|
||||
return {
|
||||
"demand": max(0, base_demand), # Never predict negative demand
|
||||
"lower_bound": max(0, adjusted_lower),
|
||||
"upper_bound": adjusted_upper,
|
||||
"uncertainty_factor": total_uncertainty,
|
||||
"trend": prediction.get("trend", 0),
|
||||
"seasonal": prediction.get("seasonal", 0),
|
||||
"holiday_effect": prediction.get("holidays", 0)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error adding uncertainty bands", error=str(e))
|
||||
# Return basic prediction if uncertainty calculation fails
|
||||
return {
|
||||
"demand": max(0, prediction["yhat"]),
|
||||
"lower_bound": max(0, prediction["yhat_lower"]),
|
||||
"upper_bound": prediction["yhat_upper"],
|
||||
"uncertainty_factor": 1.0
|
||||
}
|
||||
|
||||
def _calculate_weather_uncertainty(self, features: Dict[str, Any]) -> float:
|
||||
"""Calculate weather-based uncertainty"""
|
||||
|
||||
uncertainty = 0.0
|
||||
|
||||
# Temperature extremes add uncertainty
|
||||
temp = features.get('temperature')
|
||||
if temp is not None:
|
||||
if temp < settings.TEMPERATURE_THRESHOLD_COLD or temp > settings.TEMPERATURE_THRESHOLD_HOT:
|
||||
uncertainty += 0.1
|
||||
|
||||
# Rain adds uncertainty
|
||||
precipitation = features.get('precipitation')
|
||||
if precipitation is not None and precipitation > 0:
|
||||
uncertainty += 0.05 * min(precipitation, 10) # Cap at 50mm
|
||||
|
||||
return uncertainty
|
||||
|
||||
def _calculate_holiday_uncertainty(self, features: Dict[str, Any]) -> float:
|
||||
"""Calculate holiday-based uncertainty"""
|
||||
|
||||
if features.get('is_holiday', False):
|
||||
return 0.2 # 20% additional uncertainty on holidays
|
||||
return 0.0
|
||||
|
||||
def _calculate_weekend_uncertainty(self, features: Dict[str, Any]) -> float:
|
||||
"""Calculate weekend-based uncertainty"""
|
||||
|
||||
if features.get('is_weekend', False):
|
||||
return 0.1 # 10% additional uncertainty on weekends
|
||||
return 0.0
|
||||
|
||||
|
||||
class BakeryBusinessRules:
|
||||
"""
|
||||
Business rules for Spanish bakeries
|
||||
Applies domain-specific adjustments to predictions
|
||||
"""
|
||||
|
||||
def apply_rules(self, prediction: Dict[str, float], features: Dict[str, Any],
|
||||
business_type: str) -> Dict[str, float]:
|
||||
"""Apply all business rules to prediction"""
|
||||
|
||||
adjusted_prediction = prediction.copy()
|
||||
|
||||
# Apply weather rules
|
||||
adjusted_prediction = self._apply_weather_rules(adjusted_prediction, features)
|
||||
|
||||
# Apply time-based rules
|
||||
adjusted_prediction = self._apply_time_rules(adjusted_prediction, features)
|
||||
|
||||
# Apply business type rules
|
||||
adjusted_prediction = self._apply_business_type_rules(adjusted_prediction, business_type)
|
||||
|
||||
# Apply Spanish-specific rules
|
||||
adjusted_prediction = self._apply_spanish_rules(adjusted_prediction, features)
|
||||
|
||||
return adjusted_prediction
|
||||
|
||||
def _apply_weather_rules(self, prediction: Dict[str, float],
|
||||
features: Dict[str, Any]) -> Dict[str, float]:
|
||||
"""Apply weather-based business rules"""
|
||||
|
||||
# Rain reduces foot traffic
|
||||
precipitation = features.get('precipitation', 0)
|
||||
if precipitation > 0:
|
||||
rain_factor = settings.RAIN_IMPACT_FACTOR
|
||||
prediction["yhat"] *= rain_factor
|
||||
prediction["yhat_lower"] *= rain_factor
|
||||
prediction["yhat_upper"] *= rain_factor
|
||||
|
||||
# Extreme temperatures affect different products differently
|
||||
temperature = features.get('temperature')
|
||||
if temperature is not None:
|
||||
if temperature > settings.TEMPERATURE_THRESHOLD_HOT:
|
||||
# Hot weather reduces bread sales, increases cold drinks
|
||||
prediction["yhat"] *= 0.9
|
||||
elif temperature < settings.TEMPERATURE_THRESHOLD_COLD:
|
||||
# Cold weather increases hot beverage sales
|
||||
prediction["yhat"] *= 1.1
|
||||
|
||||
return prediction
|
||||
|
||||
def _apply_time_rules(self, prediction: Dict[str, float],
|
||||
features: Dict[str, Any]) -> Dict[str, float]:
|
||||
"""Apply time-based business rules"""
|
||||
|
||||
# Weekend adjustment
|
||||
if features.get('is_weekend', False):
|
||||
weekend_factor = settings.WEEKEND_ADJUSTMENT_FACTOR
|
||||
prediction["yhat"] *= weekend_factor
|
||||
prediction["yhat_lower"] *= weekend_factor
|
||||
prediction["yhat_upper"] *= weekend_factor
|
||||
|
||||
# Holiday adjustment
|
||||
if features.get('is_holiday', False):
|
||||
holiday_factor = settings.HOLIDAY_ADJUSTMENT_FACTOR
|
||||
prediction["yhat"] *= holiday_factor
|
||||
prediction["yhat_lower"] *= holiday_factor
|
||||
prediction["yhat_upper"] *= holiday_factor
|
||||
|
||||
return prediction
|
||||
|
||||
def _apply_business_type_rules(self, prediction: Dict[str, float],
|
||||
business_type: str) -> Dict[str, float]:
|
||||
"""Apply business type specific rules"""
|
||||
|
||||
if business_type == "central_workshop":
|
||||
# Central workshops have more stable demand
|
||||
uncertainty_reduction = 0.8
|
||||
center = prediction["yhat"]
|
||||
lower = prediction["yhat_lower"]
|
||||
upper = prediction["yhat_upper"]
|
||||
|
||||
# Reduce uncertainty band
|
||||
new_range = (upper - lower) * uncertainty_reduction
|
||||
prediction["yhat_lower"] = center - (new_range / 2)
|
||||
prediction["yhat_upper"] = center + (new_range / 2)
|
||||
|
||||
return prediction
|
||||
|
||||
def _apply_spanish_rules(self, prediction: Dict[str, float],
|
||||
features: Dict[str, Any]) -> Dict[str, float]:
|
||||
"""Apply Spanish bakery specific rules"""
|
||||
|
||||
# Spanish siesta time considerations
|
||||
current_date = pd.to_datetime(features['date'])
|
||||
day_of_week = current_date.weekday()
|
||||
|
||||
# Reduced activity during typical siesta hours (14:00-17:00)
|
||||
# This affects afternoon sales planning
|
||||
if day_of_week < 5: # Weekdays
|
||||
prediction["yhat"] *= 0.95 # Slight reduction for siesta effect
|
||||
|
||||
return prediction
|
||||
112
services/forecasting/app/models/forecasts.py
Normal file
112
services/forecasting/app/models/forecasts.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/models/forecasts.py
|
||||
# ================================================================
|
||||
"""
|
||||
Forecast models for the forecasting service
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, Integer, Float, DateTime, Boolean, Text, JSON
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
|
||||
from shared.database.base import Base
|
||||
|
||||
class Forecast(Base):
|
||||
"""Forecast model for storing prediction results"""
|
||||
__tablename__ = "forecasts"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
product_name = Column(String(255), nullable=False, index=True)
|
||||
location = Column(String(255), nullable=False, index=True)
|
||||
|
||||
# Forecast period
|
||||
forecast_date = Column(DateTime(timezone=True), nullable=False, index=True)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
|
||||
# Prediction results
|
||||
predicted_demand = Column(Float, nullable=False)
|
||||
confidence_lower = Column(Float, nullable=False)
|
||||
confidence_upper = Column(Float, nullable=False)
|
||||
confidence_level = Column(Float, default=0.8)
|
||||
|
||||
# Model information
|
||||
model_id = Column(UUID(as_uuid=True), nullable=False)
|
||||
model_version = Column(String(50), nullable=False)
|
||||
algorithm = Column(String(50), default="prophet")
|
||||
|
||||
# Business context
|
||||
business_type = Column(String(50), default="individual") # individual or central_workshop
|
||||
day_of_week = Column(Integer, nullable=False)
|
||||
is_holiday = Column(Boolean, default=False)
|
||||
is_weekend = Column(Boolean, default=False)
|
||||
|
||||
# External factors
|
||||
weather_temperature = Column(Float)
|
||||
weather_precipitation = Column(Float)
|
||||
weather_description = Column(String(100))
|
||||
traffic_volume = Column(Integer)
|
||||
|
||||
# Metadata
|
||||
processing_time_ms = Column(Integer)
|
||||
features_used = Column(JSON)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Forecast(id={self.id}, product={self.product_name}, date={self.forecast_date})>"
|
||||
|
||||
class PredictionBatch(Base):
|
||||
"""Batch prediction requests"""
|
||||
__tablename__ = "prediction_batches"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
|
||||
# Batch information
|
||||
batch_name = Column(String(255), nullable=False)
|
||||
requested_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
completed_at = Column(DateTime(timezone=True))
|
||||
|
||||
# Status
|
||||
status = Column(String(50), default="pending") # pending, processing, completed, failed
|
||||
total_products = Column(Integer, default=0)
|
||||
completed_products = Column(Integer, default=0)
|
||||
failed_products = Column(Integer, default=0)
|
||||
|
||||
# Configuration
|
||||
forecast_days = Column(Integer, default=7)
|
||||
business_type = Column(String(50), default="individual")
|
||||
|
||||
# Results
|
||||
error_message = Column(Text)
|
||||
processing_time_ms = Column(Integer)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<PredictionBatch(id={self.id}, status={self.status})>"
|
||||
|
||||
class ForecastAlert(Base):
|
||||
"""Alerts based on forecast results"""
|
||||
__tablename__ = "forecast_alerts"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
forecast_id = Column(UUID(as_uuid=True), nullable=False)
|
||||
|
||||
# Alert information
|
||||
alert_type = Column(String(50), nullable=False) # high_demand, low_demand, stockout_risk
|
||||
severity = Column(String(20), default="medium") # low, medium, high, critical
|
||||
message = Column(Text, nullable=False)
|
||||
|
||||
# Status
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
acknowledged_at = Column(DateTime(timezone=True))
|
||||
resolved_at = Column(DateTime(timezone=True))
|
||||
is_active = Column(Boolean, default=True)
|
||||
|
||||
# Notification
|
||||
notification_sent = Column(Boolean, default=False)
|
||||
notification_method = Column(String(50)) # email, whatsapp, sms
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ForecastAlert(id={self.id}, type={self.alert_type})>"
|
||||
|
||||
67
services/forecasting/app/models/predictions.py
Normal file
67
services/forecasting/app/models/predictions.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/models/predictions.py
|
||||
# ================================================================
|
||||
"""
|
||||
Additional prediction models for the forecasting service
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, Integer, Float, DateTime, Boolean, Text, JSON
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
|
||||
from shared.database.base import Base
|
||||
|
||||
class ModelPerformanceMetric(Base):
|
||||
"""Track model performance over time"""
|
||||
__tablename__ = "model_performance_metrics"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
model_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
product_name = Column(String(255), nullable=False)
|
||||
|
||||
# Performance metrics
|
||||
mae = Column(Float) # Mean Absolute Error
|
||||
mape = Column(Float) # Mean Absolute Percentage Error
|
||||
rmse = Column(Float) # Root Mean Square Error
|
||||
accuracy_score = Column(Float)
|
||||
|
||||
# Evaluation period
|
||||
evaluation_date = Column(DateTime(timezone=True), nullable=False)
|
||||
evaluation_period_start = Column(DateTime(timezone=True))
|
||||
evaluation_period_end = Column(DateTime(timezone=True))
|
||||
|
||||
# Metadata
|
||||
sample_size = Column(Integer)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ModelPerformanceMetric(model_id={self.model_id}, mae={self.mae})>"
|
||||
|
||||
class PredictionCache(Base):
|
||||
"""Cache frequently requested predictions"""
|
||||
__tablename__ = "prediction_cache"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
cache_key = Column(String(255), unique=True, nullable=False, index=True)
|
||||
|
||||
# Cached data
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
product_name = Column(String(255), nullable=False)
|
||||
location = Column(String(255), nullable=False)
|
||||
forecast_date = Column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
# Cached results
|
||||
predicted_demand = Column(Float, nullable=False)
|
||||
confidence_lower = Column(Float, nullable=False)
|
||||
confidence_upper = Column(Float, nullable=False)
|
||||
model_id = Column(UUID(as_uuid=True), nullable=False)
|
||||
|
||||
# Cache metadata
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
expires_at = Column(DateTime(timezone=True), nullable=False)
|
||||
hit_count = Column(Integer, default=0)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<PredictionCache(key={self.cache_key}, product={self.product_name})>"
|
||||
123
services/forecasting/app/schemas/forecasts.py
Normal file
123
services/forecasting/app/schemas/forecasts.py
Normal file
@@ -0,0 +1,123 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/schemas/forecasts.py
|
||||
# ================================================================
|
||||
"""
|
||||
Forecast schemas for request/response validation
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from datetime import datetime, date
|
||||
from typing import Optional, List, Dict, Any
|
||||
from enum import Enum
|
||||
|
||||
class BusinessType(str, Enum):
|
||||
INDIVIDUAL = "individual"
|
||||
CENTRAL_WORKSHOP = "central_workshop"
|
||||
|
||||
class AlertType(str, Enum):
|
||||
HIGH_DEMAND = "high_demand"
|
||||
LOW_DEMAND = "low_demand"
|
||||
STOCKOUT_RISK = "stockout_risk"
|
||||
OVERPRODUCTION = "overproduction"
|
||||
|
||||
class ForecastRequest(BaseModel):
|
||||
"""Request schema for generating forecasts"""
|
||||
tenant_id: str = Field(..., description="Tenant ID")
|
||||
product_name: str = Field(..., description="Product name")
|
||||
location: str = Field(..., description="Location identifier")
|
||||
forecast_date: date = Field(..., description="Date for which to generate forecast")
|
||||
business_type: BusinessType = Field(BusinessType.INDIVIDUAL, description="Business model type")
|
||||
|
||||
# Optional context
|
||||
include_weather: bool = Field(True, description="Include weather data in forecast")
|
||||
include_traffic: bool = Field(True, description="Include traffic data in forecast")
|
||||
confidence_level: float = Field(0.8, ge=0.5, le=0.95, description="Confidence level for intervals")
|
||||
|
||||
@validator('forecast_date')
|
||||
def validate_forecast_date(cls, v):
|
||||
if v < date.today():
|
||||
raise ValueError("Forecast date cannot be in the past")
|
||||
return v
|
||||
|
||||
class BatchForecastRequest(BaseModel):
|
||||
"""Request schema for batch forecasting"""
|
||||
tenant_id: str = Field(..., description="Tenant ID")
|
||||
batch_name: str = Field(..., description="Batch name for tracking")
|
||||
products: List[str] = Field(..., description="List of product names")
|
||||
location: str = Field(..., description="Location identifier")
|
||||
forecast_days: int = Field(7, ge=1, le=30, description="Number of days to forecast")
|
||||
business_type: BusinessType = Field(BusinessType.INDIVIDUAL, description="Business model type")
|
||||
|
||||
# Options
|
||||
include_weather: bool = Field(True, description="Include weather data")
|
||||
include_traffic: bool = Field(True, description="Include traffic data")
|
||||
confidence_level: float = Field(0.8, ge=0.5, le=0.95, description="Confidence level")
|
||||
|
||||
class ForecastResponse(BaseModel):
|
||||
"""Response schema for forecast results"""
|
||||
id: str
|
||||
tenant_id: str
|
||||
product_name: str
|
||||
location: str
|
||||
forecast_date: datetime
|
||||
|
||||
# Predictions
|
||||
predicted_demand: float
|
||||
confidence_lower: float
|
||||
confidence_upper: float
|
||||
confidence_level: float
|
||||
|
||||
# Model info
|
||||
model_id: str
|
||||
model_version: str
|
||||
algorithm: str
|
||||
|
||||
# Context
|
||||
business_type: str
|
||||
is_holiday: bool
|
||||
is_weekend: bool
|
||||
day_of_week: int
|
||||
|
||||
# External factors
|
||||
weather_temperature: Optional[float]
|
||||
weather_precipitation: Optional[float]
|
||||
weather_description: Optional[str]
|
||||
traffic_volume: Optional[int]
|
||||
|
||||
# Metadata
|
||||
created_at: datetime
|
||||
processing_time_ms: Optional[int]
|
||||
features_used: Optional[Dict[str, Any]]
|
||||
|
||||
class BatchForecastResponse(BaseModel):
|
||||
"""Response schema for batch forecast requests"""
|
||||
id: str
|
||||
tenant_id: str
|
||||
batch_name: str
|
||||
status: str
|
||||
total_products: int
|
||||
completed_products: int
|
||||
failed_products: int
|
||||
|
||||
# Timing
|
||||
requested_at: datetime
|
||||
completed_at: Optional[datetime]
|
||||
processing_time_ms: Optional[int]
|
||||
|
||||
# Results
|
||||
forecasts: Optional[List[ForecastResponse]]
|
||||
error_message: Optional[str]
|
||||
|
||||
class AlertResponse(BaseModel):
|
||||
"""Response schema for forecast alerts"""
|
||||
id: str
|
||||
tenant_id: str
|
||||
forecast_id: str
|
||||
alert_type: str
|
||||
severity: str
|
||||
message: str
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
acknowledged_at: Optional[datetime]
|
||||
notification_sent: bool
|
||||
|
||||
438
services/forecasting/app/services/forecasting_service.py
Normal file
438
services/forecasting/app/services/forecasting_service.py
Normal file
@@ -0,0 +1,438 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/services/forecasting_service.py
|
||||
# ================================================================
|
||||
"""
|
||||
Main forecasting service business logic
|
||||
Orchestrates demand prediction operations
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from typing import Dict, List, Any, Optional
|
||||
from datetime import datetime, date, timedelta
|
||||
import asyncio
|
||||
import uuid
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, desc
|
||||
import httpx
|
||||
|
||||
from app.models.forecasts import Forecast, PredictionBatch, ForecastAlert
|
||||
from app.schemas.forecasts import ForecastRequest, BatchForecastRequest, BusinessType
|
||||
from app.services.prediction_service import PredictionService
|
||||
from app.services.messaging import publish_forecast_completed, publish_alert_created
|
||||
from app.core.config import settings
|
||||
from shared.monitoring.metrics import MetricsCollector
|
||||
|
||||
logger = structlog.get_logger()
|
||||
metrics = MetricsCollector("forecasting-service")
|
||||
|
||||
class ForecastingService:
|
||||
"""
|
||||
Main service class for managing forecasting operations.
|
||||
Handles demand prediction, batch processing, and alert generation.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.prediction_service = PredictionService()
|
||||
|
||||
async def generate_forecast(self, request: ForecastRequest, db: AsyncSession) -> Forecast:
|
||||
"""Generate a single forecast for a product"""
|
||||
start_time = datetime.now()
|
||||
|
||||
try:
|
||||
logger.info("Generating forecast",
|
||||
tenant_id=request.tenant_id,
|
||||
product=request.product_name,
|
||||
date=request.forecast_date)
|
||||
|
||||
# Get the latest trained model for this tenant/product
|
||||
model_info = await self._get_latest_model(
|
||||
request.tenant_id,
|
||||
request.product_name,
|
||||
request.location
|
||||
)
|
||||
|
||||
if not model_info:
|
||||
raise ValueError(f"No trained model found for {request.product_name}")
|
||||
|
||||
# Prepare features for prediction
|
||||
features = await self._prepare_forecast_features(request)
|
||||
|
||||
# Generate prediction using ML service
|
||||
prediction_result = await self.prediction_service.predict(
|
||||
model_id=model_info["model_id"],
|
||||
features=features,
|
||||
confidence_level=request.confidence_level
|
||||
)
|
||||
|
||||
# Create forecast record
|
||||
forecast = Forecast(
|
||||
tenant_id=uuid.UUID(request.tenant_id),
|
||||
product_name=request.product_name,
|
||||
location=request.location,
|
||||
forecast_date=datetime.combine(request.forecast_date, datetime.min.time()),
|
||||
|
||||
# Prediction results
|
||||
predicted_demand=prediction_result["demand"],
|
||||
confidence_lower=prediction_result["lower_bound"],
|
||||
confidence_upper=prediction_result["upper_bound"],
|
||||
confidence_level=request.confidence_level,
|
||||
|
||||
# Model information
|
||||
model_id=uuid.UUID(model_info["model_id"]),
|
||||
model_version=model_info["version"],
|
||||
algorithm=model_info.get("algorithm", "prophet"),
|
||||
|
||||
# Context
|
||||
business_type=request.business_type.value,
|
||||
day_of_week=request.forecast_date.weekday(),
|
||||
is_holiday=features.get("is_holiday", False),
|
||||
is_weekend=request.forecast_date.weekday() >= 5,
|
||||
|
||||
# External factors
|
||||
weather_temperature=features.get("temperature"),
|
||||
weather_precipitation=features.get("precipitation"),
|
||||
weather_description=features.get("weather_description"),
|
||||
traffic_volume=features.get("traffic_volume"),
|
||||
|
||||
# Metadata
|
||||
processing_time_ms=int((datetime.now() - start_time).total_seconds() * 1000),
|
||||
features_used=features
|
||||
)
|
||||
|
||||
db.add(forecast)
|
||||
await db.commit()
|
||||
await db.refresh(forecast)
|
||||
|
||||
# Check for alerts
|
||||
await self._check_and_create_alerts(forecast, db)
|
||||
|
||||
# Update metrics
|
||||
metrics.increment_counter("forecasts_generated_total",
|
||||
{"product": request.product_name, "location": request.location})
|
||||
|
||||
# Publish event
|
||||
await publish_forecast_completed({
|
||||
"forecast_id": str(forecast.id),
|
||||
"tenant_id": request.tenant_id,
|
||||
"product_name": request.product_name,
|
||||
"predicted_demand": forecast.predicted_demand
|
||||
})
|
||||
|
||||
logger.info("Forecast generated successfully",
|
||||
forecast_id=str(forecast.id),
|
||||
predicted_demand=forecast.predicted_demand)
|
||||
|
||||
return forecast
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error generating forecast",
|
||||
error=str(e),
|
||||
tenant_id=request.tenant_id,
|
||||
product=request.product_name)
|
||||
raise
|
||||
|
||||
async def generate_batch_forecast(self, request: BatchForecastRequest, db: AsyncSession) -> PredictionBatch:
|
||||
"""Generate forecasts for multiple products over multiple days"""
|
||||
|
||||
try:
|
||||
logger.info("Starting batch forecast generation",
|
||||
tenant_id=request.tenant_id,
|
||||
batch_name=request.batch_name,
|
||||
products_count=len(request.products),
|
||||
forecast_days=request.forecast_days)
|
||||
|
||||
# Create batch record
|
||||
batch = PredictionBatch(
|
||||
tenant_id=uuid.UUID(request.tenant_id),
|
||||
batch_name=request.batch_name,
|
||||
status="processing",
|
||||
total_products=len(request.products) * request.forecast_days,
|
||||
business_type=request.business_type.value,
|
||||
forecast_days=request.forecast_days
|
||||
)
|
||||
|
||||
db.add(batch)
|
||||
await db.commit()
|
||||
await db.refresh(batch)
|
||||
|
||||
# Generate forecasts for each product and day
|
||||
completed_count = 0
|
||||
failed_count = 0
|
||||
|
||||
for product in request.products:
|
||||
for day_offset in range(request.forecast_days):
|
||||
forecast_date = date.today() + timedelta(days=day_offset + 1)
|
||||
|
||||
try:
|
||||
forecast_request = ForecastRequest(
|
||||
tenant_id=request.tenant_id,
|
||||
product_name=product,
|
||||
location=request.location,
|
||||
forecast_date=forecast_date,
|
||||
business_type=request.business_type,
|
||||
include_weather=request.include_weather,
|
||||
include_traffic=request.include_traffic,
|
||||
confidence_level=request.confidence_level
|
||||
)
|
||||
|
||||
await self.generate_forecast(forecast_request, db)
|
||||
completed_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to generate forecast for product",
|
||||
product=product,
|
||||
date=forecast_date,
|
||||
error=str(e))
|
||||
failed_count += 1
|
||||
|
||||
# Update batch status
|
||||
batch.status = "completed" if failed_count == 0 else "partial"
|
||||
batch.completed_products = completed_count
|
||||
batch.failed_products = failed_count
|
||||
batch.completed_at = datetime.now()
|
||||
|
||||
await db.commit()
|
||||
|
||||
logger.info("Batch forecast generation completed",
|
||||
batch_id=str(batch.id),
|
||||
completed=completed_count,
|
||||
failed=failed_count)
|
||||
|
||||
return batch
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in batch forecast generation", error=str(e))
|
||||
raise
|
||||
|
||||
async def get_forecasts(self, tenant_id: str, location: str,
|
||||
start_date: Optional[date] = None,
|
||||
end_date: Optional[date] = None,
|
||||
product_name: Optional[str] = None,
|
||||
db: AsyncSession = None) -> List[Forecast]:
|
||||
"""Retrieve forecasts with filtering"""
|
||||
|
||||
try:
|
||||
query = select(Forecast).where(
|
||||
and_(
|
||||
Forecast.tenant_id == uuid.UUID(tenant_id),
|
||||
Forecast.location == location
|
||||
)
|
||||
)
|
||||
|
||||
if start_date:
|
||||
query = query.where(Forecast.forecast_date >= datetime.combine(start_date, datetime.min.time()))
|
||||
|
||||
if end_date:
|
||||
query = query.where(Forecast.forecast_date <= datetime.combine(end_date, datetime.max.time()))
|
||||
|
||||
if product_name:
|
||||
query = query.where(Forecast.product_name == product_name)
|
||||
|
||||
query = query.order_by(desc(Forecast.forecast_date))
|
||||
|
||||
result = await db.execute(query)
|
||||
forecasts = result.scalars().all()
|
||||
|
||||
logger.info("Retrieved forecasts",
|
||||
tenant_id=tenant_id,
|
||||
count=len(forecasts))
|
||||
|
||||
return list(forecasts)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error retrieving forecasts", error=str(e))
|
||||
raise
|
||||
|
||||
async def _get_latest_model(self, tenant_id: str, product_name: str, location: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get the latest trained model for a tenant/product combination"""
|
||||
|
||||
try:
|
||||
# Call training service to get model information
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.TRAINING_SERVICE_URL}/api/v1/models/latest",
|
||||
params={
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name,
|
||||
"location": location
|
||||
},
|
||||
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
elif response.status_code == 404:
|
||||
logger.warning("No model found",
|
||||
tenant_id=tenant_id,
|
||||
product=product_name)
|
||||
return None
|
||||
else:
|
||||
response.raise_for_status()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting latest model", error=str(e))
|
||||
raise
|
||||
|
||||
async def _prepare_forecast_features(self, request: ForecastRequest) -> Dict[str, Any]:
|
||||
"""Prepare features for forecasting model"""
|
||||
|
||||
features = {
|
||||
"date": request.forecast_date.isoformat(),
|
||||
"day_of_week": request.forecast_date.weekday(),
|
||||
"is_weekend": request.forecast_date.weekday() >= 5,
|
||||
"business_type": request.business_type.value
|
||||
}
|
||||
|
||||
# Add Spanish holidays
|
||||
features["is_holiday"] = await self._is_spanish_holiday(request.forecast_date)
|
||||
|
||||
# Add weather data if requested
|
||||
if request.include_weather:
|
||||
weather_data = await self._get_weather_forecast(request.forecast_date)
|
||||
features.update(weather_data)
|
||||
|
||||
# Add traffic data if requested
|
||||
if request.include_traffic:
|
||||
traffic_data = await self._get_traffic_forecast(request.forecast_date, request.location)
|
||||
features.update(traffic_data)
|
||||
|
||||
return features
|
||||
|
||||
async def _is_spanish_holiday(self, forecast_date: date) -> bool:
|
||||
"""Check if date is a Spanish holiday"""
|
||||
|
||||
try:
|
||||
# Call data service for holiday information
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.DATA_SERVICE_URL}/api/v1/holidays/check",
|
||||
params={"date": forecast_date.isoformat()},
|
||||
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json().get("is_holiday", False)
|
||||
else:
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Error checking holiday status", error=str(e))
|
||||
return False
|
||||
|
||||
async def _get_weather_forecast(self, forecast_date: date) -> Dict[str, Any]:
|
||||
"""Get weather forecast for the date"""
|
||||
|
||||
try:
|
||||
# Call data service for weather forecast
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.DATA_SERVICE_URL}/api/v1/weather/forecast",
|
||||
params={"date": forecast_date.isoformat()},
|
||||
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
weather = response.json()
|
||||
return {
|
||||
"temperature": weather.get("temperature"),
|
||||
"precipitation": weather.get("precipitation"),
|
||||
"humidity": weather.get("humidity"),
|
||||
"weather_description": weather.get("description")
|
||||
}
|
||||
else:
|
||||
return {}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Error getting weather forecast", error=str(e))
|
||||
return {}
|
||||
|
||||
async def _get_traffic_forecast(self, forecast_date: date, location: str) -> Dict[str, Any]:
|
||||
"""Get traffic forecast for the date and location"""
|
||||
|
||||
try:
|
||||
# Call data service for traffic forecast
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.DATA_SERVICE_URL}/api/v1/traffic/forecast",
|
||||
params={
|
||||
"date": forecast_date.isoformat(),
|
||||
"location": location
|
||||
},
|
||||
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
traffic = response.json()
|
||||
return {
|
||||
"traffic_volume": traffic.get("volume"),
|
||||
"pedestrian_count": traffic.get("pedestrian_count")
|
||||
}
|
||||
else:
|
||||
return {}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Error getting traffic forecast", error=str(e))
|
||||
return {}
|
||||
|
||||
async def _check_and_create_alerts(self, forecast: Forecast, db: AsyncSession):
|
||||
"""Check forecast and create alerts if needed"""
|
||||
|
||||
try:
|
||||
alerts_to_create = []
|
||||
|
||||
# High demand alert
|
||||
if forecast.predicted_demand > settings.HIGH_DEMAND_THRESHOLD * 100: # Assuming base of 100 units
|
||||
alerts_to_create.append({
|
||||
"type": "high_demand",
|
||||
"severity": "medium",
|
||||
"message": f"High demand predicted for {forecast.product_name}: {forecast.predicted_demand:.0f} units"
|
||||
})
|
||||
|
||||
# Low demand alert
|
||||
if forecast.predicted_demand < settings.LOW_DEMAND_THRESHOLD * 100:
|
||||
alerts_to_create.append({
|
||||
"type": "low_demand",
|
||||
"severity": "low",
|
||||
"message": f"Low demand predicted for {forecast.product_name}: {forecast.predicted_demand:.0f} units"
|
||||
})
|
||||
|
||||
# Stockout risk alert
|
||||
if forecast.confidence_upper > settings.STOCKOUT_RISK_THRESHOLD * forecast.predicted_demand:
|
||||
alerts_to_create.append({
|
||||
"type": "stockout_risk",
|
||||
"severity": "high",
|
||||
"message": f"Stockout risk for {forecast.product_name}. Upper confidence: {forecast.confidence_upper:.0f}"
|
||||
})
|
||||
|
||||
# Create alerts
|
||||
for alert_data in alerts_to_create:
|
||||
alert = ForecastAlert(
|
||||
tenant_id=forecast.tenant_id,
|
||||
forecast_id=forecast.id,
|
||||
alert_type=alert_data["type"],
|
||||
severity=alert_data["severity"],
|
||||
message=alert_data["message"]
|
||||
)
|
||||
|
||||
db.add(alert)
|
||||
|
||||
# Publish alert event
|
||||
await publish_alert_created({
|
||||
"alert_id": str(alert.id),
|
||||
"tenant_id": str(forecast.tenant_id),
|
||||
"product_name": forecast.product_name,
|
||||
"alert_type": alert_data["type"],
|
||||
"severity": alert_data["severity"],
|
||||
"message": alert_data["message"]
|
||||
})
|
||||
|
||||
await db.commit()
|
||||
|
||||
if alerts_to_create:
|
||||
logger.info("Created forecast alerts",
|
||||
forecast_id=str(forecast.id),
|
||||
alerts_count=len(alerts_to_create))
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error creating alerts", error=str(e))
|
||||
# Don't raise - alerts are not critical for forecast generation
|
||||
98
services/forecasting/app/services/messaging.py
Normal file
98
services/forecasting/app/services/messaging.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/services/messaging.py
|
||||
# ================================================================
|
||||
"""
|
||||
Messaging service for event publishing and consuming
|
||||
"""
|
||||
|
||||
import structlog
|
||||
import json
|
||||
from typing import Dict, Any
|
||||
import asyncio
|
||||
|
||||
from shared.messaging.rabbitmq import RabbitMQPublisher, RabbitMQConsumer
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Global messaging instances
|
||||
publisher = None
|
||||
consumer = None
|
||||
|
||||
async def setup_messaging():
|
||||
"""Initialize messaging services"""
|
||||
global publisher, consumer
|
||||
|
||||
try:
|
||||
# Initialize publisher
|
||||
publisher = RabbitMQPublisher(settings.RABBITMQ_URL)
|
||||
await publisher.connect()
|
||||
|
||||
# Initialize consumer
|
||||
consumer = RabbitMQConsumer(settings.RABBITMQ_URL)
|
||||
await consumer.connect()
|
||||
|
||||
# Set up event handlers
|
||||
await consumer.subscribe("training.model.updated", handle_model_updated)
|
||||
await consumer.subscribe("data.weather.updated", handle_weather_updated)
|
||||
|
||||
logger.info("Messaging setup completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to setup messaging", error=str(e))
|
||||
raise
|
||||
|
||||
async def cleanup_messaging():
|
||||
"""Cleanup messaging connections"""
|
||||
global publisher, consumer
|
||||
|
||||
try:
|
||||
if consumer:
|
||||
await consumer.close()
|
||||
if publisher:
|
||||
await publisher.close()
|
||||
|
||||
logger.info("Messaging cleanup completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error during messaging cleanup", error=str(e))
|
||||
|
||||
async def publish_forecast_completed(data: Dict[str, Any]):
|
||||
"""Publish forecast completed event"""
|
||||
if publisher:
|
||||
await publisher.publish("forecasting.forecast.completed", data)
|
||||
|
||||
async def publish_alert_created(data: Dict[str, Any]):
|
||||
"""Publish alert created event"""
|
||||
if publisher:
|
||||
await publisher.publish("forecasting.alert.created", data)
|
||||
|
||||
async def publish_batch_completed(data: Dict[str, Any]):
|
||||
"""Publish batch forecast completed event"""
|
||||
if publisher:
|
||||
await publisher.publish("forecasting.batch.completed", data)
|
||||
|
||||
# Event handlers
|
||||
async def handle_model_updated(data: Dict[str, Any]):
|
||||
"""Handle model updated event from training service"""
|
||||
try:
|
||||
logger.info("Received model updated event",
|
||||
model_id=data.get("model_id"),
|
||||
tenant_id=data.get("tenant_id"))
|
||||
|
||||
# Clear model cache for this model
|
||||
# This will be handled by PredictionService
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error handling model updated event", error=str(e))
|
||||
|
||||
async def handle_weather_updated(data: Dict[str, Any]):
|
||||
"""Handle weather data updated event"""
|
||||
try:
|
||||
logger.info("Received weather updated event",
|
||||
date=data.get("date"))
|
||||
|
||||
# Could trigger re-forecasting if needed
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error handling weather updated event", error=str(e))
|
||||
166
services/forecasting/app/services/prediction_service.py
Normal file
166
services/forecasting/app/services/prediction_service.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/services/prediction_service.py
|
||||
# ================================================================
|
||||
"""
|
||||
Prediction service for loading models and generating predictions
|
||||
Handles the actual ML prediction logic
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from typing import Dict, List, Any, Optional
|
||||
import asyncio
|
||||
import pickle
|
||||
import json
|
||||
from datetime import datetime, date
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import httpx
|
||||
from pathlib import Path
|
||||
|
||||
from app.core.config import settings
|
||||
from shared.monitoring.metrics import MetricsCollector
|
||||
|
||||
logger = structlog.get_logger()
|
||||
metrics = MetricsCollector("forecasting-service")
|
||||
|
||||
class PredictionService:
|
||||
"""
|
||||
Service for loading ML models and generating predictions
|
||||
Interfaces with trained Prophet models from the training service
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.model_cache = {}
|
||||
self.cache_ttl = 3600 # 1 hour cache
|
||||
|
||||
async def predict(self, model_id: str, features: Dict[str, Any],
|
||||
confidence_level: float = 0.8) -> Dict[str, float]:
|
||||
"""Generate prediction using trained model"""
|
||||
|
||||
start_time = datetime.now()
|
||||
|
||||
try:
|
||||
logger.info("Generating prediction",
|
||||
model_id=model_id,
|
||||
features_count=len(features))
|
||||
|
||||
# Load model
|
||||
model = await self._load_model(model_id)
|
||||
|
||||
if not model:
|
||||
raise ValueError(f"Model {model_id} not found or failed to load")
|
||||
|
||||
# Prepare features for Prophet
|
||||
df = self._prepare_prophet_features(features)
|
||||
|
||||
# Generate prediction
|
||||
forecast = model.predict(df)
|
||||
|
||||
# Extract prediction results
|
||||
if len(forecast) > 0:
|
||||
row = forecast.iloc[0]
|
||||
result = {
|
||||
"demand": float(row['yhat']),
|
||||
"lower_bound": float(row[f'yhat_lower']),
|
||||
"upper_bound": float(row[f'yhat_upper']),
|
||||
"trend": float(row.get('trend', 0)),
|
||||
"seasonal": float(row.get('seasonal', 0)),
|
||||
"holiday": float(row.get('holidays', 0))
|
||||
}
|
||||
else:
|
||||
raise ValueError("No prediction generated from model")
|
||||
|
||||
# Update metrics
|
||||
processing_time = (datetime.now() - start_time).total_seconds()
|
||||
metrics.histogram_observe("forecast_processing_time_seconds", processing_time)
|
||||
|
||||
logger.info("Prediction generated successfully",
|
||||
model_id=model_id,
|
||||
predicted_demand=result["demand"],
|
||||
processing_time_ms=int(processing_time * 1000))
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error generating prediction",
|
||||
model_id=model_id,
|
||||
error=str(e))
|
||||
raise
|
||||
|
||||
async def _load_model(self, model_id: str):
|
||||
"""Load model from cache or training service"""
|
||||
|
||||
# Check cache first
|
||||
if model_id in self.model_cache:
|
||||
cached_model, cached_time = self.model_cache[model_id]
|
||||
if (datetime.now() - cached_time).seconds < self.cache_ttl:
|
||||
logger.debug("Using cached model", model_id=model_id)
|
||||
return cached_model
|
||||
|
||||
try:
|
||||
# Download model from training service
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(
|
||||
f"{settings.TRAINING_SERVICE_URL}/api/v1/models/{model_id}/download",
|
||||
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
# Load model from bytes
|
||||
model_data = response.content
|
||||
model = pickle.loads(model_data)
|
||||
|
||||
# Cache the model
|
||||
self.model_cache[model_id] = (model, datetime.now())
|
||||
|
||||
logger.info("Model loaded successfully", model_id=model_id)
|
||||
return model
|
||||
else:
|
||||
logger.error("Failed to download model",
|
||||
model_id=model_id,
|
||||
status_code=response.status_code)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error loading model", model_id=model_id, error=str(e))
|
||||
return None
|
||||
|
||||
def _prepare_prophet_features(self, features: Dict[str, Any]) -> pd.DataFrame:
|
||||
"""Convert features to Prophet-compatible DataFrame"""
|
||||
|
||||
try:
|
||||
# Create base DataFrame with required 'ds' column
|
||||
df = pd.DataFrame({
|
||||
'ds': [pd.to_datetime(features['date'])]
|
||||
})
|
||||
|
||||
# Add numeric features
|
||||
numeric_features = [
|
||||
'temperature', 'precipitation', 'humidity', 'wind_speed',
|
||||
'traffic_volume', 'pedestrian_count'
|
||||
]
|
||||
|
||||
for feature in numeric_features:
|
||||
if feature in features and features[feature] is not None:
|
||||
df[feature] = float(features[feature])
|
||||
else:
|
||||
df[feature] = 0.0
|
||||
|
||||
# Add categorical features
|
||||
df['day_of_week'] = int(features.get('day_of_week', 0))
|
||||
df['is_weekend'] = int(features.get('is_weekend', False))
|
||||
df['is_holiday'] = int(features.get('is_holiday', False))
|
||||
|
||||
# Business type encoding
|
||||
business_type = features.get('business_type', 'individual')
|
||||
df['is_central_workshop'] = int(business_type == 'central_workshop')
|
||||
|
||||
logger.debug("Prepared Prophet features",
|
||||
features_count=len(df.columns),
|
||||
date=features['date'])
|
||||
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error preparing Prophet features", error=str(e))
|
||||
raise
|
||||
Reference in New Issue
Block a user