Improve AI logic

This commit is contained in:
Urtzi Alfaro
2025-11-05 13:34:56 +01:00
parent 5c87fbcf48
commit 394ad3aea4
218 changed files with 30627 additions and 7658 deletions

View File

@@ -213,8 +213,7 @@ async def generate_batch_forecast(
tenant_id: str = Path(..., description="Tenant ID"),
request_obj: Request = None,
current_user: dict = Depends(get_current_user_dep),
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service),
rate_limiter = Depends(get_rate_limiter)
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
):
"""Generate forecasts for multiple products in batch (Admin+ only, quota enforced)"""
metrics = get_metrics_collector(request_obj)
@@ -227,30 +226,47 @@ async def generate_batch_forecast(
if metrics:
metrics.increment_counter("batch_forecasts_total")
if not request.inventory_product_ids:
raise ValueError("inventory_product_ids cannot be empty")
# Check if we need to get all products instead of specific ones
inventory_product_ids = request.inventory_product_ids
if inventory_product_ids is None or len(inventory_product_ids) == 0:
# If no specific products requested, fetch all products for the tenant
# from the inventory service to generate forecasts for all of them
from shared.clients.inventory_client import InventoryServiceClient
from app.core.config import settings
inventory_client = InventoryServiceClient(settings)
all_ingredients = await inventory_client.get_all_ingredients(tenant_id=tenant_id)
inventory_product_ids = [str(ingredient['id']) for ingredient in all_ingredients] if all_ingredients else []
# If still no products, return early with success response
if not inventory_product_ids:
logger.info("No products found for forecasting", tenant_id=tenant_id)
from app.schemas.forecasts import BatchForecastResponse
return BatchForecastResponse(
batch_id=str(uuid.uuid4()),
tenant_id=tenant_id,
products_processed=0,
forecasts_generated=0,
success=True,
message="No products found for forecasting"
)
# Get subscription tier and enforce quotas
tier = current_user.get('subscription_tier', 'starter')
# Skip rate limiting for service-to-service calls (orchestrator)
# Rate limiting is handled at the gateway level for user requests
# Check daily quota for forecast generation
quota_limit = get_forecast_quota(tier)
quota_result = await rate_limiter.check_and_increment_quota(
tenant_id,
"forecast_generation",
quota_limit,
period=86400 # 24 hours
# Create a copy of the request with the actual list of product IDs to forecast
# (whether originally provided or fetched from inventory service)
from app.schemas.forecasts import BatchForecastRequest
updated_request = BatchForecastRequest(
tenant_id=tenant_id, # Use the tenant_id from the path parameter
batch_name=getattr(request, 'batch_name', f"orchestrator-batch-{datetime.now().strftime('%Y%m%d')}"),
inventory_product_ids=inventory_product_ids,
forecast_days=getattr(request, 'forecast_days', 7)
)
# Validate forecast horizon if specified
if request.horizon_days:
await rate_limiter.validate_forecast_horizon(
tenant_id, request.horizon_days, tier
)
batch_result = await enhanced_forecasting_service.generate_batch_forecast(
batch_result = await enhanced_forecasting_service.generate_batch_forecasts(
tenant_id=tenant_id,
request=request
request=updated_request
)
if metrics:
@@ -258,9 +274,25 @@ async def generate_batch_forecast(
logger.info("Batch forecast generated successfully",
tenant_id=tenant_id,
total_forecasts=batch_result.total_forecasts)
total_forecasts=batch_result.get('total_forecasts', 0))
return batch_result
# Convert the service result to BatchForecastResponse format
from app.schemas.forecasts import BatchForecastResponse
now = datetime.now(timezone.utc)
return BatchForecastResponse(
id=batch_result.get('batch_id', str(uuid.uuid4())),
tenant_id=tenant_id,
batch_name=updated_request.batch_name,
status="completed",
total_products=batch_result.get('total_forecasts', 0),
completed_products=batch_result.get('successful_forecasts', 0),
failed_products=batch_result.get('failed_forecasts', 0),
requested_at=now,
completed_at=now,
processing_time_ms=0,
forecasts=[],
error_message=None
)
except ValueError as e:
if metrics:
@@ -484,6 +516,174 @@ async def clear_prediction_cache(
)
@router.post(
route_builder.build_operations_route("validate-forecasts"),
response_model=dict
)
@service_only_access
@track_execution_time("validate_forecasts_duration_seconds", "forecasting-service")
async def validate_forecasts(
validation_date: date = Query(..., description="Date to validate forecasts for"),
tenant_id: str = Path(..., description="Tenant ID"),
request_obj: Request = None,
current_user: dict = Depends(get_current_user_dep),
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
):
"""
Validate forecasts for a specific date against actual sales.
Calculates MAPE, RMSE, MAE and identifies products with poor accuracy.
This endpoint is called by the orchestrator during Step 5 to validate
yesterday's forecasts and trigger retraining if needed.
Args:
validation_date: Date to validate forecasts for
tenant_id: Tenant ID
Returns:
Dict with overall metrics and poor accuracy products list:
- overall_mape: Mean Absolute Percentage Error across all products
- overall_rmse: Root Mean Squared Error across all products
- overall_mae: Mean Absolute Error across all products
- products_validated: Number of products validated
- poor_accuracy_products: List of products with MAPE > 30%
"""
metrics = get_metrics_collector(request_obj)
try:
logger.info("Validating forecasts for date",
tenant_id=tenant_id,
validation_date=validation_date.isoformat())
if metrics:
metrics.increment_counter("forecast_validations_total")
# Get all forecasts for the validation date
from app.repositories.forecast_repository import ForecastRepository
from shared.clients.sales_client import SalesServiceClient
db_manager = create_database_manager(settings.DATABASE_URL, "forecasting-service")
async with db_manager.get_session() as session:
forecast_repo = ForecastRepository(session)
# Get forecasts for the validation date
forecasts = await forecast_repo.get_forecasts_by_date(
tenant_id=uuid.UUID(tenant_id),
forecast_date=validation_date
)
if not forecasts:
logger.warning("No forecasts found for validation date",
tenant_id=tenant_id,
validation_date=validation_date.isoformat())
return {
"overall_mape": 0,
"overall_rmse": 0,
"overall_mae": 0,
"products_validated": 0,
"poor_accuracy_products": []
}
# Get actual sales for the validation date from sales service
sales_client = SalesServiceClient(settings, "forecasting-service")
actual_sales_response = await sales_client.get_sales_by_date_range(
tenant_id=tenant_id,
start_date=validation_date,
end_date=validation_date
)
# Create sales lookup dict
sales_dict = {}
if actual_sales_response and 'sales' in actual_sales_response:
for sale in actual_sales_response['sales']:
product_id = sale.get('inventory_product_id')
quantity = sale.get('quantity', 0)
if product_id:
# Aggregate quantities for the same product
sales_dict[product_id] = sales_dict.get(product_id, 0) + quantity
# Calculate metrics for each product
import numpy as np
mape_list = []
rmse_list = []
mae_list = []
poor_accuracy_products = []
for forecast in forecasts:
product_id = str(forecast.inventory_product_id)
actual_quantity = sales_dict.get(product_id)
# Skip if no actual sales data
if actual_quantity is None:
continue
predicted_quantity = forecast.predicted_demand
# Calculate errors
absolute_error = abs(predicted_quantity - actual_quantity)
squared_error = (predicted_quantity - actual_quantity) ** 2
# Calculate percentage error (avoid division by zero)
if actual_quantity > 0:
percentage_error = (absolute_error / actual_quantity) * 100
else:
# If actual is 0 but predicted is not, treat as 100% error
percentage_error = 100 if predicted_quantity > 0 else 0
mape_list.append(percentage_error)
rmse_list.append(squared_error)
mae_list.append(absolute_error)
# Track products with poor accuracy
if percentage_error > 30:
poor_accuracy_products.append({
"product_id": product_id,
"mape": round(percentage_error, 2),
"predicted": round(predicted_quantity, 2),
"actual": round(actual_quantity, 2)
})
# Calculate overall metrics
overall_mape = np.mean(mape_list) if mape_list else 0
overall_rmse = np.sqrt(np.mean(rmse_list)) if rmse_list else 0
overall_mae = np.mean(mae_list) if mae_list else 0
result = {
"overall_mape": round(overall_mape, 2),
"overall_rmse": round(overall_rmse, 2),
"overall_mae": round(overall_mae, 2),
"products_validated": len(mape_list),
"poor_accuracy_products": poor_accuracy_products
}
logger.info("Forecast validation complete",
tenant_id=tenant_id,
validation_date=validation_date.isoformat(),
overall_mape=result["overall_mape"],
products_validated=result["products_validated"],
poor_accuracy_count=len(poor_accuracy_products))
if metrics:
metrics.increment_counter("forecast_validations_completed_total")
metrics.observe_histogram("forecast_validation_mape", overall_mape)
return result
except Exception as e:
logger.error("Failed to validate forecasts",
error=str(e),
tenant_id=tenant_id,
validation_date=validation_date.isoformat())
if metrics:
metrics.increment_counter("forecast_validations_failed_total")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to validate forecasts: {str(e)}"
)
# ============================================================================
# Tenant Data Deletion Operations (Internal Service Only)
# ============================================================================