Fix forecasting service
This commit is contained in:
@@ -11,7 +11,6 @@ from app.core.database import get_db
|
|||||||
from app.schemas.auth import UserResponse, PasswordChange
|
from app.schemas.auth import UserResponse, PasswordChange
|
||||||
from app.schemas.users import UserUpdate
|
from app.schemas.users import UserUpdate
|
||||||
from app.services.user_service import UserService
|
from app.services.user_service import UserService
|
||||||
from app.core.auth import get_current_user
|
|
||||||
from app.models.users import User
|
from app.models.users import User
|
||||||
|
|
||||||
# Import unified authentication from shared library
|
# Import unified authentication from shared library
|
||||||
@@ -53,7 +52,7 @@ async def get_current_user_info(
|
|||||||
@router.put("/me", response_model=UserResponse)
|
@router.put("/me", response_model=UserResponse)
|
||||||
async def update_current_user(
|
async def update_current_user(
|
||||||
user_update: UserUpdate,
|
user_update: UserUpdate,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""Update current user information"""
|
"""Update current user information"""
|
||||||
@@ -83,7 +82,7 @@ async def update_current_user(
|
|||||||
@router.post("/change-password")
|
@router.post("/change-password")
|
||||||
async def change_password(
|
async def change_password(
|
||||||
password_data: PasswordChange,
|
password_data: PasswordChange,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""Change user password"""
|
"""Change user password"""
|
||||||
@@ -106,7 +105,7 @@ async def change_password(
|
|||||||
|
|
||||||
@router.delete("/me")
|
@router.delete("/me")
|
||||||
async def delete_current_user(
|
async def delete_current_user(
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""Delete current user account"""
|
"""Delete current user account"""
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from datetime import datetime, timedelta
|
|||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
from app.core.auth import get_current_user, AuthInfo
|
|
||||||
from app.services.traffic_service import TrafficService
|
from app.services.traffic_service import TrafficService
|
||||||
from app.services.messaging import data_publisher
|
from app.services.messaging import data_publisher
|
||||||
from app.schemas.external import (
|
from app.schemas.external import (
|
||||||
|
|||||||
@@ -1,72 +0,0 @@
|
|||||||
from fastapi import HTTPException, Depends, status, Request
|
|
||||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
||||||
import httpx
|
|
||||||
import structlog
|
|
||||||
from typing import Dict, Any, Optional
|
|
||||||
|
|
||||||
from app.core.config import settings
|
|
||||||
|
|
||||||
logger = structlog.get_logger()
|
|
||||||
security = HTTPBearer(auto_error=False) # ✅ Don't auto-error, we'll handle manually
|
|
||||||
|
|
||||||
class AuthInfo:
|
|
||||||
"""Authentication information"""
|
|
||||||
def __init__(self, user_id: str, email: str, tenant_id: str, roles: list):
|
|
||||||
self.user_id = user_id
|
|
||||||
self.email = email
|
|
||||||
self.tenant_id = tenant_id
|
|
||||||
self.roles = roles
|
|
||||||
|
|
||||||
async def get_current_user(
|
|
||||||
request: Request,
|
|
||||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
|
|
||||||
) -> AuthInfo:
|
|
||||||
"""Get current user from gateway headers or token verification"""
|
|
||||||
|
|
||||||
# ✅ OPTION 1: Check for gateway headers (preferred when using gateway)
|
|
||||||
user_id = request.headers.get("X-User-ID")
|
|
||||||
email = request.headers.get("X-User-Email")
|
|
||||||
tenant_id = request.headers.get("X-Tenant-ID")
|
|
||||||
roles_header = request.headers.get("X-User-Roles", "")
|
|
||||||
|
|
||||||
if user_id and email and tenant_id:
|
|
||||||
# Gateway already authenticated the user
|
|
||||||
roles = roles_header.split(",") if roles_header else ["user"]
|
|
||||||
logger.info("Authenticated via gateway headers", user_id=user_id, email=email)
|
|
||||||
return AuthInfo(user_id, email, tenant_id, roles)
|
|
||||||
|
|
||||||
# ✅ OPTION 2: Direct token verification (when not using gateway)
|
|
||||||
if not credentials:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Authentication required (no token or gateway headers)"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
|
||||||
response = await client.post(
|
|
||||||
f"{settings.AUTH_SERVICE_URL}/api/v1/auth/verify",
|
|
||||||
headers={"Authorization": f"Bearer {credentials.credentials}"}
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
user_data = response.json()
|
|
||||||
logger.info("Authenticated via direct token", user_id=user_data.get("user_id"))
|
|
||||||
return AuthInfo(
|
|
||||||
user_id=user_data["user_id"],
|
|
||||||
email=user_data["email"],
|
|
||||||
tenant_id=user_data["tenant_id"],
|
|
||||||
roles=user_data.get("roles", ["user"])
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning("Token verification failed", status_code=response.status_code)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Invalid authentication credentials"
|
|
||||||
)
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
logger.error("Auth service unavailable", error=str(e))
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
||||||
detail="Authentication service unavailable"
|
|
||||||
)
|
|
||||||
@@ -12,7 +12,10 @@ from typing import List, Optional
|
|||||||
from datetime import date
|
from datetime import date
|
||||||
|
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
from app.core.auth import get_current_user_from_headers
|
from shared.auth.decorators import (
|
||||||
|
get_current_user_dep,
|
||||||
|
get_current_tenant_id_dep
|
||||||
|
)
|
||||||
from app.services.forecasting_service import ForecastingService
|
from app.services.forecasting_service import ForecastingService
|
||||||
from app.schemas.forecasts import (
|
from app.schemas.forecasts import (
|
||||||
ForecastRequest, ForecastResponse, BatchForecastRequest,
|
ForecastRequest, ForecastResponse, BatchForecastRequest,
|
||||||
@@ -30,13 +33,14 @@ forecasting_service = ForecastingService()
|
|||||||
async def create_single_forecast(
|
async def create_single_forecast(
|
||||||
request: ForecastRequest,
|
request: ForecastRequest,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
current_user: dict = Depends(get_current_user_from_headers)
|
tenant_id: str = Depends(get_current_tenant_id_dep),
|
||||||
|
current_user: dict = Depends(get_current_user_dep)
|
||||||
):
|
):
|
||||||
"""Generate a single product forecast"""
|
"""Generate a single product forecast"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Verify tenant access
|
# Verify tenant access
|
||||||
if str(request.tenant_id) != str(current_user.get("tenant_id")):
|
if str(request.tenant_id) != tenant_id:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail="Access denied to this tenant"
|
detail="Access denied to this tenant"
|
||||||
@@ -88,13 +92,14 @@ async def create_single_forecast(
|
|||||||
async def create_batch_forecast(
|
async def create_batch_forecast(
|
||||||
request: BatchForecastRequest,
|
request: BatchForecastRequest,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
current_user: dict = Depends(get_current_user_from_headers)
|
tenant_id: str = Depends(get_current_tenant_id_dep),
|
||||||
|
current_user: dict = Depends(get_current_user_dep)
|
||||||
):
|
):
|
||||||
"""Generate batch forecasts for multiple products"""
|
"""Generate batch forecasts for multiple products"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Verify tenant access
|
# Verify tenant access
|
||||||
if str(request.tenant_id) != str(current_user.get("tenant_id")):
|
if str(request.tenant_id) != tenant_id:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail="Access denied to this tenant"
|
detail="Access denied to this tenant"
|
||||||
@@ -172,12 +177,11 @@ async def list_forecasts(
|
|||||||
end_date: Optional[date] = Query(None),
|
end_date: Optional[date] = Query(None),
|
||||||
product_name: Optional[str] = Query(None),
|
product_name: Optional[str] = Query(None),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
current_user: dict = Depends(get_current_user_from_headers)
|
tenant_id: str = Depends(get_current_tenant_id_dep)
|
||||||
):
|
):
|
||||||
"""List forecasts with filtering"""
|
"""List forecasts with filtering"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tenant_id = str(current_user.get("tenant_id"))
|
|
||||||
|
|
||||||
# Get forecasts
|
# Get forecasts
|
||||||
forecasts = await forecasting_service.get_forecasts(
|
forecasts = await forecasting_service.get_forecasts(
|
||||||
@@ -230,15 +234,14 @@ async def list_forecasts(
|
|||||||
async def get_forecast_alerts(
|
async def get_forecast_alerts(
|
||||||
active_only: bool = Query(True),
|
active_only: bool = Query(True),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
current_user: dict = Depends(get_current_user_from_headers)
|
tenant_id: str = Depends(get_current_tenant_id_dep),
|
||||||
|
current_user: dict = Depends(get_current_user_dep)
|
||||||
):
|
):
|
||||||
"""Get forecast alerts for tenant"""
|
"""Get forecast alerts for tenant"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from sqlalchemy import select, and_
|
from sqlalchemy import select, and_
|
||||||
|
|
||||||
tenant_id = current_user.get("tenant_id")
|
|
||||||
|
|
||||||
# Build query
|
# Build query
|
||||||
query = select(ForecastAlert).where(
|
query = select(ForecastAlert).where(
|
||||||
ForecastAlert.tenant_id == tenant_id
|
ForecastAlert.tenant_id == tenant_id
|
||||||
@@ -281,7 +284,8 @@ async def get_forecast_alerts(
|
|||||||
async def acknowledge_alert(
|
async def acknowledge_alert(
|
||||||
alert_id: str,
|
alert_id: str,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
current_user: dict = Depends(get_current_user_from_headers)
|
tenant_id: str = Depends(get_current_tenant_id_dep),
|
||||||
|
current_user: dict = Depends(get_current_user_dep)
|
||||||
):
|
):
|
||||||
"""Acknowledge a forecast alert"""
|
"""Acknowledge a forecast alert"""
|
||||||
|
|
||||||
@@ -289,8 +293,6 @@ async def acknowledge_alert(
|
|||||||
from sqlalchemy import select, update
|
from sqlalchemy import select, update
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
tenant_id = current_user.get("tenant_id")
|
|
||||||
|
|
||||||
# Get alert
|
# Get alert
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(ForecastAlert).where(
|
select(ForecastAlert).where(
|
||||||
|
|||||||
@@ -12,7 +12,10 @@ from typing import List, Dict, Any
|
|||||||
from datetime import date, datetime, timedelta
|
from datetime import date, datetime, timedelta
|
||||||
|
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
from app.core.auth import get_current_user_from_headers
|
from shared.auth.decorators import (
|
||||||
|
get_current_user_dep,
|
||||||
|
get_current_tenant_id_dep
|
||||||
|
)
|
||||||
from app.services.prediction_service import PredictionService
|
from app.services.prediction_service import PredictionService
|
||||||
from app.schemas.forecasts import ForecastRequest
|
from app.schemas.forecasts import ForecastRequest
|
||||||
|
|
||||||
@@ -28,12 +31,11 @@ async def get_realtime_prediction(
|
|||||||
location: str,
|
location: str,
|
||||||
forecast_date: date,
|
forecast_date: date,
|
||||||
features: Dict[str, Any],
|
features: Dict[str, Any],
|
||||||
current_user: dict = Depends(get_current_user_from_headers)
|
tenant_id: str = Depends(get_current_tenant_id_dep)
|
||||||
):
|
):
|
||||||
"""Get real-time prediction without storing in database"""
|
"""Get real-time prediction without storing in database"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tenant_id = str(current_user.get("tenant_id"))
|
|
||||||
|
|
||||||
# Get latest model
|
# Get latest model
|
||||||
from app.services.forecasting_service import ForecastingService
|
from app.services.forecasting_service import ForecastingService
|
||||||
@@ -83,12 +85,11 @@ async def get_quick_prediction(
|
|||||||
product_name: str,
|
product_name: str,
|
||||||
location: str = Query(...),
|
location: str = Query(...),
|
||||||
days_ahead: int = Query(1, ge=1, le=7),
|
days_ahead: int = Query(1, ge=1, le=7),
|
||||||
current_user: dict = Depends(get_current_user_from_headers)
|
tenant_id: str = Depends(get_current_tenant_id_dep)
|
||||||
):
|
):
|
||||||
"""Get quick prediction for next few days"""
|
"""Get quick prediction for next few days"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tenant_id = str(current_user.get("tenant_id"))
|
|
||||||
|
|
||||||
# Generate predictions for the next N days
|
# Generate predictions for the next N days
|
||||||
predictions = []
|
predictions = []
|
||||||
|
|||||||
@@ -1,48 +0,0 @@
|
|||||||
# ================================================================
|
|
||||||
# services/forecasting/app/core/auth.py
|
|
||||||
# ================================================================
|
|
||||||
"""
|
|
||||||
Authentication utilities for forecasting service
|
|
||||||
"""
|
|
||||||
|
|
||||||
import structlog
|
|
||||||
from fastapi import HTTPException, status, Request
|
|
||||||
from typing import Dict, Any, Optional
|
|
||||||
|
|
||||||
logger = structlog.get_logger()
|
|
||||||
|
|
||||||
async def get_current_user_from_headers(request: Request) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Get current user from gateway headers
|
|
||||||
Gateway middleware adds user context to headers after JWT verification
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Extract user information from headers set by API Gateway
|
|
||||||
user_id = request.headers.get("X-User-ID")
|
|
||||||
user_email = request.headers.get("X-User-Email")
|
|
||||||
tenant_id = request.headers.get("X-Tenant-ID")
|
|
||||||
user_roles = request.headers.get("X-User-Roles", "").split(",")
|
|
||||||
|
|
||||||
if not user_id or not tenant_id:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Authentication required"
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"user_id": user_id,
|
|
||||||
"email": user_email,
|
|
||||||
"tenant_id": tenant_id,
|
|
||||||
"roles": [role.strip() for role in user_roles if role.strip()]
|
|
||||||
}
|
|
||||||
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Error extracting user from headers", error=str(e))
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Invalid authentication"
|
|
||||||
)
|
|
||||||
|
|
||||||
@@ -10,31 +10,41 @@ import json
|
|||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from shared.messaging.rabbitmq import RabbitMQPublisher, RabbitMQConsumer
|
from shared.messaging.rabbitmq import RabbitMQClient
|
||||||
|
from shared.messaging.events import (
|
||||||
|
TrainingCompletedEvent,
|
||||||
|
DataImportedEvent,
|
||||||
|
ForecastGeneratedEvent,
|
||||||
|
)
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
|
||||||
logger = structlog.get_logger()
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
# Global messaging instances
|
# Global messaging instance
|
||||||
publisher = None
|
rabbitmq_client = None
|
||||||
consumer = None
|
|
||||||
|
|
||||||
async def setup_messaging():
|
async def setup_messaging():
|
||||||
"""Initialize messaging services"""
|
"""Initialize messaging services"""
|
||||||
global publisher, consumer
|
global rabbitmq_client
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize publisher
|
rabbitmq_client = RabbitMQClient(settings.RABBITMQ_URL, service_name="forecasting_service")
|
||||||
publisher = RabbitMQPublisher(settings.RABBITMQ_URL)
|
await rabbitmq_client.connect()
|
||||||
await publisher.connect()
|
|
||||||
|
|
||||||
# Initialize consumer
|
|
||||||
consumer = RabbitMQConsumer(settings.RABBITMQ_URL)
|
|
||||||
await consumer.connect()
|
|
||||||
|
|
||||||
# Set up event handlers
|
# Set up event handlers
|
||||||
await consumer.subscribe("training.model.updated", handle_model_updated)
|
# We need to adapt the callback to accept aio_pika.IncomingMessage
|
||||||
await consumer.subscribe("data.weather.updated", handle_weather_updated)
|
await rabbitmq_client.consume_events(
|
||||||
|
exchange_name="training.events",
|
||||||
|
queue_name="forecasting_model_update_queue",
|
||||||
|
routing_key="training.completed", # Assuming model updates are part of training.completed events
|
||||||
|
callback=handle_model_updated_message
|
||||||
|
)
|
||||||
|
await rabbitmq_client.consume_events(
|
||||||
|
exchange_name="data.events",
|
||||||
|
queue_name="forecasting_weather_update_queue",
|
||||||
|
routing_key="data.weather.updated", # This needs to match the actual event type if different
|
||||||
|
callback=handle_weather_updated_message
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("Messaging setup completed")
|
logger.info("Messaging setup completed")
|
||||||
|
|
||||||
@@ -44,13 +54,11 @@ async def setup_messaging():
|
|||||||
|
|
||||||
async def cleanup_messaging():
|
async def cleanup_messaging():
|
||||||
"""Cleanup messaging connections"""
|
"""Cleanup messaging connections"""
|
||||||
global publisher, consumer
|
global rabbitmq_client
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if consumer:
|
if rabbitmq_client:
|
||||||
await consumer.close()
|
await rabbitmq_client.disconnect()
|
||||||
if publisher:
|
|
||||||
await publisher.close()
|
|
||||||
|
|
||||||
logger.info("Messaging cleanup completed")
|
logger.info("Messaging cleanup completed")
|
||||||
|
|
||||||
@@ -59,20 +67,49 @@ async def cleanup_messaging():
|
|||||||
|
|
||||||
async def publish_forecast_completed(data: Dict[str, Any]):
|
async def publish_forecast_completed(data: Dict[str, Any]):
|
||||||
"""Publish forecast completed event"""
|
"""Publish forecast completed event"""
|
||||||
if publisher:
|
if rabbitmq_client:
|
||||||
await publisher.publish("forecasting.forecast.completed", data)
|
event = ForecastGeneratedEvent(service_name="forecasting_service", data=data, event_type="forecast.completed")
|
||||||
|
await rabbitmq_client.publish_forecast_event(event_type="completed", forecast_data=event.to_dict())
|
||||||
|
|
||||||
async def publish_alert_created(data: Dict[str, Any]):
|
async def publish_alert_created(data: Dict[str, Any]):
|
||||||
"""Publish alert created event"""
|
"""Publish alert created event"""
|
||||||
if publisher:
|
# Assuming 'alert.created' is a type of forecast event, or define a new exchange/publisher method
|
||||||
await publisher.publish("forecasting.alert.created", data)
|
if rabbitmq_client:
|
||||||
|
event = ForecastGeneratedEvent(service_name="forecasting_service", data=data, event_type="alert.created")
|
||||||
|
await rabbitmq_client.publish_forecast_event(event_type="alert.created", forecast_data=event.to_dict())
|
||||||
|
|
||||||
async def publish_batch_completed(data: Dict[str, Any]):
|
async def publish_batch_completed(data: Dict[str, Any]):
|
||||||
"""Publish batch forecast completed event"""
|
"""Publish batch forecast completed event"""
|
||||||
if publisher:
|
if rabbitmq_client:
|
||||||
await publisher.publish("forecasting.batch.completed", data)
|
event = ForecastGeneratedEvent(service_name="forecasting_service", data=data, event_type="forecast.batch.completed")
|
||||||
|
await rabbitmq_client.publish_forecast_event(event_type="batch.completed", forecast_data=event.to_dict())
|
||||||
|
|
||||||
# Event handlers
|
|
||||||
|
# Event handler wrappers for aio_pika messages
|
||||||
|
async def handle_model_updated_message(message: Any):
|
||||||
|
async with message.process():
|
||||||
|
try:
|
||||||
|
event_data = json.loads(message.body.decode())
|
||||||
|
# Assuming the actual event data is nested under a 'data' key within the event dictionary
|
||||||
|
await handle_model_updated(event_data.get("data", {}))
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error("Failed to decode model updated message JSON", error=str(e), body=message.body)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error processing model updated message", error=str(e), body=message.body)
|
||||||
|
|
||||||
|
async def handle_weather_updated_message(message: Any):
|
||||||
|
async with message.process():
|
||||||
|
try:
|
||||||
|
event_data = json.loads(message.body.decode())
|
||||||
|
# Assuming the actual event data is nested under a 'data' key within the event dictionary
|
||||||
|
await handle_weather_updated(event_data.get("data", {}))
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error("Failed to decode weather updated message JSON", error=str(e), body=message.body)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error processing weather updated message", error=str(e), body=message.body)
|
||||||
|
|
||||||
|
|
||||||
|
# Original Event handlers (now called from the message wrappers)
|
||||||
async def handle_model_updated(data: Dict[str, Any]):
|
async def handle_model_updated(data: Dict[str, Any]):
|
||||||
"""Handle model updated event from training service"""
|
"""Handle model updated event from training service"""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,22 +0,0 @@
|
|||||||
"""
|
|
||||||
Authentication configuration for notification service
|
|
||||||
"""
|
|
||||||
|
|
||||||
from shared.auth.jwt_handler import JWTHandler
|
|
||||||
from shared.auth.decorators import require_auth, require_role
|
|
||||||
from app.core.config import settings
|
|
||||||
|
|
||||||
# Initialize JWT handler
|
|
||||||
jwt_handler = JWTHandler(
|
|
||||||
secret_key=settings.JWT_SECRET_KEY,
|
|
||||||
algorithm=settings.JWT_ALGORITHM,
|
|
||||||
access_token_expire_minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES
|
|
||||||
)
|
|
||||||
|
|
||||||
# Export commonly used functions
|
|
||||||
verify_token = jwt_handler.verify_token
|
|
||||||
create_access_token = jwt_handler.create_access_token
|
|
||||||
get_current_user = jwt_handler.get_current_user
|
|
||||||
|
|
||||||
# Export decorators
|
|
||||||
__all__ = ['verify_token', 'create_access_token', 'get_current_user', 'require_auth', 'require_role']
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
"""
|
|
||||||
Authentication configuration for tenant service
|
|
||||||
"""
|
|
||||||
|
|
||||||
from shared.auth.jwt_handler import JWTHandler
|
|
||||||
from shared.auth.decorators import require_auth, require_role
|
|
||||||
from app.core.config import settings
|
|
||||||
|
|
||||||
# Initialize JWT handler
|
|
||||||
jwt_handler = JWTHandler(
|
|
||||||
secret_key=settings.JWT_SECRET_KEY,
|
|
||||||
algorithm=settings.JWT_ALGORITHM,
|
|
||||||
access_token_expire_minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES
|
|
||||||
)
|
|
||||||
|
|
||||||
# Export commonly used functions
|
|
||||||
verify_token = jwt_handler.verify_token
|
|
||||||
create_access_token = jwt_handler.create_access_token
|
|
||||||
get_current_user = jwt_handler.get_current_user
|
|
||||||
|
|
||||||
# Export decorators
|
|
||||||
__all__ = ['verify_token', 'create_access_token', 'get_current_user', 'require_auth', 'require_role']
|
|
||||||
@@ -8,10 +8,12 @@ from typing import List
|
|||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
from app.core.auth import get_current_tenant_id
|
|
||||||
from app.schemas.training import TrainedModelResponse
|
from app.schemas.training import TrainedModelResponse
|
||||||
from app.services.training_service import TrainingService
|
from app.services.training_service import TrainingService
|
||||||
|
|
||||||
|
from shared.auth.decorators import (
|
||||||
|
get_current_tenant_id_dep
|
||||||
|
)
|
||||||
|
|
||||||
logger = structlog.get_logger()
|
logger = structlog.get_logger()
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@@ -20,7 +22,7 @@ training_service = TrainingService()
|
|||||||
|
|
||||||
@router.get("/", response_model=List[TrainedModelResponse])
|
@router.get("/", response_model=List[TrainedModelResponse])
|
||||||
async def get_trained_models(
|
async def get_trained_models(
|
||||||
tenant_id: str = Depends(get_current_tenant_id),
|
tenant_id: str = Depends(get_current_tenant_id_dep),
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""Get trained models"""
|
"""Get trained models"""
|
||||||
|
|||||||
@@ -1,303 +0,0 @@
|
|||||||
# services/training/app/core/auth.py
|
|
||||||
"""
|
|
||||||
Authentication and authorization for training service
|
|
||||||
"""
|
|
||||||
|
|
||||||
import structlog
|
|
||||||
from typing import Optional
|
|
||||||
from fastapi import HTTPException, Depends, status
|
|
||||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from app.core.config import settings
|
|
||||||
|
|
||||||
logger = structlog.get_logger()
|
|
||||||
|
|
||||||
# HTTP Bearer token scheme
|
|
||||||
security = HTTPBearer(auto_error=False)
|
|
||||||
|
|
||||||
class AuthenticationError(Exception):
|
|
||||||
"""Custom exception for authentication errors"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
class AuthorizationError(Exception):
|
|
||||||
"""Custom exception for authorization errors"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def verify_token(token: str) -> dict:
|
|
||||||
"""
|
|
||||||
Verify JWT token with auth service
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token: JWT token to verify
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Token payload with user and tenant information
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
AuthenticationError: If token is invalid
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
response = await client.post(
|
|
||||||
f"{settings.AUTH_SERVICE_URL}/auth/verify",
|
|
||||||
headers={"Authorization": f"Bearer {token}"},
|
|
||||||
timeout=10.0
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
token_data = response.json()
|
|
||||||
logger.debug("Token verified successfully", user_id=token_data.get("user_id"))
|
|
||||||
return token_data
|
|
||||||
elif response.status_code == 401:
|
|
||||||
logger.warning("Invalid token provided")
|
|
||||||
raise AuthenticationError("Invalid or expired token")
|
|
||||||
else:
|
|
||||||
logger.error("Auth service error", status_code=response.status_code)
|
|
||||||
raise AuthenticationError("Authentication service unavailable")
|
|
||||||
|
|
||||||
except httpx.TimeoutException:
|
|
||||||
logger.error("Auth service timeout")
|
|
||||||
raise AuthenticationError("Authentication service timeout")
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
logger.error("Auth service request error", error=str(e))
|
|
||||||
raise AuthenticationError("Authentication service unavailable")
|
|
||||||
except AuthenticationError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Unexpected auth error", error=str(e))
|
|
||||||
raise AuthenticationError("Authentication failed")
|
|
||||||
|
|
||||||
async def get_current_user(
|
|
||||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Get current authenticated user
|
|
||||||
|
|
||||||
Args:
|
|
||||||
credentials: HTTP Bearer credentials
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: User information
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: If authentication fails
|
|
||||||
"""
|
|
||||||
if not credentials:
|
|
||||||
logger.warning("No credentials provided")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Authentication credentials required",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
token_data = await verify_token(credentials.credentials)
|
|
||||||
return token_data
|
|
||||||
|
|
||||||
except AuthenticationError as e:
|
|
||||||
logger.warning("Authentication failed", error=str(e))
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail=str(e),
|
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_current_tenant_id(
|
|
||||||
current_user: dict = Depends(get_current_user)
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Get current tenant ID from authenticated user
|
|
||||||
|
|
||||||
Args:
|
|
||||||
current_user: Current authenticated user data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Tenant ID
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: If tenant ID is missing
|
|
||||||
"""
|
|
||||||
tenant_id = current_user.get("tenant_id")
|
|
||||||
if not tenant_id:
|
|
||||||
logger.error("Missing tenant_id in token", user_data=current_user)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail="Invalid token: missing tenant information"
|
|
||||||
)
|
|
||||||
|
|
||||||
return tenant_id
|
|
||||||
|
|
||||||
async def require_admin_role(
|
|
||||||
current_user: dict = Depends(get_current_user)
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Require admin role for endpoint access
|
|
||||||
|
|
||||||
Args:
|
|
||||||
current_user: Current authenticated user data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: User information
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: If user is not admin
|
|
||||||
"""
|
|
||||||
user_role = current_user.get("role", "").lower()
|
|
||||||
if user_role != "admin":
|
|
||||||
logger.warning("Access denied - admin role required",
|
|
||||||
user_id=current_user.get("user_id"),
|
|
||||||
role=user_role)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail="Admin role required"
|
|
||||||
)
|
|
||||||
|
|
||||||
return current_user
|
|
||||||
|
|
||||||
async def require_training_permission(
|
|
||||||
current_user: dict = Depends(get_current_user)
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Require training permission for endpoint access
|
|
||||||
|
|
||||||
Args:
|
|
||||||
current_user: Current authenticated user data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: User information
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: If user doesn't have training permission
|
|
||||||
"""
|
|
||||||
permissions = current_user.get("permissions", [])
|
|
||||||
if "training" not in permissions and current_user.get("role", "").lower() != "admin":
|
|
||||||
logger.warning("Access denied - training permission required",
|
|
||||||
user_id=current_user.get("user_id"),
|
|
||||||
permissions=permissions)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail="Training permission required"
|
|
||||||
)
|
|
||||||
|
|
||||||
return current_user
|
|
||||||
|
|
||||||
# Optional authentication for development/testing
|
|
||||||
async def get_current_user_optional(
|
|
||||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
|
|
||||||
) -> Optional[dict]:
|
|
||||||
"""
|
|
||||||
Get current user but don't require authentication (for development)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
credentials: HTTP Bearer credentials
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict or None: User information if authenticated, None otherwise
|
|
||||||
"""
|
|
||||||
if not credentials:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
token_data = await verify_token(credentials.credentials)
|
|
||||||
return token_data
|
|
||||||
except AuthenticationError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def get_tenant_id_optional(
|
|
||||||
current_user: Optional[dict] = Depends(get_current_user_optional)
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Get tenant ID but don't require authentication (for development)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
current_user: Current user data (optional)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str or None: Tenant ID if available, None otherwise
|
|
||||||
"""
|
|
||||||
if not current_user:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return current_user.get("tenant_id")
|
|
||||||
|
|
||||||
# Development/testing auth bypass
|
|
||||||
async def get_test_tenant_id() -> str:
|
|
||||||
"""
|
|
||||||
Get test tenant ID for development/testing
|
|
||||||
Only works when DEBUG is enabled
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Test tenant ID
|
|
||||||
"""
|
|
||||||
if settings.DEBUG:
|
|
||||||
return "test-tenant-development"
|
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail="Test authentication only available in debug mode"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Token validation utility
|
|
||||||
def validate_token_structure(token_data: dict) -> bool:
|
|
||||||
"""
|
|
||||||
Validate that token data has required structure
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token_data: Token payload data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if valid structure, False otherwise
|
|
||||||
"""
|
|
||||||
required_fields = ["user_id", "tenant_id"]
|
|
||||||
|
|
||||||
for field in required_fields:
|
|
||||||
if field not in token_data:
|
|
||||||
logger.warning("Invalid token structure - missing field", field=field)
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Role checking utilities
|
|
||||||
def has_role(user_data: dict, required_role: str) -> bool:
|
|
||||||
"""
|
|
||||||
Check if user has required role
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_data: User data from token
|
|
||||||
required_role: Required role name
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if user has role, False otherwise
|
|
||||||
"""
|
|
||||||
user_role = user_data.get("role", "").lower()
|
|
||||||
return user_role == required_role.lower()
|
|
||||||
|
|
||||||
def has_permission(user_data: dict, required_permission: str) -> bool:
|
|
||||||
"""
|
|
||||||
Check if user has required permission
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_data: User data from token
|
|
||||||
required_permission: Required permission name
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if user has permission, False otherwise
|
|
||||||
"""
|
|
||||||
permissions = user_data.get("permissions", [])
|
|
||||||
return required_permission in permissions or has_role(user_data, "admin")
|
|
||||||
|
|
||||||
# Export commonly used items
|
|
||||||
__all__ = [
|
|
||||||
'get_current_user',
|
|
||||||
'get_current_tenant_id',
|
|
||||||
'require_admin_role',
|
|
||||||
'require_training_permission',
|
|
||||||
'get_current_user_optional',
|
|
||||||
'get_tenant_id_optional',
|
|
||||||
'get_test_tenant_id',
|
|
||||||
'has_role',
|
|
||||||
'has_permission',
|
|
||||||
'AuthenticationError',
|
|
||||||
'AuthorizationError'
|
|
||||||
]
|
|
||||||
Reference in New Issue
Block a user