Improve the UI and training
This commit is contained in:
@@ -4,7 +4,8 @@ Forecasting Operations API - Business operations for forecast generation and pre
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query, Path, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query, Path, Request, BackgroundTasks
|
||||
from fastapi.responses import JSONResponse
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import date, datetime, timezone
|
||||
import uuid
|
||||
@@ -202,6 +203,97 @@ async def generate_multi_day_forecast(
|
||||
)
|
||||
|
||||
|
||||
async def execute_batch_forecast_background(
|
||||
tenant_id: str,
|
||||
batch_id: str,
|
||||
inventory_product_ids: List[str],
|
||||
forecast_days: int,
|
||||
batch_name: str
|
||||
):
|
||||
"""
|
||||
Background task for batch forecast generation.
|
||||
Prevents blocking the API thread for long-running batch operations.
|
||||
"""
|
||||
logger.info("Starting background batch forecast",
|
||||
batch_id=batch_id,
|
||||
tenant_id=tenant_id,
|
||||
product_count=len(inventory_product_ids))
|
||||
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "forecasting-service")
|
||||
forecasting_service = EnhancedForecastingService(database_manager)
|
||||
|
||||
try:
|
||||
# Update batch status to running
|
||||
async with database_manager.get_session() as session:
|
||||
from app.repositories import PredictionBatchRepository
|
||||
batch_repo = PredictionBatchRepository(session)
|
||||
|
||||
await batch_repo.update(
|
||||
batch_id,
|
||||
{"status": "processing", "completed_products": 0}
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# Generate forecasts for all products
|
||||
from app.schemas.forecasts import BatchForecastRequest
|
||||
batch_request = BatchForecastRequest(
|
||||
tenant_id=tenant_id,
|
||||
batch_name=batch_name,
|
||||
inventory_product_ids=inventory_product_ids,
|
||||
forecast_days=forecast_days
|
||||
)
|
||||
|
||||
result = await forecasting_service.generate_batch_forecasts(
|
||||
tenant_id=tenant_id,
|
||||
request=batch_request
|
||||
)
|
||||
|
||||
# Update batch status to completed
|
||||
async with database_manager.get_session() as session:
|
||||
from app.repositories import PredictionBatchRepository
|
||||
batch_repo = PredictionBatchRepository(session)
|
||||
|
||||
await batch_repo.update(
|
||||
batch_id,
|
||||
{
|
||||
"status": "completed",
|
||||
"completed_at": datetime.now(timezone.utc),
|
||||
"completed_products": result.get("successful_forecasts", 0),
|
||||
"failed_products": result.get("failed_forecasts", 0)
|
||||
}
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
logger.info("Background batch forecast completed",
|
||||
batch_id=batch_id,
|
||||
successful=result.get("successful_forecasts", 0),
|
||||
failed=result.get("failed_forecasts", 0))
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Background batch forecast failed",
|
||||
batch_id=batch_id,
|
||||
error=str(e))
|
||||
|
||||
try:
|
||||
async with database_manager.get_session() as session:
|
||||
from app.repositories import PredictionBatchRepository
|
||||
batch_repo = PredictionBatchRepository(session)
|
||||
|
||||
await batch_repo.update(
|
||||
batch_id,
|
||||
{
|
||||
"status": "failed",
|
||||
"completed_at": datetime.now(timezone.utc),
|
||||
"error_message": str(e)
|
||||
}
|
||||
)
|
||||
await session.commit()
|
||||
except Exception as update_error:
|
||||
logger.error("Failed to update batch status after error",
|
||||
batch_id=batch_id,
|
||||
error=str(update_error))
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_operations_route("batch"),
|
||||
response_model=BatchForecastResponse
|
||||
@@ -211,11 +303,17 @@ async def generate_multi_day_forecast(
|
||||
async def generate_batch_forecast(
|
||||
request: BatchForecastRequest,
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
background_tasks: BackgroundTasks = BackgroundTasks(),
|
||||
request_obj: Request = None,
|
||||
current_user: dict = Depends(get_current_user_dep),
|
||||
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
||||
):
|
||||
"""Generate forecasts for multiple products in batch (Admin+ only, quota enforced)"""
|
||||
"""
|
||||
Generate forecasts for multiple products in batch (Admin+ only, quota enforced).
|
||||
|
||||
IMPROVEMENT: Now uses background tasks for large batches to prevent API timeouts.
|
||||
Returns immediately with batch_id for status tracking.
|
||||
"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
@@ -258,48 +356,104 @@ async def generate_batch_forecast(
|
||||
error_message=None
|
||||
)
|
||||
|
||||
# Skip rate limiting for service-to-service calls (orchestrator)
|
||||
# Rate limiting is handled at the gateway level for user requests
|
||||
# IMPROVEMENT: For large batches (>5 products), use background task
|
||||
# For small batches, execute synchronously for immediate results
|
||||
batch_name = getattr(request, 'batch_name', f"batch-{datetime.now().strftime('%Y%m%d_%H%M%S')}")
|
||||
forecast_days = getattr(request, 'forecast_days', 7)
|
||||
|
||||
# Create a copy of the request with the actual list of product IDs to forecast
|
||||
# (whether originally provided or fetched from inventory service)
|
||||
from app.schemas.forecasts import BatchForecastRequest
|
||||
updated_request = BatchForecastRequest(
|
||||
tenant_id=tenant_id, # Use the tenant_id from the path parameter
|
||||
batch_name=getattr(request, 'batch_name', f"orchestrator-batch-{datetime.now().strftime('%Y%m%d')}"),
|
||||
inventory_product_ids=inventory_product_ids,
|
||||
forecast_days=getattr(request, 'forecast_days', 7)
|
||||
)
|
||||
|
||||
batch_result = await enhanced_forecasting_service.generate_batch_forecasts(
|
||||
tenant_id=tenant_id,
|
||||
request=updated_request
|
||||
)
|
||||
|
||||
if metrics:
|
||||
metrics.increment_counter("batch_forecasts_success_total")
|
||||
|
||||
logger.info("Batch forecast generated successfully",
|
||||
tenant_id=tenant_id,
|
||||
total_forecasts=batch_result.get('total_forecasts', 0))
|
||||
|
||||
# Convert the service result to BatchForecastResponse format
|
||||
from app.schemas.forecasts import BatchForecastResponse
|
||||
# Create batch record first
|
||||
batch_id = str(uuid.uuid4())
|
||||
now = datetime.now(timezone.utc)
|
||||
return BatchForecastResponse(
|
||||
id=batch_result.get('id', str(uuid.uuid4())), # Use 'id' field (UUID) instead of 'batch_id' (string)
|
||||
tenant_id=tenant_id,
|
||||
batch_name=updated_request.batch_name,
|
||||
status="completed",
|
||||
total_products=batch_result.get('total_forecasts', 0),
|
||||
completed_products=batch_result.get('successful_forecasts', 0),
|
||||
failed_products=batch_result.get('failed_forecasts', 0),
|
||||
requested_at=now,
|
||||
completed_at=now,
|
||||
processing_time_ms=0,
|
||||
forecasts=[],
|
||||
error_message=None
|
||||
)
|
||||
|
||||
async with enhanced_forecasting_service.database_manager.get_session() as session:
|
||||
from app.repositories import PredictionBatchRepository
|
||||
batch_repo = PredictionBatchRepository(session)
|
||||
|
||||
batch_data = {
|
||||
"tenant_id": tenant_id,
|
||||
"batch_name": batch_name,
|
||||
"total_products": len(inventory_product_ids),
|
||||
"forecast_days": forecast_days,
|
||||
"status": "pending"
|
||||
}
|
||||
|
||||
batch = await batch_repo.create_batch(batch_data)
|
||||
batch_id = str(batch.id)
|
||||
await session.commit()
|
||||
|
||||
# Use background task for large batches to prevent API timeout
|
||||
use_background = len(inventory_product_ids) > 5
|
||||
|
||||
if use_background:
|
||||
# Queue background task
|
||||
background_tasks.add_task(
|
||||
execute_batch_forecast_background,
|
||||
tenant_id=tenant_id,
|
||||
batch_id=batch_id,
|
||||
inventory_product_ids=inventory_product_ids,
|
||||
forecast_days=forecast_days,
|
||||
batch_name=batch_name
|
||||
)
|
||||
|
||||
logger.info("Batch forecast queued for background processing",
|
||||
tenant_id=tenant_id,
|
||||
batch_id=batch_id,
|
||||
product_count=len(inventory_product_ids))
|
||||
|
||||
# Return immediately with pending status
|
||||
from app.schemas.forecasts import BatchForecastResponse
|
||||
return BatchForecastResponse(
|
||||
id=batch_id,
|
||||
tenant_id=tenant_id,
|
||||
batch_name=batch_name,
|
||||
status="pending",
|
||||
total_products=len(inventory_product_ids),
|
||||
completed_products=0,
|
||||
failed_products=0,
|
||||
requested_at=now,
|
||||
completed_at=None,
|
||||
processing_time_ms=0,
|
||||
forecasts=None,
|
||||
error_message=None
|
||||
)
|
||||
else:
|
||||
# Small batch - execute synchronously
|
||||
from app.schemas.forecasts import BatchForecastRequest
|
||||
updated_request = BatchForecastRequest(
|
||||
tenant_id=tenant_id,
|
||||
batch_name=batch_name,
|
||||
inventory_product_ids=inventory_product_ids,
|
||||
forecast_days=forecast_days
|
||||
)
|
||||
|
||||
batch_result = await enhanced_forecasting_service.generate_batch_forecasts(
|
||||
tenant_id=tenant_id,
|
||||
request=updated_request
|
||||
)
|
||||
|
||||
if metrics:
|
||||
metrics.increment_counter("batch_forecasts_success_total")
|
||||
|
||||
logger.info("Batch forecast completed synchronously",
|
||||
tenant_id=tenant_id,
|
||||
total_forecasts=batch_result.get('total_forecasts', 0))
|
||||
|
||||
# Convert the service result to BatchForecastResponse format
|
||||
from app.schemas.forecasts import BatchForecastResponse
|
||||
return BatchForecastResponse(
|
||||
id=batch_id,
|
||||
tenant_id=tenant_id,
|
||||
batch_name=batch_name,
|
||||
status="completed",
|
||||
total_products=batch_result.get('total_forecasts', 0),
|
||||
completed_products=batch_result.get('successful_forecasts', 0),
|
||||
failed_products=batch_result.get('failed_forecasts', 0),
|
||||
requested_at=now,
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
processing_time_ms=0,
|
||||
forecasts=[],
|
||||
error_message=None
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
if metrics:
|
||||
@@ -806,3 +960,50 @@ async def preview_tenant_data_deletion(
|
||||
status_code=500,
|
||||
detail=f"Failed to preview tenant data deletion: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/health/database")
|
||||
async def database_health():
|
||||
"""
|
||||
Database health check endpoint with connection pool monitoring.
|
||||
|
||||
Returns detailed connection pool statistics for monitoring and alerting.
|
||||
Useful for detecting connection pool exhaustion before it causes issues.
|
||||
"""
|
||||
from app.core.database import get_db_health, get_connection_pool_stats
|
||||
from datetime import datetime
|
||||
|
||||
try:
|
||||
# Check database connectivity
|
||||
db_healthy = await get_db_health()
|
||||
|
||||
# Get connection pool statistics
|
||||
pool_stats = await get_connection_pool_stats()
|
||||
|
||||
response = {
|
||||
"service": "forecasting",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"database_connected": db_healthy,
|
||||
"connection_pool": pool_stats,
|
||||
"overall_status": "healthy" if db_healthy and pool_stats.get("status") == "healthy" else "degraded"
|
||||
}
|
||||
|
||||
# Return appropriate status code based on health
|
||||
if not db_healthy or pool_stats.get("status") == "critical":
|
||||
return JSONResponse(status_code=503, content=response)
|
||||
elif pool_stats.get("status") == "warning":
|
||||
return JSONResponse(status_code=200, content=response)
|
||||
else:
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Health check failed", error=str(e))
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={
|
||||
"service": "forecasting",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"overall_status": "unhealthy",
|
||||
"error": str(e)
|
||||
}
|
||||
)
|
||||
|
||||
@@ -70,6 +70,47 @@ async def get_db_health() -> bool:
|
||||
logger.error("Database health check failed", error=str(e))
|
||||
return False
|
||||
|
||||
|
||||
async def get_connection_pool_stats() -> dict:
|
||||
"""
|
||||
Get current connection pool statistics for monitoring.
|
||||
|
||||
Returns:
|
||||
Dictionary with pool statistics including usage and capacity
|
||||
"""
|
||||
try:
|
||||
pool = async_engine.pool
|
||||
|
||||
# Get pool stats
|
||||
stats = {
|
||||
"pool_size": pool.size(),
|
||||
"checked_in_connections": pool.checkedin(),
|
||||
"checked_out_connections": pool.checkedout(),
|
||||
"overflow_connections": pool.overflow(),
|
||||
"total_connections": pool.size() + pool.overflow(),
|
||||
"max_capacity": 10 + 20, # pool_size + max_overflow
|
||||
"usage_percentage": round(((pool.size() + pool.overflow()) / 30) * 100, 2)
|
||||
}
|
||||
|
||||
# Add health status
|
||||
if stats["usage_percentage"] > 90:
|
||||
stats["status"] = "critical"
|
||||
stats["message"] = "Connection pool near capacity"
|
||||
elif stats["usage_percentage"] > 80:
|
||||
stats["status"] = "warning"
|
||||
stats["message"] = "Connection pool usage high"
|
||||
else:
|
||||
stats["status"] = "healthy"
|
||||
stats["message"] = "Connection pool healthy"
|
||||
|
||||
return stats
|
||||
except Exception as e:
|
||||
logger.error("Failed to get connection pool stats", error=str(e))
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Failed to get pool stats: {str(e)}"
|
||||
}
|
||||
|
||||
# Database manager instance for service_base compatibility
|
||||
database_manager = DatabaseManager(
|
||||
database_url=settings.DATABASE_URL,
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
Forecast models for the forecasting service
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, Integer, Float, DateTime, Boolean, Text, JSON
|
||||
from sqlalchemy import Column, String, Integer, Float, DateTime, Boolean, Text, JSON, UniqueConstraint, Index
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
@@ -15,7 +15,18 @@ from shared.database.base import Base
|
||||
class Forecast(Base):
|
||||
"""Forecast model for storing prediction results"""
|
||||
__tablename__ = "forecasts"
|
||||
|
||||
|
||||
__table_args__ = (
|
||||
# Unique constraint to prevent duplicate forecasts
|
||||
# Ensures only one forecast per (tenant, product, date, location) combination
|
||||
UniqueConstraint(
|
||||
'tenant_id', 'inventory_product_id', 'forecast_date', 'location',
|
||||
name='uq_forecast_tenant_product_date_location'
|
||||
),
|
||||
# Composite index for common query patterns
|
||||
Index('ix_forecasts_tenant_product_date', 'tenant_id', 'inventory_product_id', 'forecast_date'),
|
||||
)
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
inventory_product_id = Column(UUID(as_uuid=True), nullable=False, index=True) # Reference to inventory service
|
||||
|
||||
@@ -6,6 +6,7 @@ Repository for forecast operations
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, text, desc, func
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from datetime import datetime, timedelta, date, timezone
|
||||
import structlog
|
||||
|
||||
@@ -24,18 +25,24 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
super().__init__(Forecast, session, cache_ttl)
|
||||
|
||||
async def create_forecast(self, forecast_data: Dict[str, Any]) -> Forecast:
|
||||
"""Create a new forecast with validation"""
|
||||
"""
|
||||
Create a new forecast with validation.
|
||||
|
||||
Handles duplicate forecast race condition gracefully:
|
||||
If a forecast already exists for the same (tenant, product, date, location),
|
||||
it will be updated instead of creating a duplicate.
|
||||
"""
|
||||
try:
|
||||
# Validate forecast data
|
||||
validation_result = self._validate_forecast_data(
|
||||
forecast_data,
|
||||
["tenant_id", "inventory_product_id", "location", "forecast_date",
|
||||
forecast_data,
|
||||
["tenant_id", "inventory_product_id", "location", "forecast_date",
|
||||
"predicted_demand", "confidence_lower", "confidence_upper", "model_id"]
|
||||
)
|
||||
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid forecast data: {validation_result['errors']}")
|
||||
|
||||
|
||||
# Set default values
|
||||
if "confidence_level" not in forecast_data:
|
||||
forecast_data["confidence_level"] = 0.8
|
||||
@@ -43,26 +50,109 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
forecast_data["algorithm"] = "prophet"
|
||||
if "business_type" not in forecast_data:
|
||||
forecast_data["business_type"] = "individual"
|
||||
|
||||
# Create forecast
|
||||
forecast = await self.create(forecast_data)
|
||||
|
||||
logger.info("Forecast created successfully",
|
||||
forecast_id=forecast.id,
|
||||
tenant_id=forecast.tenant_id,
|
||||
inventory_product_id=forecast.inventory_product_id,
|
||||
forecast_date=forecast.forecast_date.isoformat())
|
||||
|
||||
return forecast
|
||||
|
||||
|
||||
# Try to create forecast
|
||||
try:
|
||||
forecast = await self.create(forecast_data)
|
||||
|
||||
logger.info("Forecast created successfully",
|
||||
forecast_id=forecast.id,
|
||||
tenant_id=forecast.tenant_id,
|
||||
inventory_product_id=forecast.inventory_product_id,
|
||||
forecast_date=forecast.forecast_date.isoformat())
|
||||
|
||||
return forecast
|
||||
|
||||
except IntegrityError as ie:
|
||||
# Handle unique constraint violation (duplicate forecast)
|
||||
error_msg = str(ie).lower()
|
||||
if "unique constraint" in error_msg or "duplicate" in error_msg or "uq_forecast_tenant_product_date_location" in error_msg:
|
||||
logger.warning("Forecast already exists (race condition), updating instead",
|
||||
tenant_id=forecast_data.get("tenant_id"),
|
||||
inventory_product_id=forecast_data.get("inventory_product_id"),
|
||||
forecast_date=str(forecast_data.get("forecast_date")))
|
||||
|
||||
# Rollback the failed insert
|
||||
await self.session.rollback()
|
||||
|
||||
# Fetch the existing forecast
|
||||
existing_forecast = await self.get_existing_forecast(
|
||||
tenant_id=forecast_data["tenant_id"],
|
||||
inventory_product_id=forecast_data["inventory_product_id"],
|
||||
forecast_date=forecast_data["forecast_date"],
|
||||
location=forecast_data["location"]
|
||||
)
|
||||
|
||||
if existing_forecast:
|
||||
# Update existing forecast with new prediction data
|
||||
update_data = {
|
||||
"predicted_demand": forecast_data["predicted_demand"],
|
||||
"confidence_lower": forecast_data["confidence_lower"],
|
||||
"confidence_upper": forecast_data["confidence_upper"],
|
||||
"confidence_level": forecast_data.get("confidence_level", 0.8),
|
||||
"model_id": forecast_data["model_id"],
|
||||
"model_version": forecast_data.get("model_version"),
|
||||
"algorithm": forecast_data.get("algorithm", "prophet"),
|
||||
"processing_time_ms": forecast_data.get("processing_time_ms"),
|
||||
"features_used": forecast_data.get("features_used"),
|
||||
"weather_temperature": forecast_data.get("weather_temperature"),
|
||||
"weather_precipitation": forecast_data.get("weather_precipitation"),
|
||||
"weather_description": forecast_data.get("weather_description"),
|
||||
}
|
||||
|
||||
updated_forecast = await self.update(str(existing_forecast.id), update_data)
|
||||
|
||||
logger.info("Existing forecast updated after duplicate detection",
|
||||
forecast_id=updated_forecast.id,
|
||||
tenant_id=updated_forecast.tenant_id,
|
||||
inventory_product_id=updated_forecast.inventory_product_id)
|
||||
|
||||
return updated_forecast
|
||||
else:
|
||||
# This shouldn't happen, but log it
|
||||
logger.error("Duplicate forecast detected but not found in database")
|
||||
raise DatabaseError("Duplicate forecast detected but not found")
|
||||
else:
|
||||
# Different integrity error, re-raise
|
||||
raise
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except IntegrityError as ie:
|
||||
# Re-raise integrity errors that weren't handled above
|
||||
logger.error("Database integrity error creating forecast",
|
||||
tenant_id=forecast_data.get("tenant_id"),
|
||||
error=str(ie))
|
||||
raise DatabaseError(f"Database integrity error: {str(ie)}")
|
||||
except Exception as e:
|
||||
logger.error("Failed to create forecast",
|
||||
tenant_id=forecast_data.get("tenant_id"),
|
||||
inventory_product_id=forecast_data.get("inventory_product_id"),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create forecast: {str(e)}")
|
||||
|
||||
async def get_existing_forecast(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
forecast_date: datetime,
|
||||
location: str
|
||||
) -> Optional[Forecast]:
|
||||
"""Get an existing forecast by unique key (tenant, product, date, location)"""
|
||||
try:
|
||||
query = select(Forecast).where(
|
||||
and_(
|
||||
Forecast.tenant_id == tenant_id,
|
||||
Forecast.inventory_product_id == inventory_product_id,
|
||||
Forecast.forecast_date == forecast_date,
|
||||
Forecast.location == location
|
||||
)
|
||||
)
|
||||
result = await self.session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error("Failed to get existing forecast", error=str(e))
|
||||
return None
|
||||
|
||||
async def get_forecasts_by_date_range(
|
||||
self,
|
||||
|
||||
@@ -15,6 +15,7 @@ from app.schemas.forecasts import ForecastRequest, ForecastResponse
|
||||
from app.services.prediction_service import PredictionService
|
||||
from app.services.model_client import ModelClient
|
||||
from app.services.data_client import DataClient
|
||||
from app.utils.distributed_lock import get_forecast_lock, get_batch_forecast_lock, LockAcquisitionError
|
||||
|
||||
# Import repositories
|
||||
from app.repositories import (
|
||||
@@ -291,107 +292,165 @@ class EnhancedForecastingService:
|
||||
) -> ForecastResponse:
|
||||
"""
|
||||
Generate forecast using repository pattern with caching.
|
||||
|
||||
CRITICAL FIXES:
|
||||
1. External HTTP calls are performed BEFORE opening database session
|
||||
to prevent connection pool exhaustion and blocking.
|
||||
2. Advisory locks prevent concurrent forecast generation for same product/date
|
||||
to avoid duplicate work and race conditions.
|
||||
"""
|
||||
start_time = datetime.now(timezone.utc)
|
||||
|
||||
|
||||
try:
|
||||
logger.info("Generating enhanced forecast",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=request.inventory_product_id,
|
||||
date=request.forecast_date.isoformat())
|
||||
|
||||
# Get session and initialize repositories
|
||||
|
||||
# CRITICAL FIX: Get model BEFORE opening database session
|
||||
# This prevents holding database connections during potentially slow external API calls
|
||||
logger.debug("Fetching model data before opening database session",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=request.inventory_product_id)
|
||||
|
||||
model_data = await self._get_latest_model_with_fallback(tenant_id, request.inventory_product_id)
|
||||
|
||||
if not model_data:
|
||||
raise ValueError(f"No valid model available for product: {request.inventory_product_id}")
|
||||
|
||||
logger.debug("Model data fetched successfully",
|
||||
tenant_id=tenant_id,
|
||||
model_id=model_data.get('model_id'))
|
||||
|
||||
# Step 3: Prepare features with fallbacks (includes external API calls for weather)
|
||||
features = await self._prepare_forecast_features_with_fallbacks(tenant_id, request)
|
||||
|
||||
# Now open database session AFTER external HTTP calls are complete
|
||||
# CRITICAL FIX: Acquire distributed lock to prevent concurrent forecast generation
|
||||
async with self.database_manager.get_background_session() as session:
|
||||
repos = await self._init_repositories(session)
|
||||
|
||||
# Step 1: Check cache first
|
||||
cached_prediction = await repos['cache'].get_cached_prediction(
|
||||
tenant_id, request.inventory_product_id, request.location, request.forecast_date
|
||||
)
|
||||
|
||||
if cached_prediction:
|
||||
logger.debug("Using cached prediction",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=request.inventory_product_id)
|
||||
return self._create_forecast_response_from_cache(cached_prediction)
|
||||
|
||||
# Step 2: Get model with validation
|
||||
model_data = await self._get_latest_model_with_fallback(tenant_id, request.inventory_product_id)
|
||||
|
||||
if not model_data:
|
||||
raise ValueError(f"No valid model available for product: {request.inventory_product_id}")
|
||||
|
||||
# Step 3: Prepare features with fallbacks
|
||||
features = await self._prepare_forecast_features_with_fallbacks(tenant_id, request)
|
||||
|
||||
# Step 4: Generate prediction
|
||||
prediction_result = await self.prediction_service.predict(
|
||||
model_id=model_data['model_id'],
|
||||
model_path=model_data['model_path'],
|
||||
features=features,
|
||||
confidence_level=request.confidence_level
|
||||
)
|
||||
|
||||
# Step 5: Apply business rules
|
||||
adjusted_prediction = self._apply_business_rules(
|
||||
prediction_result, request, features
|
||||
)
|
||||
|
||||
# Step 6: Save forecast using repository
|
||||
# Convert forecast_date to datetime if it's a string
|
||||
forecast_datetime = request.forecast_date
|
||||
if isinstance(forecast_datetime, str):
|
||||
from dateutil.parser import parse
|
||||
forecast_datetime = parse(forecast_datetime)
|
||||
|
||||
forecast_data = {
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": request.inventory_product_id,
|
||||
"product_name": None, # Field is now nullable, use inventory_product_id as reference
|
||||
"location": request.location,
|
||||
"forecast_date": forecast_datetime,
|
||||
"predicted_demand": adjusted_prediction['prediction'],
|
||||
"confidence_lower": adjusted_prediction.get('lower_bound', adjusted_prediction['prediction'] * 0.8),
|
||||
"confidence_upper": adjusted_prediction.get('upper_bound', adjusted_prediction['prediction'] * 1.2),
|
||||
"confidence_level": request.confidence_level,
|
||||
"model_id": model_data['model_id'],
|
||||
"model_version": str(model_data.get('version', '1.0')),
|
||||
"algorithm": model_data.get('algorithm', 'prophet'),
|
||||
"business_type": features.get('business_type', 'individual'),
|
||||
"is_holiday": features.get('is_holiday', False),
|
||||
"is_weekend": features.get('is_weekend', False),
|
||||
"day_of_week": features.get('day_of_week', 0),
|
||||
"weather_temperature": features.get('temperature'),
|
||||
"weather_precipitation": features.get('precipitation'),
|
||||
"weather_description": features.get('weather_description'),
|
||||
"traffic_volume": features.get('traffic_volume'),
|
||||
"processing_time_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000),
|
||||
"features_used": features
|
||||
}
|
||||
|
||||
forecast = await repos['forecast'].create_forecast(forecast_data)
|
||||
|
||||
# Step 6: Cache the prediction
|
||||
await repos['cache'].cache_prediction(
|
||||
# Get lock for this specific forecast (tenant + product + date)
|
||||
forecast_date_str = request.forecast_date.isoformat().split('T')[0] if hasattr(request.forecast_date, 'isoformat') else str(request.forecast_date).split('T')[0]
|
||||
lock = get_forecast_lock(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=request.inventory_product_id,
|
||||
location=request.location,
|
||||
forecast_date=forecast_datetime,
|
||||
predicted_demand=adjusted_prediction['prediction'],
|
||||
confidence_lower=adjusted_prediction.get('lower_bound', adjusted_prediction['prediction'] * 0.8),
|
||||
confidence_upper=adjusted_prediction.get('upper_bound', adjusted_prediction['prediction'] * 1.2),
|
||||
model_id=model_data['model_id'],
|
||||
expires_in_hours=24
|
||||
product_id=str(request.inventory_product_id),
|
||||
forecast_date=forecast_date_str
|
||||
)
|
||||
|
||||
|
||||
logger.info("Enhanced forecast generated successfully",
|
||||
forecast_id=forecast.id,
|
||||
tenant_id=tenant_id,
|
||||
prediction=adjusted_prediction['prediction'])
|
||||
|
||||
return self._create_forecast_response_from_model(forecast)
|
||||
|
||||
|
||||
try:
|
||||
async with lock.acquire(session):
|
||||
repos = await self._init_repositories(session)
|
||||
|
||||
# Step 1: Check cache first (inside lock for consistency)
|
||||
# If another request generated the forecast while we waited for the lock,
|
||||
# we'll find it in the cache
|
||||
cached_prediction = await repos['cache'].get_cached_prediction(
|
||||
tenant_id, request.inventory_product_id, request.location, request.forecast_date
|
||||
)
|
||||
|
||||
if cached_prediction:
|
||||
logger.info("Found cached prediction after acquiring lock (concurrent request completed first)",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=request.inventory_product_id)
|
||||
return self._create_forecast_response_from_cache(cached_prediction)
|
||||
|
||||
# Step 2: Model data already fetched above (before session opened)
|
||||
|
||||
# Step 4: Generate prediction (in-memory operation)
|
||||
prediction_result = await self.prediction_service.predict(
|
||||
model_id=model_data['model_id'],
|
||||
model_path=model_data['model_path'],
|
||||
features=features,
|
||||
confidence_level=request.confidence_level
|
||||
)
|
||||
|
||||
# Step 5: Apply business rules
|
||||
adjusted_prediction = self._apply_business_rules(
|
||||
prediction_result, request, features
|
||||
)
|
||||
|
||||
# Step 6: Save forecast using repository
|
||||
# Convert forecast_date to datetime if it's a string
|
||||
forecast_datetime = request.forecast_date
|
||||
if isinstance(forecast_datetime, str):
|
||||
from dateutil.parser import parse
|
||||
forecast_datetime = parse(forecast_datetime)
|
||||
|
||||
forecast_data = {
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": request.inventory_product_id,
|
||||
"product_name": None, # Field is now nullable, use inventory_product_id as reference
|
||||
"location": request.location,
|
||||
"forecast_date": forecast_datetime,
|
||||
"predicted_demand": adjusted_prediction['prediction'],
|
||||
"confidence_lower": adjusted_prediction.get('lower_bound', adjusted_prediction['prediction'] * 0.8),
|
||||
"confidence_upper": adjusted_prediction.get('upper_bound', adjusted_prediction['prediction'] * 1.2),
|
||||
"confidence_level": request.confidence_level,
|
||||
"model_id": model_data['model_id'],
|
||||
"model_version": str(model_data.get('version', '1.0')),
|
||||
"algorithm": model_data.get('algorithm', 'prophet'),
|
||||
"business_type": features.get('business_type', 'individual'),
|
||||
"is_holiday": features.get('is_holiday', False),
|
||||
"is_weekend": features.get('is_weekend', False),
|
||||
"day_of_week": features.get('day_of_week', 0),
|
||||
"weather_temperature": features.get('temperature'),
|
||||
"weather_precipitation": features.get('precipitation'),
|
||||
"weather_description": features.get('weather_description'),
|
||||
"traffic_volume": features.get('traffic_volume'),
|
||||
"processing_time_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000),
|
||||
"features_used": features
|
||||
}
|
||||
|
||||
forecast = await repos['forecast'].create_forecast(forecast_data)
|
||||
await session.commit()
|
||||
|
||||
# Step 7: Cache the prediction
|
||||
await repos['cache'].cache_prediction(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=request.inventory_product_id,
|
||||
location=request.location,
|
||||
forecast_date=forecast_datetime,
|
||||
predicted_demand=adjusted_prediction['prediction'],
|
||||
confidence_lower=adjusted_prediction.get('lower_bound', adjusted_prediction['prediction'] * 0.8),
|
||||
confidence_upper=adjusted_prediction.get('upper_bound', adjusted_prediction['prediction'] * 1.2),
|
||||
model_id=model_data['model_id'],
|
||||
expires_in_hours=24
|
||||
)
|
||||
|
||||
|
||||
logger.info("Enhanced forecast generated successfully",
|
||||
forecast_id=forecast.id,
|
||||
tenant_id=tenant_id,
|
||||
prediction=adjusted_prediction['prediction'])
|
||||
|
||||
return self._create_forecast_response_from_model(forecast)
|
||||
|
||||
except LockAcquisitionError:
|
||||
# Could not acquire lock - another forecast request is in progress
|
||||
logger.warning("Could not acquire forecast lock, checking cache for concurrent request result",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=request.inventory_product_id,
|
||||
forecast_date=forecast_date_str)
|
||||
|
||||
# Wait a moment and check cache - maybe the concurrent request finished
|
||||
await asyncio.sleep(1)
|
||||
|
||||
repos = await self._init_repositories(session)
|
||||
cached_prediction = await repos['cache'].get_cached_prediction(
|
||||
tenant_id, request.inventory_product_id, request.location, request.forecast_date
|
||||
)
|
||||
|
||||
if cached_prediction:
|
||||
logger.info("Found forecast in cache after lock timeout (concurrent request completed)",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=request.inventory_product_id)
|
||||
return self._create_forecast_response_from_cache(cached_prediction)
|
||||
|
||||
# No cached result, raise error
|
||||
raise ValueError(
|
||||
f"Forecast generation already in progress for product {request.inventory_product_id}. "
|
||||
"Please try again in a few seconds."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
processing_time = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
|
||||
logger.error("Error generating enhanced forecast",
|
||||
|
||||
3
services/forecasting/app/utils/__init__.py
Normal file
3
services/forecasting/app/utils/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Utility modules for forecasting service
|
||||
"""
|
||||
258
services/forecasting/app/utils/distributed_lock.py
Normal file
258
services/forecasting/app/utils/distributed_lock.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""
|
||||
Distributed Locking Mechanisms for Forecasting Service
|
||||
Prevents concurrent forecast generation for the same product/date
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Optional
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import text
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LockAcquisitionError(Exception):
|
||||
"""Raised when lock cannot be acquired"""
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseLock:
|
||||
"""
|
||||
Database-based distributed lock using PostgreSQL advisory locks.
|
||||
Works across multiple service instances.
|
||||
"""
|
||||
|
||||
def __init__(self, lock_name: str, timeout: float = 30.0):
|
||||
"""
|
||||
Initialize database lock.
|
||||
|
||||
Args:
|
||||
lock_name: Unique identifier for the lock
|
||||
timeout: Maximum seconds to wait for lock acquisition
|
||||
"""
|
||||
self.lock_name = lock_name
|
||||
self.timeout = timeout
|
||||
self.lock_id = self._hash_lock_name(lock_name)
|
||||
|
||||
def _hash_lock_name(self, name: str) -> int:
|
||||
"""Convert lock name to integer ID for PostgreSQL advisory lock"""
|
||||
# Use hash and modulo to get a positive 32-bit integer
|
||||
return abs(hash(name)) % (2**31)
|
||||
|
||||
@asynccontextmanager
|
||||
async def acquire(self, session: AsyncSession):
|
||||
"""
|
||||
Acquire distributed lock as async context manager.
|
||||
|
||||
Args:
|
||||
session: Database session for lock operations
|
||||
|
||||
Raises:
|
||||
LockAcquisitionError: If lock cannot be acquired within timeout
|
||||
"""
|
||||
acquired = False
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Try to acquire lock with timeout
|
||||
while time.time() - start_time < self.timeout:
|
||||
# Try non-blocking lock acquisition
|
||||
result = await session.execute(
|
||||
text("SELECT pg_try_advisory_lock(:lock_id)"),
|
||||
{"lock_id": self.lock_id}
|
||||
)
|
||||
acquired = result.scalar()
|
||||
|
||||
if acquired:
|
||||
logger.info(f"Acquired lock: {self.lock_name} (id={self.lock_id})")
|
||||
break
|
||||
|
||||
# Wait a bit before retrying
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
if not acquired:
|
||||
raise LockAcquisitionError(
|
||||
f"Could not acquire lock '{self.lock_name}' within {self.timeout}s"
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
if acquired:
|
||||
# Release lock
|
||||
await session.execute(
|
||||
text("SELECT pg_advisory_unlock(:lock_id)"),
|
||||
{"lock_id": self.lock_id}
|
||||
)
|
||||
logger.info(f"Released lock: {self.lock_name} (id={self.lock_id})")
|
||||
|
||||
|
||||
class SimpleDatabaseLock:
|
||||
"""
|
||||
Simple table-based distributed lock.
|
||||
Alternative to advisory locks, uses a dedicated locks table.
|
||||
"""
|
||||
|
||||
def __init__(self, lock_name: str, timeout: float = 30.0, ttl: float = 300.0):
|
||||
"""
|
||||
Initialize simple database lock.
|
||||
|
||||
Args:
|
||||
lock_name: Unique identifier for the lock
|
||||
timeout: Maximum seconds to wait for lock acquisition
|
||||
ttl: Time-to-live for stale lock cleanup (seconds)
|
||||
"""
|
||||
self.lock_name = lock_name
|
||||
self.timeout = timeout
|
||||
self.ttl = ttl
|
||||
|
||||
async def _ensure_lock_table(self, session: AsyncSession):
|
||||
"""Ensure locks table exists"""
|
||||
create_table_sql = """
|
||||
CREATE TABLE IF NOT EXISTS distributed_locks (
|
||||
lock_name VARCHAR(255) PRIMARY KEY,
|
||||
acquired_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
||||
acquired_by VARCHAR(255),
|
||||
expires_at TIMESTAMP WITH TIME ZONE NOT NULL
|
||||
)
|
||||
"""
|
||||
await session.execute(text(create_table_sql))
|
||||
await session.commit()
|
||||
|
||||
async def _cleanup_stale_locks(self, session: AsyncSession):
|
||||
"""Remove expired locks"""
|
||||
cleanup_sql = """
|
||||
DELETE FROM distributed_locks
|
||||
WHERE expires_at < :now
|
||||
"""
|
||||
await session.execute(
|
||||
text(cleanup_sql),
|
||||
{"now": datetime.now(timezone.utc)}
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
@asynccontextmanager
|
||||
async def acquire(self, session: AsyncSession, owner: str = "forecasting-service"):
|
||||
"""
|
||||
Acquire simple database lock.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
owner: Identifier for lock owner
|
||||
|
||||
Raises:
|
||||
LockAcquisitionError: If lock cannot be acquired
|
||||
"""
|
||||
await self._ensure_lock_table(session)
|
||||
await self._cleanup_stale_locks(session)
|
||||
|
||||
acquired = False
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Try to acquire lock
|
||||
while time.time() - start_time < self.timeout:
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now + timedelta(seconds=self.ttl)
|
||||
|
||||
try:
|
||||
# Try to insert lock record
|
||||
insert_sql = """
|
||||
INSERT INTO distributed_locks (lock_name, acquired_at, acquired_by, expires_at)
|
||||
VALUES (:lock_name, :acquired_at, :acquired_by, :expires_at)
|
||||
ON CONFLICT (lock_name) DO NOTHING
|
||||
RETURNING lock_name
|
||||
"""
|
||||
|
||||
result = await session.execute(
|
||||
text(insert_sql),
|
||||
{
|
||||
"lock_name": self.lock_name,
|
||||
"acquired_at": now,
|
||||
"acquired_by": owner,
|
||||
"expires_at": expires_at
|
||||
}
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
if result.rowcount > 0:
|
||||
acquired = True
|
||||
logger.info(f"Acquired simple lock: {self.lock_name}")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Lock acquisition attempt failed: {e}")
|
||||
await session.rollback()
|
||||
|
||||
# Wait before retrying
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
if not acquired:
|
||||
raise LockAcquisitionError(
|
||||
f"Could not acquire lock '{self.lock_name}' within {self.timeout}s"
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
if acquired:
|
||||
# Release lock
|
||||
delete_sql = """
|
||||
DELETE FROM distributed_locks
|
||||
WHERE lock_name = :lock_name
|
||||
"""
|
||||
await session.execute(
|
||||
text(delete_sql),
|
||||
{"lock_name": self.lock_name}
|
||||
)
|
||||
await session.commit()
|
||||
logger.info(f"Released simple lock: {self.lock_name}")
|
||||
|
||||
|
||||
def get_forecast_lock(
|
||||
tenant_id: str,
|
||||
product_id: str,
|
||||
forecast_date: str,
|
||||
use_advisory: bool = True
|
||||
) -> DatabaseLock:
|
||||
"""
|
||||
Get distributed lock for generating a forecast for a specific product and date.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
product_id: Product identifier
|
||||
forecast_date: Forecast date (ISO format)
|
||||
use_advisory: Use PostgreSQL advisory locks (True) or table-based (False)
|
||||
|
||||
Returns:
|
||||
Lock instance
|
||||
"""
|
||||
lock_name = f"forecast:{tenant_id}:{product_id}:{forecast_date}"
|
||||
|
||||
if use_advisory:
|
||||
return DatabaseLock(lock_name, timeout=30.0)
|
||||
else:
|
||||
return SimpleDatabaseLock(lock_name, timeout=30.0, ttl=300.0)
|
||||
|
||||
|
||||
def get_batch_forecast_lock(tenant_id: str, use_advisory: bool = True) -> DatabaseLock:
|
||||
"""
|
||||
Get distributed lock for batch forecast generation for a tenant.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
use_advisory: Use PostgreSQL advisory locks (True) or table-based (False)
|
||||
|
||||
Returns:
|
||||
Lock instance
|
||||
"""
|
||||
lock_name = f"forecast_batch:{tenant_id}"
|
||||
|
||||
if use_advisory:
|
||||
return DatabaseLock(lock_name, timeout=60.0)
|
||||
else:
|
||||
return SimpleDatabaseLock(lock_name, timeout=60.0, ttl=600.0)
|
||||
@@ -12,7 +12,9 @@ from shared.subscription.plans import (
|
||||
SubscriptionPlanMetadata,
|
||||
PlanPricing,
|
||||
QuotaLimits,
|
||||
PlanFeatures
|
||||
PlanFeatures,
|
||||
FeatureCategories,
|
||||
UserFacingFeatures
|
||||
)
|
||||
|
||||
logger = structlog.get_logger()
|
||||
@@ -55,22 +57,21 @@ async def get_available_plans():
|
||||
# Convert Decimal to float for JSON serialization
|
||||
plans_data[tier.value] = {
|
||||
"name": metadata["name"],
|
||||
"description": metadata["description"],
|
||||
"tagline": metadata["tagline"],
|
||||
"description_key": metadata["description_key"],
|
||||
"tagline_key": metadata["tagline_key"],
|
||||
"popular": metadata["popular"],
|
||||
"monthly_price": float(metadata["monthly_price"]),
|
||||
"yearly_price": float(metadata["yearly_price"]),
|
||||
"trial_days": metadata["trial_days"],
|
||||
"features": metadata["features"],
|
||||
"limits": {
|
||||
"users": metadata["limits"]["users"],
|
||||
"locations": metadata["limits"]["locations"],
|
||||
"products": metadata["limits"]["products"],
|
||||
"forecasts_per_day": metadata["limits"]["forecasts_per_day"],
|
||||
},
|
||||
"support": metadata["support"],
|
||||
"recommended_for": metadata["recommended_for"],
|
||||
"hero_features": metadata.get("hero_features", []),
|
||||
"roi_badge": metadata.get("roi_badge"),
|
||||
"business_metrics": metadata.get("business_metrics"),
|
||||
"limits": metadata["limits"],
|
||||
"support_key": metadata["support_key"],
|
||||
"recommended_for_key": metadata["recommended_for_key"],
|
||||
"contact_sales": metadata.get("contact_sales", False),
|
||||
"custom_pricing": metadata.get("custom_pricing", False),
|
||||
}
|
||||
|
||||
logger.info("subscription_plans_fetched", tier_count=len(plans_data))
|
||||
@@ -110,22 +111,21 @@ async def get_plan_by_tier(tier: str):
|
||||
plan_data = {
|
||||
"tier": tier_enum.value,
|
||||
"name": metadata["name"],
|
||||
"description": metadata["description"],
|
||||
"tagline": metadata["tagline"],
|
||||
"description_key": metadata["description_key"],
|
||||
"tagline_key": metadata["tagline_key"],
|
||||
"popular": metadata["popular"],
|
||||
"monthly_price": float(metadata["monthly_price"]),
|
||||
"yearly_price": float(metadata["yearly_price"]),
|
||||
"trial_days": metadata["trial_days"],
|
||||
"features": metadata["features"],
|
||||
"limits": {
|
||||
"users": metadata["limits"]["users"],
|
||||
"locations": metadata["limits"]["locations"],
|
||||
"products": metadata["limits"]["products"],
|
||||
"forecasts_per_day": metadata["limits"]["forecasts_per_day"],
|
||||
},
|
||||
"support": metadata["support"],
|
||||
"recommended_for": metadata["recommended_for"],
|
||||
"hero_features": metadata.get("hero_features", []),
|
||||
"roi_badge": metadata.get("roi_badge"),
|
||||
"business_metrics": metadata.get("business_metrics"),
|
||||
"limits": metadata["limits"],
|
||||
"support_key": metadata["support_key"],
|
||||
"recommended_for_key": metadata["recommended_for_key"],
|
||||
"contact_sales": metadata.get("contact_sales", False),
|
||||
"custom_pricing": metadata.get("custom_pricing", False),
|
||||
}
|
||||
|
||||
logger.info("subscription_plan_fetched", tier=tier)
|
||||
@@ -233,6 +233,50 @@ async def get_plan_limits(tier: str):
|
||||
)
|
||||
|
||||
|
||||
@router.get("/feature-categories")
|
||||
async def get_feature_categories():
|
||||
"""
|
||||
Get all feature categories with icons and translation keys
|
||||
|
||||
**Public endpoint** - No authentication required
|
||||
|
||||
Returns:
|
||||
Dictionary of feature categories
|
||||
"""
|
||||
try:
|
||||
return {
|
||||
"categories": FeatureCategories.CATEGORIES
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error("failed_to_fetch_feature_categories", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to fetch feature categories"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/feature-descriptions")
|
||||
async def get_feature_descriptions():
|
||||
"""
|
||||
Get user-facing feature descriptions with translation keys
|
||||
|
||||
**Public endpoint** - No authentication required
|
||||
|
||||
Returns:
|
||||
Dictionary of feature descriptions mapped by feature key
|
||||
"""
|
||||
try:
|
||||
return {
|
||||
"features": UserFacingFeatures.FEATURE_DISPLAY
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error("failed_to_fetch_feature_descriptions", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to fetch feature descriptions"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/compare")
|
||||
async def compare_plans():
|
||||
"""
|
||||
|
||||
@@ -158,6 +158,56 @@ async def start_training_job(
|
||||
# Continue with job creation but log the error
|
||||
|
||||
try:
|
||||
# CRITICAL FIX: Check for existing running jobs before starting new one
|
||||
# This prevents duplicate tenant-level training jobs
|
||||
async with enhanced_training_service.database_manager.get_session() as check_session:
|
||||
from app.repositories.training_log_repository import TrainingLogRepository
|
||||
log_repo = TrainingLogRepository(check_session)
|
||||
|
||||
# Check for active jobs (running or pending)
|
||||
active_jobs = await log_repo.get_active_jobs(tenant_id=tenant_id)
|
||||
pending_jobs = await log_repo.get_logs_by_tenant(
|
||||
tenant_id=tenant_id,
|
||||
status="pending",
|
||||
limit=10
|
||||
)
|
||||
|
||||
all_active = active_jobs + pending_jobs
|
||||
|
||||
if all_active:
|
||||
# Training job already in progress, return existing job info
|
||||
existing_job = all_active[0]
|
||||
logger.info("Training job already in progress, returning existing job",
|
||||
existing_job_id=existing_job.job_id,
|
||||
tenant_id=tenant_id,
|
||||
status=existing_job.status)
|
||||
|
||||
return TrainingJobResponse(
|
||||
job_id=existing_job.job_id,
|
||||
tenant_id=tenant_id,
|
||||
status=existing_job.status,
|
||||
message=f"Training job already in progress (started {existing_job.created_at.isoformat() if existing_job.created_at else 'recently'})",
|
||||
created_at=existing_job.created_at or datetime.now(timezone.utc),
|
||||
estimated_duration_minutes=existing_job.config.get("estimated_duration_minutes", 15) if existing_job.config else 15,
|
||||
training_results={
|
||||
"total_products": 0,
|
||||
"successful_trainings": 0,
|
||||
"failed_trainings": 0,
|
||||
"products": [],
|
||||
"overall_training_time_seconds": 0.0
|
||||
},
|
||||
data_summary=None,
|
||||
completed_at=None,
|
||||
error_details=None,
|
||||
processing_metadata={
|
||||
"background_task": True,
|
||||
"async_execution": True,
|
||||
"existing_job": True,
|
||||
"deduplication": True
|
||||
}
|
||||
)
|
||||
|
||||
# No existing job, proceed with creating new one
|
||||
# Generate enhanced job ID
|
||||
job_id = f"enhanced_training_{tenant_id}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
@@ -407,6 +457,7 @@ async def start_single_product_training(
|
||||
request: SingleProductTrainingRequest,
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
inventory_product_id: str = Path(..., description="Inventory product UUID"),
|
||||
background_tasks: BackgroundTasks = BackgroundTasks(),
|
||||
request_obj: Request = None,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
|
||||
@@ -421,6 +472,7 @@ async def start_single_product_training(
|
||||
- Enhanced error handling and validation
|
||||
- Metrics tracking
|
||||
- Transactional operations
|
||||
- Background execution to prevent blocking
|
||||
"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
@@ -429,6 +481,53 @@ async def start_single_product_training(
|
||||
inventory_product_id=inventory_product_id,
|
||||
tenant_id=tenant_id)
|
||||
|
||||
# CRITICAL FIX: Check if this product is currently being trained
|
||||
# This prevents duplicate training from rapid-click scenarios
|
||||
async with enhanced_training_service.database_manager.get_session() as check_session:
|
||||
from app.repositories.training_log_repository import TrainingLogRepository
|
||||
log_repo = TrainingLogRepository(check_session)
|
||||
|
||||
# Check for active jobs for this specific product
|
||||
active_jobs = await log_repo.get_active_jobs(tenant_id=tenant_id)
|
||||
pending_jobs = await log_repo.get_logs_by_tenant(
|
||||
tenant_id=tenant_id,
|
||||
status="pending",
|
||||
limit=20
|
||||
)
|
||||
|
||||
all_active = active_jobs + pending_jobs
|
||||
|
||||
# Filter for jobs that include this specific product
|
||||
product_jobs = [
|
||||
job for job in all_active
|
||||
if job.config and (
|
||||
# Single product job for this product
|
||||
job.config.get("product_id") == inventory_product_id or
|
||||
# Tenant-wide job that would include this product
|
||||
job.config.get("job_type") == "tenant_training"
|
||||
)
|
||||
]
|
||||
|
||||
if product_jobs:
|
||||
existing_job = product_jobs[0]
|
||||
logger.warning("Product training already in progress, rejecting duplicate request",
|
||||
existing_job_id=existing_job.job_id,
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
status=existing_job.status)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail={
|
||||
"error": "Product training already in progress",
|
||||
"message": f"Product {inventory_product_id} is currently being trained in job {existing_job.job_id}",
|
||||
"existing_job_id": existing_job.job_id,
|
||||
"status": existing_job.status,
|
||||
"started_at": existing_job.created_at.isoformat() if existing_job.created_at else None
|
||||
}
|
||||
)
|
||||
|
||||
# No existing job, proceed with training
|
||||
# Record metrics
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_single_product_training_total")
|
||||
@@ -436,22 +535,60 @@ async def start_single_product_training(
|
||||
# Generate enhanced job ID
|
||||
job_id = f"enhanced_single_{tenant_id}_{inventory_product_id}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Delegate to enhanced training service
|
||||
result = await enhanced_training_service.start_single_product_training(
|
||||
# CRITICAL FIX: Add initial training log entry
|
||||
await enhanced_training_service._update_job_status_repository(
|
||||
job_id=job_id,
|
||||
status="pending",
|
||||
progress=0,
|
||||
current_step="Initializing single product training",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# Add enhanced background task for single product training
|
||||
background_tasks.add_task(
|
||||
execute_single_product_training_background,
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
job_id=job_id,
|
||||
bakery_location=request.bakery_location or (40.4168, -3.7038)
|
||||
bakery_location=request.bakery_location or (40.4168, -3.7038),
|
||||
database_manager=enhanced_training_service.database_manager
|
||||
)
|
||||
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_single_product_training_success_total")
|
||||
# Return immediate response with job info
|
||||
response_data = {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"status": "pending",
|
||||
"message": "Enhanced single product training started successfully",
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"estimated_duration_minutes": 15, # Default estimate for single product
|
||||
"training_results": {
|
||||
"total_products": 1,
|
||||
"successful_trainings": 0,
|
||||
"failed_trainings": 0,
|
||||
"products": [],
|
||||
"overall_training_time_seconds": 0.0
|
||||
},
|
||||
"data_summary": None,
|
||||
"completed_at": None,
|
||||
"error_details": None,
|
||||
"processing_metadata": {
|
||||
"background_task": True,
|
||||
"async_execution": True,
|
||||
"enhanced_features": True,
|
||||
"repository_pattern": True,
|
||||
"dependency_injection": True
|
||||
}
|
||||
}
|
||||
|
||||
logger.info("Enhanced single product training completed",
|
||||
logger.info("Enhanced single product training queued successfully",
|
||||
inventory_product_id=inventory_product_id,
|
||||
job_id=job_id)
|
||||
|
||||
return TrainingJobResponse(**result)
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_single_product_training_queued_total")
|
||||
|
||||
return TrainingJobResponse(**response_data)
|
||||
|
||||
except ValueError as e:
|
||||
if metrics:
|
||||
@@ -475,6 +612,74 @@ async def start_single_product_training(
|
||||
)
|
||||
|
||||
|
||||
async def execute_single_product_training_background(
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
job_id: str,
|
||||
bakery_location: tuple,
|
||||
database_manager
|
||||
):
|
||||
"""
|
||||
Enhanced background task that executes single product training using repository pattern.
|
||||
Uses a separate service instance to avoid session conflicts.
|
||||
"""
|
||||
logger.info("Enhanced background single product training started",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id)
|
||||
|
||||
# Create a new service instance with a fresh database session to avoid conflicts
|
||||
from app.services.training_service import EnhancedTrainingService
|
||||
fresh_training_service = EnhancedTrainingService(database_manager)
|
||||
|
||||
try:
|
||||
# Update job status to running
|
||||
await fresh_training_service._update_job_status_repository(
|
||||
job_id=job_id,
|
||||
status="running",
|
||||
progress=0,
|
||||
current_step="Starting single product training",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# Execute the enhanced single product training with repository pattern
|
||||
result = await fresh_training_service.start_single_product_training(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
job_id=job_id,
|
||||
bakery_location=bakery_location
|
||||
)
|
||||
|
||||
logger.info("Enhanced background single product training completed successfully",
|
||||
job_id=job_id,
|
||||
inventory_product_id=inventory_product_id)
|
||||
|
||||
except Exception as training_error:
|
||||
logger.error("Enhanced single product training failed",
|
||||
job_id=job_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(training_error))
|
||||
|
||||
try:
|
||||
await fresh_training_service._update_job_status_repository(
|
||||
job_id=job_id,
|
||||
status="failed",
|
||||
progress=0,
|
||||
current_step="Single product training failed",
|
||||
error_message=str(training_error),
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
except Exception as status_error:
|
||||
logger.error("Failed to update job status after training error",
|
||||
job_id=job_id,
|
||||
status_error=str(status_error))
|
||||
|
||||
finally:
|
||||
logger.info("Enhanced background single product training cleanup completed",
|
||||
job_id=job_id,
|
||||
inventory_product_id=inventory_product_id)
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint for the training operations"""
|
||||
|
||||
Reference in New Issue
Block a user