Add improvements 2
This commit is contained in:
120
DOCKER_MAINTENANCE.md
Normal file
120
DOCKER_MAINTENANCE.md
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
# Docker Maintenance Guide for Local Development
|
||||||
|
|
||||||
|
## The Problem
|
||||||
|
|
||||||
|
When developing with Tilt and local Kubernetes (Kind), Docker accumulates:
|
||||||
|
- **Multiple image versions** from each code change (Tilt rebuilds)
|
||||||
|
- **Unused volumes** from previous cluster runs
|
||||||
|
- **Build cache** that grows over time
|
||||||
|
|
||||||
|
This quickly fills up disk space, causing pods to fail with "No space left on device" errors.
|
||||||
|
|
||||||
|
## Quick Fix (When You Hit Disk Issues)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Clean up all unused Docker resources
|
||||||
|
docker system prune -a --volumes -f
|
||||||
|
```
|
||||||
|
|
||||||
|
This removes:
|
||||||
|
- All unused images
|
||||||
|
- All unused volumes
|
||||||
|
- All build cache
|
||||||
|
|
||||||
|
**Expected recovery**: 60-100GB
|
||||||
|
|
||||||
|
## Regular Maintenance
|
||||||
|
|
||||||
|
### Option 1: Use the Cleanup Script (Recommended)
|
||||||
|
|
||||||
|
Run the maintenance script weekly:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./scripts/cleanup-docker.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
Or run it automatically without confirmation:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./scripts/cleanup-docker.sh --auto
|
||||||
|
```
|
||||||
|
|
||||||
|
### Option 2: Manual Commands
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Remove images older than 24 hours
|
||||||
|
docker image prune -af --filter "until=24h"
|
||||||
|
|
||||||
|
# Remove unused volumes
|
||||||
|
docker volume prune -f
|
||||||
|
|
||||||
|
# Remove build cache
|
||||||
|
docker builder prune -af
|
||||||
|
```
|
||||||
|
|
||||||
|
### Option 3: Set Up Automated Cleanup
|
||||||
|
|
||||||
|
Add to your crontab (run every Sunday at 2 AM):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
crontab -e
|
||||||
|
# Add this line:
|
||||||
|
0 2 * * 0 /Users/urtzialfaro/Documents/bakery-ia/scripts/cleanup-docker.sh --auto >> /tmp/docker-cleanup.log 2>&1
|
||||||
|
```
|
||||||
|
|
||||||
|
## Monitoring Disk Usage
|
||||||
|
|
||||||
|
### Check Docker disk usage:
|
||||||
|
```bash
|
||||||
|
docker system df
|
||||||
|
```
|
||||||
|
|
||||||
|
### Check Kind node disk usage:
|
||||||
|
```bash
|
||||||
|
docker exec bakery-ia-local-control-plane df -h /var
|
||||||
|
```
|
||||||
|
|
||||||
|
### Alert thresholds:
|
||||||
|
- **< 70%**: Healthy ✅
|
||||||
|
- **70-85%**: Consider cleanup soon ⚠️
|
||||||
|
- **> 85%**: Run cleanup immediately 🚨
|
||||||
|
- **> 95%**: Critical - pods will fail ❌
|
||||||
|
|
||||||
|
## Prevention Tips
|
||||||
|
|
||||||
|
1. **Run cleanup weekly** to prevent accumulation
|
||||||
|
2. **Monitor disk usage** before long dev sessions
|
||||||
|
3. **Delete old Kind clusters** when switching projects:
|
||||||
|
```bash
|
||||||
|
kind delete cluster --name bakery-ia-local
|
||||||
|
```
|
||||||
|
4. **Increase Docker disk allocation** in Docker Desktop settings if you frequently rebuild many services
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Pods in CrashLoopBackOff after disk issues:
|
||||||
|
|
||||||
|
1. Run cleanup (see Quick Fix above)
|
||||||
|
2. Restart failed pods:
|
||||||
|
```bash
|
||||||
|
kubectl get pods -n bakery-ia | grep -E "(CrashLoopBackOff|Error)" | awk '{print $1}' | xargs kubectl delete pod -n bakery-ia
|
||||||
|
```
|
||||||
|
|
||||||
|
### Cleanup didn't free enough space:
|
||||||
|
|
||||||
|
If still above 90% after cleanup:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Nuclear option - rebuild everything
|
||||||
|
kind delete cluster --name bakery-ia-local
|
||||||
|
docker system prune -a --volumes -f
|
||||||
|
# Then recreate cluster with your setup scripts
|
||||||
|
```
|
||||||
|
|
||||||
|
## What Happened Today (2026-01-12)
|
||||||
|
|
||||||
|
- **Issue**: Disk was 100% full (113GB/113GB), causing database pods to crash
|
||||||
|
- **Root cause**: 122 unused Docker images + 16 unused volumes + 6GB build cache
|
||||||
|
- **Solution**: Ran `docker system prune -a --volumes -f`
|
||||||
|
- **Result**: Freed 89GB, disk now at 22% usage (24GB/113GB)
|
||||||
|
- **All services recovered successfully**
|
||||||
@@ -1,185 +0,0 @@
|
|||||||
# Lista de Mejoras Propuestas
|
|
||||||
|
|
||||||
## 1. Nueva Sección: Seguridad y Ciberseguridad (Añadir después de 5.2)
|
|
||||||
|
|
||||||
### 5.3. Arquitectura de Seguridad y Cumplimiento Normativo Europeo
|
|
||||||
|
|
||||||
**Autenticación y Autorización:**
|
|
||||||
- JWT con rotación de tokens cada 15 minutos
|
|
||||||
- Control de acceso basado en roles (RBAC)
|
|
||||||
- Rate limiting (300 req/min por cliente)
|
|
||||||
- Autenticación multifactor (MFA) planificada para Q1 2026
|
|
||||||
|
|
||||||
**Protección de Datos:**
|
|
||||||
- Cifrado AES-256 en reposo
|
|
||||||
- HTTPS obligatorio con certificados Let's Encrypt auto-renovables
|
|
||||||
- Aislamiento multi-tenant a nivel de base de datos
|
|
||||||
- Prevención SQL injection mediante Pydantic schemas
|
|
||||||
- Protección XSS y CORS
|
|
||||||
|
|
||||||
**Monitorización y Trazabilidad:**
|
|
||||||
- OpenTelemetry para trazabilidad distribuida end-to-end
|
|
||||||
- SigNoz como plataforma unificada de observabilidad
|
|
||||||
- Logs centralizados con correlación de trazas
|
|
||||||
- Auditoría completa de accesos y cambios
|
|
||||||
- Alertas en tiempo real (email y Slack)
|
|
||||||
|
|
||||||
**Cumplimiento RGPD:**
|
|
||||||
- Privacy by design
|
|
||||||
- Derecho al olvido implementado
|
|
||||||
- Exportación de datos en CSV/JSON
|
|
||||||
- Gestión de consentimientos con historial
|
|
||||||
- Anonimización de datos analíticos
|
|
||||||
|
|
||||||
**Infraestructura Segura:**
|
|
||||||
- Kubernetes con políticas de seguridad de pods
|
|
||||||
- Actualizaciones automáticas de seguridad
|
|
||||||
- Backups cifrados diarios
|
|
||||||
- Plan de recuperación ante desastres (DR)
|
|
||||||
- PostgreSQL 17 con configuraciones hardened
|
|
||||||
|
|
||||||
## 2. Actualizar Sección 4.4 (Competencia) - Añadir Ventaja de Seguridad
|
|
||||||
|
|
||||||
**Ventaja Competitiva en Seguridad:**
|
|
||||||
- Arquitectura "Security-First" vs. soluciones legacy vulnerables
|
|
||||||
- Cumplimiento RGPD desde el diseño (competidores retrofitting)
|
|
||||||
- Observabilidad completa (OpenTelemetry + SigNoz) vs. cajas negras
|
|
||||||
- Certificaciones de seguridad planificadas (ISO 27001, ENS)
|
|
||||||
- Alineación con NIS2 Directive (obligatoria 2024 para cadena alimentaria)
|
|
||||||
|
|
||||||
## 3. Actualizar Sección 8.1 (Financiación) - Nuevas Líneas de Ayuda
|
|
||||||
|
|
||||||
### Añadir subsección: "Financiación Europea en Ciberseguridad 2026-2027"
|
|
||||||
|
|
||||||
**Programas Europeos Identificados (2026-2027):**
|
|
||||||
|
|
||||||
1. **Digital Europe Programme - Ciberseguridad (€390M totales)**
|
|
||||||
- **UPTAKE Program**: €15M para SMEs en cumplimiento normativo
|
|
||||||
- **AI-Based Cybersecurity**: €15M para sistemas IA de seguridad
|
|
||||||
- Cofinanciación: hasta 75% de costes de proyecto
|
|
||||||
- Proyectos: €3-5M por proyecto, duración 36 meses
|
|
||||||
- **Elegibilidad**: Bakery-IA califica como SME tecnológica
|
|
||||||
- **Solicitud estimada**: €200.000 (Q1 2026)
|
|
||||||
|
|
||||||
2. **INCIBE EMPRENDE Program (España Digital 2026)**
|
|
||||||
- Presupuesto: €191M (2023-2026)
|
|
||||||
- 34 entidades colaboradoras en España
|
|
||||||
- Ideación, incubación y aceleración en ciberseguridad
|
|
||||||
- Fondos Next Generation-EU
|
|
||||||
- **Solicitud estimada**: €50.000 (Q2 2026)
|
|
||||||
|
|
||||||
3. **ENISA Emprendedoras Digitales**
|
|
||||||
- Préstamos participativos hasta €51M movilizables
|
|
||||||
- Líneas específicas para emprendimiento digital
|
|
||||||
- Sin avales personales, condiciones flexibles
|
|
||||||
- **Solicitud estimada**: €75.000 (Q2 2026)
|
|
||||||
|
|
||||||
**Alineación con Prioridades UE:**
|
|
||||||
- NIS2 Directive (seguridad sector alimentario)
|
|
||||||
- Cyber Resilience Act (productos digitales seguros)
|
|
||||||
- AI Act (IA transparente y auditable)
|
|
||||||
- GDPR (protección datos desde diseño)
|
|
||||||
|
|
||||||
## 4. Actualizar Sección 5.2 (Arquitectura Técnica)
|
|
||||||
|
|
||||||
### Añadir detalles de la documentación técnica:
|
|
||||||
|
|
||||||
**Arquitectura de Microservicios (21 servicios independientes):**
|
|
||||||
- API Gateway centralizado con JWT y cache Redis (95% hit rate)
|
|
||||||
- Frontend: React 18 + TypeScript (PWA mobile-first)
|
|
||||||
- 18 bases de datos PostgreSQL 17 (patrón database-per-service)
|
|
||||||
- Redis 7.4 para caché y RabbitMQ 4.1 para eventos
|
|
||||||
- Kubernetes en VPS (escalabilidad horizontal)
|
|
||||||
|
|
||||||
**Innovación Técnica Destacable:**
|
|
||||||
- **Sistema de Alertas Enriquecidas** (3 niveles: Alertas/Notificaciones/Recomendaciones)
|
|
||||||
- **Priorización Inteligente** con scoring 0-100 (4 factores ponderados)
|
|
||||||
- **Escalado Temporal** (+10 a 48h, +20 a 72h, +30 cerca deadline)
|
|
||||||
- **Encadenamiento Causal** (stock shortage → retraso producción → riesgo pedido)
|
|
||||||
- **Deduplicación** (95% reducción spam de alertas)
|
|
||||||
- **SSE + WebSocket** para actualizaciones en tiempo real
|
|
||||||
- **Prophet ML** con 20+ features (AEMET, tráfico Madrid, festivos)
|
|
||||||
|
|
||||||
**Observabilidad de Clase Empresarial:**
|
|
||||||
- **SigNoz**: Trazas, métricas y logs unificados
|
|
||||||
- **OpenTelemetry**: Auto-instrumentación de 18 servicios
|
|
||||||
- **ClickHouse**: Backend de alto rendimiento para análisis
|
|
||||||
- **Alerting**: Multi-canal (email, Slack) vía AlertManager
|
|
||||||
- Monitorización: 18 DBs PostgreSQL, Redis, RabbitMQ, Kubernetes
|
|
||||||
|
|
||||||
## 5. Actualizar Resumen Ejecutivo (Sección 0)
|
|
||||||
|
|
||||||
### Añadir bullet:
|
|
||||||
|
|
||||||
- **Seguridad y Cumplimiento:** Arquitectura Security-First con cumplimiento RGPD, observabilidad completa (OpenTelemetry/SigNoz), y alineación con normativas europeas (NIS2, Cyber Resilience Act). Elegible para €390M del Digital Europe Programme en ciberseguridad.
|
|
||||||
|
|
||||||
## 6. Actualizar Sección 9 (Decálogo) - Añadir Oportunidades de Financiación
|
|
||||||
|
|
||||||
**6. Alineación Estratégica con Prioridades Europeas 2026-2027:**
|
|
||||||
- **€390M disponibles** en Digital Europe Programme para ciberseguridad
|
|
||||||
- **€191M** del programa INCIBE EMPRENDE (España Digital 2026)
|
|
||||||
- Bakery-IA califica para **3 líneas de financiación simultáneas**:
|
|
||||||
* UPTAKE (€200K) - Cumplimiento normativo SMEs
|
|
||||||
* INCIBE EMPRENDE (€50K) - Aceleración cybersecurity startups
|
|
||||||
* ENISA Digital (€75K) - Préstamo participativo
|
|
||||||
- **Total potencial**: €325.000 en financiación no dilutiva adicional
|
|
||||||
- Ventaja competitiva: Security-First vs. competidores legacy
|
|
||||||
|
|
||||||
## 7. Nueva Tabla de Costes Operativos - Añadir Línea de Seguridad
|
|
||||||
|
|
||||||
| **Seguridad y Cumplimiento** | | |
|
|
||||||
| Certificado SSL (Let's Encrypt) | Gratuito | €0 | Renovación automática |
|
|
||||||
| SigNoz Observability (self-hosted) | Incluido en VPS | €0 | Vs. €500+/mes en SaaS |
|
|
||||||
| Auditoría RGPD anual | Externa | €1,200/año | Compliance obligatorio |
|
|
||||||
| Backups cifrados (off-site) | Backblaze B2 | €5/mes | €60/año |
|
|
||||||
|
|
||||||
## 8. Actualizar Roadmap (Sección 9) - Añadir Hitos de Seguridad
|
|
||||||
|
|
||||||
**Q1 2026:**
|
|
||||||
- ✅ Implementación MFA (Multi-Factor Authentication)
|
|
||||||
- ✅ Solicitud Digital Europe Programme (UPTAKE)
|
|
||||||
- ✅ Auditoría RGPD externa
|
|
||||||
|
|
||||||
**Q2 2026:**
|
|
||||||
- 📋 Certificación ISO 27001 (inicio proceso)
|
|
||||||
- 📋 Implementación NIS2 compliance
|
|
||||||
- 📋 Solicitud INCIBE EMPRENDE
|
|
||||||
|
|
||||||
**Q3 2026:**
|
|
||||||
- 📋 Penetration testing externo
|
|
||||||
- 📋 Certificación ENS (Esquema Nacional de Seguridad)
|
|
||||||
|
|
||||||
## 9. Actualizar Petición Concreta a VUE (Sección 9.2)
|
|
||||||
|
|
||||||
### Añadir nuevo punto 5:
|
|
||||||
|
|
||||||
**5. Conexión con Programas Europeos de Ciberseguridad:**
|
|
||||||
- Orientación para solicitud Digital Europe Programme (UPTAKE: €200K)
|
|
||||||
- Introducción a INCIBE EMPRENDE y red de 34 entidades colaboradoras
|
|
||||||
- Asesoramiento en preparación de propuestas técnicas para fondos EU
|
|
||||||
- Contacto con CDTI para programa NEOTEC (R&D+Ciberseguridad)
|
|
||||||
|
|
||||||
## 10. Añadir Anexo 7 - Compliance y Certificaciones
|
|
||||||
|
|
||||||
## ANEXO 7: ROADMAP DE COMPLIANCE Y CERTIFICACIONES
|
|
||||||
|
|
||||||
**Normativas Aplicables:**
|
|
||||||
- ✅ RGPD (Reglamento General de Protección de Datos) - Implementado
|
|
||||||
- 📋 NIS2 Directive (Seguridad de redes y sistemas de información) - Q2 2026
|
|
||||||
- 📋 Cyber Resilience Act (Productos digitales seguros) - Q3 2026
|
|
||||||
- 📋 AI Act (Transparencia y auditoría de IA) - Q4 2026
|
|
||||||
|
|
||||||
**Certificaciones Planificadas:**
|
|
||||||
- 📋 ISO 27001 (Gestión de Seguridad de la Información) - 12-18 meses
|
|
||||||
- 📋 ENS Medio (Esquema Nacional de Seguridad) - 6-9 meses
|
|
||||||
- 📋 SOC 2 Type II (para clientes Enterprise) - 18-24 meses
|
|
||||||
|
|
||||||
**Inversión Estimada en Compliance (3 años):** €25,000
|
|
||||||
**ROI Esperado:** Acceso a clientes Enterprise (+€150K ARR potencial)
|
|
||||||
|
|
||||||
## Resumen de Cambios Cuantitativos:
|
|
||||||
- Nueva financiación identificada: €325.000 (vs. €18.000 original)
|
|
||||||
- Nuevos programas: 3 líneas europeas de ciberseguridad
|
|
||||||
- Secciones nuevas: 2 (Seguridad 5.3, Compliance Anexo 7)
|
|
||||||
- Actualizaciones: 8 secciones existentes mejoradas
|
|
||||||
- Ventaja competitiva: Security-First enfatizada en 4 lugares
|
|
||||||
File diff suppressed because it is too large
Load Diff
38
Tiltfile
38
Tiltfile
@@ -7,6 +7,13 @@
|
|||||||
# - PostgreSQL pgcrypto extension and audit logging
|
# - PostgreSQL pgcrypto extension and audit logging
|
||||||
# - Organized resource dependencies and live-reload capabilities
|
# - Organized resource dependencies and live-reload capabilities
|
||||||
# - Local registry for faster image builds and deployments
|
# - Local registry for faster image builds and deployments
|
||||||
|
#
|
||||||
|
# Build Optimization:
|
||||||
|
# - Services only rebuild when their specific code changes (not all services)
|
||||||
|
# - Shared folder changes trigger rebuild of ALL services (as they all depend on it)
|
||||||
|
# - Uses 'only' parameter to watch only relevant files per service
|
||||||
|
# - Frontend only rebuilds when frontend/ code changes
|
||||||
|
# - Gateway only rebuilds when gateway/ or shared/ code changes
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -197,16 +204,25 @@ k8s_yaml(kustomize('infrastructure/kubernetes/overlays/dev'))
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
# Helper function for Python services with live updates
|
# Helper function for Python services with live updates
|
||||||
|
# This function ensures services only rebuild when their specific code changes,
|
||||||
|
# but all services rebuild when shared/ folder changes
|
||||||
def build_python_service(service_name, service_path):
|
def build_python_service(service_name, service_path):
|
||||||
docker_build(
|
docker_build(
|
||||||
'bakery/' + service_name,
|
'bakery/' + service_name,
|
||||||
context='.',
|
context='.',
|
||||||
dockerfile='./services/' + service_path + '/Dockerfile',
|
dockerfile='./services/' + service_path + '/Dockerfile',
|
||||||
|
# Only watch files relevant to this specific service + shared code
|
||||||
|
only=[
|
||||||
|
'./services/' + service_path,
|
||||||
|
'./shared',
|
||||||
|
'./scripts',
|
||||||
|
],
|
||||||
live_update=[
|
live_update=[
|
||||||
# Fall back to full image build if Dockerfile or requirements change
|
# Fall back to full image build if Dockerfile or requirements change
|
||||||
fall_back_on([
|
fall_back_on([
|
||||||
'./services/' + service_path + '/Dockerfile',
|
'./services/' + service_path + '/Dockerfile',
|
||||||
'./services/' + service_path + '/requirements.txt'
|
'./services/' + service_path + '/requirements.txt',
|
||||||
|
'./shared/requirements-tracing.txt',
|
||||||
]),
|
]),
|
||||||
|
|
||||||
# Sync service code
|
# Sync service code
|
||||||
@@ -290,10 +306,21 @@ docker_build(
|
|||||||
'bakery/gateway',
|
'bakery/gateway',
|
||||||
context='.',
|
context='.',
|
||||||
dockerfile='./gateway/Dockerfile',
|
dockerfile='./gateway/Dockerfile',
|
||||||
|
# Only watch gateway-specific files and shared code
|
||||||
|
only=[
|
||||||
|
'./gateway',
|
||||||
|
'./shared',
|
||||||
|
'./scripts',
|
||||||
|
],
|
||||||
live_update=[
|
live_update=[
|
||||||
fall_back_on(['./gateway/Dockerfile', './gateway/requirements.txt']),
|
fall_back_on([
|
||||||
|
'./gateway/Dockerfile',
|
||||||
|
'./gateway/requirements.txt',
|
||||||
|
'./shared/requirements-tracing.txt',
|
||||||
|
]),
|
||||||
sync('./gateway', '/app'),
|
sync('./gateway', '/app'),
|
||||||
sync('./shared', '/app/shared'),
|
sync('./shared', '/app/shared'),
|
||||||
|
sync('./scripts', '/app/scripts'),
|
||||||
run('kill -HUP 1', trigger=['./gateway/**/*.py', './shared/**/*.py']),
|
run('kill -HUP 1', trigger=['./gateway/**/*.py', './shared/**/*.py']),
|
||||||
],
|
],
|
||||||
ignore=[
|
ignore=[
|
||||||
@@ -680,6 +707,13 @@ Documentation:
|
|||||||
docs/SECURITY_IMPLEMENTATION_COMPLETE.md
|
docs/SECURITY_IMPLEMENTATION_COMPLETE.md
|
||||||
docs/DATABASE_SECURITY_ANALYSIS_REPORT.md
|
docs/DATABASE_SECURITY_ANALYSIS_REPORT.md
|
||||||
|
|
||||||
|
Build Optimization Active:
|
||||||
|
✅ Services only rebuild when their code changes
|
||||||
|
✅ Shared folder changes trigger ALL services (as expected)
|
||||||
|
✅ Reduces unnecessary rebuilds and disk usage
|
||||||
|
💡 Edit service code: only that service rebuilds
|
||||||
|
💡 Edit shared/ code: all services rebuild (required)
|
||||||
|
|
||||||
Useful Commands:
|
Useful Commands:
|
||||||
# Work on specific services only
|
# Work on specific services only
|
||||||
tilt up <service-name> <service-name>
|
tilt up <service-name> <service-name>
|
||||||
|
|||||||
@@ -198,16 +198,27 @@ export const RegisterTenantStep: React.FC<RegisterTenantStepProps> = ({
|
|||||||
|
|
||||||
// Trigger POI detection in the background (non-blocking)
|
// Trigger POI detection in the background (non-blocking)
|
||||||
// This replaces the removed POI Detection step
|
// This replaces the removed POI Detection step
|
||||||
|
// POI detection will be cached for 90 days and reused during training
|
||||||
const bakeryLocation = wizardContext.state.bakeryLocation;
|
const bakeryLocation = wizardContext.state.bakeryLocation;
|
||||||
if (bakeryLocation?.latitude && bakeryLocation?.longitude && tenant.id) {
|
if (bakeryLocation?.latitude && bakeryLocation?.longitude && tenant.id) {
|
||||||
|
console.log(`🔍 Triggering background POI detection for tenant ${tenant.id}...`);
|
||||||
|
|
||||||
// Run POI detection asynchronously without blocking the wizard flow
|
// Run POI detection asynchronously without blocking the wizard flow
|
||||||
|
// This ensures POI data is ready before the training step
|
||||||
poiContextApi.detectPOIs(
|
poiContextApi.detectPOIs(
|
||||||
tenant.id,
|
tenant.id,
|
||||||
bakeryLocation.latitude,
|
bakeryLocation.latitude,
|
||||||
bakeryLocation.longitude,
|
bakeryLocation.longitude,
|
||||||
false // use_cache = false for initial detection
|
false // force_refresh = false, will use cache if available
|
||||||
).then((result) => {
|
).then((result) => {
|
||||||
console.log(`✅ POI detection completed automatically for tenant ${tenant.id}:`, result.summary);
|
const source = result.source || 'unknown';
|
||||||
|
console.log(`✅ POI detection completed for tenant ${tenant.id} (source: ${source})`);
|
||||||
|
|
||||||
|
if (result.poi_context) {
|
||||||
|
const totalPois = result.poi_context.total_pois_detected || 0;
|
||||||
|
const relevantCategories = result.poi_context.relevant_categories?.length || 0;
|
||||||
|
console.log(`📍 POI Summary: ${totalPois} POIs detected, ${relevantCategories} relevant categories`);
|
||||||
|
}
|
||||||
|
|
||||||
// Phase 3: Handle calendar suggestion if available
|
// Phase 3: Handle calendar suggestion if available
|
||||||
if (result.calendar_suggestion) {
|
if (result.calendar_suggestion) {
|
||||||
@@ -230,8 +241,11 @@ export const RegisterTenantStep: React.FC<RegisterTenantStepProps> = ({
|
|||||||
}
|
}
|
||||||
}).catch((error) => {
|
}).catch((error) => {
|
||||||
console.warn('⚠️ Background POI detection failed (non-blocking):', error);
|
console.warn('⚠️ Background POI detection failed (non-blocking):', error);
|
||||||
// This is non-critical, so we don't block the user
|
console.warn('Training will continue without POI features if detection is not complete.');
|
||||||
|
// This is non-critical - training service will continue without POI features
|
||||||
});
|
});
|
||||||
|
} else {
|
||||||
|
console.warn('⚠️ Cannot trigger POI detection: missing location data or tenant ID');
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update the wizard context with tenant info
|
// Update the wizard context with tenant info
|
||||||
|
|||||||
@@ -352,6 +352,25 @@ headers = {
|
|||||||
- **Caching**: Gateway caches validated service tokens for 5 minutes
|
- **Caching**: Gateway caches validated service tokens for 5 minutes
|
||||||
- **No Additional HTTP Calls**: Service auth happens locally at gateway
|
- **No Additional HTTP Calls**: Service auth happens locally at gateway
|
||||||
|
|
||||||
|
### Unified Header Management System
|
||||||
|
|
||||||
|
The gateway uses a **centralized HeaderManager** for consistent header handling across all middleware and proxy layers.
|
||||||
|
|
||||||
|
**Key Features:**
|
||||||
|
- Standardized header names and conventions
|
||||||
|
- Automatic header sanitization to prevent spoofing
|
||||||
|
- Unified header injection and forwarding
|
||||||
|
- Cross-middleware header access via `request.state.injected_headers`
|
||||||
|
- Consistent logging and error handling
|
||||||
|
|
||||||
|
**Standard Headers:**
|
||||||
|
- `x-user-id`, `x-user-email`, `x-user-role`, `x-user-type`
|
||||||
|
- `x-service-name`, `x-tenant-id`
|
||||||
|
- `x-subscription-tier`, `x-subscription-status`
|
||||||
|
- `x-is-demo`, `x-demo-session-id`, `x-demo-account-type`
|
||||||
|
- `x-tenant-access-type`, `x-can-view-children`, `x-parent-tenant-id`
|
||||||
|
- `x-forwarded-by`, `x-request-id`
|
||||||
|
|
||||||
### Context Header Injection
|
### Context Header Injection
|
||||||
|
|
||||||
When a service token is validated, the gateway injects these headers for downstream services:
|
When a service token is validated, the gateway injects these headers for downstream services:
|
||||||
|
|||||||
345
gateway/app/core/header_manager.py
Normal file
345
gateway/app/core/header_manager.py
Normal file
@@ -0,0 +1,345 @@
|
|||||||
|
"""
|
||||||
|
Unified Header Management System for API Gateway
|
||||||
|
Centralized header injection, forwarding, and validation
|
||||||
|
"""
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
from fastapi import Request
|
||||||
|
from typing import Dict, Any, Optional, List
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class HeaderManager:
|
||||||
|
"""
|
||||||
|
Centralized header management for consistent header handling across gateway
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Standard header names (lowercase for consistency)
|
||||||
|
STANDARD_HEADERS = {
|
||||||
|
'user_id': 'x-user-id',
|
||||||
|
'user_email': 'x-user-email',
|
||||||
|
'user_role': 'x-user-role',
|
||||||
|
'user_type': 'x-user-type',
|
||||||
|
'service_name': 'x-service-name',
|
||||||
|
'tenant_id': 'x-tenant-id',
|
||||||
|
'subscription_tier': 'x-subscription-tier',
|
||||||
|
'subscription_status': 'x-subscription-status',
|
||||||
|
'is_demo': 'x-is-demo',
|
||||||
|
'demo_session_id': 'x-demo-session-id',
|
||||||
|
'demo_account_type': 'x-demo-account-type',
|
||||||
|
'tenant_access_type': 'x-tenant-access-type',
|
||||||
|
'can_view_children': 'x-can-view-children',
|
||||||
|
'parent_tenant_id': 'x-parent-tenant-id',
|
||||||
|
'forwarded_by': 'x-forwarded-by',
|
||||||
|
'request_id': 'x-request-id'
|
||||||
|
}
|
||||||
|
|
||||||
|
# Headers that should be sanitized/removed from incoming requests
|
||||||
|
SANITIZED_HEADERS = [
|
||||||
|
'x-subscription-',
|
||||||
|
'x-user-',
|
||||||
|
'x-tenant-',
|
||||||
|
'x-demo-',
|
||||||
|
'x-forwarded-by'
|
||||||
|
]
|
||||||
|
|
||||||
|
# Headers that should be forwarded to downstream services
|
||||||
|
FORWARDABLE_HEADERS = [
|
||||||
|
'authorization',
|
||||||
|
'content-type',
|
||||||
|
'accept',
|
||||||
|
'accept-language',
|
||||||
|
'user-agent',
|
||||||
|
'x-internal-service' # Required for internal service-to-service ML/alert triggers
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._initialized = False
|
||||||
|
|
||||||
|
def initialize(self):
|
||||||
|
"""Initialize header manager"""
|
||||||
|
if not self._initialized:
|
||||||
|
logger.info("HeaderManager initialized")
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
def sanitize_incoming_headers(self, request: Request) -> None:
|
||||||
|
"""
|
||||||
|
Remove sensitive headers from incoming request to prevent spoofing
|
||||||
|
"""
|
||||||
|
if not hasattr(request.headers, '_list'):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Filter out headers that start with sanitized prefixes
|
||||||
|
sanitized_headers = [
|
||||||
|
(k, v) for k, v in request.headers.raw
|
||||||
|
if not any(k.decode().lower().startswith(prefix.lower())
|
||||||
|
for prefix in self.SANITIZED_HEADERS)
|
||||||
|
]
|
||||||
|
|
||||||
|
request.headers.__dict__["_list"] = sanitized_headers
|
||||||
|
logger.debug("Sanitized incoming headers")
|
||||||
|
|
||||||
|
def inject_context_headers(self, request: Request, user_context: Dict[str, Any],
|
||||||
|
tenant_id: Optional[str] = None) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
Inject standardized context headers into request
|
||||||
|
Returns dict of injected headers for reference
|
||||||
|
"""
|
||||||
|
injected_headers = {}
|
||||||
|
|
||||||
|
# Ensure headers list exists
|
||||||
|
if not hasattr(request.headers, '_list'):
|
||||||
|
request.headers.__dict__["_list"] = []
|
||||||
|
|
||||||
|
# Store headers in request.state for cross-middleware access
|
||||||
|
request.state.injected_headers = {}
|
||||||
|
|
||||||
|
# User context headers
|
||||||
|
if user_context.get('user_id'):
|
||||||
|
header_name = self.STANDARD_HEADERS['user_id']
|
||||||
|
header_value = str(user_context['user_id'])
|
||||||
|
self._add_header(request, header_name, header_value)
|
||||||
|
injected_headers[header_name] = header_value
|
||||||
|
request.state.injected_headers[header_name] = header_value
|
||||||
|
|
||||||
|
if user_context.get('email'):
|
||||||
|
header_name = self.STANDARD_HEADERS['user_email']
|
||||||
|
header_value = str(user_context['email'])
|
||||||
|
self._add_header(request, header_name, header_value)
|
||||||
|
injected_headers[header_name] = header_value
|
||||||
|
request.state.injected_headers[header_name] = header_value
|
||||||
|
|
||||||
|
if user_context.get('role'):
|
||||||
|
header_name = self.STANDARD_HEADERS['user_role']
|
||||||
|
header_value = str(user_context['role'])
|
||||||
|
self._add_header(request, header_name, header_value)
|
||||||
|
injected_headers[header_name] = header_value
|
||||||
|
request.state.injected_headers[header_name] = header_value
|
||||||
|
|
||||||
|
# User type (service vs regular user)
|
||||||
|
if user_context.get('type'):
|
||||||
|
header_name = self.STANDARD_HEADERS['user_type']
|
||||||
|
header_value = str(user_context['type'])
|
||||||
|
self._add_header(request, header_name, header_value)
|
||||||
|
injected_headers[header_name] = header_value
|
||||||
|
request.state.injected_headers[header_name] = header_value
|
||||||
|
|
||||||
|
# Service name for service tokens
|
||||||
|
if user_context.get('service'):
|
||||||
|
header_name = self.STANDARD_HEADERS['service_name']
|
||||||
|
header_value = str(user_context['service'])
|
||||||
|
self._add_header(request, header_name, header_value)
|
||||||
|
injected_headers[header_name] = header_value
|
||||||
|
request.state.injected_headers[header_name] = header_value
|
||||||
|
|
||||||
|
# Tenant context
|
||||||
|
if tenant_id:
|
||||||
|
header_name = self.STANDARD_HEADERS['tenant_id']
|
||||||
|
header_value = str(tenant_id)
|
||||||
|
self._add_header(request, header_name, header_value)
|
||||||
|
injected_headers[header_name] = header_value
|
||||||
|
request.state.injected_headers[header_name] = header_value
|
||||||
|
|
||||||
|
# Subscription context
|
||||||
|
if user_context.get('subscription_tier'):
|
||||||
|
header_name = self.STANDARD_HEADERS['subscription_tier']
|
||||||
|
header_value = str(user_context['subscription_tier'])
|
||||||
|
self._add_header(request, header_name, header_value)
|
||||||
|
injected_headers[header_name] = header_value
|
||||||
|
request.state.injected_headers[header_name] = header_value
|
||||||
|
|
||||||
|
if user_context.get('subscription_status'):
|
||||||
|
header_name = self.STANDARD_HEADERS['subscription_status']
|
||||||
|
header_value = str(user_context['subscription_status'])
|
||||||
|
self._add_header(request, header_name, header_value)
|
||||||
|
injected_headers[header_name] = header_value
|
||||||
|
request.state.injected_headers[header_name] = header_value
|
||||||
|
|
||||||
|
# Demo session context
|
||||||
|
is_demo = user_context.get('is_demo', False)
|
||||||
|
if is_demo:
|
||||||
|
header_name = self.STANDARD_HEADERS['is_demo']
|
||||||
|
header_value = "true"
|
||||||
|
self._add_header(request, header_name, header_value)
|
||||||
|
injected_headers[header_name] = header_value
|
||||||
|
request.state.injected_headers[header_name] = header_value
|
||||||
|
|
||||||
|
if user_context.get('demo_session_id'):
|
||||||
|
header_name = self.STANDARD_HEADERS['demo_session_id']
|
||||||
|
header_value = str(user_context['demo_session_id'])
|
||||||
|
self._add_header(request, header_name, header_value)
|
||||||
|
injected_headers[header_name] = header_value
|
||||||
|
request.state.injected_headers[header_name] = header_value
|
||||||
|
|
||||||
|
if user_context.get('demo_account_type'):
|
||||||
|
header_name = self.STANDARD_HEADERS['demo_account_type']
|
||||||
|
header_value = str(user_context['demo_account_type'])
|
||||||
|
self._add_header(request, header_name, header_value)
|
||||||
|
injected_headers[header_name] = header_value
|
||||||
|
request.state.injected_headers[header_name] = header_value
|
||||||
|
|
||||||
|
# Hierarchical access context
|
||||||
|
if tenant_id:
|
||||||
|
tenant_access_type = getattr(request.state, 'tenant_access_type', 'direct')
|
||||||
|
can_view_children = getattr(request.state, 'can_view_children', False)
|
||||||
|
|
||||||
|
header_name = self.STANDARD_HEADERS['tenant_access_type']
|
||||||
|
header_value = str(tenant_access_type)
|
||||||
|
self._add_header(request, header_name, header_value)
|
||||||
|
injected_headers[header_name] = header_value
|
||||||
|
request.state.injected_headers[header_name] = header_value
|
||||||
|
|
||||||
|
header_name = self.STANDARD_HEADERS['can_view_children']
|
||||||
|
header_value = str(can_view_children).lower()
|
||||||
|
self._add_header(request, header_name, header_value)
|
||||||
|
injected_headers[header_name] = header_value
|
||||||
|
request.state.injected_headers[header_name] = header_value
|
||||||
|
|
||||||
|
# Parent tenant ID if hierarchical access
|
||||||
|
parent_tenant_id = getattr(request.state, 'parent_tenant_id', None)
|
||||||
|
if parent_tenant_id:
|
||||||
|
header_name = self.STANDARD_HEADERS['parent_tenant_id']
|
||||||
|
header_value = str(parent_tenant_id)
|
||||||
|
self._add_header(request, header_name, header_value)
|
||||||
|
injected_headers[header_name] = header_value
|
||||||
|
request.state.injected_headers[header_name] = header_value
|
||||||
|
|
||||||
|
# Gateway identification
|
||||||
|
header_name = self.STANDARD_HEADERS['forwarded_by']
|
||||||
|
header_value = "bakery-gateway"
|
||||||
|
self._add_header(request, header_name, header_value)
|
||||||
|
injected_headers[header_name] = header_value
|
||||||
|
request.state.injected_headers[header_name] = header_value
|
||||||
|
|
||||||
|
# Request ID if available
|
||||||
|
request_id = getattr(request.state, 'request_id', None)
|
||||||
|
if request_id:
|
||||||
|
header_name = self.STANDARD_HEADERS['request_id']
|
||||||
|
header_value = str(request_id)
|
||||||
|
self._add_header(request, header_name, header_value)
|
||||||
|
injected_headers[header_name] = header_value
|
||||||
|
request.state.injected_headers[header_name] = header_value
|
||||||
|
|
||||||
|
logger.info("🔧 Injected context headers",
|
||||||
|
user_id=user_context.get('user_id'),
|
||||||
|
user_type=user_context.get('type', ''),
|
||||||
|
service_name=user_context.get('service', ''),
|
||||||
|
role=user_context.get('role', ''),
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
is_demo=is_demo,
|
||||||
|
demo_session_id=user_context.get('demo_session_id', ''),
|
||||||
|
path=request.url.path)
|
||||||
|
|
||||||
|
return injected_headers
|
||||||
|
|
||||||
|
def _add_header(self, request: Request, header_name: str, header_value: str) -> None:
|
||||||
|
"""
|
||||||
|
Safely add header to request
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
request.headers.__dict__["_list"].append((header_name.encode(), header_value.encode()))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to add header {header_name}: {e}")
|
||||||
|
|
||||||
|
def get_forwardable_headers(self, request: Request) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
Get headers that should be forwarded to downstream services
|
||||||
|
Includes both original request headers and injected context headers
|
||||||
|
"""
|
||||||
|
forwardable_headers = {}
|
||||||
|
|
||||||
|
# Add forwardable original headers
|
||||||
|
for header_name in self.FORWARDABLE_HEADERS:
|
||||||
|
header_value = request.headers.get(header_name)
|
||||||
|
if header_value:
|
||||||
|
forwardable_headers[header_name] = header_value
|
||||||
|
|
||||||
|
# Add injected context headers from request.state
|
||||||
|
if hasattr(request.state, 'injected_headers'):
|
||||||
|
for header_name, header_value in request.state.injected_headers.items():
|
||||||
|
forwardable_headers[header_name] = header_value
|
||||||
|
|
||||||
|
# Add authorization header if present
|
||||||
|
auth_header = request.headers.get('authorization')
|
||||||
|
if auth_header:
|
||||||
|
forwardable_headers['authorization'] = auth_header
|
||||||
|
|
||||||
|
return forwardable_headers
|
||||||
|
|
||||||
|
def get_all_headers_for_proxy(self, request: Request,
|
||||||
|
additional_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
Get complete set of headers for proxying to downstream services
|
||||||
|
"""
|
||||||
|
headers = self.get_forwardable_headers(request)
|
||||||
|
|
||||||
|
# Add any additional headers
|
||||||
|
if additional_headers:
|
||||||
|
headers.update(additional_headers)
|
||||||
|
|
||||||
|
# Remove host header as it will be set by httpx
|
||||||
|
headers.pop('host', None)
|
||||||
|
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def validate_required_headers(self, request: Request, required_headers: List[str]) -> bool:
|
||||||
|
"""
|
||||||
|
Validate that required headers are present
|
||||||
|
"""
|
||||||
|
missing_headers = []
|
||||||
|
|
||||||
|
for header_name in required_headers:
|
||||||
|
# Check in injected headers first
|
||||||
|
if hasattr(request.state, 'injected_headers'):
|
||||||
|
if header_name in request.state.injected_headers:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check in request headers
|
||||||
|
if request.headers.get(header_name):
|
||||||
|
continue
|
||||||
|
|
||||||
|
missing_headers.append(header_name)
|
||||||
|
|
||||||
|
if missing_headers:
|
||||||
|
logger.warning(f"Missing required headers: {missing_headers}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_header_value(self, request: Request, header_name: str,
|
||||||
|
default: Optional[str] = None) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get header value from either injected headers or request headers
|
||||||
|
"""
|
||||||
|
# Check injected headers first
|
||||||
|
if hasattr(request.state, 'injected_headers'):
|
||||||
|
if header_name in request.state.injected_headers:
|
||||||
|
return request.state.injected_headers[header_name]
|
||||||
|
|
||||||
|
# Check request headers
|
||||||
|
return request.headers.get(header_name, default)
|
||||||
|
|
||||||
|
def add_header_for_middleware(self, request: Request, header_name: str, header_value: str) -> None:
|
||||||
|
"""
|
||||||
|
Allow middleware to add headers to the unified header system
|
||||||
|
This ensures all headers are available for proxying
|
||||||
|
"""
|
||||||
|
# Ensure injected_headers exists
|
||||||
|
if not hasattr(request.state, 'injected_headers'):
|
||||||
|
request.state.injected_headers = {}
|
||||||
|
|
||||||
|
# Add header to injected_headers
|
||||||
|
request.state.injected_headers[header_name] = header_value
|
||||||
|
|
||||||
|
# Also add to actual request headers for compatibility
|
||||||
|
try:
|
||||||
|
request.headers.__dict__["_list"].append((header_name.encode(), header_value.encode()))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to add header {header_name} to request headers: {e}")
|
||||||
|
|
||||||
|
logger.debug(f"Middleware added header: {header_name} = {header_value}")
|
||||||
|
|
||||||
|
|
||||||
|
# Global instance for easy access
|
||||||
|
header_manager = HeaderManager()
|
||||||
@@ -16,6 +16,7 @@ from shared.redis_utils import initialize_redis, close_redis, get_redis_client
|
|||||||
from shared.service_base import StandardFastAPIService
|
from shared.service_base import StandardFastAPIService
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
from app.core.header_manager import header_manager
|
||||||
from app.middleware.request_id import RequestIDMiddleware
|
from app.middleware.request_id import RequestIDMiddleware
|
||||||
from app.middleware.auth import AuthMiddleware
|
from app.middleware.auth import AuthMiddleware
|
||||||
from app.middleware.logging import LoggingMiddleware
|
from app.middleware.logging import LoggingMiddleware
|
||||||
@@ -50,6 +51,10 @@ class GatewayService(StandardFastAPIService):
|
|||||||
"""Custom startup logic for Gateway"""
|
"""Custom startup logic for Gateway"""
|
||||||
global redis_client
|
global redis_client
|
||||||
|
|
||||||
|
# Initialize HeaderManager
|
||||||
|
header_manager.initialize()
|
||||||
|
logger.info("HeaderManager initialized")
|
||||||
|
|
||||||
# Initialize Redis
|
# Initialize Redis
|
||||||
try:
|
try:
|
||||||
await initialize_redis(settings.REDIS_URL, db=0, max_connections=50)
|
await initialize_redis(settings.REDIS_URL, db=0, max_connections=50)
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import httpx
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
from app.core.header_manager import header_manager
|
||||||
from shared.auth.jwt_handler import JWTHandler
|
from shared.auth.jwt_handler import JWTHandler
|
||||||
from shared.auth.tenant_access import tenant_access_manager, extract_tenant_id_from_path, is_tenant_scoped_path
|
from shared.auth.tenant_access import tenant_access_manager, extract_tenant_id_from_path, is_tenant_scoped_path
|
||||||
|
|
||||||
@@ -60,15 +61,8 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|||||||
if request.method == "OPTIONS":
|
if request.method == "OPTIONS":
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
# SECURITY: Remove any incoming x-subscription-* headers
|
# SECURITY: Remove any incoming sensitive headers using HeaderManager
|
||||||
# These will be re-injected from verified JWT only
|
header_manager.sanitize_incoming_headers(request)
|
||||||
sanitized_headers = [
|
|
||||||
(k, v) for k, v in request.headers.raw
|
|
||||||
if not k.decode().lower().startswith('x-subscription-')
|
|
||||||
and not k.decode().lower().startswith('x-user-')
|
|
||||||
and not k.decode().lower().startswith('x-tenant-')
|
|
||||||
]
|
|
||||||
request.headers.__dict__["_list"] = sanitized_headers
|
|
||||||
|
|
||||||
# Skip authentication for public routes
|
# Skip authentication for public routes
|
||||||
if self._is_public_route(request.url.path):
|
if self._is_public_route(request.url.path):
|
||||||
@@ -573,109 +567,13 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|||||||
|
|
||||||
async def _inject_context_headers(self, request: Request, user_context: Dict[str, Any], tenant_id: Optional[str] = None):
|
async def _inject_context_headers(self, request: Request, user_context: Dict[str, Any], tenant_id: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
Inject user and tenant context headers for downstream services
|
Inject user and tenant context headers for downstream services using unified HeaderManager
|
||||||
ENHANCED: Added logging to verify header injection
|
|
||||||
"""
|
"""
|
||||||
# Enhanced logging for debugging
|
# Use unified HeaderManager for consistent header injection
|
||||||
logger.info(
|
injected_headers = header_manager.inject_context_headers(request, user_context, tenant_id)
|
||||||
"🔧 Injecting context headers",
|
|
||||||
user_id=user_context.get("user_id"),
|
|
||||||
user_type=user_context.get("type", ""),
|
|
||||||
service_name=user_context.get("service", ""),
|
|
||||||
role=user_context.get("role", ""),
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
is_demo=user_context.get("is_demo", False),
|
|
||||||
demo_session_id=user_context.get("demo_session_id", ""),
|
|
||||||
path=request.url.path
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add user context headers
|
|
||||||
logger.debug(f"DEBUG: Injecting headers for user: {user_context.get('user_id')}, is_demo: {user_context.get('is_demo', False)}")
|
|
||||||
logger.debug(f"DEBUG: request.headers object id: {id(request.headers)}, _list id: {id(request.headers.__dict__.get('_list', []))}")
|
|
||||||
|
|
||||||
# Store headers in request.state for cross-middleware access
|
|
||||||
request.state.injected_headers = {
|
|
||||||
"x-user-id": user_context["user_id"],
|
|
||||||
"x-user-email": user_context["email"],
|
|
||||||
"x-user-role": user_context.get("role", "user")
|
|
||||||
}
|
|
||||||
|
|
||||||
request.headers.__dict__["_list"].append((
|
|
||||||
b"x-user-id", user_context["user_id"].encode()
|
|
||||||
))
|
|
||||||
request.headers.__dict__["_list"].append((
|
|
||||||
b"x-user-email", user_context["email"].encode()
|
|
||||||
))
|
|
||||||
|
|
||||||
user_role = user_context.get("role", "user")
|
|
||||||
request.headers.__dict__["_list"].append((
|
|
||||||
b"x-user-role", user_role.encode()
|
|
||||||
))
|
|
||||||
|
|
||||||
user_type = user_context.get("type", "")
|
|
||||||
if user_type:
|
|
||||||
request.headers.__dict__["_list"].append((
|
|
||||||
b"x-user-type", user_type.encode()
|
|
||||||
))
|
|
||||||
|
|
||||||
service_name = user_context.get("service", "")
|
|
||||||
if service_name:
|
|
||||||
request.headers.__dict__["_list"].append((
|
|
||||||
b"x-service-name", service_name.encode()
|
|
||||||
))
|
|
||||||
|
|
||||||
# Add tenant context if available
|
|
||||||
if tenant_id:
|
|
||||||
request.headers.__dict__["_list"].append((
|
|
||||||
b"x-tenant-id", tenant_id.encode()
|
|
||||||
))
|
|
||||||
|
|
||||||
# Add subscription tier if available
|
|
||||||
subscription_tier = user_context.get("subscription_tier", "")
|
|
||||||
if subscription_tier:
|
|
||||||
request.headers.__dict__["_list"].append((
|
|
||||||
b"x-subscription-tier", subscription_tier.encode()
|
|
||||||
))
|
|
||||||
|
|
||||||
# Add is_demo flag for demo sessions
|
|
||||||
is_demo = user_context.get("is_demo", False)
|
|
||||||
logger.debug(f"DEBUG: is_demo value: {is_demo}, type: {type(is_demo)}")
|
|
||||||
if is_demo:
|
|
||||||
logger.info(f"🎭 Adding demo session headers",
|
|
||||||
demo_session_id=user_context.get("demo_session_id", ""),
|
|
||||||
demo_account_type=user_context.get("demo_account_type", ""),
|
|
||||||
path=request.url.path)
|
|
||||||
request.headers.__dict__["_list"].append((
|
|
||||||
b"x-is-demo", b"true"
|
|
||||||
))
|
|
||||||
else:
|
|
||||||
logger.debug(f"DEBUG: Not adding demo headers because is_demo is: {is_demo}")
|
|
||||||
|
|
||||||
# Add demo session context headers for backend services
|
|
||||||
demo_session_id = user_context.get("demo_session_id", "")
|
|
||||||
if demo_session_id:
|
|
||||||
request.headers.__dict__["_list"].append((
|
|
||||||
b"x-demo-session-id", demo_session_id.encode()
|
|
||||||
))
|
|
||||||
|
|
||||||
demo_account_type = user_context.get("demo_account_type", "")
|
|
||||||
if demo_account_type:
|
|
||||||
request.headers.__dict__["_list"].append((
|
|
||||||
b"x-demo-account-type", demo_account_type.encode()
|
|
||||||
))
|
|
||||||
|
|
||||||
# Add hierarchical access headers if tenant context exists
|
# Add hierarchical access headers if tenant context exists
|
||||||
if tenant_id:
|
if tenant_id:
|
||||||
tenant_access_type = getattr(request.state, 'tenant_access_type', 'direct')
|
|
||||||
can_view_children = getattr(request.state, 'can_view_children', False)
|
|
||||||
|
|
||||||
request.headers.__dict__["_list"].append((
|
|
||||||
b"x-tenant-access-type", tenant_access_type.encode()
|
|
||||||
))
|
|
||||||
request.headers.__dict__["_list"].append((
|
|
||||||
b"x-can-view-children", str(can_view_children).encode()
|
|
||||||
))
|
|
||||||
|
|
||||||
# If this is hierarchical access, include parent tenant ID
|
# If this is hierarchical access, include parent tenant ID
|
||||||
# Get parent tenant ID from the auth service if available
|
# Get parent tenant ID from the auth service if available
|
||||||
try:
|
try:
|
||||||
@@ -689,17 +587,16 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|||||||
hierarchy_data = response.json()
|
hierarchy_data = response.json()
|
||||||
parent_tenant_id = hierarchy_data.get("parent_tenant_id")
|
parent_tenant_id = hierarchy_data.get("parent_tenant_id")
|
||||||
if parent_tenant_id:
|
if parent_tenant_id:
|
||||||
request.headers.__dict__["_list"].append((
|
# Add parent tenant ID using HeaderManager for consistency
|
||||||
b"x-parent-tenant-id", parent_tenant_id.encode()
|
header_name = header_manager.STANDARD_HEADERS['parent_tenant_id']
|
||||||
))
|
header_value = str(parent_tenant_id)
|
||||||
|
header_manager.add_header_for_middleware(request, header_name, header_value)
|
||||||
|
logger.info(f"Added parent tenant ID header: {parent_tenant_id}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to get parent tenant ID: {e}")
|
logger.warning(f"Failed to get parent tenant ID: {e}")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Add gateway identification
|
return injected_headers
|
||||||
request.headers.__dict__["_list"].append((
|
|
||||||
b"x-forwarded-by", b"bakery-gateway"
|
|
||||||
))
|
|
||||||
|
|
||||||
async def _get_tenant_subscription_tier(self, tenant_id: str, request: Request) -> Optional[str]:
|
async def _get_tenant_subscription_tier(self, tenant_id: str, request: Request) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -45,8 +45,17 @@ class APIRateLimitMiddleware(BaseHTTPMiddleware):
|
|||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get subscription tier
|
# Get subscription tier from headers (added by AuthMiddleware)
|
||||||
|
subscription_tier = request.headers.get("x-subscription-tier")
|
||||||
|
|
||||||
|
if not subscription_tier:
|
||||||
|
# Fallback: get from request state if headers not available
|
||||||
|
subscription_tier = getattr(request.state, "subscription_tier", None)
|
||||||
|
|
||||||
|
if not subscription_tier:
|
||||||
|
# Final fallback: get from tenant service (should rarely happen)
|
||||||
subscription_tier = await self._get_subscription_tier(tenant_id, request)
|
subscription_tier = await self._get_subscription_tier(tenant_id, request)
|
||||||
|
logger.warning(f"Subscription tier not found in headers or state, fetched from tenant service: {subscription_tier}")
|
||||||
|
|
||||||
# Get quota limit for tier
|
# Get quota limit for tier
|
||||||
quota_limit = self._get_quota_limit(subscription_tier)
|
quota_limit = self._get_quota_limit(subscription_tier)
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ from fastapi import Request
|
|||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
from starlette.responses import Response
|
from starlette.responses import Response
|
||||||
|
|
||||||
|
from app.core.header_manager import header_manager
|
||||||
|
|
||||||
logger = structlog.get_logger()
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
@@ -40,11 +42,9 @@ class RequestIDMiddleware(BaseHTTPMiddleware):
|
|||||||
# Bind request ID to structured logger context
|
# Bind request ID to structured logger context
|
||||||
logger_ctx = logger.bind(request_id=request_id)
|
logger_ctx = logger.bind(request_id=request_id)
|
||||||
|
|
||||||
# Inject request ID header for downstream services
|
# Inject request ID header for downstream services using HeaderManager
|
||||||
# This is done by modifying the headers that will be forwarded
|
# Note: This runs early in middleware chain, so we use add_header_for_middleware
|
||||||
request.headers.__dict__["_list"].append((
|
header_manager.add_header_for_middleware(request, "x-request-id", request_id)
|
||||||
b"x-request-id", request_id.encode()
|
|
||||||
))
|
|
||||||
|
|
||||||
# Log request start
|
# Log request start
|
||||||
logger_ctx.info(
|
logger_ctx.info(
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import asyncio
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
from app.core.header_manager import header_manager
|
||||||
from app.utils.subscription_error_responses import create_upgrade_required_response
|
from app.utils.subscription_error_responses import create_upgrade_required_response
|
||||||
|
|
||||||
logger = structlog.get_logger()
|
logger = structlog.get_logger()
|
||||||
@@ -178,7 +179,10 @@ class SubscriptionMiddleware(BaseHTTPMiddleware):
|
|||||||
r'/api/v1/subscriptions/.*', # Subscription management itself
|
r'/api/v1/subscriptions/.*', # Subscription management itself
|
||||||
r'/api/v1/tenants/[^/]+/members.*', # Basic tenant info
|
r'/api/v1/tenants/[^/]+/members.*', # Basic tenant info
|
||||||
r'/docs.*',
|
r'/docs.*',
|
||||||
r'/openapi\.json'
|
r'/openapi\.json',
|
||||||
|
# Training monitoring endpoints (WebSocket and status checks)
|
||||||
|
r'/api/v1/tenants/[^/]+/training/jobs/.*/live.*', # WebSocket endpoint
|
||||||
|
r'/api/v1/tenants/[^/]+/training/jobs/.*/status.*', # Status polling endpoint
|
||||||
]
|
]
|
||||||
|
|
||||||
# Skip OPTIONS requests (CORS preflight)
|
# Skip OPTIONS requests (CORS preflight)
|
||||||
@@ -275,21 +279,11 @@ class SubscriptionMiddleware(BaseHTTPMiddleware):
|
|||||||
'current_tier': current_tier
|
'current_tier': current_tier
|
||||||
}
|
}
|
||||||
|
|
||||||
# Use the same authentication pattern as gateway routes for fallback
|
# Use unified HeaderManager for consistent header handling
|
||||||
headers = dict(request.headers)
|
headers = header_manager.get_all_headers_for_proxy(request)
|
||||||
headers.pop("host", None)
|
|
||||||
|
|
||||||
# Extract user_id for logging (fallback path)
|
# Extract user_id for logging (fallback path)
|
||||||
user_id = 'unknown'
|
user_id = header_manager.get_header_value(request, 'x-user-id', 'unknown')
|
||||||
# Add user context headers if available
|
|
||||||
if hasattr(request.state, 'user') and request.state.user:
|
|
||||||
user = request.state.user
|
|
||||||
user_id = str(user.get('user_id', 'unknown'))
|
|
||||||
headers["x-user-id"] = user_id
|
|
||||||
headers["x-user-email"] = str(user.get('email', ''))
|
|
||||||
headers["x-user-role"] = str(user.get('role', 'user'))
|
|
||||||
headers["x-user-full-name"] = str(user.get('full_name', ''))
|
|
||||||
headers["x-tenant-id"] = str(user.get('tenant_id', ''))
|
|
||||||
|
|
||||||
# Call tenant service fast tier endpoint with caching (fallback for old tokens)
|
# Call tenant service fast tier endpoint with caching (fallback for old tokens)
|
||||||
timeout_config = httpx.Timeout(
|
timeout_config = httpx.Timeout(
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from fastapi.responses import JSONResponse
|
|||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
from app.core.header_manager import header_manager
|
||||||
from app.core.service_discovery import ServiceDiscovery
|
from app.core.service_discovery import ServiceDiscovery
|
||||||
from shared.monitoring.metrics import MetricsCollector
|
from shared.monitoring.metrics import MetricsCollector
|
||||||
|
|
||||||
@@ -136,107 +137,32 @@ class AuthProxy:
|
|||||||
return AUTH_SERVICE_URL
|
return AUTH_SERVICE_URL
|
||||||
|
|
||||||
def _prepare_headers(self, headers, request=None) -> Dict[str, str]:
|
def _prepare_headers(self, headers, request=None) -> Dict[str, str]:
|
||||||
"""Prepare headers for forwarding (remove hop-by-hop headers)"""
|
"""Prepare headers for forwarding using unified HeaderManager"""
|
||||||
# Remove hop-by-hop headers
|
# Use unified HeaderManager to get all headers
|
||||||
hop_by_hop_headers = {
|
if request:
|
||||||
'connection', 'keep-alive', 'proxy-authenticate',
|
all_headers = header_manager.get_all_headers_for_proxy(request)
|
||||||
'proxy-authorization', 'te', 'trailers', 'upgrade'
|
logger.debug(f"DEBUG: Added headers from HeaderManager: {list(all_headers.keys())}")
|
||||||
}
|
else:
|
||||||
|
# Fallback: convert headers to dict manually
|
||||||
# Convert headers to dict - get ALL headers including those added by middleware
|
|
||||||
# Middleware adds headers to _list, so we need to read from there
|
|
||||||
logger.debug(f"DEBUG: headers type: {type(headers)}, has _list: {hasattr(headers, '_list')}, has raw: {hasattr(headers, 'raw')}")
|
|
||||||
logger.debug(f"DEBUG: headers.__dict__ keys: {list(headers.__dict__.keys())}")
|
|
||||||
logger.debug(f"DEBUG: '_list' in headers.__dict__: {'_list' in headers.__dict__}")
|
|
||||||
|
|
||||||
if hasattr(headers, '_list'):
|
|
||||||
logger.debug(f"DEBUG: Entering _list branch")
|
|
||||||
logger.debug(f"DEBUG: headers object id: {id(headers)}, _list id: {id(headers.__dict__.get('_list', []))}")
|
|
||||||
# Get headers from the _list where middleware adds them
|
|
||||||
all_headers_list = headers.__dict__.get('_list', [])
|
|
||||||
logger.debug(f"DEBUG: _list length: {len(all_headers_list)}")
|
|
||||||
|
|
||||||
# Debug: Show first few headers in the list
|
|
||||||
debug_headers = []
|
|
||||||
for i, (k, v) in enumerate(all_headers_list):
|
|
||||||
if i < 5: # Show first 5 headers for debugging
|
|
||||||
key = k.decode() if isinstance(k, bytes) else k
|
|
||||||
value = v.decode() if isinstance(v, bytes) else v
|
|
||||||
debug_headers.append(f"{key}: {value}")
|
|
||||||
logger.debug(f"DEBUG: First headers in _list: {debug_headers}")
|
|
||||||
|
|
||||||
# Convert to dict for easier processing
|
|
||||||
all_headers = {}
|
all_headers = {}
|
||||||
for k, v in all_headers_list:
|
if hasattr(headers, '_list'):
|
||||||
|
for k, v in headers.__dict__.get('_list', []):
|
||||||
key = k.decode() if isinstance(k, bytes) else k
|
key = k.decode() if isinstance(k, bytes) else k
|
||||||
value = v.decode() if isinstance(v, bytes) else v
|
value = v.decode() if isinstance(v, bytes) else v
|
||||||
all_headers[key] = value
|
all_headers[key] = value
|
||||||
|
|
||||||
# Debug: Show if x-user-id and x-is-demo are in the dict
|
|
||||||
logger.debug(f"DEBUG: x-user-id in all_headers: {'x-user-id' in all_headers}, x-is-demo in all_headers: {'x-is-demo' in all_headers}")
|
|
||||||
logger.debug(f"DEBUG: all_headers keys: {list(all_headers.keys())[:10]}...") # Show first 10 keys
|
|
||||||
|
|
||||||
logger.info(f"📤 Forwarding headers to auth service - x_user_id: {all_headers.get('x-user-id', 'MISSING')}, x_is_demo: {all_headers.get('x-is-demo', 'MISSING')}, x_demo_session_id: {all_headers.get('x-demo-session-id', 'MISSING')}, headers: {list(all_headers.keys())}")
|
|
||||||
|
|
||||||
# Check if headers are missing and try to get them from request.state
|
|
||||||
if request and hasattr(request, 'state') and hasattr(request.state, 'injected_headers'):
|
|
||||||
logger.debug(f"DEBUG: Found injected_headers in request.state: {request.state.injected_headers}")
|
|
||||||
# Add missing headers from request.state
|
|
||||||
if 'x-user-id' not in all_headers and 'x-user-id' in request.state.injected_headers:
|
|
||||||
all_headers['x-user-id'] = request.state.injected_headers['x-user-id']
|
|
||||||
logger.debug(f"DEBUG: Added x-user-id from request.state: {all_headers['x-user-id']}")
|
|
||||||
if 'x-user-email' not in all_headers and 'x-user-email' in request.state.injected_headers:
|
|
||||||
all_headers['x-user-email'] = request.state.injected_headers['x-user-email']
|
|
||||||
logger.debug(f"DEBUG: Added x-user-email from request.state: {all_headers['x-user-email']}")
|
|
||||||
if 'x-user-role' not in all_headers and 'x-user-role' in request.state.injected_headers:
|
|
||||||
all_headers['x-user-role'] = request.state.injected_headers['x-user-role']
|
|
||||||
logger.debug(f"DEBUG: Added x-user-role from request.state: {all_headers['x-user-role']}")
|
|
||||||
|
|
||||||
# Add is_demo flag if this is a demo session
|
|
||||||
if hasattr(request.state, 'is_demo_session') and request.state.is_demo_session:
|
|
||||||
all_headers['x-is-demo'] = 'true'
|
|
||||||
logger.debug(f"DEBUG: Added x-is-demo from request.state.is_demo_session")
|
|
||||||
|
|
||||||
# Filter out hop-by-hop headers
|
|
||||||
filtered_headers = {
|
|
||||||
k: v for k, v in all_headers.items()
|
|
||||||
if k.lower() not in hop_by_hop_headers
|
|
||||||
}
|
|
||||||
elif hasattr(headers, 'raw'):
|
elif hasattr(headers, 'raw'):
|
||||||
logger.debug(f"DEBUG: Entering raw branch")
|
for k, v in headers.raw:
|
||||||
|
key = k.decode() if isinstance(k, bytes) else k
|
||||||
# Filter out hop-by-hop headers
|
value = v.decode() if isinstance(v, bytes) else v
|
||||||
filtered_headers = {
|
all_headers[key] = value
|
||||||
k: v for k, v in all_headers.items()
|
|
||||||
if k.lower() not in hop_by_hop_headers
|
|
||||||
}
|
|
||||||
elif hasattr(headers, 'raw'):
|
|
||||||
# Fallback to raw headers if _list not available
|
|
||||||
all_headers = {
|
|
||||||
k.decode() if isinstance(k, bytes) else k: v.decode() if isinstance(v, bytes) else v
|
|
||||||
for k, v in headers.raw
|
|
||||||
}
|
|
||||||
logger.info(f"📤 Forwarding headers to auth service - x_user_id: {all_headers.get('x-user-id', 'MISSING')}, x_is_demo: {all_headers.get('x-is-demo', 'MISSING')}, x_demo_session_id: {all_headers.get('x-demo-session-id', 'MISSING')}, headers: {list(all_headers.keys())}")
|
|
||||||
|
|
||||||
filtered_headers = {
|
|
||||||
k.decode() if isinstance(k, bytes) else k: v.decode() if isinstance(v, bytes) else v
|
|
||||||
for k, v in headers.raw
|
|
||||||
if (k.decode() if isinstance(k, bytes) else k).lower() not in hop_by_hop_headers
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
# Handle case where headers is already a dict
|
# Headers is already a dict
|
||||||
logger.info(f"📤 Forwarding headers to auth service - x_user_id: {headers.get('x-user-id', 'MISSING')}, x_is_demo: {headers.get('x-is-demo', 'MISSING')}, x_demo_session_id: {headers.get('x-demo-session-id', 'MISSING')}, headers: {list(headers.keys())}")
|
all_headers = dict(headers)
|
||||||
|
|
||||||
filtered_headers = {
|
# Debug logging
|
||||||
k: v for k, v in headers.items()
|
logger.info(f"📤 Forwarding headers - x_user_id: {all_headers.get('x-user-id', 'MISSING')}, x_is_demo: {all_headers.get('x-is-demo', 'MISSING')}, x_demo_session_id: {all_headers.get('x-demo-session-id', 'MISSING')}, headers: {list(all_headers.keys())}")
|
||||||
if k.lower() not in hop_by_hop_headers
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add gateway identifier
|
return all_headers
|
||||||
filtered_headers['X-Forwarded-By'] = 'bakery-gateway'
|
|
||||||
filtered_headers['X-Gateway-Version'] = '1.0.0'
|
|
||||||
|
|
||||||
return filtered_headers
|
|
||||||
|
|
||||||
def _prepare_response_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
|
def _prepare_response_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
|
||||||
"""Prepare response headers"""
|
"""Prepare response headers"""
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import httpx
|
|||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
from app.core.header_manager import header_manager
|
||||||
|
|
||||||
logger = structlog.get_logger()
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
@@ -29,12 +30,8 @@ async def proxy_demo_service(path: str, request: Request):
|
|||||||
if request.method in ["POST", "PUT", "PATCH"]:
|
if request.method in ["POST", "PUT", "PATCH"]:
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
|
|
||||||
# Forward headers (excluding host)
|
# Use unified HeaderManager for consistent header forwarding
|
||||||
headers = {
|
headers = header_manager.get_all_headers_for_proxy(request)
|
||||||
key: value
|
|
||||||
for key, value in request.headers.items()
|
|
||||||
if key.lower() not in ["host", "content-length"]
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from fastapi.responses import JSONResponse
|
|||||||
import httpx
|
import httpx
|
||||||
import structlog
|
import structlog
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
from app.core.header_manager import header_manager
|
||||||
|
|
||||||
logger = structlog.get_logger()
|
logger = structlog.get_logger()
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@@ -26,12 +27,8 @@ async def proxy_geocoding(request: Request, path: str):
|
|||||||
if request.method in ["POST", "PUT", "PATCH"]:
|
if request.method in ["POST", "PUT", "PATCH"]:
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
|
|
||||||
# Forward headers (excluding host)
|
# Use unified HeaderManager for consistent header forwarding
|
||||||
headers = {
|
headers = header_manager.get_all_headers_for_proxy(request)
|
||||||
key: value
|
|
||||||
for key, value in request.headers.items()
|
|
||||||
if key.lower() not in ["host", "content-length"]
|
|
||||||
}
|
|
||||||
|
|
||||||
# Make the proxied request
|
# Make the proxied request
|
||||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from fastapi.responses import JSONResponse
|
|||||||
import httpx
|
import httpx
|
||||||
import structlog
|
import structlog
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
from app.core.header_manager import header_manager
|
||||||
|
|
||||||
logger = structlog.get_logger()
|
logger = structlog.get_logger()
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@@ -44,12 +45,8 @@ async def proxy_poi_context(request: Request, path: str):
|
|||||||
if request.method in ["POST", "PUT", "PATCH"]:
|
if request.method in ["POST", "PUT", "PATCH"]:
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
|
|
||||||
# Copy headers (exclude host and content-length as they'll be set by httpx)
|
# Use unified HeaderManager for consistent header forwarding
|
||||||
headers = {
|
headers = header_manager.get_all_headers_for_proxy(request)
|
||||||
key: value
|
|
||||||
for key, value in request.headers.items()
|
|
||||||
if key.lower() not in ["host", "content-length"]
|
|
||||||
}
|
|
||||||
|
|
||||||
# Make the request to the external service
|
# Make the request to the external service
|
||||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import httpx
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
from app.core.header_manager import header_manager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@@ -45,9 +46,8 @@ async def _proxy_to_pos_service(request: Request, target_path: str):
|
|||||||
try:
|
try:
|
||||||
url = f"{settings.POS_SERVICE_URL}{target_path}"
|
url = f"{settings.POS_SERVICE_URL}{target_path}"
|
||||||
|
|
||||||
# Forward headers
|
# Use unified HeaderManager for consistent header forwarding
|
||||||
headers = dict(request.headers)
|
headers = header_manager.get_all_headers_for_proxy(request)
|
||||||
headers.pop("host", None)
|
|
||||||
|
|
||||||
# Add query parameters
|
# Add query parameters
|
||||||
params = dict(request.query_params)
|
params = dict(request.query_params)
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import logging
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
from app.core.header_manager import header_manager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@@ -98,29 +99,13 @@ async def _proxy_request(request: Request, target_path: str, service_url: str):
|
|||||||
try:
|
try:
|
||||||
url = f"{service_url}{target_path}"
|
url = f"{service_url}{target_path}"
|
||||||
|
|
||||||
# Forward headers and add user/tenant context
|
# Use unified HeaderManager for consistent header forwarding
|
||||||
headers = dict(request.headers)
|
headers = header_manager.get_all_headers_for_proxy(request)
|
||||||
headers.pop("host", None)
|
|
||||||
|
|
||||||
# Add user context headers if available
|
# Debug logging
|
||||||
if hasattr(request.state, 'user') and request.state.user:
|
user_context = getattr(request.state, 'user', None)
|
||||||
user = request.state.user
|
if user_context:
|
||||||
headers["x-user-id"] = str(user.get('user_id', ''))
|
logger.info(f"Forwarding subscription request to {url} with user context: user_id={user_context.get('user_id')}, email={user_context.get('email')}, subscription_tier={user_context.get('subscription_tier', 'not_set')}")
|
||||||
headers["x-user-email"] = str(user.get('email', ''))
|
|
||||||
headers["x-user-role"] = str(user.get('role', 'user'))
|
|
||||||
headers["x-user-full-name"] = str(user.get('full_name', ''))
|
|
||||||
headers["x-tenant-id"] = str(user.get('tenant_id', ''))
|
|
||||||
|
|
||||||
# Add subscription context headers
|
|
||||||
if user.get('subscription_tier'):
|
|
||||||
headers["x-subscription-tier"] = str(user.get('subscription_tier', ''))
|
|
||||||
logger.debug(f"Forwarding subscription tier: {user.get('subscription_tier')}")
|
|
||||||
|
|
||||||
if user.get('subscription_status'):
|
|
||||||
headers["x-subscription-status"] = str(user.get('subscription_status', ''))
|
|
||||||
logger.debug(f"Forwarding subscription status: {user.get('subscription_status')}")
|
|
||||||
|
|
||||||
logger.info(f"Forwarding subscription request to {url} with user context: user_id={user.get('user_id')}, email={user.get('email')}, subscription_tier={user.get('subscription_tier', 'not_set')}")
|
|
||||||
else:
|
else:
|
||||||
logger.warning(f"No user context available when forwarding subscription request to {url}")
|
logger.warning(f"No user context available when forwarding subscription request to {url}")
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import logging
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
from app.core.header_manager import header_manager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@@ -715,36 +716,18 @@ async def _proxy_request(request: Request, target_path: str, service_url: str, t
|
|||||||
try:
|
try:
|
||||||
url = f"{service_url}{target_path}"
|
url = f"{service_url}{target_path}"
|
||||||
|
|
||||||
# Forward headers and add user/tenant context
|
# Use unified HeaderManager for consistent header forwarding
|
||||||
headers = dict(request.headers)
|
headers = header_manager.get_all_headers_for_proxy(request)
|
||||||
headers.pop("host", None)
|
|
||||||
|
|
||||||
# Add tenant ID header if provided
|
# Add tenant ID header if provided (override if needed)
|
||||||
if tenant_id:
|
if tenant_id:
|
||||||
headers["X-Tenant-ID"] = tenant_id
|
headers["x-tenant-id"] = tenant_id
|
||||||
|
|
||||||
# Add user context headers if available
|
|
||||||
if hasattr(request.state, 'user') and request.state.user:
|
|
||||||
user = request.state.user
|
|
||||||
headers["x-user-id"] = str(user.get('user_id', ''))
|
|
||||||
headers["x-user-email"] = str(user.get('email', ''))
|
|
||||||
headers["x-user-role"] = str(user.get('role', 'user'))
|
|
||||||
headers["x-user-full-name"] = str(user.get('full_name', ''))
|
|
||||||
headers["x-tenant-id"] = tenant_id or str(user.get('tenant_id', ''))
|
|
||||||
|
|
||||||
# Add subscription context headers
|
|
||||||
if user.get('subscription_tier'):
|
|
||||||
headers["x-subscription-tier"] = str(user.get('subscription_tier', ''))
|
|
||||||
logger.debug(f"Forwarding subscription tier: {user.get('subscription_tier')}")
|
|
||||||
|
|
||||||
if user.get('subscription_status'):
|
|
||||||
headers["x-subscription-status"] = str(user.get('subscription_status', ''))
|
|
||||||
logger.debug(f"Forwarding subscription status: {user.get('subscription_status')}")
|
|
||||||
|
|
||||||
# Debug logging
|
# Debug logging
|
||||||
logger.info(f"Forwarding request to {url} with user context: user_id={user.get('user_id')}, email={user.get('email')}, tenant_id={tenant_id}, subscription_tier={user.get('subscription_tier', 'not_set')}")
|
user_context = getattr(request.state, 'user', None)
|
||||||
|
if user_context:
|
||||||
|
logger.info(f"Forwarding request to {url} with user context: user_id={user_context.get('user_id')}, email={user_context.get('email')}, tenant_id={tenant_id}, subscription_tier={user_context.get('subscription_tier', 'not_set')}")
|
||||||
else:
|
else:
|
||||||
# Debug logging when no user context available
|
|
||||||
logger.warning(f"No user context available when forwarding request to {url}. request.state.user: {getattr(request.state, 'user', 'NOT_SET')}")
|
logger.warning(f"No user context available when forwarding request to {url}. request.state.user: {getattr(request.state, 'user', 'NOT_SET')}")
|
||||||
|
|
||||||
# Get request body if present
|
# Get request body if present
|
||||||
@@ -782,9 +765,10 @@ async def _proxy_request(request: Request, target_path: str, service_url: str, t
|
|||||||
|
|
||||||
logger.info(f"Forwarding multipart request with files={list(files.keys()) if files else None}, data={list(data.keys()) if data else None}")
|
logger.info(f"Forwarding multipart request with files={list(files.keys()) if files else None}, data={list(data.keys()) if data else None}")
|
||||||
|
|
||||||
# Remove content-type from headers - httpx will set it with new boundary
|
# For multipart requests, we need to get fresh headers since httpx will set content-type
|
||||||
headers.pop("content-type", None)
|
# Get all headers again to ensure we have the complete set
|
||||||
headers.pop("content-length", None)
|
headers = header_manager.get_all_headers_for_proxy(request)
|
||||||
|
# httpx will automatically set content-type for multipart, so we don't need to remove it
|
||||||
else:
|
else:
|
||||||
# For other content types, use body as before
|
# For other content types, use body as before
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from typing import Dict, Any
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
from app.core.header_manager import header_manager
|
||||||
from app.core.service_discovery import ServiceDiscovery
|
from app.core.service_discovery import ServiceDiscovery
|
||||||
from shared.monitoring.metrics import MetricsCollector
|
from shared.monitoring.metrics import MetricsCollector
|
||||||
|
|
||||||
@@ -136,64 +137,28 @@ class UserProxy:
|
|||||||
return AUTH_SERVICE_URL
|
return AUTH_SERVICE_URL
|
||||||
|
|
||||||
def _prepare_headers(self, headers, request=None) -> Dict[str, str]:
|
def _prepare_headers(self, headers, request=None) -> Dict[str, str]:
|
||||||
"""Prepare headers for forwarding (remove hop-by-hop headers)"""
|
"""Prepare headers for forwarding using unified HeaderManager"""
|
||||||
# Remove hop-by-hop headers
|
# Use unified HeaderManager to get all headers
|
||||||
hop_by_hop_headers = {
|
if request:
|
||||||
'connection', 'keep-alive', 'proxy-authenticate',
|
all_headers = header_manager.get_all_headers_for_proxy(request)
|
||||||
'proxy-authorization', 'te', 'trailers', 'upgrade'
|
else:
|
||||||
}
|
# Fallback: convert headers to dict manually
|
||||||
|
|
||||||
# Convert headers to dict if it's a Headers object
|
|
||||||
# This ensures we get ALL headers including those added by middleware
|
|
||||||
if hasattr(headers, '_list'):
|
|
||||||
# Get headers from the _list where middleware adds them
|
|
||||||
all_headers_list = headers.__dict__.get('_list', [])
|
|
||||||
|
|
||||||
# Convert to dict for easier processing
|
|
||||||
all_headers = {}
|
all_headers = {}
|
||||||
for k, v in all_headers_list:
|
if hasattr(headers, '_list'):
|
||||||
|
for k, v in headers.__dict__.get('_list', []):
|
||||||
key = k.decode() if isinstance(k, bytes) else k
|
key = k.decode() if isinstance(k, bytes) else k
|
||||||
value = v.decode() if isinstance(v, bytes) else v
|
value = v.decode() if isinstance(v, bytes) else v
|
||||||
all_headers[key] = value
|
all_headers[key] = value
|
||||||
|
|
||||||
# Check if headers are missing and try to get them from request.state
|
|
||||||
if request and hasattr(request, 'state') and hasattr(request.state, 'injected_headers'):
|
|
||||||
# Add missing headers from request.state
|
|
||||||
if 'x-user-id' not in all_headers and 'x-user-id' in request.state.injected_headers:
|
|
||||||
all_headers['x-user-id'] = request.state.injected_headers['x-user-id']
|
|
||||||
if 'x-user-email' not in all_headers and 'x-user-email' in request.state.injected_headers:
|
|
||||||
all_headers['x-user-email'] = request.state.injected_headers['x-user-email']
|
|
||||||
if 'x-user-role' not in all_headers and 'x-user-role' in request.state.injected_headers:
|
|
||||||
all_headers['x-user-role'] = request.state.injected_headers['x-user-role']
|
|
||||||
|
|
||||||
# Add is_demo flag if this is a demo session
|
|
||||||
if hasattr(request.state, 'is_demo_session') and request.state.is_demo_session:
|
|
||||||
all_headers['x-is-demo'] = 'true'
|
|
||||||
|
|
||||||
# Filter out hop-by-hop headers
|
|
||||||
filtered_headers = {
|
|
||||||
k: v for k, v in all_headers.items()
|
|
||||||
if k.lower() not in hop_by_hop_headers
|
|
||||||
}
|
|
||||||
elif hasattr(headers, 'raw'):
|
elif hasattr(headers, 'raw'):
|
||||||
# FastAPI/Starlette Headers object - use raw to get all headers
|
for k, v in headers.raw:
|
||||||
filtered_headers = {
|
key = k.decode() if isinstance(k, bytes) else k
|
||||||
k.decode() if isinstance(k, bytes) else k: v.decode() if isinstance(v, bytes) else v
|
value = v.decode() if isinstance(v, bytes) else v
|
||||||
for k, v in headers.raw
|
all_headers[key] = value
|
||||||
if (k.decode() if isinstance(k, bytes) else k).lower() not in hop_by_hop_headers
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
# Already a dict
|
# Headers is already a dict
|
||||||
filtered_headers = {
|
all_headers = dict(headers)
|
||||||
k: v for k, v in headers.items()
|
|
||||||
if k.lower() not in hop_by_hop_headers
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add gateway identifier
|
return all_headers
|
||||||
filtered_headers['X-Forwarded-By'] = 'bakery-gateway'
|
|
||||||
filtered_headers['X-Gateway-Version'] = '1.0.0'
|
|
||||||
|
|
||||||
return filtered_headers
|
|
||||||
|
|
||||||
def _prepare_response_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
|
def _prepare_response_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
|
||||||
"""Prepare response headers"""
|
"""Prepare response headers"""
|
||||||
|
|||||||
82
scripts/cleanup-docker.sh
Executable file
82
scripts/cleanup-docker.sh
Executable file
@@ -0,0 +1,82 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Docker Cleanup Script for Local Kubernetes Development
|
||||||
|
# This script helps prevent disk space issues by cleaning up unused Docker resources
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
echo "🧹 Docker Cleanup Script for Bakery-IA Local Development"
|
||||||
|
echo "========================================================="
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Check if we should run automatically or ask for confirmation
|
||||||
|
AUTO_MODE=${1:-""}
|
||||||
|
|
||||||
|
# Show current disk usage
|
||||||
|
echo "📊 Current Docker Disk Usage:"
|
||||||
|
docker system df
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Check Kind node disk usage if cluster is running
|
||||||
|
if docker ps | grep -q "bakery-ia-local-control-plane"; then
|
||||||
|
echo "📊 Kind Node Disk Usage:"
|
||||||
|
docker exec bakery-ia-local-control-plane df -h / /var | grep -E "(Filesystem|overlay|/dev/vdb1)"
|
||||||
|
echo ""
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Calculate reclaimable space
|
||||||
|
RECLAIMABLE=$(docker system df | grep "Images" | awk '{print $4}')
|
||||||
|
echo "💾 Estimated reclaimable space: $RECLAIMABLE"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Ask for confirmation unless in auto mode
|
||||||
|
if [ "$AUTO_MODE" != "--auto" ]; then
|
||||||
|
read -p "Do you want to proceed with cleanup? (y/n) " -n 1 -r
|
||||||
|
echo ""
|
||||||
|
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
|
||||||
|
echo "❌ Cleanup cancelled"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "🚀 Starting cleanup..."
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Remove unused images (keep images from last 24 hours)
|
||||||
|
echo "1️⃣ Removing unused Docker images..."
|
||||||
|
docker image prune -af --filter "until=24h" || true
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Remove unused volumes
|
||||||
|
echo "2️⃣ Removing unused Docker volumes..."
|
||||||
|
docker volume prune -f || true
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Remove build cache
|
||||||
|
echo "3️⃣ Removing build cache..."
|
||||||
|
docker builder prune -af || true
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Show results
|
||||||
|
echo "✅ Cleanup completed!"
|
||||||
|
echo ""
|
||||||
|
echo "📊 Final Docker Disk Usage:"
|
||||||
|
docker system df
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Check Kind node disk usage if cluster is running
|
||||||
|
if docker ps | grep -q "bakery-ia-local-control-plane"; then
|
||||||
|
echo "📊 Kind Node Disk Usage After Cleanup:"
|
||||||
|
docker exec bakery-ia-local-control-plane df -h / /var | grep -E "(Filesystem|overlay|/dev/vdb1)"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Warn if still above 80%
|
||||||
|
USAGE=$(docker exec bakery-ia-local-control-plane df -h /var | tail -1 | awk '{print $5}' | sed 's/%//')
|
||||||
|
if [ "$USAGE" -gt 80 ]; then
|
||||||
|
echo "⚠️ Warning: Disk usage is still above 80%. Consider:"
|
||||||
|
echo " - Deleting and recreating the Kind cluster"
|
||||||
|
echo " - Increasing Docker's disk allocation"
|
||||||
|
echo " - Running: docker system prune -a --volumes -f"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "🎉 All done!"
|
||||||
@@ -1,82 +0,0 @@
|
|||||||
# services/forecasting/app/clients/inventory_client.py
|
|
||||||
"""
|
|
||||||
Simple client for inventory service integration
|
|
||||||
Used when product names are not available locally
|
|
||||||
"""
|
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
import structlog
|
|
||||||
from typing import Optional, Dict, Any
|
|
||||||
import os
|
|
||||||
|
|
||||||
logger = structlog.get_logger()
|
|
||||||
|
|
||||||
class InventoryServiceClient:
|
|
||||||
"""Simple client for inventory service interactions"""
|
|
||||||
|
|
||||||
def __init__(self, base_url: str = None):
|
|
||||||
self.base_url = base_url or os.getenv("INVENTORY_SERVICE_URL", "http://inventory-service:8000")
|
|
||||||
self.timeout = aiohttp.ClientTimeout(total=5) # 5 second timeout
|
|
||||||
|
|
||||||
async def get_product_name(self, tenant_id: str, inventory_product_id: str) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Get product name from inventory service
|
|
||||||
Returns None if service is unavailable or product not found
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
async with aiohttp.ClientSession(timeout=self.timeout) as session:
|
|
||||||
url = f"{self.base_url}/api/v1/products/{inventory_product_id}"
|
|
||||||
headers = {"X-Tenant-ID": tenant_id}
|
|
||||||
|
|
||||||
async with session.get(url, headers=headers) as response:
|
|
||||||
if response.status == 200:
|
|
||||||
data = await response.json()
|
|
||||||
return data.get("name", f"Product-{inventory_product_id}")
|
|
||||||
else:
|
|
||||||
logger.debug("Product not found in inventory service",
|
|
||||||
inventory_product_id=inventory_product_id,
|
|
||||||
status=response.status)
|
|
||||||
return None
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug("Failed to get product name from inventory service",
|
|
||||||
inventory_product_id=inventory_product_id,
|
|
||||||
error=str(e))
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def get_multiple_product_names(self, tenant_id: str, product_ids: list) -> Dict[str, str]:
|
|
||||||
"""
|
|
||||||
Get multiple product names efficiently
|
|
||||||
Returns a mapping of product_id -> product_name
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
async with aiohttp.ClientSession(timeout=self.timeout) as session:
|
|
||||||
url = f"{self.base_url}/api/v1/products/batch"
|
|
||||||
headers = {"X-Tenant-ID": tenant_id}
|
|
||||||
payload = {"product_ids": product_ids}
|
|
||||||
|
|
||||||
async with session.post(url, json=payload, headers=headers) as response:
|
|
||||||
if response.status == 200:
|
|
||||||
data = await response.json()
|
|
||||||
return {item["id"]: item["name"] for item in data.get("products", [])}
|
|
||||||
else:
|
|
||||||
logger.debug("Batch product lookup failed",
|
|
||||||
product_count=len(product_ids),
|
|
||||||
status=response.status)
|
|
||||||
return {}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug("Failed to get product names from inventory service",
|
|
||||||
product_count=len(product_ids),
|
|
||||||
error=str(e))
|
|
||||||
return {}
|
|
||||||
|
|
||||||
# Global client instance
|
|
||||||
_inventory_client = None
|
|
||||||
|
|
||||||
def get_inventory_client() -> InventoryServiceClient:
|
|
||||||
"""Get the global inventory client instance"""
|
|
||||||
global _inventory_client
|
|
||||||
if _inventory_client is None:
|
|
||||||
_inventory_client = InventoryServiceClient()
|
|
||||||
return _inventory_client
|
|
||||||
@@ -12,7 +12,6 @@ from datetime import datetime, timedelta
|
|||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
from shared.messaging import UnifiedEventPublisher
|
from shared.messaging import UnifiedEventPublisher
|
||||||
from app.clients.inventory_client import get_inventory_client
|
|
||||||
|
|
||||||
logger = structlog.get_logger()
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
|
|||||||
@@ -30,11 +30,11 @@ async def trigger_inventory_alerts(
|
|||||||
- Expiring ingredients
|
- Expiring ingredients
|
||||||
- Overstock situations
|
- Overstock situations
|
||||||
|
|
||||||
Security: Protected by X-Internal-Service header check.
|
Security: Protected by x-internal-service header check.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Verify internal service header
|
# Verify internal service header
|
||||||
if not request or request.headers.get("X-Internal-Service") not in ["demo-session", "internal"]:
|
if not request or request.headers.get("x-internal-service") not in ["demo-session", "internal"]:
|
||||||
logger.warning("Unauthorized internal API call", tenant_id=str(tenant_id))
|
logger.warning("Unauthorized internal API call", tenant_id=str(tenant_id))
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=403,
|
status_code=403,
|
||||||
|
|||||||
@@ -350,7 +350,7 @@ async def generate_safety_stock_insights_internal(
|
|||||||
This endpoint is called by the demo-session service after cloning data.
|
This endpoint is called by the demo-session service after cloning data.
|
||||||
It uses the same ML logic as the public endpoint but with optimized defaults.
|
It uses the same ML logic as the public endpoint but with optimized defaults.
|
||||||
|
|
||||||
Security: Protected by X-Internal-Service header check.
|
Security: Protected by x-internal-service header check.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tenant_id: The tenant UUID
|
tenant_id: The tenant UUID
|
||||||
@@ -365,7 +365,7 @@ async def generate_safety_stock_insights_internal(
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
# Verify internal service header
|
# Verify internal service header
|
||||||
if not request or request.headers.get("X-Internal-Service") not in ["demo-session", "internal"]:
|
if not request or request.headers.get("x-internal-service") not in ["demo-session", "internal"]:
|
||||||
logger.warning("Unauthorized internal API call", tenant_id=tenant_id)
|
logger.warning("Unauthorized internal API call", tenant_id=tenant_id)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=403,
|
status_code=403,
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ async def trigger_delivery_tracking(
|
|||||||
This endpoint is called by the demo session cloning process after POs are seeded
|
This endpoint is called by the demo session cloning process after POs are seeded
|
||||||
to generate realistic delivery alerts (arriving soon, overdue, etc.).
|
to generate realistic delivery alerts (arriving soon, overdue, etc.).
|
||||||
|
|
||||||
Security: Protected by X-Internal-Service header check.
|
Security: Protected by x-internal-service header check.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tenant_id: Tenant UUID to check deliveries for
|
tenant_id: Tenant UUID to check deliveries for
|
||||||
@@ -49,7 +49,7 @@ async def trigger_delivery_tracking(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Verify internal service header
|
# Verify internal service header
|
||||||
if not request or request.headers.get("X-Internal-Service") not in ["demo-session", "internal"]:
|
if not request or request.headers.get("x-internal-service") not in ["demo-session", "internal"]:
|
||||||
logger.warning("Unauthorized internal API call", tenant_id=str(tenant_id))
|
logger.warning("Unauthorized internal API call", tenant_id=str(tenant_id))
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=403,
|
status_code=403,
|
||||||
|
|||||||
@@ -566,7 +566,7 @@ async def generate_price_insights_internal(
|
|||||||
This endpoint is called by the demo-session service after cloning data.
|
This endpoint is called by the demo-session service after cloning data.
|
||||||
It uses the same ML logic as the public endpoint but with optimized defaults.
|
It uses the same ML logic as the public endpoint but with optimized defaults.
|
||||||
|
|
||||||
Security: Protected by X-Internal-Service header check.
|
Security: Protected by x-internal-service header check.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tenant_id: The tenant UUID
|
tenant_id: The tenant UUID
|
||||||
@@ -581,7 +581,7 @@ async def generate_price_insights_internal(
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
# Verify internal service header
|
# Verify internal service header
|
||||||
if not request or request.headers.get("X-Internal-Service") not in ["demo-session", "internal"]:
|
if not request or request.headers.get("x-internal-service") not in ["demo-session", "internal"]:
|
||||||
logger.warning("Unauthorized internal API call", tenant_id=tenant_id)
|
logger.warning("Unauthorized internal API call", tenant_id=tenant_id)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=403,
|
status_code=403,
|
||||||
|
|||||||
@@ -1,42 +1,45 @@
|
|||||||
"""
|
"""
|
||||||
FastAPI Dependencies for Procurement Service
|
FastAPI Dependencies for Procurement Service
|
||||||
|
Uses shared authentication infrastructure with UUID validation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from fastapi import Header, HTTPException, status
|
from fastapi import Depends, HTTPException, status
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from .database import get_db
|
from .database import get_db
|
||||||
|
from shared.auth.decorators import get_current_tenant_id_dep
|
||||||
|
|
||||||
|
|
||||||
async def get_current_tenant_id(
|
async def get_current_tenant_id(
|
||||||
x_tenant_id: Optional[str] = Header(None, alias="X-Tenant-ID")
|
tenant_id: Optional[str] = Depends(get_current_tenant_id_dep)
|
||||||
) -> UUID:
|
) -> UUID:
|
||||||
"""
|
"""
|
||||||
Extract and validate tenant ID from request header.
|
Extract and validate tenant ID from request using shared infrastructure.
|
||||||
|
Adds UUID validation to ensure tenant ID format is correct.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x_tenant_id: Tenant ID from X-Tenant-ID header
|
tenant_id: Tenant ID from shared dependency
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
UUID: Validated tenant ID
|
UUID: Validated tenant ID
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
HTTPException: If tenant ID is missing or invalid
|
HTTPException: If tenant ID is missing or invalid UUID format
|
||||||
"""
|
"""
|
||||||
if not x_tenant_id:
|
if not tenant_id:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail="X-Tenant-ID header is required"
|
detail="x-tenant-id header is required"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return UUID(x_tenant_id)
|
return UUID(tenant_id)
|
||||||
except (ValueError, AttributeError):
|
except (ValueError, AttributeError):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=f"Invalid tenant ID format: {x_tenant_id}"
|
detail=f"Invalid tenant ID format: {tenant_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -31,11 +31,11 @@ async def trigger_production_alerts(
|
|||||||
- Equipment maintenance alerts
|
- Equipment maintenance alerts
|
||||||
- Batch start delays
|
- Batch start delays
|
||||||
|
|
||||||
Security: Protected by X-Internal-Service header check.
|
Security: Protected by x-internal-service header check.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Verify internal service header
|
# Verify internal service header
|
||||||
if not request or request.headers.get("X-Internal-Service") not in ["demo-session", "internal"]:
|
if not request or request.headers.get("x-internal-service") not in ["demo-session", "internal"]:
|
||||||
logger.warning("Unauthorized internal API call", tenant_id=str(tenant_id))
|
logger.warning("Unauthorized internal API call", tenant_id=str(tenant_id))
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=403,
|
status_code=403,
|
||||||
|
|||||||
@@ -331,7 +331,7 @@ async def generate_yield_insights_internal(
|
|||||||
This endpoint is called by the demo-session service after cloning data.
|
This endpoint is called by the demo-session service after cloning data.
|
||||||
It uses the same ML logic as the public endpoint but with optimized defaults.
|
It uses the same ML logic as the public endpoint but with optimized defaults.
|
||||||
|
|
||||||
Security: Protected by X-Internal-Service header check.
|
Security: Protected by x-internal-service header check.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tenant_id: The tenant UUID
|
tenant_id: The tenant UUID
|
||||||
@@ -346,7 +346,7 @@ async def generate_yield_insights_internal(
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
# Verify internal service header
|
# Verify internal service header
|
||||||
if not request or request.headers.get("X-Internal-Service") not in ["demo-session", "internal"]:
|
if not request or request.headers.get("x-internal-service") not in ["demo-session", "internal"]:
|
||||||
logger.warning("Unauthorized internal API call", tenant_id=tenant_id)
|
logger.warning("Unauthorized internal API call", tenant_id=tenant_id)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=403,
|
status_code=403,
|
||||||
|
|||||||
@@ -204,7 +204,7 @@ class TenantMemberRepository(TenantBaseRepository):
|
|||||||
f"{auth_service_url}/api/v1/auth/users/batch",
|
f"{auth_service_url}/api/v1/auth/users/batch",
|
||||||
json={"user_ids": user_ids},
|
json={"user_ids": user_ids},
|
||||||
timeout=10.0,
|
timeout=10.0,
|
||||||
headers={"X-Internal-Service": "tenant-service"}
|
headers={"x-internal-service": "tenant-service"}
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
@@ -226,7 +226,7 @@ class TenantMemberRepository(TenantBaseRepository):
|
|||||||
response = await client.get(
|
response = await client.get(
|
||||||
f"{auth_service_url}/api/v1/auth/users/{user_id}",
|
f"{auth_service_url}/api/v1/auth/users/{user_id}",
|
||||||
timeout=5.0,
|
timeout=5.0,
|
||||||
headers={"X-Internal-Service": "tenant-service"}
|
headers={"x-internal-service": "tenant-service"}
|
||||||
)
|
)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
user_data = response.json()
|
user_data = response.json()
|
||||||
@@ -243,7 +243,7 @@ class TenantMemberRepository(TenantBaseRepository):
|
|||||||
response = await client.get(
|
response = await client.get(
|
||||||
f"{auth_service_url}/api/v1/auth/users/{user_id}",
|
f"{auth_service_url}/api/v1/auth/users/{user_id}",
|
||||||
timeout=5.0,
|
timeout=5.0,
|
||||||
headers={"X-Internal-Service": "tenant-service"}
|
headers={"x-internal-service": "tenant-service"}
|
||||||
)
|
)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
user_data = response.json()
|
user_data = response.json()
|
||||||
|
|||||||
@@ -216,17 +216,24 @@ class HybridProphetXGBoost:
|
|||||||
Get Prophet predictions for given dataframe.
|
Get Prophet predictions for given dataframe.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prophet_result: Prophet model result from training
|
prophet_result: Prophet model result from training (contains model_path)
|
||||||
df: DataFrame with 'ds' column
|
df: DataFrame with 'ds' column
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Array of predictions
|
Array of predictions
|
||||||
"""
|
"""
|
||||||
# Get the Prophet model from result
|
# Get the model path from result instead of expecting the model object directly
|
||||||
prophet_model = prophet_result.get('model')
|
model_path = prophet_result.get('model_path')
|
||||||
|
|
||||||
if prophet_model is None:
|
if model_path is None:
|
||||||
raise ValueError("Prophet model not found in result")
|
raise ValueError("Prophet model path not found in result")
|
||||||
|
|
||||||
|
# Load the actual Prophet model from the stored path
|
||||||
|
try:
|
||||||
|
import joblib
|
||||||
|
prophet_model = joblib.load(model_path)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to load Prophet model from path {model_path}: {str(e)}")
|
||||||
|
|
||||||
# Prepare dataframe for prediction
|
# Prepare dataframe for prediction
|
||||||
pred_df = df[['ds']].copy()
|
pred_df = df[['ds']].copy()
|
||||||
@@ -273,7 +280,8 @@ class HybridProphetXGBoost:
|
|||||||
'reg_lambda': 1.0, # L2 regularization
|
'reg_lambda': 1.0, # L2 regularization
|
||||||
'objective': 'reg:squarederror',
|
'objective': 'reg:squarederror',
|
||||||
'random_state': 42,
|
'random_state': 42,
|
||||||
'n_jobs': -1
|
'n_jobs': -1,
|
||||||
|
'early_stopping_rounds': 10
|
||||||
}
|
}
|
||||||
|
|
||||||
# Initialize model
|
# Initialize model
|
||||||
@@ -285,7 +293,6 @@ class HybridProphetXGBoost:
|
|||||||
model.fit,
|
model.fit,
|
||||||
X_train, y_train,
|
X_train, y_train,
|
||||||
eval_set=[(X_val, y_val)],
|
eval_set=[(X_val, y_val)],
|
||||||
early_stopping_rounds=10,
|
|
||||||
verbose=False
|
verbose=False
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -303,109 +310,86 @@ class HybridProphetXGBoost:
|
|||||||
train_prophet_pred: np.ndarray,
|
train_prophet_pred: np.ndarray,
|
||||||
val_prophet_pred: np.ndarray,
|
val_prophet_pred: np.ndarray,
|
||||||
prophet_result: Dict[str, Any]
|
prophet_result: Dict[str, Any]
|
||||||
) -> Dict[str, float]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Evaluate hybrid model vs Prophet-only on validation set.
|
Evaluate the overall performance of the hybrid model using threading for metrics.
|
||||||
|
|
||||||
Args:
|
|
||||||
train_df: Training data
|
|
||||||
val_df: Validation data
|
|
||||||
train_prophet_pred: Prophet predictions on training set
|
|
||||||
val_prophet_pred: Prophet predictions on validation set
|
|
||||||
prophet_result: Prophet training result
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary of metrics
|
|
||||||
"""
|
"""
|
||||||
# Get actual values
|
import asyncio
|
||||||
train_actual = train_df['y'].values
|
|
||||||
val_actual = val_df['y'].values
|
|
||||||
|
|
||||||
# Get XGBoost predictions on residuals
|
# Get XGBoost predictions on training and validation
|
||||||
X_train = train_df[self.feature_columns].values
|
X_train = train_df[self.feature_columns].values
|
||||||
X_val = val_df[self.feature_columns].values
|
X_val = val_df[self.feature_columns].values
|
||||||
|
|
||||||
# ✅ FIX: Run blocking predict() in thread pool to avoid blocking event loop
|
|
||||||
import asyncio
|
|
||||||
train_xgb_pred = await asyncio.to_thread(self.xgb_model.predict, X_train)
|
train_xgb_pred = await asyncio.to_thread(self.xgb_model.predict, X_train)
|
||||||
val_xgb_pred = await asyncio.to_thread(self.xgb_model.predict, X_val)
|
val_xgb_pred = await asyncio.to_thread(self.xgb_model.predict, X_val)
|
||||||
|
|
||||||
# Hybrid predictions = Prophet + XGBoost residual correction
|
# Hybrid prediction = Prophet prediction + XGBoost residual prediction
|
||||||
train_hybrid_pred = train_prophet_pred + train_xgb_pred
|
train_hybrid_pred = train_prophet_pred + train_xgb_pred
|
||||||
val_hybrid_pred = val_prophet_pred + val_xgb_pred
|
val_hybrid_pred = val_prophet_pred + val_xgb_pred
|
||||||
|
|
||||||
# Calculate metrics for Prophet-only
|
actual_train = train_df['y'].values
|
||||||
prophet_train_mae = mean_absolute_error(train_actual, train_prophet_pred)
|
actual_val = val_df['y'].values
|
||||||
prophet_val_mae = mean_absolute_error(val_actual, val_prophet_pred)
|
|
||||||
prophet_train_mape = mean_absolute_percentage_error(train_actual, train_prophet_pred) * 100
|
|
||||||
prophet_val_mape = mean_absolute_percentage_error(val_actual, val_prophet_pred) * 100
|
|
||||||
|
|
||||||
# Calculate metrics for Hybrid
|
# Basic RMSE calculation
|
||||||
hybrid_train_mae = mean_absolute_error(train_actual, train_hybrid_pred)
|
train_rmse = float(np.sqrt(np.mean((actual_train - train_hybrid_pred)**2)))
|
||||||
hybrid_val_mae = mean_absolute_error(val_actual, val_hybrid_pred)
|
val_rmse = float(np.sqrt(np.mean((actual_val - val_hybrid_pred)**2)))
|
||||||
hybrid_train_mape = mean_absolute_percentage_error(train_actual, train_hybrid_pred) * 100
|
|
||||||
hybrid_val_mape = mean_absolute_percentage_error(val_actual, val_hybrid_pred) * 100
|
# MAE
|
||||||
|
train_mae = float(np.mean(np.abs(actual_train - train_hybrid_pred)))
|
||||||
|
val_mae = float(np.mean(np.abs(actual_val - val_hybrid_pred)))
|
||||||
|
|
||||||
|
# MAPE (with safety for zero sales)
|
||||||
|
train_mape = float(np.mean(np.abs((actual_train - train_hybrid_pred) / np.maximum(actual_train, 1))))
|
||||||
|
val_mape = float(np.mean(np.abs((actual_val - val_hybrid_pred) / np.maximum(actual_val, 1))))
|
||||||
|
|
||||||
# Calculate improvement
|
# Calculate improvement
|
||||||
mae_improvement = ((prophet_val_mae - hybrid_val_mae) / prophet_val_mae) * 100
|
prophet_metrics = prophet_result.get("metrics", {})
|
||||||
mape_improvement = ((prophet_val_mape - hybrid_val_mape) / prophet_val_mape) * 100
|
prophet_val_mae = prophet_metrics.get("val_mae", val_mae) # Fallback to hybrid if missing
|
||||||
|
prophet_val_mape = prophet_metrics.get("val_mape", val_mape)
|
||||||
|
|
||||||
|
improvement_pct = 0.0
|
||||||
|
if prophet_val_mape > 0:
|
||||||
|
improvement_pct = ((prophet_val_mape - val_mape) / prophet_val_mape) * 100
|
||||||
|
|
||||||
metrics = {
|
metrics = {
|
||||||
'prophet_train_mae': float(prophet_train_mae),
|
"train_rmse": train_rmse,
|
||||||
'prophet_val_mae': float(prophet_val_mae),
|
"val_rmse": val_rmse,
|
||||||
'prophet_train_mape': float(prophet_train_mape),
|
"train_mae": train_mae,
|
||||||
'prophet_val_mape': float(prophet_val_mape),
|
"val_mae": val_mae,
|
||||||
'hybrid_train_mae': float(hybrid_train_mae),
|
"train_mape": train_mape,
|
||||||
'hybrid_val_mae': float(hybrid_val_mae),
|
"val_mape": val_mape,
|
||||||
'hybrid_train_mape': float(hybrid_train_mape),
|
"prophet_val_mape": prophet_val_mape,
|
||||||
'hybrid_val_mape': float(hybrid_val_mape),
|
"hybrid_val_mape": val_mape,
|
||||||
'mae_improvement_pct': float(mae_improvement),
|
"improvement_percentage": float(improvement_pct),
|
||||||
'mape_improvement_pct': float(mape_improvement),
|
"prophet_metrics": prophet_metrics
|
||||||
'improvement_percentage': float(mape_improvement) # Primary metric
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Hybrid model evaluation complete",
|
||||||
|
val_rmse=val_rmse,
|
||||||
|
val_mae=val_mae,
|
||||||
|
val_mape=val_mape,
|
||||||
|
improvement=improvement_pct
|
||||||
|
)
|
||||||
|
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
def _package_hybrid_model(
|
def _package_hybrid_model(
|
||||||
self,
|
self,
|
||||||
prophet_result: Dict[str, Any],
|
prophet_result: Dict[str, Any],
|
||||||
metrics: Dict[str, float],
|
metrics: Dict[str, Any],
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
inventory_product_id: str
|
inventory_product_id: str
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Package hybrid model for storage.
|
Package hybrid model for storage.
|
||||||
|
|
||||||
Args:
|
|
||||||
prophet_result: Prophet model result
|
|
||||||
metrics: Hybrid model metrics
|
|
||||||
tenant_id: Tenant ID
|
|
||||||
inventory_product_id: Product ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Model package dictionary
|
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
'model_type': 'hybrid_prophet_xgboost',
|
'model_type': 'hybrid_prophet_xgboost',
|
||||||
'prophet_model': prophet_result.get('model'),
|
'prophet_model_path': prophet_result.get('model_path'),
|
||||||
'xgboost_model': self.xgb_model,
|
'xgboost_model': self.xgb_model,
|
||||||
'feature_columns': self.feature_columns,
|
'feature_columns': self.feature_columns,
|
||||||
'prophet_metrics': {
|
'metrics': metrics,
|
||||||
'train_mae': metrics['prophet_train_mae'],
|
|
||||||
'val_mae': metrics['prophet_val_mae'],
|
|
||||||
'train_mape': metrics['prophet_train_mape'],
|
|
||||||
'val_mape': metrics['prophet_val_mape']
|
|
||||||
},
|
|
||||||
'hybrid_metrics': {
|
|
||||||
'train_mae': metrics['hybrid_train_mae'],
|
|
||||||
'val_mae': metrics['hybrid_val_mae'],
|
|
||||||
'train_mape': metrics['hybrid_train_mape'],
|
|
||||||
'val_mape': metrics['hybrid_val_mape']
|
|
||||||
},
|
|
||||||
'improvement_metrics': {
|
|
||||||
'mae_improvement_pct': metrics['mae_improvement_pct'],
|
|
||||||
'mape_improvement_pct': metrics['mape_improvement_pct']
|
|
||||||
},
|
|
||||||
'tenant_id': tenant_id,
|
'tenant_id': tenant_id,
|
||||||
'inventory_product_id': inventory_product_id,
|
'inventory_product_id': inventory_product_id,
|
||||||
'trained_at': datetime.now(timezone.utc).isoformat()
|
'trained_at': datetime.now(timezone.utc).isoformat()
|
||||||
@@ -426,8 +410,18 @@ class HybridProphetXGBoost:
|
|||||||
Returns:
|
Returns:
|
||||||
DataFrame with predictions
|
DataFrame with predictions
|
||||||
"""
|
"""
|
||||||
# Step 1: Get Prophet predictions
|
# Step 1: Get Prophet model from path and make predictions
|
||||||
prophet_model = model_data['prophet_model']
|
prophet_model_path = model_data.get('prophet_model_path')
|
||||||
|
if prophet_model_path is None:
|
||||||
|
raise ValueError("Prophet model path not found in model data")
|
||||||
|
|
||||||
|
# Load the Prophet model from the stored path
|
||||||
|
try:
|
||||||
|
import joblib
|
||||||
|
prophet_model = joblib.load(prophet_model_path)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to load Prophet model from path {prophet_model_path}: {str(e)}")
|
||||||
|
|
||||||
# ✅ FIX: Run blocking predict() in thread pool to avoid blocking event loop
|
# ✅ FIX: Run blocking predict() in thread pool to avoid blocking event loop
|
||||||
import asyncio
|
import asyncio
|
||||||
prophet_forecast = await asyncio.to_thread(prophet_model.predict, future_df)
|
prophet_forecast = await asyncio.to_thread(prophet_model.predict, future_df)
|
||||||
|
|||||||
@@ -43,86 +43,79 @@ class POIFeatureIntegrator:
|
|||||||
force_refresh: bool = False
|
force_refresh: bool = False
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Fetch POI features for tenant location.
|
Fetch POI features for tenant location (optimized for training).
|
||||||
|
|
||||||
First checks if POI context exists, if not, triggers detection.
|
First checks if POI context exists. If not, returns None without triggering detection.
|
||||||
|
POI detection should be triggered during tenant registration, not during training.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tenant_id: Tenant UUID
|
tenant_id: Tenant UUID
|
||||||
latitude: Bakery latitude
|
latitude: Bakery latitude
|
||||||
longitude: Bakery longitude
|
longitude: Bakery longitude
|
||||||
force_refresh: Force re-detection
|
force_refresh: Force re-detection (only use if POI context already exists)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with POI features or None if detection fails
|
Dictionary with POI features or None if not available
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Try to get existing POI context first
|
# Try to get existing POI context first
|
||||||
if not force_refresh:
|
|
||||||
existing_context = await self.external_client.get_poi_context(tenant_id)
|
existing_context = await self.external_client.get_poi_context(tenant_id)
|
||||||
|
|
||||||
if existing_context:
|
if existing_context:
|
||||||
poi_context = existing_context.get("poi_context", {})
|
poi_context = existing_context.get("poi_context", {})
|
||||||
ml_features = poi_context.get("ml_features", {})
|
ml_features = poi_context.get("ml_features", {})
|
||||||
|
|
||||||
# Check if stale
|
# Check if stale and force_refresh is requested
|
||||||
is_stale = existing_context.get("is_stale", False)
|
is_stale = existing_context.get("is_stale", False)
|
||||||
if not is_stale:
|
|
||||||
|
if not is_stale or not force_refresh:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Using existing POI context",
|
"Using existing POI context",
|
||||||
tenant_id=tenant_id
|
tenant_id=tenant_id,
|
||||||
|
is_stale=is_stale,
|
||||||
|
feature_count=len(ml_features)
|
||||||
)
|
)
|
||||||
return ml_features
|
return ml_features
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
"POI context is stale, refreshing",
|
"POI context is stale and force_refresh=True, refreshing",
|
||||||
tenant_id=tenant_id
|
tenant_id=tenant_id
|
||||||
)
|
)
|
||||||
force_refresh = True
|
# Only refresh if explicitly requested and context exists
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
"No existing POI context, will detect",
|
|
||||||
tenant_id=tenant_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# Detect or refresh POIs
|
|
||||||
logger.info(
|
|
||||||
"Detecting POIs for tenant",
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
location=(latitude, longitude)
|
|
||||||
)
|
|
||||||
|
|
||||||
detection_result = await self.external_client.detect_poi_for_tenant(
|
detection_result = await self.external_client.detect_poi_for_tenant(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
latitude=latitude,
|
latitude=latitude,
|
||||||
longitude=longitude,
|
longitude=longitude,
|
||||||
force_refresh=force_refresh
|
force_refresh=True
|
||||||
)
|
)
|
||||||
|
|
||||||
if detection_result:
|
if detection_result:
|
||||||
poi_context = detection_result.get("poi_context", {})
|
poi_context = detection_result.get("poi_context", {})
|
||||||
ml_features = poi_context.get("ml_features", {})
|
ml_features = poi_context.get("ml_features", {})
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"POI detection completed",
|
"POI refresh completed",
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
total_pois=poi_context.get("total_pois_detected", 0),
|
|
||||||
feature_count=len(ml_features)
|
feature_count=len(ml_features)
|
||||||
)
|
)
|
||||||
|
|
||||||
return ml_features
|
return ml_features
|
||||||
else:
|
else:
|
||||||
logger.error(
|
logger.warning(
|
||||||
"POI detection failed",
|
"POI refresh failed, returning existing features",
|
||||||
|
tenant_id=tenant_id
|
||||||
|
)
|
||||||
|
return ml_features
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"No existing POI context found - POI detection should be triggered during tenant registration",
|
||||||
tenant_id=tenant_id
|
tenant_id=tenant_id
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.warning(
|
||||||
"Unexpected error fetching POI features",
|
"Error fetching POI features - returning None",
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
error=str(e),
|
error=str(e)
|
||||||
exc_info=True
|
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -29,16 +29,15 @@ class DataClient:
|
|||||||
self.sales_client = get_sales_client(settings, "training")
|
self.sales_client = get_sales_client(settings, "training")
|
||||||
self.external_client = get_external_client(settings, "training")
|
self.external_client = get_external_client(settings, "training")
|
||||||
|
|
||||||
|
# ExternalServiceClient always has get_stored_traffic_data_for_training method
|
||||||
|
self.supports_stored_traffic_data = True
|
||||||
|
|
||||||
# Configure timeouts for HTTP clients
|
# Configure timeouts for HTTP clients
|
||||||
self._configure_timeouts()
|
self._configure_timeouts()
|
||||||
|
|
||||||
# Initialize circuit breakers for external services
|
# Initialize circuit breakers for external services
|
||||||
self._init_circuit_breakers()
|
self._init_circuit_breakers()
|
||||||
|
|
||||||
# Check if the new method is available for stored traffic data
|
|
||||||
if hasattr(self.external_client, 'get_stored_traffic_data_for_training'):
|
|
||||||
self.supports_stored_traffic_data = True
|
|
||||||
|
|
||||||
def _configure_timeouts(self):
|
def _configure_timeouts(self):
|
||||||
"""Configure appropriate timeouts for HTTP clients"""
|
"""Configure appropriate timeouts for HTTP clients"""
|
||||||
timeout = httpx.Timeout(
|
timeout = httpx.Timeout(
|
||||||
@@ -49,14 +48,12 @@ class DataClient:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Apply timeout to clients if they have httpx clients
|
# Apply timeout to clients if they have httpx clients
|
||||||
|
# Note: BaseServiceClient manages its own HTTP client internally
|
||||||
if hasattr(self.sales_client, 'client') and isinstance(self.sales_client.client, httpx.AsyncClient):
|
if hasattr(self.sales_client, 'client') and isinstance(self.sales_client.client, httpx.AsyncClient):
|
||||||
self.sales_client.client.timeout = timeout
|
self.sales_client.client.timeout = timeout
|
||||||
|
|
||||||
if hasattr(self.external_client, 'client') and isinstance(self.external_client.client, httpx.AsyncClient):
|
if hasattr(self.external_client, 'client') and isinstance(self.external_client.client, httpx.AsyncClient):
|
||||||
self.external_client.client.timeout = timeout
|
self.external_client.client.timeout = timeout
|
||||||
else:
|
|
||||||
self.supports_stored_traffic_data = False
|
|
||||||
logger.warning("Stored traffic data method not available in external client")
|
|
||||||
|
|
||||||
def _init_circuit_breakers(self):
|
def _init_circuit_breakers(self):
|
||||||
"""Initialize circuit breakers for external service calls"""
|
"""Initialize circuit breakers for external service calls"""
|
||||||
|
|||||||
@@ -404,22 +404,32 @@ class TrainingDataOrchestrator:
|
|||||||
tenant_id: str
|
tenant_id: str
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Collect POI features for bakery location.
|
Collect POI features for bakery location (non-blocking).
|
||||||
|
|
||||||
POI features are static (location-based, not time-varying).
|
POI features are static (location-based, not time-varying).
|
||||||
|
This method is non-blocking with a short timeout to prevent training delays.
|
||||||
|
If POI detection hasn't been run yet, training continues without POI features.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with POI features or empty dict if unavailable
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Collecting POI features",
|
"Collecting POI features (non-blocking)",
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
location=(lat, lon)
|
location=(lat, lon)
|
||||||
)
|
)
|
||||||
|
|
||||||
poi_features = await self.poi_feature_integrator.fetch_poi_features(
|
# Set a short timeout to prevent blocking training
|
||||||
|
# POI detection should have been triggered during tenant registration
|
||||||
|
poi_features = await asyncio.wait_for(
|
||||||
|
self.poi_feature_integrator.fetch_poi_features(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
latitude=lat,
|
latitude=lat,
|
||||||
longitude=lon,
|
longitude=lon,
|
||||||
force_refresh=False
|
force_refresh=False
|
||||||
|
),
|
||||||
|
timeout=15.0 # 15 second timeout - POI should be cached from registration
|
||||||
)
|
)
|
||||||
|
|
||||||
if poi_features:
|
if poi_features:
|
||||||
@@ -430,18 +440,24 @@ class TrainingDataOrchestrator:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"No POI features collected (service may be unavailable)",
|
"No POI features collected (service may be unavailable or not yet detected)",
|
||||||
tenant_id=tenant_id
|
tenant_id=tenant_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return poi_features or {}
|
return poi_features or {}
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(
|
||||||
|
"POI collection timeout (15s) - continuing training without POI features. "
|
||||||
|
"POI detection should be triggered during tenant registration for best results.",
|
||||||
|
tenant_id=tenant_id
|
||||||
|
)
|
||||||
|
return {}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.warning(
|
||||||
"Failed to collect POI features, continuing without them",
|
"Failed to collect POI features (non-blocking) - continuing training without them",
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
error=str(e),
|
error=str(e)
|
||||||
exc_info=True
|
|
||||||
)
|
)
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ class ServiceAuthenticator:
|
|||||||
}
|
}
|
||||||
|
|
||||||
if tenant_id:
|
if tenant_id:
|
||||||
headers["X-Tenant-ID"] = str(tenant_id)
|
headers["x-tenant-id"] = str(tenant_id)
|
||||||
|
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
|
|||||||
@@ -351,7 +351,7 @@ class ForecastServiceClient(BaseServiceClient):
|
|||||||
"""
|
"""
|
||||||
Trigger demand forecasting insights for a tenant (internal service use only).
|
Trigger demand forecasting insights for a tenant (internal service use only).
|
||||||
|
|
||||||
This method calls the internal endpoint which is protected by X-Internal-Service header.
|
This method calls the internal endpoint which is protected by x-internal-service header.
|
||||||
Used by demo-session service after cloning to generate AI insights from seeded data.
|
Used by demo-session service after cloning to generate AI insights from seeded data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -366,7 +366,7 @@ class ForecastServiceClient(BaseServiceClient):
|
|||||||
endpoint=f"forecasting/internal/ml/generate-demand-insights",
|
endpoint=f"forecasting/internal/ml/generate-demand-insights",
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
data={"tenant_id": tenant_id},
|
data={"tenant_id": tenant_id},
|
||||||
headers={"X-Internal-Service": "demo-session"}
|
headers={"x-internal-service": "demo-session"}
|
||||||
)
|
)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
|
|||||||
@@ -766,7 +766,7 @@ class InventoryServiceClient(BaseServiceClient):
|
|||||||
"""
|
"""
|
||||||
Trigger inventory alerts for a tenant (internal service use only).
|
Trigger inventory alerts for a tenant (internal service use only).
|
||||||
|
|
||||||
This method calls the internal endpoint which is protected by X-Internal-Service header.
|
This method calls the internal endpoint which is protected by x-internal-service header.
|
||||||
The endpoint should trigger alerts specifically for the given tenant.
|
The endpoint should trigger alerts specifically for the given tenant.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -783,7 +783,7 @@ class InventoryServiceClient(BaseServiceClient):
|
|||||||
endpoint="inventory/internal/alerts/trigger",
|
endpoint="inventory/internal/alerts/trigger",
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
data={},
|
data={},
|
||||||
headers={"X-Internal-Service": "demo-session"}
|
headers={"x-internal-service": "demo-session"}
|
||||||
)
|
)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
@@ -819,7 +819,7 @@ class InventoryServiceClient(BaseServiceClient):
|
|||||||
"""
|
"""
|
||||||
Trigger safety stock optimization insights for a tenant (internal service use only).
|
Trigger safety stock optimization insights for a tenant (internal service use only).
|
||||||
|
|
||||||
This method calls the internal endpoint which is protected by X-Internal-Service header.
|
This method calls the internal endpoint which is protected by x-internal-service header.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tenant_id: Tenant ID to trigger insights for
|
tenant_id: Tenant ID to trigger insights for
|
||||||
@@ -833,7 +833,7 @@ class InventoryServiceClient(BaseServiceClient):
|
|||||||
endpoint="inventory/internal/ml/generate-safety-stock-insights",
|
endpoint="inventory/internal/ml/generate-safety-stock-insights",
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
data={"tenant_id": tenant_id},
|
data={"tenant_id": tenant_id},
|
||||||
headers={"X-Internal-Service": "demo-session"}
|
headers={"x-internal-service": "demo-session"}
|
||||||
)
|
)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
|
|||||||
@@ -580,7 +580,7 @@ class ProcurementServiceClient(BaseServiceClient):
|
|||||||
"""
|
"""
|
||||||
Trigger delivery tracking for a tenant (internal service use only).
|
Trigger delivery tracking for a tenant (internal service use only).
|
||||||
|
|
||||||
This method calls the internal endpoint which is protected by X-Internal-Service header.
|
This method calls the internal endpoint which is protected by x-internal-service header.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tenant_id: Tenant ID to trigger delivery tracking for
|
tenant_id: Tenant ID to trigger delivery tracking for
|
||||||
@@ -596,7 +596,7 @@ class ProcurementServiceClient(BaseServiceClient):
|
|||||||
endpoint="procurement/internal/delivery-tracking/trigger",
|
endpoint="procurement/internal/delivery-tracking/trigger",
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
data={},
|
data={},
|
||||||
headers={"X-Internal-Service": "demo-session"}
|
headers={"x-internal-service": "demo-session"}
|
||||||
)
|
)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
@@ -632,7 +632,7 @@ class ProcurementServiceClient(BaseServiceClient):
|
|||||||
"""
|
"""
|
||||||
Trigger price forecasting insights for a tenant (internal service use only).
|
Trigger price forecasting insights for a tenant (internal service use only).
|
||||||
|
|
||||||
This method calls the internal endpoint which is protected by X-Internal-Service header.
|
This method calls the internal endpoint which is protected by x-internal-service header.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tenant_id: Tenant ID to trigger insights for
|
tenant_id: Tenant ID to trigger insights for
|
||||||
@@ -646,7 +646,7 @@ class ProcurementServiceClient(BaseServiceClient):
|
|||||||
endpoint="procurement/internal/ml/generate-price-insights",
|
endpoint="procurement/internal/ml/generate-price-insights",
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
data={"tenant_id": tenant_id},
|
data={"tenant_id": tenant_id},
|
||||||
headers={"X-Internal-Service": "demo-session"}
|
headers={"x-internal-service": "demo-session"}
|
||||||
)
|
)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
|
|||||||
@@ -630,7 +630,7 @@ class ProductionServiceClient(BaseServiceClient):
|
|||||||
"""
|
"""
|
||||||
Trigger production alerts for a tenant (internal service use only).
|
Trigger production alerts for a tenant (internal service use only).
|
||||||
|
|
||||||
This method calls the internal endpoint which is protected by X-Internal-Service header.
|
This method calls the internal endpoint which is protected by x-internal-service header.
|
||||||
Includes both production alerts and equipment maintenance checks.
|
Includes both production alerts and equipment maintenance checks.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -647,7 +647,7 @@ class ProductionServiceClient(BaseServiceClient):
|
|||||||
endpoint="production/internal/alerts/trigger",
|
endpoint="production/internal/alerts/trigger",
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
data={},
|
data={},
|
||||||
headers={"X-Internal-Service": "demo-session"}
|
headers={"x-internal-service": "demo-session"}
|
||||||
)
|
)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
@@ -683,7 +683,7 @@ class ProductionServiceClient(BaseServiceClient):
|
|||||||
"""
|
"""
|
||||||
Trigger yield improvement insights for a tenant (internal service use only).
|
Trigger yield improvement insights for a tenant (internal service use only).
|
||||||
|
|
||||||
This method calls the internal endpoint which is protected by X-Internal-Service header.
|
This method calls the internal endpoint which is protected by x-internal-service header.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tenant_id: Tenant ID to trigger insights for
|
tenant_id: Tenant ID to trigger insights for
|
||||||
@@ -697,7 +697,7 @@ class ProductionServiceClient(BaseServiceClient):
|
|||||||
endpoint="production/internal/ml/generate-yield-insights",
|
endpoint="production/internal/ml/generate-yield-insights",
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
data={"tenant_id": tenant_id},
|
data={"tenant_id": tenant_id},
|
||||||
headers={"X-Internal-Service": "demo-session"}
|
headers={"x-internal-service": "demo-session"}
|
||||||
)
|
)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
|
|||||||
Reference in New Issue
Block a user