Improve the UI and training

This commit is contained in:
Urtzi Alfaro
2025-11-15 15:20:10 +01:00
parent c349b845a6
commit 843cd2bf5c
19 changed files with 2073 additions and 233 deletions

View File

@@ -4,7 +4,8 @@ Forecasting Operations API - Business operations for forecast generation and pre
"""
import structlog
from fastapi import APIRouter, Depends, HTTPException, status, Query, Path, Request
from fastapi import APIRouter, Depends, HTTPException, status, Query, Path, Request, BackgroundTasks
from fastapi.responses import JSONResponse
from typing import List, Dict, Any, Optional
from datetime import date, datetime, timezone
import uuid
@@ -202,6 +203,97 @@ async def generate_multi_day_forecast(
)
async def execute_batch_forecast_background(
tenant_id: str,
batch_id: str,
inventory_product_ids: List[str],
forecast_days: int,
batch_name: str
):
"""
Background task for batch forecast generation.
Prevents blocking the API thread for long-running batch operations.
"""
logger.info("Starting background batch forecast",
batch_id=batch_id,
tenant_id=tenant_id,
product_count=len(inventory_product_ids))
database_manager = create_database_manager(settings.DATABASE_URL, "forecasting-service")
forecasting_service = EnhancedForecastingService(database_manager)
try:
# Update batch status to running
async with database_manager.get_session() as session:
from app.repositories import PredictionBatchRepository
batch_repo = PredictionBatchRepository(session)
await batch_repo.update(
batch_id,
{"status": "processing", "completed_products": 0}
)
await session.commit()
# Generate forecasts for all products
from app.schemas.forecasts import BatchForecastRequest
batch_request = BatchForecastRequest(
tenant_id=tenant_id,
batch_name=batch_name,
inventory_product_ids=inventory_product_ids,
forecast_days=forecast_days
)
result = await forecasting_service.generate_batch_forecasts(
tenant_id=tenant_id,
request=batch_request
)
# Update batch status to completed
async with database_manager.get_session() as session:
from app.repositories import PredictionBatchRepository
batch_repo = PredictionBatchRepository(session)
await batch_repo.update(
batch_id,
{
"status": "completed",
"completed_at": datetime.now(timezone.utc),
"completed_products": result.get("successful_forecasts", 0),
"failed_products": result.get("failed_forecasts", 0)
}
)
await session.commit()
logger.info("Background batch forecast completed",
batch_id=batch_id,
successful=result.get("successful_forecasts", 0),
failed=result.get("failed_forecasts", 0))
except Exception as e:
logger.error("Background batch forecast failed",
batch_id=batch_id,
error=str(e))
try:
async with database_manager.get_session() as session:
from app.repositories import PredictionBatchRepository
batch_repo = PredictionBatchRepository(session)
await batch_repo.update(
batch_id,
{
"status": "failed",
"completed_at": datetime.now(timezone.utc),
"error_message": str(e)
}
)
await session.commit()
except Exception as update_error:
logger.error("Failed to update batch status after error",
batch_id=batch_id,
error=str(update_error))
@router.post(
route_builder.build_operations_route("batch"),
response_model=BatchForecastResponse
@@ -211,11 +303,17 @@ async def generate_multi_day_forecast(
async def generate_batch_forecast(
request: BatchForecastRequest,
tenant_id: str = Path(..., description="Tenant ID"),
background_tasks: BackgroundTasks = BackgroundTasks(),
request_obj: Request = None,
current_user: dict = Depends(get_current_user_dep),
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
):
"""Generate forecasts for multiple products in batch (Admin+ only, quota enforced)"""
"""
Generate forecasts for multiple products in batch (Admin+ only, quota enforced).
IMPROVEMENT: Now uses background tasks for large batches to prevent API timeouts.
Returns immediately with batch_id for status tracking.
"""
metrics = get_metrics_collector(request_obj)
try:
@@ -258,48 +356,104 @@ async def generate_batch_forecast(
error_message=None
)
# Skip rate limiting for service-to-service calls (orchestrator)
# Rate limiting is handled at the gateway level for user requests
# IMPROVEMENT: For large batches (>5 products), use background task
# For small batches, execute synchronously for immediate results
batch_name = getattr(request, 'batch_name', f"batch-{datetime.now().strftime('%Y%m%d_%H%M%S')}")
forecast_days = getattr(request, 'forecast_days', 7)
# Create a copy of the request with the actual list of product IDs to forecast
# (whether originally provided or fetched from inventory service)
from app.schemas.forecasts import BatchForecastRequest
updated_request = BatchForecastRequest(
tenant_id=tenant_id, # Use the tenant_id from the path parameter
batch_name=getattr(request, 'batch_name', f"orchestrator-batch-{datetime.now().strftime('%Y%m%d')}"),
inventory_product_ids=inventory_product_ids,
forecast_days=getattr(request, 'forecast_days', 7)
)
batch_result = await enhanced_forecasting_service.generate_batch_forecasts(
tenant_id=tenant_id,
request=updated_request
)
if metrics:
metrics.increment_counter("batch_forecasts_success_total")
logger.info("Batch forecast generated successfully",
tenant_id=tenant_id,
total_forecasts=batch_result.get('total_forecasts', 0))
# Convert the service result to BatchForecastResponse format
from app.schemas.forecasts import BatchForecastResponse
# Create batch record first
batch_id = str(uuid.uuid4())
now = datetime.now(timezone.utc)
return BatchForecastResponse(
id=batch_result.get('id', str(uuid.uuid4())), # Use 'id' field (UUID) instead of 'batch_id' (string)
tenant_id=tenant_id,
batch_name=updated_request.batch_name,
status="completed",
total_products=batch_result.get('total_forecasts', 0),
completed_products=batch_result.get('successful_forecasts', 0),
failed_products=batch_result.get('failed_forecasts', 0),
requested_at=now,
completed_at=now,
processing_time_ms=0,
forecasts=[],
error_message=None
)
async with enhanced_forecasting_service.database_manager.get_session() as session:
from app.repositories import PredictionBatchRepository
batch_repo = PredictionBatchRepository(session)
batch_data = {
"tenant_id": tenant_id,
"batch_name": batch_name,
"total_products": len(inventory_product_ids),
"forecast_days": forecast_days,
"status": "pending"
}
batch = await batch_repo.create_batch(batch_data)
batch_id = str(batch.id)
await session.commit()
# Use background task for large batches to prevent API timeout
use_background = len(inventory_product_ids) > 5
if use_background:
# Queue background task
background_tasks.add_task(
execute_batch_forecast_background,
tenant_id=tenant_id,
batch_id=batch_id,
inventory_product_ids=inventory_product_ids,
forecast_days=forecast_days,
batch_name=batch_name
)
logger.info("Batch forecast queued for background processing",
tenant_id=tenant_id,
batch_id=batch_id,
product_count=len(inventory_product_ids))
# Return immediately with pending status
from app.schemas.forecasts import BatchForecastResponse
return BatchForecastResponse(
id=batch_id,
tenant_id=tenant_id,
batch_name=batch_name,
status="pending",
total_products=len(inventory_product_ids),
completed_products=0,
failed_products=0,
requested_at=now,
completed_at=None,
processing_time_ms=0,
forecasts=None,
error_message=None
)
else:
# Small batch - execute synchronously
from app.schemas.forecasts import BatchForecastRequest
updated_request = BatchForecastRequest(
tenant_id=tenant_id,
batch_name=batch_name,
inventory_product_ids=inventory_product_ids,
forecast_days=forecast_days
)
batch_result = await enhanced_forecasting_service.generate_batch_forecasts(
tenant_id=tenant_id,
request=updated_request
)
if metrics:
metrics.increment_counter("batch_forecasts_success_total")
logger.info("Batch forecast completed synchronously",
tenant_id=tenant_id,
total_forecasts=batch_result.get('total_forecasts', 0))
# Convert the service result to BatchForecastResponse format
from app.schemas.forecasts import BatchForecastResponse
return BatchForecastResponse(
id=batch_id,
tenant_id=tenant_id,
batch_name=batch_name,
status="completed",
total_products=batch_result.get('total_forecasts', 0),
completed_products=batch_result.get('successful_forecasts', 0),
failed_products=batch_result.get('failed_forecasts', 0),
requested_at=now,
completed_at=datetime.now(timezone.utc),
processing_time_ms=0,
forecasts=[],
error_message=None
)
except ValueError as e:
if metrics:
@@ -806,3 +960,50 @@ async def preview_tenant_data_deletion(
status_code=500,
detail=f"Failed to preview tenant data deletion: {str(e)}"
)
@router.get("/health/database")
async def database_health():
"""
Database health check endpoint with connection pool monitoring.
Returns detailed connection pool statistics for monitoring and alerting.
Useful for detecting connection pool exhaustion before it causes issues.
"""
from app.core.database import get_db_health, get_connection_pool_stats
from datetime import datetime
try:
# Check database connectivity
db_healthy = await get_db_health()
# Get connection pool statistics
pool_stats = await get_connection_pool_stats()
response = {
"service": "forecasting",
"timestamp": datetime.now(timezone.utc).isoformat(),
"database_connected": db_healthy,
"connection_pool": pool_stats,
"overall_status": "healthy" if db_healthy and pool_stats.get("status") == "healthy" else "degraded"
}
# Return appropriate status code based on health
if not db_healthy or pool_stats.get("status") == "critical":
return JSONResponse(status_code=503, content=response)
elif pool_stats.get("status") == "warning":
return JSONResponse(status_code=200, content=response)
else:
return response
except Exception as e:
logger.error("Health check failed", error=str(e))
return JSONResponse(
status_code=503,
content={
"service": "forecasting",
"timestamp": datetime.now(timezone.utc).isoformat(),
"overall_status": "unhealthy",
"error": str(e)
}
)

View File

@@ -70,6 +70,47 @@ async def get_db_health() -> bool:
logger.error("Database health check failed", error=str(e))
return False
async def get_connection_pool_stats() -> dict:
"""
Get current connection pool statistics for monitoring.
Returns:
Dictionary with pool statistics including usage and capacity
"""
try:
pool = async_engine.pool
# Get pool stats
stats = {
"pool_size": pool.size(),
"checked_in_connections": pool.checkedin(),
"checked_out_connections": pool.checkedout(),
"overflow_connections": pool.overflow(),
"total_connections": pool.size() + pool.overflow(),
"max_capacity": 10 + 20, # pool_size + max_overflow
"usage_percentage": round(((pool.size() + pool.overflow()) / 30) * 100, 2)
}
# Add health status
if stats["usage_percentage"] > 90:
stats["status"] = "critical"
stats["message"] = "Connection pool near capacity"
elif stats["usage_percentage"] > 80:
stats["status"] = "warning"
stats["message"] = "Connection pool usage high"
else:
stats["status"] = "healthy"
stats["message"] = "Connection pool healthy"
return stats
except Exception as e:
logger.error("Failed to get connection pool stats", error=str(e))
return {
"status": "error",
"message": f"Failed to get pool stats: {str(e)}"
}
# Database manager instance for service_base compatibility
database_manager = DatabaseManager(
database_url=settings.DATABASE_URL,

View File

@@ -5,7 +5,7 @@
Forecast models for the forecasting service
"""
from sqlalchemy import Column, String, Integer, Float, DateTime, Boolean, Text, JSON
from sqlalchemy import Column, String, Integer, Float, DateTime, Boolean, Text, JSON, UniqueConstraint, Index
from sqlalchemy.dialects.postgresql import UUID
from datetime import datetime, timezone
import uuid
@@ -15,7 +15,18 @@ from shared.database.base import Base
class Forecast(Base):
"""Forecast model for storing prediction results"""
__tablename__ = "forecasts"
__table_args__ = (
# Unique constraint to prevent duplicate forecasts
# Ensures only one forecast per (tenant, product, date, location) combination
UniqueConstraint(
'tenant_id', 'inventory_product_id', 'forecast_date', 'location',
name='uq_forecast_tenant_product_date_location'
),
# Composite index for common query patterns
Index('ix_forecasts_tenant_product_date', 'tenant_id', 'inventory_product_id', 'forecast_date'),
)
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
inventory_product_id = Column(UUID(as_uuid=True), nullable=False, index=True) # Reference to inventory service

View File

@@ -6,6 +6,7 @@ Repository for forecast operations
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, text, desc, func
from sqlalchemy.exc import IntegrityError
from datetime import datetime, timedelta, date, timezone
import structlog
@@ -24,18 +25,24 @@ class ForecastRepository(ForecastingBaseRepository):
super().__init__(Forecast, session, cache_ttl)
async def create_forecast(self, forecast_data: Dict[str, Any]) -> Forecast:
"""Create a new forecast with validation"""
"""
Create a new forecast with validation.
Handles duplicate forecast race condition gracefully:
If a forecast already exists for the same (tenant, product, date, location),
it will be updated instead of creating a duplicate.
"""
try:
# Validate forecast data
validation_result = self._validate_forecast_data(
forecast_data,
["tenant_id", "inventory_product_id", "location", "forecast_date",
forecast_data,
["tenant_id", "inventory_product_id", "location", "forecast_date",
"predicted_demand", "confidence_lower", "confidence_upper", "model_id"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid forecast data: {validation_result['errors']}")
# Set default values
if "confidence_level" not in forecast_data:
forecast_data["confidence_level"] = 0.8
@@ -43,26 +50,109 @@ class ForecastRepository(ForecastingBaseRepository):
forecast_data["algorithm"] = "prophet"
if "business_type" not in forecast_data:
forecast_data["business_type"] = "individual"
# Create forecast
forecast = await self.create(forecast_data)
logger.info("Forecast created successfully",
forecast_id=forecast.id,
tenant_id=forecast.tenant_id,
inventory_product_id=forecast.inventory_product_id,
forecast_date=forecast.forecast_date.isoformat())
return forecast
# Try to create forecast
try:
forecast = await self.create(forecast_data)
logger.info("Forecast created successfully",
forecast_id=forecast.id,
tenant_id=forecast.tenant_id,
inventory_product_id=forecast.inventory_product_id,
forecast_date=forecast.forecast_date.isoformat())
return forecast
except IntegrityError as ie:
# Handle unique constraint violation (duplicate forecast)
error_msg = str(ie).lower()
if "unique constraint" in error_msg or "duplicate" in error_msg or "uq_forecast_tenant_product_date_location" in error_msg:
logger.warning("Forecast already exists (race condition), updating instead",
tenant_id=forecast_data.get("tenant_id"),
inventory_product_id=forecast_data.get("inventory_product_id"),
forecast_date=str(forecast_data.get("forecast_date")))
# Rollback the failed insert
await self.session.rollback()
# Fetch the existing forecast
existing_forecast = await self.get_existing_forecast(
tenant_id=forecast_data["tenant_id"],
inventory_product_id=forecast_data["inventory_product_id"],
forecast_date=forecast_data["forecast_date"],
location=forecast_data["location"]
)
if existing_forecast:
# Update existing forecast with new prediction data
update_data = {
"predicted_demand": forecast_data["predicted_demand"],
"confidence_lower": forecast_data["confidence_lower"],
"confidence_upper": forecast_data["confidence_upper"],
"confidence_level": forecast_data.get("confidence_level", 0.8),
"model_id": forecast_data["model_id"],
"model_version": forecast_data.get("model_version"),
"algorithm": forecast_data.get("algorithm", "prophet"),
"processing_time_ms": forecast_data.get("processing_time_ms"),
"features_used": forecast_data.get("features_used"),
"weather_temperature": forecast_data.get("weather_temperature"),
"weather_precipitation": forecast_data.get("weather_precipitation"),
"weather_description": forecast_data.get("weather_description"),
}
updated_forecast = await self.update(str(existing_forecast.id), update_data)
logger.info("Existing forecast updated after duplicate detection",
forecast_id=updated_forecast.id,
tenant_id=updated_forecast.tenant_id,
inventory_product_id=updated_forecast.inventory_product_id)
return updated_forecast
else:
# This shouldn't happen, but log it
logger.error("Duplicate forecast detected but not found in database")
raise DatabaseError("Duplicate forecast detected but not found")
else:
# Different integrity error, re-raise
raise
except ValidationError:
raise
except IntegrityError as ie:
# Re-raise integrity errors that weren't handled above
logger.error("Database integrity error creating forecast",
tenant_id=forecast_data.get("tenant_id"),
error=str(ie))
raise DatabaseError(f"Database integrity error: {str(ie)}")
except Exception as e:
logger.error("Failed to create forecast",
tenant_id=forecast_data.get("tenant_id"),
inventory_product_id=forecast_data.get("inventory_product_id"),
error=str(e))
raise DatabaseError(f"Failed to create forecast: {str(e)}")
async def get_existing_forecast(
self,
tenant_id: str,
inventory_product_id: str,
forecast_date: datetime,
location: str
) -> Optional[Forecast]:
"""Get an existing forecast by unique key (tenant, product, date, location)"""
try:
query = select(Forecast).where(
and_(
Forecast.tenant_id == tenant_id,
Forecast.inventory_product_id == inventory_product_id,
Forecast.forecast_date == forecast_date,
Forecast.location == location
)
)
result = await self.session.execute(query)
return result.scalar_one_or_none()
except Exception as e:
logger.error("Failed to get existing forecast", error=str(e))
return None
async def get_forecasts_by_date_range(
self,

View File

@@ -15,6 +15,7 @@ from app.schemas.forecasts import ForecastRequest, ForecastResponse
from app.services.prediction_service import PredictionService
from app.services.model_client import ModelClient
from app.services.data_client import DataClient
from app.utils.distributed_lock import get_forecast_lock, get_batch_forecast_lock, LockAcquisitionError
# Import repositories
from app.repositories import (
@@ -291,107 +292,165 @@ class EnhancedForecastingService:
) -> ForecastResponse:
"""
Generate forecast using repository pattern with caching.
CRITICAL FIXES:
1. External HTTP calls are performed BEFORE opening database session
to prevent connection pool exhaustion and blocking.
2. Advisory locks prevent concurrent forecast generation for same product/date
to avoid duplicate work and race conditions.
"""
start_time = datetime.now(timezone.utc)
try:
logger.info("Generating enhanced forecast",
tenant_id=tenant_id,
inventory_product_id=request.inventory_product_id,
date=request.forecast_date.isoformat())
# Get session and initialize repositories
# CRITICAL FIX: Get model BEFORE opening database session
# This prevents holding database connections during potentially slow external API calls
logger.debug("Fetching model data before opening database session",
tenant_id=tenant_id,
inventory_product_id=request.inventory_product_id)
model_data = await self._get_latest_model_with_fallback(tenant_id, request.inventory_product_id)
if not model_data:
raise ValueError(f"No valid model available for product: {request.inventory_product_id}")
logger.debug("Model data fetched successfully",
tenant_id=tenant_id,
model_id=model_data.get('model_id'))
# Step 3: Prepare features with fallbacks (includes external API calls for weather)
features = await self._prepare_forecast_features_with_fallbacks(tenant_id, request)
# Now open database session AFTER external HTTP calls are complete
# CRITICAL FIX: Acquire distributed lock to prevent concurrent forecast generation
async with self.database_manager.get_background_session() as session:
repos = await self._init_repositories(session)
# Step 1: Check cache first
cached_prediction = await repos['cache'].get_cached_prediction(
tenant_id, request.inventory_product_id, request.location, request.forecast_date
)
if cached_prediction:
logger.debug("Using cached prediction",
tenant_id=tenant_id,
inventory_product_id=request.inventory_product_id)
return self._create_forecast_response_from_cache(cached_prediction)
# Step 2: Get model with validation
model_data = await self._get_latest_model_with_fallback(tenant_id, request.inventory_product_id)
if not model_data:
raise ValueError(f"No valid model available for product: {request.inventory_product_id}")
# Step 3: Prepare features with fallbacks
features = await self._prepare_forecast_features_with_fallbacks(tenant_id, request)
# Step 4: Generate prediction
prediction_result = await self.prediction_service.predict(
model_id=model_data['model_id'],
model_path=model_data['model_path'],
features=features,
confidence_level=request.confidence_level
)
# Step 5: Apply business rules
adjusted_prediction = self._apply_business_rules(
prediction_result, request, features
)
# Step 6: Save forecast using repository
# Convert forecast_date to datetime if it's a string
forecast_datetime = request.forecast_date
if isinstance(forecast_datetime, str):
from dateutil.parser import parse
forecast_datetime = parse(forecast_datetime)
forecast_data = {
"tenant_id": tenant_id,
"inventory_product_id": request.inventory_product_id,
"product_name": None, # Field is now nullable, use inventory_product_id as reference
"location": request.location,
"forecast_date": forecast_datetime,
"predicted_demand": adjusted_prediction['prediction'],
"confidence_lower": adjusted_prediction.get('lower_bound', adjusted_prediction['prediction'] * 0.8),
"confidence_upper": adjusted_prediction.get('upper_bound', adjusted_prediction['prediction'] * 1.2),
"confidence_level": request.confidence_level,
"model_id": model_data['model_id'],
"model_version": str(model_data.get('version', '1.0')),
"algorithm": model_data.get('algorithm', 'prophet'),
"business_type": features.get('business_type', 'individual'),
"is_holiday": features.get('is_holiday', False),
"is_weekend": features.get('is_weekend', False),
"day_of_week": features.get('day_of_week', 0),
"weather_temperature": features.get('temperature'),
"weather_precipitation": features.get('precipitation'),
"weather_description": features.get('weather_description'),
"traffic_volume": features.get('traffic_volume'),
"processing_time_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000),
"features_used": features
}
forecast = await repos['forecast'].create_forecast(forecast_data)
# Step 6: Cache the prediction
await repos['cache'].cache_prediction(
# Get lock for this specific forecast (tenant + product + date)
forecast_date_str = request.forecast_date.isoformat().split('T')[0] if hasattr(request.forecast_date, 'isoformat') else str(request.forecast_date).split('T')[0]
lock = get_forecast_lock(
tenant_id=tenant_id,
inventory_product_id=request.inventory_product_id,
location=request.location,
forecast_date=forecast_datetime,
predicted_demand=adjusted_prediction['prediction'],
confidence_lower=adjusted_prediction.get('lower_bound', adjusted_prediction['prediction'] * 0.8),
confidence_upper=adjusted_prediction.get('upper_bound', adjusted_prediction['prediction'] * 1.2),
model_id=model_data['model_id'],
expires_in_hours=24
product_id=str(request.inventory_product_id),
forecast_date=forecast_date_str
)
logger.info("Enhanced forecast generated successfully",
forecast_id=forecast.id,
tenant_id=tenant_id,
prediction=adjusted_prediction['prediction'])
return self._create_forecast_response_from_model(forecast)
try:
async with lock.acquire(session):
repos = await self._init_repositories(session)
# Step 1: Check cache first (inside lock for consistency)
# If another request generated the forecast while we waited for the lock,
# we'll find it in the cache
cached_prediction = await repos['cache'].get_cached_prediction(
tenant_id, request.inventory_product_id, request.location, request.forecast_date
)
if cached_prediction:
logger.info("Found cached prediction after acquiring lock (concurrent request completed first)",
tenant_id=tenant_id,
inventory_product_id=request.inventory_product_id)
return self._create_forecast_response_from_cache(cached_prediction)
# Step 2: Model data already fetched above (before session opened)
# Step 4: Generate prediction (in-memory operation)
prediction_result = await self.prediction_service.predict(
model_id=model_data['model_id'],
model_path=model_data['model_path'],
features=features,
confidence_level=request.confidence_level
)
# Step 5: Apply business rules
adjusted_prediction = self._apply_business_rules(
prediction_result, request, features
)
# Step 6: Save forecast using repository
# Convert forecast_date to datetime if it's a string
forecast_datetime = request.forecast_date
if isinstance(forecast_datetime, str):
from dateutil.parser import parse
forecast_datetime = parse(forecast_datetime)
forecast_data = {
"tenant_id": tenant_id,
"inventory_product_id": request.inventory_product_id,
"product_name": None, # Field is now nullable, use inventory_product_id as reference
"location": request.location,
"forecast_date": forecast_datetime,
"predicted_demand": adjusted_prediction['prediction'],
"confidence_lower": adjusted_prediction.get('lower_bound', adjusted_prediction['prediction'] * 0.8),
"confidence_upper": adjusted_prediction.get('upper_bound', adjusted_prediction['prediction'] * 1.2),
"confidence_level": request.confidence_level,
"model_id": model_data['model_id'],
"model_version": str(model_data.get('version', '1.0')),
"algorithm": model_data.get('algorithm', 'prophet'),
"business_type": features.get('business_type', 'individual'),
"is_holiday": features.get('is_holiday', False),
"is_weekend": features.get('is_weekend', False),
"day_of_week": features.get('day_of_week', 0),
"weather_temperature": features.get('temperature'),
"weather_precipitation": features.get('precipitation'),
"weather_description": features.get('weather_description'),
"traffic_volume": features.get('traffic_volume'),
"processing_time_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000),
"features_used": features
}
forecast = await repos['forecast'].create_forecast(forecast_data)
await session.commit()
# Step 7: Cache the prediction
await repos['cache'].cache_prediction(
tenant_id=tenant_id,
inventory_product_id=request.inventory_product_id,
location=request.location,
forecast_date=forecast_datetime,
predicted_demand=adjusted_prediction['prediction'],
confidence_lower=adjusted_prediction.get('lower_bound', adjusted_prediction['prediction'] * 0.8),
confidence_upper=adjusted_prediction.get('upper_bound', adjusted_prediction['prediction'] * 1.2),
model_id=model_data['model_id'],
expires_in_hours=24
)
logger.info("Enhanced forecast generated successfully",
forecast_id=forecast.id,
tenant_id=tenant_id,
prediction=adjusted_prediction['prediction'])
return self._create_forecast_response_from_model(forecast)
except LockAcquisitionError:
# Could not acquire lock - another forecast request is in progress
logger.warning("Could not acquire forecast lock, checking cache for concurrent request result",
tenant_id=tenant_id,
inventory_product_id=request.inventory_product_id,
forecast_date=forecast_date_str)
# Wait a moment and check cache - maybe the concurrent request finished
await asyncio.sleep(1)
repos = await self._init_repositories(session)
cached_prediction = await repos['cache'].get_cached_prediction(
tenant_id, request.inventory_product_id, request.location, request.forecast_date
)
if cached_prediction:
logger.info("Found forecast in cache after lock timeout (concurrent request completed)",
tenant_id=tenant_id,
inventory_product_id=request.inventory_product_id)
return self._create_forecast_response_from_cache(cached_prediction)
# No cached result, raise error
raise ValueError(
f"Forecast generation already in progress for product {request.inventory_product_id}. "
"Please try again in a few seconds."
)
except Exception as e:
processing_time = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
logger.error("Error generating enhanced forecast",

View File

@@ -0,0 +1,3 @@
"""
Utility modules for forecasting service
"""

View File

@@ -0,0 +1,258 @@
"""
Distributed Locking Mechanisms for Forecasting Service
Prevents concurrent forecast generation for the same product/date
"""
import asyncio
import time
from typing import Optional
import logging
from contextlib import asynccontextmanager
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
from datetime import datetime, timezone, timedelta
logger = logging.getLogger(__name__)
class LockAcquisitionError(Exception):
"""Raised when lock cannot be acquired"""
pass
class DatabaseLock:
"""
Database-based distributed lock using PostgreSQL advisory locks.
Works across multiple service instances.
"""
def __init__(self, lock_name: str, timeout: float = 30.0):
"""
Initialize database lock.
Args:
lock_name: Unique identifier for the lock
timeout: Maximum seconds to wait for lock acquisition
"""
self.lock_name = lock_name
self.timeout = timeout
self.lock_id = self._hash_lock_name(lock_name)
def _hash_lock_name(self, name: str) -> int:
"""Convert lock name to integer ID for PostgreSQL advisory lock"""
# Use hash and modulo to get a positive 32-bit integer
return abs(hash(name)) % (2**31)
@asynccontextmanager
async def acquire(self, session: AsyncSession):
"""
Acquire distributed lock as async context manager.
Args:
session: Database session for lock operations
Raises:
LockAcquisitionError: If lock cannot be acquired within timeout
"""
acquired = False
start_time = time.time()
try:
# Try to acquire lock with timeout
while time.time() - start_time < self.timeout:
# Try non-blocking lock acquisition
result = await session.execute(
text("SELECT pg_try_advisory_lock(:lock_id)"),
{"lock_id": self.lock_id}
)
acquired = result.scalar()
if acquired:
logger.info(f"Acquired lock: {self.lock_name} (id={self.lock_id})")
break
# Wait a bit before retrying
await asyncio.sleep(0.1)
if not acquired:
raise LockAcquisitionError(
f"Could not acquire lock '{self.lock_name}' within {self.timeout}s"
)
yield
finally:
if acquired:
# Release lock
await session.execute(
text("SELECT pg_advisory_unlock(:lock_id)"),
{"lock_id": self.lock_id}
)
logger.info(f"Released lock: {self.lock_name} (id={self.lock_id})")
class SimpleDatabaseLock:
"""
Simple table-based distributed lock.
Alternative to advisory locks, uses a dedicated locks table.
"""
def __init__(self, lock_name: str, timeout: float = 30.0, ttl: float = 300.0):
"""
Initialize simple database lock.
Args:
lock_name: Unique identifier for the lock
timeout: Maximum seconds to wait for lock acquisition
ttl: Time-to-live for stale lock cleanup (seconds)
"""
self.lock_name = lock_name
self.timeout = timeout
self.ttl = ttl
async def _ensure_lock_table(self, session: AsyncSession):
"""Ensure locks table exists"""
create_table_sql = """
CREATE TABLE IF NOT EXISTS distributed_locks (
lock_name VARCHAR(255) PRIMARY KEY,
acquired_at TIMESTAMP WITH TIME ZONE NOT NULL,
acquired_by VARCHAR(255),
expires_at TIMESTAMP WITH TIME ZONE NOT NULL
)
"""
await session.execute(text(create_table_sql))
await session.commit()
async def _cleanup_stale_locks(self, session: AsyncSession):
"""Remove expired locks"""
cleanup_sql = """
DELETE FROM distributed_locks
WHERE expires_at < :now
"""
await session.execute(
text(cleanup_sql),
{"now": datetime.now(timezone.utc)}
)
await session.commit()
@asynccontextmanager
async def acquire(self, session: AsyncSession, owner: str = "forecasting-service"):
"""
Acquire simple database lock.
Args:
session: Database session
owner: Identifier for lock owner
Raises:
LockAcquisitionError: If lock cannot be acquired
"""
await self._ensure_lock_table(session)
await self._cleanup_stale_locks(session)
acquired = False
start_time = time.time()
try:
# Try to acquire lock
while time.time() - start_time < self.timeout:
now = datetime.now(timezone.utc)
expires_at = now + timedelta(seconds=self.ttl)
try:
# Try to insert lock record
insert_sql = """
INSERT INTO distributed_locks (lock_name, acquired_at, acquired_by, expires_at)
VALUES (:lock_name, :acquired_at, :acquired_by, :expires_at)
ON CONFLICT (lock_name) DO NOTHING
RETURNING lock_name
"""
result = await session.execute(
text(insert_sql),
{
"lock_name": self.lock_name,
"acquired_at": now,
"acquired_by": owner,
"expires_at": expires_at
}
)
await session.commit()
if result.rowcount > 0:
acquired = True
logger.info(f"Acquired simple lock: {self.lock_name}")
break
except Exception as e:
logger.debug(f"Lock acquisition attempt failed: {e}")
await session.rollback()
# Wait before retrying
await asyncio.sleep(0.5)
if not acquired:
raise LockAcquisitionError(
f"Could not acquire lock '{self.lock_name}' within {self.timeout}s"
)
yield
finally:
if acquired:
# Release lock
delete_sql = """
DELETE FROM distributed_locks
WHERE lock_name = :lock_name
"""
await session.execute(
text(delete_sql),
{"lock_name": self.lock_name}
)
await session.commit()
logger.info(f"Released simple lock: {self.lock_name}")
def get_forecast_lock(
tenant_id: str,
product_id: str,
forecast_date: str,
use_advisory: bool = True
) -> DatabaseLock:
"""
Get distributed lock for generating a forecast for a specific product and date.
Args:
tenant_id: Tenant identifier
product_id: Product identifier
forecast_date: Forecast date (ISO format)
use_advisory: Use PostgreSQL advisory locks (True) or table-based (False)
Returns:
Lock instance
"""
lock_name = f"forecast:{tenant_id}:{product_id}:{forecast_date}"
if use_advisory:
return DatabaseLock(lock_name, timeout=30.0)
else:
return SimpleDatabaseLock(lock_name, timeout=30.0, ttl=300.0)
def get_batch_forecast_lock(tenant_id: str, use_advisory: bool = True) -> DatabaseLock:
"""
Get distributed lock for batch forecast generation for a tenant.
Args:
tenant_id: Tenant identifier
use_advisory: Use PostgreSQL advisory locks (True) or table-based (False)
Returns:
Lock instance
"""
lock_name = f"forecast_batch:{tenant_id}"
if use_advisory:
return DatabaseLock(lock_name, timeout=60.0)
else:
return SimpleDatabaseLock(lock_name, timeout=60.0, ttl=600.0)