Add improvements 2
This commit is contained in:
120
DOCKER_MAINTENANCE.md
Normal file
120
DOCKER_MAINTENANCE.md
Normal file
@@ -0,0 +1,120 @@
|
||||
# Docker Maintenance Guide for Local Development
|
||||
|
||||
## The Problem
|
||||
|
||||
When developing with Tilt and local Kubernetes (Kind), Docker accumulates:
|
||||
- **Multiple image versions** from each code change (Tilt rebuilds)
|
||||
- **Unused volumes** from previous cluster runs
|
||||
- **Build cache** that grows over time
|
||||
|
||||
This quickly fills up disk space, causing pods to fail with "No space left on device" errors.
|
||||
|
||||
## Quick Fix (When You Hit Disk Issues)
|
||||
|
||||
```bash
|
||||
# Clean up all unused Docker resources
|
||||
docker system prune -a --volumes -f
|
||||
```
|
||||
|
||||
This removes:
|
||||
- All unused images
|
||||
- All unused volumes
|
||||
- All build cache
|
||||
|
||||
**Expected recovery**: 60-100GB
|
||||
|
||||
## Regular Maintenance
|
||||
|
||||
### Option 1: Use the Cleanup Script (Recommended)
|
||||
|
||||
Run the maintenance script weekly:
|
||||
|
||||
```bash
|
||||
./scripts/cleanup-docker.sh
|
||||
```
|
||||
|
||||
Or run it automatically without confirmation:
|
||||
|
||||
```bash
|
||||
./scripts/cleanup-docker.sh --auto
|
||||
```
|
||||
|
||||
### Option 2: Manual Commands
|
||||
|
||||
```bash
|
||||
# Remove images older than 24 hours
|
||||
docker image prune -af --filter "until=24h"
|
||||
|
||||
# Remove unused volumes
|
||||
docker volume prune -f
|
||||
|
||||
# Remove build cache
|
||||
docker builder prune -af
|
||||
```
|
||||
|
||||
### Option 3: Set Up Automated Cleanup
|
||||
|
||||
Add to your crontab (run every Sunday at 2 AM):
|
||||
|
||||
```bash
|
||||
crontab -e
|
||||
# Add this line:
|
||||
0 2 * * 0 /Users/urtzialfaro/Documents/bakery-ia/scripts/cleanup-docker.sh --auto >> /tmp/docker-cleanup.log 2>&1
|
||||
```
|
||||
|
||||
## Monitoring Disk Usage
|
||||
|
||||
### Check Docker disk usage:
|
||||
```bash
|
||||
docker system df
|
||||
```
|
||||
|
||||
### Check Kind node disk usage:
|
||||
```bash
|
||||
docker exec bakery-ia-local-control-plane df -h /var
|
||||
```
|
||||
|
||||
### Alert thresholds:
|
||||
- **< 70%**: Healthy ✅
|
||||
- **70-85%**: Consider cleanup soon ⚠️
|
||||
- **> 85%**: Run cleanup immediately 🚨
|
||||
- **> 95%**: Critical - pods will fail ❌
|
||||
|
||||
## Prevention Tips
|
||||
|
||||
1. **Run cleanup weekly** to prevent accumulation
|
||||
2. **Monitor disk usage** before long dev sessions
|
||||
3. **Delete old Kind clusters** when switching projects:
|
||||
```bash
|
||||
kind delete cluster --name bakery-ia-local
|
||||
```
|
||||
4. **Increase Docker disk allocation** in Docker Desktop settings if you frequently rebuild many services
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Pods in CrashLoopBackOff after disk issues:
|
||||
|
||||
1. Run cleanup (see Quick Fix above)
|
||||
2. Restart failed pods:
|
||||
```bash
|
||||
kubectl get pods -n bakery-ia | grep -E "(CrashLoopBackOff|Error)" | awk '{print $1}' | xargs kubectl delete pod -n bakery-ia
|
||||
```
|
||||
|
||||
### Cleanup didn't free enough space:
|
||||
|
||||
If still above 90% after cleanup:
|
||||
|
||||
```bash
|
||||
# Nuclear option - rebuild everything
|
||||
kind delete cluster --name bakery-ia-local
|
||||
docker system prune -a --volumes -f
|
||||
# Then recreate cluster with your setup scripts
|
||||
```
|
||||
|
||||
## What Happened Today (2026-01-12)
|
||||
|
||||
- **Issue**: Disk was 100% full (113GB/113GB), causing database pods to crash
|
||||
- **Root cause**: 122 unused Docker images + 16 unused volumes + 6GB build cache
|
||||
- **Solution**: Ran `docker system prune -a --volumes -f`
|
||||
- **Result**: Freed 89GB, disk now at 22% usage (24GB/113GB)
|
||||
- **All services recovered successfully**
|
||||
@@ -1,185 +0,0 @@
|
||||
# Lista de Mejoras Propuestas
|
||||
|
||||
## 1. Nueva Sección: Seguridad y Ciberseguridad (Añadir después de 5.2)
|
||||
|
||||
### 5.3. Arquitectura de Seguridad y Cumplimiento Normativo Europeo
|
||||
|
||||
**Autenticación y Autorización:**
|
||||
- JWT con rotación de tokens cada 15 minutos
|
||||
- Control de acceso basado en roles (RBAC)
|
||||
- Rate limiting (300 req/min por cliente)
|
||||
- Autenticación multifactor (MFA) planificada para Q1 2026
|
||||
|
||||
**Protección de Datos:**
|
||||
- Cifrado AES-256 en reposo
|
||||
- HTTPS obligatorio con certificados Let's Encrypt auto-renovables
|
||||
- Aislamiento multi-tenant a nivel de base de datos
|
||||
- Prevención SQL injection mediante Pydantic schemas
|
||||
- Protección XSS y CORS
|
||||
|
||||
**Monitorización y Trazabilidad:**
|
||||
- OpenTelemetry para trazabilidad distribuida end-to-end
|
||||
- SigNoz como plataforma unificada de observabilidad
|
||||
- Logs centralizados con correlación de trazas
|
||||
- Auditoría completa de accesos y cambios
|
||||
- Alertas en tiempo real (email y Slack)
|
||||
|
||||
**Cumplimiento RGPD:**
|
||||
- Privacy by design
|
||||
- Derecho al olvido implementado
|
||||
- Exportación de datos en CSV/JSON
|
||||
- Gestión de consentimientos con historial
|
||||
- Anonimización de datos analíticos
|
||||
|
||||
**Infraestructura Segura:**
|
||||
- Kubernetes con políticas de seguridad de pods
|
||||
- Actualizaciones automáticas de seguridad
|
||||
- Backups cifrados diarios
|
||||
- Plan de recuperación ante desastres (DR)
|
||||
- PostgreSQL 17 con configuraciones hardened
|
||||
|
||||
## 2. Actualizar Sección 4.4 (Competencia) - Añadir Ventaja de Seguridad
|
||||
|
||||
**Ventaja Competitiva en Seguridad:**
|
||||
- Arquitectura "Security-First" vs. soluciones legacy vulnerables
|
||||
- Cumplimiento RGPD desde el diseño (competidores retrofitting)
|
||||
- Observabilidad completa (OpenTelemetry + SigNoz) vs. cajas negras
|
||||
- Certificaciones de seguridad planificadas (ISO 27001, ENS)
|
||||
- Alineación con NIS2 Directive (obligatoria 2024 para cadena alimentaria)
|
||||
|
||||
## 3. Actualizar Sección 8.1 (Financiación) - Nuevas Líneas de Ayuda
|
||||
|
||||
### Añadir subsección: "Financiación Europea en Ciberseguridad 2026-2027"
|
||||
|
||||
**Programas Europeos Identificados (2026-2027):**
|
||||
|
||||
1. **Digital Europe Programme - Ciberseguridad (€390M totales)**
|
||||
- **UPTAKE Program**: €15M para SMEs en cumplimiento normativo
|
||||
- **AI-Based Cybersecurity**: €15M para sistemas IA de seguridad
|
||||
- Cofinanciación: hasta 75% de costes de proyecto
|
||||
- Proyectos: €3-5M por proyecto, duración 36 meses
|
||||
- **Elegibilidad**: Bakery-IA califica como SME tecnológica
|
||||
- **Solicitud estimada**: €200.000 (Q1 2026)
|
||||
|
||||
2. **INCIBE EMPRENDE Program (España Digital 2026)**
|
||||
- Presupuesto: €191M (2023-2026)
|
||||
- 34 entidades colaboradoras en España
|
||||
- Ideación, incubación y aceleración en ciberseguridad
|
||||
- Fondos Next Generation-EU
|
||||
- **Solicitud estimada**: €50.000 (Q2 2026)
|
||||
|
||||
3. **ENISA Emprendedoras Digitales**
|
||||
- Préstamos participativos hasta €51M movilizables
|
||||
- Líneas específicas para emprendimiento digital
|
||||
- Sin avales personales, condiciones flexibles
|
||||
- **Solicitud estimada**: €75.000 (Q2 2026)
|
||||
|
||||
**Alineación con Prioridades UE:**
|
||||
- NIS2 Directive (seguridad sector alimentario)
|
||||
- Cyber Resilience Act (productos digitales seguros)
|
||||
- AI Act (IA transparente y auditable)
|
||||
- GDPR (protección datos desde diseño)
|
||||
|
||||
## 4. Actualizar Sección 5.2 (Arquitectura Técnica)
|
||||
|
||||
### Añadir detalles de la documentación técnica:
|
||||
|
||||
**Arquitectura de Microservicios (21 servicios independientes):**
|
||||
- API Gateway centralizado con JWT y cache Redis (95% hit rate)
|
||||
- Frontend: React 18 + TypeScript (PWA mobile-first)
|
||||
- 18 bases de datos PostgreSQL 17 (patrón database-per-service)
|
||||
- Redis 7.4 para caché y RabbitMQ 4.1 para eventos
|
||||
- Kubernetes en VPS (escalabilidad horizontal)
|
||||
|
||||
**Innovación Técnica Destacable:**
|
||||
- **Sistema de Alertas Enriquecidas** (3 niveles: Alertas/Notificaciones/Recomendaciones)
|
||||
- **Priorización Inteligente** con scoring 0-100 (4 factores ponderados)
|
||||
- **Escalado Temporal** (+10 a 48h, +20 a 72h, +30 cerca deadline)
|
||||
- **Encadenamiento Causal** (stock shortage → retraso producción → riesgo pedido)
|
||||
- **Deduplicación** (95% reducción spam de alertas)
|
||||
- **SSE + WebSocket** para actualizaciones en tiempo real
|
||||
- **Prophet ML** con 20+ features (AEMET, tráfico Madrid, festivos)
|
||||
|
||||
**Observabilidad de Clase Empresarial:**
|
||||
- **SigNoz**: Trazas, métricas y logs unificados
|
||||
- **OpenTelemetry**: Auto-instrumentación de 18 servicios
|
||||
- **ClickHouse**: Backend de alto rendimiento para análisis
|
||||
- **Alerting**: Multi-canal (email, Slack) vía AlertManager
|
||||
- Monitorización: 18 DBs PostgreSQL, Redis, RabbitMQ, Kubernetes
|
||||
|
||||
## 5. Actualizar Resumen Ejecutivo (Sección 0)
|
||||
|
||||
### Añadir bullet:
|
||||
|
||||
- **Seguridad y Cumplimiento:** Arquitectura Security-First con cumplimiento RGPD, observabilidad completa (OpenTelemetry/SigNoz), y alineación con normativas europeas (NIS2, Cyber Resilience Act). Elegible para €390M del Digital Europe Programme en ciberseguridad.
|
||||
|
||||
## 6. Actualizar Sección 9 (Decálogo) - Añadir Oportunidades de Financiación
|
||||
|
||||
**6. Alineación Estratégica con Prioridades Europeas 2026-2027:**
|
||||
- **€390M disponibles** en Digital Europe Programme para ciberseguridad
|
||||
- **€191M** del programa INCIBE EMPRENDE (España Digital 2026)
|
||||
- Bakery-IA califica para **3 líneas de financiación simultáneas**:
|
||||
* UPTAKE (€200K) - Cumplimiento normativo SMEs
|
||||
* INCIBE EMPRENDE (€50K) - Aceleración cybersecurity startups
|
||||
* ENISA Digital (€75K) - Préstamo participativo
|
||||
- **Total potencial**: €325.000 en financiación no dilutiva adicional
|
||||
- Ventaja competitiva: Security-First vs. competidores legacy
|
||||
|
||||
## 7. Nueva Tabla de Costes Operativos - Añadir Línea de Seguridad
|
||||
|
||||
| **Seguridad y Cumplimiento** | | |
|
||||
| Certificado SSL (Let's Encrypt) | Gratuito | €0 | Renovación automática |
|
||||
| SigNoz Observability (self-hosted) | Incluido en VPS | €0 | Vs. €500+/mes en SaaS |
|
||||
| Auditoría RGPD anual | Externa | €1,200/año | Compliance obligatorio |
|
||||
| Backups cifrados (off-site) | Backblaze B2 | €5/mes | €60/año |
|
||||
|
||||
## 8. Actualizar Roadmap (Sección 9) - Añadir Hitos de Seguridad
|
||||
|
||||
**Q1 2026:**
|
||||
- ✅ Implementación MFA (Multi-Factor Authentication)
|
||||
- ✅ Solicitud Digital Europe Programme (UPTAKE)
|
||||
- ✅ Auditoría RGPD externa
|
||||
|
||||
**Q2 2026:**
|
||||
- 📋 Certificación ISO 27001 (inicio proceso)
|
||||
- 📋 Implementación NIS2 compliance
|
||||
- 📋 Solicitud INCIBE EMPRENDE
|
||||
|
||||
**Q3 2026:**
|
||||
- 📋 Penetration testing externo
|
||||
- 📋 Certificación ENS (Esquema Nacional de Seguridad)
|
||||
|
||||
## 9. Actualizar Petición Concreta a VUE (Sección 9.2)
|
||||
|
||||
### Añadir nuevo punto 5:
|
||||
|
||||
**5. Conexión con Programas Europeos de Ciberseguridad:**
|
||||
- Orientación para solicitud Digital Europe Programme (UPTAKE: €200K)
|
||||
- Introducción a INCIBE EMPRENDE y red de 34 entidades colaboradoras
|
||||
- Asesoramiento en preparación de propuestas técnicas para fondos EU
|
||||
- Contacto con CDTI para programa NEOTEC (R&D+Ciberseguridad)
|
||||
|
||||
## 10. Añadir Anexo 7 - Compliance y Certificaciones
|
||||
|
||||
## ANEXO 7: ROADMAP DE COMPLIANCE Y CERTIFICACIONES
|
||||
|
||||
**Normativas Aplicables:**
|
||||
- ✅ RGPD (Reglamento General de Protección de Datos) - Implementado
|
||||
- 📋 NIS2 Directive (Seguridad de redes y sistemas de información) - Q2 2026
|
||||
- 📋 Cyber Resilience Act (Productos digitales seguros) - Q3 2026
|
||||
- 📋 AI Act (Transparencia y auditoría de IA) - Q4 2026
|
||||
|
||||
**Certificaciones Planificadas:**
|
||||
- 📋 ISO 27001 (Gestión de Seguridad de la Información) - 12-18 meses
|
||||
- 📋 ENS Medio (Esquema Nacional de Seguridad) - 6-9 meses
|
||||
- 📋 SOC 2 Type II (para clientes Enterprise) - 18-24 meses
|
||||
|
||||
**Inversión Estimada en Compliance (3 años):** €25,000
|
||||
**ROI Esperado:** Acceso a clientes Enterprise (+€150K ARR potencial)
|
||||
|
||||
## Resumen de Cambios Cuantitativos:
|
||||
- Nueva financiación identificada: €325.000 (vs. €18.000 original)
|
||||
- Nuevos programas: 3 líneas europeas de ciberseguridad
|
||||
- Secciones nuevas: 2 (Seguridad 5.3, Compliance Anexo 7)
|
||||
- Actualizaciones: 8 secciones existentes mejoradas
|
||||
- Ventaja competitiva: Security-First enfatizada en 4 lugares
|
||||
File diff suppressed because it is too large
Load Diff
38
Tiltfile
38
Tiltfile
@@ -7,6 +7,13 @@
|
||||
# - PostgreSQL pgcrypto extension and audit logging
|
||||
# - Organized resource dependencies and live-reload capabilities
|
||||
# - Local registry for faster image builds and deployments
|
||||
#
|
||||
# Build Optimization:
|
||||
# - Services only rebuild when their specific code changes (not all services)
|
||||
# - Shared folder changes trigger rebuild of ALL services (as they all depend on it)
|
||||
# - Uses 'only' parameter to watch only relevant files per service
|
||||
# - Frontend only rebuilds when frontend/ code changes
|
||||
# - Gateway only rebuilds when gateway/ or shared/ code changes
|
||||
# =============================================================================
|
||||
|
||||
# =============================================================================
|
||||
@@ -197,16 +204,25 @@ k8s_yaml(kustomize('infrastructure/kubernetes/overlays/dev'))
|
||||
# =============================================================================
|
||||
|
||||
# Helper function for Python services with live updates
|
||||
# This function ensures services only rebuild when their specific code changes,
|
||||
# but all services rebuild when shared/ folder changes
|
||||
def build_python_service(service_name, service_path):
|
||||
docker_build(
|
||||
'bakery/' + service_name,
|
||||
context='.',
|
||||
dockerfile='./services/' + service_path + '/Dockerfile',
|
||||
# Only watch files relevant to this specific service + shared code
|
||||
only=[
|
||||
'./services/' + service_path,
|
||||
'./shared',
|
||||
'./scripts',
|
||||
],
|
||||
live_update=[
|
||||
# Fall back to full image build if Dockerfile or requirements change
|
||||
fall_back_on([
|
||||
'./services/' + service_path + '/Dockerfile',
|
||||
'./services/' + service_path + '/requirements.txt'
|
||||
'./services/' + service_path + '/requirements.txt',
|
||||
'./shared/requirements-tracing.txt',
|
||||
]),
|
||||
|
||||
# Sync service code
|
||||
@@ -290,10 +306,21 @@ docker_build(
|
||||
'bakery/gateway',
|
||||
context='.',
|
||||
dockerfile='./gateway/Dockerfile',
|
||||
# Only watch gateway-specific files and shared code
|
||||
only=[
|
||||
'./gateway',
|
||||
'./shared',
|
||||
'./scripts',
|
||||
],
|
||||
live_update=[
|
||||
fall_back_on(['./gateway/Dockerfile', './gateway/requirements.txt']),
|
||||
fall_back_on([
|
||||
'./gateway/Dockerfile',
|
||||
'./gateway/requirements.txt',
|
||||
'./shared/requirements-tracing.txt',
|
||||
]),
|
||||
sync('./gateway', '/app'),
|
||||
sync('./shared', '/app/shared'),
|
||||
sync('./scripts', '/app/scripts'),
|
||||
run('kill -HUP 1', trigger=['./gateway/**/*.py', './shared/**/*.py']),
|
||||
],
|
||||
ignore=[
|
||||
@@ -680,6 +707,13 @@ Documentation:
|
||||
docs/SECURITY_IMPLEMENTATION_COMPLETE.md
|
||||
docs/DATABASE_SECURITY_ANALYSIS_REPORT.md
|
||||
|
||||
Build Optimization Active:
|
||||
✅ Services only rebuild when their code changes
|
||||
✅ Shared folder changes trigger ALL services (as expected)
|
||||
✅ Reduces unnecessary rebuilds and disk usage
|
||||
💡 Edit service code: only that service rebuilds
|
||||
💡 Edit shared/ code: all services rebuild (required)
|
||||
|
||||
Useful Commands:
|
||||
# Work on specific services only
|
||||
tilt up <service-name> <service-name>
|
||||
|
||||
@@ -198,16 +198,27 @@ export const RegisterTenantStep: React.FC<RegisterTenantStepProps> = ({
|
||||
|
||||
// Trigger POI detection in the background (non-blocking)
|
||||
// This replaces the removed POI Detection step
|
||||
// POI detection will be cached for 90 days and reused during training
|
||||
const bakeryLocation = wizardContext.state.bakeryLocation;
|
||||
if (bakeryLocation?.latitude && bakeryLocation?.longitude && tenant.id) {
|
||||
console.log(`🔍 Triggering background POI detection for tenant ${tenant.id}...`);
|
||||
|
||||
// Run POI detection asynchronously without blocking the wizard flow
|
||||
// This ensures POI data is ready before the training step
|
||||
poiContextApi.detectPOIs(
|
||||
tenant.id,
|
||||
bakeryLocation.latitude,
|
||||
bakeryLocation.longitude,
|
||||
false // use_cache = false for initial detection
|
||||
false // force_refresh = false, will use cache if available
|
||||
).then((result) => {
|
||||
console.log(`✅ POI detection completed automatically for tenant ${tenant.id}:`, result.summary);
|
||||
const source = result.source || 'unknown';
|
||||
console.log(`✅ POI detection completed for tenant ${tenant.id} (source: ${source})`);
|
||||
|
||||
if (result.poi_context) {
|
||||
const totalPois = result.poi_context.total_pois_detected || 0;
|
||||
const relevantCategories = result.poi_context.relevant_categories?.length || 0;
|
||||
console.log(`📍 POI Summary: ${totalPois} POIs detected, ${relevantCategories} relevant categories`);
|
||||
}
|
||||
|
||||
// Phase 3: Handle calendar suggestion if available
|
||||
if (result.calendar_suggestion) {
|
||||
@@ -229,9 +240,12 @@ export const RegisterTenantStep: React.FC<RegisterTenantStepProps> = ({
|
||||
}
|
||||
}
|
||||
}).catch((error) => {
|
||||
console.warn('⚠️ Background POI detection failed (non-blocking):', error);
|
||||
// This is non-critical, so we don't block the user
|
||||
console.warn('⚠️ Background POI detection failed (non-blocking):', error);
|
||||
console.warn('Training will continue without POI features if detection is not complete.');
|
||||
// This is non-critical - training service will continue without POI features
|
||||
});
|
||||
} else {
|
||||
console.warn('⚠️ Cannot trigger POI detection: missing location data or tenant ID');
|
||||
}
|
||||
|
||||
// Update the wizard context with tenant info
|
||||
|
||||
@@ -352,6 +352,25 @@ headers = {
|
||||
- **Caching**: Gateway caches validated service tokens for 5 minutes
|
||||
- **No Additional HTTP Calls**: Service auth happens locally at gateway
|
||||
|
||||
### Unified Header Management System
|
||||
|
||||
The gateway uses a **centralized HeaderManager** for consistent header handling across all middleware and proxy layers.
|
||||
|
||||
**Key Features:**
|
||||
- Standardized header names and conventions
|
||||
- Automatic header sanitization to prevent spoofing
|
||||
- Unified header injection and forwarding
|
||||
- Cross-middleware header access via `request.state.injected_headers`
|
||||
- Consistent logging and error handling
|
||||
|
||||
**Standard Headers:**
|
||||
- `x-user-id`, `x-user-email`, `x-user-role`, `x-user-type`
|
||||
- `x-service-name`, `x-tenant-id`
|
||||
- `x-subscription-tier`, `x-subscription-status`
|
||||
- `x-is-demo`, `x-demo-session-id`, `x-demo-account-type`
|
||||
- `x-tenant-access-type`, `x-can-view-children`, `x-parent-tenant-id`
|
||||
- `x-forwarded-by`, `x-request-id`
|
||||
|
||||
### Context Header Injection
|
||||
|
||||
When a service token is validated, the gateway injects these headers for downstream services:
|
||||
|
||||
345
gateway/app/core/header_manager.py
Normal file
345
gateway/app/core/header_manager.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""
|
||||
Unified Header Management System for API Gateway
|
||||
Centralized header injection, forwarding, and validation
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from fastapi import Request
|
||||
from typing import Dict, Any, Optional, List
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class HeaderManager:
|
||||
"""
|
||||
Centralized header management for consistent header handling across gateway
|
||||
"""
|
||||
|
||||
# Standard header names (lowercase for consistency)
|
||||
STANDARD_HEADERS = {
|
||||
'user_id': 'x-user-id',
|
||||
'user_email': 'x-user-email',
|
||||
'user_role': 'x-user-role',
|
||||
'user_type': 'x-user-type',
|
||||
'service_name': 'x-service-name',
|
||||
'tenant_id': 'x-tenant-id',
|
||||
'subscription_tier': 'x-subscription-tier',
|
||||
'subscription_status': 'x-subscription-status',
|
||||
'is_demo': 'x-is-demo',
|
||||
'demo_session_id': 'x-demo-session-id',
|
||||
'demo_account_type': 'x-demo-account-type',
|
||||
'tenant_access_type': 'x-tenant-access-type',
|
||||
'can_view_children': 'x-can-view-children',
|
||||
'parent_tenant_id': 'x-parent-tenant-id',
|
||||
'forwarded_by': 'x-forwarded-by',
|
||||
'request_id': 'x-request-id'
|
||||
}
|
||||
|
||||
# Headers that should be sanitized/removed from incoming requests
|
||||
SANITIZED_HEADERS = [
|
||||
'x-subscription-',
|
||||
'x-user-',
|
||||
'x-tenant-',
|
||||
'x-demo-',
|
||||
'x-forwarded-by'
|
||||
]
|
||||
|
||||
# Headers that should be forwarded to downstream services
|
||||
FORWARDABLE_HEADERS = [
|
||||
'authorization',
|
||||
'content-type',
|
||||
'accept',
|
||||
'accept-language',
|
||||
'user-agent',
|
||||
'x-internal-service' # Required for internal service-to-service ML/alert triggers
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
self._initialized = False
|
||||
|
||||
def initialize(self):
|
||||
"""Initialize header manager"""
|
||||
if not self._initialized:
|
||||
logger.info("HeaderManager initialized")
|
||||
self._initialized = True
|
||||
|
||||
def sanitize_incoming_headers(self, request: Request) -> None:
|
||||
"""
|
||||
Remove sensitive headers from incoming request to prevent spoofing
|
||||
"""
|
||||
if not hasattr(request.headers, '_list'):
|
||||
return
|
||||
|
||||
# Filter out headers that start with sanitized prefixes
|
||||
sanitized_headers = [
|
||||
(k, v) for k, v in request.headers.raw
|
||||
if not any(k.decode().lower().startswith(prefix.lower())
|
||||
for prefix in self.SANITIZED_HEADERS)
|
||||
]
|
||||
|
||||
request.headers.__dict__["_list"] = sanitized_headers
|
||||
logger.debug("Sanitized incoming headers")
|
||||
|
||||
def inject_context_headers(self, request: Request, user_context: Dict[str, Any],
|
||||
tenant_id: Optional[str] = None) -> Dict[str, str]:
|
||||
"""
|
||||
Inject standardized context headers into request
|
||||
Returns dict of injected headers for reference
|
||||
"""
|
||||
injected_headers = {}
|
||||
|
||||
# Ensure headers list exists
|
||||
if not hasattr(request.headers, '_list'):
|
||||
request.headers.__dict__["_list"] = []
|
||||
|
||||
# Store headers in request.state for cross-middleware access
|
||||
request.state.injected_headers = {}
|
||||
|
||||
# User context headers
|
||||
if user_context.get('user_id'):
|
||||
header_name = self.STANDARD_HEADERS['user_id']
|
||||
header_value = str(user_context['user_id'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
if user_context.get('email'):
|
||||
header_name = self.STANDARD_HEADERS['user_email']
|
||||
header_value = str(user_context['email'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
if user_context.get('role'):
|
||||
header_name = self.STANDARD_HEADERS['user_role']
|
||||
header_value = str(user_context['role'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# User type (service vs regular user)
|
||||
if user_context.get('type'):
|
||||
header_name = self.STANDARD_HEADERS['user_type']
|
||||
header_value = str(user_context['type'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Service name for service tokens
|
||||
if user_context.get('service'):
|
||||
header_name = self.STANDARD_HEADERS['service_name']
|
||||
header_value = str(user_context['service'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Tenant context
|
||||
if tenant_id:
|
||||
header_name = self.STANDARD_HEADERS['tenant_id']
|
||||
header_value = str(tenant_id)
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Subscription context
|
||||
if user_context.get('subscription_tier'):
|
||||
header_name = self.STANDARD_HEADERS['subscription_tier']
|
||||
header_value = str(user_context['subscription_tier'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
if user_context.get('subscription_status'):
|
||||
header_name = self.STANDARD_HEADERS['subscription_status']
|
||||
header_value = str(user_context['subscription_status'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Demo session context
|
||||
is_demo = user_context.get('is_demo', False)
|
||||
if is_demo:
|
||||
header_name = self.STANDARD_HEADERS['is_demo']
|
||||
header_value = "true"
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
if user_context.get('demo_session_id'):
|
||||
header_name = self.STANDARD_HEADERS['demo_session_id']
|
||||
header_value = str(user_context['demo_session_id'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
if user_context.get('demo_account_type'):
|
||||
header_name = self.STANDARD_HEADERS['demo_account_type']
|
||||
header_value = str(user_context['demo_account_type'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Hierarchical access context
|
||||
if tenant_id:
|
||||
tenant_access_type = getattr(request.state, 'tenant_access_type', 'direct')
|
||||
can_view_children = getattr(request.state, 'can_view_children', False)
|
||||
|
||||
header_name = self.STANDARD_HEADERS['tenant_access_type']
|
||||
header_value = str(tenant_access_type)
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
header_name = self.STANDARD_HEADERS['can_view_children']
|
||||
header_value = str(can_view_children).lower()
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Parent tenant ID if hierarchical access
|
||||
parent_tenant_id = getattr(request.state, 'parent_tenant_id', None)
|
||||
if parent_tenant_id:
|
||||
header_name = self.STANDARD_HEADERS['parent_tenant_id']
|
||||
header_value = str(parent_tenant_id)
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Gateway identification
|
||||
header_name = self.STANDARD_HEADERS['forwarded_by']
|
||||
header_value = "bakery-gateway"
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Request ID if available
|
||||
request_id = getattr(request.state, 'request_id', None)
|
||||
if request_id:
|
||||
header_name = self.STANDARD_HEADERS['request_id']
|
||||
header_value = str(request_id)
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
logger.info("🔧 Injected context headers",
|
||||
user_id=user_context.get('user_id'),
|
||||
user_type=user_context.get('type', ''),
|
||||
service_name=user_context.get('service', ''),
|
||||
role=user_context.get('role', ''),
|
||||
tenant_id=tenant_id,
|
||||
is_demo=is_demo,
|
||||
demo_session_id=user_context.get('demo_session_id', ''),
|
||||
path=request.url.path)
|
||||
|
||||
return injected_headers
|
||||
|
||||
def _add_header(self, request: Request, header_name: str, header_value: str) -> None:
|
||||
"""
|
||||
Safely add header to request
|
||||
"""
|
||||
try:
|
||||
request.headers.__dict__["_list"].append((header_name.encode(), header_value.encode()))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to add header {header_name}: {e}")
|
||||
|
||||
def get_forwardable_headers(self, request: Request) -> Dict[str, str]:
|
||||
"""
|
||||
Get headers that should be forwarded to downstream services
|
||||
Includes both original request headers and injected context headers
|
||||
"""
|
||||
forwardable_headers = {}
|
||||
|
||||
# Add forwardable original headers
|
||||
for header_name in self.FORWARDABLE_HEADERS:
|
||||
header_value = request.headers.get(header_name)
|
||||
if header_value:
|
||||
forwardable_headers[header_name] = header_value
|
||||
|
||||
# Add injected context headers from request.state
|
||||
if hasattr(request.state, 'injected_headers'):
|
||||
for header_name, header_value in request.state.injected_headers.items():
|
||||
forwardable_headers[header_name] = header_value
|
||||
|
||||
# Add authorization header if present
|
||||
auth_header = request.headers.get('authorization')
|
||||
if auth_header:
|
||||
forwardable_headers['authorization'] = auth_header
|
||||
|
||||
return forwardable_headers
|
||||
|
||||
def get_all_headers_for_proxy(self, request: Request,
|
||||
additional_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
|
||||
"""
|
||||
Get complete set of headers for proxying to downstream services
|
||||
"""
|
||||
headers = self.get_forwardable_headers(request)
|
||||
|
||||
# Add any additional headers
|
||||
if additional_headers:
|
||||
headers.update(additional_headers)
|
||||
|
||||
# Remove host header as it will be set by httpx
|
||||
headers.pop('host', None)
|
||||
|
||||
return headers
|
||||
|
||||
def validate_required_headers(self, request: Request, required_headers: List[str]) -> bool:
|
||||
"""
|
||||
Validate that required headers are present
|
||||
"""
|
||||
missing_headers = []
|
||||
|
||||
for header_name in required_headers:
|
||||
# Check in injected headers first
|
||||
if hasattr(request.state, 'injected_headers'):
|
||||
if header_name in request.state.injected_headers:
|
||||
continue
|
||||
|
||||
# Check in request headers
|
||||
if request.headers.get(header_name):
|
||||
continue
|
||||
|
||||
missing_headers.append(header_name)
|
||||
|
||||
if missing_headers:
|
||||
logger.warning(f"Missing required headers: {missing_headers}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_header_value(self, request: Request, header_name: str,
|
||||
default: Optional[str] = None) -> Optional[str]:
|
||||
"""
|
||||
Get header value from either injected headers or request headers
|
||||
"""
|
||||
# Check injected headers first
|
||||
if hasattr(request.state, 'injected_headers'):
|
||||
if header_name in request.state.injected_headers:
|
||||
return request.state.injected_headers[header_name]
|
||||
|
||||
# Check request headers
|
||||
return request.headers.get(header_name, default)
|
||||
|
||||
def add_header_for_middleware(self, request: Request, header_name: str, header_value: str) -> None:
|
||||
"""
|
||||
Allow middleware to add headers to the unified header system
|
||||
This ensures all headers are available for proxying
|
||||
"""
|
||||
# Ensure injected_headers exists
|
||||
if not hasattr(request.state, 'injected_headers'):
|
||||
request.state.injected_headers = {}
|
||||
|
||||
# Add header to injected_headers
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Also add to actual request headers for compatibility
|
||||
try:
|
||||
request.headers.__dict__["_list"].append((header_name.encode(), header_value.encode()))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to add header {header_name} to request headers: {e}")
|
||||
|
||||
logger.debug(f"Middleware added header: {header_name} = {header_value}")
|
||||
|
||||
|
||||
# Global instance for easy access
|
||||
header_manager = HeaderManager()
|
||||
@@ -16,6 +16,7 @@ from shared.redis_utils import initialize_redis, close_redis, get_redis_client
|
||||
from shared.service_base import StandardFastAPIService
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
from app.middleware.request_id import RequestIDMiddleware
|
||||
from app.middleware.auth import AuthMiddleware
|
||||
from app.middleware.logging import LoggingMiddleware
|
||||
@@ -50,6 +51,10 @@ class GatewayService(StandardFastAPIService):
|
||||
"""Custom startup logic for Gateway"""
|
||||
global redis_client
|
||||
|
||||
# Initialize HeaderManager
|
||||
header_manager.initialize()
|
||||
logger.info("HeaderManager initialized")
|
||||
|
||||
# Initialize Redis
|
||||
try:
|
||||
await initialize_redis(settings.REDIS_URL, db=0, max_connections=50)
|
||||
|
||||
@@ -14,6 +14,7 @@ import httpx
|
||||
import json
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
from shared.auth.jwt_handler import JWTHandler
|
||||
from shared.auth.tenant_access import tenant_access_manager, extract_tenant_id_from_path, is_tenant_scoped_path
|
||||
|
||||
@@ -60,15 +61,8 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
if request.method == "OPTIONS":
|
||||
return await call_next(request)
|
||||
|
||||
# SECURITY: Remove any incoming x-subscription-* headers
|
||||
# These will be re-injected from verified JWT only
|
||||
sanitized_headers = [
|
||||
(k, v) for k, v in request.headers.raw
|
||||
if not k.decode().lower().startswith('x-subscription-')
|
||||
and not k.decode().lower().startswith('x-user-')
|
||||
and not k.decode().lower().startswith('x-tenant-')
|
||||
]
|
||||
request.headers.__dict__["_list"] = sanitized_headers
|
||||
# SECURITY: Remove any incoming sensitive headers using HeaderManager
|
||||
header_manager.sanitize_incoming_headers(request)
|
||||
|
||||
# Skip authentication for public routes
|
||||
if self._is_public_route(request.url.path):
|
||||
@@ -573,109 +567,13 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
async def _inject_context_headers(self, request: Request, user_context: Dict[str, Any], tenant_id: Optional[str] = None):
|
||||
"""
|
||||
Inject user and tenant context headers for downstream services
|
||||
ENHANCED: Added logging to verify header injection
|
||||
Inject user and tenant context headers for downstream services using unified HeaderManager
|
||||
"""
|
||||
# Enhanced logging for debugging
|
||||
logger.info(
|
||||
"🔧 Injecting context headers",
|
||||
user_id=user_context.get("user_id"),
|
||||
user_type=user_context.get("type", ""),
|
||||
service_name=user_context.get("service", ""),
|
||||
role=user_context.get("role", ""),
|
||||
tenant_id=tenant_id,
|
||||
is_demo=user_context.get("is_demo", False),
|
||||
demo_session_id=user_context.get("demo_session_id", ""),
|
||||
path=request.url.path
|
||||
)
|
||||
|
||||
# Add user context headers
|
||||
logger.debug(f"DEBUG: Injecting headers for user: {user_context.get('user_id')}, is_demo: {user_context.get('is_demo', False)}")
|
||||
logger.debug(f"DEBUG: request.headers object id: {id(request.headers)}, _list id: {id(request.headers.__dict__.get('_list', []))}")
|
||||
|
||||
# Store headers in request.state for cross-middleware access
|
||||
request.state.injected_headers = {
|
||||
"x-user-id": user_context["user_id"],
|
||||
"x-user-email": user_context["email"],
|
||||
"x-user-role": user_context.get("role", "user")
|
||||
}
|
||||
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-user-id", user_context["user_id"].encode()
|
||||
))
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-user-email", user_context["email"].encode()
|
||||
))
|
||||
|
||||
user_role = user_context.get("role", "user")
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-user-role", user_role.encode()
|
||||
))
|
||||
|
||||
user_type = user_context.get("type", "")
|
||||
if user_type:
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-user-type", user_type.encode()
|
||||
))
|
||||
|
||||
service_name = user_context.get("service", "")
|
||||
if service_name:
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-service-name", service_name.encode()
|
||||
))
|
||||
|
||||
# Add tenant context if available
|
||||
if tenant_id:
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-tenant-id", tenant_id.encode()
|
||||
))
|
||||
|
||||
# Add subscription tier if available
|
||||
subscription_tier = user_context.get("subscription_tier", "")
|
||||
if subscription_tier:
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-subscription-tier", subscription_tier.encode()
|
||||
))
|
||||
|
||||
# Add is_demo flag for demo sessions
|
||||
is_demo = user_context.get("is_demo", False)
|
||||
logger.debug(f"DEBUG: is_demo value: {is_demo}, type: {type(is_demo)}")
|
||||
if is_demo:
|
||||
logger.info(f"🎭 Adding demo session headers",
|
||||
demo_session_id=user_context.get("demo_session_id", ""),
|
||||
demo_account_type=user_context.get("demo_account_type", ""),
|
||||
path=request.url.path)
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-is-demo", b"true"
|
||||
))
|
||||
else:
|
||||
logger.debug(f"DEBUG: Not adding demo headers because is_demo is: {is_demo}")
|
||||
|
||||
# Add demo session context headers for backend services
|
||||
demo_session_id = user_context.get("demo_session_id", "")
|
||||
if demo_session_id:
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-demo-session-id", demo_session_id.encode()
|
||||
))
|
||||
|
||||
demo_account_type = user_context.get("demo_account_type", "")
|
||||
if demo_account_type:
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-demo-account-type", demo_account_type.encode()
|
||||
))
|
||||
# Use unified HeaderManager for consistent header injection
|
||||
injected_headers = header_manager.inject_context_headers(request, user_context, tenant_id)
|
||||
|
||||
# Add hierarchical access headers if tenant context exists
|
||||
if tenant_id:
|
||||
tenant_access_type = getattr(request.state, 'tenant_access_type', 'direct')
|
||||
can_view_children = getattr(request.state, 'can_view_children', False)
|
||||
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-tenant-access-type", tenant_access_type.encode()
|
||||
))
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-can-view-children", str(can_view_children).encode()
|
||||
))
|
||||
|
||||
# If this is hierarchical access, include parent tenant ID
|
||||
# Get parent tenant ID from the auth service if available
|
||||
try:
|
||||
@@ -689,17 +587,16 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
hierarchy_data = response.json()
|
||||
parent_tenant_id = hierarchy_data.get("parent_tenant_id")
|
||||
if parent_tenant_id:
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-parent-tenant-id", parent_tenant_id.encode()
|
||||
))
|
||||
# Add parent tenant ID using HeaderManager for consistency
|
||||
header_name = header_manager.STANDARD_HEADERS['parent_tenant_id']
|
||||
header_value = str(parent_tenant_id)
|
||||
header_manager.add_header_for_middleware(request, header_name, header_value)
|
||||
logger.info(f"Added parent tenant ID header: {parent_tenant_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get parent tenant ID: {e}")
|
||||
pass
|
||||
|
||||
# Add gateway identification
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-forwarded-by", b"bakery-gateway"
|
||||
))
|
||||
return injected_headers
|
||||
|
||||
async def _get_tenant_subscription_tier(self, tenant_id: str, request: Request) -> Optional[str]:
|
||||
"""
|
||||
|
||||
@@ -45,8 +45,17 @@ class APIRateLimitMiddleware(BaseHTTPMiddleware):
|
||||
return await call_next(request)
|
||||
|
||||
try:
|
||||
# Get subscription tier
|
||||
subscription_tier = await self._get_subscription_tier(tenant_id, request)
|
||||
# Get subscription tier from headers (added by AuthMiddleware)
|
||||
subscription_tier = request.headers.get("x-subscription-tier")
|
||||
|
||||
if not subscription_tier:
|
||||
# Fallback: get from request state if headers not available
|
||||
subscription_tier = getattr(request.state, "subscription_tier", None)
|
||||
|
||||
if not subscription_tier:
|
||||
# Final fallback: get from tenant service (should rarely happen)
|
||||
subscription_tier = await self._get_subscription_tier(tenant_id, request)
|
||||
logger.warning(f"Subscription tier not found in headers or state, fetched from tenant service: {subscription_tier}")
|
||||
|
||||
# Get quota limit for tier
|
||||
quota_limit = self._get_quota_limit(subscription_tier)
|
||||
|
||||
@@ -9,6 +9,8 @@ from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import Response
|
||||
|
||||
from app.core.header_manager import header_manager
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
@@ -40,11 +42,9 @@ class RequestIDMiddleware(BaseHTTPMiddleware):
|
||||
# Bind request ID to structured logger context
|
||||
logger_ctx = logger.bind(request_id=request_id)
|
||||
|
||||
# Inject request ID header for downstream services
|
||||
# This is done by modifying the headers that will be forwarded
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-request-id", request_id.encode()
|
||||
))
|
||||
# Inject request ID header for downstream services using HeaderManager
|
||||
# Note: This runs early in middleware chain, so we use add_header_for_middleware
|
||||
header_manager.add_header_for_middleware(request, "x-request-id", request_id)
|
||||
|
||||
# Log request start
|
||||
logger_ctx.info(
|
||||
|
||||
@@ -15,6 +15,7 @@ import asyncio
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
from app.utils.subscription_error_responses import create_upgrade_required_response
|
||||
|
||||
logger = structlog.get_logger()
|
||||
@@ -178,7 +179,10 @@ class SubscriptionMiddleware(BaseHTTPMiddleware):
|
||||
r'/api/v1/subscriptions/.*', # Subscription management itself
|
||||
r'/api/v1/tenants/[^/]+/members.*', # Basic tenant info
|
||||
r'/docs.*',
|
||||
r'/openapi\.json'
|
||||
r'/openapi\.json',
|
||||
# Training monitoring endpoints (WebSocket and status checks)
|
||||
r'/api/v1/tenants/[^/]+/training/jobs/.*/live.*', # WebSocket endpoint
|
||||
r'/api/v1/tenants/[^/]+/training/jobs/.*/status.*', # Status polling endpoint
|
||||
]
|
||||
|
||||
# Skip OPTIONS requests (CORS preflight)
|
||||
@@ -275,21 +279,11 @@ class SubscriptionMiddleware(BaseHTTPMiddleware):
|
||||
'current_tier': current_tier
|
||||
}
|
||||
|
||||
# Use the same authentication pattern as gateway routes for fallback
|
||||
headers = dict(request.headers)
|
||||
headers.pop("host", None)
|
||||
# Use unified HeaderManager for consistent header handling
|
||||
headers = header_manager.get_all_headers_for_proxy(request)
|
||||
|
||||
# Extract user_id for logging (fallback path)
|
||||
user_id = 'unknown'
|
||||
# Add user context headers if available
|
||||
if hasattr(request.state, 'user') and request.state.user:
|
||||
user = request.state.user
|
||||
user_id = str(user.get('user_id', 'unknown'))
|
||||
headers["x-user-id"] = user_id
|
||||
headers["x-user-email"] = str(user.get('email', ''))
|
||||
headers["x-user-role"] = str(user.get('role', 'user'))
|
||||
headers["x-user-full-name"] = str(user.get('full_name', ''))
|
||||
headers["x-tenant-id"] = str(user.get('tenant_id', ''))
|
||||
user_id = header_manager.get_header_value(request, 'x-user-id', 'unknown')
|
||||
|
||||
# Call tenant service fast tier endpoint with caching (fallback for old tokens)
|
||||
timeout_config = httpx.Timeout(
|
||||
|
||||
@@ -13,6 +13,7 @@ from fastapi.responses import JSONResponse
|
||||
from typing import Dict, Any
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
from app.core.service_discovery import ServiceDiscovery
|
||||
from shared.monitoring.metrics import MetricsCollector
|
||||
|
||||
@@ -136,107 +137,32 @@ class AuthProxy:
|
||||
return AUTH_SERVICE_URL
|
||||
|
||||
def _prepare_headers(self, headers, request=None) -> Dict[str, str]:
|
||||
"""Prepare headers for forwarding (remove hop-by-hop headers)"""
|
||||
# Remove hop-by-hop headers
|
||||
hop_by_hop_headers = {
|
||||
'connection', 'keep-alive', 'proxy-authenticate',
|
||||
'proxy-authorization', 'te', 'trailers', 'upgrade'
|
||||
}
|
||||
|
||||
# Convert headers to dict - get ALL headers including those added by middleware
|
||||
# Middleware adds headers to _list, so we need to read from there
|
||||
logger.debug(f"DEBUG: headers type: {type(headers)}, has _list: {hasattr(headers, '_list')}, has raw: {hasattr(headers, 'raw')}")
|
||||
logger.debug(f"DEBUG: headers.__dict__ keys: {list(headers.__dict__.keys())}")
|
||||
logger.debug(f"DEBUG: '_list' in headers.__dict__: {'_list' in headers.__dict__}")
|
||||
|
||||
if hasattr(headers, '_list'):
|
||||
logger.debug(f"DEBUG: Entering _list branch")
|
||||
logger.debug(f"DEBUG: headers object id: {id(headers)}, _list id: {id(headers.__dict__.get('_list', []))}")
|
||||
# Get headers from the _list where middleware adds them
|
||||
all_headers_list = headers.__dict__.get('_list', [])
|
||||
logger.debug(f"DEBUG: _list length: {len(all_headers_list)}")
|
||||
|
||||
# Debug: Show first few headers in the list
|
||||
debug_headers = []
|
||||
for i, (k, v) in enumerate(all_headers_list):
|
||||
if i < 5: # Show first 5 headers for debugging
|
||||
"""Prepare headers for forwarding using unified HeaderManager"""
|
||||
# Use unified HeaderManager to get all headers
|
||||
if request:
|
||||
all_headers = header_manager.get_all_headers_for_proxy(request)
|
||||
logger.debug(f"DEBUG: Added headers from HeaderManager: {list(all_headers.keys())}")
|
||||
else:
|
||||
# Fallback: convert headers to dict manually
|
||||
all_headers = {}
|
||||
if hasattr(headers, '_list'):
|
||||
for k, v in headers.__dict__.get('_list', []):
|
||||
key = k.decode() if isinstance(k, bytes) else k
|
||||
value = v.decode() if isinstance(v, bytes) else v
|
||||
debug_headers.append(f"{key}: {value}")
|
||||
logger.debug(f"DEBUG: First headers in _list: {debug_headers}")
|
||||
all_headers[key] = value
|
||||
elif hasattr(headers, 'raw'):
|
||||
for k, v in headers.raw:
|
||||
key = k.decode() if isinstance(k, bytes) else k
|
||||
value = v.decode() if isinstance(v, bytes) else v
|
||||
all_headers[key] = value
|
||||
else:
|
||||
# Headers is already a dict
|
||||
all_headers = dict(headers)
|
||||
|
||||
# Convert to dict for easier processing
|
||||
all_headers = {}
|
||||
for k, v in all_headers_list:
|
||||
key = k.decode() if isinstance(k, bytes) else k
|
||||
value = v.decode() if isinstance(v, bytes) else v
|
||||
all_headers[key] = value
|
||||
# Debug logging
|
||||
logger.info(f"📤 Forwarding headers - x_user_id: {all_headers.get('x-user-id', 'MISSING')}, x_is_demo: {all_headers.get('x-is-demo', 'MISSING')}, x_demo_session_id: {all_headers.get('x-demo-session-id', 'MISSING')}, headers: {list(all_headers.keys())}")
|
||||
|
||||
# Debug: Show if x-user-id and x-is-demo are in the dict
|
||||
logger.debug(f"DEBUG: x-user-id in all_headers: {'x-user-id' in all_headers}, x-is-demo in all_headers: {'x-is-demo' in all_headers}")
|
||||
logger.debug(f"DEBUG: all_headers keys: {list(all_headers.keys())[:10]}...") # Show first 10 keys
|
||||
|
||||
logger.info(f"📤 Forwarding headers to auth service - x_user_id: {all_headers.get('x-user-id', 'MISSING')}, x_is_demo: {all_headers.get('x-is-demo', 'MISSING')}, x_demo_session_id: {all_headers.get('x-demo-session-id', 'MISSING')}, headers: {list(all_headers.keys())}")
|
||||
|
||||
# Check if headers are missing and try to get them from request.state
|
||||
if request and hasattr(request, 'state') and hasattr(request.state, 'injected_headers'):
|
||||
logger.debug(f"DEBUG: Found injected_headers in request.state: {request.state.injected_headers}")
|
||||
# Add missing headers from request.state
|
||||
if 'x-user-id' not in all_headers and 'x-user-id' in request.state.injected_headers:
|
||||
all_headers['x-user-id'] = request.state.injected_headers['x-user-id']
|
||||
logger.debug(f"DEBUG: Added x-user-id from request.state: {all_headers['x-user-id']}")
|
||||
if 'x-user-email' not in all_headers and 'x-user-email' in request.state.injected_headers:
|
||||
all_headers['x-user-email'] = request.state.injected_headers['x-user-email']
|
||||
logger.debug(f"DEBUG: Added x-user-email from request.state: {all_headers['x-user-email']}")
|
||||
if 'x-user-role' not in all_headers and 'x-user-role' in request.state.injected_headers:
|
||||
all_headers['x-user-role'] = request.state.injected_headers['x-user-role']
|
||||
logger.debug(f"DEBUG: Added x-user-role from request.state: {all_headers['x-user-role']}")
|
||||
|
||||
# Add is_demo flag if this is a demo session
|
||||
if hasattr(request.state, 'is_demo_session') and request.state.is_demo_session:
|
||||
all_headers['x-is-demo'] = 'true'
|
||||
logger.debug(f"DEBUG: Added x-is-demo from request.state.is_demo_session")
|
||||
|
||||
# Filter out hop-by-hop headers
|
||||
filtered_headers = {
|
||||
k: v for k, v in all_headers.items()
|
||||
if k.lower() not in hop_by_hop_headers
|
||||
}
|
||||
elif hasattr(headers, 'raw'):
|
||||
logger.debug(f"DEBUG: Entering raw branch")
|
||||
|
||||
# Filter out hop-by-hop headers
|
||||
filtered_headers = {
|
||||
k: v for k, v in all_headers.items()
|
||||
if k.lower() not in hop_by_hop_headers
|
||||
}
|
||||
elif hasattr(headers, 'raw'):
|
||||
# Fallback to raw headers if _list not available
|
||||
all_headers = {
|
||||
k.decode() if isinstance(k, bytes) else k: v.decode() if isinstance(v, bytes) else v
|
||||
for k, v in headers.raw
|
||||
}
|
||||
logger.info(f"📤 Forwarding headers to auth service - x_user_id: {all_headers.get('x-user-id', 'MISSING')}, x_is_demo: {all_headers.get('x-is-demo', 'MISSING')}, x_demo_session_id: {all_headers.get('x-demo-session-id', 'MISSING')}, headers: {list(all_headers.keys())}")
|
||||
|
||||
filtered_headers = {
|
||||
k.decode() if isinstance(k, bytes) else k: v.decode() if isinstance(v, bytes) else v
|
||||
for k, v in headers.raw
|
||||
if (k.decode() if isinstance(k, bytes) else k).lower() not in hop_by_hop_headers
|
||||
}
|
||||
else:
|
||||
# Handle case where headers is already a dict
|
||||
logger.info(f"📤 Forwarding headers to auth service - x_user_id: {headers.get('x-user-id', 'MISSING')}, x_is_demo: {headers.get('x-is-demo', 'MISSING')}, x_demo_session_id: {headers.get('x-demo-session-id', 'MISSING')}, headers: {list(headers.keys())}")
|
||||
|
||||
filtered_headers = {
|
||||
k: v for k, v in headers.items()
|
||||
if k.lower() not in hop_by_hop_headers
|
||||
}
|
||||
|
||||
# Add gateway identifier
|
||||
filtered_headers['X-Forwarded-By'] = 'bakery-gateway'
|
||||
filtered_headers['X-Gateway-Version'] = '1.0.0'
|
||||
|
||||
return filtered_headers
|
||||
return all_headers
|
||||
|
||||
def _prepare_response_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
|
||||
"""Prepare response headers"""
|
||||
|
||||
@@ -8,6 +8,7 @@ import httpx
|
||||
import structlog
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
@@ -29,12 +30,8 @@ async def proxy_demo_service(path: str, request: Request):
|
||||
if request.method in ["POST", "PUT", "PATCH"]:
|
||||
body = await request.body()
|
||||
|
||||
# Forward headers (excluding host)
|
||||
headers = {
|
||||
key: value
|
||||
for key, value in request.headers.items()
|
||||
if key.lower() not in ["host", "content-length"]
|
||||
}
|
||||
# Use unified HeaderManager for consistent header forwarding
|
||||
headers = header_manager.get_all_headers_for_proxy(request)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
|
||||
@@ -5,6 +5,7 @@ from fastapi.responses import JSONResponse
|
||||
import httpx
|
||||
import structlog
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter()
|
||||
@@ -26,12 +27,8 @@ async def proxy_geocoding(request: Request, path: str):
|
||||
if request.method in ["POST", "PUT", "PATCH"]:
|
||||
body = await request.body()
|
||||
|
||||
# Forward headers (excluding host)
|
||||
headers = {
|
||||
key: value
|
||||
for key, value in request.headers.items()
|
||||
if key.lower() not in ["host", "content-length"]
|
||||
}
|
||||
# Use unified HeaderManager for consistent header forwarding
|
||||
headers = header_manager.get_all_headers_for_proxy(request)
|
||||
|
||||
# Make the proxied request
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
|
||||
@@ -8,6 +8,7 @@ from fastapi.responses import JSONResponse
|
||||
import httpx
|
||||
import structlog
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter()
|
||||
@@ -44,12 +45,8 @@ async def proxy_poi_context(request: Request, path: str):
|
||||
if request.method in ["POST", "PUT", "PATCH"]:
|
||||
body = await request.body()
|
||||
|
||||
# Copy headers (exclude host and content-length as they'll be set by httpx)
|
||||
headers = {
|
||||
key: value
|
||||
for key, value in request.headers.items()
|
||||
if key.lower() not in ["host", "content-length"]
|
||||
}
|
||||
# Use unified HeaderManager for consistent header forwarding
|
||||
headers = header_manager.get_all_headers_for_proxy(request)
|
||||
|
||||
# Make the request to the external service
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
|
||||
@@ -8,6 +8,7 @@ import httpx
|
||||
import logging
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
@@ -45,9 +46,8 @@ async def _proxy_to_pos_service(request: Request, target_path: str):
|
||||
try:
|
||||
url = f"{settings.POS_SERVICE_URL}{target_path}"
|
||||
|
||||
# Forward headers
|
||||
headers = dict(request.headers)
|
||||
headers.pop("host", None)
|
||||
# Use unified HeaderManager for consistent header forwarding
|
||||
headers = header_manager.get_all_headers_for_proxy(request)
|
||||
|
||||
# Add query parameters
|
||||
params = dict(request.query_params)
|
||||
|
||||
@@ -9,6 +9,7 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
@@ -98,29 +99,13 @@ async def _proxy_request(request: Request, target_path: str, service_url: str):
|
||||
try:
|
||||
url = f"{service_url}{target_path}"
|
||||
|
||||
# Forward headers and add user/tenant context
|
||||
headers = dict(request.headers)
|
||||
headers.pop("host", None)
|
||||
# Use unified HeaderManager for consistent header forwarding
|
||||
headers = header_manager.get_all_headers_for_proxy(request)
|
||||
|
||||
# Add user context headers if available
|
||||
if hasattr(request.state, 'user') and request.state.user:
|
||||
user = request.state.user
|
||||
headers["x-user-id"] = str(user.get('user_id', ''))
|
||||
headers["x-user-email"] = str(user.get('email', ''))
|
||||
headers["x-user-role"] = str(user.get('role', 'user'))
|
||||
headers["x-user-full-name"] = str(user.get('full_name', ''))
|
||||
headers["x-tenant-id"] = str(user.get('tenant_id', ''))
|
||||
|
||||
# Add subscription context headers
|
||||
if user.get('subscription_tier'):
|
||||
headers["x-subscription-tier"] = str(user.get('subscription_tier', ''))
|
||||
logger.debug(f"Forwarding subscription tier: {user.get('subscription_tier')}")
|
||||
|
||||
if user.get('subscription_status'):
|
||||
headers["x-subscription-status"] = str(user.get('subscription_status', ''))
|
||||
logger.debug(f"Forwarding subscription status: {user.get('subscription_status')}")
|
||||
|
||||
logger.info(f"Forwarding subscription request to {url} with user context: user_id={user.get('user_id')}, email={user.get('email')}, subscription_tier={user.get('subscription_tier', 'not_set')}")
|
||||
# Debug logging
|
||||
user_context = getattr(request.state, 'user', None)
|
||||
if user_context:
|
||||
logger.info(f"Forwarding subscription request to {url} with user context: user_id={user_context.get('user_id')}, email={user_context.get('email')}, subscription_tier={user_context.get('subscription_tier', 'not_set')}")
|
||||
else:
|
||||
logger.warning(f"No user context available when forwarding subscription request to {url}")
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
@@ -715,36 +716,18 @@ async def _proxy_request(request: Request, target_path: str, service_url: str, t
|
||||
try:
|
||||
url = f"{service_url}{target_path}"
|
||||
|
||||
# Forward headers and add user/tenant context
|
||||
headers = dict(request.headers)
|
||||
headers.pop("host", None)
|
||||
# Use unified HeaderManager for consistent header forwarding
|
||||
headers = header_manager.get_all_headers_for_proxy(request)
|
||||
|
||||
# Add tenant ID header if provided
|
||||
# Add tenant ID header if provided (override if needed)
|
||||
if tenant_id:
|
||||
headers["X-Tenant-ID"] = tenant_id
|
||||
headers["x-tenant-id"] = tenant_id
|
||||
|
||||
# Add user context headers if available
|
||||
if hasattr(request.state, 'user') and request.state.user:
|
||||
user = request.state.user
|
||||
headers["x-user-id"] = str(user.get('user_id', ''))
|
||||
headers["x-user-email"] = str(user.get('email', ''))
|
||||
headers["x-user-role"] = str(user.get('role', 'user'))
|
||||
headers["x-user-full-name"] = str(user.get('full_name', ''))
|
||||
headers["x-tenant-id"] = tenant_id or str(user.get('tenant_id', ''))
|
||||
|
||||
# Add subscription context headers
|
||||
if user.get('subscription_tier'):
|
||||
headers["x-subscription-tier"] = str(user.get('subscription_tier', ''))
|
||||
logger.debug(f"Forwarding subscription tier: {user.get('subscription_tier')}")
|
||||
|
||||
if user.get('subscription_status'):
|
||||
headers["x-subscription-status"] = str(user.get('subscription_status', ''))
|
||||
logger.debug(f"Forwarding subscription status: {user.get('subscription_status')}")
|
||||
|
||||
# Debug logging
|
||||
logger.info(f"Forwarding request to {url} with user context: user_id={user.get('user_id')}, email={user.get('email')}, tenant_id={tenant_id}, subscription_tier={user.get('subscription_tier', 'not_set')}")
|
||||
# Debug logging
|
||||
user_context = getattr(request.state, 'user', None)
|
||||
if user_context:
|
||||
logger.info(f"Forwarding request to {url} with user context: user_id={user_context.get('user_id')}, email={user_context.get('email')}, tenant_id={tenant_id}, subscription_tier={user_context.get('subscription_tier', 'not_set')}")
|
||||
else:
|
||||
# Debug logging when no user context available
|
||||
logger.warning(f"No user context available when forwarding request to {url}. request.state.user: {getattr(request.state, 'user', 'NOT_SET')}")
|
||||
|
||||
# Get request body if present
|
||||
@@ -782,9 +765,10 @@ async def _proxy_request(request: Request, target_path: str, service_url: str, t
|
||||
|
||||
logger.info(f"Forwarding multipart request with files={list(files.keys()) if files else None}, data={list(data.keys()) if data else None}")
|
||||
|
||||
# Remove content-type from headers - httpx will set it with new boundary
|
||||
headers.pop("content-type", None)
|
||||
headers.pop("content-length", None)
|
||||
# For multipart requests, we need to get fresh headers since httpx will set content-type
|
||||
# Get all headers again to ensure we have the complete set
|
||||
headers = header_manager.get_all_headers_for_proxy(request)
|
||||
# httpx will automatically set content-type for multipart, so we don't need to remove it
|
||||
else:
|
||||
# For other content types, use body as before
|
||||
body = await request.body()
|
||||
|
||||
@@ -13,6 +13,7 @@ from typing import Dict, Any
|
||||
import json
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
from app.core.service_discovery import ServiceDiscovery
|
||||
from shared.monitoring.metrics import MetricsCollector
|
||||
|
||||
@@ -136,64 +137,28 @@ class UserProxy:
|
||||
return AUTH_SERVICE_URL
|
||||
|
||||
def _prepare_headers(self, headers, request=None) -> Dict[str, str]:
|
||||
"""Prepare headers for forwarding (remove hop-by-hop headers)"""
|
||||
# Remove hop-by-hop headers
|
||||
hop_by_hop_headers = {
|
||||
'connection', 'keep-alive', 'proxy-authenticate',
|
||||
'proxy-authorization', 'te', 'trailers', 'upgrade'
|
||||
}
|
||||
|
||||
# Convert headers to dict if it's a Headers object
|
||||
# This ensures we get ALL headers including those added by middleware
|
||||
if hasattr(headers, '_list'):
|
||||
# Get headers from the _list where middleware adds them
|
||||
all_headers_list = headers.__dict__.get('_list', [])
|
||||
|
||||
# Convert to dict for easier processing
|
||||
all_headers = {}
|
||||
for k, v in all_headers_list:
|
||||
key = k.decode() if isinstance(k, bytes) else k
|
||||
value = v.decode() if isinstance(v, bytes) else v
|
||||
all_headers[key] = value
|
||||
|
||||
# Check if headers are missing and try to get them from request.state
|
||||
if request and hasattr(request, 'state') and hasattr(request.state, 'injected_headers'):
|
||||
# Add missing headers from request.state
|
||||
if 'x-user-id' not in all_headers and 'x-user-id' in request.state.injected_headers:
|
||||
all_headers['x-user-id'] = request.state.injected_headers['x-user-id']
|
||||
if 'x-user-email' not in all_headers and 'x-user-email' in request.state.injected_headers:
|
||||
all_headers['x-user-email'] = request.state.injected_headers['x-user-email']
|
||||
if 'x-user-role' not in all_headers and 'x-user-role' in request.state.injected_headers:
|
||||
all_headers['x-user-role'] = request.state.injected_headers['x-user-role']
|
||||
|
||||
# Add is_demo flag if this is a demo session
|
||||
if hasattr(request.state, 'is_demo_session') and request.state.is_demo_session:
|
||||
all_headers['x-is-demo'] = 'true'
|
||||
|
||||
# Filter out hop-by-hop headers
|
||||
filtered_headers = {
|
||||
k: v for k, v in all_headers.items()
|
||||
if k.lower() not in hop_by_hop_headers
|
||||
}
|
||||
elif hasattr(headers, 'raw'):
|
||||
# FastAPI/Starlette Headers object - use raw to get all headers
|
||||
filtered_headers = {
|
||||
k.decode() if isinstance(k, bytes) else k: v.decode() if isinstance(v, bytes) else v
|
||||
for k, v in headers.raw
|
||||
if (k.decode() if isinstance(k, bytes) else k).lower() not in hop_by_hop_headers
|
||||
}
|
||||
"""Prepare headers for forwarding using unified HeaderManager"""
|
||||
# Use unified HeaderManager to get all headers
|
||||
if request:
|
||||
all_headers = header_manager.get_all_headers_for_proxy(request)
|
||||
else:
|
||||
# Already a dict
|
||||
filtered_headers = {
|
||||
k: v for k, v in headers.items()
|
||||
if k.lower() not in hop_by_hop_headers
|
||||
}
|
||||
# Fallback: convert headers to dict manually
|
||||
all_headers = {}
|
||||
if hasattr(headers, '_list'):
|
||||
for k, v in headers.__dict__.get('_list', []):
|
||||
key = k.decode() if isinstance(k, bytes) else k
|
||||
value = v.decode() if isinstance(v, bytes) else v
|
||||
all_headers[key] = value
|
||||
elif hasattr(headers, 'raw'):
|
||||
for k, v in headers.raw:
|
||||
key = k.decode() if isinstance(k, bytes) else k
|
||||
value = v.decode() if isinstance(v, bytes) else v
|
||||
all_headers[key] = value
|
||||
else:
|
||||
# Headers is already a dict
|
||||
all_headers = dict(headers)
|
||||
|
||||
# Add gateway identifier
|
||||
filtered_headers['X-Forwarded-By'] = 'bakery-gateway'
|
||||
filtered_headers['X-Gateway-Version'] = '1.0.0'
|
||||
|
||||
return filtered_headers
|
||||
return all_headers
|
||||
|
||||
def _prepare_response_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
|
||||
"""Prepare response headers"""
|
||||
|
||||
82
scripts/cleanup-docker.sh
Executable file
82
scripts/cleanup-docker.sh
Executable file
@@ -0,0 +1,82 @@
|
||||
#!/bin/bash
|
||||
# Docker Cleanup Script for Local Kubernetes Development
|
||||
# This script helps prevent disk space issues by cleaning up unused Docker resources
|
||||
|
||||
set -e
|
||||
|
||||
echo "🧹 Docker Cleanup Script for Bakery-IA Local Development"
|
||||
echo "========================================================="
|
||||
echo ""
|
||||
|
||||
# Check if we should run automatically or ask for confirmation
|
||||
AUTO_MODE=${1:-""}
|
||||
|
||||
# Show current disk usage
|
||||
echo "📊 Current Docker Disk Usage:"
|
||||
docker system df
|
||||
echo ""
|
||||
|
||||
# Check Kind node disk usage if cluster is running
|
||||
if docker ps | grep -q "bakery-ia-local-control-plane"; then
|
||||
echo "📊 Kind Node Disk Usage:"
|
||||
docker exec bakery-ia-local-control-plane df -h / /var | grep -E "(Filesystem|overlay|/dev/vdb1)"
|
||||
echo ""
|
||||
fi
|
||||
|
||||
# Calculate reclaimable space
|
||||
RECLAIMABLE=$(docker system df | grep "Images" | awk '{print $4}')
|
||||
echo "💾 Estimated reclaimable space: $RECLAIMABLE"
|
||||
echo ""
|
||||
|
||||
# Ask for confirmation unless in auto mode
|
||||
if [ "$AUTO_MODE" != "--auto" ]; then
|
||||
read -p "Do you want to proceed with cleanup? (y/n) " -n 1 -r
|
||||
echo ""
|
||||
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
|
||||
echo "❌ Cleanup cancelled"
|
||||
exit 0
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "🚀 Starting cleanup..."
|
||||
echo ""
|
||||
|
||||
# Remove unused images (keep images from last 24 hours)
|
||||
echo "1️⃣ Removing unused Docker images..."
|
||||
docker image prune -af --filter "until=24h" || true
|
||||
echo ""
|
||||
|
||||
# Remove unused volumes
|
||||
echo "2️⃣ Removing unused Docker volumes..."
|
||||
docker volume prune -f || true
|
||||
echo ""
|
||||
|
||||
# Remove build cache
|
||||
echo "3️⃣ Removing build cache..."
|
||||
docker builder prune -af || true
|
||||
echo ""
|
||||
|
||||
# Show results
|
||||
echo "✅ Cleanup completed!"
|
||||
echo ""
|
||||
echo "📊 Final Docker Disk Usage:"
|
||||
docker system df
|
||||
echo ""
|
||||
|
||||
# Check Kind node disk usage if cluster is running
|
||||
if docker ps | grep -q "bakery-ia-local-control-plane"; then
|
||||
echo "📊 Kind Node Disk Usage After Cleanup:"
|
||||
docker exec bakery-ia-local-control-plane df -h / /var | grep -E "(Filesystem|overlay|/dev/vdb1)"
|
||||
echo ""
|
||||
|
||||
# Warn if still above 80%
|
||||
USAGE=$(docker exec bakery-ia-local-control-plane df -h /var | tail -1 | awk '{print $5}' | sed 's/%//')
|
||||
if [ "$USAGE" -gt 80 ]; then
|
||||
echo "⚠️ Warning: Disk usage is still above 80%. Consider:"
|
||||
echo " - Deleting and recreating the Kind cluster"
|
||||
echo " - Increasing Docker's disk allocation"
|
||||
echo " - Running: docker system prune -a --volumes -f"
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "🎉 All done!"
|
||||
@@ -1,82 +0,0 @@
|
||||
# services/forecasting/app/clients/inventory_client.py
|
||||
"""
|
||||
Simple client for inventory service integration
|
||||
Used when product names are not available locally
|
||||
"""
|
||||
|
||||
import aiohttp
|
||||
import structlog
|
||||
from typing import Optional, Dict, Any
|
||||
import os
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
class InventoryServiceClient:
|
||||
"""Simple client for inventory service interactions"""
|
||||
|
||||
def __init__(self, base_url: str = None):
|
||||
self.base_url = base_url or os.getenv("INVENTORY_SERVICE_URL", "http://inventory-service:8000")
|
||||
self.timeout = aiohttp.ClientTimeout(total=5) # 5 second timeout
|
||||
|
||||
async def get_product_name(self, tenant_id: str, inventory_product_id: str) -> Optional[str]:
|
||||
"""
|
||||
Get product name from inventory service
|
||||
Returns None if service is unavailable or product not found
|
||||
"""
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=self.timeout) as session:
|
||||
url = f"{self.base_url}/api/v1/products/{inventory_product_id}"
|
||||
headers = {"X-Tenant-ID": tenant_id}
|
||||
|
||||
async with session.get(url, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
return data.get("name", f"Product-{inventory_product_id}")
|
||||
else:
|
||||
logger.debug("Product not found in inventory service",
|
||||
inventory_product_id=inventory_product_id,
|
||||
status=response.status)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.debug("Failed to get product name from inventory service",
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e))
|
||||
return None
|
||||
|
||||
async def get_multiple_product_names(self, tenant_id: str, product_ids: list) -> Dict[str, str]:
|
||||
"""
|
||||
Get multiple product names efficiently
|
||||
Returns a mapping of product_id -> product_name
|
||||
"""
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=self.timeout) as session:
|
||||
url = f"{self.base_url}/api/v1/products/batch"
|
||||
headers = {"X-Tenant-ID": tenant_id}
|
||||
payload = {"product_ids": product_ids}
|
||||
|
||||
async with session.post(url, json=payload, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
return {item["id"]: item["name"] for item in data.get("products", [])}
|
||||
else:
|
||||
logger.debug("Batch product lookup failed",
|
||||
product_count=len(product_ids),
|
||||
status=response.status)
|
||||
return {}
|
||||
|
||||
except Exception as e:
|
||||
logger.debug("Failed to get product names from inventory service",
|
||||
product_count=len(product_ids),
|
||||
error=str(e))
|
||||
return {}
|
||||
|
||||
# Global client instance
|
||||
_inventory_client = None
|
||||
|
||||
def get_inventory_client() -> InventoryServiceClient:
|
||||
"""Get the global inventory client instance"""
|
||||
global _inventory_client
|
||||
if _inventory_client is None:
|
||||
_inventory_client = InventoryServiceClient()
|
||||
return _inventory_client
|
||||
@@ -12,7 +12,6 @@ from datetime import datetime, timedelta
|
||||
import structlog
|
||||
|
||||
from shared.messaging import UnifiedEventPublisher
|
||||
from app.clients.inventory_client import get_inventory_client
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
@@ -30,11 +30,11 @@ async def trigger_inventory_alerts(
|
||||
- Expiring ingredients
|
||||
- Overstock situations
|
||||
|
||||
Security: Protected by X-Internal-Service header check.
|
||||
Security: Protected by x-internal-service header check.
|
||||
"""
|
||||
try:
|
||||
# Verify internal service header
|
||||
if not request or request.headers.get("X-Internal-Service") not in ["demo-session", "internal"]:
|
||||
if not request or request.headers.get("x-internal-service") not in ["demo-session", "internal"]:
|
||||
logger.warning("Unauthorized internal API call", tenant_id=str(tenant_id))
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
|
||||
@@ -350,7 +350,7 @@ async def generate_safety_stock_insights_internal(
|
||||
This endpoint is called by the demo-session service after cloning data.
|
||||
It uses the same ML logic as the public endpoint but with optimized defaults.
|
||||
|
||||
Security: Protected by X-Internal-Service header check.
|
||||
Security: Protected by x-internal-service header check.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant UUID
|
||||
@@ -365,7 +365,7 @@ async def generate_safety_stock_insights_internal(
|
||||
}
|
||||
"""
|
||||
# Verify internal service header
|
||||
if not request or request.headers.get("X-Internal-Service") not in ["demo-session", "internal"]:
|
||||
if not request or request.headers.get("x-internal-service") not in ["demo-session", "internal"]:
|
||||
logger.warning("Unauthorized internal API call", tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
|
||||
@@ -29,7 +29,7 @@ async def trigger_delivery_tracking(
|
||||
This endpoint is called by the demo session cloning process after POs are seeded
|
||||
to generate realistic delivery alerts (arriving soon, overdue, etc.).
|
||||
|
||||
Security: Protected by X-Internal-Service header check.
|
||||
Security: Protected by x-internal-service header check.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID to check deliveries for
|
||||
@@ -49,7 +49,7 @@ async def trigger_delivery_tracking(
|
||||
"""
|
||||
try:
|
||||
# Verify internal service header
|
||||
if not request or request.headers.get("X-Internal-Service") not in ["demo-session", "internal"]:
|
||||
if not request or request.headers.get("x-internal-service") not in ["demo-session", "internal"]:
|
||||
logger.warning("Unauthorized internal API call", tenant_id=str(tenant_id))
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
|
||||
@@ -566,7 +566,7 @@ async def generate_price_insights_internal(
|
||||
This endpoint is called by the demo-session service after cloning data.
|
||||
It uses the same ML logic as the public endpoint but with optimized defaults.
|
||||
|
||||
Security: Protected by X-Internal-Service header check.
|
||||
Security: Protected by x-internal-service header check.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant UUID
|
||||
@@ -581,7 +581,7 @@ async def generate_price_insights_internal(
|
||||
}
|
||||
"""
|
||||
# Verify internal service header
|
||||
if not request or request.headers.get("X-Internal-Service") not in ["demo-session", "internal"]:
|
||||
if not request or request.headers.get("x-internal-service") not in ["demo-session", "internal"]:
|
||||
logger.warning("Unauthorized internal API call", tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
|
||||
@@ -1,42 +1,45 @@
|
||||
"""
|
||||
FastAPI Dependencies for Procurement Service
|
||||
Uses shared authentication infrastructure with UUID validation
|
||||
"""
|
||||
|
||||
from fastapi import Header, HTTPException, status
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from uuid import UUID
|
||||
from typing import Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from .database import get_db
|
||||
from shared.auth.decorators import get_current_tenant_id_dep
|
||||
|
||||
|
||||
async def get_current_tenant_id(
|
||||
x_tenant_id: Optional[str] = Header(None, alias="X-Tenant-ID")
|
||||
tenant_id: Optional[str] = Depends(get_current_tenant_id_dep)
|
||||
) -> UUID:
|
||||
"""
|
||||
Extract and validate tenant ID from request header.
|
||||
Extract and validate tenant ID from request using shared infrastructure.
|
||||
Adds UUID validation to ensure tenant ID format is correct.
|
||||
|
||||
Args:
|
||||
x_tenant_id: Tenant ID from X-Tenant-ID header
|
||||
tenant_id: Tenant ID from shared dependency
|
||||
|
||||
Returns:
|
||||
UUID: Validated tenant ID
|
||||
|
||||
Raises:
|
||||
HTTPException: If tenant ID is missing or invalid
|
||||
HTTPException: If tenant ID is missing or invalid UUID format
|
||||
"""
|
||||
if not x_tenant_id:
|
||||
if not tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="X-Tenant-ID header is required"
|
||||
detail="x-tenant-id header is required"
|
||||
)
|
||||
|
||||
try:
|
||||
return UUID(x_tenant_id)
|
||||
return UUID(tenant_id)
|
||||
except (ValueError, AttributeError):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid tenant ID format: {x_tenant_id}"
|
||||
detail=f"Invalid tenant ID format: {tenant_id}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -31,11 +31,11 @@ async def trigger_production_alerts(
|
||||
- Equipment maintenance alerts
|
||||
- Batch start delays
|
||||
|
||||
Security: Protected by X-Internal-Service header check.
|
||||
Security: Protected by x-internal-service header check.
|
||||
"""
|
||||
try:
|
||||
# Verify internal service header
|
||||
if not request or request.headers.get("X-Internal-Service") not in ["demo-session", "internal"]:
|
||||
if not request or request.headers.get("x-internal-service") not in ["demo-session", "internal"]:
|
||||
logger.warning("Unauthorized internal API call", tenant_id=str(tenant_id))
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
|
||||
@@ -331,7 +331,7 @@ async def generate_yield_insights_internal(
|
||||
This endpoint is called by the demo-session service after cloning data.
|
||||
It uses the same ML logic as the public endpoint but with optimized defaults.
|
||||
|
||||
Security: Protected by X-Internal-Service header check.
|
||||
Security: Protected by x-internal-service header check.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant UUID
|
||||
@@ -346,7 +346,7 @@ async def generate_yield_insights_internal(
|
||||
}
|
||||
"""
|
||||
# Verify internal service header
|
||||
if not request or request.headers.get("X-Internal-Service") not in ["demo-session", "internal"]:
|
||||
if not request or request.headers.get("x-internal-service") not in ["demo-session", "internal"]:
|
||||
logger.warning("Unauthorized internal API call", tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
|
||||
@@ -204,7 +204,7 @@ class TenantMemberRepository(TenantBaseRepository):
|
||||
f"{auth_service_url}/api/v1/auth/users/batch",
|
||||
json={"user_ids": user_ids},
|
||||
timeout=10.0,
|
||||
headers={"X-Internal-Service": "tenant-service"}
|
||||
headers={"x-internal-service": "tenant-service"}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
@@ -226,7 +226,7 @@ class TenantMemberRepository(TenantBaseRepository):
|
||||
response = await client.get(
|
||||
f"{auth_service_url}/api/v1/auth/users/{user_id}",
|
||||
timeout=5.0,
|
||||
headers={"X-Internal-Service": "tenant-service"}
|
||||
headers={"x-internal-service": "tenant-service"}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
user_data = response.json()
|
||||
@@ -243,7 +243,7 @@ class TenantMemberRepository(TenantBaseRepository):
|
||||
response = await client.get(
|
||||
f"{auth_service_url}/api/v1/auth/users/{user_id}",
|
||||
timeout=5.0,
|
||||
headers={"X-Internal-Service": "tenant-service"}
|
||||
headers={"x-internal-service": "tenant-service"}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
user_data = response.json()
|
||||
|
||||
@@ -216,17 +216,24 @@ class HybridProphetXGBoost:
|
||||
Get Prophet predictions for given dataframe.
|
||||
|
||||
Args:
|
||||
prophet_result: Prophet model result from training
|
||||
prophet_result: Prophet model result from training (contains model_path)
|
||||
df: DataFrame with 'ds' column
|
||||
|
||||
Returns:
|
||||
Array of predictions
|
||||
"""
|
||||
# Get the Prophet model from result
|
||||
prophet_model = prophet_result.get('model')
|
||||
# Get the model path from result instead of expecting the model object directly
|
||||
model_path = prophet_result.get('model_path')
|
||||
|
||||
if prophet_model is None:
|
||||
raise ValueError("Prophet model not found in result")
|
||||
if model_path is None:
|
||||
raise ValueError("Prophet model path not found in result")
|
||||
|
||||
# Load the actual Prophet model from the stored path
|
||||
try:
|
||||
import joblib
|
||||
prophet_model = joblib.load(model_path)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load Prophet model from path {model_path}: {str(e)}")
|
||||
|
||||
# Prepare dataframe for prediction
|
||||
pred_df = df[['ds']].copy()
|
||||
@@ -273,7 +280,8 @@ class HybridProphetXGBoost:
|
||||
'reg_lambda': 1.0, # L2 regularization
|
||||
'objective': 'reg:squarederror',
|
||||
'random_state': 42,
|
||||
'n_jobs': -1
|
||||
'n_jobs': -1,
|
||||
'early_stopping_rounds': 10
|
||||
}
|
||||
|
||||
# Initialize model
|
||||
@@ -285,7 +293,6 @@ class HybridProphetXGBoost:
|
||||
model.fit,
|
||||
X_train, y_train,
|
||||
eval_set=[(X_val, y_val)],
|
||||
early_stopping_rounds=10,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
@@ -303,109 +310,86 @@ class HybridProphetXGBoost:
|
||||
train_prophet_pred: np.ndarray,
|
||||
val_prophet_pred: np.ndarray,
|
||||
prophet_result: Dict[str, Any]
|
||||
) -> Dict[str, float]:
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Evaluate hybrid model vs Prophet-only on validation set.
|
||||
|
||||
Args:
|
||||
train_df: Training data
|
||||
val_df: Validation data
|
||||
train_prophet_pred: Prophet predictions on training set
|
||||
val_prophet_pred: Prophet predictions on validation set
|
||||
prophet_result: Prophet training result
|
||||
|
||||
Returns:
|
||||
Dictionary of metrics
|
||||
Evaluate the overall performance of the hybrid model using threading for metrics.
|
||||
"""
|
||||
# Get actual values
|
||||
train_actual = train_df['y'].values
|
||||
val_actual = val_df['y'].values
|
||||
import asyncio
|
||||
|
||||
# Get XGBoost predictions on residuals
|
||||
# Get XGBoost predictions on training and validation
|
||||
X_train = train_df[self.feature_columns].values
|
||||
X_val = val_df[self.feature_columns].values
|
||||
|
||||
# ✅ FIX: Run blocking predict() in thread pool to avoid blocking event loop
|
||||
import asyncio
|
||||
train_xgb_pred = await asyncio.to_thread(self.xgb_model.predict, X_train)
|
||||
val_xgb_pred = await asyncio.to_thread(self.xgb_model.predict, X_val)
|
||||
|
||||
# Hybrid predictions = Prophet + XGBoost residual correction
|
||||
# Hybrid prediction = Prophet prediction + XGBoost residual prediction
|
||||
train_hybrid_pred = train_prophet_pred + train_xgb_pred
|
||||
val_hybrid_pred = val_prophet_pred + val_xgb_pred
|
||||
|
||||
# Calculate metrics for Prophet-only
|
||||
prophet_train_mae = mean_absolute_error(train_actual, train_prophet_pred)
|
||||
prophet_val_mae = mean_absolute_error(val_actual, val_prophet_pred)
|
||||
prophet_train_mape = mean_absolute_percentage_error(train_actual, train_prophet_pred) * 100
|
||||
prophet_val_mape = mean_absolute_percentage_error(val_actual, val_prophet_pred) * 100
|
||||
actual_train = train_df['y'].values
|
||||
actual_val = val_df['y'].values
|
||||
|
||||
# Calculate metrics for Hybrid
|
||||
hybrid_train_mae = mean_absolute_error(train_actual, train_hybrid_pred)
|
||||
hybrid_val_mae = mean_absolute_error(val_actual, val_hybrid_pred)
|
||||
hybrid_train_mape = mean_absolute_percentage_error(train_actual, train_hybrid_pred) * 100
|
||||
hybrid_val_mape = mean_absolute_percentage_error(val_actual, val_hybrid_pred) * 100
|
||||
# Basic RMSE calculation
|
||||
train_rmse = float(np.sqrt(np.mean((actual_train - train_hybrid_pred)**2)))
|
||||
val_rmse = float(np.sqrt(np.mean((actual_val - val_hybrid_pred)**2)))
|
||||
|
||||
# MAE
|
||||
train_mae = float(np.mean(np.abs(actual_train - train_hybrid_pred)))
|
||||
val_mae = float(np.mean(np.abs(actual_val - val_hybrid_pred)))
|
||||
|
||||
# MAPE (with safety for zero sales)
|
||||
train_mape = float(np.mean(np.abs((actual_train - train_hybrid_pred) / np.maximum(actual_train, 1))))
|
||||
val_mape = float(np.mean(np.abs((actual_val - val_hybrid_pred) / np.maximum(actual_val, 1))))
|
||||
|
||||
# Calculate improvement
|
||||
mae_improvement = ((prophet_val_mae - hybrid_val_mae) / prophet_val_mae) * 100
|
||||
mape_improvement = ((prophet_val_mape - hybrid_val_mape) / prophet_val_mape) * 100
|
||||
prophet_metrics = prophet_result.get("metrics", {})
|
||||
prophet_val_mae = prophet_metrics.get("val_mae", val_mae) # Fallback to hybrid if missing
|
||||
prophet_val_mape = prophet_metrics.get("val_mape", val_mape)
|
||||
|
||||
improvement_pct = 0.0
|
||||
if prophet_val_mape > 0:
|
||||
improvement_pct = ((prophet_val_mape - val_mape) / prophet_val_mape) * 100
|
||||
|
||||
metrics = {
|
||||
'prophet_train_mae': float(prophet_train_mae),
|
||||
'prophet_val_mae': float(prophet_val_mae),
|
||||
'prophet_train_mape': float(prophet_train_mape),
|
||||
'prophet_val_mape': float(prophet_val_mape),
|
||||
'hybrid_train_mae': float(hybrid_train_mae),
|
||||
'hybrid_val_mae': float(hybrid_val_mae),
|
||||
'hybrid_train_mape': float(hybrid_train_mape),
|
||||
'hybrid_val_mape': float(hybrid_val_mape),
|
||||
'mae_improvement_pct': float(mae_improvement),
|
||||
'mape_improvement_pct': float(mape_improvement),
|
||||
'improvement_percentage': float(mape_improvement) # Primary metric
|
||||
"train_rmse": train_rmse,
|
||||
"val_rmse": val_rmse,
|
||||
"train_mae": train_mae,
|
||||
"val_mae": val_mae,
|
||||
"train_mape": train_mape,
|
||||
"val_mape": val_mape,
|
||||
"prophet_val_mape": prophet_val_mape,
|
||||
"hybrid_val_mape": val_mape,
|
||||
"improvement_percentage": float(improvement_pct),
|
||||
"prophet_metrics": prophet_metrics
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"Hybrid model evaluation complete",
|
||||
val_rmse=val_rmse,
|
||||
val_mae=val_mae,
|
||||
val_mape=val_mape,
|
||||
improvement=improvement_pct
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
def _package_hybrid_model(
|
||||
self,
|
||||
prophet_result: Dict[str, Any],
|
||||
metrics: Dict[str, float],
|
||||
metrics: Dict[str, Any],
|
||||
tenant_id: str,
|
||||
inventory_product_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Package hybrid model for storage.
|
||||
|
||||
Args:
|
||||
prophet_result: Prophet model result
|
||||
metrics: Hybrid model metrics
|
||||
tenant_id: Tenant ID
|
||||
inventory_product_id: Product ID
|
||||
|
||||
Returns:
|
||||
Model package dictionary
|
||||
"""
|
||||
return {
|
||||
'model_type': 'hybrid_prophet_xgboost',
|
||||
'prophet_model': prophet_result.get('model'),
|
||||
'prophet_model_path': prophet_result.get('model_path'),
|
||||
'xgboost_model': self.xgb_model,
|
||||
'feature_columns': self.feature_columns,
|
||||
'prophet_metrics': {
|
||||
'train_mae': metrics['prophet_train_mae'],
|
||||
'val_mae': metrics['prophet_val_mae'],
|
||||
'train_mape': metrics['prophet_train_mape'],
|
||||
'val_mape': metrics['prophet_val_mape']
|
||||
},
|
||||
'hybrid_metrics': {
|
||||
'train_mae': metrics['hybrid_train_mae'],
|
||||
'val_mae': metrics['hybrid_val_mae'],
|
||||
'train_mape': metrics['hybrid_train_mape'],
|
||||
'val_mape': metrics['hybrid_val_mape']
|
||||
},
|
||||
'improvement_metrics': {
|
||||
'mae_improvement_pct': metrics['mae_improvement_pct'],
|
||||
'mape_improvement_pct': metrics['mape_improvement_pct']
|
||||
},
|
||||
'metrics': metrics,
|
||||
'tenant_id': tenant_id,
|
||||
'inventory_product_id': inventory_product_id,
|
||||
'trained_at': datetime.now(timezone.utc).isoformat()
|
||||
@@ -426,8 +410,18 @@ class HybridProphetXGBoost:
|
||||
Returns:
|
||||
DataFrame with predictions
|
||||
"""
|
||||
# Step 1: Get Prophet predictions
|
||||
prophet_model = model_data['prophet_model']
|
||||
# Step 1: Get Prophet model from path and make predictions
|
||||
prophet_model_path = model_data.get('prophet_model_path')
|
||||
if prophet_model_path is None:
|
||||
raise ValueError("Prophet model path not found in model data")
|
||||
|
||||
# Load the Prophet model from the stored path
|
||||
try:
|
||||
import joblib
|
||||
prophet_model = joblib.load(prophet_model_path)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load Prophet model from path {prophet_model_path}: {str(e)}")
|
||||
|
||||
# ✅ FIX: Run blocking predict() in thread pool to avoid blocking event loop
|
||||
import asyncio
|
||||
prophet_forecast = await asyncio.to_thread(prophet_model.predict, future_df)
|
||||
|
||||
@@ -43,86 +43,79 @@ class POIFeatureIntegrator:
|
||||
force_refresh: bool = False
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch POI features for tenant location.
|
||||
Fetch POI features for tenant location (optimized for training).
|
||||
|
||||
First checks if POI context exists, if not, triggers detection.
|
||||
First checks if POI context exists. If not, returns None without triggering detection.
|
||||
POI detection should be triggered during tenant registration, not during training.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID
|
||||
latitude: Bakery latitude
|
||||
longitude: Bakery longitude
|
||||
force_refresh: Force re-detection
|
||||
force_refresh: Force re-detection (only use if POI context already exists)
|
||||
|
||||
Returns:
|
||||
Dictionary with POI features or None if detection fails
|
||||
Dictionary with POI features or None if not available
|
||||
"""
|
||||
try:
|
||||
# Try to get existing POI context first
|
||||
if not force_refresh:
|
||||
existing_context = await self.external_client.get_poi_context(tenant_id)
|
||||
if existing_context:
|
||||
poi_context = existing_context.get("poi_context", {})
|
||||
ml_features = poi_context.get("ml_features", {})
|
||||
existing_context = await self.external_client.get_poi_context(tenant_id)
|
||||
|
||||
# Check if stale
|
||||
is_stale = existing_context.get("is_stale", False)
|
||||
if not is_stale:
|
||||
if existing_context:
|
||||
poi_context = existing_context.get("poi_context", {})
|
||||
ml_features = poi_context.get("ml_features", {})
|
||||
|
||||
# Check if stale and force_refresh is requested
|
||||
is_stale = existing_context.get("is_stale", False)
|
||||
|
||||
if not is_stale or not force_refresh:
|
||||
logger.info(
|
||||
"Using existing POI context",
|
||||
tenant_id=tenant_id,
|
||||
is_stale=is_stale,
|
||||
feature_count=len(ml_features)
|
||||
)
|
||||
return ml_features
|
||||
else:
|
||||
logger.info(
|
||||
"POI context is stale and force_refresh=True, refreshing",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
# Only refresh if explicitly requested and context exists
|
||||
detection_result = await self.external_client.detect_poi_for_tenant(
|
||||
tenant_id=tenant_id,
|
||||
latitude=latitude,
|
||||
longitude=longitude,
|
||||
force_refresh=True
|
||||
)
|
||||
|
||||
if detection_result:
|
||||
poi_context = detection_result.get("poi_context", {})
|
||||
ml_features = poi_context.get("ml_features", {})
|
||||
logger.info(
|
||||
"Using existing POI context",
|
||||
tenant_id=tenant_id
|
||||
"POI refresh completed",
|
||||
tenant_id=tenant_id,
|
||||
feature_count=len(ml_features)
|
||||
)
|
||||
return ml_features
|
||||
else:
|
||||
logger.info(
|
||||
"POI context is stale, refreshing",
|
||||
logger.warning(
|
||||
"POI refresh failed, returning existing features",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
force_refresh = True
|
||||
else:
|
||||
logger.info(
|
||||
"No existing POI context, will detect",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# Detect or refresh POIs
|
||||
logger.info(
|
||||
"Detecting POIs for tenant",
|
||||
tenant_id=tenant_id,
|
||||
location=(latitude, longitude)
|
||||
)
|
||||
|
||||
detection_result = await self.external_client.detect_poi_for_tenant(
|
||||
tenant_id=tenant_id,
|
||||
latitude=latitude,
|
||||
longitude=longitude,
|
||||
force_refresh=force_refresh
|
||||
)
|
||||
|
||||
if detection_result:
|
||||
poi_context = detection_result.get("poi_context", {})
|
||||
ml_features = poi_context.get("ml_features", {})
|
||||
|
||||
logger.info(
|
||||
"POI detection completed",
|
||||
tenant_id=tenant_id,
|
||||
total_pois=poi_context.get("total_pois_detected", 0),
|
||||
feature_count=len(ml_features)
|
||||
)
|
||||
|
||||
return ml_features
|
||||
return ml_features
|
||||
else:
|
||||
logger.error(
|
||||
"POI detection failed",
|
||||
logger.info(
|
||||
"No existing POI context found - POI detection should be triggered during tenant registration",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Unexpected error fetching POI features",
|
||||
logger.warning(
|
||||
"Error fetching POI features - returning None",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e),
|
||||
exc_info=True
|
||||
error=str(e)
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@@ -29,16 +29,15 @@ class DataClient:
|
||||
self.sales_client = get_sales_client(settings, "training")
|
||||
self.external_client = get_external_client(settings, "training")
|
||||
|
||||
# ExternalServiceClient always has get_stored_traffic_data_for_training method
|
||||
self.supports_stored_traffic_data = True
|
||||
|
||||
# Configure timeouts for HTTP clients
|
||||
self._configure_timeouts()
|
||||
|
||||
# Initialize circuit breakers for external services
|
||||
self._init_circuit_breakers()
|
||||
|
||||
# Check if the new method is available for stored traffic data
|
||||
if hasattr(self.external_client, 'get_stored_traffic_data_for_training'):
|
||||
self.supports_stored_traffic_data = True
|
||||
|
||||
def _configure_timeouts(self):
|
||||
"""Configure appropriate timeouts for HTTP clients"""
|
||||
timeout = httpx.Timeout(
|
||||
@@ -49,14 +48,12 @@ class DataClient:
|
||||
)
|
||||
|
||||
# Apply timeout to clients if they have httpx clients
|
||||
# Note: BaseServiceClient manages its own HTTP client internally
|
||||
if hasattr(self.sales_client, 'client') and isinstance(self.sales_client.client, httpx.AsyncClient):
|
||||
self.sales_client.client.timeout = timeout
|
||||
|
||||
if hasattr(self.external_client, 'client') and isinstance(self.external_client.client, httpx.AsyncClient):
|
||||
self.external_client.client.timeout = timeout
|
||||
else:
|
||||
self.supports_stored_traffic_data = False
|
||||
logger.warning("Stored traffic data method not available in external client")
|
||||
|
||||
def _init_circuit_breakers(self):
|
||||
"""Initialize circuit breakers for external service calls"""
|
||||
|
||||
@@ -404,22 +404,32 @@ class TrainingDataOrchestrator:
|
||||
tenant_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Collect POI features for bakery location.
|
||||
Collect POI features for bakery location (non-blocking).
|
||||
|
||||
POI features are static (location-based, not time-varying).
|
||||
This method is non-blocking with a short timeout to prevent training delays.
|
||||
If POI detection hasn't been run yet, training continues without POI features.
|
||||
|
||||
Returns:
|
||||
Dictionary with POI features or empty dict if unavailable
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Collecting POI features",
|
||||
"Collecting POI features (non-blocking)",
|
||||
tenant_id=tenant_id,
|
||||
location=(lat, lon)
|
||||
)
|
||||
|
||||
poi_features = await self.poi_feature_integrator.fetch_poi_features(
|
||||
tenant_id=tenant_id,
|
||||
latitude=lat,
|
||||
longitude=lon,
|
||||
force_refresh=False
|
||||
# Set a short timeout to prevent blocking training
|
||||
# POI detection should have been triggered during tenant registration
|
||||
poi_features = await asyncio.wait_for(
|
||||
self.poi_feature_integrator.fetch_poi_features(
|
||||
tenant_id=tenant_id,
|
||||
latitude=lat,
|
||||
longitude=lon,
|
||||
force_refresh=False
|
||||
),
|
||||
timeout=15.0 # 15 second timeout - POI should be cached from registration
|
||||
)
|
||||
|
||||
if poi_features:
|
||||
@@ -430,18 +440,24 @@ class TrainingDataOrchestrator:
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"No POI features collected (service may be unavailable)",
|
||||
"No POI features collected (service may be unavailable or not yet detected)",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
return poi_features or {}
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"POI collection timeout (15s) - continuing training without POI features. "
|
||||
"POI detection should be triggered during tenant registration for best results.",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to collect POI features, continuing without them",
|
||||
logger.warning(
|
||||
"Failed to collect POI features (non-blocking) - continuing training without them",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e),
|
||||
exc_info=True
|
||||
error=str(e)
|
||||
)
|
||||
return {}
|
||||
|
||||
|
||||
@@ -71,7 +71,7 @@ class ServiceAuthenticator:
|
||||
}
|
||||
|
||||
if tenant_id:
|
||||
headers["X-Tenant-ID"] = str(tenant_id)
|
||||
headers["x-tenant-id"] = str(tenant_id)
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
@@ -351,7 +351,7 @@ class ForecastServiceClient(BaseServiceClient):
|
||||
"""
|
||||
Trigger demand forecasting insights for a tenant (internal service use only).
|
||||
|
||||
This method calls the internal endpoint which is protected by X-Internal-Service header.
|
||||
This method calls the internal endpoint which is protected by x-internal-service header.
|
||||
Used by demo-session service after cloning to generate AI insights from seeded data.
|
||||
|
||||
Args:
|
||||
@@ -366,7 +366,7 @@ class ForecastServiceClient(BaseServiceClient):
|
||||
endpoint=f"forecasting/internal/ml/generate-demand-insights",
|
||||
tenant_id=tenant_id,
|
||||
data={"tenant_id": tenant_id},
|
||||
headers={"X-Internal-Service": "demo-session"}
|
||||
headers={"x-internal-service": "demo-session"}
|
||||
)
|
||||
|
||||
if result:
|
||||
|
||||
@@ -766,7 +766,7 @@ class InventoryServiceClient(BaseServiceClient):
|
||||
"""
|
||||
Trigger inventory alerts for a tenant (internal service use only).
|
||||
|
||||
This method calls the internal endpoint which is protected by X-Internal-Service header.
|
||||
This method calls the internal endpoint which is protected by x-internal-service header.
|
||||
The endpoint should trigger alerts specifically for the given tenant.
|
||||
|
||||
Args:
|
||||
@@ -783,7 +783,7 @@ class InventoryServiceClient(BaseServiceClient):
|
||||
endpoint="inventory/internal/alerts/trigger",
|
||||
tenant_id=tenant_id,
|
||||
data={},
|
||||
headers={"X-Internal-Service": "demo-session"}
|
||||
headers={"x-internal-service": "demo-session"}
|
||||
)
|
||||
|
||||
if result:
|
||||
@@ -819,7 +819,7 @@ class InventoryServiceClient(BaseServiceClient):
|
||||
"""
|
||||
Trigger safety stock optimization insights for a tenant (internal service use only).
|
||||
|
||||
This method calls the internal endpoint which is protected by X-Internal-Service header.
|
||||
This method calls the internal endpoint which is protected by x-internal-service header.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID to trigger insights for
|
||||
@@ -833,7 +833,7 @@ class InventoryServiceClient(BaseServiceClient):
|
||||
endpoint="inventory/internal/ml/generate-safety-stock-insights",
|
||||
tenant_id=tenant_id,
|
||||
data={"tenant_id": tenant_id},
|
||||
headers={"X-Internal-Service": "demo-session"}
|
||||
headers={"x-internal-service": "demo-session"}
|
||||
)
|
||||
|
||||
if result:
|
||||
|
||||
@@ -580,7 +580,7 @@ class ProcurementServiceClient(BaseServiceClient):
|
||||
"""
|
||||
Trigger delivery tracking for a tenant (internal service use only).
|
||||
|
||||
This method calls the internal endpoint which is protected by X-Internal-Service header.
|
||||
This method calls the internal endpoint which is protected by x-internal-service header.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID to trigger delivery tracking for
|
||||
@@ -596,7 +596,7 @@ class ProcurementServiceClient(BaseServiceClient):
|
||||
endpoint="procurement/internal/delivery-tracking/trigger",
|
||||
tenant_id=tenant_id,
|
||||
data={},
|
||||
headers={"X-Internal-Service": "demo-session"}
|
||||
headers={"x-internal-service": "demo-session"}
|
||||
)
|
||||
|
||||
if result:
|
||||
@@ -632,7 +632,7 @@ class ProcurementServiceClient(BaseServiceClient):
|
||||
"""
|
||||
Trigger price forecasting insights for a tenant (internal service use only).
|
||||
|
||||
This method calls the internal endpoint which is protected by X-Internal-Service header.
|
||||
This method calls the internal endpoint which is protected by x-internal-service header.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID to trigger insights for
|
||||
@@ -646,7 +646,7 @@ class ProcurementServiceClient(BaseServiceClient):
|
||||
endpoint="procurement/internal/ml/generate-price-insights",
|
||||
tenant_id=tenant_id,
|
||||
data={"tenant_id": tenant_id},
|
||||
headers={"X-Internal-Service": "demo-session"}
|
||||
headers={"x-internal-service": "demo-session"}
|
||||
)
|
||||
|
||||
if result:
|
||||
|
||||
@@ -630,7 +630,7 @@ class ProductionServiceClient(BaseServiceClient):
|
||||
"""
|
||||
Trigger production alerts for a tenant (internal service use only).
|
||||
|
||||
This method calls the internal endpoint which is protected by X-Internal-Service header.
|
||||
This method calls the internal endpoint which is protected by x-internal-service header.
|
||||
Includes both production alerts and equipment maintenance checks.
|
||||
|
||||
Args:
|
||||
@@ -647,7 +647,7 @@ class ProductionServiceClient(BaseServiceClient):
|
||||
endpoint="production/internal/alerts/trigger",
|
||||
tenant_id=tenant_id,
|
||||
data={},
|
||||
headers={"X-Internal-Service": "demo-session"}
|
||||
headers={"x-internal-service": "demo-session"}
|
||||
)
|
||||
|
||||
if result:
|
||||
@@ -683,7 +683,7 @@ class ProductionServiceClient(BaseServiceClient):
|
||||
"""
|
||||
Trigger yield improvement insights for a tenant (internal service use only).
|
||||
|
||||
This method calls the internal endpoint which is protected by X-Internal-Service header.
|
||||
This method calls the internal endpoint which is protected by x-internal-service header.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID to trigger insights for
|
||||
@@ -697,7 +697,7 @@ class ProductionServiceClient(BaseServiceClient):
|
||||
endpoint="production/internal/ml/generate-yield-insights",
|
||||
tenant_id=tenant_id,
|
||||
data={"tenant_id": tenant_id},
|
||||
headers={"X-Internal-Service": "demo-session"}
|
||||
headers={"x-internal-service": "demo-session"}
|
||||
)
|
||||
|
||||
if result:
|
||||
|
||||
Reference in New Issue
Block a user