Add improvements
This commit is contained in:
@@ -9,6 +9,7 @@ from fastapi.responses import JSONResponse
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import date, datetime, timezone
|
||||
import uuid
|
||||
from uuid import UUID
|
||||
|
||||
from app.services.forecasting_service import EnhancedForecastingService
|
||||
from app.services.prediction_service import PredictionService
|
||||
@@ -42,6 +43,30 @@ async def get_rate_limiter():
|
||||
return create_rate_limiter(redis_client)
|
||||
|
||||
|
||||
def validate_uuid(value: str, field_name: str = "ID") -> str:
|
||||
"""
|
||||
Validate that a string is a valid UUID.
|
||||
|
||||
Args:
|
||||
value: The string to validate
|
||||
field_name: Name of the field for error messages
|
||||
|
||||
Returns:
|
||||
The validated UUID string
|
||||
|
||||
Raises:
|
||||
HTTPException: If the value is not a valid UUID
|
||||
"""
|
||||
try:
|
||||
UUID(value)
|
||||
return value
|
||||
except (ValueError, AttributeError):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"{field_name} must be a valid UUID, got: {value}"
|
||||
)
|
||||
|
||||
|
||||
def get_enhanced_forecasting_service():
|
||||
"""Dependency injection for EnhancedForecastingService"""
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "forecasting-service")
|
||||
@@ -68,6 +93,10 @@ async def generate_single_forecast(
|
||||
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
||||
):
|
||||
"""Generate a single product forecast with caching support"""
|
||||
# Validate UUID fields
|
||||
validate_uuid(tenant_id, "tenant_id")
|
||||
# inventory_product_id already validated by ForecastRequest schema
|
||||
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
|
||||
@@ -28,15 +28,6 @@ router = APIRouter(prefix="/internal/demo", tags=["internal"])
|
||||
DEMO_TENANT_PROFESSIONAL = "a1b2c3d4-e5f6-47a8-b9c0-d1e2f3a4b5c6"
|
||||
|
||||
|
||||
def verify_internal_api_key(x_internal_api_key: Optional[str] = Header(None)):
|
||||
"""Verify internal API key for service-to-service communication"""
|
||||
from app.core.config import settings
|
||||
if x_internal_api_key != settings.INTERNAL_API_KEY:
|
||||
logger.warning("Unauthorized internal API access attempted")
|
||||
raise HTTPException(status_code=403, detail="Invalid internal API key")
|
||||
return True
|
||||
|
||||
|
||||
def parse_date_field(date_value, session_time: datetime, field_name: str = "date") -> Optional[datetime]:
|
||||
"""
|
||||
Parse date field, handling both ISO strings and BASE_TS markers.
|
||||
@@ -98,8 +89,7 @@ async def clone_demo_data(
|
||||
demo_account_type: str,
|
||||
session_id: Optional[str] = None,
|
||||
session_created_at: Optional[str] = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: bool = Depends(verify_internal_api_key)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Clone forecasting service data for a virtual demo tenant
|
||||
@@ -406,7 +396,7 @@ async def clone_demo_data(
|
||||
|
||||
|
||||
@router.get("/clone/health")
|
||||
async def clone_health_check(_: bool = Depends(verify_internal_api_key)):
|
||||
async def clone_health_check():
|
||||
"""
|
||||
Health check for internal cloning endpoint
|
||||
Used by orchestrator to verify service availability
|
||||
@@ -421,8 +411,7 @@ async def clone_health_check(_: bool = Depends(verify_internal_api_key)):
|
||||
@router.delete("/tenant/{virtual_tenant_id}")
|
||||
async def delete_demo_tenant_data(
|
||||
virtual_tenant_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: bool = Depends(verify_internal_api_key)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Delete all demo data for a virtual tenant.
|
||||
|
||||
@@ -9,6 +9,7 @@ from pydantic import BaseModel, Field, validator
|
||||
from datetime import datetime, date
|
||||
from typing import Optional, List, Dict, Any
|
||||
from enum import Enum
|
||||
from uuid import UUID
|
||||
|
||||
class BusinessType(str, Enum):
|
||||
INDIVIDUAL = "individual"
|
||||
@@ -22,10 +23,19 @@ class ForecastRequest(BaseModel):
|
||||
forecast_date: date = Field(..., description="Starting date for forecast")
|
||||
forecast_days: int = Field(1, ge=1, le=30, description="Number of days to forecast")
|
||||
location: str = Field(..., description="Location identifier")
|
||||
|
||||
|
||||
# Optional parameters - internally handled
|
||||
confidence_level: float = Field(0.8, ge=0.5, le=0.95, description="Confidence level")
|
||||
|
||||
|
||||
@validator('inventory_product_id')
|
||||
def validate_inventory_product_id(cls, v):
|
||||
"""Validate that inventory_product_id is a valid UUID"""
|
||||
try:
|
||||
UUID(v)
|
||||
except (ValueError, AttributeError):
|
||||
raise ValueError(f"inventory_product_id must be a valid UUID, got: {v}")
|
||||
return v
|
||||
|
||||
@validator('forecast_date')
|
||||
def validate_forecast_date(cls, v):
|
||||
if v < date.today():
|
||||
@@ -39,6 +49,26 @@ class BatchForecastRequest(BaseModel):
|
||||
inventory_product_ids: List[str] = Field(..., description="List of inventory product IDs")
|
||||
forecast_days: int = Field(7, ge=1, le=30, description="Number of days to forecast")
|
||||
|
||||
@validator('tenant_id')
|
||||
def validate_tenant_id(cls, v):
|
||||
"""Validate that tenant_id is a valid UUID if provided"""
|
||||
if v is not None:
|
||||
try:
|
||||
UUID(v)
|
||||
except (ValueError, AttributeError):
|
||||
raise ValueError(f"tenant_id must be a valid UUID, got: {v}")
|
||||
return v
|
||||
|
||||
@validator('inventory_product_ids')
|
||||
def validate_inventory_product_ids(cls, v):
|
||||
"""Validate that all inventory_product_ids are valid UUIDs"""
|
||||
for product_id in v:
|
||||
try:
|
||||
UUID(product_id)
|
||||
except (ValueError, AttributeError):
|
||||
raise ValueError(f"All inventory_product_ids must be valid UUIDs, got invalid: {product_id}")
|
||||
return v
|
||||
|
||||
class ForecastResponse(BaseModel):
|
||||
"""Response schema for forecast results"""
|
||||
id: str
|
||||
|
||||
@@ -498,29 +498,117 @@ class EnhancedForecastingService:
|
||||
weather_date = str(weather_date).split('T')[0]
|
||||
weather_map[weather_date] = weather
|
||||
|
||||
# Generate a forecast for each day
|
||||
for day_offset in range(request.forecast_days):
|
||||
# Calculate the forecast date for this day
|
||||
current_date = request.forecast_date
|
||||
if isinstance(current_date, str):
|
||||
from dateutil.parser import parse
|
||||
current_date = parse(current_date).date()
|
||||
# PERFORMANCE FIX: Get model ONCE before loop (not per day)
|
||||
model_data = await self._get_latest_model_with_fallback(tenant_id, request.inventory_product_id)
|
||||
|
||||
if day_offset > 0:
|
||||
current_date = current_date + timedelta(days=day_offset)
|
||||
if not model_data:
|
||||
raise ValueError(f"No valid model available for product: {request.inventory_product_id}")
|
||||
|
||||
# Create a new request for this specific day
|
||||
daily_request = ForecastRequest(
|
||||
inventory_product_id=request.inventory_product_id,
|
||||
forecast_date=current_date,
|
||||
forecast_days=1, # Single day for each iteration
|
||||
location=request.location,
|
||||
confidence_level=request.confidence_level
|
||||
)
|
||||
# PERFORMANCE FIX: Open single database session for batch operations
|
||||
async with self.database_manager.get_background_session() as session:
|
||||
repos = await self._init_repositories(session)
|
||||
|
||||
# Generate forecast for this day, passing the weather data map
|
||||
daily_forecast = await self.generate_forecast_with_weather_map(tenant_id, daily_request, weather_map)
|
||||
forecasts.append(daily_forecast)
|
||||
# Generate predictions for all days (in-memory, no DB writes yet)
|
||||
forecast_data_list = []
|
||||
|
||||
for day_offset in range(request.forecast_days):
|
||||
# Calculate the forecast date for this day
|
||||
current_date = request.forecast_date
|
||||
if isinstance(current_date, str):
|
||||
from dateutil.parser import parse
|
||||
current_date = parse(current_date).date()
|
||||
|
||||
if day_offset > 0:
|
||||
current_date = current_date + timedelta(days=day_offset)
|
||||
|
||||
# Create a new request for this specific day
|
||||
daily_request = ForecastRequest(
|
||||
inventory_product_id=request.inventory_product_id,
|
||||
forecast_date=current_date,
|
||||
forecast_days=1,
|
||||
location=request.location,
|
||||
confidence_level=request.confidence_level
|
||||
)
|
||||
|
||||
# Check cache first
|
||||
forecast_datetime = current_date
|
||||
if isinstance(forecast_datetime, str):
|
||||
from dateutil.parser import parse
|
||||
forecast_datetime = parse(forecast_datetime)
|
||||
|
||||
cached_prediction = await repos['cache'].get_cached_prediction(
|
||||
tenant_id, request.inventory_product_id, request.location, forecast_datetime
|
||||
)
|
||||
|
||||
if cached_prediction:
|
||||
forecasts.append(self._create_forecast_response_from_cache(cached_prediction))
|
||||
continue
|
||||
|
||||
# Prepare features for this day
|
||||
features = await self._prepare_forecast_features_with_fallbacks_and_weather_map(
|
||||
tenant_id, daily_request, weather_map
|
||||
)
|
||||
|
||||
# Generate prediction (model already loaded and cached in prediction_service)
|
||||
prediction_result = await self.prediction_service.predict(
|
||||
model_id=model_data['model_id'],
|
||||
model_path=model_data['model_path'],
|
||||
features=features,
|
||||
confidence_level=request.confidence_level
|
||||
)
|
||||
|
||||
# Apply business rules
|
||||
adjusted_prediction = self._apply_business_rules(
|
||||
prediction_result, daily_request, features
|
||||
)
|
||||
|
||||
# Prepare forecast data for batch insert
|
||||
forecast_data = {
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": request.inventory_product_id,
|
||||
"product_name": None,
|
||||
"location": request.location,
|
||||
"forecast_date": forecast_datetime,
|
||||
"predicted_demand": adjusted_prediction['prediction'],
|
||||
"confidence_lower": adjusted_prediction.get('lower_bound', max(0.0, float(adjusted_prediction.get('prediction') or 0.0) * 0.8)),
|
||||
"confidence_upper": adjusted_prediction.get('upper_bound', max(0.0, float(adjusted_prediction.get('prediction') or 0.0) * 1.2)),
|
||||
"confidence_level": request.confidence_level,
|
||||
"model_id": model_data['model_id'],
|
||||
"model_version": str(model_data.get('version', '1.0')),
|
||||
"algorithm": model_data.get('algorithm', 'prophet'),
|
||||
"business_type": features.get('business_type', 'individual'),
|
||||
"is_holiday": features.get('is_holiday', False),
|
||||
"is_weekend": features.get('is_weekend', False),
|
||||
"day_of_week": features.get('day_of_week', 0),
|
||||
"weather_temperature": features.get('temperature'),
|
||||
"weather_precipitation": features.get('precipitation'),
|
||||
"weather_description": features.get('weather_description'),
|
||||
"traffic_volume": features.get('traffic_volume'),
|
||||
"processing_time_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000),
|
||||
"features_used": features
|
||||
}
|
||||
forecast_data_list.append((forecast_data, adjusted_prediction, features))
|
||||
|
||||
# PERFORMANCE FIX: Batch insert all forecasts in one transaction
|
||||
for forecast_data, adjusted_prediction, features in forecast_data_list:
|
||||
forecast = await repos['forecast'].create_forecast(forecast_data)
|
||||
forecasts.append(self._create_forecast_response_from_model(forecast))
|
||||
|
||||
# Cache predictions
|
||||
await repos['cache'].cache_prediction(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=request.inventory_product_id,
|
||||
location=request.location,
|
||||
forecast_date=forecast_data['forecast_date'],
|
||||
predicted_demand=adjusted_prediction['prediction'],
|
||||
confidence_lower=adjusted_prediction.get('lower_bound', max(0.0, float(adjusted_prediction.get('prediction') or 0.0) * 0.8)),
|
||||
confidence_upper=adjusted_prediction.get('upper_bound', max(0.0, float(adjusted_prediction.get('prediction') or 0.0) * 1.2)),
|
||||
model_id=model_data['model_id'],
|
||||
expires_in_hours=24
|
||||
)
|
||||
|
||||
# Commit all inserts at once
|
||||
await session.commit()
|
||||
|
||||
# Calculate summary statistics
|
||||
total_demand = sum(f.predicted_demand for f in forecasts)
|
||||
@@ -1140,13 +1228,33 @@ class EnhancedForecastingService:
|
||||
return self._is_spanish_national_holiday(date_obj)
|
||||
|
||||
def _is_spanish_national_holiday(self, date_obj: date) -> bool:
|
||||
"""Check if a date is a major Spanish national holiday (fallback)"""
|
||||
month_day = (date_obj.month, date_obj.day)
|
||||
spanish_holidays = [
|
||||
(1, 1), (1, 6), (5, 1), (8, 15), (10, 12),
|
||||
(11, 1), (12, 6), (12, 8), (12, 25)
|
||||
]
|
||||
return month_day in spanish_holidays
|
||||
"""
|
||||
Check if a date is a Spanish national or regional holiday.
|
||||
Uses the holidays library for comprehensive coverage including movable holidays.
|
||||
"""
|
||||
try:
|
||||
import holidays
|
||||
# Get Spanish holidays (national + common regional)
|
||||
# We don't have tenant location info here, so we use national holidays
|
||||
# which is better than the previous hardcoded list
|
||||
es_holidays = holidays.Spain(years=date_obj.year)
|
||||
return date_obj in es_holidays
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to check holidays library, using fallback: {e}")
|
||||
# Fallback to hardcoded national holidays if library fails
|
||||
month_day = (date_obj.month, date_obj.day)
|
||||
spanish_holidays = [
|
||||
(1, 1), # New Year's Day
|
||||
(1, 6), # Epiphany
|
||||
(5, 1), # Labour Day
|
||||
(8, 15), # Assumption of Mary
|
||||
(10, 12), # National Day
|
||||
(11, 1), # All Saints' Day
|
||||
(12, 6), # Constitution Day
|
||||
(12, 8), # Immaculate Conception
|
||||
(12, 25), # Christmas Day
|
||||
]
|
||||
return month_day in spanish_holidays
|
||||
|
||||
def _apply_business_rules(
|
||||
self,
|
||||
@@ -1156,19 +1264,53 @@ class EnhancedForecastingService:
|
||||
) -> Dict[str, float]:
|
||||
"""Apply Spanish bakery business rules to predictions"""
|
||||
base_prediction = prediction["prediction"]
|
||||
|
||||
|
||||
# Ensure confidence bounds exist with fallbacks
|
||||
lower_bound = prediction.get("lower_bound", base_prediction * 0.8)
|
||||
upper_bound = prediction.get("upper_bound", base_prediction * 1.2)
|
||||
|
||||
|
||||
# Apply adjustment factors
|
||||
adjustment_factor = 1.0
|
||||
|
||||
if features.get("is_weekend", False):
|
||||
adjustment_factor *= 0.8
|
||||
|
||||
if features.get("is_holiday", False):
|
||||
adjustment_factor *= 0.5
|
||||
|
||||
# Get business context
|
||||
is_weekend = features.get("is_weekend", False)
|
||||
is_holiday = features.get("is_holiday", False)
|
||||
business_type = request.business_type if hasattr(request, 'business_type') else None
|
||||
|
||||
# Determine location type from POI features if business_type not explicitly set
|
||||
location_type = business_type
|
||||
if not location_type:
|
||||
# Use POI features to infer location type
|
||||
office_poi = features.get("poi_offices_total_count", 0)
|
||||
residential_poi = features.get("poi_residential_total_count", 0)
|
||||
|
||||
if office_poi > residential_poi * 2:
|
||||
location_type = "office_area"
|
||||
elif residential_poi > office_poi * 2:
|
||||
location_type = "residential_area"
|
||||
else:
|
||||
location_type = "mixed_area"
|
||||
|
||||
# Handle weekend + holiday combination based on location
|
||||
if is_weekend and is_holiday:
|
||||
# Special case: location-based logic for holiday weekends
|
||||
if location_type == "office_area":
|
||||
# Office areas: huge reduction (offices closed)
|
||||
adjustment_factor *= 0.3 # 70% reduction
|
||||
elif location_type == "residential_area":
|
||||
# Residential areas: increase (family gatherings)
|
||||
adjustment_factor *= 1.2 # 20% increase
|
||||
else:
|
||||
# Mixed or unknown: moderate reduction
|
||||
adjustment_factor *= 0.5 # 50% reduction
|
||||
else:
|
||||
# Regular weekend (no holiday)
|
||||
if is_weekend:
|
||||
adjustment_factor *= 0.8
|
||||
|
||||
# Regular holiday (not weekend)
|
||||
if is_holiday:
|
||||
adjustment_factor *= 0.5
|
||||
|
||||
# Weather adjustments
|
||||
precipitation = features.get("precipitation", 0.0)
|
||||
|
||||
@@ -195,7 +195,46 @@ class PredictionService:
|
||||
|
||||
# Prepare features for Prophet model
|
||||
prophet_df = self._prepare_prophet_features(features)
|
||||
|
||||
|
||||
# CRITICAL FIX: Validate that model's required regressors are present
|
||||
# Warn if using default values for features the model was trained with
|
||||
if hasattr(model, 'extra_regressors'):
|
||||
model_regressors = set(model.extra_regressors.keys()) if model.extra_regressors else set()
|
||||
provided_features = set(prophet_df.columns) - {'ds'}
|
||||
|
||||
# Check for missing regressors
|
||||
missing_regressors = model_regressors - provided_features
|
||||
|
||||
if missing_regressors:
|
||||
logger.warning(
|
||||
"Model trained with regressors that are missing in prediction",
|
||||
model_id=model_id,
|
||||
missing_regressors=list(missing_regressors)[:10], # Log first 10
|
||||
total_missing=len(missing_regressors)
|
||||
)
|
||||
|
||||
# Check for default-valued critical features
|
||||
critical_features = {
|
||||
'traffic_volume', 'temperature', 'precipitation',
|
||||
'lag_1_day', 'rolling_mean_7d'
|
||||
}
|
||||
using_defaults = []
|
||||
for feature in critical_features:
|
||||
if feature in model_regressors:
|
||||
value = features.get(feature, 0)
|
||||
# Check if using default/fallback values
|
||||
if (feature == 'traffic_volume' and value == 100.0) or \
|
||||
(feature == 'temperature' and value == 15.0) or \
|
||||
(feature in ['lag_1_day', 'rolling_mean_7d'] and value == 0.0):
|
||||
using_defaults.append(feature)
|
||||
|
||||
if using_defaults:
|
||||
logger.warning(
|
||||
"Using default values for critical model features",
|
||||
model_id=model_id,
|
||||
features_with_defaults=using_defaults
|
||||
)
|
||||
|
||||
# Generate prediction
|
||||
forecast = model.predict(prophet_df)
|
||||
|
||||
@@ -938,8 +977,9 @@ class PredictionService:
|
||||
'is_month_end': int(forecast_date.day >= 28),
|
||||
'is_payday_period': int((forecast_date.day <= 5) or (forecast_date.day >= 25)),
|
||||
# CRITICAL FIX: Add is_payday feature to match training service
|
||||
# Training defines: is_payday = (day == 15 OR is_month_end)
|
||||
'is_payday': int((forecast_date.day == 15) or self._is_end_of_month(forecast_date)),
|
||||
# Training defines: is_payday = (day == 15 OR day == 28 OR is_month_end)
|
||||
# Spain commonly pays on 28th, 15th, or last day of month
|
||||
'is_payday': int((forecast_date.day == 15) or (forecast_date.day == 28) or self._is_end_of_month(forecast_date)),
|
||||
|
||||
# Weather-based derived features
|
||||
'temp_squared': temperature ** 2,
|
||||
|
||||
Reference in New Issue
Block a user