Fix forecasting service
This commit is contained in:
@@ -11,7 +11,6 @@ from app.core.database import get_db
|
||||
from app.schemas.auth import UserResponse, PasswordChange
|
||||
from app.schemas.users import UserUpdate
|
||||
from app.services.user_service import UserService
|
||||
from app.core.auth import get_current_user
|
||||
from app.models.users import User
|
||||
|
||||
# Import unified authentication from shared library
|
||||
@@ -53,7 +52,7 @@ async def get_current_user_info(
|
||||
@router.put("/me", response_model=UserResponse)
|
||||
async def update_current_user(
|
||||
user_update: UserUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Update current user information"""
|
||||
@@ -83,7 +82,7 @@ async def update_current_user(
|
||||
@router.post("/change-password")
|
||||
async def change_password(
|
||||
password_data: PasswordChange,
|
||||
current_user: User = Depends(get_current_user),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Change user password"""
|
||||
@@ -106,7 +105,7 @@ async def change_password(
|
||||
|
||||
@router.delete("/me")
|
||||
async def delete_current_user(
|
||||
current_user: User = Depends(get_current_user),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Delete current user account"""
|
||||
|
||||
@@ -10,7 +10,6 @@ from datetime import datetime, timedelta
|
||||
import structlog
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.auth import get_current_user, AuthInfo
|
||||
from app.services.traffic_service import TrafficService
|
||||
from app.services.messaging import data_publisher
|
||||
from app.schemas.external import (
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
from fastapi import HTTPException, Depends, status, Request
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
import httpx
|
||||
import structlog
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
security = HTTPBearer(auto_error=False) # ✅ Don't auto-error, we'll handle manually
|
||||
|
||||
class AuthInfo:
|
||||
"""Authentication information"""
|
||||
def __init__(self, user_id: str, email: str, tenant_id: str, roles: list):
|
||||
self.user_id = user_id
|
||||
self.email = email
|
||||
self.tenant_id = tenant_id
|
||||
self.roles = roles
|
||||
|
||||
async def get_current_user(
|
||||
request: Request,
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
|
||||
) -> AuthInfo:
|
||||
"""Get current user from gateway headers or token verification"""
|
||||
|
||||
# ✅ OPTION 1: Check for gateway headers (preferred when using gateway)
|
||||
user_id = request.headers.get("X-User-ID")
|
||||
email = request.headers.get("X-User-Email")
|
||||
tenant_id = request.headers.get("X-Tenant-ID")
|
||||
roles_header = request.headers.get("X-User-Roles", "")
|
||||
|
||||
if user_id and email and tenant_id:
|
||||
# Gateway already authenticated the user
|
||||
roles = roles_header.split(",") if roles_header else ["user"]
|
||||
logger.info("Authenticated via gateway headers", user_id=user_id, email=email)
|
||||
return AuthInfo(user_id, email, tenant_id, roles)
|
||||
|
||||
# ✅ OPTION 2: Direct token verification (when not using gateway)
|
||||
if not credentials:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required (no token or gateway headers)"
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
response = await client.post(
|
||||
f"{settings.AUTH_SERVICE_URL}/api/v1/auth/verify",
|
||||
headers={"Authorization": f"Bearer {credentials.credentials}"}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
user_data = response.json()
|
||||
logger.info("Authenticated via direct token", user_id=user_data.get("user_id"))
|
||||
return AuthInfo(
|
||||
user_id=user_data["user_id"],
|
||||
email=user_data["email"],
|
||||
tenant_id=user_data["tenant_id"],
|
||||
roles=user_data.get("roles", ["user"])
|
||||
)
|
||||
else:
|
||||
logger.warning("Token verification failed", status_code=response.status_code)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authentication credentials"
|
||||
)
|
||||
except httpx.RequestError as e:
|
||||
logger.error("Auth service unavailable", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Authentication service unavailable"
|
||||
)
|
||||
@@ -12,7 +12,10 @@ from typing import List, Optional
|
||||
from datetime import date
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.auth import get_current_user_from_headers
|
||||
from shared.auth.decorators import (
|
||||
get_current_user_dep,
|
||||
get_current_tenant_id_dep
|
||||
)
|
||||
from app.services.forecasting_service import ForecastingService
|
||||
from app.schemas.forecasts import (
|
||||
ForecastRequest, ForecastResponse, BatchForecastRequest,
|
||||
@@ -30,13 +33,14 @@ forecasting_service = ForecastingService()
|
||||
async def create_single_forecast(
|
||||
request: ForecastRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user_from_headers)
|
||||
tenant_id: str = Depends(get_current_tenant_id_dep),
|
||||
current_user: dict = Depends(get_current_user_dep)
|
||||
):
|
||||
"""Generate a single product forecast"""
|
||||
|
||||
try:
|
||||
# Verify tenant access
|
||||
if str(request.tenant_id) != str(current_user.get("tenant_id")):
|
||||
if str(request.tenant_id) != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to this tenant"
|
||||
@@ -88,13 +92,14 @@ async def create_single_forecast(
|
||||
async def create_batch_forecast(
|
||||
request: BatchForecastRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user_from_headers)
|
||||
tenant_id: str = Depends(get_current_tenant_id_dep),
|
||||
current_user: dict = Depends(get_current_user_dep)
|
||||
):
|
||||
"""Generate batch forecasts for multiple products"""
|
||||
|
||||
try:
|
||||
# Verify tenant access
|
||||
if str(request.tenant_id) != str(current_user.get("tenant_id")):
|
||||
if str(request.tenant_id) != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to this tenant"
|
||||
@@ -172,12 +177,11 @@ async def list_forecasts(
|
||||
end_date: Optional[date] = Query(None),
|
||||
product_name: Optional[str] = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user_from_headers)
|
||||
tenant_id: str = Depends(get_current_tenant_id_dep)
|
||||
):
|
||||
"""List forecasts with filtering"""
|
||||
|
||||
try:
|
||||
tenant_id = str(current_user.get("tenant_id"))
|
||||
|
||||
# Get forecasts
|
||||
forecasts = await forecasting_service.get_forecasts(
|
||||
@@ -230,15 +234,14 @@ async def list_forecasts(
|
||||
async def get_forecast_alerts(
|
||||
active_only: bool = Query(True),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user_from_headers)
|
||||
tenant_id: str = Depends(get_current_tenant_id_dep),
|
||||
current_user: dict = Depends(get_current_user_dep)
|
||||
):
|
||||
"""Get forecast alerts for tenant"""
|
||||
|
||||
try:
|
||||
from sqlalchemy import select, and_
|
||||
|
||||
tenant_id = current_user.get("tenant_id")
|
||||
|
||||
# Build query
|
||||
query = select(ForecastAlert).where(
|
||||
ForecastAlert.tenant_id == tenant_id
|
||||
@@ -281,7 +284,8 @@ async def get_forecast_alerts(
|
||||
async def acknowledge_alert(
|
||||
alert_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user_from_headers)
|
||||
tenant_id: str = Depends(get_current_tenant_id_dep),
|
||||
current_user: dict = Depends(get_current_user_dep)
|
||||
):
|
||||
"""Acknowledge a forecast alert"""
|
||||
|
||||
@@ -289,8 +293,6 @@ async def acknowledge_alert(
|
||||
from sqlalchemy import select, update
|
||||
from datetime import datetime
|
||||
|
||||
tenant_id = current_user.get("tenant_id")
|
||||
|
||||
# Get alert
|
||||
result = await db.execute(
|
||||
select(ForecastAlert).where(
|
||||
|
||||
@@ -12,7 +12,10 @@ from typing import List, Dict, Any
|
||||
from datetime import date, datetime, timedelta
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.auth import get_current_user_from_headers
|
||||
from shared.auth.decorators import (
|
||||
get_current_user_dep,
|
||||
get_current_tenant_id_dep
|
||||
)
|
||||
from app.services.prediction_service import PredictionService
|
||||
from app.schemas.forecasts import ForecastRequest
|
||||
|
||||
@@ -28,12 +31,11 @@ async def get_realtime_prediction(
|
||||
location: str,
|
||||
forecast_date: date,
|
||||
features: Dict[str, Any],
|
||||
current_user: dict = Depends(get_current_user_from_headers)
|
||||
tenant_id: str = Depends(get_current_tenant_id_dep)
|
||||
):
|
||||
"""Get real-time prediction without storing in database"""
|
||||
|
||||
try:
|
||||
tenant_id = str(current_user.get("tenant_id"))
|
||||
|
||||
# Get latest model
|
||||
from app.services.forecasting_service import ForecastingService
|
||||
@@ -83,13 +85,12 @@ async def get_quick_prediction(
|
||||
product_name: str,
|
||||
location: str = Query(...),
|
||||
days_ahead: int = Query(1, ge=1, le=7),
|
||||
current_user: dict = Depends(get_current_user_from_headers)
|
||||
tenant_id: str = Depends(get_current_tenant_id_dep)
|
||||
):
|
||||
"""Get quick prediction for next few days"""
|
||||
|
||||
try:
|
||||
tenant_id = str(current_user.get("tenant_id"))
|
||||
|
||||
|
||||
# Generate predictions for the next N days
|
||||
predictions = []
|
||||
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/core/auth.py
|
||||
# ================================================================
|
||||
"""
|
||||
Authentication utilities for forecasting service
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from fastapi import HTTPException, status, Request
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
async def get_current_user_from_headers(request: Request) -> Dict[str, Any]:
|
||||
"""
|
||||
Get current user from gateway headers
|
||||
Gateway middleware adds user context to headers after JWT verification
|
||||
"""
|
||||
|
||||
try:
|
||||
# Extract user information from headers set by API Gateway
|
||||
user_id = request.headers.get("X-User-ID")
|
||||
user_email = request.headers.get("X-User-Email")
|
||||
tenant_id = request.headers.get("X-Tenant-ID")
|
||||
user_roles = request.headers.get("X-User-Roles", "").split(",")
|
||||
|
||||
if not user_id or not tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required"
|
||||
)
|
||||
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"email": user_email,
|
||||
"tenant_id": tenant_id,
|
||||
"roles": [role.strip() for role in user_roles if role.strip()]
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Error extracting user from headers", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authentication"
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/services/messaging.py
|
||||
# services/forecasting/app/services/messaging.py
|
||||
# ================================================================
|
||||
"""
|
||||
Messaging service for event publishing and consuming
|
||||
@@ -10,89 +10,126 @@ import json
|
||||
from typing import Dict, Any
|
||||
import asyncio
|
||||
|
||||
from shared.messaging.rabbitmq import RabbitMQPublisher, RabbitMQConsumer
|
||||
from shared.messaging.rabbitmq import RabbitMQClient
|
||||
from shared.messaging.events import (
|
||||
TrainingCompletedEvent,
|
||||
DataImportedEvent,
|
||||
ForecastGeneratedEvent,
|
||||
)
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Global messaging instances
|
||||
publisher = None
|
||||
consumer = None
|
||||
# Global messaging instance
|
||||
rabbitmq_client = None
|
||||
|
||||
async def setup_messaging():
|
||||
"""Initialize messaging services"""
|
||||
global publisher, consumer
|
||||
|
||||
global rabbitmq_client
|
||||
|
||||
try:
|
||||
# Initialize publisher
|
||||
publisher = RabbitMQPublisher(settings.RABBITMQ_URL)
|
||||
await publisher.connect()
|
||||
|
||||
# Initialize consumer
|
||||
consumer = RabbitMQConsumer(settings.RABBITMQ_URL)
|
||||
await consumer.connect()
|
||||
|
||||
rabbitmq_client = RabbitMQClient(settings.RABBITMQ_URL, service_name="forecasting_service")
|
||||
await rabbitmq_client.connect()
|
||||
|
||||
# Set up event handlers
|
||||
await consumer.subscribe("training.model.updated", handle_model_updated)
|
||||
await consumer.subscribe("data.weather.updated", handle_weather_updated)
|
||||
|
||||
# We need to adapt the callback to accept aio_pika.IncomingMessage
|
||||
await rabbitmq_client.consume_events(
|
||||
exchange_name="training.events",
|
||||
queue_name="forecasting_model_update_queue",
|
||||
routing_key="training.completed", # Assuming model updates are part of training.completed events
|
||||
callback=handle_model_updated_message
|
||||
)
|
||||
await rabbitmq_client.consume_events(
|
||||
exchange_name="data.events",
|
||||
queue_name="forecasting_weather_update_queue",
|
||||
routing_key="data.weather.updated", # This needs to match the actual event type if different
|
||||
callback=handle_weather_updated_message
|
||||
)
|
||||
|
||||
logger.info("Messaging setup completed")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to setup messaging", error=str(e))
|
||||
raise
|
||||
|
||||
async def cleanup_messaging():
|
||||
"""Cleanup messaging connections"""
|
||||
global publisher, consumer
|
||||
|
||||
global rabbitmq_client
|
||||
|
||||
try:
|
||||
if consumer:
|
||||
await consumer.close()
|
||||
if publisher:
|
||||
await publisher.close()
|
||||
|
||||
if rabbitmq_client:
|
||||
await rabbitmq_client.disconnect()
|
||||
|
||||
logger.info("Messaging cleanup completed")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error during messaging cleanup", error=str(e))
|
||||
|
||||
async def publish_forecast_completed(data: Dict[str, Any]):
|
||||
"""Publish forecast completed event"""
|
||||
if publisher:
|
||||
await publisher.publish("forecasting.forecast.completed", data)
|
||||
if rabbitmq_client:
|
||||
event = ForecastGeneratedEvent(service_name="forecasting_service", data=data, event_type="forecast.completed")
|
||||
await rabbitmq_client.publish_forecast_event(event_type="completed", forecast_data=event.to_dict())
|
||||
|
||||
async def publish_alert_created(data: Dict[str, Any]):
|
||||
"""Publish alert created event"""
|
||||
if publisher:
|
||||
await publisher.publish("forecasting.alert.created", data)
|
||||
# Assuming 'alert.created' is a type of forecast event, or define a new exchange/publisher method
|
||||
if rabbitmq_client:
|
||||
event = ForecastGeneratedEvent(service_name="forecasting_service", data=data, event_type="alert.created")
|
||||
await rabbitmq_client.publish_forecast_event(event_type="alert.created", forecast_data=event.to_dict())
|
||||
|
||||
async def publish_batch_completed(data: Dict[str, Any]):
|
||||
"""Publish batch forecast completed event"""
|
||||
if publisher:
|
||||
await publisher.publish("forecasting.batch.completed", data)
|
||||
if rabbitmq_client:
|
||||
event = ForecastGeneratedEvent(service_name="forecasting_service", data=data, event_type="forecast.batch.completed")
|
||||
await rabbitmq_client.publish_forecast_event(event_type="batch.completed", forecast_data=event.to_dict())
|
||||
|
||||
# Event handlers
|
||||
|
||||
# Event handler wrappers for aio_pika messages
|
||||
async def handle_model_updated_message(message: Any):
|
||||
async with message.process():
|
||||
try:
|
||||
event_data = json.loads(message.body.decode())
|
||||
# Assuming the actual event data is nested under a 'data' key within the event dictionary
|
||||
await handle_model_updated(event_data.get("data", {}))
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error("Failed to decode model updated message JSON", error=str(e), body=message.body)
|
||||
except Exception as e:
|
||||
logger.error("Error processing model updated message", error=str(e), body=message.body)
|
||||
|
||||
async def handle_weather_updated_message(message: Any):
|
||||
async with message.process():
|
||||
try:
|
||||
event_data = json.loads(message.body.decode())
|
||||
# Assuming the actual event data is nested under a 'data' key within the event dictionary
|
||||
await handle_weather_updated(event_data.get("data", {}))
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error("Failed to decode weather updated message JSON", error=str(e), body=message.body)
|
||||
except Exception as e:
|
||||
logger.error("Error processing weather updated message", error=str(e), body=message.body)
|
||||
|
||||
|
||||
# Original Event handlers (now called from the message wrappers)
|
||||
async def handle_model_updated(data: Dict[str, Any]):
|
||||
"""Handle model updated event from training service"""
|
||||
try:
|
||||
logger.info("Received model updated event",
|
||||
logger.info("Received model updated event",
|
||||
model_id=data.get("model_id"),
|
||||
tenant_id=data.get("tenant_id"))
|
||||
|
||||
|
||||
# Clear model cache for this model
|
||||
# This will be handled by PredictionService
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error handling model updated event", error=str(e))
|
||||
|
||||
async def handle_weather_updated(data: Dict[str, Any]):
|
||||
"""Handle weather data updated event"""
|
||||
try:
|
||||
logger.info("Received weather updated event",
|
||||
logger.info("Received weather updated event",
|
||||
date=data.get("date"))
|
||||
|
||||
|
||||
# Could trigger re-forecasting if needed
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error handling weather updated event", error=str(e))
|
||||
logger.error("Error handling weather updated event", error=str(e))
|
||||
@@ -1,22 +0,0 @@
|
||||
"""
|
||||
Authentication configuration for notification service
|
||||
"""
|
||||
|
||||
from shared.auth.jwt_handler import JWTHandler
|
||||
from shared.auth.decorators import require_auth, require_role
|
||||
from app.core.config import settings
|
||||
|
||||
# Initialize JWT handler
|
||||
jwt_handler = JWTHandler(
|
||||
secret_key=settings.JWT_SECRET_KEY,
|
||||
algorithm=settings.JWT_ALGORITHM,
|
||||
access_token_expire_minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
|
||||
# Export commonly used functions
|
||||
verify_token = jwt_handler.verify_token
|
||||
create_access_token = jwt_handler.create_access_token
|
||||
get_current_user = jwt_handler.get_current_user
|
||||
|
||||
# Export decorators
|
||||
__all__ = ['verify_token', 'create_access_token', 'get_current_user', 'require_auth', 'require_role']
|
||||
@@ -1,22 +0,0 @@
|
||||
"""
|
||||
Authentication configuration for tenant service
|
||||
"""
|
||||
|
||||
from shared.auth.jwt_handler import JWTHandler
|
||||
from shared.auth.decorators import require_auth, require_role
|
||||
from app.core.config import settings
|
||||
|
||||
# Initialize JWT handler
|
||||
jwt_handler = JWTHandler(
|
||||
secret_key=settings.JWT_SECRET_KEY,
|
||||
algorithm=settings.JWT_ALGORITHM,
|
||||
access_token_expire_minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
|
||||
# Export commonly used functions
|
||||
verify_token = jwt_handler.verify_token
|
||||
create_access_token = jwt_handler.create_access_token
|
||||
get_current_user = jwt_handler.get_current_user
|
||||
|
||||
# Export decorators
|
||||
__all__ = ['verify_token', 'create_access_token', 'get_current_user', 'require_auth', 'require_role']
|
||||
@@ -8,10 +8,12 @@ from typing import List
|
||||
import structlog
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.auth import get_current_tenant_id
|
||||
from app.schemas.training import TrainedModelResponse
|
||||
from app.services.training_service import TrainingService
|
||||
|
||||
from shared.auth.decorators import (
|
||||
get_current_tenant_id_dep
|
||||
)
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter()
|
||||
@@ -20,7 +22,7 @@ training_service = TrainingService()
|
||||
|
||||
@router.get("/", response_model=List[TrainedModelResponse])
|
||||
async def get_trained_models(
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
tenant_id: str = Depends(get_current_tenant_id_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get trained models"""
|
||||
|
||||
@@ -1,303 +0,0 @@
|
||||
# services/training/app/core/auth.py
|
||||
"""
|
||||
Authentication and authorization for training service
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from typing import Optional
|
||||
from fastapi import HTTPException, Depends, status
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
import httpx
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# HTTP Bearer token scheme
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
class AuthenticationError(Exception):
|
||||
"""Custom exception for authentication errors"""
|
||||
pass
|
||||
|
||||
class AuthorizationError(Exception):
|
||||
"""Custom exception for authorization errors"""
|
||||
pass
|
||||
|
||||
async def verify_token(token: str) -> dict:
|
||||
"""
|
||||
Verify JWT token with auth service
|
||||
|
||||
Args:
|
||||
token: JWT token to verify
|
||||
|
||||
Returns:
|
||||
dict: Token payload with user and tenant information
|
||||
|
||||
Raises:
|
||||
AuthenticationError: If token is invalid
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{settings.AUTH_SERVICE_URL}/auth/verify",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
token_data = response.json()
|
||||
logger.debug("Token verified successfully", user_id=token_data.get("user_id"))
|
||||
return token_data
|
||||
elif response.status_code == 401:
|
||||
logger.warning("Invalid token provided")
|
||||
raise AuthenticationError("Invalid or expired token")
|
||||
else:
|
||||
logger.error("Auth service error", status_code=response.status_code)
|
||||
raise AuthenticationError("Authentication service unavailable")
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.error("Auth service timeout")
|
||||
raise AuthenticationError("Authentication service timeout")
|
||||
except httpx.RequestError as e:
|
||||
logger.error("Auth service request error", error=str(e))
|
||||
raise AuthenticationError("Authentication service unavailable")
|
||||
except AuthenticationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Unexpected auth error", error=str(e))
|
||||
raise AuthenticationError("Authentication failed")
|
||||
|
||||
async def get_current_user(
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
|
||||
) -> dict:
|
||||
"""
|
||||
Get current authenticated user
|
||||
|
||||
Args:
|
||||
credentials: HTTP Bearer credentials
|
||||
|
||||
Returns:
|
||||
dict: User information
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails
|
||||
"""
|
||||
if not credentials:
|
||||
logger.warning("No credentials provided")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication credentials required",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
try:
|
||||
token_data = await verify_token(credentials.credentials)
|
||||
return token_data
|
||||
|
||||
except AuthenticationError as e:
|
||||
logger.warning("Authentication failed", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
async def get_current_tenant_id(
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> str:
|
||||
"""
|
||||
Get current tenant ID from authenticated user
|
||||
|
||||
Args:
|
||||
current_user: Current authenticated user data
|
||||
|
||||
Returns:
|
||||
str: Tenant ID
|
||||
|
||||
Raises:
|
||||
HTTPException: If tenant ID is missing
|
||||
"""
|
||||
tenant_id = current_user.get("tenant_id")
|
||||
if not tenant_id:
|
||||
logger.error("Missing tenant_id in token", user_data=current_user)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid token: missing tenant information"
|
||||
)
|
||||
|
||||
return tenant_id
|
||||
|
||||
async def require_admin_role(
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> dict:
|
||||
"""
|
||||
Require admin role for endpoint access
|
||||
|
||||
Args:
|
||||
current_user: Current authenticated user data
|
||||
|
||||
Returns:
|
||||
dict: User information
|
||||
|
||||
Raises:
|
||||
HTTPException: If user is not admin
|
||||
"""
|
||||
user_role = current_user.get("role", "").lower()
|
||||
if user_role != "admin":
|
||||
logger.warning("Access denied - admin role required",
|
||||
user_id=current_user.get("user_id"),
|
||||
role=user_role)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Admin role required"
|
||||
)
|
||||
|
||||
return current_user
|
||||
|
||||
async def require_training_permission(
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> dict:
|
||||
"""
|
||||
Require training permission for endpoint access
|
||||
|
||||
Args:
|
||||
current_user: Current authenticated user data
|
||||
|
||||
Returns:
|
||||
dict: User information
|
||||
|
||||
Raises:
|
||||
HTTPException: If user doesn't have training permission
|
||||
"""
|
||||
permissions = current_user.get("permissions", [])
|
||||
if "training" not in permissions and current_user.get("role", "").lower() != "admin":
|
||||
logger.warning("Access denied - training permission required",
|
||||
user_id=current_user.get("user_id"),
|
||||
permissions=permissions)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Training permission required"
|
||||
)
|
||||
|
||||
return current_user
|
||||
|
||||
# Optional authentication for development/testing
|
||||
async def get_current_user_optional(
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get current user but don't require authentication (for development)
|
||||
|
||||
Args:
|
||||
credentials: HTTP Bearer credentials
|
||||
|
||||
Returns:
|
||||
dict or None: User information if authenticated, None otherwise
|
||||
"""
|
||||
if not credentials:
|
||||
return None
|
||||
|
||||
try:
|
||||
token_data = await verify_token(credentials.credentials)
|
||||
return token_data
|
||||
except AuthenticationError:
|
||||
return None
|
||||
|
||||
async def get_tenant_id_optional(
|
||||
current_user: Optional[dict] = Depends(get_current_user_optional)
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get tenant ID but don't require authentication (for development)
|
||||
|
||||
Args:
|
||||
current_user: Current user data (optional)
|
||||
|
||||
Returns:
|
||||
str or None: Tenant ID if available, None otherwise
|
||||
"""
|
||||
if not current_user:
|
||||
return None
|
||||
|
||||
return current_user.get("tenant_id")
|
||||
|
||||
# Development/testing auth bypass
|
||||
async def get_test_tenant_id() -> str:
|
||||
"""
|
||||
Get test tenant ID for development/testing
|
||||
Only works when DEBUG is enabled
|
||||
|
||||
Returns:
|
||||
str: Test tenant ID
|
||||
"""
|
||||
if settings.DEBUG:
|
||||
return "test-tenant-development"
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Test authentication only available in debug mode"
|
||||
)
|
||||
|
||||
# Token validation utility
|
||||
def validate_token_structure(token_data: dict) -> bool:
|
||||
"""
|
||||
Validate that token data has required structure
|
||||
|
||||
Args:
|
||||
token_data: Token payload data
|
||||
|
||||
Returns:
|
||||
bool: True if valid structure, False otherwise
|
||||
"""
|
||||
required_fields = ["user_id", "tenant_id"]
|
||||
|
||||
for field in required_fields:
|
||||
if field not in token_data:
|
||||
logger.warning("Invalid token structure - missing field", field=field)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
# Role checking utilities
|
||||
def has_role(user_data: dict, required_role: str) -> bool:
|
||||
"""
|
||||
Check if user has required role
|
||||
|
||||
Args:
|
||||
user_data: User data from token
|
||||
required_role: Required role name
|
||||
|
||||
Returns:
|
||||
bool: True if user has role, False otherwise
|
||||
"""
|
||||
user_role = user_data.get("role", "").lower()
|
||||
return user_role == required_role.lower()
|
||||
|
||||
def has_permission(user_data: dict, required_permission: str) -> bool:
|
||||
"""
|
||||
Check if user has required permission
|
||||
|
||||
Args:
|
||||
user_data: User data from token
|
||||
required_permission: Required permission name
|
||||
|
||||
Returns:
|
||||
bool: True if user has permission, False otherwise
|
||||
"""
|
||||
permissions = user_data.get("permissions", [])
|
||||
return required_permission in permissions or has_role(user_data, "admin")
|
||||
|
||||
# Export commonly used items
|
||||
__all__ = [
|
||||
'get_current_user',
|
||||
'get_current_tenant_id',
|
||||
'require_admin_role',
|
||||
'require_training_permission',
|
||||
'get_current_user_optional',
|
||||
'get_tenant_id_optional',
|
||||
'get_test_tenant_id',
|
||||
'has_role',
|
||||
'has_permission',
|
||||
'AuthenticationError',
|
||||
'AuthorizationError'
|
||||
]
|
||||
Reference in New Issue
Block a user