Improve the UI and training
This commit is contained in:
@@ -6,6 +6,7 @@ Repository for forecast operations
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, text, desc, func
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from datetime import datetime, timedelta, date, timezone
|
||||
import structlog
|
||||
|
||||
@@ -24,18 +25,24 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
super().__init__(Forecast, session, cache_ttl)
|
||||
|
||||
async def create_forecast(self, forecast_data: Dict[str, Any]) -> Forecast:
|
||||
"""Create a new forecast with validation"""
|
||||
"""
|
||||
Create a new forecast with validation.
|
||||
|
||||
Handles duplicate forecast race condition gracefully:
|
||||
If a forecast already exists for the same (tenant, product, date, location),
|
||||
it will be updated instead of creating a duplicate.
|
||||
"""
|
||||
try:
|
||||
# Validate forecast data
|
||||
validation_result = self._validate_forecast_data(
|
||||
forecast_data,
|
||||
["tenant_id", "inventory_product_id", "location", "forecast_date",
|
||||
forecast_data,
|
||||
["tenant_id", "inventory_product_id", "location", "forecast_date",
|
||||
"predicted_demand", "confidence_lower", "confidence_upper", "model_id"]
|
||||
)
|
||||
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid forecast data: {validation_result['errors']}")
|
||||
|
||||
|
||||
# Set default values
|
||||
if "confidence_level" not in forecast_data:
|
||||
forecast_data["confidence_level"] = 0.8
|
||||
@@ -43,26 +50,109 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
forecast_data["algorithm"] = "prophet"
|
||||
if "business_type" not in forecast_data:
|
||||
forecast_data["business_type"] = "individual"
|
||||
|
||||
# Create forecast
|
||||
forecast = await self.create(forecast_data)
|
||||
|
||||
logger.info("Forecast created successfully",
|
||||
forecast_id=forecast.id,
|
||||
tenant_id=forecast.tenant_id,
|
||||
inventory_product_id=forecast.inventory_product_id,
|
||||
forecast_date=forecast.forecast_date.isoformat())
|
||||
|
||||
return forecast
|
||||
|
||||
|
||||
# Try to create forecast
|
||||
try:
|
||||
forecast = await self.create(forecast_data)
|
||||
|
||||
logger.info("Forecast created successfully",
|
||||
forecast_id=forecast.id,
|
||||
tenant_id=forecast.tenant_id,
|
||||
inventory_product_id=forecast.inventory_product_id,
|
||||
forecast_date=forecast.forecast_date.isoformat())
|
||||
|
||||
return forecast
|
||||
|
||||
except IntegrityError as ie:
|
||||
# Handle unique constraint violation (duplicate forecast)
|
||||
error_msg = str(ie).lower()
|
||||
if "unique constraint" in error_msg or "duplicate" in error_msg or "uq_forecast_tenant_product_date_location" in error_msg:
|
||||
logger.warning("Forecast already exists (race condition), updating instead",
|
||||
tenant_id=forecast_data.get("tenant_id"),
|
||||
inventory_product_id=forecast_data.get("inventory_product_id"),
|
||||
forecast_date=str(forecast_data.get("forecast_date")))
|
||||
|
||||
# Rollback the failed insert
|
||||
await self.session.rollback()
|
||||
|
||||
# Fetch the existing forecast
|
||||
existing_forecast = await self.get_existing_forecast(
|
||||
tenant_id=forecast_data["tenant_id"],
|
||||
inventory_product_id=forecast_data["inventory_product_id"],
|
||||
forecast_date=forecast_data["forecast_date"],
|
||||
location=forecast_data["location"]
|
||||
)
|
||||
|
||||
if existing_forecast:
|
||||
# Update existing forecast with new prediction data
|
||||
update_data = {
|
||||
"predicted_demand": forecast_data["predicted_demand"],
|
||||
"confidence_lower": forecast_data["confidence_lower"],
|
||||
"confidence_upper": forecast_data["confidence_upper"],
|
||||
"confidence_level": forecast_data.get("confidence_level", 0.8),
|
||||
"model_id": forecast_data["model_id"],
|
||||
"model_version": forecast_data.get("model_version"),
|
||||
"algorithm": forecast_data.get("algorithm", "prophet"),
|
||||
"processing_time_ms": forecast_data.get("processing_time_ms"),
|
||||
"features_used": forecast_data.get("features_used"),
|
||||
"weather_temperature": forecast_data.get("weather_temperature"),
|
||||
"weather_precipitation": forecast_data.get("weather_precipitation"),
|
||||
"weather_description": forecast_data.get("weather_description"),
|
||||
}
|
||||
|
||||
updated_forecast = await self.update(str(existing_forecast.id), update_data)
|
||||
|
||||
logger.info("Existing forecast updated after duplicate detection",
|
||||
forecast_id=updated_forecast.id,
|
||||
tenant_id=updated_forecast.tenant_id,
|
||||
inventory_product_id=updated_forecast.inventory_product_id)
|
||||
|
||||
return updated_forecast
|
||||
else:
|
||||
# This shouldn't happen, but log it
|
||||
logger.error("Duplicate forecast detected but not found in database")
|
||||
raise DatabaseError("Duplicate forecast detected but not found")
|
||||
else:
|
||||
# Different integrity error, re-raise
|
||||
raise
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except IntegrityError as ie:
|
||||
# Re-raise integrity errors that weren't handled above
|
||||
logger.error("Database integrity error creating forecast",
|
||||
tenant_id=forecast_data.get("tenant_id"),
|
||||
error=str(ie))
|
||||
raise DatabaseError(f"Database integrity error: {str(ie)}")
|
||||
except Exception as e:
|
||||
logger.error("Failed to create forecast",
|
||||
tenant_id=forecast_data.get("tenant_id"),
|
||||
inventory_product_id=forecast_data.get("inventory_product_id"),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create forecast: {str(e)}")
|
||||
|
||||
async def get_existing_forecast(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
forecast_date: datetime,
|
||||
location: str
|
||||
) -> Optional[Forecast]:
|
||||
"""Get an existing forecast by unique key (tenant, product, date, location)"""
|
||||
try:
|
||||
query = select(Forecast).where(
|
||||
and_(
|
||||
Forecast.tenant_id == tenant_id,
|
||||
Forecast.inventory_product_id == inventory_product_id,
|
||||
Forecast.forecast_date == forecast_date,
|
||||
Forecast.location == location
|
||||
)
|
||||
)
|
||||
result = await self.session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error("Failed to get existing forecast", error=str(e))
|
||||
return None
|
||||
|
||||
async def get_forecasts_by_date_range(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user