Start fixing forecast service API 7
This commit is contained in:
@@ -33,26 +33,19 @@ forecasting_service = ForecastingService()
|
||||
async def create_single_forecast(
|
||||
request: ForecastRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
current_user: dict = Depends(get_current_user_dep)
|
||||
tenant_id: str = Path(..., description="Tenant ID")
|
||||
):
|
||||
"""Generate a single product forecast"""
|
||||
|
||||
try:
|
||||
# Verify tenant access
|
||||
if str(request.tenant_id) != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to this tenant"
|
||||
)
|
||||
|
||||
# Generate forecast
|
||||
forecast = await forecasting_service.generate_forecast(request, db)
|
||||
forecast = await forecasting_service.generate_forecast(tenant_id, request, db)
|
||||
|
||||
# Convert to response model
|
||||
return ForecastResponse(
|
||||
id=str(forecast.id),
|
||||
tenant_id=str(forecast.tenant_id),
|
||||
tenant_id=tenant_id,
|
||||
product_name=forecast.product_name,
|
||||
location=forecast.location,
|
||||
forecast_date=forecast.forecast_date,
|
||||
|
||||
@@ -38,19 +38,19 @@ class ForecastingService:
|
||||
self.model_client = ModelClient()
|
||||
self.data_client = DataClient()
|
||||
|
||||
async def generate_forecast(self, request: ForecastRequest, db: AsyncSession) -> Forecast:
|
||||
async def generate_forecast(self, tenant_id: str, request: ForecastRequest, db: AsyncSession) -> Forecast:
|
||||
"""Generate a single forecast for a product"""
|
||||
start_time = datetime.now()
|
||||
|
||||
try:
|
||||
logger.info("Generating forecast",
|
||||
tenant_id=request.tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
product=request.product_name,
|
||||
date=request.forecast_date)
|
||||
|
||||
# Get the latest trained model for this tenant/product
|
||||
model_info = await self._get_latest_model(
|
||||
request.tenant_id,
|
||||
tenant_id,
|
||||
request.product_name,
|
||||
)
|
||||
|
||||
@@ -58,7 +58,7 @@ class ForecastingService:
|
||||
raise ValueError(f"No trained model found for {request.product_name}")
|
||||
|
||||
# Prepare features for prediction
|
||||
features = await self._prepare_forecast_features(request)
|
||||
features = await self._prepare_forecast_features(tenant_id, request)
|
||||
|
||||
# Generate prediction using ML service
|
||||
prediction_result = await self.prediction_service.predict(
|
||||
@@ -69,7 +69,7 @@ class ForecastingService:
|
||||
|
||||
# Create forecast record
|
||||
forecast = Forecast(
|
||||
tenant_id=uuid.UUID(request.tenant_id),
|
||||
tenant_id=uuid.UUID(tenant_id),
|
||||
product_name=request.product_name,
|
||||
forecast_date=datetime.combine(request.forecast_date, datetime.min.time()),
|
||||
|
||||
@@ -115,7 +115,7 @@ class ForecastingService:
|
||||
# Publish event
|
||||
await publish_forecast_completed({
|
||||
"forecast_id": str(forecast.id),
|
||||
"tenant_id": request.tenant_id,
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": request.product_name,
|
||||
"predicted_demand": forecast.predicted_demand
|
||||
})
|
||||
@@ -256,7 +256,7 @@ class ForecastingService:
|
||||
logger.error("Error getting latest model", error=str(e))
|
||||
raise
|
||||
|
||||
async def _prepare_forecast_features(self, request: ForecastRequest) -> Dict[str, Any]:
|
||||
async def _prepare_forecast_features(self, tenant_id: str, request: ForecastRequest) -> Dict[str, Any]:
|
||||
"""Prepare features for forecasting model"""
|
||||
|
||||
features = {
|
||||
@@ -269,7 +269,7 @@ class ForecastingService:
|
||||
features["is_holiday"] = await self._is_spanish_holiday(request.forecast_date)
|
||||
|
||||
|
||||
weather_data = await self._get_weather_forecast(request.tenant_id, 1)
|
||||
weather_data = await self._get_weather_forecast(tenant_id, 1)
|
||||
features.update(weather_data)
|
||||
|
||||
return features
|
||||
|
||||
Reference in New Issue
Block a user