2025-07-17 13:09:24 +02:00
|
|
|
import logging
|
2025-07-17 19:54:04 +02:00
|
|
|
from fastapi import Request
|
2025-07-17 13:09:24 +02:00
|
|
|
from fastapi.responses import JSONResponse
|
2025-07-17 19:54:04 +02:00
|
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
|
from starlette.responses import Response
|
2025-07-17 13:09:24 +02:00
|
|
|
import httpx
|
|
|
|
|
from typing import Optional
|
2025-07-18 16:48:49 +02:00
|
|
|
import json
|
2025-07-17 13:09:24 +02:00
|
|
|
|
|
|
|
|
from app.core.config import settings
|
|
|
|
|
from shared.auth.jwt_handler import JWTHandler
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
# JWT handler
|
|
|
|
|
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
|
|
|
|
|
|
|
|
|
|
# Routes that don't require authentication
|
|
|
|
|
PUBLIC_ROUTES = [
|
|
|
|
|
"/health",
|
2025-07-18 16:48:49 +02:00
|
|
|
"/metrics",
|
2025-07-17 13:09:24 +02:00
|
|
|
"/docs",
|
|
|
|
|
"/redoc",
|
|
|
|
|
"/openapi.json",
|
|
|
|
|
"/api/v1/auth/login",
|
|
|
|
|
"/api/v1/auth/register",
|
2025-07-18 16:48:49 +02:00
|
|
|
"/api/v1/auth/refresh",
|
|
|
|
|
"/api/v1/auth/verify" # ✅ Add verify to public routes
|
2025-07-17 13:09:24 +02:00
|
|
|
]
|
|
|
|
|
|
2025-07-17 19:54:04 +02:00
|
|
|
class AuthMiddleware(BaseHTTPMiddleware):
|
2025-07-18 16:48:49 +02:00
|
|
|
"""Authentication middleware with better error handling"""
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-07-17 19:54:04 +02:00
|
|
|
async def dispatch(self, request: Request, call_next) -> Response:
|
|
|
|
|
"""Process request with authentication"""
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-07-17 19:54:04 +02:00
|
|
|
# Check if route requires authentication
|
|
|
|
|
if self._is_public_route(request.url.path):
|
2025-07-17 13:09:24 +02:00
|
|
|
return await call_next(request)
|
2025-07-17 19:54:04 +02:00
|
|
|
|
|
|
|
|
# Get token from header
|
|
|
|
|
token = self._extract_token(request)
|
|
|
|
|
if not token:
|
2025-07-18 16:48:49 +02:00
|
|
|
logger.warning(f"Missing token for {request.url.path}")
|
2025-07-17 19:54:04 +02:00
|
|
|
return JSONResponse(
|
|
|
|
|
status_code=401,
|
|
|
|
|
content={"detail": "Authentication required"}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Verify token
|
|
|
|
|
try:
|
|
|
|
|
# First try to verify token locally
|
|
|
|
|
payload = jwt_handler.verify_token(token)
|
|
|
|
|
|
|
|
|
|
if payload:
|
2025-07-18 16:48:49 +02:00
|
|
|
# Validate required fields
|
|
|
|
|
required_fields = ["user_id", "email", "tenant_id"]
|
|
|
|
|
missing_fields = [field for field in required_fields if field not in payload]
|
|
|
|
|
|
|
|
|
|
if missing_fields:
|
|
|
|
|
logger.warning(f"Token missing required fields: {missing_fields}")
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
status_code=401,
|
|
|
|
|
content={"detail": f"Invalid token: missing {missing_fields}"}
|
|
|
|
|
)
|
|
|
|
|
|
2025-07-17 19:54:04 +02:00
|
|
|
# Add user info to request state
|
|
|
|
|
request.state.user = payload
|
2025-07-18 16:48:49 +02:00
|
|
|
logger.debug(f"Authenticated user: {payload.get('email')} (tenant: {payload.get('tenant_id')})")
|
2025-07-17 13:09:24 +02:00
|
|
|
return await call_next(request)
|
|
|
|
|
else:
|
2025-07-18 16:48:49 +02:00
|
|
|
# Token invalid or expired, try auth service verification
|
|
|
|
|
logger.info("Local token verification failed, trying auth service")
|
2025-07-17 19:54:04 +02:00
|
|
|
user_info = await self._verify_with_auth_service(token)
|
|
|
|
|
if user_info:
|
|
|
|
|
request.state.user = user_info
|
|
|
|
|
return await call_next(request)
|
|
|
|
|
else:
|
2025-07-18 16:48:49 +02:00
|
|
|
logger.warning("Auth service verification also failed")
|
2025-07-17 19:54:04 +02:00
|
|
|
return JSONResponse(
|
|
|
|
|
status_code=401,
|
|
|
|
|
content={"detail": "Invalid or expired token"}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Authentication error: {e}")
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
status_code=401,
|
|
|
|
|
content={"detail": "Authentication failed"}
|
|
|
|
|
)
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-07-17 19:54:04 +02:00
|
|
|
def _is_public_route(self, path: str) -> bool:
|
|
|
|
|
"""Check if route is public"""
|
|
|
|
|
return any(path.startswith(route) for route in PUBLIC_ROUTES)
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-07-17 19:54:04 +02:00
|
|
|
def _extract_token(self, request: Request) -> Optional[str]:
|
|
|
|
|
"""Extract JWT token from request"""
|
|
|
|
|
auth_header = request.headers.get("Authorization")
|
|
|
|
|
if auth_header and auth_header.startswith("Bearer "):
|
|
|
|
|
return auth_header.split(" ")[1]
|
|
|
|
|
return None
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-07-17 19:54:04 +02:00
|
|
|
async def _verify_with_auth_service(self, token: str) -> Optional[dict]:
|
|
|
|
|
"""Verify token with auth service"""
|
|
|
|
|
try:
|
|
|
|
|
async with httpx.AsyncClient(timeout=5.0) as client:
|
|
|
|
|
response = await client.post(
|
2025-07-18 16:48:49 +02:00
|
|
|
f"{settings.AUTH_SERVICE_URL}/api/v1/auth/verify",
|
2025-07-17 19:54:04 +02:00
|
|
|
headers={"Authorization": f"Bearer {token}"}
|
|
|
|
|
)
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-07-17 19:54:04 +02:00
|
|
|
if response.status_code == 200:
|
2025-07-18 16:48:49 +02:00
|
|
|
user_info = response.json()
|
|
|
|
|
logger.debug(f"Auth service verification successful: {user_info.get('email')}")
|
|
|
|
|
return user_info
|
2025-07-17 19:54:04 +02:00
|
|
|
else:
|
2025-07-18 16:48:49 +02:00
|
|
|
logger.warning(f"Auth service verification failed: {response.status_code}")
|
2025-07-17 19:54:04 +02:00
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Auth service verification failed: {e}")
|
2025-07-18 16:48:49 +02:00
|
|
|
return None
|