Improve AI logic
This commit is contained in:
@@ -394,34 +394,80 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
error=str(e))
|
||||
return {"error": f"Failed to get forecast summary: {str(e)}"}
|
||||
|
||||
async def get_forecasts_by_date(
|
||||
self,
|
||||
tenant_id: str,
|
||||
forecast_date: date,
|
||||
inventory_product_id: str = None
|
||||
) -> List[Forecast]:
|
||||
"""
|
||||
Get all forecasts for a specific date.
|
||||
Used for forecast validation against actual sales.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID
|
||||
forecast_date: Date to get forecasts for
|
||||
inventory_product_id: Optional product filter
|
||||
|
||||
Returns:
|
||||
List of forecasts for the date
|
||||
"""
|
||||
try:
|
||||
query = select(Forecast).where(
|
||||
and_(
|
||||
Forecast.tenant_id == tenant_id,
|
||||
func.date(Forecast.forecast_date) == forecast_date
|
||||
)
|
||||
)
|
||||
|
||||
if inventory_product_id:
|
||||
query = query.where(Forecast.inventory_product_id == inventory_product_id)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
forecasts = result.scalars().all()
|
||||
|
||||
logger.info("Retrieved forecasts by date",
|
||||
tenant_id=tenant_id,
|
||||
forecast_date=forecast_date.isoformat(),
|
||||
count=len(forecasts))
|
||||
|
||||
return list(forecasts)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get forecasts by date",
|
||||
tenant_id=tenant_id,
|
||||
forecast_date=forecast_date.isoformat(),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get forecasts: {str(e)}")
|
||||
|
||||
async def bulk_create_forecasts(self, forecasts_data: List[Dict[str, Any]]) -> List[Forecast]:
|
||||
"""Bulk create multiple forecasts"""
|
||||
try:
|
||||
created_forecasts = []
|
||||
|
||||
|
||||
for forecast_data in forecasts_data:
|
||||
# Validate each forecast
|
||||
validation_result = self._validate_forecast_data(
|
||||
forecast_data,
|
||||
["tenant_id", "inventory_product_id", "location", "forecast_date",
|
||||
["tenant_id", "inventory_product_id", "location", "forecast_date",
|
||||
"predicted_demand", "confidence_lower", "confidence_upper", "model_id"]
|
||||
)
|
||||
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
logger.warning("Skipping invalid forecast data",
|
||||
errors=validation_result["errors"],
|
||||
data=forecast_data)
|
||||
continue
|
||||
|
||||
|
||||
forecast = await self.create(forecast_data)
|
||||
created_forecasts.append(forecast)
|
||||
|
||||
|
||||
logger.info("Bulk created forecasts",
|
||||
requested_count=len(forecasts_data),
|
||||
created_count=len(created_forecasts))
|
||||
|
||||
|
||||
return created_forecasts
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to bulk create forecasts",
|
||||
requested_count=len(forecasts_data),
|
||||
|
||||
Reference in New Issue
Block a user