Add improvements

This commit is contained in:
Urtzi Alfaro
2026-01-12 14:24:14 +01:00
parent 6037faaf8c
commit 230bbe6a19
61 changed files with 1668 additions and 894 deletions

View File

@@ -9,6 +9,7 @@ from fastapi.responses import JSONResponse
from typing import List, Dict, Any, Optional
from datetime import date, datetime, timezone
import uuid
from uuid import UUID
from app.services.forecasting_service import EnhancedForecastingService
from app.services.prediction_service import PredictionService
@@ -42,6 +43,30 @@ async def get_rate_limiter():
return create_rate_limiter(redis_client)
def validate_uuid(value: str, field_name: str = "ID") -> str:
"""
Validate that a string is a valid UUID.
Args:
value: The string to validate
field_name: Name of the field for error messages
Returns:
The validated UUID string
Raises:
HTTPException: If the value is not a valid UUID
"""
try:
UUID(value)
return value
except (ValueError, AttributeError):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"{field_name} must be a valid UUID, got: {value}"
)
def get_enhanced_forecasting_service():
"""Dependency injection for EnhancedForecastingService"""
database_manager = create_database_manager(settings.DATABASE_URL, "forecasting-service")
@@ -68,6 +93,10 @@ async def generate_single_forecast(
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
):
"""Generate a single product forecast with caching support"""
# Validate UUID fields
validate_uuid(tenant_id, "tenant_id")
# inventory_product_id already validated by ForecastRequest schema
metrics = get_metrics_collector(request_obj)
try:

View File

@@ -28,15 +28,6 @@ router = APIRouter(prefix="/internal/demo", tags=["internal"])
DEMO_TENANT_PROFESSIONAL = "a1b2c3d4-e5f6-47a8-b9c0-d1e2f3a4b5c6"
def verify_internal_api_key(x_internal_api_key: Optional[str] = Header(None)):
"""Verify internal API key for service-to-service communication"""
from app.core.config import settings
if x_internal_api_key != settings.INTERNAL_API_KEY:
logger.warning("Unauthorized internal API access attempted")
raise HTTPException(status_code=403, detail="Invalid internal API key")
return True
def parse_date_field(date_value, session_time: datetime, field_name: str = "date") -> Optional[datetime]:
"""
Parse date field, handling both ISO strings and BASE_TS markers.
@@ -98,8 +89,7 @@ async def clone_demo_data(
demo_account_type: str,
session_id: Optional[str] = None,
session_created_at: Optional[str] = None,
db: AsyncSession = Depends(get_db),
_: bool = Depends(verify_internal_api_key)
db: AsyncSession = Depends(get_db)
):
"""
Clone forecasting service data for a virtual demo tenant
@@ -406,7 +396,7 @@ async def clone_demo_data(
@router.get("/clone/health")
async def clone_health_check(_: bool = Depends(verify_internal_api_key)):
async def clone_health_check():
"""
Health check for internal cloning endpoint
Used by orchestrator to verify service availability
@@ -421,8 +411,7 @@ async def clone_health_check(_: bool = Depends(verify_internal_api_key)):
@router.delete("/tenant/{virtual_tenant_id}")
async def delete_demo_tenant_data(
virtual_tenant_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
_: bool = Depends(verify_internal_api_key)
db: AsyncSession = Depends(get_db)
):
"""
Delete all demo data for a virtual tenant.

View File

@@ -9,6 +9,7 @@ from pydantic import BaseModel, Field, validator
from datetime import datetime, date
from typing import Optional, List, Dict, Any
from enum import Enum
from uuid import UUID
class BusinessType(str, Enum):
INDIVIDUAL = "individual"
@@ -22,10 +23,19 @@ class ForecastRequest(BaseModel):
forecast_date: date = Field(..., description="Starting date for forecast")
forecast_days: int = Field(1, ge=1, le=30, description="Number of days to forecast")
location: str = Field(..., description="Location identifier")
# Optional parameters - internally handled
confidence_level: float = Field(0.8, ge=0.5, le=0.95, description="Confidence level")
@validator('inventory_product_id')
def validate_inventory_product_id(cls, v):
"""Validate that inventory_product_id is a valid UUID"""
try:
UUID(v)
except (ValueError, AttributeError):
raise ValueError(f"inventory_product_id must be a valid UUID, got: {v}")
return v
@validator('forecast_date')
def validate_forecast_date(cls, v):
if v < date.today():
@@ -39,6 +49,26 @@ class BatchForecastRequest(BaseModel):
inventory_product_ids: List[str] = Field(..., description="List of inventory product IDs")
forecast_days: int = Field(7, ge=1, le=30, description="Number of days to forecast")
@validator('tenant_id')
def validate_tenant_id(cls, v):
"""Validate that tenant_id is a valid UUID if provided"""
if v is not None:
try:
UUID(v)
except (ValueError, AttributeError):
raise ValueError(f"tenant_id must be a valid UUID, got: {v}")
return v
@validator('inventory_product_ids')
def validate_inventory_product_ids(cls, v):
"""Validate that all inventory_product_ids are valid UUIDs"""
for product_id in v:
try:
UUID(product_id)
except (ValueError, AttributeError):
raise ValueError(f"All inventory_product_ids must be valid UUIDs, got invalid: {product_id}")
return v
class ForecastResponse(BaseModel):
"""Response schema for forecast results"""
id: str

View File

@@ -498,29 +498,117 @@ class EnhancedForecastingService:
weather_date = str(weather_date).split('T')[0]
weather_map[weather_date] = weather
# Generate a forecast for each day
for day_offset in range(request.forecast_days):
# Calculate the forecast date for this day
current_date = request.forecast_date
if isinstance(current_date, str):
from dateutil.parser import parse
current_date = parse(current_date).date()
# PERFORMANCE FIX: Get model ONCE before loop (not per day)
model_data = await self._get_latest_model_with_fallback(tenant_id, request.inventory_product_id)
if day_offset > 0:
current_date = current_date + timedelta(days=day_offset)
if not model_data:
raise ValueError(f"No valid model available for product: {request.inventory_product_id}")
# Create a new request for this specific day
daily_request = ForecastRequest(
inventory_product_id=request.inventory_product_id,
forecast_date=current_date,
forecast_days=1, # Single day for each iteration
location=request.location,
confidence_level=request.confidence_level
)
# PERFORMANCE FIX: Open single database session for batch operations
async with self.database_manager.get_background_session() as session:
repos = await self._init_repositories(session)
# Generate forecast for this day, passing the weather data map
daily_forecast = await self.generate_forecast_with_weather_map(tenant_id, daily_request, weather_map)
forecasts.append(daily_forecast)
# Generate predictions for all days (in-memory, no DB writes yet)
forecast_data_list = []
for day_offset in range(request.forecast_days):
# Calculate the forecast date for this day
current_date = request.forecast_date
if isinstance(current_date, str):
from dateutil.parser import parse
current_date = parse(current_date).date()
if day_offset > 0:
current_date = current_date + timedelta(days=day_offset)
# Create a new request for this specific day
daily_request = ForecastRequest(
inventory_product_id=request.inventory_product_id,
forecast_date=current_date,
forecast_days=1,
location=request.location,
confidence_level=request.confidence_level
)
# Check cache first
forecast_datetime = current_date
if isinstance(forecast_datetime, str):
from dateutil.parser import parse
forecast_datetime = parse(forecast_datetime)
cached_prediction = await repos['cache'].get_cached_prediction(
tenant_id, request.inventory_product_id, request.location, forecast_datetime
)
if cached_prediction:
forecasts.append(self._create_forecast_response_from_cache(cached_prediction))
continue
# Prepare features for this day
features = await self._prepare_forecast_features_with_fallbacks_and_weather_map(
tenant_id, daily_request, weather_map
)
# Generate prediction (model already loaded and cached in prediction_service)
prediction_result = await self.prediction_service.predict(
model_id=model_data['model_id'],
model_path=model_data['model_path'],
features=features,
confidence_level=request.confidence_level
)
# Apply business rules
adjusted_prediction = self._apply_business_rules(
prediction_result, daily_request, features
)
# Prepare forecast data for batch insert
forecast_data = {
"tenant_id": tenant_id,
"inventory_product_id": request.inventory_product_id,
"product_name": None,
"location": request.location,
"forecast_date": forecast_datetime,
"predicted_demand": adjusted_prediction['prediction'],
"confidence_lower": adjusted_prediction.get('lower_bound', max(0.0, float(adjusted_prediction.get('prediction') or 0.0) * 0.8)),
"confidence_upper": adjusted_prediction.get('upper_bound', max(0.0, float(adjusted_prediction.get('prediction') or 0.0) * 1.2)),
"confidence_level": request.confidence_level,
"model_id": model_data['model_id'],
"model_version": str(model_data.get('version', '1.0')),
"algorithm": model_data.get('algorithm', 'prophet'),
"business_type": features.get('business_type', 'individual'),
"is_holiday": features.get('is_holiday', False),
"is_weekend": features.get('is_weekend', False),
"day_of_week": features.get('day_of_week', 0),
"weather_temperature": features.get('temperature'),
"weather_precipitation": features.get('precipitation'),
"weather_description": features.get('weather_description'),
"traffic_volume": features.get('traffic_volume'),
"processing_time_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000),
"features_used": features
}
forecast_data_list.append((forecast_data, adjusted_prediction, features))
# PERFORMANCE FIX: Batch insert all forecasts in one transaction
for forecast_data, adjusted_prediction, features in forecast_data_list:
forecast = await repos['forecast'].create_forecast(forecast_data)
forecasts.append(self._create_forecast_response_from_model(forecast))
# Cache predictions
await repos['cache'].cache_prediction(
tenant_id=tenant_id,
inventory_product_id=request.inventory_product_id,
location=request.location,
forecast_date=forecast_data['forecast_date'],
predicted_demand=adjusted_prediction['prediction'],
confidence_lower=adjusted_prediction.get('lower_bound', max(0.0, float(adjusted_prediction.get('prediction') or 0.0) * 0.8)),
confidence_upper=adjusted_prediction.get('upper_bound', max(0.0, float(adjusted_prediction.get('prediction') or 0.0) * 1.2)),
model_id=model_data['model_id'],
expires_in_hours=24
)
# Commit all inserts at once
await session.commit()
# Calculate summary statistics
total_demand = sum(f.predicted_demand for f in forecasts)
@@ -1140,13 +1228,33 @@ class EnhancedForecastingService:
return self._is_spanish_national_holiday(date_obj)
def _is_spanish_national_holiday(self, date_obj: date) -> bool:
"""Check if a date is a major Spanish national holiday (fallback)"""
month_day = (date_obj.month, date_obj.day)
spanish_holidays = [
(1, 1), (1, 6), (5, 1), (8, 15), (10, 12),
(11, 1), (12, 6), (12, 8), (12, 25)
]
return month_day in spanish_holidays
"""
Check if a date is a Spanish national or regional holiday.
Uses the holidays library for comprehensive coverage including movable holidays.
"""
try:
import holidays
# Get Spanish holidays (national + common regional)
# We don't have tenant location info here, so we use national holidays
# which is better than the previous hardcoded list
es_holidays = holidays.Spain(years=date_obj.year)
return date_obj in es_holidays
except Exception as e:
logger.warning(f"Failed to check holidays library, using fallback: {e}")
# Fallback to hardcoded national holidays if library fails
month_day = (date_obj.month, date_obj.day)
spanish_holidays = [
(1, 1), # New Year's Day
(1, 6), # Epiphany
(5, 1), # Labour Day
(8, 15), # Assumption of Mary
(10, 12), # National Day
(11, 1), # All Saints' Day
(12, 6), # Constitution Day
(12, 8), # Immaculate Conception
(12, 25), # Christmas Day
]
return month_day in spanish_holidays
def _apply_business_rules(
self,
@@ -1156,19 +1264,53 @@ class EnhancedForecastingService:
) -> Dict[str, float]:
"""Apply Spanish bakery business rules to predictions"""
base_prediction = prediction["prediction"]
# Ensure confidence bounds exist with fallbacks
lower_bound = prediction.get("lower_bound", base_prediction * 0.8)
upper_bound = prediction.get("upper_bound", base_prediction * 1.2)
# Apply adjustment factors
adjustment_factor = 1.0
if features.get("is_weekend", False):
adjustment_factor *= 0.8
if features.get("is_holiday", False):
adjustment_factor *= 0.5
# Get business context
is_weekend = features.get("is_weekend", False)
is_holiday = features.get("is_holiday", False)
business_type = request.business_type if hasattr(request, 'business_type') else None
# Determine location type from POI features if business_type not explicitly set
location_type = business_type
if not location_type:
# Use POI features to infer location type
office_poi = features.get("poi_offices_total_count", 0)
residential_poi = features.get("poi_residential_total_count", 0)
if office_poi > residential_poi * 2:
location_type = "office_area"
elif residential_poi > office_poi * 2:
location_type = "residential_area"
else:
location_type = "mixed_area"
# Handle weekend + holiday combination based on location
if is_weekend and is_holiday:
# Special case: location-based logic for holiday weekends
if location_type == "office_area":
# Office areas: huge reduction (offices closed)
adjustment_factor *= 0.3 # 70% reduction
elif location_type == "residential_area":
# Residential areas: increase (family gatherings)
adjustment_factor *= 1.2 # 20% increase
else:
# Mixed or unknown: moderate reduction
adjustment_factor *= 0.5 # 50% reduction
else:
# Regular weekend (no holiday)
if is_weekend:
adjustment_factor *= 0.8
# Regular holiday (not weekend)
if is_holiday:
adjustment_factor *= 0.5
# Weather adjustments
precipitation = features.get("precipitation", 0.0)

View File

@@ -195,7 +195,46 @@ class PredictionService:
# Prepare features for Prophet model
prophet_df = self._prepare_prophet_features(features)
# CRITICAL FIX: Validate that model's required regressors are present
# Warn if using default values for features the model was trained with
if hasattr(model, 'extra_regressors'):
model_regressors = set(model.extra_regressors.keys()) if model.extra_regressors else set()
provided_features = set(prophet_df.columns) - {'ds'}
# Check for missing regressors
missing_regressors = model_regressors - provided_features
if missing_regressors:
logger.warning(
"Model trained with regressors that are missing in prediction",
model_id=model_id,
missing_regressors=list(missing_regressors)[:10], # Log first 10
total_missing=len(missing_regressors)
)
# Check for default-valued critical features
critical_features = {
'traffic_volume', 'temperature', 'precipitation',
'lag_1_day', 'rolling_mean_7d'
}
using_defaults = []
for feature in critical_features:
if feature in model_regressors:
value = features.get(feature, 0)
# Check if using default/fallback values
if (feature == 'traffic_volume' and value == 100.0) or \
(feature == 'temperature' and value == 15.0) or \
(feature in ['lag_1_day', 'rolling_mean_7d'] and value == 0.0):
using_defaults.append(feature)
if using_defaults:
logger.warning(
"Using default values for critical model features",
model_id=model_id,
features_with_defaults=using_defaults
)
# Generate prediction
forecast = model.predict(prophet_df)
@@ -938,8 +977,9 @@ class PredictionService:
'is_month_end': int(forecast_date.day >= 28),
'is_payday_period': int((forecast_date.day <= 5) or (forecast_date.day >= 25)),
# CRITICAL FIX: Add is_payday feature to match training service
# Training defines: is_payday = (day == 15 OR is_month_end)
'is_payday': int((forecast_date.day == 15) or self._is_end_of_month(forecast_date)),
# Training defines: is_payday = (day == 15 OR day == 28 OR is_month_end)
# Spain commonly pays on 28th, 15th, or last day of month
'is_payday': int((forecast_date.day == 15) or (forecast_date.day == 28) or self._is_end_of_month(forecast_date)),
# Weather-based derived features
'temp_squared': temperature ** 2,