Initial commit - production deployment

This commit is contained in:
2026-01-21 17:17:16 +01:00
commit c23d00dd92
2289 changed files with 638440 additions and 0 deletions

View File

View File

@@ -0,0 +1,649 @@
# gateway/app/middleware/auth.py
"""
Enhanced Authentication Middleware for API Gateway with Tenant Access Control
FIXED VERSION - Proper JWT verification and token structure handling
"""
import structlog
from fastapi import Request, HTTPException, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from typing import Optional, Dict, Any
import httpx
import json
from app.core.config import settings
from app.core.header_manager import header_manager
from shared.auth.jwt_handler import JWTHandler
from shared.auth.tenant_access import tenant_access_manager, extract_tenant_id_from_path, is_tenant_scoped_path
logger = structlog.get_logger()
# JWT handler for local token validation - using SAME configuration as auth service
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
# Routes that don't require authentication
PUBLIC_ROUTES = [
"/health",
"/metrics",
"/docs",
"/redoc",
"/openapi.json",
"/api/v1/auth/login",
"/api/v1/auth/register",
"/api/v1/auth/refresh",
"/api/v1/auth/verify",
"/api/v1/auth/start-registration", # Registration step 1 - SetupIntent creation
"/api/v1/auth/complete-registration", # Registration step 2 - Completion after 3DS
"/api/v1/registration/payment-setup", # New registration payment setup endpoint
"/api/v1/registration/complete", # New registration completion endpoint
"/api/v1/registration/state/", # Registration state check
"/api/v1/auth/verify-email", # Email verification
"/api/v1/auth/password/reset-request", # Password reset request - no auth required
"/api/v1/auth/password/reset", # Password reset with token - no auth required
"/api/v1/nominatim/search",
"/api/v1/plans",
"/api/v1/demo/accounts",
"/api/v1/demo/sessions",
"/api/v1/webhooks/stripe", # Stripe webhook endpoint - bypasses auth for signature verification
"/api/v1/webhooks/generic", # Generic webhook endpoint
"/api/v1/telemetry/v1/traces", # Frontend telemetry traces - no auth for performance
"/api/v1/telemetry/v1/metrics", # Frontend telemetry metrics - no auth for performance
"/api/v1/telemetry/health" # Telemetry health check
]
# Routes accessible with demo session (no JWT required, just demo session header)
DEMO_ACCESSIBLE_ROUTES = [
"/api/v1/tenants/", # All tenant endpoints accessible in demo mode
]
class AuthMiddleware(BaseHTTPMiddleware):
"""
Enhanced Authentication Middleware with Tenant Access Control
"""
def __init__(self, app, redis_client=None):
super().__init__(app)
self.redis_client = redis_client # For caching and rate limiting
async def dispatch(self, request: Request, call_next) -> Response:
"""Process request with enhanced authentication and tenant access control"""
# Skip authentication for OPTIONS requests (CORS preflight)
if request.method == "OPTIONS":
return await call_next(request)
# SECURITY: Remove any incoming sensitive headers using HeaderManager
header_manager.sanitize_incoming_headers(request)
# Skip authentication for public routes
if self._is_public_route(request.url.path):
return await call_next(request)
# ✅ Check if demo middleware already set user context OR check query param for SSE
demo_session_header = request.headers.get("X-Demo-Session-Id")
demo_session_query = request.query_params.get("demo_session_id") # For SSE endpoint
logger.info(f"Auth check - path: {request.url.path}, demo_header: {demo_session_header}, demo_query: {demo_session_query}, has_demo_state: {hasattr(request.state, 'is_demo_session')}")
# For SSE endpoint with demo_session_id in query params, validate it here
if request.url.path == "/api/v1/events" and demo_session_query and not hasattr(request.state, "is_demo_session"):
logger.info(f"SSE endpoint with demo_session_id query param: {demo_session_query}")
# Validate demo session via demo-session service using JWT service token
import httpx
try:
# Create service token for gateway-to-demo-session communication
service_token = jwt_handler.create_service_token(service_name="gateway")
async with httpx.AsyncClient() as client:
response = await client.get(
f"http://demo-session-service:8000/api/v1/demo/sessions/{demo_session_query}",
headers={"Authorization": f"Bearer {service_token}"}
)
if response.status_code == 200:
session_data = response.json()
# Set demo session context
request.state.is_demo_session = True
request.state.user = {
"user_id": f"demo-user-{demo_session_query}",
"email": f"demo-{demo_session_query}@demo.local",
"tenant_id": session_data.get("virtual_tenant_id"),
"demo_session_id": demo_session_query,
}
request.state.tenant_id = session_data.get("virtual_tenant_id")
logger.info(f"✅ Demo session validated for SSE: {demo_session_query}")
else:
logger.warning(f"Invalid demo session for SSE: {demo_session_query}")
return JSONResponse(
status_code=401,
content={"detail": "Invalid demo session"}
)
except Exception as e:
logger.error(f"Failed to validate demo session for SSE: {e}")
return JSONResponse(
status_code=503,
content={"detail": "Demo session service unavailable"}
)
if hasattr(request.state, "is_demo_session") and request.state.is_demo_session:
if hasattr(request.state, "user") and request.state.user:
logger.info(f"✅ Demo session authenticated for route: {request.url.path}")
# Demo middleware already validated and set user context
# But we still need to inject context headers for downstream services
user_context = request.state.user
tenant_id = user_context.get("tenant_id") or getattr(request.state, "tenant_id", None)
# For demo sessions, get the actual subscription tier from the tenant service
# instead of always defaulting to enterprise
if not user_context.get("subscription_tier"):
subscription_tier = await self._get_tenant_subscription_tier(tenant_id, request)
if subscription_tier:
user_context["subscription_tier"] = subscription_tier
else:
# Fallback to enterprise for demo if no tier is found
user_context["subscription_tier"] = "enterprise"
logger.debug(f"Demo session subscription tier set to {user_context['subscription_tier']}", tenant_id=tenant_id)
await self._inject_context_headers(request, user_context, tenant_id)
return await call_next(request)
# ✅ STEP 1: Extract and validate JWT token
token = self._extract_token(request)
if not token:
logger.warning(f"❌ Missing token for protected route: {request.url.path}, demo_header: {demo_session_header}")
return JSONResponse(
status_code=401,
content={"detail": "Authentication required"}
)
# ✅ STEP 2: Verify token and get user context
user_context = await self._verify_token(token, request)
if not user_context:
logger.warning(f"Invalid token for route: {request.url.path}")
return JSONResponse(
status_code=401,
content={"detail": "User not authenticated"}
)
# ✅ STEP 3: Extract tenant context from URL using shared utility
tenant_id = extract_tenant_id_from_path(request.url.path)
# ✅ STEP 4: Verify tenant access if this is a tenant-scoped route
if tenant_id and is_tenant_scoped_path(request.url.path):
# Skip tenant access verification for service tokens (services have admin access)
if user_context.get("type") != "service":
# Use TenantAccessManager for gateway-level verification with caching
if self.redis_client and tenant_access_manager.redis_client is None:
tenant_access_manager.redis_client = self.redis_client
has_access = await tenant_access_manager.verify_basic_tenant_access(
user_context["user_id"],
tenant_id
)
if not has_access:
logger.warning(f"User {user_context['email']} denied access to tenant {tenant_id}")
return JSONResponse(
status_code=403,
content={"detail": f"Access denied to tenant {tenant_id}"}
)
else:
logger.debug(f"Service token granted access to tenant {tenant_id}",
service=user_context.get("service"))
# Get tenant subscription tier and inject into user context
# NEW: Use JWT data if available, skip HTTP call
if user_context.get("subscription_from_jwt"):
subscription_tier = user_context.get("subscription_tier")
logger.debug("Using subscription tier from JWT", tier=subscription_tier)
else:
# Only for old tokens - remove after full rollout
subscription_tier = await self._get_tenant_subscription_tier(tenant_id, request)
if subscription_tier:
user_context["subscription_tier"] = subscription_tier
# Check hierarchical access to determine access type and permissions
hierarchical_access = await tenant_access_manager.verify_hierarchical_access(
user_context["user_id"],
tenant_id
)
# Set tenant context in request state
request.state.tenant_id = tenant_id
request.state.tenant_verified = True
request.state.tenant_access_type = hierarchical_access.get("access_type", "direct")
request.state.can_view_children = hierarchical_access.get("can_view_children", False)
logger.debug(f"Tenant access verified",
user_id=user_context["user_id"],
tenant_id=tenant_id,
subscription_tier=subscription_tier,
access_type=hierarchical_access.get("access_type"),
can_view_children=hierarchical_access.get("can_view_children"),
path=request.url.path)
# ✅ STEP 5: Inject user context into request
request.state.user = user_context
request.state.authenticated = True
# ✅ STEP 6: Add context headers for downstream services
await self._inject_context_headers(request, user_context, tenant_id)
logger.debug(f"Authenticated request",
user_email=user_context['email'],
tenant_id=tenant_id,
path=request.url.path)
# Process the request
response = await call_next(request)
# Add token expiry warning header if token is near expiry
if hasattr(request.state, 'token_near_expiry') and request.state.token_near_expiry:
response.headers["X-Token-Refresh-Suggested"] = "true"
return response
def _is_public_route(self, path: str) -> bool:
"""Check if route requires authentication"""
return any(path.startswith(route) for route in PUBLIC_ROUTES)
def _extract_token(self, request: Request) -> Optional[str]:
"""
Extract JWT token from Authorization header or query params for SSE.
For SSE endpoints (/api/v1/events), browsers' EventSource API cannot send
custom headers, so we must accept token as query parameter.
For all other routes, token must be in Authorization header (more secure).
Security note: Query param tokens are logged. Use short expiry and filter logs.
"""
# SSE endpoint exception: token in query param (EventSource API limitation)
if request.url.path == "/api/v1/events":
token = request.query_params.get("token")
if token:
logger.debug("Token extracted from query param for SSE endpoint")
return token
logger.warning("SSE request missing token in query param")
return None
# Standard authentication: Authorization header for all other routes
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
return auth_header.split(" ")[1]
return None
async def _verify_token(self, token: str, request: Request = None) -> Optional[Dict[str, Any]]:
"""
Verify JWT token with improved fallback strategy
FIXED: Better error handling and token structure validation
"""
# Strategy 1: Try local JWT validation first (fast)
try:
payload = jwt_handler.verify_token(token)
if payload and self._validate_token_payload(payload):
logger.debug("Token validated locally")
# NEW: Check token freshness for subscription changes (async)
if payload.get("tenant_id") and request:
try:
is_fresh = await self._verify_token_freshness(payload, payload["tenant_id"])
if not is_fresh:
logger.warning("Stale token detected - subscription changed since token was issued",
user_id=payload.get("user_id"),
tenant_id=payload.get("tenant_id"))
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token is stale - subscription has changed"
)
except Exception as e:
logger.warning("Token freshness check failed, allowing token", error=str(e))
# Allow token if check fails (fail open for availability)
# Check if token is near expiry and set flag for response header
if request:
import time
exp_time = payload.get("exp", 0)
current_time = time.time()
time_until_expiry = exp_time - current_time
if time_until_expiry < 300: # 5 minutes
request.state.token_near_expiry = True
# Convert JWT payload to user context format
return self._jwt_payload_to_user_context(payload)
except Exception as e:
logger.debug(f"Local token validation failed: {e}")
# Strategy 2: Check cache for recently validated tokens
if self.redis_client:
try:
cached_user = await self._get_cached_user(token)
if cached_user:
logger.debug("Token found in cache")
return cached_user
except Exception as e:
logger.warning(f"Cache lookup failed: {e}")
# Strategy 3: Verify with auth service (authoritative)
try:
user_context = await self._verify_with_auth_service(token)
if user_context:
# Cache successful validation
if self.redis_client:
await self._cache_user(token, user_context)
logger.debug("Token validated by auth service")
return user_context
except Exception as e:
logger.error(f"Auth service validation failed: {e}")
return None
def _validate_token_payload(self, payload: Dict[str, Any]) -> bool:
"""
Validate JWT payload has required fields
FIXED: Updated to match actual token structure from auth service
"""
required_fields = ["user_id", "email", "exp", "type"]
missing_fields = [field for field in required_fields if field not in payload]
if missing_fields:
logger.warning(f"Token payload missing fields: {missing_fields}")
return False
# Validate token type
token_type = payload.get("type")
if token_type not in ["access", "service"]:
logger.warning(f"Invalid token type: {payload.get('type')}")
return False
# Check if token is near expiry (within 5 minutes) and log warning
import time
exp_time = payload.get("exp", 0)
current_time = time.time()
time_until_expiry = exp_time - current_time
if time_until_expiry < 300: # 5 minutes
logger.warning(f"Token expires in {int(time_until_expiry)} seconds for user {payload.get('email')}")
# NEW: Check token freshness for subscription changes
if payload.get("tenant_id"):
try:
# Note: We can't await here because this is a sync function
# Token freshness will be checked in the async dispatch method
# For now, just log that we would check freshness
logger.debug("Token freshness check would be performed in async context",
tenant_id=payload.get("tenant_id"))
except Exception as e:
logger.warning("Token freshness check setup failed", error=str(e))
# FIX: Validate service tokens with tenant context for tenant-scoped routes
if token_type == "service" and payload.get("tenant_id"):
# Service tokens with tenant context are valid for tenant-scoped operations
logger.debug("Service token with tenant context validated",
service=payload.get("service"), tenant_id=payload.get("tenant_id"))
return True
def _validate_jwt_integrity(self, payload: Dict[str, Any]) -> bool:
"""
Validate JWT payload integrity beyond signature verification.
Prevents edge cases where payload might be malformed.
"""
# Required fields must exist
required_fields = ["user_id", "email", "exp", "iat", "iss"]
if not all(field in payload for field in required_fields):
logger.warning("JWT missing required fields", missing=[f for f in required_fields if f not in payload])
return False
# Issuer must be our auth service
if payload.get("iss") != "bakery-auth":
logger.warning("JWT has invalid issuer", issuer=payload.get("iss"))
return False
# Token type must be valid
if payload.get("type") not in ["access", "service"]:
logger.warning("JWT has invalid type", token_type=payload.get("type"))
return False
# Subscription tier must be valid if present
valid_tiers = ["starter", "professional", "enterprise"]
if payload.get("subscription"):
tier = payload["subscription"].get("tier", "").lower()
if tier and tier not in valid_tiers:
logger.warning("JWT has invalid subscription tier", tier=tier)
return False
return True
async def _verify_token_freshness(self, payload: Dict[str, Any], tenant_id: str) -> bool:
"""
Verify token was issued after the last subscription change.
Prevents use of stale tokens with old subscription data.
"""
if not self.redis_client:
return True # Skip check if no Redis
try:
subscription_changed_at = await self.redis_client.get(
f"tenant:{tenant_id}:subscription_changed_at"
)
if subscription_changed_at:
changed_timestamp = float(subscription_changed_at)
token_issued_at = payload.get("iat", 0)
if token_issued_at < changed_timestamp:
logger.warning(
"Token issued before subscription change",
token_iat=token_issued_at,
subscription_changed=changed_timestamp,
tenant_id=tenant_id
)
return False # Token is stale
except Exception as e:
logger.warning("Failed to check token freshness", error=str(e))
return True
def _jwt_payload_to_user_context(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""
Convert JWT payload to user context format
FIXED: Proper mapping between JWT structure and user context
"""
# NEW: Validate JWT integrity before processing
if not self._validate_jwt_integrity(payload):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid JWT payload"
)
base_context = {
"user_id": payload["user_id"],
"email": payload["email"],
"exp": payload["exp"],
"valid": True,
"role": payload.get("role", "user"),
}
# NEW: Extract subscription from JWT
if payload.get("tenant_id"):
base_context["tenant_id"] = payload["tenant_id"]
base_context["tenant_role"] = payload.get("tenant_role", "member")
if payload.get("subscription"):
sub = payload["subscription"]
base_context["subscription_tier"] = sub.get("tier", "starter")
base_context["subscription_status"] = sub.get("status", "active")
base_context["subscription_from_jwt"] = True # Flag to skip HTTP
if payload.get("tenant_access"):
base_context["tenant_access"] = payload["tenant_access"]
if payload.get("service"):
service_name = payload["service"]
base_context["service"] = service_name
base_context["type"] = "service"
base_context["role"] = "admin"
base_context["user_id"] = f"{service_name}-service"
base_context["email"] = f"{service_name}-service@internal"
# FIX: Service tokens with tenant context should use that tenant_id
if payload.get("tenant_id"):
base_context["tenant_id"] = payload["tenant_id"]
logger.debug(f"Service authentication with tenant context: {service_name}, tenant_id: {payload['tenant_id']}")
else:
logger.debug(f"Service authentication: {service_name}")
return base_context
async def _verify_with_auth_service(self, token: str) -> Optional[Dict[str, Any]]:
"""
Verify token with auth service
FIXED: Improved error handling and response parsing
"""
try:
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.post(
f"{settings.AUTH_SERVICE_URL}/api/v1/auth/verify",
headers={"Authorization": f"Bearer {token}"}
)
if response.status_code == 200:
auth_response = response.json()
# Validate auth service response structure
if auth_response.get("valid") and auth_response.get("user_id"):
return {
"user_id": auth_response["user_id"],
"email": auth_response["email"],
"exp": auth_response.get("exp"),
"valid": True
}
else:
logger.warning(f"Auth service returned invalid response: {auth_response}")
return None
else:
logger.warning(f"Auth service returned {response.status_code}: {response.text}")
return None
except httpx.TimeoutException:
logger.error("Auth service timeout during token verification")
return None
except Exception as e:
logger.error(f"Auth service error: {e}")
return None
async def _get_cached_user(self, token: str) -> Optional[Dict[str, Any]]:
"""
Get user context from cache
FIXED: Better error handling and JSON parsing
"""
if not self.redis_client:
return None
cache_key = f"auth:token:{hash(token) % 1000000}" # Use modulo for shorter keys
try:
cached_data = await self.redis_client.get(cache_key)
if cached_data:
if isinstance(cached_data, bytes):
cached_data = cached_data.decode()
return json.loads(cached_data)
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse cached user data: {e}")
except Exception as e:
logger.warning(f"Cache lookup error: {e}")
return None
async def _cache_user(self, token: str, user_context: Dict[str, Any]) -> None:
"""
Cache user context
FIXED: Better error handling and expiration
"""
if not self.redis_client:
return
cache_key = f"auth:token:{hash(token) % 1000000}"
try:
# Cache for 5 minutes (shorter than token expiry)
await self.redis_client.setex(
cache_key,
300, # 5 minutes
json.dumps(user_context)
)
except Exception as e:
logger.warning(f"Failed to cache user context: {e}")
async def _inject_context_headers(self, request: Request, user_context: Dict[str, Any], tenant_id: Optional[str] = None):
"""
Inject user and tenant context headers for downstream services using unified HeaderManager
"""
# Use unified HeaderManager for consistent header injection
injected_headers = header_manager.inject_context_headers(request, user_context, tenant_id)
# Add hierarchical access headers if tenant context exists
if tenant_id:
# If this is hierarchical access, include parent tenant ID
# Get parent tenant ID from the auth service if available
try:
import httpx
async with httpx.AsyncClient(timeout=3.0) as client:
response = await client.get(
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}/hierarchy",
headers={"Authorization": request.headers.get("Authorization", "")}
)
if response.status_code == 200:
hierarchy_data = response.json()
parent_tenant_id = hierarchy_data.get("parent_tenant_id")
if parent_tenant_id:
# Add parent tenant ID using HeaderManager for consistency
header_name = header_manager.STANDARD_HEADERS['parent_tenant_id']
header_value = str(parent_tenant_id)
header_manager.add_header_for_middleware(request, header_name, header_value)
logger.info(f"Added parent tenant ID header: {parent_tenant_id}")
except Exception as e:
logger.warning(f"Failed to get parent tenant ID: {e}")
pass
return injected_headers
async def _get_tenant_subscription_tier(self, tenant_id: str, request: Request) -> Optional[str]:
"""
Get tenant subscription tier using fast cached endpoint
Args:
tenant_id: Tenant ID
request: FastAPI request for headers
Returns:
Subscription tier string or None
"""
try:
# Use fast cached subscription tier endpoint (has its own Redis caching)
async with httpx.AsyncClient(timeout=3.0) as client:
headers = {"Authorization": request.headers.get("Authorization", "")}
response = await client.get(
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}/subscription/tier",
headers=headers
)
if response.status_code == 200:
tier_data = response.json()
subscription_tier = tier_data.get("tier", "starter")
logger.debug("Subscription tier from cached endpoint",
tenant_id=tenant_id,
tier=subscription_tier,
cached=tier_data.get("cached", False))
return subscription_tier
else:
logger.warning(f"Failed to get tenant subscription tier: {response.status_code}")
return "starter" # Default to starter
except Exception as e:
logger.error(f"Error getting tenant subscription tier: {e}")
return "starter" # Default to starter on error

View File

@@ -0,0 +1,384 @@
"""
Demo Session Middleware
Handles demo account restrictions and virtual tenant injection
"""
from fastapi import Request, HTTPException
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from typing import Optional
import uuid
import httpx
import structlog
import json
logger = structlog.get_logger()
# Fixed Demo Tenant IDs (these are the template tenants that will be cloned)
# Professional demo (merged from San Pablo + La Espiga)
DEMO_TENANT_PROFESSIONAL = uuid.UUID("a1b2c3d4-e5f6-47a8-b9c0-d1e2f3a4b5c6")
# Enterprise chain demo (parent + 3 children)
DEMO_TENANT_ENTERPRISE_CHAIN = uuid.UUID("c3d4e5f6-a7b8-49c0-d1e2-f3a4b5c6d7e8")
DEMO_TENANT_CHILD_1 = uuid.UUID("d4e5f6a7-b8c9-40d1-e2f3-a4b5c6d7e8f9")
DEMO_TENANT_CHILD_2 = uuid.UUID("e5f6a7b8-c9d0-41e2-f3a4-b5c6d7e8f9a0")
DEMO_TENANT_CHILD_3 = uuid.UUID("f6a7b8c9-d0e1-42f3-a4b5-c6d7e8f9a0b1")
# Demo tenant IDs (base templates)
DEMO_TENANT_IDS = {
str(DEMO_TENANT_PROFESSIONAL), # Professional demo tenant
str(DEMO_TENANT_ENTERPRISE_CHAIN), # Enterprise chain parent
str(DEMO_TENANT_CHILD_1), # Enterprise chain child 1
str(DEMO_TENANT_CHILD_2), # Enterprise chain child 2
str(DEMO_TENANT_CHILD_3), # Enterprise chain child 3
}
# Demo user IDs - Maps demo account type to actual user UUIDs from fixture files
# These IDs are the owner IDs from the respective 01-tenant.json files
DEMO_USER_IDS = {
"professional": "c1a2b3c4-d5e6-47a8-b9c0-d1e2f3a4b5c6", # María García López (professional/01-tenant.json -> owner.id)
"enterprise": "d2e3f4a5-b6c7-48d9-e0f1-a2b3c4d5e6f7" # Director (enterprise/parent/01-tenant.json -> owner.id)
}
# Allowed operations for demo accounts (limited write)
DEMO_ALLOWED_OPERATIONS = {
# Read operations - all allowed
"GET": ["*"],
# Limited write operations for realistic testing
"POST": [
"/api/v1/pos/sales",
"/api/v1/pos/sessions",
"/api/v1/orders",
"/api/v1/inventory/adjustments",
"/api/v1/sales",
"/api/v1/production/batches",
"/api/v1/tenants/batch/sales-summary",
"/api/v1/tenants/batch/production-summary",
"/api/v1/auth/me/onboarding/complete", # Allow completing onboarding (no-op for demos)
"/api/v1/tenants/*/notifications/send", # Allow notifications (ML insights, alerts, etc.)
# Note: Forecast generation is explicitly blocked (see DEMO_BLOCKED_PATHS)
],
"PUT": [
"/api/v1/pos/sales/*",
"/api/v1/orders/*",
"/api/v1/inventory/stock/*",
"/api/v1/auth/me/onboarding/step", # Allow onboarding step updates (no-op for demos)
],
# Blocked operations
"DELETE": [], # No deletes allowed
"PATCH": [], # No patches allowed
}
# Explicitly blocked paths for demo accounts (even if method would be allowed)
# These require trained AI models which demo tenants don't have
DEMO_BLOCKED_PATHS = [
"/api/v1/forecasts/single",
"/api/v1/forecasts/multi-day",
"/api/v1/forecasts/batch",
]
DEMO_BLOCKED_PATH_MESSAGE = {
"forecasts": {
"message": "La generación de pronósticos no está disponible para cuentas demo. "
"Las cuentas demo no tienen modelos de IA entrenados.",
"message_en": "Forecast generation is not available for demo accounts. "
"Demo accounts do not have trained AI models.",
}
}
class DemoMiddleware(BaseHTTPMiddleware):
"""Middleware to handle demo session logic with Redis caching"""
def __init__(self, app, demo_session_url: str = "http://demo-session-service:8000"):
super().__init__(app)
self.demo_session_url = demo_session_url
self._redis_client = None
async def _get_redis_client(self):
"""Get or lazily initialize Redis client"""
if self._redis_client is None:
try:
from shared.redis_utils import get_redis_client
self._redis_client = await get_redis_client()
logger.debug("Demo middleware: Redis client initialized")
except Exception as e:
logger.warning(f"Demo middleware: Failed to get Redis client: {e}. Caching disabled.")
self._redis_client = False # Sentinel value to avoid retrying
return self._redis_client if self._redis_client is not False else None
async def _get_cached_session(self, session_id: str) -> Optional[dict]:
"""Get session info from Redis cache"""
try:
redis_client = await self._get_redis_client()
if not redis_client:
return None
cache_key = f"demo_session:{session_id}"
cached_data = await redis_client.get(cache_key)
if cached_data:
logger.debug("Demo middleware: Cache HIT", session_id=session_id)
return json.loads(cached_data)
else:
logger.debug("Demo middleware: Cache MISS", session_id=session_id)
return None
except Exception as e:
logger.warning(f"Demo middleware: Redis cache read error: {e}")
return None
async def _cache_session(self, session_id: str, session_info: dict, ttl: int = 30):
"""Cache session info in Redis with TTL"""
try:
redis_client = await self._get_redis_client()
if not redis_client:
return
cache_key = f"demo_session:{session_id}"
serialized = json.dumps(session_info)
await redis_client.setex(cache_key, ttl, serialized)
logger.debug(f"Demo middleware: Cached session {session_id} (TTL: {ttl}s)")
except Exception as e:
logger.warning(f"Demo middleware: Redis cache write error: {e}")
async def dispatch(self, request: Request, call_next) -> Response:
"""Process request through demo middleware"""
# Skip demo middleware for demo service endpoints
demo_service_paths = [
"/api/v1/demo/accounts",
"/api/v1/demo/sessions",
"/api/v1/demo/stats",
"/api/v1/demo/operations",
]
if any(request.url.path.startswith(path) or request.url.path == path for path in demo_service_paths):
return await call_next(request)
# Extract session ID from header or cookie
session_id = (
request.headers.get("X-Demo-Session-Id") or
request.cookies.get("demo_session_id")
)
logger.info(f"🎭 DemoMiddleware - path: {request.url.path}, session_id: {session_id}")
# Extract tenant ID from request
tenant_id = request.headers.get("X-Tenant-Id")
# Check if this is a demo session request
if session_id:
try:
# PERFORMANCE OPTIMIZATION: Check Redis cache first before HTTP call
session_info = await self._get_cached_session(session_id)
if not session_info:
# Cache miss - fetch from demo service
logger.debug("Demo middleware: Fetching from demo service", session_id=session_id)
session_info = await self._get_session_info(session_id)
# Cache the result if successful (30s TTL to balance freshness vs performance)
if session_info:
await self._cache_session(session_id, session_info, ttl=30)
# Accept pending, ready, partial, failed (if data exists), and active (deprecated) statuses
# Even "failed" sessions can be usable if some services succeeded
valid_statuses = ["pending", "ready", "partial", "failed", "active"]
current_status = session_info.get("status") if session_info else None
if session_info and current_status in valid_statuses:
# NOTE: Path transformation for demo-user removed.
# Frontend now receives the real demo_user_id from session creation
# and uses it directly in API calls.
# Inject virtual tenant ID
# Use scope state directly to avoid potential state property issues
request.scope.setdefault("state", {})
state = request.scope["state"]
state["tenant_id"] = session_info["virtual_tenant_id"]
state["is_demo_session"] = True
state["demo_account_type"] = session_info["demo_account_type"]
state["demo_session_status"] = current_status # Track status for monitoring
# Inject demo user context for auth middleware
# Uses DEMO_USER_IDS constant defined at module level
demo_user_id = DEMO_USER_IDS.get(
session_info.get("demo_account_type", "professional"),
DEMO_USER_IDS["professional"]
)
# This allows the request to pass through AuthMiddleware
# NEW: Extract subscription tier from demo account type
subscription_tier = "enterprise" if session_info.get("demo_account_type") == "enterprise" else "professional"
state["user"] = {
"user_id": demo_user_id, # Use actual demo user UUID
"email": f"demo-{session_id}@demo.local",
"tenant_id": session_info["virtual_tenant_id"],
"role": "owner", # Demo users have owner role
"is_demo": True,
"demo_session_id": session_id,
"demo_account_type": session_info.get("demo_account_type", "professional"),
"demo_session_status": current_status,
# NEW: Subscription context (no HTTP call needed!)
"subscription_tier": subscription_tier,
"subscription_status": "active",
"subscription_from_jwt": True # Flag to skip HTTP calls
}
# Update activity
await self._update_session_activity(session_id)
# Check if path is explicitly blocked
blocked_reason = self._check_blocked_path(request.url.path)
if blocked_reason:
return JSONResponse(
status_code=403,
content={
"error": "demo_restriction",
**blocked_reason,
"upgrade_url": "/pricing",
"session_expires_at": session_info.get("expires_at")
}
)
# Check if operation is allowed
if not self._is_operation_allowed(request.method, request.url.path):
return JSONResponse(
status_code=403,
content={
"error": "demo_restriction",
"message": "Esta operación no está permitida en cuentas demo. "
"Las sesiones demo se eliminan automáticamente después de 30 minutos. "
"Suscríbete para obtener acceso completo.",
"message_en": "This operation is not allowed in demo accounts. "
"Demo sessions are automatically deleted after 30 minutes. "
"Subscribe for full access.",
"upgrade_url": "/pricing",
"session_expires_at": session_info.get("expires_at")
}
)
else:
# Session expired, invalid, or in failed/destroyed state
logger.warning(f"Invalid demo session state", session_id=session_id, status=current_status)
return JSONResponse(
status_code=401,
content={
"error": "session_expired",
"message": "Tu sesión demo ha expirado. Crea una nueva sesión para continuar.",
"message_en": "Your demo session has expired. Create a new session to continue.",
"session_status": current_status
}
)
except Exception as e:
logger.error("Demo middleware error", error=str(e), session_id=session_id, path=request.url.path)
# On error, return 401 instead of continuing
return JSONResponse(
status_code=401,
content={
"error": "session_error",
"message": "Error validando sesión demo. Por favor, inténtalo de nuevo.",
"message_en": "Error validating demo session. Please try again."
}
)
# Check if this is a demo tenant (base template)
elif tenant_id in DEMO_TENANT_IDS:
# Direct access to demo tenant without session - block writes
request.scope.setdefault("state", {})
state = request.scope["state"]
state["is_demo_session"] = True
state["tenant_id"] = tenant_id
if request.method not in ["GET", "HEAD", "OPTIONS"]:
return JSONResponse(
status_code=403,
content={
"error": "demo_restriction",
"message": "Acceso directo al tenant demo no permitido. Crea una sesión demo.",
"message_en": "Direct access to demo tenant not allowed. Create a demo session."
}
)
# Proceed with request
response = await call_next(request)
# Add demo session header to response if demo session
if hasattr(request.state, "is_demo_session") and request.state.is_demo_session:
response.headers["X-Demo-Session"] = "true"
return response
async def _get_session_info(self, session_id: str) -> Optional[dict]:
"""Get session information from demo service using JWT service token"""
try:
# Create JWT service token for gateway-to-demo-session communication
from shared.auth.jwt_handler import JWTHandler
from app.core.config import settings
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
service_token = jwt_handler.create_service_token(service_name="gateway")
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.get(
f"{self.demo_session_url}/api/v1/demo/sessions/{session_id}",
headers={"Authorization": f"Bearer {service_token}"}
)
if response.status_code == 200:
return response.json()
else:
logger.warning("Demo session fetch failed",
session_id=session_id,
status_code=response.status_code,
response_text=response.text[:200] if hasattr(response, 'text') else '')
return None
except Exception as e:
logger.error("Failed to get session info", session_id=session_id, error=str(e))
return None
async def _update_session_activity(self, session_id: str):
"""Update session activity timestamp"""
# Note: Activity tracking is handled by the demo service internally
# No explicit endpoint needed - activity is updated on session access
pass
def _check_blocked_path(self, path: str) -> Optional[dict]:
"""Check if path is explicitly blocked for demo accounts"""
for blocked_path in DEMO_BLOCKED_PATHS:
if blocked_path in path:
# Determine which category of blocked path
if "forecast" in blocked_path:
return DEMO_BLOCKED_PATH_MESSAGE["forecasts"]
# Can add more categories here in the future
return {
"message": "Esta funcionalidad no está disponible para cuentas demo.",
"message_en": "This functionality is not available for demo accounts."
}
return None
def _is_operation_allowed(self, method: str, path: str) -> bool:
"""Check if method + path combination is allowed for demo"""
allowed_paths = DEMO_ALLOWED_OPERATIONS.get(method, [])
# Check for wildcard
if "*" in allowed_paths:
return True
# Check for exact match or pattern match
for allowed_path in allowed_paths:
if allowed_path.endswith("*"):
# Pattern match: /api/orders/* matches /api/orders/123
if path.startswith(allowed_path[:-1]):
return True
elif path == allowed_path:
# Exact match
return True
return False

View File

@@ -0,0 +1,57 @@
"""
Logging middleware for gateway
"""
import logging
import time
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
import uuid
logger = logging.getLogger(__name__)
class LoggingMiddleware(BaseHTTPMiddleware):
"""Logging middleware class"""
async def dispatch(self, request: Request, call_next) -> Response:
"""Process request with logging"""
start_time = time.time()
# Generate request ID
request_id = str(uuid.uuid4())
request.state.request_id = request_id
# Log request
logger.info(
f"Request: {request.method} {request.url.path}",
extra={
"method": request.method,
"url": request.url.path,
"query_params": str(request.query_params),
"client_host": request.client.host if request.client else "unknown",
"user_agent": request.headers.get("user-agent", ""),
"request_id": request_id
}
)
# Process request
response = await call_next(request)
# Calculate duration
duration = time.time() - start_time
# Log response
logger.info(
f"Response: {response.status_code} in {duration:.3f}s",
extra={
"status_code": response.status_code,
"duration": duration,
"method": request.method,
"url": request.url.path,
"request_id": request_id
}
)
return response

View File

@@ -0,0 +1,93 @@
"""
Rate limiting middleware for gateway
"""
import logging
import time
from fastapi import Request
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from typing import Dict, Optional
import asyncio
logger = logging.getLogger(__name__)
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Rate limiting middleware class"""
def __init__(self, app, calls_per_minute: int = 60):
super().__init__(app)
self.calls_per_minute = calls_per_minute
self.requests: Dict[str, list] = {}
self._cleanup_task = None
async def dispatch(self, request: Request, call_next) -> Response:
"""Process request with rate limiting"""
# Skip rate limiting for health checks
if request.url.path in ["/health", "/metrics"]:
return await call_next(request)
# Get client identifier
client_id = self._get_client_id(request)
# Check rate limit
if self._is_rate_limited(client_id):
logger.warning(f"Rate limit exceeded for client: {client_id}")
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded"}
)
# Record request
self._record_request(client_id)
# Process request
return await call_next(request)
def _get_client_id(self, request: Request) -> str:
"""Get client identifier"""
# Try to get user ID from state (if authenticated)
if hasattr(request.state, 'user') and request.state.user:
return f"user:{request.state.user.get('user_id', 'unknown')}"
# Fall back to IP address
return f"ip:{request.client.host if request.client else 'unknown'}"
def _is_rate_limited(self, client_id: str) -> bool:
"""Check if client is rate limited"""
now = time.time()
minute_ago = now - 60
# Get recent requests for this client
if client_id not in self.requests:
return False
# Filter requests from last minute
recent_requests = [
req_time for req_time in self.requests[client_id]
if req_time > minute_ago
]
# Update the list
self.requests[client_id] = recent_requests
# Check if limit exceeded
return len(recent_requests) >= self.calls_per_minute
def _record_request(self, client_id: str):
"""Record a request for rate limiting"""
now = time.time()
if client_id not in self.requests:
self.requests[client_id] = []
self.requests[client_id].append(now)
# Keep only last minute of requests
minute_ago = now - 60
self.requests[client_id] = [
req_time for req_time in self.requests[client_id]
if req_time > minute_ago
]

View File

@@ -0,0 +1,269 @@
"""
API Rate Limiting Middleware for Gateway
Enforces subscription-based API call quotas per hour
"""
import structlog
import shared.redis_utils
from datetime import datetime, timezone
from fastapi import Request, HTTPException, status
from starlette.middleware.base import BaseHTTPMiddleware
from typing import Optional
logger = structlog.get_logger()
class APIRateLimitMiddleware(BaseHTTPMiddleware):
"""
Middleware to enforce API rate limits based on subscription tier.
Quota limits per hour:
- Starter: 100 calls/hour
- Professional: 1,000 calls/hour
- Enterprise: 10,000 calls/hour
Uses Redis to track API calls with hourly buckets.
"""
def __init__(self, app, redis_client=None):
super().__init__(app)
self.redis_client = redis_client
async def dispatch(self, request: Request, call_next):
"""
Check API rate limit before processing request.
"""
# Skip rate limiting for certain paths
if self._should_skip_rate_limit(request.url.path):
return await call_next(request)
# Extract tenant_id from request
tenant_id = self._extract_tenant_id(request)
if not tenant_id:
# No tenant ID - skip rate limiting for auth/public endpoints
return await call_next(request)
try:
# Get subscription tier from headers (added by AuthMiddleware)
subscription_tier = request.headers.get("x-subscription-tier")
if not subscription_tier:
# Fallback: get from request state if headers not available
subscription_tier = getattr(request.state, "subscription_tier", None)
if not subscription_tier:
# Final fallback: get from tenant service (should rarely happen)
subscription_tier = await self._get_subscription_tier(tenant_id, request)
logger.warning(f"Subscription tier not found in headers or state, fetched from tenant service: {subscription_tier}")
# Get quota limit for tier
quota_limit = self._get_quota_limit(subscription_tier)
# Check and increment quota
allowed, current_count = await self._check_and_increment_quota(
tenant_id,
quota_limit
)
if not allowed:
logger.warning(
"API rate limit exceeded",
tenant_id=tenant_id,
subscription_tier=subscription_tier,
current_count=current_count,
quota_limit=quota_limit,
path=request.url.path
)
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail={
"error": "rate_limit_exceeded",
"message": f"API rate limit exceeded. Maximum {quota_limit} calls per hour allowed for {subscription_tier} plan.",
"current_count": current_count,
"quota_limit": quota_limit,
"reset_time": self._get_reset_time(),
"upgrade_required": subscription_tier in ['starter', 'professional']
}
)
# Add rate limit headers to response
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(quota_limit)
response.headers["X-RateLimit-Remaining"] = str(max(0, quota_limit - current_count))
response.headers["X-RateLimit-Reset"] = self._get_reset_time()
return response
except HTTPException:
raise
except Exception as e:
logger.error(
"Rate limiting check failed, allowing request",
tenant_id=tenant_id,
error=str(e),
path=request.url.path
)
# Fail open - allow request if rate limiting fails
return await call_next(request)
def _should_skip_rate_limit(self, path: str) -> bool:
"""
Determine if path should skip rate limiting.
"""
skip_paths = [
"/health",
"/metrics",
"/docs",
"/openapi.json",
"/api/v1/auth/",
"/api/v1/plans", # Public pricing info
]
for skip_path in skip_paths:
if path.startswith(skip_path):
return True
return False
def _extract_tenant_id(self, request: Request) -> Optional[str]:
"""
Extract tenant ID from request headers or path.
"""
# Try header first
tenant_id = request.headers.get("x-tenant-id")
if tenant_id:
return tenant_id
# Try to extract from path /api/v1/tenants/{tenant_id}/...
path_parts = request.url.path.split("/")
if "tenants" in path_parts:
try:
tenant_index = path_parts.index("tenants")
if len(path_parts) > tenant_index + 1:
return path_parts[tenant_index + 1]
except (ValueError, IndexError):
pass
return None
async def _get_subscription_tier(self, tenant_id: str, request: Request) -> str:
"""
Get subscription tier from tenant service (with caching).
"""
try:
# Try to get from request state (if subscription middleware already ran)
if hasattr(request.state, "subscription_tier"):
return request.state.subscription_tier
# Call tenant service to get tier
import httpx
from gateway.app.core.config import settings
async with httpx.AsyncClient(timeout=2.0) as client:
response = await client.get(
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}/subscription/tier",
headers={
"x-service": "gateway"
}
)
if response.status_code == 200:
data = response.json()
return data.get("tier", "starter")
except Exception as e:
logger.warning(
"Failed to get subscription tier, defaulting to starter",
tenant_id=tenant_id,
error=str(e)
)
return "starter"
def _get_quota_limit(self, subscription_tier: str) -> int:
"""
Get API calls per hour quota for subscription tier.
"""
quota_map = {
"starter": 100,
"professional": 1000,
"enterprise": 10000,
"demo": 1000, # Same as professional
}
return quota_map.get(subscription_tier.lower(), 100)
async def _check_and_increment_quota(
self,
tenant_id: str,
quota_limit: int
) -> tuple[bool, int]:
"""
Check current quota usage and increment counter.
Returns:
(allowed: bool, current_count: int)
"""
if not self.redis_client:
# No Redis - fail open
return True, 0
try:
# Create hourly bucket key
current_hour = datetime.now(timezone.utc).strftime("%Y-%m-%d-%H")
quota_key = f"quota:hourly:api_calls:{tenant_id}:{current_hour}"
# Get current count
current_count = await self.redis_client.get(quota_key)
current_count = int(current_count) if current_count else 0
# Check if within limit
if current_count >= quota_limit:
return False, current_count
# Increment counter
new_count = await self.redis_client.incr(quota_key)
# Set expiry (1 hour + 5 minutes buffer)
await self.redis_client.expire(quota_key, 3900)
return True, new_count
except Exception as e:
logger.error(
"Redis quota check failed",
tenant_id=tenant_id,
error=str(e)
)
# Fail open
return True, 0
def _get_reset_time(self) -> str:
"""
Get the reset time for the current hour bucket (top of next hour).
"""
from datetime import timedelta
now = datetime.now(timezone.utc)
next_hour = (now + timedelta(hours=1)).replace(minute=0, second=0, microsecond=0)
return next_hour.isoformat()
async def get_rate_limit_middleware(app):
"""
Factory function to create rate limiting middleware with Redis client.
"""
try:
from gateway.app.core.config import settings
redis_client = await shared.redis_utils.initialize_redis(settings.REDIS_URL)
logger.info("API rate limiting middleware initialized with Redis")
return APIRateLimitMiddleware(app, redis_client=redis_client)
except Exception as e:
logger.warning(
"Failed to initialize Redis for rate limiting, middleware will fail open",
error=str(e)
)
return APIRateLimitMiddleware(app, redis_client=None)

View File

@@ -0,0 +1,149 @@
"""
Gateway middleware to enforce read-only mode for subscriptions with status:
- pending_cancellation (until cancellation_effective_date)
- inactive (after cancellation or no active subscription)
Allowed operations in read-only mode:
- GET requests (all read operations)
- POST /api/v1/users/me/delete/request (account deletion)
- POST /api/v1/subscriptions/reactivate (subscription reactivation)
- POST /api/v1/subscriptions/* (subscription management)
"""
import httpx
import logging
from fastapi import Request, HTTPException, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from typing import Optional
import re
logger = logging.getLogger(__name__)
# Whitelist of POST/PUT/DELETE endpoints allowed in read-only mode
READ_ONLY_WHITELIST_PATTERNS = [
r'^/api/v1/users/me/delete/request$',
r'^/api/v1/users/me/export.*$',
r'^/api/v1/tenants/.*/subscription/.*', # All tenant subscription endpoints
r'^/api/v1/registration/.*', # Registration flow endpoints
r'^/api/v1/auth/.*', # Allow auth operations
r'^/api/v1/tenants/register$', # Allow new tenant registration (no existing tenant context)
r'^/api/v1/tenants/.*/orchestrator/run-daily-workflow$', # Allow workflow testing
r'^/api/v1/tenants/.*/inventory/ml/insights/.*', # Allow ML insights (safety stock optimization)
r'^/api/v1/tenants/.*/production/ml/insights/.*', # Allow ML insights (yield prediction)
r'^/api/v1/tenants/.*/procurement/ml/insights/.*', # Allow ML insights (supplier analysis, price forecasting)
r'^/api/v1/tenants/.*/forecasting/ml/insights/.*', # Allow ML insights (rules generation)
r'^/api/v1/tenants/.*/forecasting/operations/.*', # Allow forecasting operations
r'^/api/v1/webhooks/.*', # Webhook endpoints - no tenant context
]
class ReadOnlyModeMiddleware(BaseHTTPMiddleware):
"""
Middleware to enforce read-only mode based on subscription status
"""
def __init__(self, app, tenant_service_url: str = "http://tenant-service:8000"):
super().__init__(app)
self.tenant_service_url = tenant_service_url
self.cache = {}
self.cache_ttl = 60
async def check_subscription_status(self, tenant_id: str, authorization: str) -> dict:
"""
Check subscription status from tenant service
Returns subscription data including status and read_only flag
"""
try:
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.get(
f"{self.tenant_service_url}/api/v1/tenants/{tenant_id}/subscription/status",
headers={"Authorization": authorization}
)
if response.status_code == 200:
return response.json()
elif response.status_code == 404:
return {"status": "inactive", "is_read_only": True}
else:
logger.warning(
f"Failed to check subscription status: {response.status_code}",
extra={"tenant_id": tenant_id}
)
return {"status": "unknown", "is_read_only": False}
except Exception as e:
logger.error(
f"Error checking subscription status: {e}",
extra={"tenant_id": tenant_id}
)
return {"status": "unknown", "is_read_only": False}
def is_whitelisted_endpoint(self, path: str) -> bool:
"""
Check if endpoint is whitelisted for read-only mode
"""
for pattern in READ_ONLY_WHITELIST_PATTERNS:
if re.match(pattern, path):
return True
return False
def is_write_operation(self, method: str) -> bool:
"""
Determine if HTTP method is a write operation
"""
return method.upper() in ['POST', 'PUT', 'DELETE', 'PATCH']
async def dispatch(self, request: Request, call_next):
"""
Process each request through read-only mode check
"""
tenant_id = request.headers.get("X-Tenant-ID")
authorization = request.headers.get("Authorization")
path = request.url.path
method = request.method
if not tenant_id or not authorization:
return await call_next(request)
if method.upper() == 'GET':
return await call_next(request)
if self.is_whitelisted_endpoint(path):
return await call_next(request)
if self.is_write_operation(method):
subscription_data = await self.check_subscription_status(tenant_id, authorization)
if subscription_data.get("is_read_only", False):
status_detail = subscription_data.get("status", "inactive")
effective_date = subscription_data.get("cancellation_effective_date")
error_message = {
"detail": "Account is in read-only mode",
"reason": f"Subscription status: {status_detail}",
"message": "Your subscription has been cancelled. You can view data but cannot make changes.",
"action_required": "Reactivate your subscription to regain full access",
"reactivation_url": "/app/settings/subscription"
}
if effective_date:
error_message["read_only_until"] = effective_date
error_message["message"] = f"Your subscription is pending cancellation. Read-only mode starts on {effective_date}."
logger.info(
"read_only_mode_enforced",
extra={
"tenant_id": tenant_id,
"path": path,
"method": method,
"subscription_status": status_detail
}
)
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content=error_message
)
return await call_next(request)

View File

@@ -0,0 +1,83 @@
"""
Request ID Middleware for distributed tracing
Generates and propagates unique request IDs across all services
"""
import uuid
import structlog
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from app.core.header_manager import header_manager
logger = structlog.get_logger()
class RequestIDMiddleware(BaseHTTPMiddleware):
"""
Middleware to generate and propagate request IDs for distributed tracing.
Request IDs are:
- Generated if not provided by client
- Logged with every request
- Propagated to all downstream services
- Returned in response headers
"""
def __init__(self, app):
super().__init__(app)
async def dispatch(self, request: Request, call_next) -> Response:
"""Process request with request ID tracking"""
# Extract or generate request ID
request_id = request.headers.get("X-Request-ID")
if not request_id:
request_id = str(uuid.uuid4())
# Store in request state for access by routes
request.state.request_id = request_id
# Bind request ID to structured logger context
logger_ctx = logger.bind(request_id=request_id)
# Inject request ID header for downstream services using HeaderManager
# Note: This runs early in middleware chain, so we use add_header_for_middleware
header_manager.add_header_for_middleware(request, "x-request-id", request_id)
# Log request start
logger_ctx.info(
"Request started",
method=request.method,
path=request.url.path,
client_ip=request.client.host if request.client else None
)
try:
# Process request
response = await call_next(request)
# Add request ID to response headers
response.headers["X-Request-ID"] = request_id
# Log request completion
logger_ctx.info(
"Request completed",
method=request.method,
path=request.url.path,
status_code=response.status_code
)
return response
except Exception as e:
# Log request failure
logger_ctx.error(
"Request failed",
method=request.method,
path=request.url.path,
error=str(e),
error_type=type(e).__name__
)
raise

View File

@@ -0,0 +1,462 @@
"""
Subscription Middleware - Enforces subscription limits and feature access
Updated to support standardized URL structure with tier-based access control
"""
import re
import json
import structlog
from fastapi import Request, Response, HTTPException
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
import httpx
from typing import Dict, Any, Optional, List
import asyncio
from datetime import datetime, timezone
from app.core.config import settings
from app.core.header_manager import header_manager
from app.utils.subscription_error_responses import create_upgrade_required_response
logger = structlog.get_logger()
class SubscriptionMiddleware(BaseHTTPMiddleware):
"""
Middleware to enforce subscription-based access control
Supports standardized URL structure:
- Base routes (/api/v1/tenants/{tenant_id}/{service}/{resource}): ALL tiers
- Dashboard routes (/api/v1/tenants/{tenant_id}/{service}/dashboard/*): ALL tiers
- Analytics routes (/api/v1/tenants/{tenant_id}/{service}/analytics/*): PROFESSIONAL+
- Operations routes (/api/v1/tenants/{tenant_id}/{service}/operations/*): ALL tiers (role-based)
"""
def __init__(self, app, tenant_service_url: str, redis_client=None):
super().__init__(app)
self.tenant_service_url = tenant_service_url.rstrip('/')
self.redis_client = redis_client # Optional Redis client for abuse detection
# Define route patterns that require subscription validation
# Using new standardized URL structure
self.protected_routes = {
# ===== ANALYTICS ROUTES - PROFESSIONAL/ENTERPRISE ONLY =====
# Any service analytics endpoint
r'^/api/v1/tenants/[^/]+/[^/]+/analytics/.*': {
'feature': 'analytics',
'minimum_tier': 'professional',
'allowed_tiers': ['professional', 'enterprise'],
'description': 'Analytics features (Professional/Enterprise only)'
},
# ===== TRAINING SERVICE - ALL TIERS =====
r'^/api/v1/tenants/[^/]+/training/.*': {
'feature': 'ml_training',
'minimum_tier': 'basic',
'allowed_tiers': ['basic', 'professional', 'enterprise'],
'description': 'Machine learning model training (Available for all tiers)'
},
# ===== ADVANCED FEATURES - PROFESSIONAL/ENTERPRISE =====
# Advanced reporting and exports
r'^/api/v1/tenants/[^/]+/[^/]+/export/advanced.*': {
'feature': 'advanced_exports',
'minimum_tier': 'professional',
'allowed_tiers': ['professional', 'enterprise'],
'description': 'Advanced export formats (Professional/Enterprise only)'
},
# Bulk operations
r'^/api/v1/tenants/[^/]+/[^/]+/bulk/.*': {
'feature': 'bulk_operations',
'minimum_tier': 'professional',
'allowed_tiers': ['professional', 'enterprise'],
'description': 'Bulk operations (Professional/Enterprise only)'
},
}
# Routes that are explicitly allowed for all tiers (no check needed)
self.public_tier_routes = [
# Base CRUD operations - ALL TIERS
r'^/api/v1/tenants/[^/]+/[^/]+/(?!analytics|export/advanced|bulk)[^/]+/?$',
r'^/api/v1/tenants/[^/]+/[^/]+/(?!analytics|export/advanced|bulk)[^/]+/[^/]+/?$',
# Dashboard routes - ALL TIERS
r'^/api/v1/tenants/[^/]+/[^/]+/dashboard/.*',
# Operations routes - ALL TIERS (role-based control applies)
r'^/api/v1/tenants/[^/]+/[^/]+/operations/.*',
]
async def dispatch(self, request: Request, call_next):
"""Process the request and check subscription requirements"""
# Skip subscription check for certain routes
if self._should_skip_subscription_check(request):
return await call_next(request)
# Check if route is explicitly allowed for all tiers
if self._is_public_tier_route(request.url.path):
return await call_next(request)
# Check if route requires subscription validation
subscription_requirement = self._get_subscription_requirement(request.url.path)
if not subscription_requirement:
return await call_next(request)
# Get tenant ID from request
tenant_id = self._extract_tenant_id(request)
if not tenant_id:
return JSONResponse(
status_code=400,
content={
"error": "subscription_validation_failed",
"message": "Tenant ID required for subscription validation",
"code": "MISSING_TENANT_ID"
}
)
# Validate subscription with new tier-based system
validation_result = await self._validate_subscription_tier(
request,
tenant_id,
subscription_requirement.get('feature'),
subscription_requirement.get('minimum_tier'),
subscription_requirement.get('allowed_tiers', [])
)
if not validation_result['allowed']:
# Use enhanced error response with conversion optimization
feature = subscription_requirement.get('feature')
current_tier = validation_result.get('current_tier', 'unknown')
required_tier = subscription_requirement.get('minimum_tier')
allowed_tiers = subscription_requirement.get('allowed_tiers', [])
# Create conversion-optimized error response
enhanced_response = create_upgrade_required_response(
feature=feature,
current_tier=current_tier,
required_tier=required_tier,
allowed_tiers=allowed_tiers,
custom_message=validation_result.get('message')
)
return JSONResponse(
status_code=enhanced_response.status_code,
content=enhanced_response.dict()
)
# Subscription validation passed, continue with request
response = await call_next(request)
return response
def _is_public_tier_route(self, path: str) -> bool:
"""
Check if route is explicitly allowed for all subscription tiers
Args:
path: Request path
Returns:
True if route is allowed for all tiers
"""
for pattern in self.public_tier_routes:
if re.match(pattern, path):
logger.debug("Route allowed for all tiers", path=path, pattern=pattern)
return True
return False
def _should_skip_subscription_check(self, request: Request) -> bool:
"""Check if subscription validation should be skipped"""
path = request.url.path
method = request.method
# Skip for health checks, auth, and public routes
skip_patterns = [
r'/health.*',
r'/metrics.*',
r'/api/v1/auth/.*',
r'/api/v1/tenants/[^/]+/subscription/.*', # All tenant subscription endpoints
r'/api/v1/registration/.*', # Registration flow endpoints
r'/api/v1/tenants/[^/]+/members.*', # Basic tenant info
r'/api/v1/webhooks/.*', # Webhook endpoints - no tenant context
r'/docs.*',
r'/openapi\.json',
# Training monitoring endpoints (WebSocket and status checks)
r'/api/v1/tenants/[^/]+/training/jobs/.*/live.*', # WebSocket endpoint
r'/api/v1/tenants/[^/]+/training/jobs/.*/status.*', # Status polling endpoint
]
# Skip OPTIONS requests (CORS preflight)
if method == "OPTIONS":
return True
for pattern in skip_patterns:
if re.match(pattern, path):
return True
return False
def _get_subscription_requirement(self, path: str) -> Optional[Dict[str, str]]:
"""Get subscription requirement for a given path"""
for pattern, requirement in self.protected_routes.items():
if re.match(pattern, path):
return requirement
return None
def _extract_tenant_id(self, request: Request) -> Optional[str]:
"""Extract tenant ID from request"""
# Try to get from URL path first
path_match = re.search(r'/api/v1/tenants/([^/]+)/', request.url.path)
if path_match:
return path_match.group(1)
# Try to get from headers
tenant_id = request.headers.get('x-tenant-id')
if tenant_id:
return tenant_id
# Try to get from user state (set by auth middleware)
if hasattr(request.state, 'user') and request.state.user:
return request.state.user.get('tenant_id')
return None
async def _validate_subscription_tier(
self,
request: Request,
tenant_id: str,
feature: Optional[str],
minimum_tier: str,
allowed_tiers: List[str]
) -> Dict[str, Any]:
"""
Validate subscription tier access using cached subscription lookup
Args:
request: FastAPI request
tenant_id: Tenant ID
feature: Feature name (optional, for additional checks)
minimum_tier: Minimum required subscription tier
allowed_tiers: List of allowed subscription tiers
Returns:
Dict with 'allowed' boolean and additional metadata
"""
try:
# Check if JWT already has subscription
if hasattr(request.state, 'user') and request.state.user:
user_context = request.state.user
user_id = user_context.get('user_id', 'unknown')
if user_context.get("subscription_from_jwt"):
# Use JWT data directly - NO HTTP CALL!
current_tier = user_context.get("subscription_tier", "starter")
logger.debug("Using subscription tier from JWT (no HTTP call)",
tenant_id=tenant_id,
current_tier=current_tier,
minimum_tier=minimum_tier,
allowed_tiers=allowed_tiers)
if current_tier not in [tier.lower() for tier in allowed_tiers]:
tier_names = ', '.join(allowed_tiers)
return {
'allowed': False,
'message': f'This feature requires a {tier_names} subscription plan',
'current_tier': current_tier
}
await self._log_subscription_access(
tenant_id,
user_id,
feature,
current_tier,
True,
"jwt"
)
return {
'allowed': True,
'message': 'Access granted (JWT subscription)',
'current_tier': current_tier
}
# Use unified HeaderManager for consistent header handling
headers = header_manager.get_all_headers_for_proxy(request)
# Extract user_id for logging (fallback path)
user_id = header_manager.get_header_value(request, 'x-user-id', 'unknown')
# Call tenant service fast tier endpoint with caching (fallback for old tokens)
timeout_config = httpx.Timeout(
connect=1.0, # Connection timeout - very short for cached endpoint
read=5.0, # Read timeout - short for cached lookup
write=1.0, # Write timeout
pool=1.0 # Pool timeout
)
async with httpx.AsyncClient(timeout=timeout_config) as client:
# Use fast cached tier endpoint (new URL pattern)
tier_response = await client.get(
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}/subscription/tier",
headers=headers
)
if tier_response.status_code != 200:
logger.warning(
"Failed to get subscription tier from cache",
tenant_id=tenant_id,
status_code=tier_response.status_code,
response_text=tier_response.text
)
# Fail open for availability
return {
'allowed': True,
'message': 'Access granted (validation service unavailable)',
'current_tier': 'unknown'
}
tier_data = tier_response.json()
current_tier = tier_data.get('tier', 'starter').lower()
logger.debug("Subscription tier check (cached)",
tenant_id=tenant_id,
current_tier=current_tier,
minimum_tier=minimum_tier,
allowed_tiers=allowed_tiers,
cached=tier_data.get('cached', False))
# Check if current tier is in allowed tiers
if current_tier not in [tier.lower() for tier in allowed_tiers]:
tier_names = ', '.join(allowed_tiers)
await self._log_subscription_access(
tenant_id,
user_id,
feature,
current_tier,
False,
"jwt"
)
return {
'allowed': False,
'message': f'This feature requires a {tier_names} subscription plan',
'current_tier': current_tier
}
# Tier check passed
await self._log_subscription_access(
tenant_id,
user_id,
feature,
current_tier,
True,
"database"
)
return {
'allowed': True,
'message': 'Access granted',
'current_tier': current_tier
}
except asyncio.TimeoutError:
logger.error(
"Timeout validating subscription",
tenant_id=tenant_id,
feature=feature
)
# Fail open for availability (let service handle detailed check)
return {
'allowed': True,
'message': 'Access granted (validation timeout)',
'current_plan': 'unknown'
}
except httpx.RequestError as e:
logger.error(
"Request error validating subscription",
tenant_id=tenant_id,
feature=feature,
error=str(e)
)
# Fail open for availability
return {
'allowed': True,
'message': 'Access granted (validation service unavailable)',
'current_plan': 'unknown'
}
except Exception as e:
logger.error(
"Subscription validation error",
tenant_id=tenant_id,
feature=feature,
error=str(e)
)
# Fail open for availability (let service handle detailed check)
return {
'allowed': True,
'message': 'Access granted (validation error)',
'current_plan': 'unknown'
}
async def _log_subscription_access(
self,
tenant_id: str,
user_id: str,
requested_feature: str,
current_tier: str,
access_granted: bool,
source: str # "jwt" or "database"
):
"""
Log all subscription-gated access attempts for audit and anomaly detection.
"""
logger.info(
"Subscription access check",
tenant_id=tenant_id,
user_id=user_id,
feature=requested_feature,
tier=current_tier,
granted=access_granted,
source=source,
timestamp=datetime.now(timezone.utc).isoformat()
)
# For denied access, check for suspicious patterns
if not access_granted:
await self._check_for_abuse_patterns(tenant_id, user_id, requested_feature)
async def _check_for_abuse_patterns(
self,
tenant_id: str,
user_id: str,
feature: str
):
"""
Detect potential abuse patterns like repeated premium feature access attempts.
"""
if not self.redis_client:
return
# Track denied attempts in a sliding window
key = f"subscription_denied:{tenant_id}:{user_id}:{feature}"
try:
attempts = await self.redis_client.incr(key)
if attempts == 1:
await self.redis_client.expire(key, 3600) # 1 hour window
# Alert if too many denied attempts (potential bypass attempt)
if attempts > 10:
logger.warning(
"SECURITY: Excessive premium feature access attempts detected",
tenant_id=tenant_id,
user_id=user_id,
feature=feature,
attempts=attempts,
window="1 hour"
)
# Could trigger alert to security team here
except Exception as e:
logger.warning("Failed to track abuse patterns", error=str(e))