Fix forecasting service
This commit is contained in:
@@ -12,7 +12,10 @@ from typing import List, Optional
|
||||
from datetime import date
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.auth import get_current_user_from_headers
|
||||
from shared.auth.decorators import (
|
||||
get_current_user_dep,
|
||||
get_current_tenant_id_dep
|
||||
)
|
||||
from app.services.forecasting_service import ForecastingService
|
||||
from app.schemas.forecasts import (
|
||||
ForecastRequest, ForecastResponse, BatchForecastRequest,
|
||||
@@ -30,13 +33,14 @@ forecasting_service = ForecastingService()
|
||||
async def create_single_forecast(
|
||||
request: ForecastRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user_from_headers)
|
||||
tenant_id: str = Depends(get_current_tenant_id_dep),
|
||||
current_user: dict = Depends(get_current_user_dep)
|
||||
):
|
||||
"""Generate a single product forecast"""
|
||||
|
||||
try:
|
||||
# Verify tenant access
|
||||
if str(request.tenant_id) != str(current_user.get("tenant_id")):
|
||||
if str(request.tenant_id) != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to this tenant"
|
||||
@@ -88,13 +92,14 @@ async def create_single_forecast(
|
||||
async def create_batch_forecast(
|
||||
request: BatchForecastRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user_from_headers)
|
||||
tenant_id: str = Depends(get_current_tenant_id_dep),
|
||||
current_user: dict = Depends(get_current_user_dep)
|
||||
):
|
||||
"""Generate batch forecasts for multiple products"""
|
||||
|
||||
try:
|
||||
# Verify tenant access
|
||||
if str(request.tenant_id) != str(current_user.get("tenant_id")):
|
||||
if str(request.tenant_id) != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to this tenant"
|
||||
@@ -172,12 +177,11 @@ async def list_forecasts(
|
||||
end_date: Optional[date] = Query(None),
|
||||
product_name: Optional[str] = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user_from_headers)
|
||||
tenant_id: str = Depends(get_current_tenant_id_dep)
|
||||
):
|
||||
"""List forecasts with filtering"""
|
||||
|
||||
try:
|
||||
tenant_id = str(current_user.get("tenant_id"))
|
||||
|
||||
# Get forecasts
|
||||
forecasts = await forecasting_service.get_forecasts(
|
||||
@@ -230,15 +234,14 @@ async def list_forecasts(
|
||||
async def get_forecast_alerts(
|
||||
active_only: bool = Query(True),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user_from_headers)
|
||||
tenant_id: str = Depends(get_current_tenant_id_dep),
|
||||
current_user: dict = Depends(get_current_user_dep)
|
||||
):
|
||||
"""Get forecast alerts for tenant"""
|
||||
|
||||
try:
|
||||
from sqlalchemy import select, and_
|
||||
|
||||
tenant_id = current_user.get("tenant_id")
|
||||
|
||||
# Build query
|
||||
query = select(ForecastAlert).where(
|
||||
ForecastAlert.tenant_id == tenant_id
|
||||
@@ -281,7 +284,8 @@ async def get_forecast_alerts(
|
||||
async def acknowledge_alert(
|
||||
alert_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user_from_headers)
|
||||
tenant_id: str = Depends(get_current_tenant_id_dep),
|
||||
current_user: dict = Depends(get_current_user_dep)
|
||||
):
|
||||
"""Acknowledge a forecast alert"""
|
||||
|
||||
@@ -289,8 +293,6 @@ async def acknowledge_alert(
|
||||
from sqlalchemy import select, update
|
||||
from datetime import datetime
|
||||
|
||||
tenant_id = current_user.get("tenant_id")
|
||||
|
||||
# Get alert
|
||||
result = await db.execute(
|
||||
select(ForecastAlert).where(
|
||||
|
||||
@@ -12,7 +12,10 @@ from typing import List, Dict, Any
|
||||
from datetime import date, datetime, timedelta
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.auth import get_current_user_from_headers
|
||||
from shared.auth.decorators import (
|
||||
get_current_user_dep,
|
||||
get_current_tenant_id_dep
|
||||
)
|
||||
from app.services.prediction_service import PredictionService
|
||||
from app.schemas.forecasts import ForecastRequest
|
||||
|
||||
@@ -28,12 +31,11 @@ async def get_realtime_prediction(
|
||||
location: str,
|
||||
forecast_date: date,
|
||||
features: Dict[str, Any],
|
||||
current_user: dict = Depends(get_current_user_from_headers)
|
||||
tenant_id: str = Depends(get_current_tenant_id_dep)
|
||||
):
|
||||
"""Get real-time prediction without storing in database"""
|
||||
|
||||
try:
|
||||
tenant_id = str(current_user.get("tenant_id"))
|
||||
|
||||
# Get latest model
|
||||
from app.services.forecasting_service import ForecastingService
|
||||
@@ -83,13 +85,12 @@ async def get_quick_prediction(
|
||||
product_name: str,
|
||||
location: str = Query(...),
|
||||
days_ahead: int = Query(1, ge=1, le=7),
|
||||
current_user: dict = Depends(get_current_user_from_headers)
|
||||
tenant_id: str = Depends(get_current_tenant_id_dep)
|
||||
):
|
||||
"""Get quick prediction for next few days"""
|
||||
|
||||
try:
|
||||
tenant_id = str(current_user.get("tenant_id"))
|
||||
|
||||
|
||||
# Generate predictions for the next N days
|
||||
predictions = []
|
||||
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/core/auth.py
|
||||
# ================================================================
|
||||
"""
|
||||
Authentication utilities for forecasting service
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from fastapi import HTTPException, status, Request
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
async def get_current_user_from_headers(request: Request) -> Dict[str, Any]:
|
||||
"""
|
||||
Get current user from gateway headers
|
||||
Gateway middleware adds user context to headers after JWT verification
|
||||
"""
|
||||
|
||||
try:
|
||||
# Extract user information from headers set by API Gateway
|
||||
user_id = request.headers.get("X-User-ID")
|
||||
user_email = request.headers.get("X-User-Email")
|
||||
tenant_id = request.headers.get("X-Tenant-ID")
|
||||
user_roles = request.headers.get("X-User-Roles", "").split(",")
|
||||
|
||||
if not user_id or not tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required"
|
||||
)
|
||||
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"email": user_email,
|
||||
"tenant_id": tenant_id,
|
||||
"roles": [role.strip() for role in user_roles if role.strip()]
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Error extracting user from headers", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authentication"
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/services/messaging.py
|
||||
# services/forecasting/app/services/messaging.py
|
||||
# ================================================================
|
||||
"""
|
||||
Messaging service for event publishing and consuming
|
||||
@@ -10,89 +10,126 @@ import json
|
||||
from typing import Dict, Any
|
||||
import asyncio
|
||||
|
||||
from shared.messaging.rabbitmq import RabbitMQPublisher, RabbitMQConsumer
|
||||
from shared.messaging.rabbitmq import RabbitMQClient
|
||||
from shared.messaging.events import (
|
||||
TrainingCompletedEvent,
|
||||
DataImportedEvent,
|
||||
ForecastGeneratedEvent,
|
||||
)
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Global messaging instances
|
||||
publisher = None
|
||||
consumer = None
|
||||
# Global messaging instance
|
||||
rabbitmq_client = None
|
||||
|
||||
async def setup_messaging():
|
||||
"""Initialize messaging services"""
|
||||
global publisher, consumer
|
||||
|
||||
global rabbitmq_client
|
||||
|
||||
try:
|
||||
# Initialize publisher
|
||||
publisher = RabbitMQPublisher(settings.RABBITMQ_URL)
|
||||
await publisher.connect()
|
||||
|
||||
# Initialize consumer
|
||||
consumer = RabbitMQConsumer(settings.RABBITMQ_URL)
|
||||
await consumer.connect()
|
||||
|
||||
rabbitmq_client = RabbitMQClient(settings.RABBITMQ_URL, service_name="forecasting_service")
|
||||
await rabbitmq_client.connect()
|
||||
|
||||
# Set up event handlers
|
||||
await consumer.subscribe("training.model.updated", handle_model_updated)
|
||||
await consumer.subscribe("data.weather.updated", handle_weather_updated)
|
||||
|
||||
# We need to adapt the callback to accept aio_pika.IncomingMessage
|
||||
await rabbitmq_client.consume_events(
|
||||
exchange_name="training.events",
|
||||
queue_name="forecasting_model_update_queue",
|
||||
routing_key="training.completed", # Assuming model updates are part of training.completed events
|
||||
callback=handle_model_updated_message
|
||||
)
|
||||
await rabbitmq_client.consume_events(
|
||||
exchange_name="data.events",
|
||||
queue_name="forecasting_weather_update_queue",
|
||||
routing_key="data.weather.updated", # This needs to match the actual event type if different
|
||||
callback=handle_weather_updated_message
|
||||
)
|
||||
|
||||
logger.info("Messaging setup completed")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to setup messaging", error=str(e))
|
||||
raise
|
||||
|
||||
async def cleanup_messaging():
|
||||
"""Cleanup messaging connections"""
|
||||
global publisher, consumer
|
||||
|
||||
global rabbitmq_client
|
||||
|
||||
try:
|
||||
if consumer:
|
||||
await consumer.close()
|
||||
if publisher:
|
||||
await publisher.close()
|
||||
|
||||
if rabbitmq_client:
|
||||
await rabbitmq_client.disconnect()
|
||||
|
||||
logger.info("Messaging cleanup completed")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error during messaging cleanup", error=str(e))
|
||||
|
||||
async def publish_forecast_completed(data: Dict[str, Any]):
|
||||
"""Publish forecast completed event"""
|
||||
if publisher:
|
||||
await publisher.publish("forecasting.forecast.completed", data)
|
||||
if rabbitmq_client:
|
||||
event = ForecastGeneratedEvent(service_name="forecasting_service", data=data, event_type="forecast.completed")
|
||||
await rabbitmq_client.publish_forecast_event(event_type="completed", forecast_data=event.to_dict())
|
||||
|
||||
async def publish_alert_created(data: Dict[str, Any]):
|
||||
"""Publish alert created event"""
|
||||
if publisher:
|
||||
await publisher.publish("forecasting.alert.created", data)
|
||||
# Assuming 'alert.created' is a type of forecast event, or define a new exchange/publisher method
|
||||
if rabbitmq_client:
|
||||
event = ForecastGeneratedEvent(service_name="forecasting_service", data=data, event_type="alert.created")
|
||||
await rabbitmq_client.publish_forecast_event(event_type="alert.created", forecast_data=event.to_dict())
|
||||
|
||||
async def publish_batch_completed(data: Dict[str, Any]):
|
||||
"""Publish batch forecast completed event"""
|
||||
if publisher:
|
||||
await publisher.publish("forecasting.batch.completed", data)
|
||||
if rabbitmq_client:
|
||||
event = ForecastGeneratedEvent(service_name="forecasting_service", data=data, event_type="forecast.batch.completed")
|
||||
await rabbitmq_client.publish_forecast_event(event_type="batch.completed", forecast_data=event.to_dict())
|
||||
|
||||
# Event handlers
|
||||
|
||||
# Event handler wrappers for aio_pika messages
|
||||
async def handle_model_updated_message(message: Any):
|
||||
async with message.process():
|
||||
try:
|
||||
event_data = json.loads(message.body.decode())
|
||||
# Assuming the actual event data is nested under a 'data' key within the event dictionary
|
||||
await handle_model_updated(event_data.get("data", {}))
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error("Failed to decode model updated message JSON", error=str(e), body=message.body)
|
||||
except Exception as e:
|
||||
logger.error("Error processing model updated message", error=str(e), body=message.body)
|
||||
|
||||
async def handle_weather_updated_message(message: Any):
|
||||
async with message.process():
|
||||
try:
|
||||
event_data = json.loads(message.body.decode())
|
||||
# Assuming the actual event data is nested under a 'data' key within the event dictionary
|
||||
await handle_weather_updated(event_data.get("data", {}))
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error("Failed to decode weather updated message JSON", error=str(e), body=message.body)
|
||||
except Exception as e:
|
||||
logger.error("Error processing weather updated message", error=str(e), body=message.body)
|
||||
|
||||
|
||||
# Original Event handlers (now called from the message wrappers)
|
||||
async def handle_model_updated(data: Dict[str, Any]):
|
||||
"""Handle model updated event from training service"""
|
||||
try:
|
||||
logger.info("Received model updated event",
|
||||
logger.info("Received model updated event",
|
||||
model_id=data.get("model_id"),
|
||||
tenant_id=data.get("tenant_id"))
|
||||
|
||||
|
||||
# Clear model cache for this model
|
||||
# This will be handled by PredictionService
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error handling model updated event", error=str(e))
|
||||
|
||||
async def handle_weather_updated(data: Dict[str, Any]):
|
||||
"""Handle weather data updated event"""
|
||||
try:
|
||||
logger.info("Received weather updated event",
|
||||
logger.info("Received weather updated event",
|
||||
date=data.get("date"))
|
||||
|
||||
|
||||
# Could trigger re-forecasting if needed
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error handling weather updated event", error=str(e))
|
||||
logger.error("Error handling weather updated event", error=str(e))
|
||||
Reference in New Issue
Block a user