Fix forecasting service
This commit is contained in:
@@ -8,10 +8,12 @@ from typing import List
|
||||
import structlog
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.auth import get_current_tenant_id
|
||||
from app.schemas.training import TrainedModelResponse
|
||||
from app.services.training_service import TrainingService
|
||||
|
||||
from shared.auth.decorators import (
|
||||
get_current_tenant_id_dep
|
||||
)
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter()
|
||||
@@ -20,7 +22,7 @@ training_service = TrainingService()
|
||||
|
||||
@router.get("/", response_model=List[TrainedModelResponse])
|
||||
async def get_trained_models(
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
tenant_id: str = Depends(get_current_tenant_id_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get trained models"""
|
||||
|
||||
@@ -1,303 +0,0 @@
|
||||
# services/training/app/core/auth.py
|
||||
"""
|
||||
Authentication and authorization for training service
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from typing import Optional
|
||||
from fastapi import HTTPException, Depends, status
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
import httpx
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# HTTP Bearer token scheme
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
class AuthenticationError(Exception):
|
||||
"""Custom exception for authentication errors"""
|
||||
pass
|
||||
|
||||
class AuthorizationError(Exception):
|
||||
"""Custom exception for authorization errors"""
|
||||
pass
|
||||
|
||||
async def verify_token(token: str) -> dict:
|
||||
"""
|
||||
Verify JWT token with auth service
|
||||
|
||||
Args:
|
||||
token: JWT token to verify
|
||||
|
||||
Returns:
|
||||
dict: Token payload with user and tenant information
|
||||
|
||||
Raises:
|
||||
AuthenticationError: If token is invalid
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{settings.AUTH_SERVICE_URL}/auth/verify",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
token_data = response.json()
|
||||
logger.debug("Token verified successfully", user_id=token_data.get("user_id"))
|
||||
return token_data
|
||||
elif response.status_code == 401:
|
||||
logger.warning("Invalid token provided")
|
||||
raise AuthenticationError("Invalid or expired token")
|
||||
else:
|
||||
logger.error("Auth service error", status_code=response.status_code)
|
||||
raise AuthenticationError("Authentication service unavailable")
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.error("Auth service timeout")
|
||||
raise AuthenticationError("Authentication service timeout")
|
||||
except httpx.RequestError as e:
|
||||
logger.error("Auth service request error", error=str(e))
|
||||
raise AuthenticationError("Authentication service unavailable")
|
||||
except AuthenticationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Unexpected auth error", error=str(e))
|
||||
raise AuthenticationError("Authentication failed")
|
||||
|
||||
async def get_current_user(
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
|
||||
) -> dict:
|
||||
"""
|
||||
Get current authenticated user
|
||||
|
||||
Args:
|
||||
credentials: HTTP Bearer credentials
|
||||
|
||||
Returns:
|
||||
dict: User information
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails
|
||||
"""
|
||||
if not credentials:
|
||||
logger.warning("No credentials provided")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication credentials required",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
try:
|
||||
token_data = await verify_token(credentials.credentials)
|
||||
return token_data
|
||||
|
||||
except AuthenticationError as e:
|
||||
logger.warning("Authentication failed", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
async def get_current_tenant_id(
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> str:
|
||||
"""
|
||||
Get current tenant ID from authenticated user
|
||||
|
||||
Args:
|
||||
current_user: Current authenticated user data
|
||||
|
||||
Returns:
|
||||
str: Tenant ID
|
||||
|
||||
Raises:
|
||||
HTTPException: If tenant ID is missing
|
||||
"""
|
||||
tenant_id = current_user.get("tenant_id")
|
||||
if not tenant_id:
|
||||
logger.error("Missing tenant_id in token", user_data=current_user)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid token: missing tenant information"
|
||||
)
|
||||
|
||||
return tenant_id
|
||||
|
||||
async def require_admin_role(
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> dict:
|
||||
"""
|
||||
Require admin role for endpoint access
|
||||
|
||||
Args:
|
||||
current_user: Current authenticated user data
|
||||
|
||||
Returns:
|
||||
dict: User information
|
||||
|
||||
Raises:
|
||||
HTTPException: If user is not admin
|
||||
"""
|
||||
user_role = current_user.get("role", "").lower()
|
||||
if user_role != "admin":
|
||||
logger.warning("Access denied - admin role required",
|
||||
user_id=current_user.get("user_id"),
|
||||
role=user_role)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Admin role required"
|
||||
)
|
||||
|
||||
return current_user
|
||||
|
||||
async def require_training_permission(
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> dict:
|
||||
"""
|
||||
Require training permission for endpoint access
|
||||
|
||||
Args:
|
||||
current_user: Current authenticated user data
|
||||
|
||||
Returns:
|
||||
dict: User information
|
||||
|
||||
Raises:
|
||||
HTTPException: If user doesn't have training permission
|
||||
"""
|
||||
permissions = current_user.get("permissions", [])
|
||||
if "training" not in permissions and current_user.get("role", "").lower() != "admin":
|
||||
logger.warning("Access denied - training permission required",
|
||||
user_id=current_user.get("user_id"),
|
||||
permissions=permissions)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Training permission required"
|
||||
)
|
||||
|
||||
return current_user
|
||||
|
||||
# Optional authentication for development/testing
|
||||
async def get_current_user_optional(
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get current user but don't require authentication (for development)
|
||||
|
||||
Args:
|
||||
credentials: HTTP Bearer credentials
|
||||
|
||||
Returns:
|
||||
dict or None: User information if authenticated, None otherwise
|
||||
"""
|
||||
if not credentials:
|
||||
return None
|
||||
|
||||
try:
|
||||
token_data = await verify_token(credentials.credentials)
|
||||
return token_data
|
||||
except AuthenticationError:
|
||||
return None
|
||||
|
||||
async def get_tenant_id_optional(
|
||||
current_user: Optional[dict] = Depends(get_current_user_optional)
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get tenant ID but don't require authentication (for development)
|
||||
|
||||
Args:
|
||||
current_user: Current user data (optional)
|
||||
|
||||
Returns:
|
||||
str or None: Tenant ID if available, None otherwise
|
||||
"""
|
||||
if not current_user:
|
||||
return None
|
||||
|
||||
return current_user.get("tenant_id")
|
||||
|
||||
# Development/testing auth bypass
|
||||
async def get_test_tenant_id() -> str:
|
||||
"""
|
||||
Get test tenant ID for development/testing
|
||||
Only works when DEBUG is enabled
|
||||
|
||||
Returns:
|
||||
str: Test tenant ID
|
||||
"""
|
||||
if settings.DEBUG:
|
||||
return "test-tenant-development"
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Test authentication only available in debug mode"
|
||||
)
|
||||
|
||||
# Token validation utility
|
||||
def validate_token_structure(token_data: dict) -> bool:
|
||||
"""
|
||||
Validate that token data has required structure
|
||||
|
||||
Args:
|
||||
token_data: Token payload data
|
||||
|
||||
Returns:
|
||||
bool: True if valid structure, False otherwise
|
||||
"""
|
||||
required_fields = ["user_id", "tenant_id"]
|
||||
|
||||
for field in required_fields:
|
||||
if field not in token_data:
|
||||
logger.warning("Invalid token structure - missing field", field=field)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
# Role checking utilities
|
||||
def has_role(user_data: dict, required_role: str) -> bool:
|
||||
"""
|
||||
Check if user has required role
|
||||
|
||||
Args:
|
||||
user_data: User data from token
|
||||
required_role: Required role name
|
||||
|
||||
Returns:
|
||||
bool: True if user has role, False otherwise
|
||||
"""
|
||||
user_role = user_data.get("role", "").lower()
|
||||
return user_role == required_role.lower()
|
||||
|
||||
def has_permission(user_data: dict, required_permission: str) -> bool:
|
||||
"""
|
||||
Check if user has required permission
|
||||
|
||||
Args:
|
||||
user_data: User data from token
|
||||
required_permission: Required permission name
|
||||
|
||||
Returns:
|
||||
bool: True if user has permission, False otherwise
|
||||
"""
|
||||
permissions = user_data.get("permissions", [])
|
||||
return required_permission in permissions or has_role(user_data, "admin")
|
||||
|
||||
# Export commonly used items
|
||||
__all__ = [
|
||||
'get_current_user',
|
||||
'get_current_tenant_id',
|
||||
'require_admin_role',
|
||||
'require_training_permission',
|
||||
'get_current_user_optional',
|
||||
'get_tenant_id_optional',
|
||||
'get_test_tenant_id',
|
||||
'has_role',
|
||||
'has_permission',
|
||||
'AuthenticationError',
|
||||
'AuthorizationError'
|
||||
]
|
||||
Reference in New Issue
Block a user