New enterprise feature2

This commit is contained in:
Urtzi Alfaro
2025-11-30 16:29:38 +01:00
parent 972db02f6d
commit 0da0470786
18 changed files with 698 additions and 76 deletions

View File

@@ -532,6 +532,10 @@ k8s_resource('demo-session-service',
resource_deps=['demo-session-migration', 'redis'],
labels=['services'])
k8s_resource('demo-cleanup-worker',
resource_deps=['demo-session-service', 'redis'],
labels=['services', 'workers'])
k8s_resource('distribution-service',
resource_deps=['distribution-migration', 'redis', 'rabbitmq'],
labels=['services'])

View File

@@ -481,6 +481,19 @@ class AuthMiddleware(BaseHTTPMiddleware):
b"x-is-demo", b"true"
))
# Add demo session context headers for backend services
demo_session_id = user_context.get("demo_session_id", "")
if demo_session_id:
request.headers.__dict__["_list"].append((
b"x-demo-session-id", demo_session_id.encode()
))
demo_account_type = user_context.get("demo_account_type", "")
if demo_account_type:
request.headers.__dict__["_list"].append((
b"x-demo-account-type", demo_account_type.encode()
))
# Add hierarchical access headers if tenant context exists
if tenant_id:
tenant_access_type = getattr(request.state, 'tenant_access_type', 'direct')

View File

@@ -23,34 +23,31 @@ spec:
app: demo-cleanup
spec:
containers:
- name: cleanup
image: bakery/demo-session-service:latest
- name: cleanup-trigger
image: curlimages/curl:latest
command:
- python
- sh
- -c
- |
import asyncio
import httpx
async def cleanup():
async with httpx.AsyncClient() as client:
response = await client.post("http://demo-session-service:8000/api/v1/demo/operations/cleanup")
print(response.json())
asyncio.run(cleanup())
env:
- name: DEMO_SESSION_DATABASE_URL
valueFrom:
secretKeyRef:
name: database-secrets
key: DEMO_SESSION_DATABASE_URL
- name: REDIS_URL
value: "redis://redis-service:6379/0"
- name: LOG_LEVEL
value: "INFO"
echo "Triggering demo session cleanup..."
response=$(curl -s -w "\n%{http_code}" -X POST http://demo-session-service:8000/api/v1/demo/operations/cleanup)
http_code=$(echo "$response" | tail -n 1)
body=$(echo "$response" | sed '$d')
echo "Response: $body"
echo "HTTP Status: $http_code"
if [ "$http_code" -ge 200 ] && [ "$http_code" -lt 300 ]; then
echo "Cleanup job enqueued successfully"
exit 0
else
echo "Failed to enqueue cleanup job"
exit 1
fi
resources:
requests:
memory: "128Mi"
cpu: "50m"
memory: "32Mi"
cpu: "10m"
limits:
memory: "256Mi"
cpu: "200m"
memory: "64Mi"
cpu: "50m"
restartPolicy: OnFailure
activeDeadlineSeconds: 30

View File

@@ -0,0 +1,96 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: demo-cleanup-worker
namespace: bakery-ia
labels:
app: demo-cleanup-worker
component: background-jobs
service: demo-session
spec:
replicas: 2
selector:
matchLabels:
app: demo-cleanup-worker
template:
metadata:
labels:
app: demo-cleanup-worker
component: background-jobs
service: demo-session
spec:
containers:
- name: worker
image: bakery/demo-session-service:latest
imagePullPolicy: IfNotPresent
command:
- python
- -m
- app.jobs.cleanup_worker
env:
- name: DEMO_SESSION_DATABASE_URL
valueFrom:
secretKeyRef:
name: database-secrets
key: DEMO_SESSION_DATABASE_URL
- name: REDIS_PASSWORD
valueFrom:
secretKeyRef:
name: redis-secrets
key: REDIS_PASSWORD
- name: REDIS_URL
value: "rediss://:$(REDIS_PASSWORD)@redis-service:6379/0?ssl_cert_reqs=none"
- name: LOG_LEVEL
value: "INFO"
- name: INTERNAL_API_KEY
valueFrom:
secretKeyRef:
name: demo-internal-api-key
key: INTERNAL_API_KEY
- name: INVENTORY_SERVICE_URL
value: "http://inventory-service:8000"
- name: RECIPES_SERVICE_URL
value: "http://recipes-service:8000"
- name: SALES_SERVICE_URL
value: "http://sales-service:8000"
- name: ORDERS_SERVICE_URL
value: "http://orders-service:8000"
- name: PRODUCTION_SERVICE_URL
value: "http://production-service:8000"
- name: SUPPLIERS_SERVICE_URL
value: "http://suppliers-service:8000"
- name: POS_SERVICE_URL
value: "http://pos-service:8000"
- name: PROCUREMENT_SERVICE_URL
value: "http://procurement-service:8000"
- name: DISTRIBUTION_SERVICE_URL
value: "http://distribution-service:8000"
- name: FORECASTING_SERVICE_URL
value: "http://forecasting-service:8000"
resources:
requests:
memory: "256Mi"
cpu: "100m"
limits:
memory: "512Mi"
cpu: "500m"
livenessProbe:
exec:
command:
- python
- -c
- "import sys; sys.exit(0)"
initialDelaySeconds: 30
periodSeconds: 60
timeoutSeconds: 5
failureThreshold: 3
readinessProbe:
exec:
command:
- python
- -c
- "import sys; sys.exit(0)"
initialDelaySeconds: 10
periodSeconds: 30
timeoutSeconds: 5
restartPolicy: Always

View File

@@ -15,6 +15,7 @@ resources:
- configmaps/postgres-logging-config.yaml
- secrets/postgres-tls-secret.yaml
- secrets/redis-tls-secret.yaml
- secrets/demo-internal-api-key-secret.yaml
# Additional configs
- configs/postgres-init-config.yaml
@@ -127,6 +128,9 @@ resources:
- components/demo-session/service.yaml
- components/demo-session/deployment.yaml
# Demo cleanup worker (background job processor)
- deployments/demo-cleanup-worker.yaml
# Microservices
- components/auth/auth-service.yaml
- components/tenant/tenant-service.yaml

View File

@@ -5,6 +5,7 @@ Demo Operations API - Business operations for demo session management
from fastapi import APIRouter, Depends, HTTPException, Path
import structlog
import jwt
from datetime import datetime, timezone
from app.api.schemas import DemoSessionResponse, DemoSessionStats
from app.services import DemoSessionManager, DemoCleanupService
@@ -83,10 +84,111 @@ async def run_cleanup(
db: AsyncSession = Depends(get_db),
redis: DemoRedisWrapper = Depends(get_redis)
):
"""Manually trigger session cleanup (BUSINESS OPERATION - Internal endpoint for CronJob)"""
cleanup_service = DemoCleanupService(db, redis)
stats = await cleanup_service.cleanup_expired_sessions()
return stats
"""
Trigger session cleanup via background worker (async via Redis queue)
Returns immediately after enqueuing work - does not block
"""
from datetime import timedelta
from sqlalchemy import select
from app.models.demo_session import DemoSession, DemoSessionStatus
import uuid
import json
logger.info("Starting demo session cleanup enqueue")
now = datetime.now(timezone.utc)
stuck_threshold = now - timedelta(minutes=5)
# Find expired sessions
result = await db.execute(
select(DemoSession).where(
DemoSession.status.in_([
DemoSessionStatus.PENDING,
DemoSessionStatus.READY,
DemoSessionStatus.PARTIAL,
DemoSessionStatus.FAILED,
DemoSessionStatus.ACTIVE
]),
DemoSession.expires_at < now
)
)
expired_sessions = result.scalars().all()
# Find stuck sessions
stuck_result = await db.execute(
select(DemoSession).where(
DemoSession.status == DemoSessionStatus.PENDING,
DemoSession.created_at < stuck_threshold
)
)
stuck_sessions = stuck_result.scalars().all()
all_sessions = list(expired_sessions) + list(stuck_sessions)
if not all_sessions:
return {
"status": "no_sessions",
"message": "No sessions to cleanup",
"total_expired": 0,
"total_stuck": 0
}
# Create cleanup job
job_id = str(uuid.uuid4())
session_ids = [s.session_id for s in all_sessions]
job_data = {
"job_id": job_id,
"session_ids": session_ids,
"created_at": now.isoformat(),
"retry_count": 0
}
# Enqueue job
client = await redis.get_client()
await client.lpush("cleanup:queue", json.dumps(job_data))
logger.info(
"Cleanup job enqueued",
job_id=job_id,
session_count=len(session_ids),
expired_count=len(expired_sessions),
stuck_count=len(stuck_sessions)
)
return {
"status": "enqueued",
"job_id": job_id,
"session_count": len(session_ids),
"total_expired": len(expired_sessions),
"total_stuck": len(stuck_sessions),
"message": f"Cleanup job enqueued for {len(session_ids)} sessions"
}
@router.get(
route_builder.build_operations_route("cleanup/{job_id}", include_tenant_prefix=False),
response_model=dict
)
async def get_cleanup_status(
job_id: str,
redis: DemoRedisWrapper = Depends(get_redis)
):
"""Get status of cleanup job"""
import json
client = await redis.get_client()
status_key = f"cleanup:job:{job_id}:status"
status_data = await client.get(status_key)
if not status_data:
return {
"status": "not_found",
"message": "Job not found or expired (jobs expire after 1 hour)"
}
return json.loads(status_data)
@router.post(

View File

@@ -0,0 +1,7 @@
"""
Background Jobs Package
"""
from .cleanup_worker import CleanupWorker, run_cleanup_worker
__all__ = ["CleanupWorker", "run_cleanup_worker"]

View File

@@ -0,0 +1,272 @@
"""
Background Cleanup Worker
Processes demo session cleanup jobs from Redis queue
"""
import asyncio
import structlog
from datetime import datetime, timezone, timedelta
from typing import Dict, Any
import json
import uuid
from contextlib import asynccontextmanager
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import DatabaseManager
from app.core.redis_wrapper import DemoRedisWrapper
from app.services.data_cloner import DemoDataCloner
from app.models.demo_session import DemoSession, DemoSessionStatus
logger = structlog.get_logger()
@asynccontextmanager
async def get_db_session():
"""Get database session context manager"""
db_manager = DatabaseManager()
db_manager.initialize()
async with db_manager.session_factory() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()
class CleanupWorker:
"""Background worker for processing cleanup jobs"""
def __init__(self, redis: DemoRedisWrapper):
self.redis = redis
self.queue_key = "cleanup:queue"
self.processing_key = "cleanup:processing"
self.running = False
async def start(self):
"""Start the worker (runs indefinitely)"""
self.running = True
logger.info("Cleanup worker started")
while self.running:
try:
await self._process_next_job()
except Exception as e:
logger.error("Worker error", error=str(e), exc_info=True)
await asyncio.sleep(5) # Back off on error
async def stop(self):
"""Stop the worker gracefully"""
self.running = False
logger.info("Cleanup worker stopped")
async def _process_next_job(self):
"""Process next job from queue"""
client = await self.redis.get_client()
# Blocking pop from queue (5 second timeout)
result = await client.brpoplpush(
self.queue_key,
self.processing_key,
timeout=5
)
if not result:
return # No job available
job_data = json.loads(result)
job_id = job_data["job_id"]
session_ids = job_data["session_ids"]
logger.info(
"Processing cleanup job",
job_id=job_id,
session_count=len(session_ids)
)
try:
# Process cleanup
stats = await self._cleanup_sessions(session_ids)
# Mark job as complete
await self._mark_job_complete(job_id, stats)
# Remove from processing queue
await client.lrem(self.processing_key, 1, result)
logger.info("Job completed", job_id=job_id, stats=stats)
except Exception as e:
logger.error("Job failed", job_id=job_id, error=str(e), exc_info=True)
# Check retry count
retry_count = job_data.get("retry_count", 0)
if retry_count < 3:
# Retry - put back in queue
job_data["retry_count"] = retry_count + 1
await client.lpush(self.queue_key, json.dumps(job_data))
logger.info("Job requeued for retry", job_id=job_id, retry_count=retry_count + 1)
else:
# Max retries reached - mark as failed
await self._mark_job_failed(job_id, str(e))
logger.error("Job failed after max retries", job_id=job_id)
# Remove from processing queue
await client.lrem(self.processing_key, 1, result)
async def _cleanup_sessions(self, session_ids: list) -> Dict[str, Any]:
"""Execute cleanup for list of sessions with parallelization"""
async with get_db_session() as db:
redis = DemoRedisWrapper()
data_cloner = DemoDataCloner(db, redis)
try:
# Get sessions to cleanup
result = await db.execute(
select(DemoSession).where(
DemoSession.session_id.in_(session_ids)
)
)
sessions = result.scalars().all()
stats = {
"cleaned_up": 0,
"failed": 0,
"errors": []
}
# Process each session
for session in sessions:
try:
# Mark session as expired
session.status = DemoSessionStatus.EXPIRED
await db.commit()
# Check if this is an enterprise demo with children
child_tenant_ids = []
if session.demo_account_type == "enterprise" and session.session_metadata:
child_tenant_ids = session.session_metadata.get("child_tenant_ids", [])
# Delete child tenants in parallel (for enterprise demos)
if child_tenant_ids:
logger.info(
"Cleaning up enterprise demo children",
session_id=session.session_id,
child_count=len(child_tenant_ids)
)
child_tasks = [
data_cloner.delete_session_data(
str(child_id),
session.session_id
)
for child_id in child_tenant_ids
]
child_results = await asyncio.gather(*child_tasks, return_exceptions=True)
# Log any child deletion failures
for child_id, result in zip(child_tenant_ids, child_results):
if isinstance(result, Exception):
logger.error(
"Failed to delete child tenant",
child_id=child_id,
error=str(result)
)
# Delete parent/main session data
await data_cloner.delete_session_data(
str(session.virtual_tenant_id),
session.session_id
)
stats["cleaned_up"] += 1
logger.info(
"Session cleaned up",
session_id=session.session_id,
is_enterprise=(session.demo_account_type == "enterprise"),
children_deleted=len(child_tenant_ids)
)
except Exception as e:
stats["failed"] += 1
stats["errors"].append({
"session_id": session.session_id,
"error": str(e)
})
logger.error(
"Failed to cleanup session",
session_id=session.session_id,
error=str(e),
exc_info=True
)
return stats
finally:
# Always close HTTP client
await data_cloner.close()
async def _mark_job_complete(self, job_id: str, stats: Dict[str, Any]):
"""Mark job as complete in Redis"""
client = await self.redis.get_client()
status_key = f"cleanup:job:{job_id}:status"
await client.setex(
status_key,
3600, # Keep status for 1 hour
json.dumps({
"status": "completed",
"stats": stats,
"completed_at": datetime.now(timezone.utc).isoformat()
})
)
async def _mark_job_failed(self, job_id: str, error: str):
"""Mark job as failed in Redis"""
client = await self.redis.get_client()
status_key = f"cleanup:job:{job_id}:status"
await client.setex(
status_key,
3600,
json.dumps({
"status": "failed",
"error": error,
"failed_at": datetime.now(timezone.utc).isoformat()
})
)
async def run_cleanup_worker():
"""Entry point for worker process"""
# Initialize Redis client
import os
from shared.redis_utils import initialize_redis
redis_url = os.getenv("REDIS_URL", "redis://redis-service:6379/0")
try:
# Initialize Redis with connection pool settings
await initialize_redis(redis_url, db=0, max_connections=10)
logger.info("Redis initialized successfully", redis_url=redis_url.split('@')[-1])
except Exception as e:
logger.error("Failed to initialize Redis", error=str(e))
raise
redis = DemoRedisWrapper()
worker = CleanupWorker(redis)
try:
await worker.start()
except KeyboardInterrupt:
logger.info("Received interrupt signal")
await worker.stop()
except Exception as e:
logger.error("Worker crashed", error=str(e), exc_info=True)
raise
if __name__ == "__main__":
asyncio.run(run_cleanup_worker())

View File

@@ -4,11 +4,12 @@ Clones base demo data to session-specific virtual tenants
"""
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Dict, Any, List
from typing import Dict, Any, List, Optional
import httpx
import structlog
import uuid
import os
import asyncio
from app.core.redis_wrapper import DemoRedisWrapper
from app.core import settings
@@ -22,6 +23,26 @@ class DemoDataCloner:
def __init__(self, db: AsyncSession, redis: DemoRedisWrapper):
self.db = db
self.redis = redis
self._http_client: Optional[httpx.AsyncClient] = None
async def get_http_client(self) -> httpx.AsyncClient:
"""Get or create shared HTTP client with connection pooling"""
if self._http_client is None:
self._http_client = httpx.AsyncClient(
timeout=httpx.Timeout(30.0, connect_timeout=10.0),
limits=httpx.Limits(
max_connections=20,
max_keepalive_connections=10,
keepalive_expiry=30.0
)
)
return self._http_client
async def close(self):
"""Close HTTP client on cleanup"""
if self._http_client:
await self._http_client.aclose()
self._http_client = None
async def clone_tenant_data(
self,
@@ -214,7 +235,7 @@ class DemoDataCloner:
async def _fetch_inventory_data(self, tenant_id: str) -> Dict[str, Any]:
"""Fetch inventory data for caching"""
async with httpx.AsyncClient() as client:
async with httpx.AsyncClient(timeout=httpx.Timeout(15.0, connect_timeout=5.0)) as client:
response = await client.get(
f"{settings.INVENTORY_SERVICE_URL}/api/inventory/summary",
headers={"X-Tenant-Id": tenant_id}
@@ -223,7 +244,7 @@ class DemoDataCloner:
async def _fetch_pos_data(self, tenant_id: str) -> Dict[str, Any]:
"""Fetch POS data for caching"""
async with httpx.AsyncClient() as client:
async with httpx.AsyncClient(timeout=httpx.Timeout(15.0, connect_timeout=5.0)) as client:
response = await client.get(
f"{settings.POS_SERVICE_URL}/api/pos/current-session",
headers={"X-Tenant-Id": tenant_id}
@@ -261,7 +282,7 @@ class DemoDataCloner:
session_id: str
):
"""
Delete all data for a session
Delete all data for a session using parallel deletion for performance
Args:
virtual_tenant_id: Virtual tenant ID to delete
@@ -273,29 +294,40 @@ class DemoDataCloner:
session_id=session_id
)
# Delete from each service
# Note: Services are deleted in reverse dependency order to avoid foreign key issues
# Get shared HTTP client for all deletions
client = await self.get_http_client()
# Services list - all can be deleted in parallel as deletion endpoints
# handle their own internal ordering if needed
services = [
"forecasting", # No dependencies
"sales", # Depends on inventory, recipes
"orders", # Depends on customers (within same service)
"production", # Depends on recipes, equipment
"inventory", # Core data (ingredients, products)
"recipes", # Core data
"suppliers", # Core data
"pos", # Point of sale data
"distribution", # Distribution routes
"procurement" # Procurement and purchase orders
"forecasting",
"sales",
"orders",
"production",
"inventory",
"recipes",
"suppliers",
"pos",
"distribution",
"procurement"
]
for service_name in services:
try:
await self._delete_service_data(service_name, virtual_tenant_id)
except Exception as e:
# Create deletion tasks for all services
deletion_tasks = [
self._delete_service_data(service_name, virtual_tenant_id, client)
for service_name in services
]
# Execute all deletions in parallel with exception handling
results = await asyncio.gather(*deletion_tasks, return_exceptions=True)
# Log any failures
for service_name, result in zip(services, results):
if isinstance(result, Exception):
logger.error(
"Failed to delete service data",
service=service_name,
error=str(e)
error=str(result)
)
# Delete from Redis
@@ -303,16 +335,20 @@ class DemoDataCloner:
logger.info("Session data deleted", virtual_tenant_id=virtual_tenant_id)
async def _delete_service_data(self, service_name: str, virtual_tenant_id: str):
"""Delete data from a specific service"""
async def _delete_service_data(
self,
service_name: str,
virtual_tenant_id: str,
client: httpx.AsyncClient
):
"""Delete data from a specific service using provided HTTP client"""
service_url = self._get_service_url(service_name)
# Get internal API key from settings
from app.core.config import settings
internal_api_key = settings.INTERNAL_API_KEY
async with httpx.AsyncClient(timeout=30.0) as client:
await client.delete(
f"{service_url}/internal/demo/tenant/{virtual_tenant_id}",
headers={"X-Internal-API-Key": internal_api_key}
)
await client.delete(
f"{service_url}/internal/demo/tenant/{virtual_tenant_id}",
headers={"X-Internal-API-Key": internal_api_key}
)

View File

@@ -28,10 +28,9 @@ from shared.clients import (
get_production_client,
get_sales_client,
get_inventory_client,
get_procurement_client
get_procurement_client,
get_distribution_client
)
# TODO: Add distribution client when available
# from shared.clients import get_distribution_client
def get_enterprise_dashboard_service() -> EnterpriseDashboardService:
from app.core.config import settings
@@ -40,7 +39,7 @@ def get_enterprise_dashboard_service() -> EnterpriseDashboardService:
production_client = get_production_client(settings)
sales_client = get_sales_client(settings)
inventory_client = get_inventory_client(settings)
distribution_client = None # TODO: Add when distribution service is ready
distribution_client = get_distribution_client(settings)
procurement_client = get_procurement_client(settings)
return EnterpriseDashboardService(

View File

@@ -110,7 +110,8 @@ from shared.clients import (
get_production_client,
get_sales_client,
get_inventory_client,
get_procurement_client
get_procurement_client,
get_distribution_client
)
def get_enterprise_dashboard_service() -> EnterpriseDashboardService:
@@ -119,7 +120,7 @@ def get_enterprise_dashboard_service() -> EnterpriseDashboardService:
production_client = get_production_client(settings)
sales_client = get_sales_client(settings)
inventory_client = get_inventory_client(settings)
distribution_client = None # TODO: Add when distribution service is ready
distribution_client = get_distribution_client(settings)
procurement_client = get_procurement_client(settings)
return EnterpriseDashboardService(

View File

@@ -138,13 +138,8 @@ class EnterpriseDashboardService:
async def _get_production_volume(self, parent_tenant_id: str) -> float:
"""Get total production volume for the parent tenant (central production)"""
try:
start_date = date.today() - timedelta(days=30)
end_date = date.today()
production_summary = await self.production_client.get_production_summary(
tenant_id=parent_tenant_id,
start_date=start_date,
end_date=end_date
production_summary = await self.production_client.get_dashboard_summary(
tenant_id=parent_tenant_id
)
# Return total production value
@@ -382,6 +377,16 @@ class EnterpriseDashboardService:
total_demand = 0
daily_summary = {}
if not forecast_data:
logger.warning("No forecast data returned", parent_tenant_id=parent_tenant_id)
return {
'parent_tenant_id': parent_tenant_id,
'days_forecast': days_ahead,
'total_predicted_demand': 0,
'daily_summary': {},
'last_updated': datetime.utcnow().isoformat()
}
for forecast_date_str, products in forecast_data.get('aggregated_forecasts', {}).items():
day_total = sum(item.get('predicted_demand', 0) for item in products.values())
total_demand += day_total
@@ -500,10 +505,8 @@ class EnterpriseDashboardService:
async def _get_tenant_production(self, tenant_id: str, start_date: date, end_date: date) -> float:
"""Helper to get production for a specific tenant"""
try:
production_data = await self.production_client.get_production_summary(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date
production_data = await self.production_client.get_dashboard_summary(
tenant_id=tenant_id
)
return float(production_data.get('total_value', 0))
except Exception as e:

View File

@@ -171,7 +171,8 @@ async def clone_demo_data(
business_model=demo_account_type,
is_active=True,
timezone="Europe/Madrid",
owner_id=demo_owner_uuid # Required field - matches seed_demo_users.py
owner_id=demo_owner_uuid, # Required field - matches seed_demo_users.py
tenant_type="parent" if demo_account_type in ["enterprise", "enterprise_parent"] else "standalone"
)
db.add(tenant)

View File

@@ -354,7 +354,9 @@ def extract_user_from_headers(request: Request) -> Optional[Dict[str, Any]]:
"permissions": request.headers.get("X-User-Permissions", "").split(",") if request.headers.get("X-User-Permissions") else [],
"full_name": request.headers.get("x-user-full-name", ""),
"subscription_tier": request.headers.get("x-subscription-tier", ""),
"is_demo": request.headers.get("x-is-demo", "").lower() == "true"
"is_demo": request.headers.get("x-is-demo", "").lower() == "true",
"demo_session_id": request.headers.get("x-demo-session-id", ""),
"demo_account_type": request.headers.get("x-demo-account-type", "")
}
# ✅ ADD THIS: Handle service tokens properly

View File

@@ -19,6 +19,7 @@ from .tenant_client import TenantServiceClient
from .ai_insights_client import AIInsightsClient
from .alerts_client import AlertsServiceClient
from .procurement_client import ProcurementServiceClient
from .distribution_client import DistributionServiceClient
# Import config
from shared.config.base import BaseServiceSettings
@@ -146,6 +147,16 @@ def get_procurement_client(config: BaseServiceSettings = None, service_name: str
_client_cache[cache_key] = ProcurementServiceClient(config, service_name)
return _client_cache[cache_key]
def get_distribution_client(config: BaseServiceSettings = None, service_name: str = "unknown") -> DistributionServiceClient:
"""Get or create a distribution service client"""
if config is None:
from app.core.config import settings as config
cache_key = f"distribution_{service_name}"
if cache_key not in _client_cache:
_client_cache[cache_key] = DistributionServiceClient(config, service_name)
return _client_cache[cache_key]
class ServiceClients:
"""Convenient wrapper for all service clients"""
@@ -257,6 +268,7 @@ __all__ = [
'SuppliersServiceClient',
'AlertsServiceClient',
'TenantServiceClient',
'DistributionServiceClient',
'ServiceClients',
'get_training_client',
'get_sales_client',
@@ -270,6 +282,7 @@ __all__ = [
'get_alerts_client',
'get_tenant_client',
'get_procurement_client',
'get_distribution_client',
'get_service_clients',
'create_forecast_client'
]

View File

@@ -386,6 +386,46 @@ class ForecastServiceClient(BaseServiceClient):
forecast_days=forecast_days
)
async def get_aggregated_forecast(
self,
parent_tenant_id: str,
start_date: date,
end_date: date,
product_id: Optional[str] = None
) -> Dict[str, Any]:
"""
Get aggregated forecast for enterprise tenant and all children.
This method calls the enterprise forecasting aggregation endpoint which
combines demand forecasts across the parent tenant and all child tenants
in the network. Used for centralized production planning.
Args:
parent_tenant_id: The parent tenant (central bakery) UUID
start_date: Start date for forecast range
end_date: End date for forecast range
product_id: Optional product ID to filter forecasts
Returns:
Aggregated forecast data including:
- total_demand: Sum of all child demands
- child_contributions: Per-child demand breakdown
- forecast_date_range: Date range for the forecast
- cached: Whether data was served from Redis cache
"""
params = {
"start_date": start_date.isoformat(),
"end_date": end_date.isoformat()
}
if product_id:
params["product_id"] = product_id
return await self.get(
"forecasting/enterprise/aggregated",
tenant_id=parent_tenant_id,
params=params
)
async def create_forecast(
self,
tenant_id: str,

View File

@@ -6,6 +6,7 @@ Handles all API calls to the sales service
import httpx
import structlog
from datetime import date
from typing import Dict, Any, Optional, List, Union
from .base_service_client import BaseServiceClient
from shared.config.base import BaseServiceSettings
@@ -183,6 +184,37 @@ class SalesServiceClient(BaseServiceClient):
tenant_id=tenant_id)
return None
async def get_sales_summary(
self,
tenant_id: str,
start_date: date,
end_date: date
) -> Dict[str, Any]:
"""
Get sales summary/analytics for a tenant.
This method calls the sales analytics summary endpoint which provides
aggregated sales metrics over a date range.
Args:
tenant_id: The tenant UUID
start_date: Start date for summary range
end_date: End date for summary range
Returns:
Sales summary data including metrics like total sales, revenue, etc.
"""
params = {
"start_date": start_date.isoformat(),
"end_date": end_date.isoformat()
}
return await self.get(
"sales/analytics/summary",
tenant_id=tenant_id,
params=params
)
# ================================================================
# DATA IMPORT
# ================================================================

View File

@@ -310,7 +310,7 @@ class TenantServiceClient(BaseServiceClient):
List of child tenant dictionaries
"""
try:
result = await self.get(f"tenants/{parent_tenant_id}/children", tenant_id=parent_tenant_id)
result = await self.get("children", tenant_id=parent_tenant_id)
if result:
logger.info("Retrieved child tenants",
parent_tenant_id=parent_tenant_id,