Improve the UI and training
This commit is contained in:
@@ -15,6 +15,7 @@ from app.schemas.forecasts import ForecastRequest, ForecastResponse
|
||||
from app.services.prediction_service import PredictionService
|
||||
from app.services.model_client import ModelClient
|
||||
from app.services.data_client import DataClient
|
||||
from app.utils.distributed_lock import get_forecast_lock, get_batch_forecast_lock, LockAcquisitionError
|
||||
|
||||
# Import repositories
|
||||
from app.repositories import (
|
||||
@@ -291,107 +292,165 @@ class EnhancedForecastingService:
|
||||
) -> ForecastResponse:
|
||||
"""
|
||||
Generate forecast using repository pattern with caching.
|
||||
|
||||
CRITICAL FIXES:
|
||||
1. External HTTP calls are performed BEFORE opening database session
|
||||
to prevent connection pool exhaustion and blocking.
|
||||
2. Advisory locks prevent concurrent forecast generation for same product/date
|
||||
to avoid duplicate work and race conditions.
|
||||
"""
|
||||
start_time = datetime.now(timezone.utc)
|
||||
|
||||
|
||||
try:
|
||||
logger.info("Generating enhanced forecast",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=request.inventory_product_id,
|
||||
date=request.forecast_date.isoformat())
|
||||
|
||||
# Get session and initialize repositories
|
||||
|
||||
# CRITICAL FIX: Get model BEFORE opening database session
|
||||
# This prevents holding database connections during potentially slow external API calls
|
||||
logger.debug("Fetching model data before opening database session",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=request.inventory_product_id)
|
||||
|
||||
model_data = await self._get_latest_model_with_fallback(tenant_id, request.inventory_product_id)
|
||||
|
||||
if not model_data:
|
||||
raise ValueError(f"No valid model available for product: {request.inventory_product_id}")
|
||||
|
||||
logger.debug("Model data fetched successfully",
|
||||
tenant_id=tenant_id,
|
||||
model_id=model_data.get('model_id'))
|
||||
|
||||
# Step 3: Prepare features with fallbacks (includes external API calls for weather)
|
||||
features = await self._prepare_forecast_features_with_fallbacks(tenant_id, request)
|
||||
|
||||
# Now open database session AFTER external HTTP calls are complete
|
||||
# CRITICAL FIX: Acquire distributed lock to prevent concurrent forecast generation
|
||||
async with self.database_manager.get_background_session() as session:
|
||||
repos = await self._init_repositories(session)
|
||||
|
||||
# Step 1: Check cache first
|
||||
cached_prediction = await repos['cache'].get_cached_prediction(
|
||||
tenant_id, request.inventory_product_id, request.location, request.forecast_date
|
||||
)
|
||||
|
||||
if cached_prediction:
|
||||
logger.debug("Using cached prediction",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=request.inventory_product_id)
|
||||
return self._create_forecast_response_from_cache(cached_prediction)
|
||||
|
||||
# Step 2: Get model with validation
|
||||
model_data = await self._get_latest_model_with_fallback(tenant_id, request.inventory_product_id)
|
||||
|
||||
if not model_data:
|
||||
raise ValueError(f"No valid model available for product: {request.inventory_product_id}")
|
||||
|
||||
# Step 3: Prepare features with fallbacks
|
||||
features = await self._prepare_forecast_features_with_fallbacks(tenant_id, request)
|
||||
|
||||
# Step 4: Generate prediction
|
||||
prediction_result = await self.prediction_service.predict(
|
||||
model_id=model_data['model_id'],
|
||||
model_path=model_data['model_path'],
|
||||
features=features,
|
||||
confidence_level=request.confidence_level
|
||||
)
|
||||
|
||||
# Step 5: Apply business rules
|
||||
adjusted_prediction = self._apply_business_rules(
|
||||
prediction_result, request, features
|
||||
)
|
||||
|
||||
# Step 6: Save forecast using repository
|
||||
# Convert forecast_date to datetime if it's a string
|
||||
forecast_datetime = request.forecast_date
|
||||
if isinstance(forecast_datetime, str):
|
||||
from dateutil.parser import parse
|
||||
forecast_datetime = parse(forecast_datetime)
|
||||
|
||||
forecast_data = {
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": request.inventory_product_id,
|
||||
"product_name": None, # Field is now nullable, use inventory_product_id as reference
|
||||
"location": request.location,
|
||||
"forecast_date": forecast_datetime,
|
||||
"predicted_demand": adjusted_prediction['prediction'],
|
||||
"confidence_lower": adjusted_prediction.get('lower_bound', adjusted_prediction['prediction'] * 0.8),
|
||||
"confidence_upper": adjusted_prediction.get('upper_bound', adjusted_prediction['prediction'] * 1.2),
|
||||
"confidence_level": request.confidence_level,
|
||||
"model_id": model_data['model_id'],
|
||||
"model_version": str(model_data.get('version', '1.0')),
|
||||
"algorithm": model_data.get('algorithm', 'prophet'),
|
||||
"business_type": features.get('business_type', 'individual'),
|
||||
"is_holiday": features.get('is_holiday', False),
|
||||
"is_weekend": features.get('is_weekend', False),
|
||||
"day_of_week": features.get('day_of_week', 0),
|
||||
"weather_temperature": features.get('temperature'),
|
||||
"weather_precipitation": features.get('precipitation'),
|
||||
"weather_description": features.get('weather_description'),
|
||||
"traffic_volume": features.get('traffic_volume'),
|
||||
"processing_time_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000),
|
||||
"features_used": features
|
||||
}
|
||||
|
||||
forecast = await repos['forecast'].create_forecast(forecast_data)
|
||||
|
||||
# Step 6: Cache the prediction
|
||||
await repos['cache'].cache_prediction(
|
||||
# Get lock for this specific forecast (tenant + product + date)
|
||||
forecast_date_str = request.forecast_date.isoformat().split('T')[0] if hasattr(request.forecast_date, 'isoformat') else str(request.forecast_date).split('T')[0]
|
||||
lock = get_forecast_lock(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=request.inventory_product_id,
|
||||
location=request.location,
|
||||
forecast_date=forecast_datetime,
|
||||
predicted_demand=adjusted_prediction['prediction'],
|
||||
confidence_lower=adjusted_prediction.get('lower_bound', adjusted_prediction['prediction'] * 0.8),
|
||||
confidence_upper=adjusted_prediction.get('upper_bound', adjusted_prediction['prediction'] * 1.2),
|
||||
model_id=model_data['model_id'],
|
||||
expires_in_hours=24
|
||||
product_id=str(request.inventory_product_id),
|
||||
forecast_date=forecast_date_str
|
||||
)
|
||||
|
||||
|
||||
logger.info("Enhanced forecast generated successfully",
|
||||
forecast_id=forecast.id,
|
||||
tenant_id=tenant_id,
|
||||
prediction=adjusted_prediction['prediction'])
|
||||
|
||||
return self._create_forecast_response_from_model(forecast)
|
||||
|
||||
|
||||
try:
|
||||
async with lock.acquire(session):
|
||||
repos = await self._init_repositories(session)
|
||||
|
||||
# Step 1: Check cache first (inside lock for consistency)
|
||||
# If another request generated the forecast while we waited for the lock,
|
||||
# we'll find it in the cache
|
||||
cached_prediction = await repos['cache'].get_cached_prediction(
|
||||
tenant_id, request.inventory_product_id, request.location, request.forecast_date
|
||||
)
|
||||
|
||||
if cached_prediction:
|
||||
logger.info("Found cached prediction after acquiring lock (concurrent request completed first)",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=request.inventory_product_id)
|
||||
return self._create_forecast_response_from_cache(cached_prediction)
|
||||
|
||||
# Step 2: Model data already fetched above (before session opened)
|
||||
|
||||
# Step 4: Generate prediction (in-memory operation)
|
||||
prediction_result = await self.prediction_service.predict(
|
||||
model_id=model_data['model_id'],
|
||||
model_path=model_data['model_path'],
|
||||
features=features,
|
||||
confidence_level=request.confidence_level
|
||||
)
|
||||
|
||||
# Step 5: Apply business rules
|
||||
adjusted_prediction = self._apply_business_rules(
|
||||
prediction_result, request, features
|
||||
)
|
||||
|
||||
# Step 6: Save forecast using repository
|
||||
# Convert forecast_date to datetime if it's a string
|
||||
forecast_datetime = request.forecast_date
|
||||
if isinstance(forecast_datetime, str):
|
||||
from dateutil.parser import parse
|
||||
forecast_datetime = parse(forecast_datetime)
|
||||
|
||||
forecast_data = {
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": request.inventory_product_id,
|
||||
"product_name": None, # Field is now nullable, use inventory_product_id as reference
|
||||
"location": request.location,
|
||||
"forecast_date": forecast_datetime,
|
||||
"predicted_demand": adjusted_prediction['prediction'],
|
||||
"confidence_lower": adjusted_prediction.get('lower_bound', adjusted_prediction['prediction'] * 0.8),
|
||||
"confidence_upper": adjusted_prediction.get('upper_bound', adjusted_prediction['prediction'] * 1.2),
|
||||
"confidence_level": request.confidence_level,
|
||||
"model_id": model_data['model_id'],
|
||||
"model_version": str(model_data.get('version', '1.0')),
|
||||
"algorithm": model_data.get('algorithm', 'prophet'),
|
||||
"business_type": features.get('business_type', 'individual'),
|
||||
"is_holiday": features.get('is_holiday', False),
|
||||
"is_weekend": features.get('is_weekend', False),
|
||||
"day_of_week": features.get('day_of_week', 0),
|
||||
"weather_temperature": features.get('temperature'),
|
||||
"weather_precipitation": features.get('precipitation'),
|
||||
"weather_description": features.get('weather_description'),
|
||||
"traffic_volume": features.get('traffic_volume'),
|
||||
"processing_time_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000),
|
||||
"features_used": features
|
||||
}
|
||||
|
||||
forecast = await repos['forecast'].create_forecast(forecast_data)
|
||||
await session.commit()
|
||||
|
||||
# Step 7: Cache the prediction
|
||||
await repos['cache'].cache_prediction(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=request.inventory_product_id,
|
||||
location=request.location,
|
||||
forecast_date=forecast_datetime,
|
||||
predicted_demand=adjusted_prediction['prediction'],
|
||||
confidence_lower=adjusted_prediction.get('lower_bound', adjusted_prediction['prediction'] * 0.8),
|
||||
confidence_upper=adjusted_prediction.get('upper_bound', adjusted_prediction['prediction'] * 1.2),
|
||||
model_id=model_data['model_id'],
|
||||
expires_in_hours=24
|
||||
)
|
||||
|
||||
|
||||
logger.info("Enhanced forecast generated successfully",
|
||||
forecast_id=forecast.id,
|
||||
tenant_id=tenant_id,
|
||||
prediction=adjusted_prediction['prediction'])
|
||||
|
||||
return self._create_forecast_response_from_model(forecast)
|
||||
|
||||
except LockAcquisitionError:
|
||||
# Could not acquire lock - another forecast request is in progress
|
||||
logger.warning("Could not acquire forecast lock, checking cache for concurrent request result",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=request.inventory_product_id,
|
||||
forecast_date=forecast_date_str)
|
||||
|
||||
# Wait a moment and check cache - maybe the concurrent request finished
|
||||
await asyncio.sleep(1)
|
||||
|
||||
repos = await self._init_repositories(session)
|
||||
cached_prediction = await repos['cache'].get_cached_prediction(
|
||||
tenant_id, request.inventory_product_id, request.location, request.forecast_date
|
||||
)
|
||||
|
||||
if cached_prediction:
|
||||
logger.info("Found forecast in cache after lock timeout (concurrent request completed)",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=request.inventory_product_id)
|
||||
return self._create_forecast_response_from_cache(cached_prediction)
|
||||
|
||||
# No cached result, raise error
|
||||
raise ValueError(
|
||||
f"Forecast generation already in progress for product {request.inventory_product_id}. "
|
||||
"Please try again in a few seconds."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
processing_time = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
|
||||
logger.error("Error generating enhanced forecast",
|
||||
|
||||
Reference in New Issue
Block a user