Improve AI logic
This commit is contained in:
@@ -213,8 +213,7 @@ async def generate_batch_forecast(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
request_obj: Request = None,
|
||||
current_user: dict = Depends(get_current_user_dep),
|
||||
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service),
|
||||
rate_limiter = Depends(get_rate_limiter)
|
||||
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
||||
):
|
||||
"""Generate forecasts for multiple products in batch (Admin+ only, quota enforced)"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
@@ -227,30 +226,47 @@ async def generate_batch_forecast(
|
||||
if metrics:
|
||||
metrics.increment_counter("batch_forecasts_total")
|
||||
|
||||
if not request.inventory_product_ids:
|
||||
raise ValueError("inventory_product_ids cannot be empty")
|
||||
# Check if we need to get all products instead of specific ones
|
||||
inventory_product_ids = request.inventory_product_ids
|
||||
if inventory_product_ids is None or len(inventory_product_ids) == 0:
|
||||
# If no specific products requested, fetch all products for the tenant
|
||||
# from the inventory service to generate forecasts for all of them
|
||||
from shared.clients.inventory_client import InventoryServiceClient
|
||||
from app.core.config import settings
|
||||
|
||||
inventory_client = InventoryServiceClient(settings)
|
||||
all_ingredients = await inventory_client.get_all_ingredients(tenant_id=tenant_id)
|
||||
inventory_product_ids = [str(ingredient['id']) for ingredient in all_ingredients] if all_ingredients else []
|
||||
|
||||
# If still no products, return early with success response
|
||||
if not inventory_product_ids:
|
||||
logger.info("No products found for forecasting", tenant_id=tenant_id)
|
||||
from app.schemas.forecasts import BatchForecastResponse
|
||||
return BatchForecastResponse(
|
||||
batch_id=str(uuid.uuid4()),
|
||||
tenant_id=tenant_id,
|
||||
products_processed=0,
|
||||
forecasts_generated=0,
|
||||
success=True,
|
||||
message="No products found for forecasting"
|
||||
)
|
||||
|
||||
# Get subscription tier and enforce quotas
|
||||
tier = current_user.get('subscription_tier', 'starter')
|
||||
# Skip rate limiting for service-to-service calls (orchestrator)
|
||||
# Rate limiting is handled at the gateway level for user requests
|
||||
|
||||
# Check daily quota for forecast generation
|
||||
quota_limit = get_forecast_quota(tier)
|
||||
quota_result = await rate_limiter.check_and_increment_quota(
|
||||
tenant_id,
|
||||
"forecast_generation",
|
||||
quota_limit,
|
||||
period=86400 # 24 hours
|
||||
# Create a copy of the request with the actual list of product IDs to forecast
|
||||
# (whether originally provided or fetched from inventory service)
|
||||
from app.schemas.forecasts import BatchForecastRequest
|
||||
updated_request = BatchForecastRequest(
|
||||
tenant_id=tenant_id, # Use the tenant_id from the path parameter
|
||||
batch_name=getattr(request, 'batch_name', f"orchestrator-batch-{datetime.now().strftime('%Y%m%d')}"),
|
||||
inventory_product_ids=inventory_product_ids,
|
||||
forecast_days=getattr(request, 'forecast_days', 7)
|
||||
)
|
||||
|
||||
# Validate forecast horizon if specified
|
||||
if request.horizon_days:
|
||||
await rate_limiter.validate_forecast_horizon(
|
||||
tenant_id, request.horizon_days, tier
|
||||
)
|
||||
|
||||
batch_result = await enhanced_forecasting_service.generate_batch_forecast(
|
||||
|
||||
batch_result = await enhanced_forecasting_service.generate_batch_forecasts(
|
||||
tenant_id=tenant_id,
|
||||
request=request
|
||||
request=updated_request
|
||||
)
|
||||
|
||||
if metrics:
|
||||
@@ -258,9 +274,25 @@ async def generate_batch_forecast(
|
||||
|
||||
logger.info("Batch forecast generated successfully",
|
||||
tenant_id=tenant_id,
|
||||
total_forecasts=batch_result.total_forecasts)
|
||||
total_forecasts=batch_result.get('total_forecasts', 0))
|
||||
|
||||
return batch_result
|
||||
# Convert the service result to BatchForecastResponse format
|
||||
from app.schemas.forecasts import BatchForecastResponse
|
||||
now = datetime.now(timezone.utc)
|
||||
return BatchForecastResponse(
|
||||
id=batch_result.get('batch_id', str(uuid.uuid4())),
|
||||
tenant_id=tenant_id,
|
||||
batch_name=updated_request.batch_name,
|
||||
status="completed",
|
||||
total_products=batch_result.get('total_forecasts', 0),
|
||||
completed_products=batch_result.get('successful_forecasts', 0),
|
||||
failed_products=batch_result.get('failed_forecasts', 0),
|
||||
requested_at=now,
|
||||
completed_at=now,
|
||||
processing_time_ms=0,
|
||||
forecasts=[],
|
||||
error_message=None
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
if metrics:
|
||||
@@ -484,6 +516,174 @@ async def clear_prediction_cache(
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_operations_route("validate-forecasts"),
|
||||
response_model=dict
|
||||
)
|
||||
@service_only_access
|
||||
@track_execution_time("validate_forecasts_duration_seconds", "forecasting-service")
|
||||
async def validate_forecasts(
|
||||
validation_date: date = Query(..., description="Date to validate forecasts for"),
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
request_obj: Request = None,
|
||||
current_user: dict = Depends(get_current_user_dep),
|
||||
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
||||
):
|
||||
"""
|
||||
Validate forecasts for a specific date against actual sales.
|
||||
Calculates MAPE, RMSE, MAE and identifies products with poor accuracy.
|
||||
|
||||
This endpoint is called by the orchestrator during Step 5 to validate
|
||||
yesterday's forecasts and trigger retraining if needed.
|
||||
|
||||
Args:
|
||||
validation_date: Date to validate forecasts for
|
||||
tenant_id: Tenant ID
|
||||
|
||||
Returns:
|
||||
Dict with overall metrics and poor accuracy products list:
|
||||
- overall_mape: Mean Absolute Percentage Error across all products
|
||||
- overall_rmse: Root Mean Squared Error across all products
|
||||
- overall_mae: Mean Absolute Error across all products
|
||||
- products_validated: Number of products validated
|
||||
- poor_accuracy_products: List of products with MAPE > 30%
|
||||
"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
logger.info("Validating forecasts for date",
|
||||
tenant_id=tenant_id,
|
||||
validation_date=validation_date.isoformat())
|
||||
|
||||
if metrics:
|
||||
metrics.increment_counter("forecast_validations_total")
|
||||
|
||||
# Get all forecasts for the validation date
|
||||
from app.repositories.forecast_repository import ForecastRepository
|
||||
from shared.clients.sales_client import SalesServiceClient
|
||||
|
||||
db_manager = create_database_manager(settings.DATABASE_URL, "forecasting-service")
|
||||
|
||||
async with db_manager.get_session() as session:
|
||||
forecast_repo = ForecastRepository(session)
|
||||
|
||||
# Get forecasts for the validation date
|
||||
forecasts = await forecast_repo.get_forecasts_by_date(
|
||||
tenant_id=uuid.UUID(tenant_id),
|
||||
forecast_date=validation_date
|
||||
)
|
||||
|
||||
if not forecasts:
|
||||
logger.warning("No forecasts found for validation date",
|
||||
tenant_id=tenant_id,
|
||||
validation_date=validation_date.isoformat())
|
||||
return {
|
||||
"overall_mape": 0,
|
||||
"overall_rmse": 0,
|
||||
"overall_mae": 0,
|
||||
"products_validated": 0,
|
||||
"poor_accuracy_products": []
|
||||
}
|
||||
|
||||
# Get actual sales for the validation date from sales service
|
||||
sales_client = SalesServiceClient(settings, "forecasting-service")
|
||||
actual_sales_response = await sales_client.get_sales_by_date_range(
|
||||
tenant_id=tenant_id,
|
||||
start_date=validation_date,
|
||||
end_date=validation_date
|
||||
)
|
||||
|
||||
# Create sales lookup dict
|
||||
sales_dict = {}
|
||||
if actual_sales_response and 'sales' in actual_sales_response:
|
||||
for sale in actual_sales_response['sales']:
|
||||
product_id = sale.get('inventory_product_id')
|
||||
quantity = sale.get('quantity', 0)
|
||||
if product_id:
|
||||
# Aggregate quantities for the same product
|
||||
sales_dict[product_id] = sales_dict.get(product_id, 0) + quantity
|
||||
|
||||
# Calculate metrics for each product
|
||||
import numpy as np
|
||||
|
||||
mape_list = []
|
||||
rmse_list = []
|
||||
mae_list = []
|
||||
poor_accuracy_products = []
|
||||
|
||||
for forecast in forecasts:
|
||||
product_id = str(forecast.inventory_product_id)
|
||||
actual_quantity = sales_dict.get(product_id)
|
||||
|
||||
# Skip if no actual sales data
|
||||
if actual_quantity is None:
|
||||
continue
|
||||
|
||||
predicted_quantity = forecast.predicted_demand
|
||||
|
||||
# Calculate errors
|
||||
absolute_error = abs(predicted_quantity - actual_quantity)
|
||||
squared_error = (predicted_quantity - actual_quantity) ** 2
|
||||
|
||||
# Calculate percentage error (avoid division by zero)
|
||||
if actual_quantity > 0:
|
||||
percentage_error = (absolute_error / actual_quantity) * 100
|
||||
else:
|
||||
# If actual is 0 but predicted is not, treat as 100% error
|
||||
percentage_error = 100 if predicted_quantity > 0 else 0
|
||||
|
||||
mape_list.append(percentage_error)
|
||||
rmse_list.append(squared_error)
|
||||
mae_list.append(absolute_error)
|
||||
|
||||
# Track products with poor accuracy
|
||||
if percentage_error > 30:
|
||||
poor_accuracy_products.append({
|
||||
"product_id": product_id,
|
||||
"mape": round(percentage_error, 2),
|
||||
"predicted": round(predicted_quantity, 2),
|
||||
"actual": round(actual_quantity, 2)
|
||||
})
|
||||
|
||||
# Calculate overall metrics
|
||||
overall_mape = np.mean(mape_list) if mape_list else 0
|
||||
overall_rmse = np.sqrt(np.mean(rmse_list)) if rmse_list else 0
|
||||
overall_mae = np.mean(mae_list) if mae_list else 0
|
||||
|
||||
result = {
|
||||
"overall_mape": round(overall_mape, 2),
|
||||
"overall_rmse": round(overall_rmse, 2),
|
||||
"overall_mae": round(overall_mae, 2),
|
||||
"products_validated": len(mape_list),
|
||||
"poor_accuracy_products": poor_accuracy_products
|
||||
}
|
||||
|
||||
logger.info("Forecast validation complete",
|
||||
tenant_id=tenant_id,
|
||||
validation_date=validation_date.isoformat(),
|
||||
overall_mape=result["overall_mape"],
|
||||
products_validated=result["products_validated"],
|
||||
poor_accuracy_count=len(poor_accuracy_products))
|
||||
|
||||
if metrics:
|
||||
metrics.increment_counter("forecast_validations_completed_total")
|
||||
metrics.observe_histogram("forecast_validation_mape", overall_mape)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to validate forecasts",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id,
|
||||
validation_date=validation_date.isoformat())
|
||||
if metrics:
|
||||
metrics.increment_counter("forecast_validations_failed_total")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to validate forecasts: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tenant Data Deletion Operations (Internal Service Only)
|
||||
# ============================================================================
|
||||
|
||||
279
services/forecasting/app/api/ml_insights.py
Normal file
279
services/forecasting/app/api/ml_insights.py
Normal file
@@ -0,0 +1,279 @@
|
||||
"""
|
||||
ML Insights API Endpoints for Forecasting Service
|
||||
|
||||
Provides endpoints to trigger ML insight generation for:
|
||||
- Dynamic business rules learning
|
||||
- Demand pattern analysis
|
||||
- Seasonal trend detection
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, List
|
||||
from uuid import UUID
|
||||
from datetime import datetime, timedelta
|
||||
import structlog
|
||||
import pandas as pd
|
||||
|
||||
from app.core.database import get_db
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/v1/tenants/{tenant_id}/forecasting/ml/insights",
|
||||
tags=["ML Insights"]
|
||||
)
|
||||
|
||||
|
||||
# ================================================================
|
||||
# REQUEST/RESPONSE SCHEMAS
|
||||
# ================================================================
|
||||
|
||||
class RulesGenerationRequest(BaseModel):
|
||||
"""Request schema for rules generation"""
|
||||
product_ids: Optional[List[str]] = Field(
|
||||
None,
|
||||
description="Specific product IDs to analyze. If None, analyzes all products"
|
||||
)
|
||||
lookback_days: int = Field(
|
||||
90,
|
||||
description="Days of historical data to analyze",
|
||||
ge=30,
|
||||
le=365
|
||||
)
|
||||
min_samples: int = Field(
|
||||
10,
|
||||
description="Minimum samples required for rule learning",
|
||||
ge=5,
|
||||
le=100
|
||||
)
|
||||
|
||||
|
||||
class RulesGenerationResponse(BaseModel):
|
||||
"""Response schema for rules generation"""
|
||||
success: bool
|
||||
message: str
|
||||
tenant_id: str
|
||||
products_analyzed: int
|
||||
total_insights_generated: int
|
||||
total_insights_posted: int
|
||||
insights_by_product: dict
|
||||
errors: List[str] = []
|
||||
|
||||
|
||||
# ================================================================
|
||||
# API ENDPOINTS
|
||||
# ================================================================
|
||||
|
||||
@router.post("/generate-rules", response_model=RulesGenerationResponse)
|
||||
async def trigger_rules_generation(
|
||||
tenant_id: str,
|
||||
request_data: RulesGenerationRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Trigger dynamic business rules learning from historical sales data.
|
||||
|
||||
This endpoint:
|
||||
1. Fetches historical sales data for specified products
|
||||
2. Runs the RulesOrchestrator to learn patterns
|
||||
3. Generates insights about optimal business rules
|
||||
4. Posts insights to AI Insights Service
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID
|
||||
request_data: Rules generation parameters
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
RulesGenerationResponse with generation results
|
||||
"""
|
||||
logger.info(
|
||||
"ML insights rules generation requested",
|
||||
tenant_id=tenant_id,
|
||||
product_ids=request_data.product_ids,
|
||||
lookback_days=request_data.lookback_days
|
||||
)
|
||||
|
||||
try:
|
||||
# Import ML orchestrator and clients
|
||||
from app.ml.rules_orchestrator import RulesOrchestrator
|
||||
from shared.clients.sales_client import SalesServiceClient
|
||||
from shared.clients.inventory_client import InventoryServiceClient
|
||||
from app.core.config import settings
|
||||
|
||||
# Initialize orchestrator and clients
|
||||
orchestrator = RulesOrchestrator()
|
||||
inventory_client = InventoryServiceClient(settings)
|
||||
|
||||
# Get products to analyze from inventory service via API
|
||||
if request_data.product_ids:
|
||||
# Fetch specific products
|
||||
products = []
|
||||
for product_id in request_data.product_ids:
|
||||
product = await inventory_client.get_ingredient_by_id(
|
||||
ingredient_id=UUID(product_id),
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
if product:
|
||||
products.append(product)
|
||||
else:
|
||||
# Fetch all products for tenant (limit to 10)
|
||||
all_products = await inventory_client.get_all_ingredients(tenant_id=tenant_id)
|
||||
products = all_products[:10] # Limit to prevent timeout
|
||||
|
||||
if not products:
|
||||
return RulesGenerationResponse(
|
||||
success=False,
|
||||
message="No products found for analysis",
|
||||
tenant_id=tenant_id,
|
||||
products_analyzed=0,
|
||||
total_insights_generated=0,
|
||||
total_insights_posted=0,
|
||||
insights_by_product={},
|
||||
errors=["No products found"]
|
||||
)
|
||||
|
||||
# Initialize sales client to fetch historical data
|
||||
sales_client = SalesServiceClient(config=settings, calling_service_name="forecasting")
|
||||
|
||||
# Calculate date range
|
||||
end_date = datetime.utcnow()
|
||||
start_date = end_date - timedelta(days=request_data.lookback_days)
|
||||
|
||||
# Process each product
|
||||
total_insights_generated = 0
|
||||
total_insights_posted = 0
|
||||
insights_by_product = {}
|
||||
errors = []
|
||||
|
||||
for product in products:
|
||||
try:
|
||||
product_id = str(product['id'])
|
||||
product_name = product.get('name', 'Unknown')
|
||||
logger.info(f"Analyzing product {product_name} ({product_id})")
|
||||
|
||||
# Fetch sales data for product
|
||||
sales_data = await sales_client.get_sales_data(
|
||||
tenant_id=tenant_id,
|
||||
product_id=product_id,
|
||||
start_date=start_date.strftime('%Y-%m-%d'),
|
||||
end_date=end_date.strftime('%Y-%m-%d')
|
||||
)
|
||||
|
||||
if not sales_data:
|
||||
logger.warning(f"No sales data for product {product_id}")
|
||||
continue
|
||||
|
||||
# Convert to DataFrame
|
||||
sales_df = pd.DataFrame(sales_data)
|
||||
|
||||
if len(sales_df) < request_data.min_samples:
|
||||
logger.warning(
|
||||
f"Insufficient data for product {product_id}: "
|
||||
f"{len(sales_df)} samples < {request_data.min_samples} required"
|
||||
)
|
||||
continue
|
||||
|
||||
# Check what columns are available and map to expected format
|
||||
logger.debug(f"Sales data columns for product {product_id}: {sales_df.columns.tolist()}")
|
||||
|
||||
# Map common field names to 'quantity' and 'date'
|
||||
if 'quantity' not in sales_df.columns:
|
||||
if 'total_quantity' in sales_df.columns:
|
||||
sales_df['quantity'] = sales_df['total_quantity']
|
||||
elif 'amount' in sales_df.columns:
|
||||
sales_df['quantity'] = sales_df['amount']
|
||||
else:
|
||||
logger.warning(f"No quantity field found for product {product_id}, skipping")
|
||||
continue
|
||||
|
||||
if 'date' not in sales_df.columns:
|
||||
if 'sale_date' in sales_df.columns:
|
||||
sales_df['date'] = sales_df['sale_date']
|
||||
else:
|
||||
logger.warning(f"No date field found for product {product_id}, skipping")
|
||||
continue
|
||||
|
||||
# Prepare sales data with required columns
|
||||
sales_df['date'] = pd.to_datetime(sales_df['date'])
|
||||
sales_df['quantity'] = sales_df['quantity'].astype(float)
|
||||
sales_df['day_of_week'] = sales_df['date'].dt.dayofweek
|
||||
sales_df['is_holiday'] = False # TODO: Add holiday detection
|
||||
sales_df['weather'] = 'unknown' # TODO: Add weather data
|
||||
|
||||
# Run rules learning
|
||||
results = await orchestrator.learn_and_post_rules(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=product_id,
|
||||
sales_data=sales_df,
|
||||
external_data=None,
|
||||
min_samples=request_data.min_samples
|
||||
)
|
||||
|
||||
# Track results
|
||||
total_insights_generated += results['insights_generated']
|
||||
total_insights_posted += results['insights_posted']
|
||||
insights_by_product[product_id] = {
|
||||
'product_name': product_name,
|
||||
'insights_posted': results['insights_posted'],
|
||||
'rules_learned': len(results['rules'])
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Product {product_id} analysis complete",
|
||||
insights_posted=results['insights_posted']
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error analyzing product {product_id}: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
errors.append(error_msg)
|
||||
|
||||
# Close orchestrator
|
||||
await orchestrator.close()
|
||||
|
||||
# Build response
|
||||
response = RulesGenerationResponse(
|
||||
success=total_insights_posted > 0,
|
||||
message=f"Successfully generated {total_insights_posted} insights from {len(products)} products",
|
||||
tenant_id=tenant_id,
|
||||
products_analyzed=len(products),
|
||||
total_insights_generated=total_insights_generated,
|
||||
total_insights_posted=total_insights_posted,
|
||||
insights_by_product=insights_by_product,
|
||||
errors=errors
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"ML insights rules generation complete",
|
||||
tenant_id=tenant_id,
|
||||
total_insights=total_insights_posted
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"ML insights rules generation failed",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e),
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Rules generation failed: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def ml_insights_health():
|
||||
"""Health check for ML insights endpoints"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "forecasting-ml-insights",
|
||||
"endpoints": [
|
||||
"POST /ml/insights/generate-rules"
|
||||
]
|
||||
}
|
||||
Reference in New Issue
Block a user