New token arch

This commit is contained in:
Urtzi Alfaro
2026-01-10 21:45:37 +01:00
parent cc53037552
commit bf1db7cb9e
26 changed files with 1751 additions and 107 deletions

View File

@@ -12,6 +12,7 @@ from starlette.middleware.base import BaseHTTPMiddleware
import httpx
from typing import Dict, Any, Optional, List
import asyncio
from datetime import datetime, timezone
from app.core.config import settings
from app.utils.subscription_error_responses import create_upgrade_required_response
@@ -30,9 +31,10 @@ class SubscriptionMiddleware(BaseHTTPMiddleware):
- Operations routes (/api/v1/tenants/{tenant_id}/{service}/operations/*): ALL tiers (role-based)
"""
def __init__(self, app, tenant_service_url: str):
def __init__(self, app, tenant_service_url: str, redis_client=None):
super().__init__(app)
self.tenant_service_url = tenant_service_url.rstrip('/')
self.redis_client = redis_client # Optional Redis client for abuse detection
# Define route patterns that require subscription validation
# Using new standardized URL structure
@@ -236,20 +238,60 @@ class SubscriptionMiddleware(BaseHTTPMiddleware):
Dict with 'allowed' boolean and additional metadata
"""
try:
# Use the same authentication pattern as gateway routes
# Check if JWT already has subscription
if hasattr(request.state, 'user') and request.state.user:
user_context = request.state.user
user_id = user_context.get('user_id', 'unknown')
if user_context.get("subscription_from_jwt"):
# Use JWT data directly - NO HTTP CALL!
current_tier = user_context.get("subscription_tier", "starter")
logger.debug("Using subscription tier from JWT (no HTTP call)",
tenant_id=tenant_id,
current_tier=current_tier,
minimum_tier=minimum_tier,
allowed_tiers=allowed_tiers)
if current_tier not in [tier.lower() for tier in allowed_tiers]:
tier_names = ', '.join(allowed_tiers)
return {
'allowed': False,
'message': f'This feature requires a {tier_names} subscription plan',
'current_tier': current_tier
}
await self._log_subscription_access(
tenant_id,
user_id,
feature,
current_tier,
True,
"jwt"
)
return {
'allowed': True,
'message': 'Access granted (JWT subscription)',
'current_tier': current_tier
}
# Use the same authentication pattern as gateway routes for fallback
headers = dict(request.headers)
headers.pop("host", None)
# Extract user_id for logging (fallback path)
user_id = 'unknown'
# Add user context headers if available
if hasattr(request.state, 'user') and request.state.user:
user = request.state.user
headers["x-user-id"] = str(user.get('user_id', ''))
user_id = str(user.get('user_id', 'unknown'))
headers["x-user-id"] = user_id
headers["x-user-email"] = str(user.get('email', ''))
headers["x-user-role"] = str(user.get('role', 'user'))
headers["x-user-full-name"] = str(user.get('full_name', ''))
headers["x-tenant-id"] = str(user.get('tenant_id', ''))
# Call tenant service fast tier endpoint with caching
# Call tenant service fast tier endpoint with caching (fallback for old tokens)
timeout_config = httpx.Timeout(
connect=1.0, # Connection timeout - very short for cached endpoint
read=5.0, # Read timeout - short for cached lookup
@@ -291,6 +333,15 @@ class SubscriptionMiddleware(BaseHTTPMiddleware):
# Check if current tier is in allowed tiers
if current_tier not in [tier.lower() for tier in allowed_tiers]:
tier_names = ', '.join(allowed_tiers)
await self._log_subscription_access(
tenant_id,
user_id,
feature,
current_tier,
False,
"jwt"
)
return {
'allowed': False,
'message': f'This feature requires a {tier_names} subscription plan',
@@ -298,6 +349,15 @@ class SubscriptionMiddleware(BaseHTTPMiddleware):
}
# Tier check passed
await self._log_subscription_access(
tenant_id,
user_id,
feature,
current_tier,
True,
"database"
)
return {
'allowed': True,
'message': 'Access granted',
@@ -343,3 +403,64 @@ class SubscriptionMiddleware(BaseHTTPMiddleware):
'current_plan': 'unknown'
}
async def _log_subscription_access(
self,
tenant_id: str,
user_id: str,
requested_feature: str,
current_tier: str,
access_granted: bool,
source: str # "jwt" or "database"
):
"""
Log all subscription-gated access attempts for audit and anomaly detection.
"""
logger.info(
"Subscription access check",
tenant_id=tenant_id,
user_id=user_id,
feature=requested_feature,
tier=current_tier,
granted=access_granted,
source=source,
timestamp=datetime.now(timezone.utc).isoformat()
)
# For denied access, check for suspicious patterns
if not access_granted:
await self._check_for_abuse_patterns(tenant_id, user_id, requested_feature)
async def _check_for_abuse_patterns(
self,
tenant_id: str,
user_id: str,
feature: str
):
"""
Detect potential abuse patterns like repeated premium feature access attempts.
"""
if not self.redis_client:
return
# Track denied attempts in a sliding window
key = f"subscription_denied:{tenant_id}:{user_id}:{feature}"
try:
attempts = await self.redis_client.incr(key)
if attempts == 1:
await self.redis_client.expire(key, 3600) # 1 hour window
# Alert if too many denied attempts (potential bypass attempt)
if attempts > 10:
logger.warning(
"SECURITY: Excessive premium feature access attempts detected",
tenant_id=tenant_id,
user_id=user_id,
feature=feature,
attempts=attempts,
window="1 hour"
)
# Could trigger alert to security team here
except Exception as e:
logger.warning("Failed to track abuse patterns", error=str(e))