Improve AI logic

This commit is contained in:
Urtzi Alfaro
2025-11-05 13:34:56 +01:00
parent 5c87fbcf48
commit 394ad3aea4
218 changed files with 30627 additions and 7658 deletions

View File

@@ -213,8 +213,7 @@ async def generate_batch_forecast(
tenant_id: str = Path(..., description="Tenant ID"),
request_obj: Request = None,
current_user: dict = Depends(get_current_user_dep),
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service),
rate_limiter = Depends(get_rate_limiter)
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
):
"""Generate forecasts for multiple products in batch (Admin+ only, quota enforced)"""
metrics = get_metrics_collector(request_obj)
@@ -227,30 +226,47 @@ async def generate_batch_forecast(
if metrics:
metrics.increment_counter("batch_forecasts_total")
if not request.inventory_product_ids:
raise ValueError("inventory_product_ids cannot be empty")
# Check if we need to get all products instead of specific ones
inventory_product_ids = request.inventory_product_ids
if inventory_product_ids is None or len(inventory_product_ids) == 0:
# If no specific products requested, fetch all products for the tenant
# from the inventory service to generate forecasts for all of them
from shared.clients.inventory_client import InventoryServiceClient
from app.core.config import settings
inventory_client = InventoryServiceClient(settings)
all_ingredients = await inventory_client.get_all_ingredients(tenant_id=tenant_id)
inventory_product_ids = [str(ingredient['id']) for ingredient in all_ingredients] if all_ingredients else []
# If still no products, return early with success response
if not inventory_product_ids:
logger.info("No products found for forecasting", tenant_id=tenant_id)
from app.schemas.forecasts import BatchForecastResponse
return BatchForecastResponse(
batch_id=str(uuid.uuid4()),
tenant_id=tenant_id,
products_processed=0,
forecasts_generated=0,
success=True,
message="No products found for forecasting"
)
# Get subscription tier and enforce quotas
tier = current_user.get('subscription_tier', 'starter')
# Skip rate limiting for service-to-service calls (orchestrator)
# Rate limiting is handled at the gateway level for user requests
# Check daily quota for forecast generation
quota_limit = get_forecast_quota(tier)
quota_result = await rate_limiter.check_and_increment_quota(
tenant_id,
"forecast_generation",
quota_limit,
period=86400 # 24 hours
# Create a copy of the request with the actual list of product IDs to forecast
# (whether originally provided or fetched from inventory service)
from app.schemas.forecasts import BatchForecastRequest
updated_request = BatchForecastRequest(
tenant_id=tenant_id, # Use the tenant_id from the path parameter
batch_name=getattr(request, 'batch_name', f"orchestrator-batch-{datetime.now().strftime('%Y%m%d')}"),
inventory_product_ids=inventory_product_ids,
forecast_days=getattr(request, 'forecast_days', 7)
)
# Validate forecast horizon if specified
if request.horizon_days:
await rate_limiter.validate_forecast_horizon(
tenant_id, request.horizon_days, tier
)
batch_result = await enhanced_forecasting_service.generate_batch_forecast(
batch_result = await enhanced_forecasting_service.generate_batch_forecasts(
tenant_id=tenant_id,
request=request
request=updated_request
)
if metrics:
@@ -258,9 +274,25 @@ async def generate_batch_forecast(
logger.info("Batch forecast generated successfully",
tenant_id=tenant_id,
total_forecasts=batch_result.total_forecasts)
total_forecasts=batch_result.get('total_forecasts', 0))
return batch_result
# Convert the service result to BatchForecastResponse format
from app.schemas.forecasts import BatchForecastResponse
now = datetime.now(timezone.utc)
return BatchForecastResponse(
id=batch_result.get('batch_id', str(uuid.uuid4())),
tenant_id=tenant_id,
batch_name=updated_request.batch_name,
status="completed",
total_products=batch_result.get('total_forecasts', 0),
completed_products=batch_result.get('successful_forecasts', 0),
failed_products=batch_result.get('failed_forecasts', 0),
requested_at=now,
completed_at=now,
processing_time_ms=0,
forecasts=[],
error_message=None
)
except ValueError as e:
if metrics:
@@ -484,6 +516,174 @@ async def clear_prediction_cache(
)
@router.post(
route_builder.build_operations_route("validate-forecasts"),
response_model=dict
)
@service_only_access
@track_execution_time("validate_forecasts_duration_seconds", "forecasting-service")
async def validate_forecasts(
validation_date: date = Query(..., description="Date to validate forecasts for"),
tenant_id: str = Path(..., description="Tenant ID"),
request_obj: Request = None,
current_user: dict = Depends(get_current_user_dep),
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
):
"""
Validate forecasts for a specific date against actual sales.
Calculates MAPE, RMSE, MAE and identifies products with poor accuracy.
This endpoint is called by the orchestrator during Step 5 to validate
yesterday's forecasts and trigger retraining if needed.
Args:
validation_date: Date to validate forecasts for
tenant_id: Tenant ID
Returns:
Dict with overall metrics and poor accuracy products list:
- overall_mape: Mean Absolute Percentage Error across all products
- overall_rmse: Root Mean Squared Error across all products
- overall_mae: Mean Absolute Error across all products
- products_validated: Number of products validated
- poor_accuracy_products: List of products with MAPE > 30%
"""
metrics = get_metrics_collector(request_obj)
try:
logger.info("Validating forecasts for date",
tenant_id=tenant_id,
validation_date=validation_date.isoformat())
if metrics:
metrics.increment_counter("forecast_validations_total")
# Get all forecasts for the validation date
from app.repositories.forecast_repository import ForecastRepository
from shared.clients.sales_client import SalesServiceClient
db_manager = create_database_manager(settings.DATABASE_URL, "forecasting-service")
async with db_manager.get_session() as session:
forecast_repo = ForecastRepository(session)
# Get forecasts for the validation date
forecasts = await forecast_repo.get_forecasts_by_date(
tenant_id=uuid.UUID(tenant_id),
forecast_date=validation_date
)
if not forecasts:
logger.warning("No forecasts found for validation date",
tenant_id=tenant_id,
validation_date=validation_date.isoformat())
return {
"overall_mape": 0,
"overall_rmse": 0,
"overall_mae": 0,
"products_validated": 0,
"poor_accuracy_products": []
}
# Get actual sales for the validation date from sales service
sales_client = SalesServiceClient(settings, "forecasting-service")
actual_sales_response = await sales_client.get_sales_by_date_range(
tenant_id=tenant_id,
start_date=validation_date,
end_date=validation_date
)
# Create sales lookup dict
sales_dict = {}
if actual_sales_response and 'sales' in actual_sales_response:
for sale in actual_sales_response['sales']:
product_id = sale.get('inventory_product_id')
quantity = sale.get('quantity', 0)
if product_id:
# Aggregate quantities for the same product
sales_dict[product_id] = sales_dict.get(product_id, 0) + quantity
# Calculate metrics for each product
import numpy as np
mape_list = []
rmse_list = []
mae_list = []
poor_accuracy_products = []
for forecast in forecasts:
product_id = str(forecast.inventory_product_id)
actual_quantity = sales_dict.get(product_id)
# Skip if no actual sales data
if actual_quantity is None:
continue
predicted_quantity = forecast.predicted_demand
# Calculate errors
absolute_error = abs(predicted_quantity - actual_quantity)
squared_error = (predicted_quantity - actual_quantity) ** 2
# Calculate percentage error (avoid division by zero)
if actual_quantity > 0:
percentage_error = (absolute_error / actual_quantity) * 100
else:
# If actual is 0 but predicted is not, treat as 100% error
percentage_error = 100 if predicted_quantity > 0 else 0
mape_list.append(percentage_error)
rmse_list.append(squared_error)
mae_list.append(absolute_error)
# Track products with poor accuracy
if percentage_error > 30:
poor_accuracy_products.append({
"product_id": product_id,
"mape": round(percentage_error, 2),
"predicted": round(predicted_quantity, 2),
"actual": round(actual_quantity, 2)
})
# Calculate overall metrics
overall_mape = np.mean(mape_list) if mape_list else 0
overall_rmse = np.sqrt(np.mean(rmse_list)) if rmse_list else 0
overall_mae = np.mean(mae_list) if mae_list else 0
result = {
"overall_mape": round(overall_mape, 2),
"overall_rmse": round(overall_rmse, 2),
"overall_mae": round(overall_mae, 2),
"products_validated": len(mape_list),
"poor_accuracy_products": poor_accuracy_products
}
logger.info("Forecast validation complete",
tenant_id=tenant_id,
validation_date=validation_date.isoformat(),
overall_mape=result["overall_mape"],
products_validated=result["products_validated"],
poor_accuracy_count=len(poor_accuracy_products))
if metrics:
metrics.increment_counter("forecast_validations_completed_total")
metrics.observe_histogram("forecast_validation_mape", overall_mape)
return result
except Exception as e:
logger.error("Failed to validate forecasts",
error=str(e),
tenant_id=tenant_id,
validation_date=validation_date.isoformat())
if metrics:
metrics.increment_counter("forecast_validations_failed_total")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to validate forecasts: {str(e)}"
)
# ============================================================================
# Tenant Data Deletion Operations (Internal Service Only)
# ============================================================================

View File

@@ -0,0 +1,279 @@
"""
ML Insights API Endpoints for Forecasting Service
Provides endpoints to trigger ML insight generation for:
- Dynamic business rules learning
- Demand pattern analysis
- Seasonal trend detection
"""
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
from pydantic import BaseModel, Field
from typing import Optional, List
from uuid import UUID
from datetime import datetime, timedelta
import structlog
import pandas as pd
from app.core.database import get_db
from sqlalchemy.ext.asyncio import AsyncSession
logger = structlog.get_logger()
router = APIRouter(
prefix="/api/v1/tenants/{tenant_id}/forecasting/ml/insights",
tags=["ML Insights"]
)
# ================================================================
# REQUEST/RESPONSE SCHEMAS
# ================================================================
class RulesGenerationRequest(BaseModel):
"""Request schema for rules generation"""
product_ids: Optional[List[str]] = Field(
None,
description="Specific product IDs to analyze. If None, analyzes all products"
)
lookback_days: int = Field(
90,
description="Days of historical data to analyze",
ge=30,
le=365
)
min_samples: int = Field(
10,
description="Minimum samples required for rule learning",
ge=5,
le=100
)
class RulesGenerationResponse(BaseModel):
"""Response schema for rules generation"""
success: bool
message: str
tenant_id: str
products_analyzed: int
total_insights_generated: int
total_insights_posted: int
insights_by_product: dict
errors: List[str] = []
# ================================================================
# API ENDPOINTS
# ================================================================
@router.post("/generate-rules", response_model=RulesGenerationResponse)
async def trigger_rules_generation(
tenant_id: str,
request_data: RulesGenerationRequest,
db: AsyncSession = Depends(get_db)
):
"""
Trigger dynamic business rules learning from historical sales data.
This endpoint:
1. Fetches historical sales data for specified products
2. Runs the RulesOrchestrator to learn patterns
3. Generates insights about optimal business rules
4. Posts insights to AI Insights Service
Args:
tenant_id: Tenant UUID
request_data: Rules generation parameters
db: Database session
Returns:
RulesGenerationResponse with generation results
"""
logger.info(
"ML insights rules generation requested",
tenant_id=tenant_id,
product_ids=request_data.product_ids,
lookback_days=request_data.lookback_days
)
try:
# Import ML orchestrator and clients
from app.ml.rules_orchestrator import RulesOrchestrator
from shared.clients.sales_client import SalesServiceClient
from shared.clients.inventory_client import InventoryServiceClient
from app.core.config import settings
# Initialize orchestrator and clients
orchestrator = RulesOrchestrator()
inventory_client = InventoryServiceClient(settings)
# Get products to analyze from inventory service via API
if request_data.product_ids:
# Fetch specific products
products = []
for product_id in request_data.product_ids:
product = await inventory_client.get_ingredient_by_id(
ingredient_id=UUID(product_id),
tenant_id=tenant_id
)
if product:
products.append(product)
else:
# Fetch all products for tenant (limit to 10)
all_products = await inventory_client.get_all_ingredients(tenant_id=tenant_id)
products = all_products[:10] # Limit to prevent timeout
if not products:
return RulesGenerationResponse(
success=False,
message="No products found for analysis",
tenant_id=tenant_id,
products_analyzed=0,
total_insights_generated=0,
total_insights_posted=0,
insights_by_product={},
errors=["No products found"]
)
# Initialize sales client to fetch historical data
sales_client = SalesServiceClient(config=settings, calling_service_name="forecasting")
# Calculate date range
end_date = datetime.utcnow()
start_date = end_date - timedelta(days=request_data.lookback_days)
# Process each product
total_insights_generated = 0
total_insights_posted = 0
insights_by_product = {}
errors = []
for product in products:
try:
product_id = str(product['id'])
product_name = product.get('name', 'Unknown')
logger.info(f"Analyzing product {product_name} ({product_id})")
# Fetch sales data for product
sales_data = await sales_client.get_sales_data(
tenant_id=tenant_id,
product_id=product_id,
start_date=start_date.strftime('%Y-%m-%d'),
end_date=end_date.strftime('%Y-%m-%d')
)
if not sales_data:
logger.warning(f"No sales data for product {product_id}")
continue
# Convert to DataFrame
sales_df = pd.DataFrame(sales_data)
if len(sales_df) < request_data.min_samples:
logger.warning(
f"Insufficient data for product {product_id}: "
f"{len(sales_df)} samples < {request_data.min_samples} required"
)
continue
# Check what columns are available and map to expected format
logger.debug(f"Sales data columns for product {product_id}: {sales_df.columns.tolist()}")
# Map common field names to 'quantity' and 'date'
if 'quantity' not in sales_df.columns:
if 'total_quantity' in sales_df.columns:
sales_df['quantity'] = sales_df['total_quantity']
elif 'amount' in sales_df.columns:
sales_df['quantity'] = sales_df['amount']
else:
logger.warning(f"No quantity field found for product {product_id}, skipping")
continue
if 'date' not in sales_df.columns:
if 'sale_date' in sales_df.columns:
sales_df['date'] = sales_df['sale_date']
else:
logger.warning(f"No date field found for product {product_id}, skipping")
continue
# Prepare sales data with required columns
sales_df['date'] = pd.to_datetime(sales_df['date'])
sales_df['quantity'] = sales_df['quantity'].astype(float)
sales_df['day_of_week'] = sales_df['date'].dt.dayofweek
sales_df['is_holiday'] = False # TODO: Add holiday detection
sales_df['weather'] = 'unknown' # TODO: Add weather data
# Run rules learning
results = await orchestrator.learn_and_post_rules(
tenant_id=tenant_id,
inventory_product_id=product_id,
sales_data=sales_df,
external_data=None,
min_samples=request_data.min_samples
)
# Track results
total_insights_generated += results['insights_generated']
total_insights_posted += results['insights_posted']
insights_by_product[product_id] = {
'product_name': product_name,
'insights_posted': results['insights_posted'],
'rules_learned': len(results['rules'])
}
logger.info(
f"Product {product_id} analysis complete",
insights_posted=results['insights_posted']
)
except Exception as e:
error_msg = f"Error analyzing product {product_id}: {str(e)}"
logger.error(error_msg, exc_info=True)
errors.append(error_msg)
# Close orchestrator
await orchestrator.close()
# Build response
response = RulesGenerationResponse(
success=total_insights_posted > 0,
message=f"Successfully generated {total_insights_posted} insights from {len(products)} products",
tenant_id=tenant_id,
products_analyzed=len(products),
total_insights_generated=total_insights_generated,
total_insights_posted=total_insights_posted,
insights_by_product=insights_by_product,
errors=errors
)
logger.info(
"ML insights rules generation complete",
tenant_id=tenant_id,
total_insights=total_insights_posted
)
return response
except Exception as e:
logger.error(
"ML insights rules generation failed",
tenant_id=tenant_id,
error=str(e),
exc_info=True
)
raise HTTPException(
status_code=500,
detail=f"Rules generation failed: {str(e)}"
)
@router.get("/health")
async def ml_insights_health():
"""Health check for ML insights endpoints"""
return {
"status": "healthy",
"service": "forecasting-ml-insights",
"endpoints": [
"POST /ml/insights/generate-rules"
]
}