Add improvements 2

This commit is contained in:
Urtzi Alfaro
2026-01-12 22:15:11 +01:00
parent 230bbe6a19
commit b931a5c45e
40 changed files with 1820 additions and 887 deletions

120
DOCKER_MAINTENANCE.md Normal file
View File

@@ -0,0 +1,120 @@
# Docker Maintenance Guide for Local Development
## The Problem
When developing with Tilt and local Kubernetes (Kind), Docker accumulates:
- **Multiple image versions** from each code change (Tilt rebuilds)
- **Unused volumes** from previous cluster runs
- **Build cache** that grows over time
This quickly fills up disk space, causing pods to fail with "No space left on device" errors.
## Quick Fix (When You Hit Disk Issues)
```bash
# Clean up all unused Docker resources
docker system prune -a --volumes -f
```
This removes:
- All unused images
- All unused volumes
- All build cache
**Expected recovery**: 60-100GB
## Regular Maintenance
### Option 1: Use the Cleanup Script (Recommended)
Run the maintenance script weekly:
```bash
./scripts/cleanup-docker.sh
```
Or run it automatically without confirmation:
```bash
./scripts/cleanup-docker.sh --auto
```
### Option 2: Manual Commands
```bash
# Remove images older than 24 hours
docker image prune -af --filter "until=24h"
# Remove unused volumes
docker volume prune -f
# Remove build cache
docker builder prune -af
```
### Option 3: Set Up Automated Cleanup
Add to your crontab (run every Sunday at 2 AM):
```bash
crontab -e
# Add this line:
0 2 * * 0 /Users/urtzialfaro/Documents/bakery-ia/scripts/cleanup-docker.sh --auto >> /tmp/docker-cleanup.log 2>&1
```
## Monitoring Disk Usage
### Check Docker disk usage:
```bash
docker system df
```
### Check Kind node disk usage:
```bash
docker exec bakery-ia-local-control-plane df -h /var
```
### Alert thresholds:
- **< 70%**: Healthy
- **70-85%**: Consider cleanup soon
- **> 85%**: Run cleanup immediately 🚨
- **> 95%**: Critical - pods will fail ❌
## Prevention Tips
1. **Run cleanup weekly** to prevent accumulation
2. **Monitor disk usage** before long dev sessions
3. **Delete old Kind clusters** when switching projects:
```bash
kind delete cluster --name bakery-ia-local
```
4. **Increase Docker disk allocation** in Docker Desktop settings if you frequently rebuild many services
## Troubleshooting
### Pods in CrashLoopBackOff after disk issues:
1. Run cleanup (see Quick Fix above)
2. Restart failed pods:
```bash
kubectl get pods -n bakery-ia | grep -E "(CrashLoopBackOff|Error)" | awk '{print $1}' | xargs kubectl delete pod -n bakery-ia
```
### Cleanup didn't free enough space:
If still above 90% after cleanup:
```bash
# Nuclear option - rebuild everything
kind delete cluster --name bakery-ia-local
docker system prune -a --volumes -f
# Then recreate cluster with your setup scripts
```
## What Happened Today (2026-01-12)
- **Issue**: Disk was 100% full (113GB/113GB), causing database pods to crash
- **Root cause**: 122 unused Docker images + 16 unused volumes + 6GB build cache
- **Solution**: Ran `docker system prune -a --volumes -f`
- **Result**: Freed 89GB, disk now at 22% usage (24GB/113GB)
- **All services recovered successfully**

View File

@@ -1,185 +0,0 @@
# Lista de Mejoras Propuestas
## 1. Nueva Sección: Seguridad y Ciberseguridad (Añadir después de 5.2)
### 5.3. Arquitectura de Seguridad y Cumplimiento Normativo Europeo
**Autenticación y Autorización:**
- JWT con rotación de tokens cada 15 minutos
- Control de acceso basado en roles (RBAC)
- Rate limiting (300 req/min por cliente)
- Autenticación multifactor (MFA) planificada para Q1 2026
**Protección de Datos:**
- Cifrado AES-256 en reposo
- HTTPS obligatorio con certificados Let's Encrypt auto-renovables
- Aislamiento multi-tenant a nivel de base de datos
- Prevención SQL injection mediante Pydantic schemas
- Protección XSS y CORS
**Monitorización y Trazabilidad:**
- OpenTelemetry para trazabilidad distribuida end-to-end
- SigNoz como plataforma unificada de observabilidad
- Logs centralizados con correlación de trazas
- Auditoría completa de accesos y cambios
- Alertas en tiempo real (email y Slack)
**Cumplimiento RGPD:**
- Privacy by design
- Derecho al olvido implementado
- Exportación de datos en CSV/JSON
- Gestión de consentimientos con historial
- Anonimización de datos analíticos
**Infraestructura Segura:**
- Kubernetes con políticas de seguridad de pods
- Actualizaciones automáticas de seguridad
- Backups cifrados diarios
- Plan de recuperación ante desastres (DR)
- PostgreSQL 17 con configuraciones hardened
## 2. Actualizar Sección 4.4 (Competencia) - Añadir Ventaja de Seguridad
**Ventaja Competitiva en Seguridad:**
- Arquitectura "Security-First" vs. soluciones legacy vulnerables
- Cumplimiento RGPD desde el diseño (competidores retrofitting)
- Observabilidad completa (OpenTelemetry + SigNoz) vs. cajas negras
- Certificaciones de seguridad planificadas (ISO 27001, ENS)
- Alineación con NIS2 Directive (obligatoria 2024 para cadena alimentaria)
## 3. Actualizar Sección 8.1 (Financiación) - Nuevas Líneas de Ayuda
### Añadir subsección: "Financiación Europea en Ciberseguridad 2026-2027"
**Programas Europeos Identificados (2026-2027):**
1. **Digital Europe Programme - Ciberseguridad (€390M totales)**
- **UPTAKE Program**: €15M para SMEs en cumplimiento normativo
- **AI-Based Cybersecurity**: €15M para sistemas IA de seguridad
- Cofinanciación: hasta 75% de costes de proyecto
- Proyectos: €3-5M por proyecto, duración 36 meses
- **Elegibilidad**: Bakery-IA califica como SME tecnológica
- **Solicitud estimada**: €200.000 (Q1 2026)
2. **INCIBE EMPRENDE Program (España Digital 2026)**
- Presupuesto: €191M (2023-2026)
- 34 entidades colaboradoras en España
- Ideación, incubación y aceleración en ciberseguridad
- Fondos Next Generation-EU
- **Solicitud estimada**: €50.000 (Q2 2026)
3. **ENISA Emprendedoras Digitales**
- Préstamos participativos hasta €51M movilizables
- Líneas específicas para emprendimiento digital
- Sin avales personales, condiciones flexibles
- **Solicitud estimada**: €75.000 (Q2 2026)
**Alineación con Prioridades UE:**
- NIS2 Directive (seguridad sector alimentario)
- Cyber Resilience Act (productos digitales seguros)
- AI Act (IA transparente y auditable)
- GDPR (protección datos desde diseño)
## 4. Actualizar Sección 5.2 (Arquitectura Técnica)
### Añadir detalles de la documentación técnica:
**Arquitectura de Microservicios (21 servicios independientes):**
- API Gateway centralizado con JWT y cache Redis (95% hit rate)
- Frontend: React 18 + TypeScript (PWA mobile-first)
- 18 bases de datos PostgreSQL 17 (patrón database-per-service)
- Redis 7.4 para caché y RabbitMQ 4.1 para eventos
- Kubernetes en VPS (escalabilidad horizontal)
**Innovación Técnica Destacable:**
- **Sistema de Alertas Enriquecidas** (3 niveles: Alertas/Notificaciones/Recomendaciones)
- **Priorización Inteligente** con scoring 0-100 (4 factores ponderados)
- **Escalado Temporal** (+10 a 48h, +20 a 72h, +30 cerca deadline)
- **Encadenamiento Causal** (stock shortage → retraso producción → riesgo pedido)
- **Deduplicación** (95% reducción spam de alertas)
- **SSE + WebSocket** para actualizaciones en tiempo real
- **Prophet ML** con 20+ features (AEMET, tráfico Madrid, festivos)
**Observabilidad de Clase Empresarial:**
- **SigNoz**: Trazas, métricas y logs unificados
- **OpenTelemetry**: Auto-instrumentación de 18 servicios
- **ClickHouse**: Backend de alto rendimiento para análisis
- **Alerting**: Multi-canal (email, Slack) vía AlertManager
- Monitorización: 18 DBs PostgreSQL, Redis, RabbitMQ, Kubernetes
## 5. Actualizar Resumen Ejecutivo (Sección 0)
### Añadir bullet:
- **Seguridad y Cumplimiento:** Arquitectura Security-First con cumplimiento RGPD, observabilidad completa (OpenTelemetry/SigNoz), y alineación con normativas europeas (NIS2, Cyber Resilience Act). Elegible para €390M del Digital Europe Programme en ciberseguridad.
## 6. Actualizar Sección 9 (Decálogo) - Añadir Oportunidades de Financiación
**6. Alineación Estratégica con Prioridades Europeas 2026-2027:**
- **€390M disponibles** en Digital Europe Programme para ciberseguridad
- **€191M** del programa INCIBE EMPRENDE (España Digital 2026)
- Bakery-IA califica para **3 líneas de financiación simultáneas**:
* UPTAKE (€200K) - Cumplimiento normativo SMEs
* INCIBE EMPRENDE (€50K) - Aceleración cybersecurity startups
* ENISA Digital (€75K) - Préstamo participativo
- **Total potencial**: €325.000 en financiación no dilutiva adicional
- Ventaja competitiva: Security-First vs. competidores legacy
## 7. Nueva Tabla de Costes Operativos - Añadir Línea de Seguridad
| **Seguridad y Cumplimiento** | | |
| Certificado SSL (Let's Encrypt) | Gratuito | €0 | Renovación automática |
| SigNoz Observability (self-hosted) | Incluido en VPS | €0 | Vs. €500+/mes en SaaS |
| Auditoría RGPD anual | Externa | €1,200/año | Compliance obligatorio |
| Backups cifrados (off-site) | Backblaze B2 | €5/mes | €60/año |
## 8. Actualizar Roadmap (Sección 9) - Añadir Hitos de Seguridad
**Q1 2026:**
- ✅ Implementación MFA (Multi-Factor Authentication)
- ✅ Solicitud Digital Europe Programme (UPTAKE)
- ✅ Auditoría RGPD externa
**Q2 2026:**
- 📋 Certificación ISO 27001 (inicio proceso)
- 📋 Implementación NIS2 compliance
- 📋 Solicitud INCIBE EMPRENDE
**Q3 2026:**
- 📋 Penetration testing externo
- 📋 Certificación ENS (Esquema Nacional de Seguridad)
## 9. Actualizar Petición Concreta a VUE (Sección 9.2)
### Añadir nuevo punto 5:
**5. Conexión con Programas Europeos de Ciberseguridad:**
- Orientación para solicitud Digital Europe Programme (UPTAKE: €200K)
- Introducción a INCIBE EMPRENDE y red de 34 entidades colaboradoras
- Asesoramiento en preparación de propuestas técnicas para fondos EU
- Contacto con CDTI para programa NEOTEC (R&D+Ciberseguridad)
## 10. Añadir Anexo 7 - Compliance y Certificaciones
## ANEXO 7: ROADMAP DE COMPLIANCE Y CERTIFICACIONES
**Normativas Aplicables:**
- ✅ RGPD (Reglamento General de Protección de Datos) - Implementado
- 📋 NIS2 Directive (Seguridad de redes y sistemas de información) - Q2 2026
- 📋 Cyber Resilience Act (Productos digitales seguros) - Q3 2026
- 📋 AI Act (Transparencia y auditoría de IA) - Q4 2026
**Certificaciones Planificadas:**
- 📋 ISO 27001 (Gestión de Seguridad de la Información) - 12-18 meses
- 📋 ENS Medio (Esquema Nacional de Seguridad) - 6-9 meses
- 📋 SOC 2 Type II (para clientes Enterprise) - 18-24 meses
**Inversión Estimada en Compliance (3 años):** €25,000
**ROI Esperado:** Acceso a clientes Enterprise (+€150K ARR potencial)
## Resumen de Cambios Cuantitativos:
- Nueva financiación identificada: €325.000 (vs. €18.000 original)
- Nuevos programas: 3 líneas europeas de ciberseguridad
- Secciones nuevas: 2 (Seguridad 5.3, Compliance Anexo 7)
- Actualizaciones: 8 secciones existentes mejoradas
- Ventaja competitiva: Security-First enfatizada en 4 lugares

File diff suppressed because it is too large Load Diff

View File

@@ -7,6 +7,13 @@
# - PostgreSQL pgcrypto extension and audit logging
# - Organized resource dependencies and live-reload capabilities
# - Local registry for faster image builds and deployments
#
# Build Optimization:
# - Services only rebuild when their specific code changes (not all services)
# - Shared folder changes trigger rebuild of ALL services (as they all depend on it)
# - Uses 'only' parameter to watch only relevant files per service
# - Frontend only rebuilds when frontend/ code changes
# - Gateway only rebuilds when gateway/ or shared/ code changes
# =============================================================================
# =============================================================================
@@ -197,16 +204,25 @@ k8s_yaml(kustomize('infrastructure/kubernetes/overlays/dev'))
# =============================================================================
# Helper function for Python services with live updates
# This function ensures services only rebuild when their specific code changes,
# but all services rebuild when shared/ folder changes
def build_python_service(service_name, service_path):
docker_build(
'bakery/' + service_name,
context='.',
dockerfile='./services/' + service_path + '/Dockerfile',
# Only watch files relevant to this specific service + shared code
only=[
'./services/' + service_path,
'./shared',
'./scripts',
],
live_update=[
# Fall back to full image build if Dockerfile or requirements change
fall_back_on([
'./services/' + service_path + '/Dockerfile',
'./services/' + service_path + '/requirements.txt'
'./services/' + service_path + '/requirements.txt',
'./shared/requirements-tracing.txt',
]),
# Sync service code
@@ -290,10 +306,21 @@ docker_build(
'bakery/gateway',
context='.',
dockerfile='./gateway/Dockerfile',
# Only watch gateway-specific files and shared code
only=[
'./gateway',
'./shared',
'./scripts',
],
live_update=[
fall_back_on(['./gateway/Dockerfile', './gateway/requirements.txt']),
fall_back_on([
'./gateway/Dockerfile',
'./gateway/requirements.txt',
'./shared/requirements-tracing.txt',
]),
sync('./gateway', '/app'),
sync('./shared', '/app/shared'),
sync('./scripts', '/app/scripts'),
run('kill -HUP 1', trigger=['./gateway/**/*.py', './shared/**/*.py']),
],
ignore=[
@@ -680,6 +707,13 @@ Documentation:
docs/SECURITY_IMPLEMENTATION_COMPLETE.md
docs/DATABASE_SECURITY_ANALYSIS_REPORT.md
Build Optimization Active:
✅ Services only rebuild when their code changes
✅ Shared folder changes trigger ALL services (as expected)
✅ Reduces unnecessary rebuilds and disk usage
💡 Edit service code: only that service rebuilds
💡 Edit shared/ code: all services rebuild (required)
Useful Commands:
# Work on specific services only
tilt up <service-name> <service-name>

View File

@@ -198,16 +198,27 @@ export const RegisterTenantStep: React.FC<RegisterTenantStepProps> = ({
// Trigger POI detection in the background (non-blocking)
// This replaces the removed POI Detection step
// POI detection will be cached for 90 days and reused during training
const bakeryLocation = wizardContext.state.bakeryLocation;
if (bakeryLocation?.latitude && bakeryLocation?.longitude && tenant.id) {
console.log(`🔍 Triggering background POI detection for tenant ${tenant.id}...`);
// Run POI detection asynchronously without blocking the wizard flow
// This ensures POI data is ready before the training step
poiContextApi.detectPOIs(
tenant.id,
bakeryLocation.latitude,
bakeryLocation.longitude,
false // use_cache = false for initial detection
false // force_refresh = false, will use cache if available
).then((result) => {
console.log(`✅ POI detection completed automatically for tenant ${tenant.id}:`, result.summary);
const source = result.source || 'unknown';
console.log(`✅ POI detection completed for tenant ${tenant.id} (source: ${source})`);
if (result.poi_context) {
const totalPois = result.poi_context.total_pois_detected || 0;
const relevantCategories = result.poi_context.relevant_categories?.length || 0;
console.log(`📍 POI Summary: ${totalPois} POIs detected, ${relevantCategories} relevant categories`);
}
// Phase 3: Handle calendar suggestion if available
if (result.calendar_suggestion) {
@@ -230,8 +241,11 @@ export const RegisterTenantStep: React.FC<RegisterTenantStepProps> = ({
}
}).catch((error) => {
console.warn('⚠️ Background POI detection failed (non-blocking):', error);
// This is non-critical, so we don't block the user
console.warn('Training will continue without POI features if detection is not complete.');
// This is non-critical - training service will continue without POI features
});
} else {
console.warn('⚠️ Cannot trigger POI detection: missing location data or tenant ID');
}
// Update the wizard context with tenant info

View File

@@ -352,6 +352,25 @@ headers = {
- **Caching**: Gateway caches validated service tokens for 5 minutes
- **No Additional HTTP Calls**: Service auth happens locally at gateway
### Unified Header Management System
The gateway uses a **centralized HeaderManager** for consistent header handling across all middleware and proxy layers.
**Key Features:**
- Standardized header names and conventions
- Automatic header sanitization to prevent spoofing
- Unified header injection and forwarding
- Cross-middleware header access via `request.state.injected_headers`
- Consistent logging and error handling
**Standard Headers:**
- `x-user-id`, `x-user-email`, `x-user-role`, `x-user-type`
- `x-service-name`, `x-tenant-id`
- `x-subscription-tier`, `x-subscription-status`
- `x-is-demo`, `x-demo-session-id`, `x-demo-account-type`
- `x-tenant-access-type`, `x-can-view-children`, `x-parent-tenant-id`
- `x-forwarded-by`, `x-request-id`
### Context Header Injection
When a service token is validated, the gateway injects these headers for downstream services:

View File

@@ -0,0 +1,345 @@
"""
Unified Header Management System for API Gateway
Centralized header injection, forwarding, and validation
"""
import structlog
from fastapi import Request
from typing import Dict, Any, Optional, List
logger = structlog.get_logger()
class HeaderManager:
"""
Centralized header management for consistent header handling across gateway
"""
# Standard header names (lowercase for consistency)
STANDARD_HEADERS = {
'user_id': 'x-user-id',
'user_email': 'x-user-email',
'user_role': 'x-user-role',
'user_type': 'x-user-type',
'service_name': 'x-service-name',
'tenant_id': 'x-tenant-id',
'subscription_tier': 'x-subscription-tier',
'subscription_status': 'x-subscription-status',
'is_demo': 'x-is-demo',
'demo_session_id': 'x-demo-session-id',
'demo_account_type': 'x-demo-account-type',
'tenant_access_type': 'x-tenant-access-type',
'can_view_children': 'x-can-view-children',
'parent_tenant_id': 'x-parent-tenant-id',
'forwarded_by': 'x-forwarded-by',
'request_id': 'x-request-id'
}
# Headers that should be sanitized/removed from incoming requests
SANITIZED_HEADERS = [
'x-subscription-',
'x-user-',
'x-tenant-',
'x-demo-',
'x-forwarded-by'
]
# Headers that should be forwarded to downstream services
FORWARDABLE_HEADERS = [
'authorization',
'content-type',
'accept',
'accept-language',
'user-agent',
'x-internal-service' # Required for internal service-to-service ML/alert triggers
]
def __init__(self):
self._initialized = False
def initialize(self):
"""Initialize header manager"""
if not self._initialized:
logger.info("HeaderManager initialized")
self._initialized = True
def sanitize_incoming_headers(self, request: Request) -> None:
"""
Remove sensitive headers from incoming request to prevent spoofing
"""
if not hasattr(request.headers, '_list'):
return
# Filter out headers that start with sanitized prefixes
sanitized_headers = [
(k, v) for k, v in request.headers.raw
if not any(k.decode().lower().startswith(prefix.lower())
for prefix in self.SANITIZED_HEADERS)
]
request.headers.__dict__["_list"] = sanitized_headers
logger.debug("Sanitized incoming headers")
def inject_context_headers(self, request: Request, user_context: Dict[str, Any],
tenant_id: Optional[str] = None) -> Dict[str, str]:
"""
Inject standardized context headers into request
Returns dict of injected headers for reference
"""
injected_headers = {}
# Ensure headers list exists
if not hasattr(request.headers, '_list'):
request.headers.__dict__["_list"] = []
# Store headers in request.state for cross-middleware access
request.state.injected_headers = {}
# User context headers
if user_context.get('user_id'):
header_name = self.STANDARD_HEADERS['user_id']
header_value = str(user_context['user_id'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
if user_context.get('email'):
header_name = self.STANDARD_HEADERS['user_email']
header_value = str(user_context['email'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
if user_context.get('role'):
header_name = self.STANDARD_HEADERS['user_role']
header_value = str(user_context['role'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# User type (service vs regular user)
if user_context.get('type'):
header_name = self.STANDARD_HEADERS['user_type']
header_value = str(user_context['type'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Service name for service tokens
if user_context.get('service'):
header_name = self.STANDARD_HEADERS['service_name']
header_value = str(user_context['service'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Tenant context
if tenant_id:
header_name = self.STANDARD_HEADERS['tenant_id']
header_value = str(tenant_id)
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Subscription context
if user_context.get('subscription_tier'):
header_name = self.STANDARD_HEADERS['subscription_tier']
header_value = str(user_context['subscription_tier'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
if user_context.get('subscription_status'):
header_name = self.STANDARD_HEADERS['subscription_status']
header_value = str(user_context['subscription_status'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Demo session context
is_demo = user_context.get('is_demo', False)
if is_demo:
header_name = self.STANDARD_HEADERS['is_demo']
header_value = "true"
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
if user_context.get('demo_session_id'):
header_name = self.STANDARD_HEADERS['demo_session_id']
header_value = str(user_context['demo_session_id'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
if user_context.get('demo_account_type'):
header_name = self.STANDARD_HEADERS['demo_account_type']
header_value = str(user_context['demo_account_type'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Hierarchical access context
if tenant_id:
tenant_access_type = getattr(request.state, 'tenant_access_type', 'direct')
can_view_children = getattr(request.state, 'can_view_children', False)
header_name = self.STANDARD_HEADERS['tenant_access_type']
header_value = str(tenant_access_type)
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
header_name = self.STANDARD_HEADERS['can_view_children']
header_value = str(can_view_children).lower()
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Parent tenant ID if hierarchical access
parent_tenant_id = getattr(request.state, 'parent_tenant_id', None)
if parent_tenant_id:
header_name = self.STANDARD_HEADERS['parent_tenant_id']
header_value = str(parent_tenant_id)
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Gateway identification
header_name = self.STANDARD_HEADERS['forwarded_by']
header_value = "bakery-gateway"
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Request ID if available
request_id = getattr(request.state, 'request_id', None)
if request_id:
header_name = self.STANDARD_HEADERS['request_id']
header_value = str(request_id)
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
logger.info("🔧 Injected context headers",
user_id=user_context.get('user_id'),
user_type=user_context.get('type', ''),
service_name=user_context.get('service', ''),
role=user_context.get('role', ''),
tenant_id=tenant_id,
is_demo=is_demo,
demo_session_id=user_context.get('demo_session_id', ''),
path=request.url.path)
return injected_headers
def _add_header(self, request: Request, header_name: str, header_value: str) -> None:
"""
Safely add header to request
"""
try:
request.headers.__dict__["_list"].append((header_name.encode(), header_value.encode()))
except Exception as e:
logger.warning(f"Failed to add header {header_name}: {e}")
def get_forwardable_headers(self, request: Request) -> Dict[str, str]:
"""
Get headers that should be forwarded to downstream services
Includes both original request headers and injected context headers
"""
forwardable_headers = {}
# Add forwardable original headers
for header_name in self.FORWARDABLE_HEADERS:
header_value = request.headers.get(header_name)
if header_value:
forwardable_headers[header_name] = header_value
# Add injected context headers from request.state
if hasattr(request.state, 'injected_headers'):
for header_name, header_value in request.state.injected_headers.items():
forwardable_headers[header_name] = header_value
# Add authorization header if present
auth_header = request.headers.get('authorization')
if auth_header:
forwardable_headers['authorization'] = auth_header
return forwardable_headers
def get_all_headers_for_proxy(self, request: Request,
additional_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
"""
Get complete set of headers for proxying to downstream services
"""
headers = self.get_forwardable_headers(request)
# Add any additional headers
if additional_headers:
headers.update(additional_headers)
# Remove host header as it will be set by httpx
headers.pop('host', None)
return headers
def validate_required_headers(self, request: Request, required_headers: List[str]) -> bool:
"""
Validate that required headers are present
"""
missing_headers = []
for header_name in required_headers:
# Check in injected headers first
if hasattr(request.state, 'injected_headers'):
if header_name in request.state.injected_headers:
continue
# Check in request headers
if request.headers.get(header_name):
continue
missing_headers.append(header_name)
if missing_headers:
logger.warning(f"Missing required headers: {missing_headers}")
return False
return True
def get_header_value(self, request: Request, header_name: str,
default: Optional[str] = None) -> Optional[str]:
"""
Get header value from either injected headers or request headers
"""
# Check injected headers first
if hasattr(request.state, 'injected_headers'):
if header_name in request.state.injected_headers:
return request.state.injected_headers[header_name]
# Check request headers
return request.headers.get(header_name, default)
def add_header_for_middleware(self, request: Request, header_name: str, header_value: str) -> None:
"""
Allow middleware to add headers to the unified header system
This ensures all headers are available for proxying
"""
# Ensure injected_headers exists
if not hasattr(request.state, 'injected_headers'):
request.state.injected_headers = {}
# Add header to injected_headers
request.state.injected_headers[header_name] = header_value
# Also add to actual request headers for compatibility
try:
request.headers.__dict__["_list"].append((header_name.encode(), header_value.encode()))
except Exception as e:
logger.warning(f"Failed to add header {header_name} to request headers: {e}")
logger.debug(f"Middleware added header: {header_name} = {header_value}")
# Global instance for easy access
header_manager = HeaderManager()

View File

@@ -16,6 +16,7 @@ from shared.redis_utils import initialize_redis, close_redis, get_redis_client
from shared.service_base import StandardFastAPIService
from app.core.config import settings
from app.core.header_manager import header_manager
from app.middleware.request_id import RequestIDMiddleware
from app.middleware.auth import AuthMiddleware
from app.middleware.logging import LoggingMiddleware
@@ -50,6 +51,10 @@ class GatewayService(StandardFastAPIService):
"""Custom startup logic for Gateway"""
global redis_client
# Initialize HeaderManager
header_manager.initialize()
logger.info("HeaderManager initialized")
# Initialize Redis
try:
await initialize_redis(settings.REDIS_URL, db=0, max_connections=50)

View File

@@ -14,6 +14,7 @@ import httpx
import json
from app.core.config import settings
from app.core.header_manager import header_manager
from shared.auth.jwt_handler import JWTHandler
from shared.auth.tenant_access import tenant_access_manager, extract_tenant_id_from_path, is_tenant_scoped_path
@@ -60,15 +61,8 @@ class AuthMiddleware(BaseHTTPMiddleware):
if request.method == "OPTIONS":
return await call_next(request)
# SECURITY: Remove any incoming x-subscription-* headers
# These will be re-injected from verified JWT only
sanitized_headers = [
(k, v) for k, v in request.headers.raw
if not k.decode().lower().startswith('x-subscription-')
and not k.decode().lower().startswith('x-user-')
and not k.decode().lower().startswith('x-tenant-')
]
request.headers.__dict__["_list"] = sanitized_headers
# SECURITY: Remove any incoming sensitive headers using HeaderManager
header_manager.sanitize_incoming_headers(request)
# Skip authentication for public routes
if self._is_public_route(request.url.path):
@@ -573,109 +567,13 @@ class AuthMiddleware(BaseHTTPMiddleware):
async def _inject_context_headers(self, request: Request, user_context: Dict[str, Any], tenant_id: Optional[str] = None):
"""
Inject user and tenant context headers for downstream services
ENHANCED: Added logging to verify header injection
Inject user and tenant context headers for downstream services using unified HeaderManager
"""
# Enhanced logging for debugging
logger.info(
"🔧 Injecting context headers",
user_id=user_context.get("user_id"),
user_type=user_context.get("type", ""),
service_name=user_context.get("service", ""),
role=user_context.get("role", ""),
tenant_id=tenant_id,
is_demo=user_context.get("is_demo", False),
demo_session_id=user_context.get("demo_session_id", ""),
path=request.url.path
)
# Add user context headers
logger.debug(f"DEBUG: Injecting headers for user: {user_context.get('user_id')}, is_demo: {user_context.get('is_demo', False)}")
logger.debug(f"DEBUG: request.headers object id: {id(request.headers)}, _list id: {id(request.headers.__dict__.get('_list', []))}")
# Store headers in request.state for cross-middleware access
request.state.injected_headers = {
"x-user-id": user_context["user_id"],
"x-user-email": user_context["email"],
"x-user-role": user_context.get("role", "user")
}
request.headers.__dict__["_list"].append((
b"x-user-id", user_context["user_id"].encode()
))
request.headers.__dict__["_list"].append((
b"x-user-email", user_context["email"].encode()
))
user_role = user_context.get("role", "user")
request.headers.__dict__["_list"].append((
b"x-user-role", user_role.encode()
))
user_type = user_context.get("type", "")
if user_type:
request.headers.__dict__["_list"].append((
b"x-user-type", user_type.encode()
))
service_name = user_context.get("service", "")
if service_name:
request.headers.__dict__["_list"].append((
b"x-service-name", service_name.encode()
))
# Add tenant context if available
if tenant_id:
request.headers.__dict__["_list"].append((
b"x-tenant-id", tenant_id.encode()
))
# Add subscription tier if available
subscription_tier = user_context.get("subscription_tier", "")
if subscription_tier:
request.headers.__dict__["_list"].append((
b"x-subscription-tier", subscription_tier.encode()
))
# Add is_demo flag for demo sessions
is_demo = user_context.get("is_demo", False)
logger.debug(f"DEBUG: is_demo value: {is_demo}, type: {type(is_demo)}")
if is_demo:
logger.info(f"🎭 Adding demo session headers",
demo_session_id=user_context.get("demo_session_id", ""),
demo_account_type=user_context.get("demo_account_type", ""),
path=request.url.path)
request.headers.__dict__["_list"].append((
b"x-is-demo", b"true"
))
else:
logger.debug(f"DEBUG: Not adding demo headers because is_demo is: {is_demo}")
# Add demo session context headers for backend services
demo_session_id = user_context.get("demo_session_id", "")
if demo_session_id:
request.headers.__dict__["_list"].append((
b"x-demo-session-id", demo_session_id.encode()
))
demo_account_type = user_context.get("demo_account_type", "")
if demo_account_type:
request.headers.__dict__["_list"].append((
b"x-demo-account-type", demo_account_type.encode()
))
# Use unified HeaderManager for consistent header injection
injected_headers = header_manager.inject_context_headers(request, user_context, tenant_id)
# Add hierarchical access headers if tenant context exists
if tenant_id:
tenant_access_type = getattr(request.state, 'tenant_access_type', 'direct')
can_view_children = getattr(request.state, 'can_view_children', False)
request.headers.__dict__["_list"].append((
b"x-tenant-access-type", tenant_access_type.encode()
))
request.headers.__dict__["_list"].append((
b"x-can-view-children", str(can_view_children).encode()
))
# If this is hierarchical access, include parent tenant ID
# Get parent tenant ID from the auth service if available
try:
@@ -689,17 +587,16 @@ class AuthMiddleware(BaseHTTPMiddleware):
hierarchy_data = response.json()
parent_tenant_id = hierarchy_data.get("parent_tenant_id")
if parent_tenant_id:
request.headers.__dict__["_list"].append((
b"x-parent-tenant-id", parent_tenant_id.encode()
))
# Add parent tenant ID using HeaderManager for consistency
header_name = header_manager.STANDARD_HEADERS['parent_tenant_id']
header_value = str(parent_tenant_id)
header_manager.add_header_for_middleware(request, header_name, header_value)
logger.info(f"Added parent tenant ID header: {parent_tenant_id}")
except Exception as e:
logger.warning(f"Failed to get parent tenant ID: {e}")
pass
# Add gateway identification
request.headers.__dict__["_list"].append((
b"x-forwarded-by", b"bakery-gateway"
))
return injected_headers
async def _get_tenant_subscription_tier(self, tenant_id: str, request: Request) -> Optional[str]:
"""

View File

@@ -45,8 +45,17 @@ class APIRateLimitMiddleware(BaseHTTPMiddleware):
return await call_next(request)
try:
# Get subscription tier
# Get subscription tier from headers (added by AuthMiddleware)
subscription_tier = request.headers.get("x-subscription-tier")
if not subscription_tier:
# Fallback: get from request state if headers not available
subscription_tier = getattr(request.state, "subscription_tier", None)
if not subscription_tier:
# Final fallback: get from tenant service (should rarely happen)
subscription_tier = await self._get_subscription_tier(tenant_id, request)
logger.warning(f"Subscription tier not found in headers or state, fetched from tenant service: {subscription_tier}")
# Get quota limit for tier
quota_limit = self._get_quota_limit(subscription_tier)

View File

@@ -9,6 +9,8 @@ from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from app.core.header_manager import header_manager
logger = structlog.get_logger()
@@ -40,11 +42,9 @@ class RequestIDMiddleware(BaseHTTPMiddleware):
# Bind request ID to structured logger context
logger_ctx = logger.bind(request_id=request_id)
# Inject request ID header for downstream services
# This is done by modifying the headers that will be forwarded
request.headers.__dict__["_list"].append((
b"x-request-id", request_id.encode()
))
# Inject request ID header for downstream services using HeaderManager
# Note: This runs early in middleware chain, so we use add_header_for_middleware
header_manager.add_header_for_middleware(request, "x-request-id", request_id)
# Log request start
logger_ctx.info(

View File

@@ -15,6 +15,7 @@ import asyncio
from datetime import datetime, timezone
from app.core.config import settings
from app.core.header_manager import header_manager
from app.utils.subscription_error_responses import create_upgrade_required_response
logger = structlog.get_logger()
@@ -178,7 +179,10 @@ class SubscriptionMiddleware(BaseHTTPMiddleware):
r'/api/v1/subscriptions/.*', # Subscription management itself
r'/api/v1/tenants/[^/]+/members.*', # Basic tenant info
r'/docs.*',
r'/openapi\.json'
r'/openapi\.json',
# Training monitoring endpoints (WebSocket and status checks)
r'/api/v1/tenants/[^/]+/training/jobs/.*/live.*', # WebSocket endpoint
r'/api/v1/tenants/[^/]+/training/jobs/.*/status.*', # Status polling endpoint
]
# Skip OPTIONS requests (CORS preflight)
@@ -275,21 +279,11 @@ class SubscriptionMiddleware(BaseHTTPMiddleware):
'current_tier': current_tier
}
# Use the same authentication pattern as gateway routes for fallback
headers = dict(request.headers)
headers.pop("host", None)
# Use unified HeaderManager for consistent header handling
headers = header_manager.get_all_headers_for_proxy(request)
# Extract user_id for logging (fallback path)
user_id = 'unknown'
# Add user context headers if available
if hasattr(request.state, 'user') and request.state.user:
user = request.state.user
user_id = str(user.get('user_id', 'unknown'))
headers["x-user-id"] = user_id
headers["x-user-email"] = str(user.get('email', ''))
headers["x-user-role"] = str(user.get('role', 'user'))
headers["x-user-full-name"] = str(user.get('full_name', ''))
headers["x-tenant-id"] = str(user.get('tenant_id', ''))
user_id = header_manager.get_header_value(request, 'x-user-id', 'unknown')
# Call tenant service fast tier endpoint with caching (fallback for old tokens)
timeout_config = httpx.Timeout(

View File

@@ -13,6 +13,7 @@ from fastapi.responses import JSONResponse
from typing import Dict, Any
from app.core.config import settings
from app.core.header_manager import header_manager
from app.core.service_discovery import ServiceDiscovery
from shared.monitoring.metrics import MetricsCollector
@@ -136,107 +137,32 @@ class AuthProxy:
return AUTH_SERVICE_URL
def _prepare_headers(self, headers, request=None) -> Dict[str, str]:
"""Prepare headers for forwarding (remove hop-by-hop headers)"""
# Remove hop-by-hop headers
hop_by_hop_headers = {
'connection', 'keep-alive', 'proxy-authenticate',
'proxy-authorization', 'te', 'trailers', 'upgrade'
}
# Convert headers to dict - get ALL headers including those added by middleware
# Middleware adds headers to _list, so we need to read from there
logger.debug(f"DEBUG: headers type: {type(headers)}, has _list: {hasattr(headers, '_list')}, has raw: {hasattr(headers, 'raw')}")
logger.debug(f"DEBUG: headers.__dict__ keys: {list(headers.__dict__.keys())}")
logger.debug(f"DEBUG: '_list' in headers.__dict__: {'_list' in headers.__dict__}")
if hasattr(headers, '_list'):
logger.debug(f"DEBUG: Entering _list branch")
logger.debug(f"DEBUG: headers object id: {id(headers)}, _list id: {id(headers.__dict__.get('_list', []))}")
# Get headers from the _list where middleware adds them
all_headers_list = headers.__dict__.get('_list', [])
logger.debug(f"DEBUG: _list length: {len(all_headers_list)}")
# Debug: Show first few headers in the list
debug_headers = []
for i, (k, v) in enumerate(all_headers_list):
if i < 5: # Show first 5 headers for debugging
key = k.decode() if isinstance(k, bytes) else k
value = v.decode() if isinstance(v, bytes) else v
debug_headers.append(f"{key}: {value}")
logger.debug(f"DEBUG: First headers in _list: {debug_headers}")
# Convert to dict for easier processing
"""Prepare headers for forwarding using unified HeaderManager"""
# Use unified HeaderManager to get all headers
if request:
all_headers = header_manager.get_all_headers_for_proxy(request)
logger.debug(f"DEBUG: Added headers from HeaderManager: {list(all_headers.keys())}")
else:
# Fallback: convert headers to dict manually
all_headers = {}
for k, v in all_headers_list:
if hasattr(headers, '_list'):
for k, v in headers.__dict__.get('_list', []):
key = k.decode() if isinstance(k, bytes) else k
value = v.decode() if isinstance(v, bytes) else v
all_headers[key] = value
# Debug: Show if x-user-id and x-is-demo are in the dict
logger.debug(f"DEBUG: x-user-id in all_headers: {'x-user-id' in all_headers}, x-is-demo in all_headers: {'x-is-demo' in all_headers}")
logger.debug(f"DEBUG: all_headers keys: {list(all_headers.keys())[:10]}...") # Show first 10 keys
logger.info(f"📤 Forwarding headers to auth service - x_user_id: {all_headers.get('x-user-id', 'MISSING')}, x_is_demo: {all_headers.get('x-is-demo', 'MISSING')}, x_demo_session_id: {all_headers.get('x-demo-session-id', 'MISSING')}, headers: {list(all_headers.keys())}")
# Check if headers are missing and try to get them from request.state
if request and hasattr(request, 'state') and hasattr(request.state, 'injected_headers'):
logger.debug(f"DEBUG: Found injected_headers in request.state: {request.state.injected_headers}")
# Add missing headers from request.state
if 'x-user-id' not in all_headers and 'x-user-id' in request.state.injected_headers:
all_headers['x-user-id'] = request.state.injected_headers['x-user-id']
logger.debug(f"DEBUG: Added x-user-id from request.state: {all_headers['x-user-id']}")
if 'x-user-email' not in all_headers and 'x-user-email' in request.state.injected_headers:
all_headers['x-user-email'] = request.state.injected_headers['x-user-email']
logger.debug(f"DEBUG: Added x-user-email from request.state: {all_headers['x-user-email']}")
if 'x-user-role' not in all_headers and 'x-user-role' in request.state.injected_headers:
all_headers['x-user-role'] = request.state.injected_headers['x-user-role']
logger.debug(f"DEBUG: Added x-user-role from request.state: {all_headers['x-user-role']}")
# Add is_demo flag if this is a demo session
if hasattr(request.state, 'is_demo_session') and request.state.is_demo_session:
all_headers['x-is-demo'] = 'true'
logger.debug(f"DEBUG: Added x-is-demo from request.state.is_demo_session")
# Filter out hop-by-hop headers
filtered_headers = {
k: v for k, v in all_headers.items()
if k.lower() not in hop_by_hop_headers
}
elif hasattr(headers, 'raw'):
logger.debug(f"DEBUG: Entering raw branch")
# Filter out hop-by-hop headers
filtered_headers = {
k: v for k, v in all_headers.items()
if k.lower() not in hop_by_hop_headers
}
elif hasattr(headers, 'raw'):
# Fallback to raw headers if _list not available
all_headers = {
k.decode() if isinstance(k, bytes) else k: v.decode() if isinstance(v, bytes) else v
for k, v in headers.raw
}
logger.info(f"📤 Forwarding headers to auth service - x_user_id: {all_headers.get('x-user-id', 'MISSING')}, x_is_demo: {all_headers.get('x-is-demo', 'MISSING')}, x_demo_session_id: {all_headers.get('x-demo-session-id', 'MISSING')}, headers: {list(all_headers.keys())}")
filtered_headers = {
k.decode() if isinstance(k, bytes) else k: v.decode() if isinstance(v, bytes) else v
for k, v in headers.raw
if (k.decode() if isinstance(k, bytes) else k).lower() not in hop_by_hop_headers
}
for k, v in headers.raw:
key = k.decode() if isinstance(k, bytes) else k
value = v.decode() if isinstance(v, bytes) else v
all_headers[key] = value
else:
# Handle case where headers is already a dict
logger.info(f"📤 Forwarding headers to auth service - x_user_id: {headers.get('x-user-id', 'MISSING')}, x_is_demo: {headers.get('x-is-demo', 'MISSING')}, x_demo_session_id: {headers.get('x-demo-session-id', 'MISSING')}, headers: {list(headers.keys())}")
# Headers is already a dict
all_headers = dict(headers)
filtered_headers = {
k: v for k, v in headers.items()
if k.lower() not in hop_by_hop_headers
}
# Debug logging
logger.info(f"📤 Forwarding headers - x_user_id: {all_headers.get('x-user-id', 'MISSING')}, x_is_demo: {all_headers.get('x-is-demo', 'MISSING')}, x_demo_session_id: {all_headers.get('x-demo-session-id', 'MISSING')}, headers: {list(all_headers.keys())}")
# Add gateway identifier
filtered_headers['X-Forwarded-By'] = 'bakery-gateway'
filtered_headers['X-Gateway-Version'] = '1.0.0'
return filtered_headers
return all_headers
def _prepare_response_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
"""Prepare response headers"""

View File

@@ -8,6 +8,7 @@ import httpx
import structlog
from app.core.config import settings
from app.core.header_manager import header_manager
logger = structlog.get_logger()
@@ -29,12 +30,8 @@ async def proxy_demo_service(path: str, request: Request):
if request.method in ["POST", "PUT", "PATCH"]:
body = await request.body()
# Forward headers (excluding host)
headers = {
key: value
for key, value in request.headers.items()
if key.lower() not in ["host", "content-length"]
}
# Use unified HeaderManager for consistent header forwarding
headers = header_manager.get_all_headers_for_proxy(request)
try:
async with httpx.AsyncClient(timeout=30.0) as client:

View File

@@ -5,6 +5,7 @@ from fastapi.responses import JSONResponse
import httpx
import structlog
from app.core.config import settings
from app.core.header_manager import header_manager
logger = structlog.get_logger()
router = APIRouter()
@@ -26,12 +27,8 @@ async def proxy_geocoding(request: Request, path: str):
if request.method in ["POST", "PUT", "PATCH"]:
body = await request.body()
# Forward headers (excluding host)
headers = {
key: value
for key, value in request.headers.items()
if key.lower() not in ["host", "content-length"]
}
# Use unified HeaderManager for consistent header forwarding
headers = header_manager.get_all_headers_for_proxy(request)
# Make the proxied request
async with httpx.AsyncClient(timeout=30.0) as client:

View File

@@ -8,6 +8,7 @@ from fastapi.responses import JSONResponse
import httpx
import structlog
from app.core.config import settings
from app.core.header_manager import header_manager
logger = structlog.get_logger()
router = APIRouter()
@@ -44,12 +45,8 @@ async def proxy_poi_context(request: Request, path: str):
if request.method in ["POST", "PUT", "PATCH"]:
body = await request.body()
# Copy headers (exclude host and content-length as they'll be set by httpx)
headers = {
key: value
for key, value in request.headers.items()
if key.lower() not in ["host", "content-length"]
}
# Use unified HeaderManager for consistent header forwarding
headers = header_manager.get_all_headers_for_proxy(request)
# Make the request to the external service
async with httpx.AsyncClient(timeout=60.0) as client:

View File

@@ -8,6 +8,7 @@ import httpx
import logging
from app.core.config import settings
from app.core.header_manager import header_manager
logger = logging.getLogger(__name__)
router = APIRouter()
@@ -45,9 +46,8 @@ async def _proxy_to_pos_service(request: Request, target_path: str):
try:
url = f"{settings.POS_SERVICE_URL}{target_path}"
# Forward headers
headers = dict(request.headers)
headers.pop("host", None)
# Use unified HeaderManager for consistent header forwarding
headers = header_manager.get_all_headers_for_proxy(request)
# Add query parameters
params = dict(request.query_params)

View File

@@ -9,6 +9,7 @@ import logging
from typing import Optional
from app.core.config import settings
from app.core.header_manager import header_manager
logger = logging.getLogger(__name__)
router = APIRouter()
@@ -98,29 +99,13 @@ async def _proxy_request(request: Request, target_path: str, service_url: str):
try:
url = f"{service_url}{target_path}"
# Forward headers and add user/tenant context
headers = dict(request.headers)
headers.pop("host", None)
# Use unified HeaderManager for consistent header forwarding
headers = header_manager.get_all_headers_for_proxy(request)
# Add user context headers if available
if hasattr(request.state, 'user') and request.state.user:
user = request.state.user
headers["x-user-id"] = str(user.get('user_id', ''))
headers["x-user-email"] = str(user.get('email', ''))
headers["x-user-role"] = str(user.get('role', 'user'))
headers["x-user-full-name"] = str(user.get('full_name', ''))
headers["x-tenant-id"] = str(user.get('tenant_id', ''))
# Add subscription context headers
if user.get('subscription_tier'):
headers["x-subscription-tier"] = str(user.get('subscription_tier', ''))
logger.debug(f"Forwarding subscription tier: {user.get('subscription_tier')}")
if user.get('subscription_status'):
headers["x-subscription-status"] = str(user.get('subscription_status', ''))
logger.debug(f"Forwarding subscription status: {user.get('subscription_status')}")
logger.info(f"Forwarding subscription request to {url} with user context: user_id={user.get('user_id')}, email={user.get('email')}, subscription_tier={user.get('subscription_tier', 'not_set')}")
# Debug logging
user_context = getattr(request.state, 'user', None)
if user_context:
logger.info(f"Forwarding subscription request to {url} with user context: user_id={user_context.get('user_id')}, email={user_context.get('email')}, subscription_tier={user_context.get('subscription_tier', 'not_set')}")
else:
logger.warning(f"No user context available when forwarding subscription request to {url}")

View File

@@ -10,6 +10,7 @@ import logging
from typing import Optional
from app.core.config import settings
from app.core.header_manager import header_manager
logger = logging.getLogger(__name__)
router = APIRouter()
@@ -715,36 +716,18 @@ async def _proxy_request(request: Request, target_path: str, service_url: str, t
try:
url = f"{service_url}{target_path}"
# Forward headers and add user/tenant context
headers = dict(request.headers)
headers.pop("host", None)
# Use unified HeaderManager for consistent header forwarding
headers = header_manager.get_all_headers_for_proxy(request)
# Add tenant ID header if provided
# Add tenant ID header if provided (override if needed)
if tenant_id:
headers["X-Tenant-ID"] = tenant_id
# Add user context headers if available
if hasattr(request.state, 'user') and request.state.user:
user = request.state.user
headers["x-user-id"] = str(user.get('user_id', ''))
headers["x-user-email"] = str(user.get('email', ''))
headers["x-user-role"] = str(user.get('role', 'user'))
headers["x-user-full-name"] = str(user.get('full_name', ''))
headers["x-tenant-id"] = tenant_id or str(user.get('tenant_id', ''))
# Add subscription context headers
if user.get('subscription_tier'):
headers["x-subscription-tier"] = str(user.get('subscription_tier', ''))
logger.debug(f"Forwarding subscription tier: {user.get('subscription_tier')}")
if user.get('subscription_status'):
headers["x-subscription-status"] = str(user.get('subscription_status', ''))
logger.debug(f"Forwarding subscription status: {user.get('subscription_status')}")
headers["x-tenant-id"] = tenant_id
# Debug logging
logger.info(f"Forwarding request to {url} with user context: user_id={user.get('user_id')}, email={user.get('email')}, tenant_id={tenant_id}, subscription_tier={user.get('subscription_tier', 'not_set')}")
user_context = getattr(request.state, 'user', None)
if user_context:
logger.info(f"Forwarding request to {url} with user context: user_id={user_context.get('user_id')}, email={user_context.get('email')}, tenant_id={tenant_id}, subscription_tier={user_context.get('subscription_tier', 'not_set')}")
else:
# Debug logging when no user context available
logger.warning(f"No user context available when forwarding request to {url}. request.state.user: {getattr(request.state, 'user', 'NOT_SET')}")
# Get request body if present
@@ -782,9 +765,10 @@ async def _proxy_request(request: Request, target_path: str, service_url: str, t
logger.info(f"Forwarding multipart request with files={list(files.keys()) if files else None}, data={list(data.keys()) if data else None}")
# Remove content-type from headers - httpx will set it with new boundary
headers.pop("content-type", None)
headers.pop("content-length", None)
# For multipart requests, we need to get fresh headers since httpx will set content-type
# Get all headers again to ensure we have the complete set
headers = header_manager.get_all_headers_for_proxy(request)
# httpx will automatically set content-type for multipart, so we don't need to remove it
else:
# For other content types, use body as before
body = await request.body()

View File

@@ -13,6 +13,7 @@ from typing import Dict, Any
import json
from app.core.config import settings
from app.core.header_manager import header_manager
from app.core.service_discovery import ServiceDiscovery
from shared.monitoring.metrics import MetricsCollector
@@ -136,64 +137,28 @@ class UserProxy:
return AUTH_SERVICE_URL
def _prepare_headers(self, headers, request=None) -> Dict[str, str]:
"""Prepare headers for forwarding (remove hop-by-hop headers)"""
# Remove hop-by-hop headers
hop_by_hop_headers = {
'connection', 'keep-alive', 'proxy-authenticate',
'proxy-authorization', 'te', 'trailers', 'upgrade'
}
# Convert headers to dict if it's a Headers object
# This ensures we get ALL headers including those added by middleware
if hasattr(headers, '_list'):
# Get headers from the _list where middleware adds them
all_headers_list = headers.__dict__.get('_list', [])
# Convert to dict for easier processing
"""Prepare headers for forwarding using unified HeaderManager"""
# Use unified HeaderManager to get all headers
if request:
all_headers = header_manager.get_all_headers_for_proxy(request)
else:
# Fallback: convert headers to dict manually
all_headers = {}
for k, v in all_headers_list:
if hasattr(headers, '_list'):
for k, v in headers.__dict__.get('_list', []):
key = k.decode() if isinstance(k, bytes) else k
value = v.decode() if isinstance(v, bytes) else v
all_headers[key] = value
# Check if headers are missing and try to get them from request.state
if request and hasattr(request, 'state') and hasattr(request.state, 'injected_headers'):
# Add missing headers from request.state
if 'x-user-id' not in all_headers and 'x-user-id' in request.state.injected_headers:
all_headers['x-user-id'] = request.state.injected_headers['x-user-id']
if 'x-user-email' not in all_headers and 'x-user-email' in request.state.injected_headers:
all_headers['x-user-email'] = request.state.injected_headers['x-user-email']
if 'x-user-role' not in all_headers and 'x-user-role' in request.state.injected_headers:
all_headers['x-user-role'] = request.state.injected_headers['x-user-role']
# Add is_demo flag if this is a demo session
if hasattr(request.state, 'is_demo_session') and request.state.is_demo_session:
all_headers['x-is-demo'] = 'true'
# Filter out hop-by-hop headers
filtered_headers = {
k: v for k, v in all_headers.items()
if k.lower() not in hop_by_hop_headers
}
elif hasattr(headers, 'raw'):
# FastAPI/Starlette Headers object - use raw to get all headers
filtered_headers = {
k.decode() if isinstance(k, bytes) else k: v.decode() if isinstance(v, bytes) else v
for k, v in headers.raw
if (k.decode() if isinstance(k, bytes) else k).lower() not in hop_by_hop_headers
}
for k, v in headers.raw:
key = k.decode() if isinstance(k, bytes) else k
value = v.decode() if isinstance(v, bytes) else v
all_headers[key] = value
else:
# Already a dict
filtered_headers = {
k: v for k, v in headers.items()
if k.lower() not in hop_by_hop_headers
}
# Headers is already a dict
all_headers = dict(headers)
# Add gateway identifier
filtered_headers['X-Forwarded-By'] = 'bakery-gateway'
filtered_headers['X-Gateway-Version'] = '1.0.0'
return filtered_headers
return all_headers
def _prepare_response_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
"""Prepare response headers"""

82
scripts/cleanup-docker.sh Executable file
View File

@@ -0,0 +1,82 @@
#!/bin/bash
# Docker Cleanup Script for Local Kubernetes Development
# This script helps prevent disk space issues by cleaning up unused Docker resources
set -e
echo "🧹 Docker Cleanup Script for Bakery-IA Local Development"
echo "========================================================="
echo ""
# Check if we should run automatically or ask for confirmation
AUTO_MODE=${1:-""}
# Show current disk usage
echo "📊 Current Docker Disk Usage:"
docker system df
echo ""
# Check Kind node disk usage if cluster is running
if docker ps | grep -q "bakery-ia-local-control-plane"; then
echo "📊 Kind Node Disk Usage:"
docker exec bakery-ia-local-control-plane df -h / /var | grep -E "(Filesystem|overlay|/dev/vdb1)"
echo ""
fi
# Calculate reclaimable space
RECLAIMABLE=$(docker system df | grep "Images" | awk '{print $4}')
echo "💾 Estimated reclaimable space: $RECLAIMABLE"
echo ""
# Ask for confirmation unless in auto mode
if [ "$AUTO_MODE" != "--auto" ]; then
read -p "Do you want to proceed with cleanup? (y/n) " -n 1 -r
echo ""
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
echo "❌ Cleanup cancelled"
exit 0
fi
fi
echo "🚀 Starting cleanup..."
echo ""
# Remove unused images (keep images from last 24 hours)
echo "1⃣ Removing unused Docker images..."
docker image prune -af --filter "until=24h" || true
echo ""
# Remove unused volumes
echo "2⃣ Removing unused Docker volumes..."
docker volume prune -f || true
echo ""
# Remove build cache
echo "3⃣ Removing build cache..."
docker builder prune -af || true
echo ""
# Show results
echo "✅ Cleanup completed!"
echo ""
echo "📊 Final Docker Disk Usage:"
docker system df
echo ""
# Check Kind node disk usage if cluster is running
if docker ps | grep -q "bakery-ia-local-control-plane"; then
echo "📊 Kind Node Disk Usage After Cleanup:"
docker exec bakery-ia-local-control-plane df -h / /var | grep -E "(Filesystem|overlay|/dev/vdb1)"
echo ""
# Warn if still above 80%
USAGE=$(docker exec bakery-ia-local-control-plane df -h /var | tail -1 | awk '{print $5}' | sed 's/%//')
if [ "$USAGE" -gt 80 ]; then
echo "⚠️ Warning: Disk usage is still above 80%. Consider:"
echo " - Deleting and recreating the Kind cluster"
echo " - Increasing Docker's disk allocation"
echo " - Running: docker system prune -a --volumes -f"
fi
fi
echo "🎉 All done!"

View File

@@ -1,82 +0,0 @@
# services/forecasting/app/clients/inventory_client.py
"""
Simple client for inventory service integration
Used when product names are not available locally
"""
import aiohttp
import structlog
from typing import Optional, Dict, Any
import os
logger = structlog.get_logger()
class InventoryServiceClient:
"""Simple client for inventory service interactions"""
def __init__(self, base_url: str = None):
self.base_url = base_url or os.getenv("INVENTORY_SERVICE_URL", "http://inventory-service:8000")
self.timeout = aiohttp.ClientTimeout(total=5) # 5 second timeout
async def get_product_name(self, tenant_id: str, inventory_product_id: str) -> Optional[str]:
"""
Get product name from inventory service
Returns None if service is unavailable or product not found
"""
try:
async with aiohttp.ClientSession(timeout=self.timeout) as session:
url = f"{self.base_url}/api/v1/products/{inventory_product_id}"
headers = {"X-Tenant-ID": tenant_id}
async with session.get(url, headers=headers) as response:
if response.status == 200:
data = await response.json()
return data.get("name", f"Product-{inventory_product_id}")
else:
logger.debug("Product not found in inventory service",
inventory_product_id=inventory_product_id,
status=response.status)
return None
except Exception as e:
logger.debug("Failed to get product name from inventory service",
inventory_product_id=inventory_product_id,
error=str(e))
return None
async def get_multiple_product_names(self, tenant_id: str, product_ids: list) -> Dict[str, str]:
"""
Get multiple product names efficiently
Returns a mapping of product_id -> product_name
"""
try:
async with aiohttp.ClientSession(timeout=self.timeout) as session:
url = f"{self.base_url}/api/v1/products/batch"
headers = {"X-Tenant-ID": tenant_id}
payload = {"product_ids": product_ids}
async with session.post(url, json=payload, headers=headers) as response:
if response.status == 200:
data = await response.json()
return {item["id"]: item["name"] for item in data.get("products", [])}
else:
logger.debug("Batch product lookup failed",
product_count=len(product_ids),
status=response.status)
return {}
except Exception as e:
logger.debug("Failed to get product names from inventory service",
product_count=len(product_ids),
error=str(e))
return {}
# Global client instance
_inventory_client = None
def get_inventory_client() -> InventoryServiceClient:
"""Get the global inventory client instance"""
global _inventory_client
if _inventory_client is None:
_inventory_client = InventoryServiceClient()
return _inventory_client

View File

@@ -12,7 +12,6 @@ from datetime import datetime, timedelta
import structlog
from shared.messaging import UnifiedEventPublisher
from app.clients.inventory_client import get_inventory_client
logger = structlog.get_logger()

View File

@@ -30,11 +30,11 @@ async def trigger_inventory_alerts(
- Expiring ingredients
- Overstock situations
Security: Protected by X-Internal-Service header check.
Security: Protected by x-internal-service header check.
"""
try:
# Verify internal service header
if not request or request.headers.get("X-Internal-Service") not in ["demo-session", "internal"]:
if not request or request.headers.get("x-internal-service") not in ["demo-session", "internal"]:
logger.warning("Unauthorized internal API call", tenant_id=str(tenant_id))
raise HTTPException(
status_code=403,

View File

@@ -350,7 +350,7 @@ async def generate_safety_stock_insights_internal(
This endpoint is called by the demo-session service after cloning data.
It uses the same ML logic as the public endpoint but with optimized defaults.
Security: Protected by X-Internal-Service header check.
Security: Protected by x-internal-service header check.
Args:
tenant_id: The tenant UUID
@@ -365,7 +365,7 @@ async def generate_safety_stock_insights_internal(
}
"""
# Verify internal service header
if not request or request.headers.get("X-Internal-Service") not in ["demo-session", "internal"]:
if not request or request.headers.get("x-internal-service") not in ["demo-session", "internal"]:
logger.warning("Unauthorized internal API call", tenant_id=tenant_id)
raise HTTPException(
status_code=403,

View File

@@ -29,7 +29,7 @@ async def trigger_delivery_tracking(
This endpoint is called by the demo session cloning process after POs are seeded
to generate realistic delivery alerts (arriving soon, overdue, etc.).
Security: Protected by X-Internal-Service header check.
Security: Protected by x-internal-service header check.
Args:
tenant_id: Tenant UUID to check deliveries for
@@ -49,7 +49,7 @@ async def trigger_delivery_tracking(
"""
try:
# Verify internal service header
if not request or request.headers.get("X-Internal-Service") not in ["demo-session", "internal"]:
if not request or request.headers.get("x-internal-service") not in ["demo-session", "internal"]:
logger.warning("Unauthorized internal API call", tenant_id=str(tenant_id))
raise HTTPException(
status_code=403,

View File

@@ -566,7 +566,7 @@ async def generate_price_insights_internal(
This endpoint is called by the demo-session service after cloning data.
It uses the same ML logic as the public endpoint but with optimized defaults.
Security: Protected by X-Internal-Service header check.
Security: Protected by x-internal-service header check.
Args:
tenant_id: The tenant UUID
@@ -581,7 +581,7 @@ async def generate_price_insights_internal(
}
"""
# Verify internal service header
if not request or request.headers.get("X-Internal-Service") not in ["demo-session", "internal"]:
if not request or request.headers.get("x-internal-service") not in ["demo-session", "internal"]:
logger.warning("Unauthorized internal API call", tenant_id=tenant_id)
raise HTTPException(
status_code=403,

View File

@@ -1,42 +1,45 @@
"""
FastAPI Dependencies for Procurement Service
Uses shared authentication infrastructure with UUID validation
"""
from fastapi import Header, HTTPException, status
from fastapi import Depends, HTTPException, status
from uuid import UUID
from typing import Optional
from sqlalchemy.ext.asyncio import AsyncSession
from .database import get_db
from shared.auth.decorators import get_current_tenant_id_dep
async def get_current_tenant_id(
x_tenant_id: Optional[str] = Header(None, alias="X-Tenant-ID")
tenant_id: Optional[str] = Depends(get_current_tenant_id_dep)
) -> UUID:
"""
Extract and validate tenant ID from request header.
Extract and validate tenant ID from request using shared infrastructure.
Adds UUID validation to ensure tenant ID format is correct.
Args:
x_tenant_id: Tenant ID from X-Tenant-ID header
tenant_id: Tenant ID from shared dependency
Returns:
UUID: Validated tenant ID
Raises:
HTTPException: If tenant ID is missing or invalid
HTTPException: If tenant ID is missing or invalid UUID format
"""
if not x_tenant_id:
if not tenant_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="X-Tenant-ID header is required"
detail="x-tenant-id header is required"
)
try:
return UUID(x_tenant_id)
return UUID(tenant_id)
except (ValueError, AttributeError):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid tenant ID format: {x_tenant_id}"
detail=f"Invalid tenant ID format: {tenant_id}"
)

View File

@@ -31,11 +31,11 @@ async def trigger_production_alerts(
- Equipment maintenance alerts
- Batch start delays
Security: Protected by X-Internal-Service header check.
Security: Protected by x-internal-service header check.
"""
try:
# Verify internal service header
if not request or request.headers.get("X-Internal-Service") not in ["demo-session", "internal"]:
if not request or request.headers.get("x-internal-service") not in ["demo-session", "internal"]:
logger.warning("Unauthorized internal API call", tenant_id=str(tenant_id))
raise HTTPException(
status_code=403,

View File

@@ -331,7 +331,7 @@ async def generate_yield_insights_internal(
This endpoint is called by the demo-session service after cloning data.
It uses the same ML logic as the public endpoint but with optimized defaults.
Security: Protected by X-Internal-Service header check.
Security: Protected by x-internal-service header check.
Args:
tenant_id: The tenant UUID
@@ -346,7 +346,7 @@ async def generate_yield_insights_internal(
}
"""
# Verify internal service header
if not request or request.headers.get("X-Internal-Service") not in ["demo-session", "internal"]:
if not request or request.headers.get("x-internal-service") not in ["demo-session", "internal"]:
logger.warning("Unauthorized internal API call", tenant_id=tenant_id)
raise HTTPException(
status_code=403,

View File

@@ -204,7 +204,7 @@ class TenantMemberRepository(TenantBaseRepository):
f"{auth_service_url}/api/v1/auth/users/batch",
json={"user_ids": user_ids},
timeout=10.0,
headers={"X-Internal-Service": "tenant-service"}
headers={"x-internal-service": "tenant-service"}
)
if response.status_code == 200:
@@ -226,7 +226,7 @@ class TenantMemberRepository(TenantBaseRepository):
response = await client.get(
f"{auth_service_url}/api/v1/auth/users/{user_id}",
timeout=5.0,
headers={"X-Internal-Service": "tenant-service"}
headers={"x-internal-service": "tenant-service"}
)
if response.status_code == 200:
user_data = response.json()
@@ -243,7 +243,7 @@ class TenantMemberRepository(TenantBaseRepository):
response = await client.get(
f"{auth_service_url}/api/v1/auth/users/{user_id}",
timeout=5.0,
headers={"X-Internal-Service": "tenant-service"}
headers={"x-internal-service": "tenant-service"}
)
if response.status_code == 200:
user_data = response.json()

View File

@@ -216,17 +216,24 @@ class HybridProphetXGBoost:
Get Prophet predictions for given dataframe.
Args:
prophet_result: Prophet model result from training
prophet_result: Prophet model result from training (contains model_path)
df: DataFrame with 'ds' column
Returns:
Array of predictions
"""
# Get the Prophet model from result
prophet_model = prophet_result.get('model')
# Get the model path from result instead of expecting the model object directly
model_path = prophet_result.get('model_path')
if prophet_model is None:
raise ValueError("Prophet model not found in result")
if model_path is None:
raise ValueError("Prophet model path not found in result")
# Load the actual Prophet model from the stored path
try:
import joblib
prophet_model = joblib.load(model_path)
except Exception as e:
raise ValueError(f"Failed to load Prophet model from path {model_path}: {str(e)}")
# Prepare dataframe for prediction
pred_df = df[['ds']].copy()
@@ -273,7 +280,8 @@ class HybridProphetXGBoost:
'reg_lambda': 1.0, # L2 regularization
'objective': 'reg:squarederror',
'random_state': 42,
'n_jobs': -1
'n_jobs': -1,
'early_stopping_rounds': 10
}
# Initialize model
@@ -285,7 +293,6 @@ class HybridProphetXGBoost:
model.fit,
X_train, y_train,
eval_set=[(X_val, y_val)],
early_stopping_rounds=10,
verbose=False
)
@@ -303,109 +310,86 @@ class HybridProphetXGBoost:
train_prophet_pred: np.ndarray,
val_prophet_pred: np.ndarray,
prophet_result: Dict[str, Any]
) -> Dict[str, float]:
) -> Dict[str, Any]:
"""
Evaluate hybrid model vs Prophet-only on validation set.
Args:
train_df: Training data
val_df: Validation data
train_prophet_pred: Prophet predictions on training set
val_prophet_pred: Prophet predictions on validation set
prophet_result: Prophet training result
Returns:
Dictionary of metrics
Evaluate the overall performance of the hybrid model using threading for metrics.
"""
# Get actual values
train_actual = train_df['y'].values
val_actual = val_df['y'].values
import asyncio
# Get XGBoost predictions on residuals
# Get XGBoost predictions on training and validation
X_train = train_df[self.feature_columns].values
X_val = val_df[self.feature_columns].values
# ✅ FIX: Run blocking predict() in thread pool to avoid blocking event loop
import asyncio
train_xgb_pred = await asyncio.to_thread(self.xgb_model.predict, X_train)
val_xgb_pred = await asyncio.to_thread(self.xgb_model.predict, X_val)
# Hybrid predictions = Prophet + XGBoost residual correction
# Hybrid prediction = Prophet prediction + XGBoost residual prediction
train_hybrid_pred = train_prophet_pred + train_xgb_pred
val_hybrid_pred = val_prophet_pred + val_xgb_pred
# Calculate metrics for Prophet-only
prophet_train_mae = mean_absolute_error(train_actual, train_prophet_pred)
prophet_val_mae = mean_absolute_error(val_actual, val_prophet_pred)
prophet_train_mape = mean_absolute_percentage_error(train_actual, train_prophet_pred) * 100
prophet_val_mape = mean_absolute_percentage_error(val_actual, val_prophet_pred) * 100
actual_train = train_df['y'].values
actual_val = val_df['y'].values
# Calculate metrics for Hybrid
hybrid_train_mae = mean_absolute_error(train_actual, train_hybrid_pred)
hybrid_val_mae = mean_absolute_error(val_actual, val_hybrid_pred)
hybrid_train_mape = mean_absolute_percentage_error(train_actual, train_hybrid_pred) * 100
hybrid_val_mape = mean_absolute_percentage_error(val_actual, val_hybrid_pred) * 100
# Basic RMSE calculation
train_rmse = float(np.sqrt(np.mean((actual_train - train_hybrid_pred)**2)))
val_rmse = float(np.sqrt(np.mean((actual_val - val_hybrid_pred)**2)))
# MAE
train_mae = float(np.mean(np.abs(actual_train - train_hybrid_pred)))
val_mae = float(np.mean(np.abs(actual_val - val_hybrid_pred)))
# MAPE (with safety for zero sales)
train_mape = float(np.mean(np.abs((actual_train - train_hybrid_pred) / np.maximum(actual_train, 1))))
val_mape = float(np.mean(np.abs((actual_val - val_hybrid_pred) / np.maximum(actual_val, 1))))
# Calculate improvement
mae_improvement = ((prophet_val_mae - hybrid_val_mae) / prophet_val_mae) * 100
mape_improvement = ((prophet_val_mape - hybrid_val_mape) / prophet_val_mape) * 100
prophet_metrics = prophet_result.get("metrics", {})
prophet_val_mae = prophet_metrics.get("val_mae", val_mae) # Fallback to hybrid if missing
prophet_val_mape = prophet_metrics.get("val_mape", val_mape)
improvement_pct = 0.0
if prophet_val_mape > 0:
improvement_pct = ((prophet_val_mape - val_mape) / prophet_val_mape) * 100
metrics = {
'prophet_train_mae': float(prophet_train_mae),
'prophet_val_mae': float(prophet_val_mae),
'prophet_train_mape': float(prophet_train_mape),
'prophet_val_mape': float(prophet_val_mape),
'hybrid_train_mae': float(hybrid_train_mae),
'hybrid_val_mae': float(hybrid_val_mae),
'hybrid_train_mape': float(hybrid_train_mape),
'hybrid_val_mape': float(hybrid_val_mape),
'mae_improvement_pct': float(mae_improvement),
'mape_improvement_pct': float(mape_improvement),
'improvement_percentage': float(mape_improvement) # Primary metric
"train_rmse": train_rmse,
"val_rmse": val_rmse,
"train_mae": train_mae,
"val_mae": val_mae,
"train_mape": train_mape,
"val_mape": val_mape,
"prophet_val_mape": prophet_val_mape,
"hybrid_val_mape": val_mape,
"improvement_percentage": float(improvement_pct),
"prophet_metrics": prophet_metrics
}
logger.info(
"Hybrid model evaluation complete",
val_rmse=val_rmse,
val_mae=val_mae,
val_mape=val_mape,
improvement=improvement_pct
)
return metrics
def _package_hybrid_model(
self,
prophet_result: Dict[str, Any],
metrics: Dict[str, float],
metrics: Dict[str, Any],
tenant_id: str,
inventory_product_id: str
) -> Dict[str, Any]:
"""
Package hybrid model for storage.
Args:
prophet_result: Prophet model result
metrics: Hybrid model metrics
tenant_id: Tenant ID
inventory_product_id: Product ID
Returns:
Model package dictionary
"""
return {
'model_type': 'hybrid_prophet_xgboost',
'prophet_model': prophet_result.get('model'),
'prophet_model_path': prophet_result.get('model_path'),
'xgboost_model': self.xgb_model,
'feature_columns': self.feature_columns,
'prophet_metrics': {
'train_mae': metrics['prophet_train_mae'],
'val_mae': metrics['prophet_val_mae'],
'train_mape': metrics['prophet_train_mape'],
'val_mape': metrics['prophet_val_mape']
},
'hybrid_metrics': {
'train_mae': metrics['hybrid_train_mae'],
'val_mae': metrics['hybrid_val_mae'],
'train_mape': metrics['hybrid_train_mape'],
'val_mape': metrics['hybrid_val_mape']
},
'improvement_metrics': {
'mae_improvement_pct': metrics['mae_improvement_pct'],
'mape_improvement_pct': metrics['mape_improvement_pct']
},
'metrics': metrics,
'tenant_id': tenant_id,
'inventory_product_id': inventory_product_id,
'trained_at': datetime.now(timezone.utc).isoformat()
@@ -426,8 +410,18 @@ class HybridProphetXGBoost:
Returns:
DataFrame with predictions
"""
# Step 1: Get Prophet predictions
prophet_model = model_data['prophet_model']
# Step 1: Get Prophet model from path and make predictions
prophet_model_path = model_data.get('prophet_model_path')
if prophet_model_path is None:
raise ValueError("Prophet model path not found in model data")
# Load the Prophet model from the stored path
try:
import joblib
prophet_model = joblib.load(prophet_model_path)
except Exception as e:
raise ValueError(f"Failed to load Prophet model from path {prophet_model_path}: {str(e)}")
# ✅ FIX: Run blocking predict() in thread pool to avoid blocking event loop
import asyncio
prophet_forecast = await asyncio.to_thread(prophet_model.predict, future_df)

View File

@@ -43,86 +43,79 @@ class POIFeatureIntegrator:
force_refresh: bool = False
) -> Optional[Dict[str, Any]]:
"""
Fetch POI features for tenant location.
Fetch POI features for tenant location (optimized for training).
First checks if POI context exists, if not, triggers detection.
First checks if POI context exists. If not, returns None without triggering detection.
POI detection should be triggered during tenant registration, not during training.
Args:
tenant_id: Tenant UUID
latitude: Bakery latitude
longitude: Bakery longitude
force_refresh: Force re-detection
force_refresh: Force re-detection (only use if POI context already exists)
Returns:
Dictionary with POI features or None if detection fails
Dictionary with POI features or None if not available
"""
try:
# Try to get existing POI context first
if not force_refresh:
existing_context = await self.external_client.get_poi_context(tenant_id)
if existing_context:
poi_context = existing_context.get("poi_context", {})
ml_features = poi_context.get("ml_features", {})
# Check if stale
# Check if stale and force_refresh is requested
is_stale = existing_context.get("is_stale", False)
if not is_stale:
if not is_stale or not force_refresh:
logger.info(
"Using existing POI context",
tenant_id=tenant_id
tenant_id=tenant_id,
is_stale=is_stale,
feature_count=len(ml_features)
)
return ml_features
else:
logger.info(
"POI context is stale, refreshing",
"POI context is stale and force_refresh=True, refreshing",
tenant_id=tenant_id
)
force_refresh = True
else:
logger.info(
"No existing POI context, will detect",
tenant_id=tenant_id
)
# Detect or refresh POIs
logger.info(
"Detecting POIs for tenant",
tenant_id=tenant_id,
location=(latitude, longitude)
)
# Only refresh if explicitly requested and context exists
detection_result = await self.external_client.detect_poi_for_tenant(
tenant_id=tenant_id,
latitude=latitude,
longitude=longitude,
force_refresh=force_refresh
force_refresh=True
)
if detection_result:
poi_context = detection_result.get("poi_context", {})
ml_features = poi_context.get("ml_features", {})
logger.info(
"POI detection completed",
"POI refresh completed",
tenant_id=tenant_id,
total_pois=poi_context.get("total_pois_detected", 0),
feature_count=len(ml_features)
)
return ml_features
else:
logger.error(
"POI detection failed",
logger.warning(
"POI refresh failed, returning existing features",
tenant_id=tenant_id
)
return ml_features
else:
logger.info(
"No existing POI context found - POI detection should be triggered during tenant registration",
tenant_id=tenant_id
)
return None
except Exception as e:
logger.error(
"Unexpected error fetching POI features",
logger.warning(
"Error fetching POI features - returning None",
tenant_id=tenant_id,
error=str(e),
exc_info=True
error=str(e)
)
return None

View File

@@ -29,16 +29,15 @@ class DataClient:
self.sales_client = get_sales_client(settings, "training")
self.external_client = get_external_client(settings, "training")
# ExternalServiceClient always has get_stored_traffic_data_for_training method
self.supports_stored_traffic_data = True
# Configure timeouts for HTTP clients
self._configure_timeouts()
# Initialize circuit breakers for external services
self._init_circuit_breakers()
# Check if the new method is available for stored traffic data
if hasattr(self.external_client, 'get_stored_traffic_data_for_training'):
self.supports_stored_traffic_data = True
def _configure_timeouts(self):
"""Configure appropriate timeouts for HTTP clients"""
timeout = httpx.Timeout(
@@ -49,14 +48,12 @@ class DataClient:
)
# Apply timeout to clients if they have httpx clients
# Note: BaseServiceClient manages its own HTTP client internally
if hasattr(self.sales_client, 'client') and isinstance(self.sales_client.client, httpx.AsyncClient):
self.sales_client.client.timeout = timeout
if hasattr(self.external_client, 'client') and isinstance(self.external_client.client, httpx.AsyncClient):
self.external_client.client.timeout = timeout
else:
self.supports_stored_traffic_data = False
logger.warning("Stored traffic data method not available in external client")
def _init_circuit_breakers(self):
"""Initialize circuit breakers for external service calls"""

View File

@@ -404,22 +404,32 @@ class TrainingDataOrchestrator:
tenant_id: str
) -> Dict[str, Any]:
"""
Collect POI features for bakery location.
Collect POI features for bakery location (non-blocking).
POI features are static (location-based, not time-varying).
This method is non-blocking with a short timeout to prevent training delays.
If POI detection hasn't been run yet, training continues without POI features.
Returns:
Dictionary with POI features or empty dict if unavailable
"""
try:
logger.info(
"Collecting POI features",
"Collecting POI features (non-blocking)",
tenant_id=tenant_id,
location=(lat, lon)
)
poi_features = await self.poi_feature_integrator.fetch_poi_features(
# Set a short timeout to prevent blocking training
# POI detection should have been triggered during tenant registration
poi_features = await asyncio.wait_for(
self.poi_feature_integrator.fetch_poi_features(
tenant_id=tenant_id,
latitude=lat,
longitude=lon,
force_refresh=False
),
timeout=15.0 # 15 second timeout - POI should be cached from registration
)
if poi_features:
@@ -430,18 +440,24 @@ class TrainingDataOrchestrator:
)
else:
logger.warning(
"No POI features collected (service may be unavailable)",
"No POI features collected (service may be unavailable or not yet detected)",
tenant_id=tenant_id
)
return poi_features or {}
except asyncio.TimeoutError:
logger.warning(
"POI collection timeout (15s) - continuing training without POI features. "
"POI detection should be triggered during tenant registration for best results.",
tenant_id=tenant_id
)
return {}
except Exception as e:
logger.error(
"Failed to collect POI features, continuing without them",
logger.warning(
"Failed to collect POI features (non-blocking) - continuing training without them",
tenant_id=tenant_id,
error=str(e),
exc_info=True
error=str(e)
)
return {}

View File

@@ -71,7 +71,7 @@ class ServiceAuthenticator:
}
if tenant_id:
headers["X-Tenant-ID"] = str(tenant_id)
headers["x-tenant-id"] = str(tenant_id)
return headers

View File

@@ -351,7 +351,7 @@ class ForecastServiceClient(BaseServiceClient):
"""
Trigger demand forecasting insights for a tenant (internal service use only).
This method calls the internal endpoint which is protected by X-Internal-Service header.
This method calls the internal endpoint which is protected by x-internal-service header.
Used by demo-session service after cloning to generate AI insights from seeded data.
Args:
@@ -366,7 +366,7 @@ class ForecastServiceClient(BaseServiceClient):
endpoint=f"forecasting/internal/ml/generate-demand-insights",
tenant_id=tenant_id,
data={"tenant_id": tenant_id},
headers={"X-Internal-Service": "demo-session"}
headers={"x-internal-service": "demo-session"}
)
if result:

View File

@@ -766,7 +766,7 @@ class InventoryServiceClient(BaseServiceClient):
"""
Trigger inventory alerts for a tenant (internal service use only).
This method calls the internal endpoint which is protected by X-Internal-Service header.
This method calls the internal endpoint which is protected by x-internal-service header.
The endpoint should trigger alerts specifically for the given tenant.
Args:
@@ -783,7 +783,7 @@ class InventoryServiceClient(BaseServiceClient):
endpoint="inventory/internal/alerts/trigger",
tenant_id=tenant_id,
data={},
headers={"X-Internal-Service": "demo-session"}
headers={"x-internal-service": "demo-session"}
)
if result:
@@ -819,7 +819,7 @@ class InventoryServiceClient(BaseServiceClient):
"""
Trigger safety stock optimization insights for a tenant (internal service use only).
This method calls the internal endpoint which is protected by X-Internal-Service header.
This method calls the internal endpoint which is protected by x-internal-service header.
Args:
tenant_id: Tenant ID to trigger insights for
@@ -833,7 +833,7 @@ class InventoryServiceClient(BaseServiceClient):
endpoint="inventory/internal/ml/generate-safety-stock-insights",
tenant_id=tenant_id,
data={"tenant_id": tenant_id},
headers={"X-Internal-Service": "demo-session"}
headers={"x-internal-service": "demo-session"}
)
if result:

View File

@@ -580,7 +580,7 @@ class ProcurementServiceClient(BaseServiceClient):
"""
Trigger delivery tracking for a tenant (internal service use only).
This method calls the internal endpoint which is protected by X-Internal-Service header.
This method calls the internal endpoint which is protected by x-internal-service header.
Args:
tenant_id: Tenant ID to trigger delivery tracking for
@@ -596,7 +596,7 @@ class ProcurementServiceClient(BaseServiceClient):
endpoint="procurement/internal/delivery-tracking/trigger",
tenant_id=tenant_id,
data={},
headers={"X-Internal-Service": "demo-session"}
headers={"x-internal-service": "demo-session"}
)
if result:
@@ -632,7 +632,7 @@ class ProcurementServiceClient(BaseServiceClient):
"""
Trigger price forecasting insights for a tenant (internal service use only).
This method calls the internal endpoint which is protected by X-Internal-Service header.
This method calls the internal endpoint which is protected by x-internal-service header.
Args:
tenant_id: Tenant ID to trigger insights for
@@ -646,7 +646,7 @@ class ProcurementServiceClient(BaseServiceClient):
endpoint="procurement/internal/ml/generate-price-insights",
tenant_id=tenant_id,
data={"tenant_id": tenant_id},
headers={"X-Internal-Service": "demo-session"}
headers={"x-internal-service": "demo-session"}
)
if result:

View File

@@ -630,7 +630,7 @@ class ProductionServiceClient(BaseServiceClient):
"""
Trigger production alerts for a tenant (internal service use only).
This method calls the internal endpoint which is protected by X-Internal-Service header.
This method calls the internal endpoint which is protected by x-internal-service header.
Includes both production alerts and equipment maintenance checks.
Args:
@@ -647,7 +647,7 @@ class ProductionServiceClient(BaseServiceClient):
endpoint="production/internal/alerts/trigger",
tenant_id=tenant_id,
data={},
headers={"X-Internal-Service": "demo-session"}
headers={"x-internal-service": "demo-session"}
)
if result:
@@ -683,7 +683,7 @@ class ProductionServiceClient(BaseServiceClient):
"""
Trigger yield improvement insights for a tenant (internal service use only).
This method calls the internal endpoint which is protected by X-Internal-Service header.
This method calls the internal endpoint which is protected by x-internal-service header.
Args:
tenant_id: Tenant ID to trigger insights for
@@ -697,7 +697,7 @@ class ProductionServiceClient(BaseServiceClient):
endpoint="production/internal/ml/generate-yield-insights",
tenant_id=tenant_id,
data={"tenant_id": tenant_id},
headers={"X-Internal-Service": "demo-session"}
headers={"x-internal-service": "demo-session"}
)
if result: