REFACTOR ALL APIs

This commit is contained in:
Urtzi Alfaro
2025-10-06 15:27:01 +02:00
parent dc8221bd2f
commit 38fb98bc27
166 changed files with 18454 additions and 13605 deletions

View File

@@ -291,7 +291,7 @@ async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_
# Build WebSocket URL to training service
training_service_base = settings.TRAINING_SERVICE_URL.rstrip('/')
training_ws_url = training_service_base.replace('http://', 'ws://').replace('https://', 'wss://')
training_ws_url = f"{training_ws_url}/api/v1/ws/tenants/{tenant_id}/training/jobs/{job_id}/live?token={token}"
training_ws_url = f"{training_ws_url}/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live?token={token}"
training_ws = None
heartbeat_task = None
@@ -348,12 +348,20 @@ async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_
try:
# Use longer timeout to avoid conflicts with frontend 30s heartbeat
# Frontend sends ping every 30s, so we need to allow for some latency
message = await asyncio.wait_for(websocket.receive_text(), timeout=45.0)
data = await asyncio.wait_for(websocket.receive(), timeout=45.0)
last_activity = asyncio.get_event_loop().time()
# Forward the message to training service
await training_ws.send(message)
logger.debug(f"Forwarded message to training service for job {job_id}: {message[:100]}...")
# Handle different message types
if data.get("type") == "websocket.receive":
if "text" in data:
message = data["text"]
# Forward text messages to training service
await training_ws.send(message)
logger.debug(f"Forwarded message to training service for job {job_id}: {message[:100]}...")
elif "bytes" in data:
# Forward binary messages if needed
await training_ws.send(data["bytes"])
# Ping/pong frames are automatically handled by Starlette/FastAPI
except asyncio.TimeoutError:
# No message received in 45 seconds, continue loop

View File

@@ -108,6 +108,11 @@ class AuthMiddleware(BaseHTTPMiddleware):
content={"detail": f"Access denied to tenant {tenant_id}"}
)
# Get tenant subscription tier and inject into user context
subscription_tier = await self._get_tenant_subscription_tier(tenant_id, request)
if subscription_tier:
user_context["subscription_tier"] = subscription_tier
# Set tenant context in request state
request.state.tenant_id = tenant_id
request.state.tenant_verified = True
@@ -115,6 +120,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
logger.debug(f"Tenant access verified",
user_id=user_context["user_id"],
tenant_id=tenant_id,
subscription_tier=subscription_tier,
path=request.url.path)
# ✅ STEP 5: Inject user context into request
@@ -386,7 +392,72 @@ class AuthMiddleware(BaseHTTPMiddleware):
b"x-tenant-id", tenant_id.encode()
))
# Add subscription tier if available
subscription_tier = user_context.get("subscription_tier", "")
if subscription_tier:
request.headers.__dict__["_list"].append((
b"x-subscription-tier", subscription_tier.encode()
))
# Add gateway identification
request.headers.__dict__["_list"].append((
b"x-forwarded-by", b"bakery-gateway"
))
))
async def _get_tenant_subscription_tier(self, tenant_id: str, request: Request) -> Optional[str]:
"""
Get tenant subscription tier from tenant service
Args:
tenant_id: Tenant ID
request: FastAPI request for headers
Returns:
Subscription tier string or None
"""
try:
# Check cache first
if self.redis_client:
cache_key = f"tenant:tier:{tenant_id}"
try:
cached_tier = await self.redis_client.get(cache_key)
if cached_tier:
if isinstance(cached_tier, bytes):
cached_tier = cached_tier.decode()
logger.debug("Subscription tier from cache", tenant_id=tenant_id, tier=cached_tier)
return cached_tier
except Exception as e:
logger.warning(f"Cache lookup failed for tenant tier: {e}")
# Get from tenant service
async with httpx.AsyncClient(timeout=5.0) as client:
headers = {"Authorization": request.headers.get("Authorization", "")}
response = await client.get(
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}",
headers=headers
)
if response.status_code == 200:
tenant_data = response.json()
subscription_tier = tenant_data.get("subscription_tier", "basic")
# Cache for 5 minutes
if self.redis_client:
try:
await self.redis_client.setex(
f"tenant:tier:{tenant_id}",
300, # 5 minutes
subscription_tier
)
except Exception as e:
logger.warning(f"Failed to cache tenant tier: {e}")
logger.debug("Subscription tier from service", tenant_id=tenant_id, tier=subscription_tier)
return subscription_tier
else:
logger.warning(f"Failed to get tenant subscription tier: {response.status_code}")
return "basic" # Default to basic
except Exception as e:
logger.error(f"Error getting tenant subscription tier: {e}")
return "basic" # Default to basic on error

View File

@@ -1,5 +1,6 @@
"""
Subscription Middleware - Enforces subscription limits and feature access
Updated to support standardized URL structure with tier-based access control
"""
import re
@@ -9,7 +10,7 @@ from fastapi import Request, Response, HTTPException
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
import httpx
from typing import Dict, Any, Optional
from typing import Dict, Any, Optional, List
import asyncio
from app.core.config import settings
@@ -18,48 +19,71 @@ logger = structlog.get_logger()
class SubscriptionMiddleware(BaseHTTPMiddleware):
"""Middleware to enforce subscription-based access control"""
"""
Middleware to enforce subscription-based access control
Supports standardized URL structure:
- Base routes (/api/v1/tenants/{tenant_id}/{service}/{resource}): ALL tiers
- Dashboard routes (/api/v1/tenants/{tenant_id}/{service}/dashboard/*): ALL tiers
- Analytics routes (/api/v1/tenants/{tenant_id}/{service}/analytics/*): PROFESSIONAL+
- Operations routes (/api/v1/tenants/{tenant_id}/{service}/operations/*): ALL tiers (role-based)
"""
def __init__(self, app, tenant_service_url: str):
super().__init__(app)
self.tenant_service_url = tenant_service_url.rstrip('/')
# Define route patterns that require subscription validation
# Using new standardized URL structure
self.protected_routes = {
# Analytics routes - require different levels based on actual app routes
r'/api/v1/tenants/[^/]+/analytics/.*': {
# ===== ANALYTICS ROUTES - PROFESSIONAL/ENTERPRISE ONLY =====
# Any service analytics endpoint
r'^/api/v1/tenants/[^/]+/[^/]+/analytics/.*': {
'feature': 'analytics',
'minimum_level': 'basic' # Changed to basic to allow all tiers access to analytics
'minimum_tier': 'professional',
'allowed_tiers': ['professional', 'enterprise'],
'description': 'Analytics features (Professional/Enterprise only)'
},
r'/api/v1/tenants/[^/]+/forecasts/.*': {
'feature': 'analytics',
'minimum_level': 'basic' # Changed to basic to allow all tiers access to forecasting
# ===== TRAINING SERVICE - ALL TIERS =====
r'^/api/v1/tenants/[^/]+/training/.*': {
'feature': 'ml_training',
'minimum_tier': 'basic',
'allowed_tiers': ['basic', 'professional', 'enterprise'],
'description': 'Machine learning model training (Available for all tiers)'
},
r'/api/v1/tenants/[^/]+/predictions/.*': {
'feature': 'analytics',
'minimum_level': 'basic' # Changed to basic to allow all tiers access to predictions
# ===== ADVANCED FEATURES - PROFESSIONAL/ENTERPRISE =====
# Advanced reporting and exports
r'^/api/v1/tenants/[^/]+/[^/]+/export/advanced.*': {
'feature': 'advanced_exports',
'minimum_tier': 'professional',
'allowed_tiers': ['professional', 'enterprise'],
'description': 'Advanced export formats (Professional/Enterprise only)'
},
# Training and AI models - Now available to all tiers
r'/api/v1/tenants/[^/]+/training/.*': {
'feature': 'analytics',
'minimum_level': 'basic' # Changed to basic to allow all tiers access to training
# Bulk operations
r'^/api/v1/tenants/[^/]+/[^/]+/bulk/.*': {
'feature': 'bulk_operations',
'minimum_tier': 'professional',
'allowed_tiers': ['professional', 'enterprise'],
'description': 'Bulk operations (Professional/Enterprise only)'
},
r'/api/v1/tenants/[^/]+/models/.*': {
'feature': 'analytics',
'minimum_level': 'basic' # Changed to basic to allow all tiers access to models
},
# Advanced production features - Professional+
r'/api/v1/tenants/[^/]+/production/optimization/.*': {
'feature': 'analytics',
'minimum_level': 'basic'
},
# Enterprise-only features
r'/api/v1/tenants/[^/]+/statistics.*': {
'feature': 'analytics',
'minimum_level': 'basic' # Advanced stats for Enterprise only
}
}
# Routes that are explicitly allowed for all tiers (no check needed)
self.public_tier_routes = [
# Base CRUD operations - ALL TIERS
r'^/api/v1/tenants/[^/]+/[^/]+/(?!analytics|export/advanced|bulk)[^/]+/?$',
r'^/api/v1/tenants/[^/]+/[^/]+/(?!analytics|export/advanced|bulk)[^/]+/[^/]+/?$',
# Dashboard routes - ALL TIERS
r'^/api/v1/tenants/[^/]+/[^/]+/dashboard/.*',
# Operations routes - ALL TIERS (role-based control applies)
r'^/api/v1/tenants/[^/]+/[^/]+/operations/.*',
]
async def dispatch(self, request: Request, call_next):
"""Process the request and check subscription requirements"""
@@ -67,6 +91,10 @@ class SubscriptionMiddleware(BaseHTTPMiddleware):
if self._should_skip_subscription_check(request):
return await call_next(request)
# Check if route is explicitly allowed for all tiers
if self._is_public_tier_route(request.url.path):
return await call_next(request)
# Check if route requires subscription validation
subscription_requirement = self._get_subscription_requirement(request.url.path)
if not subscription_requirement:
@@ -84,25 +112,28 @@ class SubscriptionMiddleware(BaseHTTPMiddleware):
}
)
# Validate subscription
validation_result = await self._validate_subscription(
# Validate subscription with new tier-based system
validation_result = await self._validate_subscription_tier(
request,
tenant_id,
subscription_requirement['feature'],
subscription_requirement['minimum_level']
subscription_requirement.get('feature'),
subscription_requirement.get('minimum_tier'),
subscription_requirement.get('allowed_tiers', [])
)
if not validation_result['allowed']:
return JSONResponse(
status_code=403,
status_code=402, # Payment Required for tier limitations
content={
"error": "subscription_required",
"error": "subscription_tier_insufficient",
"message": validation_result['message'],
"code": "SUBSCRIPTION_UPGRADE_REQUIRED",
"details": {
"required_feature": subscription_requirement['feature'],
"required_level": subscription_requirement['minimum_level'],
"current_plan": validation_result.get('current_plan', 'unknown'),
"required_feature": subscription_requirement.get('feature'),
"minimum_tier": subscription_requirement.get('minimum_tier'),
"allowed_tiers": subscription_requirement.get('allowed_tiers', []),
"current_tier": validation_result.get('current_tier', 'unknown'),
"description": subscription_requirement.get('description', ''),
"upgrade_url": "/app/settings/profile"
}
}
@@ -112,6 +143,22 @@ class SubscriptionMiddleware(BaseHTTPMiddleware):
response = await call_next(request)
return response
def _is_public_tier_route(self, path: str) -> bool:
"""
Check if route is explicitly allowed for all subscription tiers
Args:
path: Request path
Returns:
True if route is allowed for all tiers
"""
for pattern in self.public_tier_routes:
if re.match(pattern, path):
logger.debug("Route allowed for all tiers", path=path, pattern=pattern)
return True
return False
def _should_skip_subscription_check(self, request: Request) -> bool:
"""Check if subscription validation should be skipped"""
path = request.url.path
@@ -163,20 +210,33 @@ class SubscriptionMiddleware(BaseHTTPMiddleware):
return None
async def _validate_subscription(
async def _validate_subscription_tier(
self,
request: Request,
tenant_id: str,
feature: str,
minimum_level: str
feature: Optional[str],
minimum_tier: str,
allowed_tiers: List[str]
) -> Dict[str, Any]:
"""Validate subscription feature access using the same pattern as other gateway services"""
"""
Validate subscription tier access using tenant service
Args:
request: FastAPI request
tenant_id: Tenant ID
feature: Feature name (optional, for additional checks)
minimum_tier: Minimum required subscription tier
allowed_tiers: List of allowed subscription tiers
Returns:
Dict with 'allowed' boolean and additional metadata
"""
try:
# Use the same authentication pattern as gateway routes
headers = dict(request.headers)
headers.pop("host", None)
# Add user context headers if available (same as _proxy_request)
# Add user context headers if available
if hasattr(request.state, 'user') and request.state.user:
user = request.state.user
headers["x-user-id"] = str(user.get('user_id', ''))
@@ -185,64 +245,58 @@ class SubscriptionMiddleware(BaseHTTPMiddleware):
headers["x-user-full-name"] = str(user.get('full_name', ''))
headers["x-tenant-id"] = str(user.get('tenant_id', ''))
# Call tenant service to check subscription with gateway-appropriate timeout
# Call tenant service to get subscription tier with gateway-appropriate timeout
timeout_config = httpx.Timeout(
connect=2.0, # Connection timeout - short for gateway
read=10.0, # Read timeout
write=2.0, # Write timeout
pool=2.0 # Pool timeout
)
async with httpx.AsyncClient(timeout=timeout_config) as client:
# Check feature access
feature_response = await client.get(
f"{settings.TENANT_SERVICE_URL}/api/v1/subscriptions/{tenant_id}/features/{feature}",
# Get tenant subscription information
tenant_response = await client.get(
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}",
headers=headers
)
if feature_response.status_code != 200:
if tenant_response.status_code != 200:
logger.warning(
"Failed to check feature access",
"Failed to get tenant subscription",
tenant_id=tenant_id,
feature=feature,
status_code=feature_response.status_code,
response_text=feature_response.text,
url=f"{settings.TENANT_SERVICE_URL}/api/v1/subscriptions/{tenant_id}/features/{feature}"
status_code=tenant_response.status_code,
response_text=tenant_response.text
)
# Fail open for availability (let service handle detailed check if needed)
# Fail open for availability
return {
'allowed': True,
'message': 'Access granted (validation service unavailable)',
'current_plan': 'unknown'
'current_tier': 'unknown'
}
feature_data = feature_response.json()
logger.info("Feature check response",
tenant_id=tenant_id,
feature=feature,
response=feature_data)
tenant_data = tenant_response.json()
current_tier = tenant_data.get('subscription_tier', 'basic').lower()
if not feature_data.get('has_feature'):
logger.debug("Subscription tier check",
tenant_id=tenant_id,
current_tier=current_tier,
minimum_tier=minimum_tier,
allowed_tiers=allowed_tiers)
# Check if current tier is in allowed tiers
if current_tier not in [tier.lower() for tier in allowed_tiers]:
tier_names = ', '.join(allowed_tiers)
return {
'allowed': False,
'message': f'Feature "{feature}" not available in your current plan',
'current_plan': feature_data.get('plan', 'unknown')
'message': f'This feature requires a {tier_names} subscription plan',
'current_tier': current_tier
}
# Check feature level if it's analytics
if feature == 'analytics':
feature_level = feature_data.get('feature_value', 'basic')
if not self._check_analytics_level(feature_level, minimum_level):
return {
'allowed': False,
'message': f'Analytics level "{minimum_level}" required. Current level: "{feature_level}"',
'current_plan': feature_data.get('plan', 'unknown')
}
# Tier check passed
return {
'allowed': True,
'message': 'Access granted',
'current_plan': feature_data.get('plan', 'unknown')
'current_tier': current_tier
}
except asyncio.TimeoutError:
@@ -284,15 +338,3 @@ class SubscriptionMiddleware(BaseHTTPMiddleware):
'current_plan': 'unknown'
}
def _check_analytics_level(self, current_level: str, required_level: str) -> bool:
"""Check if current analytics level meets the requirement"""
level_hierarchy = {
'basic': 1,
'advanced': 2,
'predictive': 3
}
current_rank = level_hierarchy.get(current_level, 0)
required_rank = level_hierarchy.get(required_level, 0)
return current_rank >= required_rank

View File

@@ -214,7 +214,7 @@ async def change_password(request: Request):
# CATCH-ALL ROUTE for any other auth endpoints
# ================================================================
@router.api_route("/auth/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
async def proxy_auth_requests(path: str, request: Request):
"""Catch-all proxy for auth requests"""
return await auth_proxy.forward_request(request.method, path, request)

View File

@@ -391,25 +391,65 @@ async def _proxy_request(request: Request, target_path: str, service_url: str, t
# Get request body if present
body = None
files = None
data = None
if request.method in ["POST", "PUT", "PATCH"]:
body = await request.body()
content_type = request.headers.get("content-type", "")
logger.info(f"Processing {request.method} request with content-type: {content_type}")
# Handle multipart/form-data (file uploads)
if "multipart/form-data" in content_type:
logger.info("Detected multipart/form-data, parsing form...")
# For multipart/form-data, we need to re-parse and forward as files
form = await request.form()
logger.info(f"Form parsed, found {len(form)} fields: {list(form.keys())}")
# Extract files and form fields separately
files_dict = {}
data_dict = {}
for key, value in form.items():
if hasattr(value, 'file'): # It's a file
# Read file content
file_content = await value.read()
files_dict[key] = (value.filename, file_content, value.content_type)
logger.info(f"Found file field '{key}': filename={value.filename}, size={len(file_content)}, type={value.content_type}")
else: # It's a regular form field
data_dict[key] = value
logger.info(f"Found form field '{key}': value={value}")
files = files_dict if files_dict else None
data = data_dict if data_dict else None
logger.info(f"Forwarding multipart request with files={list(files.keys()) if files else None}, data={list(data.keys()) if data else None}")
# Remove content-type from headers - httpx will set it with new boundary
headers.pop("content-type", None)
headers.pop("content-length", None)
else:
# For other content types, use body as before
body = await request.body()
logger.info(f"Using raw body, size: {len(body)} bytes")
# Add query parameters
params = dict(request.query_params)
timeout_config = httpx.Timeout(
connect=30.0, # Connection timeout
read=600.0, # Read timeout: 10 minutes (was 30s)
write=30.0, # Write timeout
pool=30.0 # Pool timeout
)
async with httpx.AsyncClient(timeout=timeout_config) as client:
response = await client.request(
method=request.method,
url=url,
headers=headers,
content=body,
files=files,
data=data,
params=params
)