Improve the UI and training

This commit is contained in:
Urtzi Alfaro
2025-11-15 15:20:10 +01:00
parent c349b845a6
commit 843cd2bf5c
19 changed files with 2073 additions and 233 deletions

View File

@@ -15,6 +15,7 @@ from app.schemas.forecasts import ForecastRequest, ForecastResponse
from app.services.prediction_service import PredictionService
from app.services.model_client import ModelClient
from app.services.data_client import DataClient
from app.utils.distributed_lock import get_forecast_lock, get_batch_forecast_lock, LockAcquisitionError
# Import repositories
from app.repositories import (
@@ -291,107 +292,165 @@ class EnhancedForecastingService:
) -> ForecastResponse:
"""
Generate forecast using repository pattern with caching.
CRITICAL FIXES:
1. External HTTP calls are performed BEFORE opening database session
to prevent connection pool exhaustion and blocking.
2. Advisory locks prevent concurrent forecast generation for same product/date
to avoid duplicate work and race conditions.
"""
start_time = datetime.now(timezone.utc)
try:
logger.info("Generating enhanced forecast",
tenant_id=tenant_id,
inventory_product_id=request.inventory_product_id,
date=request.forecast_date.isoformat())
# Get session and initialize repositories
# CRITICAL FIX: Get model BEFORE opening database session
# This prevents holding database connections during potentially slow external API calls
logger.debug("Fetching model data before opening database session",
tenant_id=tenant_id,
inventory_product_id=request.inventory_product_id)
model_data = await self._get_latest_model_with_fallback(tenant_id, request.inventory_product_id)
if not model_data:
raise ValueError(f"No valid model available for product: {request.inventory_product_id}")
logger.debug("Model data fetched successfully",
tenant_id=tenant_id,
model_id=model_data.get('model_id'))
# Step 3: Prepare features with fallbacks (includes external API calls for weather)
features = await self._prepare_forecast_features_with_fallbacks(tenant_id, request)
# Now open database session AFTER external HTTP calls are complete
# CRITICAL FIX: Acquire distributed lock to prevent concurrent forecast generation
async with self.database_manager.get_background_session() as session:
repos = await self._init_repositories(session)
# Step 1: Check cache first
cached_prediction = await repos['cache'].get_cached_prediction(
tenant_id, request.inventory_product_id, request.location, request.forecast_date
)
if cached_prediction:
logger.debug("Using cached prediction",
tenant_id=tenant_id,
inventory_product_id=request.inventory_product_id)
return self._create_forecast_response_from_cache(cached_prediction)
# Step 2: Get model with validation
model_data = await self._get_latest_model_with_fallback(tenant_id, request.inventory_product_id)
if not model_data:
raise ValueError(f"No valid model available for product: {request.inventory_product_id}")
# Step 3: Prepare features with fallbacks
features = await self._prepare_forecast_features_with_fallbacks(tenant_id, request)
# Step 4: Generate prediction
prediction_result = await self.prediction_service.predict(
model_id=model_data['model_id'],
model_path=model_data['model_path'],
features=features,
confidence_level=request.confidence_level
)
# Step 5: Apply business rules
adjusted_prediction = self._apply_business_rules(
prediction_result, request, features
)
# Step 6: Save forecast using repository
# Convert forecast_date to datetime if it's a string
forecast_datetime = request.forecast_date
if isinstance(forecast_datetime, str):
from dateutil.parser import parse
forecast_datetime = parse(forecast_datetime)
forecast_data = {
"tenant_id": tenant_id,
"inventory_product_id": request.inventory_product_id,
"product_name": None, # Field is now nullable, use inventory_product_id as reference
"location": request.location,
"forecast_date": forecast_datetime,
"predicted_demand": adjusted_prediction['prediction'],
"confidence_lower": adjusted_prediction.get('lower_bound', adjusted_prediction['prediction'] * 0.8),
"confidence_upper": adjusted_prediction.get('upper_bound', adjusted_prediction['prediction'] * 1.2),
"confidence_level": request.confidence_level,
"model_id": model_data['model_id'],
"model_version": str(model_data.get('version', '1.0')),
"algorithm": model_data.get('algorithm', 'prophet'),
"business_type": features.get('business_type', 'individual'),
"is_holiday": features.get('is_holiday', False),
"is_weekend": features.get('is_weekend', False),
"day_of_week": features.get('day_of_week', 0),
"weather_temperature": features.get('temperature'),
"weather_precipitation": features.get('precipitation'),
"weather_description": features.get('weather_description'),
"traffic_volume": features.get('traffic_volume'),
"processing_time_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000),
"features_used": features
}
forecast = await repos['forecast'].create_forecast(forecast_data)
# Step 6: Cache the prediction
await repos['cache'].cache_prediction(
# Get lock for this specific forecast (tenant + product + date)
forecast_date_str = request.forecast_date.isoformat().split('T')[0] if hasattr(request.forecast_date, 'isoformat') else str(request.forecast_date).split('T')[0]
lock = get_forecast_lock(
tenant_id=tenant_id,
inventory_product_id=request.inventory_product_id,
location=request.location,
forecast_date=forecast_datetime,
predicted_demand=adjusted_prediction['prediction'],
confidence_lower=adjusted_prediction.get('lower_bound', adjusted_prediction['prediction'] * 0.8),
confidence_upper=adjusted_prediction.get('upper_bound', adjusted_prediction['prediction'] * 1.2),
model_id=model_data['model_id'],
expires_in_hours=24
product_id=str(request.inventory_product_id),
forecast_date=forecast_date_str
)
logger.info("Enhanced forecast generated successfully",
forecast_id=forecast.id,
tenant_id=tenant_id,
prediction=adjusted_prediction['prediction'])
return self._create_forecast_response_from_model(forecast)
try:
async with lock.acquire(session):
repos = await self._init_repositories(session)
# Step 1: Check cache first (inside lock for consistency)
# If another request generated the forecast while we waited for the lock,
# we'll find it in the cache
cached_prediction = await repos['cache'].get_cached_prediction(
tenant_id, request.inventory_product_id, request.location, request.forecast_date
)
if cached_prediction:
logger.info("Found cached prediction after acquiring lock (concurrent request completed first)",
tenant_id=tenant_id,
inventory_product_id=request.inventory_product_id)
return self._create_forecast_response_from_cache(cached_prediction)
# Step 2: Model data already fetched above (before session opened)
# Step 4: Generate prediction (in-memory operation)
prediction_result = await self.prediction_service.predict(
model_id=model_data['model_id'],
model_path=model_data['model_path'],
features=features,
confidence_level=request.confidence_level
)
# Step 5: Apply business rules
adjusted_prediction = self._apply_business_rules(
prediction_result, request, features
)
# Step 6: Save forecast using repository
# Convert forecast_date to datetime if it's a string
forecast_datetime = request.forecast_date
if isinstance(forecast_datetime, str):
from dateutil.parser import parse
forecast_datetime = parse(forecast_datetime)
forecast_data = {
"tenant_id": tenant_id,
"inventory_product_id": request.inventory_product_id,
"product_name": None, # Field is now nullable, use inventory_product_id as reference
"location": request.location,
"forecast_date": forecast_datetime,
"predicted_demand": adjusted_prediction['prediction'],
"confidence_lower": adjusted_prediction.get('lower_bound', adjusted_prediction['prediction'] * 0.8),
"confidence_upper": adjusted_prediction.get('upper_bound', adjusted_prediction['prediction'] * 1.2),
"confidence_level": request.confidence_level,
"model_id": model_data['model_id'],
"model_version": str(model_data.get('version', '1.0')),
"algorithm": model_data.get('algorithm', 'prophet'),
"business_type": features.get('business_type', 'individual'),
"is_holiday": features.get('is_holiday', False),
"is_weekend": features.get('is_weekend', False),
"day_of_week": features.get('day_of_week', 0),
"weather_temperature": features.get('temperature'),
"weather_precipitation": features.get('precipitation'),
"weather_description": features.get('weather_description'),
"traffic_volume": features.get('traffic_volume'),
"processing_time_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000),
"features_used": features
}
forecast = await repos['forecast'].create_forecast(forecast_data)
await session.commit()
# Step 7: Cache the prediction
await repos['cache'].cache_prediction(
tenant_id=tenant_id,
inventory_product_id=request.inventory_product_id,
location=request.location,
forecast_date=forecast_datetime,
predicted_demand=adjusted_prediction['prediction'],
confidence_lower=adjusted_prediction.get('lower_bound', adjusted_prediction['prediction'] * 0.8),
confidence_upper=adjusted_prediction.get('upper_bound', adjusted_prediction['prediction'] * 1.2),
model_id=model_data['model_id'],
expires_in_hours=24
)
logger.info("Enhanced forecast generated successfully",
forecast_id=forecast.id,
tenant_id=tenant_id,
prediction=adjusted_prediction['prediction'])
return self._create_forecast_response_from_model(forecast)
except LockAcquisitionError:
# Could not acquire lock - another forecast request is in progress
logger.warning("Could not acquire forecast lock, checking cache for concurrent request result",
tenant_id=tenant_id,
inventory_product_id=request.inventory_product_id,
forecast_date=forecast_date_str)
# Wait a moment and check cache - maybe the concurrent request finished
await asyncio.sleep(1)
repos = await self._init_repositories(session)
cached_prediction = await repos['cache'].get_cached_prediction(
tenant_id, request.inventory_product_id, request.location, request.forecast_date
)
if cached_prediction:
logger.info("Found forecast in cache after lock timeout (concurrent request completed)",
tenant_id=tenant_id,
inventory_product_id=request.inventory_product_id)
return self._create_forecast_response_from_cache(cached_prediction)
# No cached result, raise error
raise ValueError(
f"Forecast generation already in progress for product {request.inventory_product_id}. "
"Please try again in a few seconds."
)
except Exception as e:
processing_time = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
logger.error("Error generating enhanced forecast",