Fix forecasting service

This commit is contained in:
Urtzi Alfaro
2025-07-21 20:43:17 +02:00
parent 0e7ca10a29
commit 153ae3f154
11 changed files with 107 additions and 534 deletions

View File

@@ -11,7 +11,6 @@ from app.core.database import get_db
from app.schemas.auth import UserResponse, PasswordChange
from app.schemas.users import UserUpdate
from app.services.user_service import UserService
from app.core.auth import get_current_user
from app.models.users import User
# Import unified authentication from shared library
@@ -53,7 +52,7 @@ async def get_current_user_info(
@router.put("/me", response_model=UserResponse)
async def update_current_user(
user_update: UserUpdate,
current_user: User = Depends(get_current_user),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""Update current user information"""
@@ -83,7 +82,7 @@ async def update_current_user(
@router.post("/change-password")
async def change_password(
password_data: PasswordChange,
current_user: User = Depends(get_current_user),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""Change user password"""
@@ -106,7 +105,7 @@ async def change_password(
@router.delete("/me")
async def delete_current_user(
current_user: User = Depends(get_current_user),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""Delete current user account"""

View File

@@ -10,7 +10,6 @@ from datetime import datetime, timedelta
import structlog
from app.core.database import get_db
from app.core.auth import get_current_user, AuthInfo
from app.services.traffic_service import TrafficService
from app.services.messaging import data_publisher
from app.schemas.external import (

View File

@@ -1,72 +0,0 @@
from fastapi import HTTPException, Depends, status, Request
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
import httpx
import structlog
from typing import Dict, Any, Optional
from app.core.config import settings
logger = structlog.get_logger()
security = HTTPBearer(auto_error=False) # ✅ Don't auto-error, we'll handle manually
class AuthInfo:
"""Authentication information"""
def __init__(self, user_id: str, email: str, tenant_id: str, roles: list):
self.user_id = user_id
self.email = email
self.tenant_id = tenant_id
self.roles = roles
async def get_current_user(
request: Request,
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
) -> AuthInfo:
"""Get current user from gateway headers or token verification"""
# ✅ OPTION 1: Check for gateway headers (preferred when using gateway)
user_id = request.headers.get("X-User-ID")
email = request.headers.get("X-User-Email")
tenant_id = request.headers.get("X-Tenant-ID")
roles_header = request.headers.get("X-User-Roles", "")
if user_id and email and tenant_id:
# Gateway already authenticated the user
roles = roles_header.split(",") if roles_header else ["user"]
logger.info("Authenticated via gateway headers", user_id=user_id, email=email)
return AuthInfo(user_id, email, tenant_id, roles)
# ✅ OPTION 2: Direct token verification (when not using gateway)
if not credentials:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required (no token or gateway headers)"
)
try:
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.post(
f"{settings.AUTH_SERVICE_URL}/api/v1/auth/verify",
headers={"Authorization": f"Bearer {credentials.credentials}"}
)
if response.status_code == 200:
user_data = response.json()
logger.info("Authenticated via direct token", user_id=user_data.get("user_id"))
return AuthInfo(
user_id=user_data["user_id"],
email=user_data["email"],
tenant_id=user_data["tenant_id"],
roles=user_data.get("roles", ["user"])
)
else:
logger.warning("Token verification failed", status_code=response.status_code)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials"
)
except httpx.RequestError as e:
logger.error("Auth service unavailable", error=str(e))
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Authentication service unavailable"
)

View File

@@ -12,7 +12,10 @@ from typing import List, Optional
from datetime import date
from app.core.database import get_db
from app.core.auth import get_current_user_from_headers
from shared.auth.decorators import (
get_current_user_dep,
get_current_tenant_id_dep
)
from app.services.forecasting_service import ForecastingService
from app.schemas.forecasts import (
ForecastRequest, ForecastResponse, BatchForecastRequest,
@@ -30,13 +33,14 @@ forecasting_service = ForecastingService()
async def create_single_forecast(
request: ForecastRequest,
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(get_current_user_from_headers)
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: dict = Depends(get_current_user_dep)
):
"""Generate a single product forecast"""
try:
# Verify tenant access
if str(request.tenant_id) != str(current_user.get("tenant_id")):
if str(request.tenant_id) != tenant_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to this tenant"
@@ -88,13 +92,14 @@ async def create_single_forecast(
async def create_batch_forecast(
request: BatchForecastRequest,
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(get_current_user_from_headers)
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: dict = Depends(get_current_user_dep)
):
"""Generate batch forecasts for multiple products"""
try:
# Verify tenant access
if str(request.tenant_id) != str(current_user.get("tenant_id")):
if str(request.tenant_id) != tenant_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to this tenant"
@@ -172,12 +177,11 @@ async def list_forecasts(
end_date: Optional[date] = Query(None),
product_name: Optional[str] = Query(None),
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(get_current_user_from_headers)
tenant_id: str = Depends(get_current_tenant_id_dep)
):
"""List forecasts with filtering"""
try:
tenant_id = str(current_user.get("tenant_id"))
# Get forecasts
forecasts = await forecasting_service.get_forecasts(
@@ -230,15 +234,14 @@ async def list_forecasts(
async def get_forecast_alerts(
active_only: bool = Query(True),
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(get_current_user_from_headers)
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: dict = Depends(get_current_user_dep)
):
"""Get forecast alerts for tenant"""
try:
from sqlalchemy import select, and_
tenant_id = current_user.get("tenant_id")
# Build query
query = select(ForecastAlert).where(
ForecastAlert.tenant_id == tenant_id
@@ -281,7 +284,8 @@ async def get_forecast_alerts(
async def acknowledge_alert(
alert_id: str,
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(get_current_user_from_headers)
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: dict = Depends(get_current_user_dep)
):
"""Acknowledge a forecast alert"""
@@ -289,8 +293,6 @@ async def acknowledge_alert(
from sqlalchemy import select, update
from datetime import datetime
tenant_id = current_user.get("tenant_id")
# Get alert
result = await db.execute(
select(ForecastAlert).where(

View File

@@ -12,7 +12,10 @@ from typing import List, Dict, Any
from datetime import date, datetime, timedelta
from app.core.database import get_db
from app.core.auth import get_current_user_from_headers
from shared.auth.decorators import (
get_current_user_dep,
get_current_tenant_id_dep
)
from app.services.prediction_service import PredictionService
from app.schemas.forecasts import ForecastRequest
@@ -28,12 +31,11 @@ async def get_realtime_prediction(
location: str,
forecast_date: date,
features: Dict[str, Any],
current_user: dict = Depends(get_current_user_from_headers)
tenant_id: str = Depends(get_current_tenant_id_dep)
):
"""Get real-time prediction without storing in database"""
try:
tenant_id = str(current_user.get("tenant_id"))
# Get latest model
from app.services.forecasting_service import ForecastingService
@@ -83,13 +85,12 @@ async def get_quick_prediction(
product_name: str,
location: str = Query(...),
days_ahead: int = Query(1, ge=1, le=7),
current_user: dict = Depends(get_current_user_from_headers)
tenant_id: str = Depends(get_current_tenant_id_dep)
):
"""Get quick prediction for next few days"""
try:
tenant_id = str(current_user.get("tenant_id"))
# Generate predictions for the next N days
predictions = []

View File

@@ -1,48 +0,0 @@
# ================================================================
# services/forecasting/app/core/auth.py
# ================================================================
"""
Authentication utilities for forecasting service
"""
import structlog
from fastapi import HTTPException, status, Request
from typing import Dict, Any, Optional
logger = structlog.get_logger()
async def get_current_user_from_headers(request: Request) -> Dict[str, Any]:
"""
Get current user from gateway headers
Gateway middleware adds user context to headers after JWT verification
"""
try:
# Extract user information from headers set by API Gateway
user_id = request.headers.get("X-User-ID")
user_email = request.headers.get("X-User-Email")
tenant_id = request.headers.get("X-Tenant-ID")
user_roles = request.headers.get("X-User-Roles", "").split(",")
if not user_id or not tenant_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required"
)
return {
"user_id": user_id,
"email": user_email,
"tenant_id": tenant_id,
"roles": [role.strip() for role in user_roles if role.strip()]
}
except HTTPException:
raise
except Exception as e:
logger.error("Error extracting user from headers", error=str(e))
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication"
)

View File

@@ -1,5 +1,5 @@
# ================================================================
# services/forecasting/app/services/messaging.py
# services/forecasting/app/services/messaging.py
# ================================================================
"""
Messaging service for event publishing and consuming
@@ -10,89 +10,126 @@ import json
from typing import Dict, Any
import asyncio
from shared.messaging.rabbitmq import RabbitMQPublisher, RabbitMQConsumer
from shared.messaging.rabbitmq import RabbitMQClient
from shared.messaging.events import (
TrainingCompletedEvent,
DataImportedEvent,
ForecastGeneratedEvent,
)
from app.core.config import settings
logger = structlog.get_logger()
# Global messaging instances
publisher = None
consumer = None
# Global messaging instance
rabbitmq_client = None
async def setup_messaging():
"""Initialize messaging services"""
global publisher, consumer
global rabbitmq_client
try:
# Initialize publisher
publisher = RabbitMQPublisher(settings.RABBITMQ_URL)
await publisher.connect()
# Initialize consumer
consumer = RabbitMQConsumer(settings.RABBITMQ_URL)
await consumer.connect()
rabbitmq_client = RabbitMQClient(settings.RABBITMQ_URL, service_name="forecasting_service")
await rabbitmq_client.connect()
# Set up event handlers
await consumer.subscribe("training.model.updated", handle_model_updated)
await consumer.subscribe("data.weather.updated", handle_weather_updated)
# We need to adapt the callback to accept aio_pika.IncomingMessage
await rabbitmq_client.consume_events(
exchange_name="training.events",
queue_name="forecasting_model_update_queue",
routing_key="training.completed", # Assuming model updates are part of training.completed events
callback=handle_model_updated_message
)
await rabbitmq_client.consume_events(
exchange_name="data.events",
queue_name="forecasting_weather_update_queue",
routing_key="data.weather.updated", # This needs to match the actual event type if different
callback=handle_weather_updated_message
)
logger.info("Messaging setup completed")
except Exception as e:
logger.error("Failed to setup messaging", error=str(e))
raise
async def cleanup_messaging():
"""Cleanup messaging connections"""
global publisher, consumer
global rabbitmq_client
try:
if consumer:
await consumer.close()
if publisher:
await publisher.close()
if rabbitmq_client:
await rabbitmq_client.disconnect()
logger.info("Messaging cleanup completed")
except Exception as e:
logger.error("Error during messaging cleanup", error=str(e))
async def publish_forecast_completed(data: Dict[str, Any]):
"""Publish forecast completed event"""
if publisher:
await publisher.publish("forecasting.forecast.completed", data)
if rabbitmq_client:
event = ForecastGeneratedEvent(service_name="forecasting_service", data=data, event_type="forecast.completed")
await rabbitmq_client.publish_forecast_event(event_type="completed", forecast_data=event.to_dict())
async def publish_alert_created(data: Dict[str, Any]):
"""Publish alert created event"""
if publisher:
await publisher.publish("forecasting.alert.created", data)
# Assuming 'alert.created' is a type of forecast event, or define a new exchange/publisher method
if rabbitmq_client:
event = ForecastGeneratedEvent(service_name="forecasting_service", data=data, event_type="alert.created")
await rabbitmq_client.publish_forecast_event(event_type="alert.created", forecast_data=event.to_dict())
async def publish_batch_completed(data: Dict[str, Any]):
"""Publish batch forecast completed event"""
if publisher:
await publisher.publish("forecasting.batch.completed", data)
if rabbitmq_client:
event = ForecastGeneratedEvent(service_name="forecasting_service", data=data, event_type="forecast.batch.completed")
await rabbitmq_client.publish_forecast_event(event_type="batch.completed", forecast_data=event.to_dict())
# Event handlers
# Event handler wrappers for aio_pika messages
async def handle_model_updated_message(message: Any):
async with message.process():
try:
event_data = json.loads(message.body.decode())
# Assuming the actual event data is nested under a 'data' key within the event dictionary
await handle_model_updated(event_data.get("data", {}))
except json.JSONDecodeError as e:
logger.error("Failed to decode model updated message JSON", error=str(e), body=message.body)
except Exception as e:
logger.error("Error processing model updated message", error=str(e), body=message.body)
async def handle_weather_updated_message(message: Any):
async with message.process():
try:
event_data = json.loads(message.body.decode())
# Assuming the actual event data is nested under a 'data' key within the event dictionary
await handle_weather_updated(event_data.get("data", {}))
except json.JSONDecodeError as e:
logger.error("Failed to decode weather updated message JSON", error=str(e), body=message.body)
except Exception as e:
logger.error("Error processing weather updated message", error=str(e), body=message.body)
# Original Event handlers (now called from the message wrappers)
async def handle_model_updated(data: Dict[str, Any]):
"""Handle model updated event from training service"""
try:
logger.info("Received model updated event",
logger.info("Received model updated event",
model_id=data.get("model_id"),
tenant_id=data.get("tenant_id"))
# Clear model cache for this model
# This will be handled by PredictionService
except Exception as e:
logger.error("Error handling model updated event", error=str(e))
async def handle_weather_updated(data: Dict[str, Any]):
"""Handle weather data updated event"""
try:
logger.info("Received weather updated event",
logger.info("Received weather updated event",
date=data.get("date"))
# Could trigger re-forecasting if needed
except Exception as e:
logger.error("Error handling weather updated event", error=str(e))
logger.error("Error handling weather updated event", error=str(e))

View File

@@ -1,22 +0,0 @@
"""
Authentication configuration for notification service
"""
from shared.auth.jwt_handler import JWTHandler
from shared.auth.decorators import require_auth, require_role
from app.core.config import settings
# Initialize JWT handler
jwt_handler = JWTHandler(
secret_key=settings.JWT_SECRET_KEY,
algorithm=settings.JWT_ALGORITHM,
access_token_expire_minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES
)
# Export commonly used functions
verify_token = jwt_handler.verify_token
create_access_token = jwt_handler.create_access_token
get_current_user = jwt_handler.get_current_user
# Export decorators
__all__ = ['verify_token', 'create_access_token', 'get_current_user', 'require_auth', 'require_role']

View File

@@ -1,22 +0,0 @@
"""
Authentication configuration for tenant service
"""
from shared.auth.jwt_handler import JWTHandler
from shared.auth.decorators import require_auth, require_role
from app.core.config import settings
# Initialize JWT handler
jwt_handler = JWTHandler(
secret_key=settings.JWT_SECRET_KEY,
algorithm=settings.JWT_ALGORITHM,
access_token_expire_minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES
)
# Export commonly used functions
verify_token = jwt_handler.verify_token
create_access_token = jwt_handler.create_access_token
get_current_user = jwt_handler.get_current_user
# Export decorators
__all__ = ['verify_token', 'create_access_token', 'get_current_user', 'require_auth', 'require_role']

View File

@@ -8,10 +8,12 @@ from typing import List
import structlog
from app.core.database import get_db
from app.core.auth import get_current_tenant_id
from app.schemas.training import TrainedModelResponse
from app.services.training_service import TrainingService
from shared.auth.decorators import (
get_current_tenant_id_dep
)
logger = structlog.get_logger()
router = APIRouter()
@@ -20,7 +22,7 @@ training_service = TrainingService()
@router.get("/", response_model=List[TrainedModelResponse])
async def get_trained_models(
tenant_id: str = Depends(get_current_tenant_id),
tenant_id: str = Depends(get_current_tenant_id_dep),
db: AsyncSession = Depends(get_db)
):
"""Get trained models"""

View File

@@ -1,303 +0,0 @@
# services/training/app/core/auth.py
"""
Authentication and authorization for training service
"""
import structlog
from typing import Optional
from fastapi import HTTPException, Depends, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
import httpx
from app.core.config import settings
logger = structlog.get_logger()
# HTTP Bearer token scheme
security = HTTPBearer(auto_error=False)
class AuthenticationError(Exception):
"""Custom exception for authentication errors"""
pass
class AuthorizationError(Exception):
"""Custom exception for authorization errors"""
pass
async def verify_token(token: str) -> dict:
"""
Verify JWT token with auth service
Args:
token: JWT token to verify
Returns:
dict: Token payload with user and tenant information
Raises:
AuthenticationError: If token is invalid
"""
try:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{settings.AUTH_SERVICE_URL}/auth/verify",
headers={"Authorization": f"Bearer {token}"},
timeout=10.0
)
if response.status_code == 200:
token_data = response.json()
logger.debug("Token verified successfully", user_id=token_data.get("user_id"))
return token_data
elif response.status_code == 401:
logger.warning("Invalid token provided")
raise AuthenticationError("Invalid or expired token")
else:
logger.error("Auth service error", status_code=response.status_code)
raise AuthenticationError("Authentication service unavailable")
except httpx.TimeoutException:
logger.error("Auth service timeout")
raise AuthenticationError("Authentication service timeout")
except httpx.RequestError as e:
logger.error("Auth service request error", error=str(e))
raise AuthenticationError("Authentication service unavailable")
except AuthenticationError:
raise
except Exception as e:
logger.error("Unexpected auth error", error=str(e))
raise AuthenticationError("Authentication failed")
async def get_current_user(
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
) -> dict:
"""
Get current authenticated user
Args:
credentials: HTTP Bearer credentials
Returns:
dict: User information
Raises:
HTTPException: If authentication fails
"""
if not credentials:
logger.warning("No credentials provided")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication credentials required",
headers={"WWW-Authenticate": "Bearer"},
)
try:
token_data = await verify_token(credentials.credentials)
return token_data
except AuthenticationError as e:
logger.warning("Authentication failed", error=str(e))
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=str(e),
headers={"WWW-Authenticate": "Bearer"},
)
async def get_current_tenant_id(
current_user: dict = Depends(get_current_user)
) -> str:
"""
Get current tenant ID from authenticated user
Args:
current_user: Current authenticated user data
Returns:
str: Tenant ID
Raises:
HTTPException: If tenant ID is missing
"""
tenant_id = current_user.get("tenant_id")
if not tenant_id:
logger.error("Missing tenant_id in token", user_data=current_user)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid token: missing tenant information"
)
return tenant_id
async def require_admin_role(
current_user: dict = Depends(get_current_user)
) -> dict:
"""
Require admin role for endpoint access
Args:
current_user: Current authenticated user data
Returns:
dict: User information
Raises:
HTTPException: If user is not admin
"""
user_role = current_user.get("role", "").lower()
if user_role != "admin":
logger.warning("Access denied - admin role required",
user_id=current_user.get("user_id"),
role=user_role)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin role required"
)
return current_user
async def require_training_permission(
current_user: dict = Depends(get_current_user)
) -> dict:
"""
Require training permission for endpoint access
Args:
current_user: Current authenticated user data
Returns:
dict: User information
Raises:
HTTPException: If user doesn't have training permission
"""
permissions = current_user.get("permissions", [])
if "training" not in permissions and current_user.get("role", "").lower() != "admin":
logger.warning("Access denied - training permission required",
user_id=current_user.get("user_id"),
permissions=permissions)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Training permission required"
)
return current_user
# Optional authentication for development/testing
async def get_current_user_optional(
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
) -> Optional[dict]:
"""
Get current user but don't require authentication (for development)
Args:
credentials: HTTP Bearer credentials
Returns:
dict or None: User information if authenticated, None otherwise
"""
if not credentials:
return None
try:
token_data = await verify_token(credentials.credentials)
return token_data
except AuthenticationError:
return None
async def get_tenant_id_optional(
current_user: Optional[dict] = Depends(get_current_user_optional)
) -> Optional[str]:
"""
Get tenant ID but don't require authentication (for development)
Args:
current_user: Current user data (optional)
Returns:
str or None: Tenant ID if available, None otherwise
"""
if not current_user:
return None
return current_user.get("tenant_id")
# Development/testing auth bypass
async def get_test_tenant_id() -> str:
"""
Get test tenant ID for development/testing
Only works when DEBUG is enabled
Returns:
str: Test tenant ID
"""
if settings.DEBUG:
return "test-tenant-development"
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Test authentication only available in debug mode"
)
# Token validation utility
def validate_token_structure(token_data: dict) -> bool:
"""
Validate that token data has required structure
Args:
token_data: Token payload data
Returns:
bool: True if valid structure, False otherwise
"""
required_fields = ["user_id", "tenant_id"]
for field in required_fields:
if field not in token_data:
logger.warning("Invalid token structure - missing field", field=field)
return False
return True
# Role checking utilities
def has_role(user_data: dict, required_role: str) -> bool:
"""
Check if user has required role
Args:
user_data: User data from token
required_role: Required role name
Returns:
bool: True if user has role, False otherwise
"""
user_role = user_data.get("role", "").lower()
return user_role == required_role.lower()
def has_permission(user_data: dict, required_permission: str) -> bool:
"""
Check if user has required permission
Args:
user_data: User data from token
required_permission: Required permission name
Returns:
bool: True if user has permission, False otherwise
"""
permissions = user_data.get("permissions", [])
return required_permission in permissions or has_role(user_data, "admin")
# Export commonly used items
__all__ = [
'get_current_user',
'get_current_tenant_id',
'require_admin_role',
'require_training_permission',
'get_current_user_optional',
'get_tenant_id_optional',
'get_test_tenant_id',
'has_role',
'has_permission',
'AuthenticationError',
'AuthorizationError'
]