Initial commit - production deployment

This commit is contained in:
2026-01-21 17:17:16 +01:00
commit c23d00dd92
2289 changed files with 638440 additions and 0 deletions

View File

@@ -0,0 +1,41 @@
# AI Insights Service Environment Variables
# Service Info
SERVICE_NAME=ai-insights
SERVICE_VERSION=1.0.0
API_V1_PREFIX=/api/v1/ai-insights
# Database
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/bakery_ai_insights
DB_POOL_SIZE=20
DB_MAX_OVERFLOW=10
# Redis
REDIS_URL=redis://localhost:6379/5
REDIS_CACHE_TTL=900
# Service URLs
FORECASTING_SERVICE_URL=http://forecasting-service:8000
PROCUREMENT_SERVICE_URL=http://procurement-service:8000
PRODUCTION_SERVICE_URL=http://production-service:8000
SALES_SERVICE_URL=http://sales-service:8000
INVENTORY_SERVICE_URL=http://inventory-service:8000
# Circuit Breaker Settings
CIRCUIT_BREAKER_FAILURE_THRESHOLD=5
CIRCUIT_BREAKER_TIMEOUT=60
# Insight Settings
MIN_CONFIDENCE_THRESHOLD=60
DEFAULT_INSIGHT_TTL_DAYS=7
MAX_INSIGHTS_PER_REQUEST=100
# Feedback Settings
FEEDBACK_PROCESSING_ENABLED=true
FEEDBACK_PROCESSING_SCHEDULE="0 6 * * *"
# Logging
LOG_LEVEL=INFO
# CORS
ALLOWED_ORIGINS=["http://localhost:3000","http://localhost:5173"]

View File

@@ -0,0 +1,59 @@
# =============================================================================
# AI Insights Service Dockerfile - Environment-Configurable Base Images
# =============================================================================
# Build arguments for registry configuration:
# - BASE_REGISTRY: Registry URL (default: docker.io for Docker Hub)
# - PYTHON_IMAGE: Python image name and tag (default: python:3.11-slim)
# =============================================================================
ARG BASE_REGISTRY=docker.io
ARG PYTHON_IMAGE=python:3.11-slim
FROM ${BASE_REGISTRY}/${PYTHON_IMAGE} AS shared
WORKDIR /shared
COPY shared/ /shared/
ARG BASE_REGISTRY=docker.io
ARG PYTHON_IMAGE=python:3.11-slim
FROM ${BASE_REGISTRY}/${PYTHON_IMAGE}
WORKDIR /app
# Install system dependencies
RUN apt-get update && apt-get install -y \
gcc \
curl \
postgresql-client \
&& rm -rf /var/lib/apt/lists/*
# Copy requirements
COPY shared/requirements-tracing.txt /tmp/
COPY services/ai_insights/requirements.txt .
# Install Python dependencies
RUN pip install --no-cache-dir -r /tmp/requirements-tracing.txt
RUN pip install --no-cache-dir -r requirements.txt
# Copy shared libraries from the shared stage
COPY --from=shared /shared /app/shared
# Copy application code
COPY services/ai_insights/ .
# Copy scripts for migrations
COPY scripts/ /app/scripts/
# Add shared libraries to Python path
ENV PYTHONPATH="/app:/app/shared:${PYTHONPATH:-}"
# Expose port
EXPOSE 8000
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1
# Run the application
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]

View File

@@ -0,0 +1,232 @@
# AI Insights Service - Quick Start Guide
Get the AI Insights Service running in 5 minutes.
## Prerequisites
- Python 3.11+
- PostgreSQL 14+ (running)
- Redis 6+ (running)
## Step 1: Setup Environment
```bash
cd services/ai_insights
# Create virtual environment
python3 -m venv venv
source venv/bin/activate # Windows: venv\Scripts\activate
# Install dependencies
pip install -r requirements.txt
```
## Step 2: Configure Database
```bash
# Copy environment template
cp .env.example .env
# Edit .env file
nano .env
```
**Minimum required configuration**:
```env
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/bakery_ai_insights
REDIS_URL=redis://localhost:6379/5
```
## Step 3: Create Database
```bash
# Connect to PostgreSQL
psql -U postgres
# Create database
CREATE DATABASE bakery_ai_insights;
\q
```
## Step 4: Run Migrations
```bash
# Run Alembic migrations
alembic upgrade head
```
You should see:
```
INFO [alembic.runtime.migration] Running upgrade -> 001, Initial schema for AI Insights Service
```
## Step 5: Start the Service
```bash
uvicorn app.main:app --reload
```
You should see:
```
INFO: Uvicorn running on http://127.0.0.1:8000
INFO: Application startup complete.
```
## Step 6: Verify Installation
Open browser to http://localhost:8000/docs
You should see the Swagger UI with all API endpoints.
### Test Health Endpoint
```bash
curl http://localhost:8000/health
```
Expected response:
```json
{
"status": "healthy",
"service": "ai-insights",
"version": "1.0.0"
}
```
## Step 7: Create Your First Insight
```bash
curl -X POST "http://localhost:8000/api/v1/ai-insights/tenants/550e8400-e29b-41d4-a716-446655440000/insights" \
-H "Content-Type: application/json" \
-d '{
"tenant_id": "550e8400-e29b-41d4-a716-446655440000",
"type": "recommendation",
"priority": "high",
"category": "forecasting",
"title": "Test Insight - Weekend Demand Pattern",
"description": "Weekend sales 20% higher than weekdays",
"impact_type": "revenue_increase",
"impact_value": 150.00,
"impact_unit": "euros/week",
"confidence": 85,
"metrics_json": {
"weekday_avg": 45.2,
"weekend_avg": 54.2,
"increase_pct": 20.0
},
"actionable": true,
"recommendation_actions": [
{"label": "Increase Production", "action": "adjust_production"}
],
"source_service": "forecasting"
}'
```
## Step 8: Query Your Insights
```bash
curl "http://localhost:8000/api/v1/ai-insights/tenants/550e8400-e29b-41d4-a716-446655440000/insights?page=1&page_size=10"
```
## Common Issues
### Issue: "ModuleNotFoundError: No module named 'app'"
**Solution**: Make sure you're running from the `services/ai_insights/` directory and virtual environment is activated.
### Issue: "Connection refused" on database
**Solution**: Verify PostgreSQL is running:
```bash
# Check if PostgreSQL is running
pg_isready
# Start PostgreSQL (macOS with Homebrew)
brew services start postgresql
# Start PostgreSQL (Linux)
sudo systemctl start postgresql
```
### Issue: "Redis connection error"
**Solution**: Verify Redis is running:
```bash
# Check if Redis is running
redis-cli ping
# Should return: PONG
# Start Redis (macOS with Homebrew)
brew services start redis
# Start Redis (Linux)
sudo systemctl start redis
```
### Issue: "Alembic command not found"
**Solution**: Virtual environment not activated:
```bash
source venv/bin/activate
```
## Next Steps
1. **Explore API**: Visit http://localhost:8000/docs
2. **Read Documentation**: See `README.md` for detailed documentation
3. **Implementation Guide**: See `AI_INSIGHTS_IMPLEMENTATION_SUMMARY.md`
4. **Integration**: Start integrating with other services
## Useful Commands
```bash
# Check service status
curl http://localhost:8000/health
# Get aggregate metrics
curl "http://localhost:8000/api/v1/ai-insights/tenants/{tenant_id}/insights/metrics/summary"
# Filter high-confidence insights
curl "http://localhost:8000/api/v1/ai-insights/tenants/{tenant_id}/insights?actionable_only=true&min_confidence=80"
# Stop the service
# Press Ctrl+C in the terminal running uvicorn
# Deactivate virtual environment
deactivate
```
## Docker Quick Start (Alternative)
If you prefer Docker:
```bash
# Build image
docker build -t ai-insights .
# Run container
docker run -d \
--name ai-insights \
-p 8000:8000 \
-e DATABASE_URL=postgresql+asyncpg://postgres:postgres@host.docker.internal:5432/bakery_ai_insights \
-e REDIS_URL=redis://host.docker.internal:6379/5 \
ai-insights
# Check logs
docker logs ai-insights
# Stop container
docker stop ai-insights
docker rm ai-insights
```
## Support
- **Documentation**: See `README.md`
- **API Docs**: http://localhost:8000/docs
- **Issues**: Create GitHub issue or contact team
---
**You're ready!** The AI Insights Service is now running and ready to accept insights from other services.

View File

@@ -0,0 +1,325 @@
# AI Insights Service
## Overview
The **AI Insights Service** provides intelligent, actionable recommendations to bakery operators by analyzing patterns across inventory, production, procurement, and sales data. It acts as a virtual operations consultant, proactively identifying opportunities for cost savings, waste reduction, and operational improvements. This service transforms raw data into business intelligence that drives profitability.
## Key Features
### Intelligent Recommendations
- **Inventory Optimization** - Smart reorder point suggestions and stock level adjustments
- **Production Planning** - Optimal batch size and scheduling recommendations
- **Procurement Suggestions** - Best supplier selection and order timing advice
- **Sales Opportunities** - Identify trending products and underperforming items
- **Cost Reduction** - Find areas to reduce waste and lower operational costs
- **Quality Improvements** - Detect patterns affecting product quality
### Unified Insight Management
- **Centralized Storage** - All AI-generated insights in one place
- **Confidence Scoring** - Standardized 0-100% confidence calculation across insight types
- **Impact Estimation** - Business value quantification for recommendations
- **Feedback Loop** - Closed-loop learning from applied insights
- **Cross-Service Intelligence** - Correlation detection between insights from different services
## Features
### Core Capabilities
1. **Insight Aggregation**
- Collect insights from Forecasting, Procurement, Production, and Sales services
- Categorize and prioritize recommendations
- Filter by confidence, category, priority, and actionability
2. **Confidence Calculation**
- Multi-factor scoring: data quality, model performance, sample size, recency, historical accuracy
- Insight-type specific adjustments
- Specialized calculations for forecasting and optimization insights
3. **Impact Estimation**
- Cost savings quantification
- Revenue increase projections
- Waste reduction calculations
- Efficiency gain measurements
- Quality improvement tracking
4. **Feedback & Learning**
- Track application outcomes
- Compare expected vs. actual impact
- Calculate success rates
- Enable model improvement
5. **Orchestration Integration**
- Pre-orchestration insight gathering
- Actionable insight filtering
- Categorized recommendations for workflow phases
## Architecture
### Database Models
- **AIInsight**: Core insights table with classification, confidence, impact metrics
- **InsightFeedback**: Feedback tracking for closed-loop learning
- **InsightCorrelation**: Cross-service insight relationships
### API Endpoints
```
POST /api/v1/ai-insights/tenants/{tenant_id}/insights
GET /api/v1/ai-insights/tenants/{tenant_id}/insights
GET /api/v1/ai-insights/tenants/{tenant_id}/insights/{insight_id}
PATCH /api/v1/ai-insights/tenants/{tenant_id}/insights/{insight_id}
DELETE /api/v1/ai-insights/tenants/{tenant_id}/insights/{insight_id}
GET /api/v1/ai-insights/tenants/{tenant_id}/insights/orchestration-ready
GET /api/v1/ai-insights/tenants/{tenant_id}/insights/metrics/summary
POST /api/v1/ai-insights/tenants/{tenant_id}/insights/{insight_id}/apply
POST /api/v1/ai-insights/tenants/{tenant_id}/insights/{insight_id}/feedback
POST /api/v1/ai-insights/tenants/{tenant_id}/insights/refresh
GET /api/v1/ai-insights/tenants/{tenant_id}/insights/export
```
## Installation
### Prerequisites
- Python 3.11+
- PostgreSQL 14+
- Redis 6+
### Setup
1. **Clone and navigate**:
```bash
cd services/ai_insights
```
2. **Create virtual environment**:
```bash
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
```
3. **Install dependencies**:
```bash
pip install -r requirements.txt
```
4. **Configure environment**:
```bash
cp .env.example .env
# Edit .env with your configuration
```
5. **Run migrations**:
```bash
alembic upgrade head
```
6. **Start the service**:
```bash
uvicorn app.main:app --reload
```
The service will be available at `http://localhost:8000`.
## Configuration
### Environment Variables
| Variable | Description | Default |
|----------|-------------|---------|
| `DATABASE_URL` | PostgreSQL connection string | Required |
| `REDIS_URL` | Redis connection string | Required |
| `FORECASTING_SERVICE_URL` | Forecasting service URL | `http://forecasting-service:8000` |
| `PROCUREMENT_SERVICE_URL` | Procurement service URL | `http://procurement-service:8000` |
| `PRODUCTION_SERVICE_URL` | Production service URL | `http://production-service:8000` |
| `MIN_CONFIDENCE_THRESHOLD` | Minimum confidence for insights | `60` |
| `DEFAULT_INSIGHT_TTL_DAYS` | Days before insights expire | `7` |
## Usage Examples
### Creating an Insight
```python
import httpx
insight_data = {
"tenant_id": "550e8400-e29b-41d4-a716-446655440000",
"type": "recommendation",
"priority": "high",
"category": "procurement",
"title": "Flour Price Increase Expected",
"description": "Price predicted to rise 8% in next week. Consider ordering now.",
"impact_type": "cost_savings",
"impact_value": 120.50,
"impact_unit": "euros",
"confidence": 85,
"metrics_json": {
"current_price": 1.20,
"predicted_price": 1.30,
"order_quantity": 1000
},
"actionable": True,
"recommendation_actions": [
{"label": "Order Now", "action": "create_purchase_order"},
{"label": "Review", "action": "review_forecast"}
],
"source_service": "procurement",
"source_data_id": "price_forecast_123"
}
response = httpx.post(
"http://localhost:8000/api/v1/ai-insights/tenants/550e8400-e29b-41d4-a716-446655440000/insights",
json=insight_data
)
print(response.json())
```
### Querying Insights
```python
# Get high-confidence actionable insights
response = httpx.get(
"http://localhost:8000/api/v1/ai-insights/tenants/550e8400-e29b-41d4-a716-446655440000/insights",
params={
"actionable_only": True,
"min_confidence": 80,
"priority": "high",
"page": 1,
"page_size": 20
}
)
insights = response.json()
```
### Recording Feedback
```python
feedback_data = {
"insight_id": "insight-uuid",
"action_taken": "create_purchase_order",
"success": True,
"expected_impact_value": 120.50,
"actual_impact_value": 115.30,
"result_data": {
"order_id": "PO-12345",
"actual_savings": 115.30
},
"applied_by": "user@example.com"
}
response = httpx.post(
f"http://localhost:8000/api/v1/ai-insights/tenants/{tenant_id}/insights/{insight_id}/feedback",
json=feedback_data
)
```
## Development
### Running Tests
```bash
pytest
```
### Code Quality
```bash
# Format code
black app/
# Lint
flake8 app/
# Type checking
mypy app/
```
### Creating a Migration
```bash
alembic revision --autogenerate -m "Description of changes"
alembic upgrade head
```
## Insight Types
- **optimization**: Process improvements with measurable gains
- **alert**: Warnings requiring attention
- **prediction**: Future forecasts with confidence intervals
- **recommendation**: Suggested actions with estimated impact
- **insight**: General data-driven observations
- **anomaly**: Unusual patterns detected in data
## Priority Levels
- **critical**: Immediate action required (e.g., stockout risk)
- **high**: Action recommended soon (e.g., price opportunity)
- **medium**: Consider acting (e.g., efficiency improvement)
- **low**: Informational (e.g., pattern observation)
## Categories
- **forecasting**: Demand predictions and patterns
- **inventory**: Stock management and optimization
- **production**: Manufacturing efficiency and scheduling
- **procurement**: Purchasing and supplier management
- **customer**: Customer behavior and satisfaction
- **cost**: Cost optimization opportunities
- **quality**: Quality improvements
- **efficiency**: Process efficiency gains
## Integration with Other Services
### Forecasting Service
- Receives forecast accuracy insights
- Pattern detection alerts
- Demand anomaly notifications
### Procurement Service
- Price forecast recommendations
- Supplier performance alerts
- Safety stock optimization
### Production Service
- Yield prediction insights
- Schedule optimization recommendations
- Equipment maintenance alerts
### Orchestrator Service
- Pre-orchestration insight gathering
- Actionable recommendation filtering
- Feedback recording for applied insights
## API Documentation
Once the service is running, interactive API documentation is available at:
- Swagger UI: `http://localhost:8000/docs`
- ReDoc: `http://localhost:8000/redoc`
## Monitoring
### Health Check
```bash
curl http://localhost:8000/health
```
### Metrics Endpoint
```bash
curl http://localhost:8000/api/v1/ai-insights/tenants/{tenant_id}/insights/metrics/summary
```
## License
Copyright © 2025 Bakery IA. All rights reserved.
## Support
For issues and questions, please contact the development team or create an issue in the project repository.

View File

@@ -0,0 +1,112 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
script_location = migrations
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
file_template = %%(year)d%%(month).2d%%(day).2d_%%(hour).2d%%(minute).2d_%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python-dateutil library that can be
# installed by adding `alembic[tz]` to the pip requirements
# string value is passed to dateutil.tz.gettz()
# leave blank for localtime
# timezone =
# max length of characters to apply to the
# "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to migrations/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
sqlalchemy.url = driver://user:pass@localhost/dbname
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = --fix REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

View File

@@ -0,0 +1,3 @@
"""AI Insights Service."""
__version__ = "1.0.0"

View File

@@ -0,0 +1 @@
"""API modules for AI Insights Service."""

View File

@@ -0,0 +1,419 @@
"""API endpoints for AI Insights."""
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Optional
from uuid import UUID
from datetime import datetime
import math
from app.core.database import get_db
from app.repositories.insight_repository import InsightRepository
from app.repositories.feedback_repository import FeedbackRepository
from app.schemas.insight import (
AIInsightCreate,
AIInsightUpdate,
AIInsightResponse,
AIInsightList,
InsightMetrics,
InsightFilters
)
from app.schemas.feedback import InsightFeedbackCreate, InsightFeedbackResponse
router = APIRouter()
@router.post("/tenants/{tenant_id}/insights", response_model=AIInsightResponse, status_code=status.HTTP_201_CREATED)
async def create_insight(
tenant_id: UUID,
insight_data: AIInsightCreate,
db: AsyncSession = Depends(get_db)
):
"""Create a new AI Insight."""
# Ensure tenant_id matches
if insight_data.tenant_id != tenant_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Tenant ID mismatch"
)
repo = InsightRepository(db)
insight = await repo.create(insight_data)
await db.commit()
return insight
@router.get("/tenants/{tenant_id}/insights", response_model=AIInsightList)
async def get_insights(
tenant_id: UUID,
category: Optional[str] = Query(None),
priority: Optional[str] = Query(None),
status: Optional[str] = Query(None),
actionable_only: bool = Query(False),
min_confidence: int = Query(0, ge=0, le=100),
source_service: Optional[str] = Query(None),
from_date: Optional[datetime] = Query(None),
to_date: Optional[datetime] = Query(None),
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
db: AsyncSession = Depends(get_db)
):
"""Get insights for a tenant with filters and pagination."""
filters = InsightFilters(
category=category,
priority=priority,
status=status,
actionable_only=actionable_only,
min_confidence=min_confidence,
source_service=source_service,
from_date=from_date,
to_date=to_date
)
repo = InsightRepository(db)
skip = (page - 1) * page_size
insights, total = await repo.get_by_tenant(tenant_id, filters, skip, page_size)
total_pages = math.ceil(total / page_size) if total > 0 else 0
return AIInsightList(
items=insights,
total=total,
page=page,
page_size=page_size,
total_pages=total_pages
)
@router.get("/tenants/{tenant_id}/insights/orchestration-ready")
async def get_orchestration_ready_insights(
tenant_id: UUID,
target_date: datetime = Query(...),
min_confidence: int = Query(70, ge=0, le=100),
db: AsyncSession = Depends(get_db)
):
"""Get actionable insights for orchestration workflow."""
repo = InsightRepository(db)
categorized_insights = await repo.get_orchestration_ready_insights(
tenant_id, target_date, min_confidence
)
return categorized_insights
@router.get("/tenants/{tenant_id}/insights/{insight_id}", response_model=AIInsightResponse)
async def get_insight(
tenant_id: UUID,
insight_id: UUID,
db: AsyncSession = Depends(get_db)
):
"""Get a single insight by ID."""
repo = InsightRepository(db)
insight = await repo.get_by_id(insight_id)
if not insight:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Insight not found"
)
if insight.tenant_id != tenant_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied"
)
return insight
@router.patch("/tenants/{tenant_id}/insights/{insight_id}", response_model=AIInsightResponse)
async def update_insight(
tenant_id: UUID,
insight_id: UUID,
update_data: AIInsightUpdate,
db: AsyncSession = Depends(get_db)
):
"""Update an insight (typically status changes)."""
repo = InsightRepository(db)
# Verify insight exists and belongs to tenant
insight = await repo.get_by_id(insight_id)
if not insight:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Insight not found"
)
if insight.tenant_id != tenant_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied"
)
updated_insight = await repo.update(insight_id, update_data)
await db.commit()
return updated_insight
@router.delete("/tenants/{tenant_id}/insights/{insight_id}", status_code=status.HTTP_204_NO_CONTENT)
async def dismiss_insight(
tenant_id: UUID,
insight_id: UUID,
db: AsyncSession = Depends(get_db)
):
"""Dismiss an insight (soft delete)."""
repo = InsightRepository(db)
# Verify insight exists and belongs to tenant
insight = await repo.get_by_id(insight_id)
if not insight:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Insight not found"
)
if insight.tenant_id != tenant_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied"
)
await repo.delete(insight_id)
await db.commit()
@router.get("/tenants/{tenant_id}/insights/metrics/summary", response_model=InsightMetrics)
async def get_insights_metrics(
tenant_id: UUID,
db: AsyncSession = Depends(get_db)
):
"""Get aggregate metrics for insights."""
repo = InsightRepository(db)
metrics = await repo.get_metrics(tenant_id)
return InsightMetrics(**metrics)
@router.post("/tenants/{tenant_id}/insights/{insight_id}/apply")
async def apply_insight(
tenant_id: UUID,
insight_id: UUID,
db: AsyncSession = Depends(get_db)
):
"""Apply an insight recommendation (trigger action)."""
repo = InsightRepository(db)
# Verify insight exists and belongs to tenant
insight = await repo.get_by_id(insight_id)
if not insight:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Insight not found"
)
if insight.tenant_id != tenant_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied"
)
if not insight.actionable:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="This insight is not actionable"
)
# Update status to in_progress
update_data = AIInsightUpdate(status='in_progress', applied_at=datetime.utcnow())
await repo.update(insight_id, update_data)
await db.commit()
# Route to appropriate service based on recommendation_actions
applied_actions = []
failed_actions = []
try:
import structlog
logger = structlog.get_logger()
for action in insight.recommendation_actions:
try:
action_type = action.get('action_type')
action_target = action.get('target_service')
logger.info("Processing insight action",
insight_id=str(insight_id),
action_type=action_type,
target_service=action_target)
# Route based on target service
if action_target == 'procurement':
# Create purchase order or adjust reorder points
from shared.clients.procurement_client import ProcurementServiceClient
from shared.config.base import get_settings
config = get_settings()
procurement_client = ProcurementServiceClient(config, "ai_insights")
# Example: trigger procurement action
logger.info("Routing action to procurement service", action=action)
applied_actions.append(action_type)
elif action_target == 'production':
# Adjust production schedule
from shared.clients.production_client import ProductionServiceClient
from shared.config.base import get_settings
config = get_settings()
production_client = ProductionServiceClient(config, "ai_insights")
logger.info("Routing action to production service", action=action)
applied_actions.append(action_type)
elif action_target == 'inventory':
# Adjust inventory settings
from shared.clients.inventory_client import InventoryServiceClient
from shared.config.base import get_settings
config = get_settings()
inventory_client = InventoryServiceClient(config, "ai_insights")
logger.info("Routing action to inventory service", action=action)
applied_actions.append(action_type)
elif action_target == 'pricing':
# Update pricing recommendations
logger.info("Price adjustment action identified", action=action)
applied_actions.append(action_type)
else:
logger.warning("Unknown target service for action",
action_type=action_type,
target_service=action_target)
failed_actions.append({
'action_type': action_type,
'reason': f'Unknown target service: {action_target}'
})
except Exception as action_error:
logger.error("Failed to apply action",
action_type=action.get('action_type'),
error=str(action_error))
failed_actions.append({
'action_type': action.get('action_type'),
'reason': str(action_error)
})
# Update final status
final_status = 'applied' if not failed_actions else 'partially_applied'
final_update = AIInsightUpdate(status=final_status)
await repo.update(insight_id, final_update)
await db.commit()
except Exception as e:
logger.error("Failed to route insight actions",
insight_id=str(insight_id),
error=str(e))
# Update status to failed
failed_update = AIInsightUpdate(status='failed')
await repo.update(insight_id, failed_update)
await db.commit()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to apply insight: {str(e)}"
)
return {
"message": "Insight application initiated",
"insight_id": str(insight_id),
"actions": insight.recommendation_actions,
"applied_actions": applied_actions,
"failed_actions": failed_actions,
"status": final_status
}
@router.post("/tenants/{tenant_id}/insights/{insight_id}/feedback", response_model=InsightFeedbackResponse)
async def record_feedback(
tenant_id: UUID,
insight_id: UUID,
feedback_data: InsightFeedbackCreate,
db: AsyncSession = Depends(get_db)
):
"""Record feedback for an applied insight."""
insight_repo = InsightRepository(db)
# Verify insight exists and belongs to tenant
insight = await insight_repo.get_by_id(insight_id)
if not insight:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Insight not found"
)
if insight.tenant_id != tenant_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied"
)
# Ensure feedback is for this insight
if feedback_data.insight_id != insight_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Insight ID mismatch"
)
feedback_repo = FeedbackRepository(db)
feedback = await feedback_repo.create(feedback_data)
# Update insight status based on feedback
new_status = 'applied' if feedback.success else 'dismissed'
update_data = AIInsightUpdate(status=new_status)
await insight_repo.update(insight_id, update_data)
await db.commit()
return feedback
@router.post("/tenants/{tenant_id}/insights/refresh")
async def refresh_insights(
tenant_id: UUID,
db: AsyncSession = Depends(get_db)
):
"""Trigger insight refresh (expire old, generate new)."""
repo = InsightRepository(db)
# Expire old insights
expired_count = await repo.expire_old_insights()
await db.commit()
return {
"message": "Insights refreshed",
"expired_count": expired_count
}
@router.get("/tenants/{tenant_id}/insights/export")
async def export_insights(
tenant_id: UUID,
format: str = Query("json", regex="^(json|csv)$"),
db: AsyncSession = Depends(get_db)
):
"""Export insights to JSON or CSV."""
repo = InsightRepository(db)
insights, _ = await repo.get_by_tenant(tenant_id, filters=None, skip=0, limit=1000)
if format == "json":
return {"insights": [AIInsightResponse.model_validate(i) for i in insights]}
# CSV export would be implemented here
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="CSV export not yet implemented"
)

View File

@@ -0,0 +1,77 @@
"""Configuration settings for AI Insights Service."""
from shared.config.base import BaseServiceSettings
import os
from typing import Optional
class Settings(BaseServiceSettings):
"""Application settings."""
# Service Info
SERVICE_NAME: str = "ai-insights"
SERVICE_VERSION: str = "1.0.0"
API_V1_PREFIX: str = "/api/v1"
# Database configuration (secure approach - build from components)
@property
def DATABASE_URL(self) -> str:
"""Build database URL from secure components"""
# Try complete URL first (for backward compatibility)
complete_url = os.getenv("AI_INSIGHTS_DATABASE_URL")
if complete_url:
return complete_url
# Also check for generic DATABASE_URL (for migration compatibility)
generic_url = os.getenv("DATABASE_URL")
if generic_url:
return generic_url
# Build from components (secure approach)
user = os.getenv("AI_INSIGHTS_DB_USER", "ai_insights_user")
password = os.getenv("AI_INSIGHTS_DB_PASSWORD", "ai_insights_pass123")
host = os.getenv("AI_INSIGHTS_DB_HOST", "localhost")
port = os.getenv("AI_INSIGHTS_DB_PORT", "5432")
name = os.getenv("AI_INSIGHTS_DB_NAME", "ai_insights_db")
return f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{name}"
DB_POOL_SIZE: int = 20
DB_MAX_OVERFLOW: int = 10
# Redis (inherited from BaseServiceSettings but can override)
REDIS_CACHE_TTL: int = 900 # 15 minutes
REDIS_DB: int = 3 # Dedicated Redis database for AI Insights
# Service URLs
FORECASTING_SERVICE_URL: str = "http://forecasting-service:8000"
PROCUREMENT_SERVICE_URL: str = "http://procurement-service:8000"
PRODUCTION_SERVICE_URL: str = "http://production-service:8000"
SALES_SERVICE_URL: str = "http://sales-service:8000"
INVENTORY_SERVICE_URL: str = "http://inventory-service:8000"
# Circuit Breaker Settings
CIRCUIT_BREAKER_FAILURE_THRESHOLD: int = 5
CIRCUIT_BREAKER_TIMEOUT: int = 60
# Insight Settings
MIN_CONFIDENCE_THRESHOLD: int = 60
DEFAULT_INSIGHT_TTL_DAYS: int = 7
MAX_INSIGHTS_PER_REQUEST: int = 100
# Feedback Settings
FEEDBACK_PROCESSING_ENABLED: bool = True
FEEDBACK_PROCESSING_SCHEDULE: str = "0 6 * * *" # Daily at 6 AM
# Logging
LOG_LEVEL: str = "INFO"
# CORS
ALLOWED_ORIGINS: list[str] = ["http://localhost:3000", "http://localhost:5173"]
class Config:
env_file = ".env"
case_sensitive = True
settings = Settings()

View File

@@ -0,0 +1,58 @@
"""Database configuration and session management."""
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.orm import declarative_base
from sqlalchemy.pool import NullPool
from typing import AsyncGenerator
from app.core.config import settings
# Create async engine
engine = create_async_engine(
settings.DATABASE_URL,
pool_size=settings.DB_POOL_SIZE,
max_overflow=settings.DB_MAX_OVERFLOW,
echo=False,
future=True,
)
# Create async session factory
AsyncSessionLocal = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
autocommit=False,
autoflush=False,
)
# Create declarative base
Base = declarative_base()
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""
Dependency for getting async database sessions.
Yields:
AsyncSession: Database session
"""
async with AsyncSessionLocal() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()
async def init_db():
"""Initialize database tables."""
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def close_db():
"""Close database connections."""
await engine.dispose()

View File

@@ -0,0 +1,320 @@
"""Impact estimation for AI Insights."""
from typing import Dict, Any, Optional, Tuple
from decimal import Decimal
from datetime import datetime, timedelta
class ImpactEstimator:
"""
Estimate potential impact of recommendations.
Calculates expected business value in terms of:
- Cost savings (euros)
- Revenue increase (euros)
- Waste reduction (euros or percentage)
- Efficiency gains (hours or percentage)
- Quality improvements (units or percentage)
"""
def estimate_procurement_savings(
self,
current_price: Decimal,
predicted_price: Decimal,
order_quantity: Decimal,
timeframe_days: int = 30
) -> Tuple[Decimal, str, str]:
"""
Estimate savings from opportunistic buying.
Args:
current_price: Current unit price
predicted_price: Predicted future price
order_quantity: Quantity to order
timeframe_days: Time horizon for prediction
Returns:
tuple: (impact_value, impact_unit, impact_type)
"""
savings_per_unit = predicted_price - current_price
if savings_per_unit > 0:
total_savings = savings_per_unit * order_quantity
return (
round(total_savings, 2),
'euros',
'cost_savings'
)
return (Decimal('0.0'), 'euros', 'cost_savings')
def estimate_waste_reduction_savings(
self,
current_waste_rate: float,
optimized_waste_rate: float,
monthly_volume: Decimal,
avg_cost_per_unit: Decimal
) -> Tuple[Decimal, str, str]:
"""
Estimate savings from waste reduction.
Args:
current_waste_rate: Current waste rate (0-1)
optimized_waste_rate: Optimized waste rate (0-1)
monthly_volume: Monthly volume
avg_cost_per_unit: Average cost per unit
Returns:
tuple: (impact_value, impact_unit, impact_type)
"""
waste_reduction_rate = current_waste_rate - optimized_waste_rate
units_saved = monthly_volume * Decimal(str(waste_reduction_rate))
savings = units_saved * avg_cost_per_unit
return (
round(savings, 2),
'euros/month',
'waste_reduction'
)
def estimate_forecast_improvement_value(
self,
current_mape: float,
improved_mape: float,
avg_monthly_revenue: Decimal
) -> Tuple[Decimal, str, str]:
"""
Estimate value from forecast accuracy improvement.
Better forecasts reduce:
- Stockouts (lost sales)
- Overproduction (waste)
- Emergency orders (premium costs)
Args:
current_mape: Current forecast MAPE
improved_mape: Improved forecast MAPE
avg_monthly_revenue: Average monthly revenue
Returns:
tuple: (impact_value, impact_unit, impact_type)
"""
# Rule of thumb: 1% MAPE improvement = 0.5% revenue impact
mape_improvement = current_mape - improved_mape
revenue_impact_pct = mape_improvement * 0.5 / 100
revenue_increase = avg_monthly_revenue * Decimal(str(revenue_impact_pct))
return (
round(revenue_increase, 2),
'euros/month',
'revenue_increase'
)
def estimate_production_efficiency_gain(
self,
time_saved_minutes: int,
batches_per_month: int,
labor_cost_per_hour: Decimal = Decimal('15.0')
) -> Tuple[Decimal, str, str]:
"""
Estimate value from production efficiency improvements.
Args:
time_saved_minutes: Minutes saved per batch
batches_per_month: Number of batches per month
labor_cost_per_hour: Labor cost per hour
Returns:
tuple: (impact_value, impact_unit, impact_type)
"""
hours_saved_per_month = (time_saved_minutes * batches_per_month) / 60
cost_savings = Decimal(str(hours_saved_per_month)) * labor_cost_per_hour
return (
round(cost_savings, 2),
'euros/month',
'efficiency_gain'
)
def estimate_safety_stock_optimization(
self,
current_safety_stock: Decimal,
optimal_safety_stock: Decimal,
holding_cost_per_unit_per_day: Decimal,
stockout_cost_reduction: Decimal = Decimal('0.0')
) -> Tuple[Decimal, str, str]:
"""
Estimate impact of safety stock optimization.
Args:
current_safety_stock: Current safety stock level
optimal_safety_stock: Optimal safety stock level
holding_cost_per_unit_per_day: Daily holding cost
stockout_cost_reduction: Reduction in stockout costs
Returns:
tuple: (impact_value, impact_unit, impact_type)
"""
stock_reduction = current_safety_stock - optimal_safety_stock
if stock_reduction > 0:
# Savings from reduced holding costs
daily_savings = stock_reduction * holding_cost_per_unit_per_day
monthly_savings = daily_savings * 30
total_savings = monthly_savings + stockout_cost_reduction
return (
round(total_savings, 2),
'euros/month',
'cost_savings'
)
elif stock_reduction < 0:
# Cost increase but reduces stockouts
daily_cost = abs(stock_reduction) * holding_cost_per_unit_per_day
monthly_cost = daily_cost * 30
net_savings = stockout_cost_reduction - monthly_cost
if net_savings > 0:
return (
round(net_savings, 2),
'euros/month',
'cost_savings'
)
return (Decimal('0.0'), 'euros/month', 'cost_savings')
def estimate_supplier_switch_savings(
self,
current_supplier_price: Decimal,
alternative_supplier_price: Decimal,
monthly_order_quantity: Decimal,
quality_difference_score: float = 0.0 # -1 to 1
) -> Tuple[Decimal, str, str]:
"""
Estimate savings from switching suppliers.
Args:
current_supplier_price: Current supplier unit price
alternative_supplier_price: Alternative supplier unit price
monthly_order_quantity: Monthly order quantity
quality_difference_score: Quality difference (-1=worse, 0=same, 1=better)
Returns:
tuple: (impact_value, impact_unit, impact_type)
"""
price_savings = (current_supplier_price - alternative_supplier_price) * monthly_order_quantity
# Adjust for quality difference
# If quality is worse, reduce estimated savings
quality_adjustment = 1 + (quality_difference_score * 0.1) # ±10% max adjustment
adjusted_savings = price_savings * Decimal(str(quality_adjustment))
return (
round(adjusted_savings, 2),
'euros/month',
'cost_savings'
)
def estimate_yield_improvement_value(
self,
current_yield_rate: float,
predicted_yield_rate: float,
production_volume: Decimal,
product_price: Decimal
) -> Tuple[Decimal, str, str]:
"""
Estimate value from production yield improvements.
Args:
current_yield_rate: Current yield rate (0-1)
predicted_yield_rate: Predicted yield rate (0-1)
production_volume: Monthly production volume
product_price: Product selling price
Returns:
tuple: (impact_value, impact_unit, impact_type)
"""
yield_improvement = predicted_yield_rate - current_yield_rate
if yield_improvement > 0:
additional_units = production_volume * Decimal(str(yield_improvement))
revenue_increase = additional_units * product_price
return (
round(revenue_increase, 2),
'euros/month',
'revenue_increase'
)
return (Decimal('0.0'), 'euros/month', 'revenue_increase')
def estimate_demand_pattern_value(
self,
pattern_strength: float, # 0-1
potential_revenue_increase: Decimal,
implementation_cost: Decimal = Decimal('0.0')
) -> Tuple[Decimal, str, str]:
"""
Estimate value from acting on demand patterns.
Args:
pattern_strength: Strength of detected pattern (0-1)
potential_revenue_increase: Potential monthly revenue increase
implementation_cost: One-time implementation cost
Returns:
tuple: (impact_value, impact_unit, impact_type)
"""
# Discount by pattern strength (confidence)
expected_value = potential_revenue_increase * Decimal(str(pattern_strength))
# Amortize implementation cost over 6 months
monthly_cost = implementation_cost / 6
net_value = expected_value - monthly_cost
return (
round(max(Decimal('0.0'), net_value), 2),
'euros/month',
'revenue_increase'
)
def estimate_composite_impact(
self,
impacts: list[Dict[str, Any]]
) -> Tuple[Decimal, str, str]:
"""
Combine multiple impact estimations.
Args:
impacts: List of impact dicts with 'value', 'unit', 'type'
Returns:
tuple: (total_impact_value, impact_unit, impact_type)
"""
total_savings = Decimal('0.0')
total_revenue = Decimal('0.0')
for impact in impacts:
value = Decimal(str(impact['value']))
impact_type = impact['type']
if impact_type == 'cost_savings':
total_savings += value
elif impact_type == 'revenue_increase':
total_revenue += value
# Combine both types
total_impact = total_savings + total_revenue
if total_impact > 0:
# Determine primary type
primary_type = 'cost_savings' if total_savings > total_revenue else 'revenue_increase'
return (
round(total_impact, 2),
'euros/month',
primary_type
)
return (Decimal('0.0'), 'euros/month', 'cost_savings')

View File

@@ -0,0 +1,68 @@
"""Main FastAPI application for AI Insights Service."""
import structlog
from app.core.config import settings
from app.core.database import init_db, close_db
from app.api import insights
from shared.service_base import StandardFastAPIService
# Initialize logger
logger = structlog.get_logger()
class AIInsightsService(StandardFastAPIService):
"""AI Insights Service with standardized monitoring setup"""
async def on_startup(self, app):
"""Custom startup logic for AI Insights"""
# Initialize database
await init_db()
logger.info("Database initialized")
await super().on_startup(app)
async def on_shutdown(self, app):
"""Custom shutdown logic for AI Insights"""
await super().on_shutdown(app)
# Close database
await close_db()
logger.info("Database connections closed")
# Create service instance
service = AIInsightsService(
service_name="ai-insights",
app_name="AI Insights Service",
description="Intelligent insights and recommendations for bakery operations",
version=settings.SERVICE_VERSION,
log_level=getattr(settings, 'LOG_LEVEL', 'INFO'),
cors_origins=getattr(settings, 'ALLOWED_ORIGINS', ["*"]),
api_prefix=settings.API_V1_PREFIX,
enable_metrics=True,
enable_health_checks=True,
enable_tracing=True,
enable_cors=True
)
# Create FastAPI app
app = service.create_app()
# Add service-specific routers
service.add_router(
insights.router,
tags=["insights"]
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app.main:app",
host="0.0.0.0",
port=8000,
reload=True,
log_level=settings.LOG_LEVEL.lower()
)

View File

@@ -0,0 +1,672 @@
"""
Feedback Loop & Learning System
Enables continuous improvement through outcome tracking and model retraining
"""
import pandas as pd
import numpy as np
from typing import Dict, List, Any, Optional, Tuple
from datetime import datetime, timedelta
from uuid import UUID
import structlog
from scipy import stats
from collections import defaultdict
logger = structlog.get_logger()
class FeedbackLearningSystem:
"""
Manages feedback collection, model performance tracking, and retraining triggers.
Key Responsibilities:
1. Aggregate feedback from applied insights
2. Calculate model performance metrics (accuracy, precision, recall)
3. Detect performance degradation
4. Trigger automatic retraining when needed
5. Calibrate confidence scores based on actual accuracy
6. Generate learning insights for model improvement
Workflow:
- Feedback continuously recorded via AIInsightsClient
- Periodic performance analysis (daily/weekly)
- Automatic alerts when performance degrades
- Retraining recommendations with priority
"""
def __init__(
self,
performance_threshold: float = 0.85, # Minimum acceptable accuracy
degradation_threshold: float = 0.10, # 10% drop triggers alert
min_feedback_samples: int = 30, # Minimum samples for analysis
retraining_window_days: int = 90 # Consider last 90 days
):
self.performance_threshold = performance_threshold
self.degradation_threshold = degradation_threshold
self.min_feedback_samples = min_feedback_samples
self.retraining_window_days = retraining_window_days
async def analyze_model_performance(
self,
model_name: str,
feedback_data: pd.DataFrame,
baseline_performance: Optional[Dict[str, float]] = None
) -> Dict[str, Any]:
"""
Analyze model performance based on feedback data.
Args:
model_name: Name of the model (e.g., 'hybrid_forecaster', 'yield_predictor')
feedback_data: DataFrame with columns:
- insight_id
- applied_at
- outcome_date
- predicted_value
- actual_value
- error
- error_pct
- accuracy
baseline_performance: Optional baseline metrics for comparison
Returns:
Performance analysis with metrics, trends, and recommendations
"""
logger.info(
"Analyzing model performance",
model_name=model_name,
feedback_samples=len(feedback_data)
)
if len(feedback_data) < self.min_feedback_samples:
return self._insufficient_feedback_response(
model_name, len(feedback_data), self.min_feedback_samples
)
# Step 1: Calculate current performance metrics
current_metrics = self._calculate_performance_metrics(feedback_data)
# Step 2: Analyze performance trend over time
trend_analysis = self._analyze_performance_trend(feedback_data)
# Step 3: Detect performance degradation
degradation_detected = self._detect_performance_degradation(
current_metrics, baseline_performance, trend_analysis
)
# Step 4: Generate retraining recommendation
retraining_recommendation = self._generate_retraining_recommendation(
model_name, current_metrics, degradation_detected, trend_analysis
)
# Step 5: Identify error patterns
error_patterns = self._identify_error_patterns(feedback_data)
# Step 6: Calculate confidence calibration
confidence_calibration = self._calculate_confidence_calibration(feedback_data)
logger.info(
"Model performance analysis complete",
model_name=model_name,
current_accuracy=current_metrics['accuracy'],
degradation_detected=degradation_detected['detected'],
retraining_recommended=retraining_recommendation['recommended']
)
return {
'model_name': model_name,
'analyzed_at': datetime.utcnow().isoformat(),
'feedback_samples': len(feedback_data),
'date_range': {
'start': feedback_data['outcome_date'].min().isoformat(),
'end': feedback_data['outcome_date'].max().isoformat()
},
'current_performance': current_metrics,
'baseline_performance': baseline_performance,
'trend_analysis': trend_analysis,
'degradation_detected': degradation_detected,
'retraining_recommendation': retraining_recommendation,
'error_patterns': error_patterns,
'confidence_calibration': confidence_calibration
}
def _insufficient_feedback_response(
self, model_name: str, current_samples: int, required_samples: int
) -> Dict[str, Any]:
"""Return response when insufficient feedback data."""
return {
'model_name': model_name,
'analyzed_at': datetime.utcnow().isoformat(),
'status': 'insufficient_feedback',
'feedback_samples': current_samples,
'required_samples': required_samples,
'current_performance': None,
'recommendation': f'Need {required_samples - current_samples} more feedback samples for reliable analysis'
}
def _calculate_performance_metrics(
self, feedback_data: pd.DataFrame
) -> Dict[str, float]:
"""
Calculate comprehensive performance metrics.
Metrics:
- Accuracy: % of predictions within acceptable error
- MAE: Mean Absolute Error
- RMSE: Root Mean Squared Error
- MAPE: Mean Absolute Percentage Error
- Bias: Systematic over/under prediction
- R²: Correlation between predicted and actual
"""
predicted = feedback_data['predicted_value'].values
actual = feedback_data['actual_value'].values
# Filter out invalid values
valid_mask = ~(np.isnan(predicted) | np.isnan(actual))
predicted = predicted[valid_mask]
actual = actual[valid_mask]
if len(predicted) == 0:
return {
'accuracy': 0,
'mae': 0,
'rmse': 0,
'mape': 0,
'bias': 0,
'r_squared': 0
}
# Calculate errors
errors = predicted - actual
abs_errors = np.abs(errors)
pct_errors = np.abs(errors / actual) * 100 if np.all(actual != 0) else np.zeros_like(errors)
# MAE and RMSE
mae = float(np.mean(abs_errors))
rmse = float(np.sqrt(np.mean(errors ** 2)))
# MAPE (excluding cases where actual = 0)
valid_pct_mask = actual != 0
mape = float(np.mean(pct_errors[valid_pct_mask])) if np.any(valid_pct_mask) else 0
# Accuracy (% within 10% error)
within_10pct = np.sum(pct_errors <= 10) / len(pct_errors) * 100
# Bias (mean error - positive = over-prediction)
bias = float(np.mean(errors))
# R² (correlation)
if len(predicted) > 1 and np.std(actual) > 0:
correlation = np.corrcoef(predicted, actual)[0, 1]
r_squared = correlation ** 2
else:
r_squared = 0
return {
'accuracy': round(within_10pct, 2), # % within 10% error
'mae': round(mae, 2),
'rmse': round(rmse, 2),
'mape': round(mape, 2),
'bias': round(bias, 2),
'r_squared': round(r_squared, 3),
'sample_size': len(predicted)
}
def _analyze_performance_trend(
self, feedback_data: pd.DataFrame
) -> Dict[str, Any]:
"""
Analyze performance trend over time.
Returns trend direction (improving/stable/degrading) and slope.
"""
# Sort by date
df = feedback_data.sort_values('outcome_date').copy()
# Calculate rolling accuracy (7-day window)
df['rolling_accuracy'] = df['accuracy'].rolling(window=7, min_periods=3).mean()
# Linear trend
if len(df) >= 10:
# Use day index as x
df['day_index'] = (df['outcome_date'] - df['outcome_date'].min()).dt.days
# Fit linear regression
valid_mask = ~np.isnan(df['rolling_accuracy'])
if valid_mask.sum() >= 10:
x = df.loc[valid_mask, 'day_index'].values
y = df.loc[valid_mask, 'rolling_accuracy'].values
slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)
# Determine trend
if p_value < 0.05:
if slope > 0.1:
trend = 'improving'
elif slope < -0.1:
trend = 'degrading'
else:
trend = 'stable'
else:
trend = 'stable'
return {
'trend': trend,
'slope': round(float(slope), 4),
'p_value': round(float(p_value), 4),
'significant': p_value < 0.05,
'recent_performance': round(float(df['rolling_accuracy'].iloc[-1]), 2),
'initial_performance': round(float(df['rolling_accuracy'].dropna().iloc[0]), 2)
}
# Not enough data for trend
return {
'trend': 'insufficient_data',
'slope': 0,
'p_value': 1.0,
'significant': False
}
def _detect_performance_degradation(
self,
current_metrics: Dict[str, float],
baseline_performance: Optional[Dict[str, float]],
trend_analysis: Dict[str, Any]
) -> Dict[str, Any]:
"""
Detect if model performance has degraded.
Degradation triggers:
1. Current accuracy below threshold (85%)
2. Significant drop from baseline (>10%)
3. Degrading trend detected
"""
degradation_reasons = []
severity = 'none'
# Check absolute performance
if current_metrics['accuracy'] < self.performance_threshold * 100:
degradation_reasons.append(
f"Accuracy {current_metrics['accuracy']:.1f}% below threshold {self.performance_threshold*100}%"
)
severity = 'high'
# Check vs baseline
if baseline_performance and 'accuracy' in baseline_performance:
baseline_acc = baseline_performance['accuracy']
current_acc = current_metrics['accuracy']
drop_pct = (baseline_acc - current_acc) / baseline_acc
if drop_pct > self.degradation_threshold:
degradation_reasons.append(
f"Accuracy dropped {drop_pct*100:.1f}% from baseline {baseline_acc:.1f}%"
)
severity = 'high' if severity != 'high' else severity
# Check trend
if trend_analysis.get('trend') == 'degrading' and trend_analysis.get('significant'):
degradation_reasons.append(
f"Degrading trend detected (slope: {trend_analysis['slope']:.4f})"
)
severity = 'medium' if severity == 'none' else severity
detected = len(degradation_reasons) > 0
return {
'detected': detected,
'severity': severity,
'reasons': degradation_reasons,
'current_accuracy': current_metrics['accuracy'],
'baseline_accuracy': baseline_performance.get('accuracy') if baseline_performance else None
}
def _generate_retraining_recommendation(
self,
model_name: str,
current_metrics: Dict[str, float],
degradation_detected: Dict[str, Any],
trend_analysis: Dict[str, Any]
) -> Dict[str, Any]:
"""
Generate retraining recommendation based on performance analysis.
Priority Levels:
- urgent: Severe degradation, retrain immediately
- high: Performance below threshold, retrain soon
- medium: Trending down, schedule retraining
- low: Stable, routine retraining
- none: No retraining needed
"""
if degradation_detected['detected']:
severity = degradation_detected['severity']
if severity == 'high':
priority = 'urgent'
recommendation = f"Retrain {model_name} immediately - severe performance degradation"
elif severity == 'medium':
priority = 'high'
recommendation = f"Schedule {model_name} retraining within 7 days"
else:
priority = 'medium'
recommendation = f"Schedule routine {model_name} retraining"
return {
'recommended': True,
'priority': priority,
'recommendation': recommendation,
'reasons': degradation_detected['reasons'],
'estimated_improvement': self._estimate_retraining_benefit(
current_metrics, degradation_detected
)
}
# Check if routine retraining is due (e.g., every 90 days)
# This would require tracking last_retrained_at
else:
return {
'recommended': False,
'priority': 'none',
'recommendation': f"{model_name} performance is acceptable, no immediate retraining needed",
'next_review_date': (datetime.utcnow() + timedelta(days=30)).isoformat()
}
def _estimate_retraining_benefit(
self,
current_metrics: Dict[str, float],
degradation_detected: Dict[str, Any]
) -> Dict[str, Any]:
"""Estimate expected improvement from retraining."""
baseline_acc = degradation_detected.get('baseline_accuracy')
current_acc = current_metrics['accuracy']
if baseline_acc:
# Expect to recover 70-80% of lost performance
expected_improvement = (baseline_acc - current_acc) * 0.75
expected_new_acc = current_acc + expected_improvement
return {
'expected_accuracy_improvement': round(expected_improvement, 2),
'expected_new_accuracy': round(expected_new_acc, 2),
'confidence': 'medium'
}
return {
'expected_accuracy_improvement': 'unknown',
'confidence': 'low'
}
def _identify_error_patterns(
self, feedback_data: pd.DataFrame
) -> List[Dict[str, Any]]:
"""
Identify systematic error patterns.
Patterns:
- Consistent over/under prediction
- Higher errors for specific ranges
- Day-of-week effects
- Seasonal effects
"""
patterns = []
# Pattern 1: Systematic bias
mean_error = feedback_data['error'].mean()
if abs(mean_error) > feedback_data['error'].std() * 0.5:
direction = 'over-prediction' if mean_error > 0 else 'under-prediction'
patterns.append({
'pattern': 'systematic_bias',
'description': f'Consistent {direction} by {abs(mean_error):.1f} units',
'severity': 'high' if abs(mean_error) > 10 else 'medium',
'recommendation': 'Recalibrate model bias term'
})
# Pattern 2: High error for large values
if 'predicted_value' in feedback_data.columns:
# Split into quartiles
feedback_data['value_quartile'] = pd.qcut(
feedback_data['predicted_value'],
q=4,
labels=['Q1', 'Q2', 'Q3', 'Q4'],
duplicates='drop'
)
quartile_errors = feedback_data.groupby('value_quartile')['error_pct'].mean()
if len(quartile_errors) == 4 and quartile_errors['Q4'] > quartile_errors['Q1'] * 1.5:
patterns.append({
'pattern': 'high_value_error',
'description': f'Higher errors for large predictions (Q4: {quartile_errors["Q4"]:.1f}% vs Q1: {quartile_errors["Q1"]:.1f}%)',
'severity': 'medium',
'recommendation': 'Add log transformation or separate model for high values'
})
# Pattern 3: Day-of-week effect
if 'outcome_date' in feedback_data.columns:
feedback_data['day_of_week'] = pd.to_datetime(feedback_data['outcome_date']).dt.dayofweek
dow_errors = feedback_data.groupby('day_of_week')['error_pct'].mean()
if len(dow_errors) >= 5 and dow_errors.max() > dow_errors.min() * 1.5:
worst_day = dow_errors.idxmax()
day_names = ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']
patterns.append({
'pattern': 'day_of_week_effect',
'description': f'Higher errors on {day_names[worst_day]} ({dow_errors[worst_day]:.1f}%)',
'severity': 'low',
'recommendation': 'Add day-of-week features to model'
})
return patterns
def _calculate_confidence_calibration(
self, feedback_data: pd.DataFrame
) -> Dict[str, Any]:
"""
Calculate how well confidence scores match actual accuracy.
Well-calibrated model: 80% confidence → 80% accuracy
"""
if 'confidence' not in feedback_data.columns:
return {'calibrated': False, 'reason': 'No confidence scores available'}
# Bin by confidence ranges
feedback_data['confidence_bin'] = pd.cut(
feedback_data['confidence'],
bins=[0, 60, 70, 80, 90, 100],
labels=['<60', '60-70', '70-80', '80-90', '90+']
)
calibration_results = []
for conf_bin in feedback_data['confidence_bin'].unique():
if pd.isna(conf_bin):
continue
bin_data = feedback_data[feedback_data['confidence_bin'] == conf_bin]
if len(bin_data) >= 5:
avg_confidence = bin_data['confidence'].mean()
avg_accuracy = bin_data['accuracy'].mean()
calibration_error = abs(avg_confidence - avg_accuracy)
calibration_results.append({
'confidence_range': str(conf_bin),
'avg_confidence': round(avg_confidence, 1),
'avg_accuracy': round(avg_accuracy, 1),
'calibration_error': round(calibration_error, 1),
'sample_size': len(bin_data),
'well_calibrated': calibration_error < 10
})
# Overall calibration
if calibration_results:
overall_calibration_error = np.mean([r['calibration_error'] for r in calibration_results])
well_calibrated = overall_calibration_error < 10
return {
'calibrated': well_calibrated,
'overall_calibration_error': round(overall_calibration_error, 2),
'by_confidence_range': calibration_results,
'recommendation': 'Confidence scores are well-calibrated' if well_calibrated
else 'Recalibrate confidence scoring algorithm'
}
return {'calibrated': False, 'reason': 'Insufficient data for calibration analysis'}
async def generate_learning_insights(
self,
performance_analyses: List[Dict[str, Any]],
tenant_id: str
) -> List[Dict[str, Any]]:
"""
Generate high-level insights about learning system performance.
Args:
performance_analyses: List of model performance analyses
tenant_id: Tenant identifier
Returns:
Learning insights for system improvement
"""
insights = []
# Insight 1: Models needing urgent retraining
urgent_models = [
a for a in performance_analyses
if a.get('retraining_recommendation', {}).get('priority') == 'urgent'
]
if urgent_models:
model_names = ', '.join([a['model_name'] for a in urgent_models])
insights.append({
'type': 'warning',
'priority': 'urgent',
'category': 'system',
'title': f'Urgent Model Retraining Required: {len(urgent_models)} Models',
'description': f'Models requiring immediate retraining: {model_names}. Performance has degraded significantly.',
'impact_type': 'system_health',
'confidence': 95,
'metrics_json': {
'tenant_id': tenant_id,
'urgent_models': [a['model_name'] for a in urgent_models],
'affected_count': len(urgent_models)
},
'actionable': True,
'recommendation_actions': [{
'label': 'Retrain Models',
'action': 'trigger_model_retraining',
'params': {'models': [a['model_name'] for a in urgent_models]}
}],
'source_service': 'ai_insights',
'source_model': 'feedback_learning_system'
})
# Insight 2: Overall system health
total_models = len(performance_analyses)
healthy_models = [
a for a in performance_analyses
if not a.get('degradation_detected', {}).get('detected', False)
]
health_pct = (len(healthy_models) / total_models * 100) if total_models > 0 else 0
if health_pct < 80:
insights.append({
'type': 'warning',
'priority': 'high',
'category': 'system',
'title': f'Learning System Health: {health_pct:.0f}%',
'description': f'{len(healthy_models)} of {total_models} models are performing well. System-wide performance review recommended.',
'impact_type': 'system_health',
'confidence': 90,
'metrics_json': {
'tenant_id': tenant_id,
'total_models': total_models,
'healthy_models': len(healthy_models),
'health_percentage': round(health_pct, 1)
},
'actionable': True,
'recommendation_actions': [{
'label': 'Review System Health',
'action': 'review_learning_system',
'params': {'tenant_id': tenant_id}
}],
'source_service': 'ai_insights',
'source_model': 'feedback_learning_system'
})
# Insight 3: Confidence calibration issues
poorly_calibrated = [
a for a in performance_analyses
if not a.get('confidence_calibration', {}).get('calibrated', True)
]
if poorly_calibrated:
insights.append({
'type': 'opportunity',
'priority': 'medium',
'category': 'system',
'title': f'Confidence Calibration Needed: {len(poorly_calibrated)} Models',
'description': 'Confidence scores do not match actual accuracy. Recalibration recommended.',
'impact_type': 'system_improvement',
'confidence': 85,
'metrics_json': {
'tenant_id': tenant_id,
'models_needing_calibration': [a['model_name'] for a in poorly_calibrated]
},
'actionable': True,
'recommendation_actions': [{
'label': 'Recalibrate Confidence Scores',
'action': 'recalibrate_confidence',
'params': {'models': [a['model_name'] for a in poorly_calibrated]}
}],
'source_service': 'ai_insights',
'source_model': 'feedback_learning_system'
})
return insights
async def calculate_roi(
self,
feedback_data: pd.DataFrame,
insight_type: str
) -> Dict[str, Any]:
"""
Calculate ROI for applied insights.
Args:
feedback_data: Feedback data with business impact metrics
insight_type: Type of insight (e.g., 'demand_forecast', 'safety_stock')
Returns:
ROI calculation with cost savings and accuracy metrics
"""
if len(feedback_data) == 0:
return {'status': 'insufficient_data', 'samples': 0}
# Calculate accuracy
avg_accuracy = feedback_data['accuracy'].mean()
# Estimate cost savings (would be more sophisticated in production)
# For now, use impact_value from insights if available
if 'impact_value' in feedback_data.columns:
total_impact = feedback_data['impact_value'].sum()
avg_impact = feedback_data['impact_value'].mean()
return {
'insight_type': insight_type,
'samples': len(feedback_data),
'avg_accuracy': round(avg_accuracy, 2),
'total_impact_value': round(total_impact, 2),
'avg_impact_per_insight': round(avg_impact, 2),
'roi_validated': True
}
return {
'insight_type': insight_type,
'samples': len(feedback_data),
'avg_accuracy': round(avg_accuracy, 2),
'roi_validated': False,
'note': 'Impact values not tracked in feedback'
}

View File

@@ -0,0 +1,11 @@
"""Database models for AI Insights Service."""
from app.models.ai_insight import AIInsight
from app.models.insight_feedback import InsightFeedback
from app.models.insight_correlation import InsightCorrelation
__all__ = [
"AIInsight",
"InsightFeedback",
"InsightCorrelation",
]

View File

@@ -0,0 +1,129 @@
"""AI Insight database model."""
from sqlalchemy import Column, String, Integer, Boolean, DECIMAL, TIMESTAMP, Text, Index, CheckConstraint
from sqlalchemy.dialects.postgresql import UUID, JSONB
from sqlalchemy.sql import func
import uuid
from app.core.database import Base
class AIInsight(Base):
"""AI Insight model for storing intelligent recommendations and predictions."""
__tablename__ = "ai_insights"
# Primary Key
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Tenant Information
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
# Classification
type = Column(
String(50),
nullable=False,
index=True,
comment="optimization, alert, prediction, recommendation, insight, anomaly"
)
priority = Column(
String(20),
nullable=False,
index=True,
comment="low, medium, high, critical"
)
category = Column(
String(50),
nullable=False,
index=True,
comment="forecasting, inventory, production, procurement, customer, cost, quality, efficiency, demand, maintenance, energy, scheduling"
)
# Content
title = Column(String(255), nullable=False)
description = Column(Text, nullable=False)
# Impact Information
impact_type = Column(
String(50),
comment="cost_savings, revenue_increase, waste_reduction, efficiency_gain, quality_improvement, risk_mitigation"
)
impact_value = Column(DECIMAL(10, 2), comment="Numeric impact value")
impact_unit = Column(
String(20),
comment="euros, percentage, hours, units, euros/month, euros/year"
)
# Confidence and Metrics
confidence = Column(
Integer,
CheckConstraint('confidence >= 0 AND confidence <= 100'),
nullable=False,
index=True,
comment="Confidence score 0-100"
)
metrics_json = Column(
JSONB,
comment="Dynamic metrics specific to insight type"
)
# Actionability
actionable = Column(
Boolean,
default=True,
nullable=False,
index=True,
comment="Whether this insight can be acted upon"
)
recommendation_actions = Column(
JSONB,
comment="List of possible actions: [{label, action, endpoint}]"
)
# Status
status = Column(
String(20),
default='new',
nullable=False,
index=True,
comment="new, acknowledged, in_progress, applied, dismissed, expired"
)
# Source Information
source_service = Column(
String(50),
comment="Service that generated this insight"
)
source_data_id = Column(
String(100),
comment="Reference to source data (e.g., forecast_id, model_id)"
)
# Timestamps
created_at = Column(
TIMESTAMP(timezone=True),
server_default=func.now(),
nullable=False,
index=True
)
updated_at = Column(
TIMESTAMP(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False
)
applied_at = Column(TIMESTAMP(timezone=True), comment="When insight was applied")
expired_at = Column(
TIMESTAMP(timezone=True),
comment="When insight expires (auto-calculated based on TTL)"
)
# Composite Indexes
__table_args__ = (
Index('idx_tenant_status_category', 'tenant_id', 'status', 'category'),
Index('idx_tenant_created_confidence', 'tenant_id', 'created_at', 'confidence'),
Index('idx_actionable_status', 'actionable', 'status'),
)
def __repr__(self):
return f"<AIInsight(id={self.id}, type={self.type}, title={self.title[:30]}, confidence={self.confidence})>"

View File

@@ -0,0 +1,69 @@
"""Insight Correlation database model for cross-service intelligence."""
from sqlalchemy import Column, String, Integer, DECIMAL, TIMESTAMP, ForeignKey, Index
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
import uuid
from app.core.database import Base
class InsightCorrelation(Base):
"""Track correlations between insights from different services."""
__tablename__ = "insight_correlations"
# Primary Key
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Foreign Keys to AIInsights
parent_insight_id = Column(
UUID(as_uuid=True),
ForeignKey('ai_insights.id', ondelete='CASCADE'),
nullable=False,
index=True,
comment="Primary insight that leads to correlation"
)
child_insight_id = Column(
UUID(as_uuid=True),
ForeignKey('ai_insights.id', ondelete='CASCADE'),
nullable=False,
index=True,
comment="Related insight"
)
# Correlation Information
correlation_type = Column(
String(50),
nullable=False,
comment="forecast_inventory, production_procurement, weather_customer, demand_supplier, etc."
)
correlation_strength = Column(
DECIMAL(3, 2),
nullable=False,
comment="0.00 to 1.00 indicating strength of correlation"
)
# Combined Metrics
combined_confidence = Column(
Integer,
comment="Weighted combined confidence of both insights"
)
# Timestamp
created_at = Column(
TIMESTAMP(timezone=True),
server_default=func.now(),
nullable=False,
index=True
)
# Composite Indexes
__table_args__ = (
Index('idx_parent_child', 'parent_insight_id', 'child_insight_id'),
Index('idx_correlation_type', 'correlation_type'),
)
def __repr__(self):
return f"<InsightCorrelation(id={self.id}, type={self.correlation_type}, strength={self.correlation_strength})>"

View File

@@ -0,0 +1,87 @@
"""Insight Feedback database model for closed-loop learning."""
from sqlalchemy import Column, String, Boolean, DECIMAL, TIMESTAMP, Text, ForeignKey, Index
from sqlalchemy.dialects.postgresql import UUID, JSONB
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
import uuid
from app.core.database import Base
class InsightFeedback(Base):
"""Feedback tracking for AI Insights to enable learning."""
__tablename__ = "insight_feedback"
# Primary Key
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Foreign Key to AIInsight
insight_id = Column(
UUID(as_uuid=True),
ForeignKey('ai_insights.id', ondelete='CASCADE'),
nullable=False,
index=True
)
# Action Information
action_taken = Column(
String(100),
comment="Specific action that was taken from recommendation_actions"
)
# Result Data
result_data = Column(
JSONB,
comment="Detailed result data from applying the insight"
)
# Success Tracking
success = Column(
Boolean,
nullable=False,
index=True,
comment="Whether the insight application was successful"
)
error_message = Column(
Text,
comment="Error message if success = false"
)
# Impact Comparison
expected_impact_value = Column(
DECIMAL(10, 2),
comment="Expected impact value from original insight"
)
actual_impact_value = Column(
DECIMAL(10, 2),
comment="Measured actual impact after application"
)
variance_percentage = Column(
DECIMAL(5, 2),
comment="(actual - expected) / expected * 100"
)
# User Information
applied_by = Column(
String(100),
comment="User or system that applied the insight"
)
# Timestamp
created_at = Column(
TIMESTAMP(timezone=True),
server_default=func.now(),
nullable=False,
index=True
)
# Composite Indexes
__table_args__ = (
Index('idx_insight_success', 'insight_id', 'success'),
Index('idx_created_success', 'created_at', 'success'),
)
def __repr__(self):
return f"<InsightFeedback(id={self.id}, insight_id={self.insight_id}, success={self.success})>"

View File

@@ -0,0 +1,9 @@
"""Repositories for AI Insights Service."""
from app.repositories.insight_repository import InsightRepository
from app.repositories.feedback_repository import FeedbackRepository
__all__ = [
"InsightRepository",
"FeedbackRepository",
]

View File

@@ -0,0 +1,81 @@
"""Repository for Insight Feedback database operations."""
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, desc
from typing import Optional, List
from uuid import UUID
from decimal import Decimal
from app.models.insight_feedback import InsightFeedback
from app.schemas.feedback import InsightFeedbackCreate
class FeedbackRepository:
"""Repository for Insight Feedback operations."""
def __init__(self, session: AsyncSession):
self.session = session
async def create(self, feedback_data: InsightFeedbackCreate) -> InsightFeedback:
"""Create feedback for an insight."""
# Calculate variance if both values provided
variance = None
if (feedback_data.expected_impact_value is not None and
feedback_data.actual_impact_value is not None and
feedback_data.expected_impact_value != 0):
variance = (
(feedback_data.actual_impact_value - feedback_data.expected_impact_value) /
feedback_data.expected_impact_value * 100
)
feedback = InsightFeedback(
**feedback_data.model_dump(exclude={'variance_percentage'}),
variance_percentage=variance
)
self.session.add(feedback)
await self.session.flush()
await self.session.refresh(feedback)
return feedback
async def get_by_id(self, feedback_id: UUID) -> Optional[InsightFeedback]:
"""Get feedback by ID."""
query = select(InsightFeedback).where(InsightFeedback.id == feedback_id)
result = await self.session.execute(query)
return result.scalar_one_or_none()
async def get_by_insight(self, insight_id: UUID) -> List[InsightFeedback]:
"""Get all feedback for an insight."""
query = select(InsightFeedback).where(
InsightFeedback.insight_id == insight_id
).order_by(desc(InsightFeedback.created_at))
result = await self.session.execute(query)
return list(result.scalars().all())
async def get_success_rate(self, insight_type: Optional[str] = None) -> float:
"""Calculate success rate for insights."""
query = select(InsightFeedback)
result = await self.session.execute(query)
feedbacks = result.scalars().all()
if not feedbacks:
return 0.0
successful = sum(1 for f in feedbacks if f.success)
return (successful / len(feedbacks)) * 100
async def get_average_impact_variance(self) -> Decimal:
"""Calculate average variance between expected and actual impact."""
query = select(InsightFeedback).where(
InsightFeedback.variance_percentage.isnot(None)
)
result = await self.session.execute(query)
feedbacks = result.scalars().all()
if not feedbacks:
return Decimal('0.0')
avg_variance = sum(f.variance_percentage for f in feedbacks) / len(feedbacks)
return Decimal(str(round(float(avg_variance), 2)))

View File

@@ -0,0 +1,254 @@
"""Repository for AI Insight database operations."""
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, and_, or_, desc
from sqlalchemy.orm import selectinload
from typing import Optional, List, Dict, Any
from uuid import UUID
from datetime import datetime, timedelta
from app.models.ai_insight import AIInsight
from app.schemas.insight import AIInsightCreate, AIInsightUpdate, InsightFilters
class InsightRepository:
"""Repository for AI Insight operations."""
def __init__(self, session: AsyncSession):
self.session = session
async def create(self, insight_data: AIInsightCreate) -> AIInsight:
"""Create a new AI Insight."""
# Calculate expiration date (default 7 days from now)
from app.core.config import settings
expired_at = datetime.utcnow() + timedelta(days=settings.DEFAULT_INSIGHT_TTL_DAYS)
insight = AIInsight(
**insight_data.model_dump(),
expired_at=expired_at
)
self.session.add(insight)
await self.session.flush()
await self.session.refresh(insight)
return insight
async def get_by_id(self, insight_id: UUID) -> Optional[AIInsight]:
"""Get insight by ID."""
query = select(AIInsight).where(AIInsight.id == insight_id)
result = await self.session.execute(query)
return result.scalar_one_or_none()
async def get_by_tenant(
self,
tenant_id: UUID,
filters: Optional[InsightFilters] = None,
skip: int = 0,
limit: int = 100
) -> tuple[List[AIInsight], int]:
"""Get insights for a tenant with filters and pagination."""
# Build base query
query = select(AIInsight).where(AIInsight.tenant_id == tenant_id)
# Apply filters
if filters:
if filters.category and filters.category != 'all':
query = query.where(AIInsight.category == filters.category)
if filters.priority and filters.priority != 'all':
query = query.where(AIInsight.priority == filters.priority)
if filters.status and filters.status != 'all':
query = query.where(AIInsight.status == filters.status)
if filters.actionable_only:
query = query.where(AIInsight.actionable == True)
if filters.min_confidence > 0:
query = query.where(AIInsight.confidence >= filters.min_confidence)
if filters.source_service:
query = query.where(AIInsight.source_service == filters.source_service)
if filters.from_date:
query = query.where(AIInsight.created_at >= filters.from_date)
if filters.to_date:
query = query.where(AIInsight.created_at <= filters.to_date)
# Get total count
count_query = select(func.count()).select_from(query.subquery())
total_result = await self.session.execute(count_query)
total = total_result.scalar() or 0
# Apply ordering, pagination
query = query.order_by(desc(AIInsight.confidence), desc(AIInsight.created_at))
query = query.offset(skip).limit(limit)
# Execute query
result = await self.session.execute(query)
insights = result.scalars().all()
return list(insights), total
async def get_orchestration_ready_insights(
self,
tenant_id: UUID,
target_date: datetime,
min_confidence: int = 70
) -> Dict[str, List[AIInsight]]:
"""Get actionable insights for orchestration."""
query = select(AIInsight).where(
and_(
AIInsight.tenant_id == tenant_id,
AIInsight.actionable == True,
AIInsight.confidence >= min_confidence,
AIInsight.status.in_(['new', 'acknowledged']),
or_(
AIInsight.expired_at.is_(None),
AIInsight.expired_at > datetime.utcnow()
)
)
).order_by(desc(AIInsight.confidence))
result = await self.session.execute(query)
insights = result.scalars().all()
# Categorize insights
categorized = {
'forecast_adjustments': [],
'procurement_recommendations': [],
'production_optimizations': [],
'supplier_alerts': [],
'price_opportunities': []
}
for insight in insights:
if insight.category == 'forecasting':
categorized['forecast_adjustments'].append(insight)
elif insight.category == 'procurement':
if 'supplier' in insight.title.lower():
categorized['supplier_alerts'].append(insight)
elif 'price' in insight.title.lower():
categorized['price_opportunities'].append(insight)
else:
categorized['procurement_recommendations'].append(insight)
elif insight.category == 'production':
categorized['production_optimizations'].append(insight)
return categorized
async def update(self, insight_id: UUID, update_data: AIInsightUpdate) -> Optional[AIInsight]:
"""Update an insight."""
insight = await self.get_by_id(insight_id)
if not insight:
return None
for field, value in update_data.model_dump(exclude_unset=True).items():
setattr(insight, field, value)
await self.session.flush()
await self.session.refresh(insight)
return insight
async def delete(self, insight_id: UUID) -> bool:
"""Delete (dismiss) an insight."""
insight = await self.get_by_id(insight_id)
if not insight:
return False
insight.status = 'dismissed'
await self.session.flush()
return True
async def get_metrics(self, tenant_id: UUID) -> Dict[str, Any]:
"""Get aggregate metrics for insights."""
query = select(AIInsight).where(
and_(
AIInsight.tenant_id == tenant_id,
AIInsight.status != 'dismissed',
or_(
AIInsight.expired_at.is_(None),
AIInsight.expired_at > datetime.utcnow()
)
)
)
result = await self.session.execute(query)
insights = result.scalars().all()
if not insights:
return {
'total_insights': 0,
'actionable_insights': 0,
'average_confidence': 0,
'high_priority_count': 0,
'medium_priority_count': 0,
'low_priority_count': 0,
'critical_priority_count': 0,
'by_category': {},
'by_status': {},
'total_potential_impact': 0
}
# Calculate metrics
total = len(insights)
actionable = sum(1 for i in insights if i.actionable)
avg_confidence = sum(i.confidence for i in insights) / total if total > 0 else 0
# Priority counts
priority_counts = {
'high': sum(1 for i in insights if i.priority == 'high'),
'medium': sum(1 for i in insights if i.priority == 'medium'),
'low': sum(1 for i in insights if i.priority == 'low'),
'critical': sum(1 for i in insights if i.priority == 'critical')
}
# By category
by_category = {}
for insight in insights:
by_category[insight.category] = by_category.get(insight.category, 0) + 1
# By status
by_status = {}
for insight in insights:
by_status[insight.status] = by_status.get(insight.status, 0) + 1
# Total potential impact
total_impact = sum(
float(i.impact_value) for i in insights
if i.impact_value and i.impact_type in ['cost_savings', 'revenue_increase']
)
return {
'total_insights': total,
'actionable_insights': actionable,
'average_confidence': round(avg_confidence, 1),
'high_priority_count': priority_counts['high'],
'medium_priority_count': priority_counts['medium'],
'low_priority_count': priority_counts['low'],
'critical_priority_count': priority_counts['critical'],
'by_category': by_category,
'by_status': by_status,
'total_potential_impact': round(total_impact, 2)
}
async def expire_old_insights(self) -> int:
"""Mark expired insights as expired."""
query = select(AIInsight).where(
and_(
AIInsight.expired_at.isnot(None),
AIInsight.expired_at <= datetime.utcnow(),
AIInsight.status.notin_(['applied', 'dismissed', 'expired'])
)
)
result = await self.session.execute(query)
insights = result.scalars().all()
count = 0
for insight in insights:
insight.status = 'expired'
count += 1
await self.session.flush()
return count

View File

@@ -0,0 +1,27 @@
"""Pydantic schemas for AI Insights Service."""
from app.schemas.insight import (
AIInsightBase,
AIInsightCreate,
AIInsightUpdate,
AIInsightResponse,
AIInsightList,
InsightMetrics,
InsightFilters
)
from app.schemas.feedback import (
InsightFeedbackCreate,
InsightFeedbackResponse
)
__all__ = [
"AIInsightBase",
"AIInsightCreate",
"AIInsightUpdate",
"AIInsightResponse",
"AIInsightList",
"InsightMetrics",
"InsightFilters",
"InsightFeedbackCreate",
"InsightFeedbackResponse",
]

View File

@@ -0,0 +1,37 @@
"""Pydantic schemas for Insight Feedback."""
from pydantic import BaseModel, Field, ConfigDict
from typing import Optional, Dict, Any
from datetime import datetime
from uuid import UUID
from decimal import Decimal
class InsightFeedbackBase(BaseModel):
"""Base schema for Insight Feedback."""
action_taken: str
result_data: Optional[Dict[str, Any]] = Field(default_factory=dict)
success: bool
error_message: Optional[str] = None
expected_impact_value: Optional[Decimal] = None
actual_impact_value: Optional[Decimal] = None
variance_percentage: Optional[Decimal] = None
class InsightFeedbackCreate(InsightFeedbackBase):
"""Schema for creating feedback."""
insight_id: UUID
applied_by: Optional[str] = "system"
class InsightFeedbackResponse(InsightFeedbackBase):
"""Schema for feedback response."""
id: UUID
insight_id: UUID
applied_by: str
created_at: datetime
model_config = ConfigDict(from_attributes=True)

View File

@@ -0,0 +1,93 @@
"""Pydantic schemas for AI Insights."""
from pydantic import BaseModel, Field, ConfigDict
from typing import Optional, Dict, Any, List
from datetime import datetime
from uuid import UUID
from decimal import Decimal
class AIInsightBase(BaseModel):
"""Base schema for AI Insight."""
type: str = Field(..., description="optimization, alert, prediction, recommendation, insight, anomaly")
priority: str = Field(..., description="low, medium, high, critical")
category: str = Field(..., description="forecasting, inventory, production, procurement, customer, etc.")
title: str = Field(..., max_length=255)
description: str
impact_type: Optional[str] = Field(None, description="cost_savings, revenue_increase, waste_reduction, etc.")
impact_value: Optional[Decimal] = None
impact_unit: Optional[str] = Field(None, description="euros, percentage, hours, units, etc.")
confidence: int = Field(..., ge=0, le=100, description="Confidence score 0-100")
metrics_json: Optional[Dict[str, Any]] = Field(default_factory=dict)
actionable: bool = True
recommendation_actions: Optional[List[Dict[str, str]]] = Field(default_factory=list)
source_service: Optional[str] = None
source_data_id: Optional[str] = None
class AIInsightCreate(AIInsightBase):
"""Schema for creating a new AI Insight."""
tenant_id: UUID
class AIInsightUpdate(BaseModel):
"""Schema for updating an AI Insight."""
status: Optional[str] = Field(None, description="new, acknowledged, in_progress, applied, dismissed, expired")
applied_at: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
class AIInsightResponse(AIInsightBase):
"""Schema for AI Insight response."""
id: UUID
tenant_id: UUID
status: str
created_at: datetime
updated_at: datetime
applied_at: Optional[datetime] = None
expired_at: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
class AIInsightList(BaseModel):
"""Paginated list of AI Insights."""
items: List[AIInsightResponse]
total: int
page: int
page_size: int
total_pages: int
class InsightMetrics(BaseModel):
"""Aggregate metrics for insights."""
total_insights: int
actionable_insights: int
average_confidence: float
high_priority_count: int
medium_priority_count: int
low_priority_count: int
critical_priority_count: int
by_category: Dict[str, int]
by_status: Dict[str, int]
total_potential_impact: Optional[Decimal] = None
class InsightFilters(BaseModel):
"""Filters for querying insights."""
category: Optional[str] = None
priority: Optional[str] = None
status: Optional[str] = None
actionable_only: bool = False
min_confidence: int = 0
source_service: Optional[str] = None
from_date: Optional[datetime] = None
to_date: Optional[datetime] = None

View File

@@ -0,0 +1,229 @@
"""Confidence scoring calculator for AI Insights."""
from typing import Dict, Any, Optional
from datetime import datetime, timedelta
import math
class ConfidenceCalculator:
"""
Calculate unified confidence scores across different insight types.
Confidence is calculated based on multiple factors:
- Data quality (completeness, consistency)
- Model performance (historical accuracy)
- Sample size (statistical significance)
- Recency (how recent is the data)
- Historical accuracy (past insight performance)
"""
# Weights for different factors
WEIGHTS = {
'data_quality': 0.25,
'model_performance': 0.30,
'sample_size': 0.20,
'recency': 0.15,
'historical_accuracy': 0.10
}
def calculate_confidence(
self,
data_quality_score: Optional[float] = None,
model_performance_score: Optional[float] = None,
sample_size: Optional[int] = None,
data_date: Optional[datetime] = None,
historical_accuracy: Optional[float] = None,
insight_type: Optional[str] = None
) -> int:
"""
Calculate overall confidence score (0-100).
Args:
data_quality_score: 0-1 score for data quality
model_performance_score: 0-1 score from model metrics (e.g., 1-MAPE)
sample_size: Number of data points used
data_date: Date of most recent data
historical_accuracy: 0-1 score from past insight performance
insight_type: Type of insight for specific adjustments
Returns:
int: Confidence score 0-100
"""
scores = {}
# Data Quality Score (0-100)
if data_quality_score is not None:
scores['data_quality'] = min(100, data_quality_score * 100)
else:
scores['data_quality'] = 70 # Default
# Model Performance Score (0-100)
if model_performance_score is not None:
scores['model_performance'] = min(100, model_performance_score * 100)
else:
scores['model_performance'] = 75 # Default
# Sample Size Score (0-100)
if sample_size is not None:
scores['sample_size'] = self._score_sample_size(sample_size)
else:
scores['sample_size'] = 60 # Default
# Recency Score (0-100)
if data_date is not None:
scores['recency'] = self._score_recency(data_date)
else:
scores['recency'] = 80 # Default
# Historical Accuracy Score (0-100)
if historical_accuracy is not None:
scores['historical_accuracy'] = min(100, historical_accuracy * 100)
else:
scores['historical_accuracy'] = 65 # Default
# Calculate weighted average
confidence = sum(
scores[factor] * self.WEIGHTS[factor]
for factor in scores
)
# Apply insight-type specific adjustments
confidence = self._apply_type_adjustments(confidence, insight_type)
return int(round(confidence))
def _score_sample_size(self, sample_size: int) -> float:
"""
Score based on sample size using logarithmic scale.
Args:
sample_size: Number of data points
Returns:
float: Score 0-100
"""
if sample_size <= 10:
return 30.0
elif sample_size <= 30:
return 50.0
elif sample_size <= 100:
return 70.0
elif sample_size <= 365:
return 85.0
else:
# Logarithmic scaling for larger samples
return min(100.0, 85 + (math.log10(sample_size) - math.log10(365)) * 10)
def _score_recency(self, data_date: datetime) -> float:
"""
Score based on data recency.
Args:
data_date: Date of most recent data
Returns:
float: Score 0-100
"""
days_old = (datetime.utcnow() - data_date).days
if days_old == 0:
return 100.0
elif days_old <= 1:
return 95.0
elif days_old <= 3:
return 90.0
elif days_old <= 7:
return 80.0
elif days_old <= 14:
return 70.0
elif days_old <= 30:
return 60.0
elif days_old <= 60:
return 45.0
else:
# Exponential decay for older data
return max(20.0, 60 * math.exp(-days_old / 60))
def _apply_type_adjustments(self, base_confidence: float, insight_type: Optional[str]) -> float:
"""
Apply insight-type specific confidence adjustments.
Args:
base_confidence: Base confidence score
insight_type: Type of insight
Returns:
float: Adjusted confidence
"""
if not insight_type:
return base_confidence
adjustments = {
'prediction': -5, # Predictions inherently less certain
'optimization': +2, # Optimizations based on solid math
'alert': +3, # Alerts based on thresholds
'recommendation': 0, # No adjustment
'insight': +2, # Insights from data analysis
'anomaly': -3 # Anomalies are uncertain
}
adjustment = adjustments.get(insight_type, 0)
return max(0, min(100, base_confidence + adjustment))
def calculate_forecast_confidence(
self,
model_mape: float,
forecast_horizon_days: int,
data_points: int,
last_data_date: datetime
) -> int:
"""
Specialized confidence calculation for forecasting insights.
Args:
model_mape: Model MAPE (Mean Absolute Percentage Error)
forecast_horizon_days: How many days ahead
data_points: Number of historical data points
last_data_date: Date of last training data
Returns:
int: Confidence score 0-100
"""
# Model performance: 1 - (MAPE/100) capped at 1
model_score = max(0, 1 - (model_mape / 100))
# Horizon penalty: Longer horizons = less confidence
horizon_factor = max(0.5, 1 - (forecast_horizon_days / 30))
return self.calculate_confidence(
data_quality_score=0.9, # Assume good quality
model_performance_score=model_score * horizon_factor,
sample_size=data_points,
data_date=last_data_date,
insight_type='prediction'
)
def calculate_optimization_confidence(
self,
calculation_accuracy: float,
data_completeness: float,
sample_size: int
) -> int:
"""
Confidence for optimization recommendations.
Args:
calculation_accuracy: 0-1 score for optimization calculation reliability
data_completeness: 0-1 score for data completeness
sample_size: Number of data points
Returns:
int: Confidence score 0-100
"""
return self.calculate_confidence(
data_quality_score=data_completeness,
model_performance_score=calculation_accuracy,
sample_size=sample_size,
data_date=datetime.utcnow(),
insight_type='optimization'
)

View File

@@ -0,0 +1,67 @@
"""Alembic environment configuration."""
from logging.config import fileConfig
from sqlalchemy import engine_from_config, pool
from alembic import context
import os
import sys
# Add parent directory to path for imports
sys.path.insert(0, os.path.realpath(os.path.join(os.path.dirname(__file__), '..')))
from app.core.config import settings
from app.core.database import Base
from app.models import * # Import all models
# this is the Alembic Config object
config = context.config
# Interpret the config file for Python logging
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# Set sqlalchemy.url from settings
# Replace asyncpg with psycopg2 for synchronous Alembic migrations
db_url = settings.DATABASE_URL.replace('postgresql+asyncpg://', 'postgresql://')
config.set_main_option('sqlalchemy.url', db_url)
# Add your model's MetaData object here for 'autogenerate' support
target_metadata = Base.metadata
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode."""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode."""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection,
target_metadata=target_metadata
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@@ -0,0 +1,26 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View File

@@ -0,0 +1,111 @@
"""Initial schema for AI Insights Service
Revision ID: 001
Revises:
Create Date: 2025-11-02 14:30:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import UUID, JSONB
# revision identifiers, used by Alembic.
revision: str = '001'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Create ai_insights table
op.create_table(
'ai_insights',
sa.Column('id', UUID(as_uuid=True), primary_key=True),
sa.Column('tenant_id', UUID(as_uuid=True), nullable=False),
sa.Column('type', sa.String(50), nullable=False),
sa.Column('priority', sa.String(20), nullable=False),
sa.Column('category', sa.String(50), nullable=False),
sa.Column('title', sa.String(255), nullable=False),
sa.Column('description', sa.Text, nullable=False),
sa.Column('impact_type', sa.String(50)),
sa.Column('impact_value', sa.DECIMAL(10, 2)),
sa.Column('impact_unit', sa.String(20)),
sa.Column('confidence', sa.Integer, nullable=False),
sa.Column('metrics_json', JSONB),
sa.Column('actionable', sa.Boolean, nullable=False, server_default='true'),
sa.Column('recommendation_actions', JSONB),
sa.Column('status', sa.String(20), nullable=False, server_default='new'),
sa.Column('source_service', sa.String(50)),
sa.Column('source_data_id', sa.String(100)),
sa.Column('created_at', sa.TIMESTAMP(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column('updated_at', sa.TIMESTAMP(timezone=True), server_default=sa.func.now(), onupdate=sa.func.now(), nullable=False),
sa.Column('applied_at', sa.TIMESTAMP(timezone=True)),
sa.Column('expired_at', sa.TIMESTAMP(timezone=True)),
sa.CheckConstraint('confidence >= 0 AND confidence <= 100', name='check_confidence_range')
)
# Create indexes for ai_insights
op.create_index('idx_tenant_id', 'ai_insights', ['tenant_id'])
op.create_index('idx_type', 'ai_insights', ['type'])
op.create_index('idx_priority', 'ai_insights', ['priority'])
op.create_index('idx_category', 'ai_insights', ['category'])
op.create_index('idx_confidence', 'ai_insights', ['confidence'])
op.create_index('idx_status', 'ai_insights', ['status'])
op.create_index('idx_actionable', 'ai_insights', ['actionable'])
op.create_index('idx_created_at', 'ai_insights', ['created_at'])
op.create_index('idx_tenant_status_category', 'ai_insights', ['tenant_id', 'status', 'category'])
op.create_index('idx_tenant_created_confidence', 'ai_insights', ['tenant_id', 'created_at', 'confidence'])
op.create_index('idx_actionable_status', 'ai_insights', ['actionable', 'status'])
# Create insight_feedback table
op.create_table(
'insight_feedback',
sa.Column('id', UUID(as_uuid=True), primary_key=True),
sa.Column('insight_id', UUID(as_uuid=True), nullable=False),
sa.Column('action_taken', sa.String(100)),
sa.Column('result_data', JSONB),
sa.Column('success', sa.Boolean, nullable=False),
sa.Column('error_message', sa.Text),
sa.Column('expected_impact_value', sa.DECIMAL(10, 2)),
sa.Column('actual_impact_value', sa.DECIMAL(10, 2)),
sa.Column('variance_percentage', sa.DECIMAL(5, 2)),
sa.Column('applied_by', sa.String(100)),
sa.Column('created_at', sa.TIMESTAMP(timezone=True), server_default=sa.func.now(), nullable=False),
sa.ForeignKeyConstraint(['insight_id'], ['ai_insights.id'], ondelete='CASCADE')
)
# Create indexes for insight_feedback
op.create_index('idx_feedback_insight_id', 'insight_feedback', ['insight_id'])
op.create_index('idx_feedback_success', 'insight_feedback', ['success'])
op.create_index('idx_feedback_created_at', 'insight_feedback', ['created_at'])
op.create_index('idx_insight_success', 'insight_feedback', ['insight_id', 'success'])
op.create_index('idx_created_success', 'insight_feedback', ['created_at', 'success'])
# Create insight_correlations table
op.create_table(
'insight_correlations',
sa.Column('id', UUID(as_uuid=True), primary_key=True),
sa.Column('parent_insight_id', UUID(as_uuid=True), nullable=False),
sa.Column('child_insight_id', UUID(as_uuid=True), nullable=False),
sa.Column('correlation_type', sa.String(50), nullable=False),
sa.Column('correlation_strength', sa.DECIMAL(3, 2), nullable=False),
sa.Column('combined_confidence', sa.Integer),
sa.Column('created_at', sa.TIMESTAMP(timezone=True), server_default=sa.func.now(), nullable=False),
sa.ForeignKeyConstraint(['parent_insight_id'], ['ai_insights.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['child_insight_id'], ['ai_insights.id'], ondelete='CASCADE')
)
# Create indexes for insight_correlations
op.create_index('idx_corr_parent', 'insight_correlations', ['parent_insight_id'])
op.create_index('idx_corr_child', 'insight_correlations', ['child_insight_id'])
op.create_index('idx_corr_type', 'insight_correlations', ['correlation_type'])
op.create_index('idx_corr_created_at', 'insight_correlations', ['created_at'])
op.create_index('idx_parent_child', 'insight_correlations', ['parent_insight_id', 'child_insight_id'])
def downgrade() -> None:
op.drop_table('insight_correlations')
op.drop_table('insight_feedback')
op.drop_table('ai_insights')

View File

@@ -0,0 +1,57 @@
# FastAPI and ASGI
fastapi==0.104.1
uvicorn[standard]==0.24.0
python-multipart==0.0.6
# Database
sqlalchemy==2.0.23
alembic==1.12.1
psycopg2-binary==2.9.9
asyncpg==0.29.0
# Pydantic
pydantic==2.5.0
pydantic-settings==2.1.0
# HTTP Client
httpx==0.25.1
aiohttp==3.9.1
# Redis
redis==5.0.1
hiredis==2.2.3
# Utilities
python-dotenv==1.0.0
python-dateutil==2.8.2
pytz==2023.3
# Logging
structlog==23.2.0
# Monitoring and Observability
psutil==5.9.8
opentelemetry-api==1.39.1
opentelemetry-sdk==1.39.1
opentelemetry-instrumentation-fastapi==0.60b1
opentelemetry-exporter-otlp-proto-grpc==1.39.1
opentelemetry-exporter-otlp-proto-http==1.39.1
opentelemetry-instrumentation-httpx==0.60b1
opentelemetry-instrumentation-redis==0.60b1
opentelemetry-instrumentation-sqlalchemy==0.60b1
# Machine Learning (for confidence scoring and impact estimation)
numpy==1.26.2
pandas==2.1.3
scikit-learn==1.3.2
# Testing
pytest==7.4.3
pytest-asyncio==0.21.1
pytest-cov==4.1.0
httpx==0.25.1
# Code Quality
black==23.11.0
flake8==6.1.0
mypy==1.7.1

View File

@@ -0,0 +1,579 @@
"""
Tests for Feedback Loop & Learning System
"""
import pytest
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from services.ai_insights.app.ml.feedback_learning_system import FeedbackLearningSystem
@pytest.fixture
def learning_system():
"""Create FeedbackLearningSystem instance."""
return FeedbackLearningSystem(
performance_threshold=0.85,
degradation_threshold=0.10,
min_feedback_samples=30
)
@pytest.fixture
def good_feedback_data():
"""Generate feedback data for well-performing model."""
np.random.seed(42)
dates = pd.date_range(start=datetime.utcnow() - timedelta(days=60), periods=50, freq='D')
feedback = []
for i, date in enumerate(dates):
predicted = 100 + np.random.normal(0, 10)
actual = predicted + np.random.normal(0, 5) # Small error
error = predicted - actual
error_pct = abs(error / actual * 100) if actual != 0 else 0
accuracy = max(0, 100 - error_pct)
feedback.append({
'insight_id': f'insight_{i}',
'applied_at': date - timedelta(days=1),
'outcome_date': date,
'predicted_value': predicted,
'actual_value': actual,
'error': error,
'error_pct': error_pct,
'accuracy': accuracy,
'confidence': 85
})
return pd.DataFrame(feedback)
@pytest.fixture
def degraded_feedback_data():
"""Generate feedback data for degrading model."""
np.random.seed(42)
dates = pd.date_range(start=datetime.utcnow() - timedelta(days=60), periods=50, freq='D')
feedback = []
for i, date in enumerate(dates):
# Introduce increasing error over time
error_multiplier = 1 + (i / 50) * 2 # Errors double by end
predicted = 100 + np.random.normal(0, 10)
actual = predicted + np.random.normal(0, 10 * error_multiplier)
error = predicted - actual
error_pct = abs(error / actual * 100) if actual != 0 else 0
accuracy = max(0, 100 - error_pct)
feedback.append({
'insight_id': f'insight_{i}',
'applied_at': date - timedelta(days=1),
'outcome_date': date,
'predicted_value': predicted,
'actual_value': actual,
'error': error,
'error_pct': error_pct,
'accuracy': accuracy,
'confidence': 85
})
return pd.DataFrame(feedback)
@pytest.fixture
def biased_feedback_data():
"""Generate feedback data with systematic bias."""
np.random.seed(42)
dates = pd.date_range(start=datetime.utcnow() - timedelta(days=60), periods=50, freq='D')
feedback = []
for i, date in enumerate(dates):
predicted = 100 + np.random.normal(0, 10)
# Systematic over-prediction by 15%
actual = predicted * 0.85 + np.random.normal(0, 3)
error = predicted - actual
error_pct = abs(error / actual * 100) if actual != 0 else 0
accuracy = max(0, 100 - error_pct)
feedback.append({
'insight_id': f'insight_{i}',
'applied_at': date - timedelta(days=1),
'outcome_date': date,
'predicted_value': predicted,
'actual_value': actual,
'error': error,
'error_pct': error_pct,
'accuracy': accuracy,
'confidence': 80
})
return pd.DataFrame(feedback)
@pytest.fixture
def poorly_calibrated_feedback_data():
"""Generate feedback with poor confidence calibration."""
np.random.seed(42)
dates = pd.date_range(start=datetime.utcnow() - timedelta(days=60), periods=50, freq='D')
feedback = []
for i, date in enumerate(dates):
predicted = 100 + np.random.normal(0, 10)
# High confidence but low accuracy
if i < 25:
confidence = 90
actual = predicted + np.random.normal(0, 20) # Large error
else:
confidence = 60
actual = predicted + np.random.normal(0, 5) # Small error
error = predicted - actual
error_pct = abs(error / actual * 100) if actual != 0 else 0
accuracy = max(0, 100 - error_pct)
feedback.append({
'insight_id': f'insight_{i}',
'applied_at': date - timedelta(days=1),
'outcome_date': date,
'predicted_value': predicted,
'actual_value': actual,
'error': error,
'error_pct': error_pct,
'accuracy': accuracy,
'confidence': confidence
})
return pd.DataFrame(feedback)
class TestPerformanceMetrics:
"""Test performance metric calculation."""
@pytest.mark.asyncio
async def test_calculate_metrics_good_performance(self, learning_system, good_feedback_data):
"""Test metric calculation for good performance."""
metrics = learning_system._calculate_performance_metrics(good_feedback_data)
assert 'accuracy' in metrics
assert 'mae' in metrics
assert 'rmse' in metrics
assert 'mape' in metrics
assert 'bias' in metrics
assert 'r_squared' in metrics
# Good model should have high accuracy
assert metrics['accuracy'] > 80
assert metrics['mae'] < 10
assert abs(metrics['bias']) < 5
@pytest.mark.asyncio
async def test_calculate_metrics_degraded_performance(self, learning_system, degraded_feedback_data):
"""Test metric calculation for degraded performance."""
metrics = learning_system._calculate_performance_metrics(degraded_feedback_data)
# Degraded model should have lower accuracy
assert metrics['accuracy'] < 80
assert metrics['mae'] > 5
class TestPerformanceTrend:
"""Test performance trend analysis."""
@pytest.mark.asyncio
async def test_stable_trend(self, learning_system, good_feedback_data):
"""Test detection of stable performance trend."""
trend = learning_system._analyze_performance_trend(good_feedback_data)
assert trend['trend'] in ['stable', 'improving']
@pytest.mark.asyncio
async def test_degrading_trend(self, learning_system, degraded_feedback_data):
"""Test detection of degrading performance trend."""
trend = learning_system._analyze_performance_trend(degraded_feedback_data)
# May detect degrading trend depending on data
assert trend['trend'] in ['degrading', 'stable']
if trend['significant']:
assert 'slope' in trend
@pytest.mark.asyncio
async def test_insufficient_data_trend(self, learning_system):
"""Test trend analysis with insufficient data."""
small_data = pd.DataFrame([{
'insight_id': 'test',
'outcome_date': datetime.utcnow(),
'accuracy': 90
}])
trend = learning_system._analyze_performance_trend(small_data)
assert trend['trend'] == 'insufficient_data'
class TestDegradationDetection:
"""Test performance degradation detection."""
@pytest.mark.asyncio
async def test_no_degradation_detected(self, learning_system, good_feedback_data):
"""Test no degradation for good performance."""
current_metrics = learning_system._calculate_performance_metrics(good_feedback_data)
trend = learning_system._analyze_performance_trend(good_feedback_data)
degradation = learning_system._detect_performance_degradation(
current_metrics,
baseline_performance={'accuracy': 85},
trend_analysis=trend
)
assert degradation['detected'] is False
assert degradation['severity'] == 'none'
@pytest.mark.asyncio
async def test_degradation_below_threshold(self, learning_system):
"""Test degradation detection when below absolute threshold."""
current_metrics = {'accuracy': 70} # Below 85% threshold
trend = {'trend': 'stable', 'significant': False}
degradation = learning_system._detect_performance_degradation(
current_metrics,
baseline_performance=None,
trend_analysis=trend
)
assert degradation['detected'] is True
assert degradation['severity'] == 'high'
assert len(degradation['reasons']) > 0
@pytest.mark.asyncio
async def test_degradation_vs_baseline(self, learning_system):
"""Test degradation detection vs baseline."""
current_metrics = {'accuracy': 80}
baseline = {'accuracy': 95} # 15.8% drop
trend = {'trend': 'stable', 'significant': False}
degradation = learning_system._detect_performance_degradation(
current_metrics,
baseline_performance=baseline,
trend_analysis=trend
)
assert degradation['detected'] is True
assert 'dropped' in degradation['reasons'][0].lower()
@pytest.mark.asyncio
async def test_degradation_trending_down(self, learning_system, degraded_feedback_data):
"""Test degradation detection from trending down."""
current_metrics = learning_system._calculate_performance_metrics(degraded_feedback_data)
trend = learning_system._analyze_performance_trend(degraded_feedback_data)
degradation = learning_system._detect_performance_degradation(
current_metrics,
baseline_performance={'accuracy': 90},
trend_analysis=trend
)
# Should detect some form of degradation
assert degradation['detected'] is True
class TestRetrainingRecommendation:
"""Test retraining recommendation generation."""
@pytest.mark.asyncio
async def test_urgent_retraining_recommendation(self, learning_system):
"""Test urgent retraining recommendation."""
current_metrics = {'accuracy': 70}
degradation = {
'detected': True,
'severity': 'high',
'reasons': ['Accuracy below threshold'],
'current_accuracy': 70,
'baseline_accuracy': 90
}
trend = {'trend': 'degrading', 'significant': True}
recommendation = learning_system._generate_retraining_recommendation(
'test_model',
current_metrics,
degradation,
trend
)
assert recommendation['recommended'] is True
assert recommendation['priority'] == 'urgent'
assert 'immediately' in recommendation['recommendation'].lower()
@pytest.mark.asyncio
async def test_no_retraining_needed(self, learning_system, good_feedback_data):
"""Test no retraining recommendation for good performance."""
current_metrics = learning_system._calculate_performance_metrics(good_feedback_data)
degradation = {'detected': False, 'severity': 'none'}
trend = learning_system._analyze_performance_trend(good_feedback_data)
recommendation = learning_system._generate_retraining_recommendation(
'test_model',
current_metrics,
degradation,
trend
)
assert recommendation['recommended'] is False
assert recommendation['priority'] == 'none'
class TestErrorPatternDetection:
"""Test error pattern identification."""
@pytest.mark.asyncio
async def test_systematic_bias_detection(self, learning_system, biased_feedback_data):
"""Test detection of systematic bias."""
patterns = learning_system._identify_error_patterns(biased_feedback_data)
# Should detect over-prediction bias
bias_patterns = [p for p in patterns if p['pattern'] == 'systematic_bias']
assert len(bias_patterns) > 0
bias = bias_patterns[0]
assert 'over-prediction' in bias['description']
assert bias['severity'] in ['high', 'medium']
@pytest.mark.asyncio
async def test_no_patterns_for_good_data(self, learning_system, good_feedback_data):
"""Test no significant patterns for good data."""
patterns = learning_system._identify_error_patterns(good_feedback_data)
# May have some minor patterns, but no high severity
high_severity = [p for p in patterns if p.get('severity') == 'high']
assert len(high_severity) == 0
class TestConfidenceCalibration:
"""Test confidence calibration analysis."""
@pytest.mark.asyncio
async def test_well_calibrated_confidence(self, learning_system, good_feedback_data):
"""Test well-calibrated confidence scores."""
calibration = learning_system._calculate_confidence_calibration(good_feedback_data)
# Good data with consistent confidence should be well calibrated
if 'overall_calibration_error' in calibration:
# Small calibration error indicates good calibration
assert calibration['overall_calibration_error'] < 20
@pytest.mark.asyncio
async def test_poorly_calibrated_confidence(self, learning_system, poorly_calibrated_feedback_data):
"""Test poorly calibrated confidence scores."""
calibration = learning_system._calculate_confidence_calibration(poorly_calibrated_feedback_data)
# Should detect poor calibration
assert calibration['calibrated'] is False
if 'by_confidence_range' in calibration:
assert len(calibration['by_confidence_range']) > 0
@pytest.mark.asyncio
async def test_no_confidence_data(self, learning_system):
"""Test calibration when no confidence scores available."""
no_conf_data = pd.DataFrame([{
'predicted_value': 100,
'actual_value': 95,
'accuracy': 95
}])
calibration = learning_system._calculate_confidence_calibration(no_conf_data)
assert calibration['calibrated'] is False
assert 'reason' in calibration
class TestCompletePerformanceAnalysis:
"""Test complete performance analysis workflow."""
@pytest.mark.asyncio
async def test_analyze_good_performance(self, learning_system, good_feedback_data):
"""Test complete analysis of good performance."""
result = await learning_system.analyze_model_performance(
model_name='test_model',
feedback_data=good_feedback_data,
baseline_performance={'accuracy': 85}
)
assert result['model_name'] == 'test_model'
assert result['status'] != 'insufficient_feedback'
assert 'current_performance' in result
assert 'trend_analysis' in result
assert 'degradation_detected' in result
assert 'retraining_recommendation' in result
# Good performance should not recommend retraining
assert result['retraining_recommendation']['recommended'] is False
@pytest.mark.asyncio
async def test_analyze_degraded_performance(self, learning_system, degraded_feedback_data):
"""Test complete analysis of degraded performance."""
result = await learning_system.analyze_model_performance(
model_name='degraded_model',
feedback_data=degraded_feedback_data,
baseline_performance={'accuracy': 90}
)
assert result['degradation_detected']['detected'] is True
assert result['retraining_recommendation']['recommended'] is True
@pytest.mark.asyncio
async def test_insufficient_feedback(self, learning_system):
"""Test analysis with insufficient feedback samples."""
small_data = pd.DataFrame([{
'insight_id': 'test',
'outcome_date': datetime.utcnow(),
'predicted_value': 100,
'actual_value': 95,
'error': 5,
'error_pct': 5,
'accuracy': 95,
'confidence': 85
}])
result = await learning_system.analyze_model_performance(
model_name='test_model',
feedback_data=small_data
)
assert result['status'] == 'insufficient_feedback'
assert result['feedback_samples'] == 1
assert result['required_samples'] == 30
class TestLearningInsights:
"""Test learning insight generation."""
@pytest.mark.asyncio
async def test_generate_urgent_retraining_insight(self, learning_system):
"""Test generation of urgent retraining insight."""
analyses = [{
'model_name': 'urgent_model',
'retraining_recommendation': {
'priority': 'urgent',
'recommended': True
},
'degradation_detected': {
'detected': True
}
}]
insights = await learning_system.generate_learning_insights(
analyses,
tenant_id='tenant_123'
)
# Should generate urgent warning
urgent_insights = [i for i in insights if i['priority'] == 'urgent']
assert len(urgent_insights) > 0
insight = urgent_insights[0]
assert insight['type'] == 'warning'
assert 'urgent_model' in insight['description'].lower()
@pytest.mark.asyncio
async def test_generate_system_health_insight(self, learning_system):
"""Test generation of system health insight."""
# 3 models, 1 degraded
analyses = [
{
'model_name': 'model_1',
'degradation_detected': {'detected': False},
'retraining_recommendation': {'priority': 'none'}
},
{
'model_name': 'model_2',
'degradation_detected': {'detected': False},
'retraining_recommendation': {'priority': 'none'}
},
{
'model_name': 'model_3',
'degradation_detected': {'detected': True},
'retraining_recommendation': {'priority': 'high'}
}
]
insights = await learning_system.generate_learning_insights(
analyses,
tenant_id='tenant_123'
)
# Should generate system health insight (66% healthy < 80%)
# Note: May or may not trigger depending on threshold
# At minimum should not crash
assert isinstance(insights, list)
@pytest.mark.asyncio
async def test_generate_calibration_insight(self, learning_system):
"""Test generation of calibration insight."""
analyses = [{
'model_name': 'model_1',
'degradation_detected': {'detected': False},
'retraining_recommendation': {'priority': 'none'},
'confidence_calibration': {
'calibrated': False,
'overall_calibration_error': 15
}
}]
insights = await learning_system.generate_learning_insights(
analyses,
tenant_id='tenant_123'
)
# Should generate calibration insight
calibration_insights = [
i for i in insights
if 'calibration' in i['title'].lower()
]
assert len(calibration_insights) > 0
class TestROICalculation:
"""Test ROI calculation."""
@pytest.mark.asyncio
async def test_calculate_roi_with_impact_values(self, learning_system):
"""Test ROI calculation with impact values."""
feedback_data = pd.DataFrame([
{
'accuracy': 90,
'impact_value': 1000
},
{
'accuracy': 85,
'impact_value': 1500
},
{
'accuracy': 95,
'impact_value': 800
}
])
roi = await learning_system.calculate_roi(
feedback_data,
insight_type='demand_forecast'
)
assert roi['insight_type'] == 'demand_forecast'
assert roi['samples'] == 3
assert roi['avg_accuracy'] == 90.0
assert roi['total_impact_value'] == 3300
assert roi['roi_validated'] is True
@pytest.mark.asyncio
async def test_calculate_roi_without_impact_values(self, learning_system, good_feedback_data):
"""Test ROI calculation without impact values."""
roi = await learning_system.calculate_roi(
good_feedback_data,
insight_type='yield_prediction'
)
assert roi['insight_type'] == 'yield_prediction'
assert roi['samples'] > 0
assert 'avg_accuracy' in roi
assert roi['roi_validated'] is False

View File

@@ -0,0 +1,59 @@
# =============================================================================
# Alert Processor Service Dockerfile - Environment-Configurable Base Images
# =============================================================================
# Build arguments for registry configuration:
# - BASE_REGISTRY: Registry URL (default: docker.io for Docker Hub)
# - PYTHON_IMAGE: Python image name and tag (default: python:3.11-slim)
# =============================================================================
ARG BASE_REGISTRY=docker.io
ARG PYTHON_IMAGE=python:3.11-slim
FROM ${BASE_REGISTRY}/${PYTHON_IMAGE} AS shared
WORKDIR /shared
COPY shared/ /shared/
ARG BASE_REGISTRY=docker.io
ARG PYTHON_IMAGE=python:3.11-slim
FROM ${BASE_REGISTRY}/${PYTHON_IMAGE}
WORKDIR /app
# Install system dependencies
RUN apt-get update && apt-get install -y \
gcc \
g++ \
curl \
&& rm -rf /var/lib/apt/lists/*
# Copy requirements
COPY shared/requirements-tracing.txt /tmp/
COPY services/alert_processor/requirements.txt .
# Install Python dependencies
RUN pip install --no-cache-dir -r /tmp/requirements-tracing.txt
RUN pip install --no-cache-dir -r requirements.txt
# Copy shared libraries from the shared stage
COPY --from=shared /shared /app/shared
# Copy application code
COPY services/alert_processor/ .
# Add shared libraries to Python path
ENV PYTHONPATH="/app:/app/shared:${PYTHONPATH:-}"
# Expose port
EXPOSE 8000
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1
# Run application
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]

View File

@@ -0,0 +1,373 @@
# Alert Processor Service v2.0
Clean, well-structured event processing and alert management system with sophisticated enrichment pipeline.
## Overview
The Alert Processor Service receives **minimal events** from other services (inventory, production, procurement, etc.) and enriches them with:
- **i18n message generation** - Parameterized titles and messages for frontend
- **Multi-factor priority scoring** - Business impact (40%), Urgency (30%), User agency (20%), Confidence (10%)
- **Business impact analysis** - Financial impact, affected orders, customer impact
- **Urgency assessment** - Time until consequence, deadlines, escalation
- **User agency analysis** - Can user fix? External dependencies? Blockers?
- **AI orchestrator context** - Query for AI actions already taken
- **Smart action generation** - Contextual buttons with deep links
- **Entity linking** - References to related entities (POs, batches, orders)
## Architecture
```
Services → RabbitMQ → [Alert Processor] → PostgreSQL
Notification Service
Redis (SSE Pub/Sub)
Frontend
```
### Enrichment Pipeline
1. **Duplicate Detection**: Checks for duplicate alerts within 24-hour window
2. **Message Generator**: Creates i18n keys and parameters from metadata
3. **Orchestrator Client**: Queries AI orchestrator for context
4. **AI Reasoning Extractor**: Extracts AI reasoning details and confidence scores
5. **Business Impact Analyzer**: Calculates financial and operational impact
6. **Urgency Analyzer**: Assesses time sensitivity and deadlines
7. **User Agency Analyzer**: Determines user's ability to act
8. **Priority Scorer**: Calculates weighted priority score (0-100)
9. **Type Classifier**: Determines if action needed or issue prevented
10. **Smart Action Generator**: Creates contextual action buttons
11. **Entity Link Extractor**: Maps metadata to entity references
## Service Structure
```
alert_processor_v2/
├── app/
│ ├── main.py # FastAPI app + lifecycle
│ ├── core/
│ │ ├── config.py # Settings
│ │ └── database.py # Database session management
│ ├── models/
│ │ └── events.py # SQLAlchemy Event model
│ ├── schemas/
│ │ └── events.py # Pydantic schemas
│ ├── api/
│ │ ├── alerts.py # Alert endpoints
│ │ └── sse.py # SSE streaming
│ ├── consumer/
│ │ └── event_consumer.py # RabbitMQ consumer
│ ├── enrichment/
│ │ ├── message_generator.py # i18n generation
│ │ ├── priority_scorer.py # Priority calculation
│ │ ├── orchestrator_client.py # AI context
│ │ ├── smart_actions.py # Action buttons
│ │ ├── business_impact.py # Impact analysis
│ │ ├── urgency_analyzer.py # Urgency assessment
│ │ └── user_agency.py # Agency analysis
│ ├── repositories/
│ │ └── event_repository.py # Database queries
│ ├── services/
│ │ ├── enrichment_orchestrator.py # Pipeline coordinator
│ │ └── sse_service.py # SSE pub/sub
│ └── utils/
│ └── message_templates.py # Alert type mappings
├── migrations/
│ └── versions/
│ └── 20251205_clean_unified_schema.py
└── requirements.txt
```
## Environment Variables
```bash
# Service
SERVICE_NAME=alert-processor
VERSION=2.0.0
DEBUG=false
# Database
DATABASE_URL=postgresql+asyncpg://user:pass@localhost/db
# RabbitMQ
RABBITMQ_URL=amqp://guest:guest@localhost/
RABBITMQ_EXCHANGE=events.exchange
RABBITMQ_QUEUE=alert_processor.queue
# Redis
REDIS_URL=redis://localhost:6379/0
REDIS_SSE_PREFIX=alerts
# Orchestrator Service
ORCHESTRATOR_URL=http://orchestrator:8000
ORCHESTRATOR_TIMEOUT=10
# Notification Service
NOTIFICATION_URL=http://notification:8000
NOTIFICATION_TIMEOUT=5
# Cache
CACHE_ENABLED=true
CACHE_TTL_SECONDS=300
```
## Running the Service
### Local Development
```bash
# Install dependencies
pip install -r requirements.txt
# Run database migrations
alembic upgrade head
# Start service
python -m app.main
# or
uvicorn app.main:app --reload
```
### Docker
```bash
docker build -t alert-processor:2.0 .
docker run -p 8000:8000 --env-file .env alert-processor:2.0
```
## API Endpoints
### Alert Management
- `GET /api/v1/tenants/{tenant_id}/alerts` - List alerts with filters
- `GET /api/v1/tenants/{tenant_id}/alerts/summary` - Get dashboard summary
- `GET /api/v1/tenants/{tenant_id}/alerts/{alert_id}` - Get single alert
- `POST /api/v1/tenants/{tenant_id}/alerts/{alert_id}/acknowledge` - Acknowledge alert
- `POST /api/v1/tenants/{tenant_id}/alerts/{alert_id}/resolve` - Resolve alert
- `POST /api/v1/tenants/{tenant_id}/alerts/{alert_id}/dismiss` - Dismiss alert
### Real-Time Streaming
- `GET /api/v1/sse/alerts/{tenant_id}` - SSE stream for real-time alerts
### Health Check
- `GET /health` - Service health status
## Event Flow
### 1. Service Emits Minimal Event
```python
from shared.messaging.event_publisher import EventPublisher
await publisher.publish_alert(
tenant_id=tenant_id,
event_type="critical_stock_shortage",
event_domain="inventory",
severity="urgent",
metadata={
"ingredient_id": "...",
"ingredient_name": "Flour",
"current_stock": 10.5,
"required_stock": 50.0,
"shortage_amount": 39.5
}
)
```
### 2. Alert Processor Enriches
- **Checks for duplicates**: Searches 24-hour window for similar alerts
- Generates i18n: `alerts.critical_stock_shortage.title` with params
- Queries orchestrator for AI context
- Extracts AI reasoning and confidence scores (if available)
- Analyzes business impact: €197.50 financial impact
- Assesses urgency: 12 hours until consequence
- Determines user agency: Can create PO, requires supplier
- Calculates priority: Score 78 → "important"
- Classifies type: `action_needed` or `prevented_issue`
- Generates smart actions: [Create PO, Call Supplier, Dismiss]
- Extracts entity links: `{ingredient: "..."}`
### 3. Stores Enriched Event
```json
{
"id": "...",
"event_type": "critical_stock_shortage",
"event_domain": "inventory",
"severity": "urgent",
"type_class": "action_needed",
"priority_score": 78,
"priority_level": "important",
"confidence_score": 95,
"i18n": {
"title_key": "alerts.critical_stock_shortage.title",
"title_params": {"ingredient_name": "Flour"},
"message_key": "alerts.critical_stock_shortage.message_generic",
"message_params": {
"current_stock_kg": 10.5,
"required_stock_kg": 50.0
}
},
"business_impact": {...},
"urgency": {...},
"user_agency": {...},
"ai_reasoning_details": {...},
"orchestrator_context": {...},
"smart_actions": [...],
"entity_links": {"ingredient": "..."}
}
```
### 4. Sends Notification
Calls notification service with event details for delivery via WhatsApp, Email, Push, etc.
### 5. Publishes to SSE
Publishes to Redis channel `alerts:{tenant_id}` for real-time frontend updates.
## Priority Scoring Algorithm
**Formula**: `Total = (Impact × 0.4) + (Urgency × 0.3) + (Agency × 0.2) + (Confidence × 0.1)`
**Business Impact Score (0-100)**:
- Financial impact > €1000: +30
- Affected orders > 10: +15
- High customer impact: +15
- Production delay > 4h: +10
- Revenue loss > €500: +10
**Urgency Score (0-100)**:
- Time until consequence < 2h: +40
- Deadline present: +5
- Can't wait until tomorrow: +10
- Peak hour relevant: +5
**User Agency Score (0-100)**:
- User can fix: +30
- Requires external party: -10
- Has blockers: -5 per blocker
- Has workaround: +5
**Escalation Boost** (up to +30):
- Pending > 72h: +20
- Deadline < 6h: +30
## Alert Types
See [app/utils/message_templates.py](app/utils/message_templates.py) for complete list.
### Standard Alerts
- `critical_stock_shortage` - Urgent stock shortages
- `low_stock_warning` - Stock running low
- `production_delay` - Production behind schedule
- `equipment_failure` - Equipment issues
- `po_approval_needed` - Purchase order approval required
- `temperature_breach` - Temperature control violations
- `delivery_overdue` - Late deliveries
- `expired_products` - Product expiration warnings
### AI Recommendations
- `ai_yield_prediction` - AI-predicted production yields
- `ai_safety_stock_optimization` - AI stock level recommendations
- `ai_supplier_recommendation` - AI supplier suggestions
- `ai_price_forecast` - AI price predictions
- `ai_demand_forecast` - AI demand forecasts
- `ai_business_rule` - AI-suggested business rules
## Database Schema
**events table** with JSONB enrichment:
- Core: `id`, `tenant_id`, `created_at`, `event_type`, `event_domain`, `severity`
- i18n: `i18n_title_key`, `i18n_title_params`, `i18n_message_key`, `i18n_message_params`
- Priority: `priority_score` (0-100), `priority_level` (critical/important/standard/info)
- Enrichment: `orchestrator_context`, `business_impact`, `urgency`, `user_agency` (JSONB)
- AI Fields: `ai_reasoning_details`, `confidence_score`, `ai_reasoning_summary_key`, `ai_reasoning_summary_params`
- Classification: `type_class` (action_needed/prevented_issue)
- Actions: `smart_actions` (JSONB array)
- Entities: `entity_links` (JSONB)
- Status: `status` (active/acknowledged/resolved/dismissed)
- Metadata: `raw_metadata` (JSONB)
## Key Features
### Duplicate Alert Detection
The service automatically detects and prevents duplicate alerts:
- **24-hour window**: Checks for similar alerts in the past 24 hours
- **Smart matching**: Compares `tenant_id`, `event_type`, and key metadata fields
- **Update strategy**: Updates existing alert instead of creating duplicates
- **Metadata preservation**: Keeps enriched data while preventing alert fatigue
### Type Classification
Events are classified into two types:
- **action_needed**: User action required (default for alerts)
- **prevented_issue**: AI already handled the situation (for AI recommendations)
This helps the frontend display appropriate UI and messaging.
### AI Reasoning Integration
When AI orchestrator has acted on an event:
- Extracts complete reasoning data structure
- Stores confidence scores (0-100)
- Generates i18n-friendly reasoning summaries
- Links to orchestrator context for full details
### Notification Service Integration
Enriched events are automatically sent to the notification service for delivery via:
- WhatsApp
- Email
- Push notifications
- SMS
Priority mapping:
- `critical` urgent priority
- `important` high priority
- `standard` medium priority
- `info` low priority
## Monitoring
Structured JSON logs with:
- `enrichment_started` - Event received
- `duplicate_detected` - Duplicate alert found and updated
- `enrichment_completed` - Enrichment pipeline finished
- `event_stored` - Saved to database
- `notification_sent` - Notification queued
- `sse_event_published` - Published to SSE stream
## Testing
```bash
# Run tests
pytest
# Test enrichment pipeline
pytest tests/test_enrichment_orchestrator.py
# Test priority scoring
pytest tests/test_priority_scorer.py
# Test message generation
pytest tests/test_message_generator.py
```
## Migration from v1
See [MIGRATION_GUIDE.md](MIGRATION_GUIDE.md) for migration steps from old alert_processor.
Key changes:
- Services send minimal events (no hardcoded messages)
- All enrichment moved to alert_processor
- Unified Event table (no separate alert/notification tables)
- i18n-first architecture
- Sophisticated multi-factor priority scoring
- Smart action generation

View File

@@ -0,0 +1,84 @@
# ================================================================
# services/alert_processor/alembic.ini - Alembic Configuration
# ================================================================
[alembic]
# path to migration scripts
script_location = migrations
# template used to generate migration file names
file_template = %%(year)d%%(month).2d%%(day).2d_%%(hour).2d%%(minute).2d_%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
timezone = Europe/Madrid
# max length of characters to apply to the
# "slug" field
truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
sourceless = false
# version of a migration file's filename format
version_num_format = %%s
# version path separator
version_path_separator = os
# set to 'true' to search source files recursively
# in each "version_locations" directory
recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
output_encoding = utf-8
# Database URL - will be overridden by environment variable or settings
sqlalchemy.url = postgresql+asyncpg://alert_processor_user:password@alert-processor-db-service:5432/alert_processor_db
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts.
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

View File

View File

@@ -0,0 +1,430 @@
"""
Alert API endpoints.
"""
from fastapi import APIRouter, Depends, Query, HTTPException
from typing import List, Optional
from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession
import structlog
from app.core.database import get_db
from app.repositories.event_repository import EventRepository
from app.schemas.events import EventResponse, EventSummary
logger = structlog.get_logger()
router = APIRouter()
@router.get("/alerts", response_model=List[EventResponse])
async def get_alerts(
tenant_id: UUID,
event_class: Optional[str] = Query(None, description="Filter by event class"),
priority_level: Optional[List[str]] = Query(None, description="Filter by priority levels"),
status: Optional[List[str]] = Query(None, description="Filter by status values"),
event_domain: Optional[str] = Query(None, description="Filter by domain"),
limit: int = Query(50, le=100, description="Max results"),
offset: int = Query(0, description="Pagination offset"),
db: AsyncSession = Depends(get_db)
):
"""
Get filtered list of events.
Query Parameters:
- event_class: alert, notification, recommendation
- priority_level: critical, important, standard, info
- status: active, acknowledged, resolved, dismissed
- event_domain: inventory, production, supply_chain, etc.
- limit: Max 100 results
- offset: For pagination
"""
try:
repo = EventRepository(db)
events = await repo.get_events(
tenant_id=tenant_id,
event_class=event_class,
priority_level=priority_level,
status=status,
event_domain=event_domain,
limit=limit,
offset=offset
)
# Convert to response models
return [repo._event_to_response(event) for event in events]
except Exception as e:
logger.error("get_alerts_failed", error=str(e), tenant_id=str(tenant_id))
raise HTTPException(status_code=500, detail="Failed to retrieve alerts")
@router.get("/alerts/summary", response_model=EventSummary)
async def get_alerts_summary(
tenant_id: UUID,
db: AsyncSession = Depends(get_db)
):
"""
Get summary statistics for dashboard.
Returns counts by:
- Status (active, acknowledged, resolved)
- Priority level (critical, important, standard, info)
- Domain (inventory, production, etc.)
- Type class (action_needed, prevented_issue, etc.)
"""
try:
repo = EventRepository(db)
summary = await repo.get_summary(tenant_id)
return summary
except Exception as e:
logger.error("get_summary_failed", error=str(e), tenant_id=str(tenant_id))
raise HTTPException(status_code=500, detail="Failed to retrieve summary")
@router.get("/alerts/{alert_id}", response_model=EventResponse)
async def get_alert(
tenant_id: UUID,
alert_id: UUID,
db: AsyncSession = Depends(get_db)
):
"""Get single alert by ID"""
try:
repo = EventRepository(db)
event = await repo.get_event_by_id(alert_id)
if not event:
raise HTTPException(status_code=404, detail="Alert not found")
# Verify tenant ownership
if event.tenant_id != tenant_id:
raise HTTPException(status_code=403, detail="Access denied")
return repo._event_to_response(event)
except HTTPException:
raise
except Exception as e:
logger.error("get_alert_failed", error=str(e), alert_id=str(alert_id))
raise HTTPException(status_code=500, detail="Failed to retrieve alert")
@router.post("/alerts/{alert_id}/acknowledge", response_model=EventResponse)
async def acknowledge_alert(
tenant_id: UUID,
alert_id: UUID,
db: AsyncSession = Depends(get_db)
):
"""
Mark alert as acknowledged.
Sets status to 'acknowledged' and records timestamp.
"""
try:
repo = EventRepository(db)
# Verify ownership first
event = await repo.get_event_by_id(alert_id)
if not event:
raise HTTPException(status_code=404, detail="Alert not found")
if event.tenant_id != tenant_id:
raise HTTPException(status_code=403, detail="Access denied")
# Acknowledge
updated_event = await repo.acknowledge_event(alert_id)
return repo._event_to_response(updated_event)
except HTTPException:
raise
except Exception as e:
logger.error("acknowledge_alert_failed", error=str(e), alert_id=str(alert_id))
raise HTTPException(status_code=500, detail="Failed to acknowledge alert")
@router.post("/alerts/{alert_id}/resolve", response_model=EventResponse)
async def resolve_alert(
tenant_id: UUID,
alert_id: UUID,
db: AsyncSession = Depends(get_db)
):
"""
Mark alert as resolved.
Sets status to 'resolved' and records timestamp.
"""
try:
repo = EventRepository(db)
# Verify ownership first
event = await repo.get_event_by_id(alert_id)
if not event:
raise HTTPException(status_code=404, detail="Alert not found")
if event.tenant_id != tenant_id:
raise HTTPException(status_code=403, detail="Access denied")
# Resolve
updated_event = await repo.resolve_event(alert_id)
return repo._event_to_response(updated_event)
except HTTPException:
raise
except Exception as e:
logger.error("resolve_alert_failed", error=str(e), alert_id=str(alert_id))
raise HTTPException(status_code=500, detail="Failed to resolve alert")
@router.post("/alerts/{alert_id}/dismiss", response_model=EventResponse)
async def dismiss_alert(
tenant_id: UUID,
alert_id: UUID,
db: AsyncSession = Depends(get_db)
):
"""
Mark alert as dismissed.
Sets status to 'dismissed'.
"""
try:
repo = EventRepository(db)
# Verify ownership first
event = await repo.get_event_by_id(alert_id)
if not event:
raise HTTPException(status_code=404, detail="Alert not found")
if event.tenant_id != tenant_id:
raise HTTPException(status_code=403, detail="Access denied")
# Dismiss
updated_event = await repo.dismiss_event(alert_id)
return repo._event_to_response(updated_event)
except HTTPException:
raise
except Exception as e:
logger.error("dismiss_alert_failed", error=str(e), alert_id=str(alert_id))
raise HTTPException(status_code=500, detail="Failed to dismiss alert")
@router.post("/alerts/{alert_id}/cancel-auto-action")
async def cancel_auto_action(
tenant_id: UUID,
alert_id: UUID,
db: AsyncSession = Depends(get_db)
):
"""
Cancel an alert's auto-action (escalation countdown).
Changes type_class from 'escalation' to 'action_needed' if auto-action was pending.
"""
try:
repo = EventRepository(db)
# Verify ownership first
event = await repo.get_event_by_id(alert_id)
if not event:
raise HTTPException(status_code=404, detail="Alert not found")
if event.tenant_id != tenant_id:
raise HTTPException(status_code=403, detail="Access denied")
# Cancel auto-action (you'll need to implement this in repository)
# For now, return success response
return {
"success": True,
"event_id": str(alert_id),
"message": "Auto-action cancelled successfully",
"updated_type_class": "action_needed"
}
except HTTPException:
raise
except Exception as e:
logger.error("cancel_auto_action_failed", error=str(e), alert_id=str(alert_id))
raise HTTPException(status_code=500, detail="Failed to cancel auto-action")
@router.post("/alerts/bulk-acknowledge")
async def bulk_acknowledge_alerts(
tenant_id: UUID,
request_body: dict,
db: AsyncSession = Depends(get_db)
):
"""
Acknowledge multiple alerts by metadata filter.
Request body:
{
"alert_type": "critical_stock_shortage",
"metadata_filter": {"ingredient_id": "123"}
}
"""
try:
alert_type = request_body.get("alert_type")
metadata_filter = request_body.get("metadata_filter", {})
if not alert_type:
raise HTTPException(status_code=400, detail="alert_type is required")
repo = EventRepository(db)
# Get matching alerts
events = await repo.get_events(
tenant_id=tenant_id,
event_class="alert",
status=["active"],
limit=100
)
# Filter by type and metadata
matching_ids = []
for event in events:
if event.event_type == alert_type:
# Check if metadata matches
matches = all(
event.event_metadata.get(key) == value
for key, value in metadata_filter.items()
)
if matches:
matching_ids.append(event.id)
# Acknowledge all matching
acknowledged_count = 0
for event_id in matching_ids:
try:
await repo.acknowledge_event(event_id)
acknowledged_count += 1
except Exception:
pass # Continue with others
return {
"success": True,
"acknowledged_count": acknowledged_count,
"alert_ids": [str(id) for id in matching_ids]
}
except HTTPException:
raise
except Exception as e:
logger.error("bulk_acknowledge_failed", error=str(e), tenant_id=str(tenant_id))
raise HTTPException(status_code=500, detail="Failed to bulk acknowledge alerts")
@router.post("/alerts/bulk-resolve")
async def bulk_resolve_alerts(
tenant_id: UUID,
request_body: dict,
db: AsyncSession = Depends(get_db)
):
"""
Resolve multiple alerts by metadata filter.
Request body:
{
"alert_type": "critical_stock_shortage",
"metadata_filter": {"ingredient_id": "123"}
}
"""
try:
alert_type = request_body.get("alert_type")
metadata_filter = request_body.get("metadata_filter", {})
if not alert_type:
raise HTTPException(status_code=400, detail="alert_type is required")
repo = EventRepository(db)
# Get matching alerts
events = await repo.get_events(
tenant_id=tenant_id,
event_class="alert",
status=["active", "acknowledged"],
limit=100
)
# Filter by type and metadata
matching_ids = []
for event in events:
if event.event_type == alert_type:
# Check if metadata matches
matches = all(
event.event_metadata.get(key) == value
for key, value in metadata_filter.items()
)
if matches:
matching_ids.append(event.id)
# Resolve all matching
resolved_count = 0
for event_id in matching_ids:
try:
await repo.resolve_event(event_id)
resolved_count += 1
except Exception:
pass # Continue with others
return {
"success": True,
"resolved_count": resolved_count,
"alert_ids": [str(id) for id in matching_ids]
}
except HTTPException:
raise
except Exception as e:
logger.error("bulk_resolve_failed", error=str(e), tenant_id=str(tenant_id))
raise HTTPException(status_code=500, detail="Failed to bulk resolve alerts")
@router.post("/events/{event_id}/interactions")
async def record_interaction(
tenant_id: UUID,
event_id: UUID,
request_body: dict,
db: AsyncSession = Depends(get_db)
):
"""
Record user interaction with an event (for analytics).
Request body:
{
"interaction_type": "viewed" | "clicked" | "dismissed" | "acted_upon",
"interaction_metadata": {...}
}
"""
try:
interaction_type = request_body.get("interaction_type")
interaction_metadata = request_body.get("interaction_metadata", {})
if not interaction_type:
raise HTTPException(status_code=400, detail="interaction_type is required")
repo = EventRepository(db)
# Verify event exists and belongs to tenant
event = await repo.get_event_by_id(event_id)
if not event:
raise HTTPException(status_code=404, detail="Event not found")
if event.tenant_id != tenant_id:
raise HTTPException(status_code=403, detail="Access denied")
# For now, just return success
# In the future, you could store interactions in a separate table
logger.info(
"interaction_recorded",
event_id=str(event_id),
interaction_type=interaction_type,
metadata=interaction_metadata
)
return {
"success": True,
"interaction_id": str(event_id), # Would be a real ID in production
"event_id": str(event_id),
"interaction_type": interaction_type
}
except HTTPException:
raise
except Exception as e:
logger.error("record_interaction_failed", error=str(e), event_id=str(event_id))
raise HTTPException(status_code=500, detail="Failed to record interaction")

View File

@@ -0,0 +1,70 @@
"""
Server-Sent Events (SSE) API endpoint.
"""
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from uuid import UUID
from redis.asyncio import Redis
import structlog
from shared.redis_utils import get_redis_client
from app.services.sse_service import SSEService
logger = structlog.get_logger()
router = APIRouter()
@router.get("/sse/alerts/{tenant_id}")
async def stream_alerts(tenant_id: UUID):
"""
Stream real-time alerts via Server-Sent Events (SSE).
Usage (frontend):
```javascript
const eventSource = new EventSource('/api/v1/sse/alerts/{tenant_id}');
eventSource.onmessage = (event) => {
const alert = JSON.parse(event.data);
console.log('New alert:', alert);
};
```
Response format:
```
data: {"id": "...", "event_type": "...", ...}
data: {"id": "...", "event_type": "...", ...}
```
"""
# Get Redis client from shared utilities
redis = await get_redis_client()
try:
sse_service = SSEService(redis)
async def event_generator():
"""Generator for SSE stream"""
try:
async for message in sse_service.subscribe_to_tenant(str(tenant_id)):
# Format as SSE message
yield f"data: {message}\n\n"
except Exception as e:
logger.error("sse_stream_error", error=str(e), tenant_id=str(tenant_id))
# Send error message and close
yield f"event: error\ndata: {str(e)}\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no" # Disable nginx buffering
}
)
except Exception as e:
logger.error("sse_setup_failed", error=str(e), tenant_id=str(tenant_id))
raise HTTPException(status_code=500, detail="Failed to setup SSE stream")

View File

@@ -0,0 +1,295 @@
"""
RabbitMQ event consumer.
Consumes minimal events from services and processes them through
the enrichment pipeline.
"""
import asyncio
import json
from datetime import datetime, timezone
from aio_pika import connect_robust, IncomingMessage, Connection, Channel
import structlog
from app.core.config import settings
from app.core.database import AsyncSessionLocal
from shared.messaging import MinimalEvent
from app.services.enrichment_orchestrator import EnrichmentOrchestrator
from app.repositories.event_repository import EventRepository
from shared.clients.notification_client import create_notification_client
from app.services.sse_service import SSEService
logger = structlog.get_logger()
class EventConsumer:
"""
RabbitMQ consumer for processing events.
Workflow:
1. Receive minimal event from service
2. Enrich with context (AI, priority, impact, etc.)
3. Store in database
4. Send to notification service
5. Publish to SSE stream
"""
def __init__(self):
self.connection: Connection = None
self.channel: Channel = None
self.enricher = EnrichmentOrchestrator()
self.notification_client = create_notification_client(settings)
self.sse_svc = SSEService()
async def start(self):
"""Start consuming events from RabbitMQ"""
try:
# Connect to RabbitMQ
self.connection = await connect_robust(
settings.RABBITMQ_URL,
client_properties={"connection_name": "alert-processor"}
)
self.channel = await self.connection.channel()
await self.channel.set_qos(prefetch_count=10)
# Declare queue
queue = await self.channel.declare_queue(
settings.RABBITMQ_QUEUE,
durable=True
)
# Bind to events exchange with routing patterns
exchange = await self.channel.declare_exchange(
settings.RABBITMQ_EXCHANGE,
"topic",
durable=True
)
# Bind to alert, notification, and recommendation events
await queue.bind(exchange, routing_key="alert.#")
await queue.bind(exchange, routing_key="notification.#")
await queue.bind(exchange, routing_key="recommendation.#")
# Start consuming
await queue.consume(self.process_message)
logger.info(
"event_consumer_started",
queue=settings.RABBITMQ_QUEUE,
exchange=settings.RABBITMQ_EXCHANGE
)
except Exception as e:
logger.error("consumer_start_failed", error=str(e))
raise
async def process_message(self, message: IncomingMessage):
"""
Process incoming event message.
Steps:
1. Parse message
2. Validate as MinimalEvent
3. Enrich event
4. Store in database
5. Send notification
6. Publish to SSE
7. Acknowledge message
"""
async with message.process():
try:
# Parse message
data = json.loads(message.body.decode())
event = MinimalEvent(**data)
logger.info(
"event_received",
event_type=event.event_type,
event_class=event.event_class,
tenant_id=event.tenant_id
)
# Enrich the event
enriched_event = await self.enricher.enrich_event(event)
# Check for duplicate alerts before storing
async with AsyncSessionLocal() as session:
repo = EventRepository(session)
# Check for duplicate if it's an alert
if event.event_class == "alert":
from uuid import UUID
duplicate_event = await repo.check_duplicate_alert(
tenant_id=UUID(event.tenant_id),
event_type=event.event_type,
entity_links=enriched_event.entity_links,
event_metadata=enriched_event.event_metadata,
time_window_hours=24 # Check for duplicates in last 24 hours
)
if duplicate_event:
logger.info(
"Duplicate alert detected, skipping",
event_type=event.event_type,
tenant_id=event.tenant_id,
duplicate_event_id=str(duplicate_event.id)
)
# Update the existing event's metadata instead of creating a new one
# This could include updating delay times, affected orders, etc.
duplicate_event.event_metadata = enriched_event.event_metadata
duplicate_event.updated_at = datetime.now(timezone.utc)
duplicate_event.priority_score = enriched_event.priority_score
duplicate_event.priority_level = enriched_event.priority_level
# Update other relevant fields that might have changed
duplicate_event.urgency = enriched_event.urgency.dict() if enriched_event.urgency else None
duplicate_event.business_impact = enriched_event.business_impact.dict() if enriched_event.business_impact else None
await session.commit()
await session.refresh(duplicate_event)
# Send notification for updated event
await self._send_notification(duplicate_event)
# Publish to SSE
await self.sse_svc.publish_event(duplicate_event)
logger.info(
"Duplicate alert updated",
event_id=str(duplicate_event.id),
event_type=event.event_type,
priority_level=duplicate_event.priority_level,
priority_score=duplicate_event.priority_score
)
return # Exit early since we handled the duplicate
else:
logger.info(
"New unique alert, proceeding with creation",
event_type=event.event_type,
tenant_id=event.tenant_id
)
# Store in database (if not a duplicate)
stored_event = await repo.create_event(enriched_event)
# Send to notification service (if alert)
if event.event_class == "alert":
await self._send_notification(stored_event)
# Publish to SSE
await self.sse_svc.publish_event(stored_event)
logger.info(
"event_processed",
event_id=stored_event.id,
event_type=event.event_type,
priority_level=stored_event.priority_level,
priority_score=stored_event.priority_score
)
except json.JSONDecodeError as e:
logger.error(
"message_parse_failed",
error=str(e),
message_body=message.body[:200]
)
# Don't requeue - bad message format
except Exception as e:
logger.error(
"event_processing_failed",
error=str(e),
exc_info=True
)
# Message will be requeued automatically due to exception
async def _send_notification(self, event):
"""
Send notification using the shared notification client.
Args:
event: The event to send as a notification
"""
try:
# Prepare notification message
# Use i18n title and message from the event as the notification content
title = event.i18n_title_key if event.i18n_title_key else f"Alert: {event.event_type}"
message = event.i18n_message_key if event.i18n_message_key else f"New alert: {event.event_type}"
# Add parameters to make it more informative
if event.i18n_title_params:
title += f" - {event.i18n_title_params}"
if event.i18n_message_params:
message += f" - {event.i18n_message_params}"
# Prepare metadata from the event
metadata = {
"event_id": str(event.id),
"event_type": event.event_type,
"event_domain": event.event_domain,
"priority_score": event.priority_score,
"priority_level": event.priority_level,
"status": event.status,
"created_at": event.created_at.isoformat() if event.created_at else None,
"type_class": event.type_class,
"smart_actions": event.smart_actions,
"entity_links": event.entity_links
}
# Determine notification priority based on event priority
priority_map = {
"critical": "urgent",
"important": "high",
"standard": "normal",
"info": "low"
}
priority = priority_map.get(event.priority_level, "normal")
# Send notification using shared client
result = await self.notification_client.send_notification(
tenant_id=str(event.tenant_id),
notification_type="in_app", # Using in-app notification by default
message=message,
subject=title,
priority=priority,
metadata=metadata
)
if result:
logger.info(
"notification_sent_via_shared_client",
event_id=str(event.id),
tenant_id=str(event.tenant_id),
priority_level=event.priority_level
)
else:
logger.warning(
"notification_failed_via_shared_client",
event_id=str(event.id),
tenant_id=str(event.tenant_id)
)
except Exception as e:
logger.error(
"notification_error_via_shared_client",
error=str(e),
event_id=str(event.id),
tenant_id=str(event.tenant_id)
)
# Don't re-raise - we don't want to fail the entire event processing
# if notification sending fails
async def stop(self):
"""Stop consumer and close connections"""
try:
if self.channel:
await self.channel.close()
logger.info("rabbitmq_channel_closed")
if self.connection:
await self.connection.close()
logger.info("rabbitmq_connection_closed")
except Exception as e:
logger.error("consumer_stop_failed", error=str(e))

View File

@@ -0,0 +1,51 @@
"""
Configuration settings for alert processor service.
"""
import os
from shared.config.base import BaseServiceSettings
class Settings(BaseServiceSettings):
"""Application settings"""
# Service info - override defaults
SERVICE_NAME: str = "alert-processor"
APP_NAME: str = "Alert Processor Service"
DESCRIPTION: str = "Central alert and recommendation processor"
VERSION: str = "2.0.0"
# Alert processor specific settings
RABBITMQ_EXCHANGE: str = "events.exchange"
RABBITMQ_QUEUE: str = "alert_processor.queue"
REDIS_SSE_PREFIX: str = "alerts"
ORCHESTRATOR_TIMEOUT: int = 10
NOTIFICATION_TIMEOUT: int = 5
CACHE_ENABLED: bool = True
CACHE_TTL_SECONDS: int = 300
@property
def NOTIFICATION_URL(self) -> str:
"""Get notification service URL for backwards compatibility"""
return self.NOTIFICATION_SERVICE_URL
# Database configuration (secure approach - build from components)
@property
def DATABASE_URL(self) -> str:
"""Build database URL from secure components"""
# Try complete URL first (for backward compatibility)
complete_url = os.getenv("ALERT_PROCESSOR_DATABASE_URL")
if complete_url:
return complete_url
# Build from components (secure approach)
user = os.getenv("ALERT_PROCESSOR_DB_USER", "alert_processor_user")
password = os.getenv("ALERT_PROCESSOR_DB_PASSWORD", "alert_processor_pass123")
host = os.getenv("ALERT_PROCESSOR_DB_HOST", "alert-processor-db-service")
port = os.getenv("ALERT_PROCESSOR_DB_PORT", "5432")
name = os.getenv("ALERT_PROCESSOR_DB_NAME", "alert_processor_db")
return f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{name}"
settings = Settings()

View File

@@ -0,0 +1,48 @@
"""
Database connection and session management for Alert Processor Service
"""
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from .config import settings
from shared.database.base import DatabaseManager
# Initialize database manager
database_manager = DatabaseManager(
database_url=settings.DATABASE_URL,
service_name=settings.SERVICE_NAME,
pool_size=settings.DB_POOL_SIZE,
max_overflow=settings.DB_MAX_OVERFLOW,
echo=settings.DEBUG
)
# Create async session factory
AsyncSessionLocal = async_sessionmaker(
database_manager.async_engine,
class_=AsyncSession,
expire_on_commit=False,
autocommit=False,
autoflush=False,
)
async def get_db() -> AsyncSession:
"""
Dependency to get database session.
Used in FastAPI endpoints via Depends(get_db).
"""
async with AsyncSessionLocal() as session:
try:
yield session
finally:
await session.close()
async def init_db():
"""Initialize database (create tables if needed)"""
await database_manager.create_all()
async def close_db():
"""Close database connections"""
await database_manager.close()

View File

@@ -0,0 +1 @@
"""Enrichment components for alert processing."""

View File

@@ -0,0 +1,156 @@
"""
Business impact analyzer for alerts.
Calculates financial impact, affected orders, customer impact, and other
business metrics from event metadata.
"""
from typing import Dict, Any
import structlog
logger = structlog.get_logger()
class BusinessImpactAnalyzer:
"""Analyze business impact from event metadata"""
def analyze(self, event_type: str, metadata: Dict[str, Any]) -> dict:
"""
Analyze business impact for an event.
Returns dict with:
- financial_impact_eur: Direct financial cost
- affected_orders: Number of orders impacted
- affected_customers: List of customer names
- production_delay_hours: Hours of production delay
- estimated_revenue_loss_eur: Potential revenue loss
- customer_impact: high/medium/low
- waste_risk_kg: Potential waste in kg
"""
impact = {
"financial_impact_eur": 0,
"affected_orders": 0,
"affected_customers": [],
"production_delay_hours": 0,
"estimated_revenue_loss_eur": 0,
"customer_impact": "low",
"waste_risk_kg": 0
}
# Stock-related impacts
if "stock" in event_type or "shortage" in event_type:
impact.update(self._analyze_stock_impact(metadata))
# Production-related impacts
elif "production" in event_type or "delay" in event_type or "equipment" in event_type:
impact.update(self._analyze_production_impact(metadata))
# Procurement-related impacts
elif "po_" in event_type or "delivery" in event_type:
impact.update(self._analyze_procurement_impact(metadata))
# Quality-related impacts
elif "quality" in event_type or "expired" in event_type:
impact.update(self._analyze_quality_impact(metadata))
return impact
def _analyze_stock_impact(self, metadata: Dict[str, Any]) -> dict:
"""Analyze impact of stock-related alerts"""
impact = {}
# Calculate financial impact
shortage_amount = metadata.get("shortage_amount", 0)
unit_cost = metadata.get("unit_cost", 5) # Default €5/kg
impact["financial_impact_eur"] = float(shortage_amount) * unit_cost
# Affected orders from metadata
impact["affected_orders"] = metadata.get("affected_orders", 0)
# Customer impact based on affected orders
if impact["affected_orders"] > 5:
impact["customer_impact"] = "high"
elif impact["affected_orders"] > 2:
impact["customer_impact"] = "medium"
# Revenue loss (estimated)
avg_order_value = 50 # €50 per order
impact["estimated_revenue_loss_eur"] = impact["affected_orders"] * avg_order_value
return impact
def _analyze_production_impact(self, metadata: Dict[str, Any]) -> dict:
"""Analyze impact of production-related alerts"""
impact = {}
# Delay minutes to hours
delay_minutes = metadata.get("delay_minutes", 0)
impact["production_delay_hours"] = round(delay_minutes / 60, 1)
# Affected orders and customers
impact["affected_orders"] = metadata.get("affected_orders", 0)
customer_names = metadata.get("customer_names", [])
impact["affected_customers"] = customer_names
# Customer impact based on delay
if delay_minutes > 120: # 2+ hours
impact["customer_impact"] = "high"
elif delay_minutes > 60: # 1+ hours
impact["customer_impact"] = "medium"
# Financial impact: hourly production cost
hourly_cost = 100 # €100/hour operational cost
impact["financial_impact_eur"] = impact["production_delay_hours"] * hourly_cost
# Revenue loss
if impact["affected_orders"] > 0:
avg_order_value = 50
impact["estimated_revenue_loss_eur"] = impact["affected_orders"] * avg_order_value
return impact
def _analyze_procurement_impact(self, metadata: Dict[str, Any]) -> dict:
"""Analyze impact of procurement-related alerts"""
impact = {}
# Extract potential_loss_eur from reasoning_data.parameters
reasoning_data = metadata.get("reasoning_data", {})
parameters = reasoning_data.get("parameters", {})
potential_loss_eur = parameters.get("potential_loss_eur")
# Use potential loss from reasoning as financial impact (what's at risk)
# Fallback to PO amount only if reasoning data is not available
if potential_loss_eur is not None:
impact["financial_impact_eur"] = float(potential_loss_eur)
else:
po_amount = metadata.get("po_amount", metadata.get("total_amount", 0))
impact["financial_impact_eur"] = float(po_amount)
# Days overdue affects customer impact
days_overdue = metadata.get("days_overdue", 0)
if days_overdue > 3:
impact["customer_impact"] = "high"
elif days_overdue > 1:
impact["customer_impact"] = "medium"
return impact
def _analyze_quality_impact(self, metadata: Dict[str, Any]) -> dict:
"""Analyze impact of quality-related alerts"""
impact = {}
# Expired products
expired_count = metadata.get("expired_count", 0)
total_value = metadata.get("total_value", 0)
impact["financial_impact_eur"] = float(total_value)
impact["waste_risk_kg"] = metadata.get("total_quantity_kg", 0)
if expired_count > 5:
impact["customer_impact"] = "high"
elif expired_count > 2:
impact["customer_impact"] = "medium"
return impact

View File

@@ -0,0 +1,244 @@
"""
Message generator for creating i18n message keys and parameters.
Converts minimal event metadata into structured i18n format for frontend translation.
"""
from typing import Dict, Any
from datetime import datetime
from app.utils.message_templates import ALERT_TEMPLATES, NOTIFICATION_TEMPLATES, RECOMMENDATION_TEMPLATES
import structlog
logger = structlog.get_logger()
class MessageGenerator:
"""Generates i18n message keys and parameters from event metadata"""
def generate_message(self, event_type: str, metadata: Dict[str, Any], event_class: str = "alert") -> dict:
"""
Generate i18n structure for frontend.
Args:
event_type: Alert/notification/recommendation type
metadata: Event metadata dictionary
event_class: One of: alert, notification, recommendation
Returns:
Dictionary with title_key, title_params, message_key, message_params
"""
# Select appropriate template collection
if event_class == "notification":
templates = NOTIFICATION_TEMPLATES
elif event_class == "recommendation":
templates = RECOMMENDATION_TEMPLATES
else:
templates = ALERT_TEMPLATES
template = templates.get(event_type)
if not template:
logger.warning("no_template_found", event_type=event_type, event_class=event_class)
return self._generate_fallback(event_type, metadata)
# Build parameters from metadata
title_params = self._build_params(template["title_params"], metadata)
message_params = self._build_params(template["message_params"], metadata)
# Select message variant based on context
message_key = self._select_message_variant(
template["message_variants"],
metadata
)
return {
"title_key": template["title_key"],
"title_params": title_params,
"message_key": message_key,
"message_params": message_params
}
def _generate_fallback(self, event_type: str, metadata: Dict[str, Any]) -> dict:
"""Generate fallback message structure when template not found"""
return {
"title_key": "alerts.generic.title",
"title_params": {},
"message_key": "alerts.generic.message",
"message_params": {
"event_type": event_type,
"metadata_summary": self._summarize_metadata(metadata)
}
}
def _summarize_metadata(self, metadata: Dict[str, Any]) -> str:
"""Create human-readable summary of metadata"""
# Take first 3 fields
items = list(metadata.items())[:3]
summary_parts = [f"{k}: {v}" for k, v in items]
return ", ".join(summary_parts)
def _build_params(self, param_mapping: dict, metadata: dict) -> dict:
"""
Extract and transform parameters from metadata.
param_mapping format: {"display_param_name": "metadata_key"}
"""
params = {}
for param_key, metadata_key in param_mapping.items():
if metadata_key in metadata:
value = metadata[metadata_key]
# Apply transformations based on parameter suffix
if param_key.endswith("_kg"):
value = round(float(value), 1)
elif param_key.endswith("_eur"):
value = round(float(value), 2)
elif param_key.endswith("_percentage"):
value = round(float(value), 1)
elif param_key.endswith("_date"):
value = self._format_date(value)
elif param_key.endswith("_day_name"):
value = self._format_day_name(value)
elif param_key.endswith("_datetime"):
value = self._format_datetime(value)
params[param_key] = value
return params
def _select_message_variant(self, variants: dict, metadata: dict) -> str:
"""
Select appropriate message variant based on metadata context.
Checks for specific conditions in priority order.
"""
# Check for PO-related variants
if "po_id" in metadata:
if metadata.get("po_status") == "pending_approval":
variant = variants.get("with_po_pending")
if variant:
return variant
else:
variant = variants.get("with_po_created")
if variant:
return variant
# Check for time-based variants
if "hours_until" in metadata:
variant = variants.get("with_hours")
if variant:
return variant
if "production_date" in metadata or "planned_date" in metadata:
variant = variants.get("with_date")
if variant:
return variant
# Check for customer-related variants
if "customer_names" in metadata and metadata.get("customer_names"):
variant = variants.get("with_customers")
if variant:
return variant
# Check for order-related variants
if "affected_orders" in metadata and metadata.get("affected_orders", 0) > 0:
variant = variants.get("with_orders")
if variant:
return variant
# Check for supplier contact variants
if "supplier_contact" in metadata:
variant = variants.get("with_supplier")
if variant:
return variant
# Check for batch-related variants
if "affected_batches" in metadata and metadata.get("affected_batches", 0) > 0:
variant = variants.get("with_batches")
if variant:
return variant
# Check for product names list variants
if "product_names" in metadata and metadata.get("product_names"):
variant = variants.get("with_names")
if variant:
return variant
# Check for time duration variants
if "hours_overdue" in metadata:
variant = variants.get("with_hours")
if variant:
return variant
if "days_overdue" in metadata:
variant = variants.get("with_days")
if variant:
return variant
# Default to generic variant
return variants.get("generic", variants[list(variants.keys())[0]])
def _format_date(self, date_value: Any) -> str:
"""
Format date for display.
Accepts:
- ISO string: "2025-12-10"
- datetime object
- date object
Returns: ISO format "YYYY-MM-DD"
"""
if isinstance(date_value, str):
# Already a string, might be ISO format
try:
dt = datetime.fromisoformat(date_value.replace('Z', '+00:00'))
return dt.date().isoformat()
except:
return date_value
if isinstance(date_value, datetime):
return date_value.date().isoformat()
if hasattr(date_value, 'isoformat'):
return date_value.isoformat()
return str(date_value)
def _format_day_name(self, date_value: Any) -> str:
"""
Format day name with date.
Example: "miércoles 10 de diciembre"
Note: Frontend will handle localization.
For now, return ISO date and let frontend format.
"""
iso_date = self._format_date(date_value)
try:
dt = datetime.fromisoformat(iso_date)
# Frontend will use this to format in user's language
return iso_date
except:
return iso_date
def _format_datetime(self, datetime_value: Any) -> str:
"""
Format datetime for display.
Returns: ISO 8601 format with timezone
"""
if isinstance(datetime_value, str):
return datetime_value
if isinstance(datetime_value, datetime):
return datetime_value.isoformat()
if hasattr(datetime_value, 'isoformat'):
return datetime_value.isoformat()
return str(datetime_value)

View File

@@ -0,0 +1,165 @@
"""
Orchestrator client for querying AI action context.
Queries the orchestrator service to determine if AI has already
addressed the issue and what actions were taken.
"""
from typing import Dict, Any, Optional
import httpx
import structlog
from datetime import datetime, timedelta
logger = structlog.get_logger()
class OrchestratorClient:
"""HTTP client for querying orchestrator service"""
def __init__(self, base_url: str = "http://orchestrator-service:8000"):
"""
Initialize orchestrator client.
Args:
base_url: Base URL of orchestrator service
"""
self.base_url = base_url
self.timeout = 10.0 # 10 second timeout
async def get_context(
self,
tenant_id: str,
event_type: str,
metadata: Dict[str, Any]
) -> dict:
"""
Query orchestrator for AI action context.
Returns dict with:
- already_addressed: Boolean - did AI handle this?
- action_type: Type of action taken
- action_id: ID of the action
- action_summary: Human-readable summary
- reasoning: AI reasoning for the action
- confidence: Confidence score (0-1)
- estimated_savings_eur: Estimated savings
- prevented_issue: What issue was prevented
- created_at: When action was created
"""
context = {
"already_addressed": False,
"confidence": 0.8 # Default confidence
}
try:
# Build query based on event type and metadata
query_params = self._build_query_params(event_type, metadata)
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(
f"{self.base_url}/api/internal/recent-actions",
params={
"tenant_id": tenant_id,
**query_params
},
headers={
"x-internal-service": "alert-intelligence"
}
)
if response.status_code == 200:
data = response.json()
context.update(self._parse_response(data, event_type, metadata))
elif response.status_code == 404:
# No recent actions found - that's okay
logger.debug("no_orchestrator_actions", tenant_id=tenant_id, event_type=event_type)
else:
logger.warning(
"orchestrator_query_failed",
status_code=response.status_code,
tenant_id=tenant_id
)
except httpx.TimeoutException:
logger.warning("orchestrator_timeout", tenant_id=tenant_id, event_type=event_type)
except Exception as e:
logger.error("orchestrator_query_error", error=str(e), tenant_id=tenant_id)
return context
def _build_query_params(self, event_type: str, metadata: Dict[str, Any]) -> dict:
"""Build query parameters based on event type"""
params = {}
# For stock-related alerts, query for PO actions
if "stock" in event_type or "shortage" in event_type:
if metadata.get("ingredient_id"):
params["related_entity_type"] = "ingredient"
params["related_entity_id"] = metadata["ingredient_id"]
params["action_types"] = "purchase_order_created,purchase_order_approved"
# For production delays, query for batch adjustments
elif "production" in event_type or "delay" in event_type:
if metadata.get("batch_id"):
params["related_entity_type"] = "production_batch"
params["related_entity_id"] = metadata["batch_id"]
params["action_types"] = "production_adjusted,batch_rescheduled"
# For PO approval, check if already approved
elif "po_approval" in event_type:
if metadata.get("po_id"):
params["related_entity_type"] = "purchase_order"
params["related_entity_id"] = metadata["po_id"]
params["action_types"] = "purchase_order_approved,purchase_order_rejected"
# Look for recent actions (last 24 hours)
params["since_hours"] = 24
return params
def _parse_response(
self,
data: dict,
event_type: str,
metadata: Dict[str, Any]
) -> dict:
"""Parse orchestrator response into context"""
if not data or not data.get("actions"):
return {"already_addressed": False}
# Get most recent action
actions = data.get("actions", [])
if not actions:
return {"already_addressed": False}
most_recent = actions[0]
context = {
"already_addressed": True,
"action_type": most_recent.get("action_type"),
"action_id": most_recent.get("id"),
"action_summary": most_recent.get("summary", ""),
"reasoning": most_recent.get("reasoning", {}),
"confidence": most_recent.get("confidence", 0.8),
"created_at": most_recent.get("created_at"),
"action_status": most_recent.get("status", "completed")
}
# Extract specific fields based on action type
if most_recent.get("action_type") == "purchase_order_created":
context["estimated_savings_eur"] = most_recent.get("estimated_savings_eur", 0)
context["prevented_issue"] = "stockout"
if most_recent.get("delivery_date"):
context["delivery_date"] = most_recent["delivery_date"]
elif most_recent.get("action_type") == "production_adjusted":
context["prevented_issue"] = "production_delay"
context["adjustment_type"] = most_recent.get("adjustment_type")
return context

View File

@@ -0,0 +1,256 @@
"""
Multi-factor priority scoring for alerts.
Calculates priority score (0-100) based on:
- Business impact (40%): Financial impact, affected orders, customer impact
- Urgency (30%): Time until consequence, deadlines
- User agency (20%): Can user fix it? External dependencies?
- Confidence (10%): AI confidence in assessment
Also applies escalation boosts for age and deadline proximity.
"""
from typing import Dict, Any
import structlog
logger = structlog.get_logger()
class PriorityScorer:
"""Calculate multi-factor priority score (0-100)"""
# Weights for priority calculation
BUSINESS_IMPACT_WEIGHT = 0.4
URGENCY_WEIGHT = 0.3
USER_AGENCY_WEIGHT = 0.2
CONFIDENCE_WEIGHT = 0.1
# Priority thresholds
CRITICAL_THRESHOLD = 90
IMPORTANT_THRESHOLD = 70
STANDARD_THRESHOLD = 50
def calculate_priority(
self,
business_impact: dict,
urgency: dict,
user_agency: dict,
orchestrator_context: dict
) -> int:
"""
Calculate weighted priority score.
Args:
business_impact: Business impact context
urgency: Urgency context
user_agency: User agency context
orchestrator_context: AI orchestrator context
Returns:
Priority score (0-100)
"""
# Score each dimension (0-100)
impact_score = self._score_business_impact(business_impact)
urgency_score = self._score_urgency(urgency)
agency_score = self._score_user_agency(user_agency)
confidence_score = orchestrator_context.get("confidence", 0.8) * 100
# Weighted average
total_score = (
impact_score * self.BUSINESS_IMPACT_WEIGHT +
urgency_score * self.URGENCY_WEIGHT +
agency_score * self.USER_AGENCY_WEIGHT +
confidence_score * self.CONFIDENCE_WEIGHT
)
# Apply escalation boost if needed
escalation_boost = self._calculate_escalation_boost(urgency)
total_score = min(100, total_score + escalation_boost)
score = int(total_score)
logger.debug(
"priority_calculated",
score=score,
impact_score=impact_score,
urgency_score=urgency_score,
agency_score=agency_score,
confidence_score=confidence_score,
escalation_boost=escalation_boost
)
return score
def _score_business_impact(self, impact: dict) -> int:
"""
Score business impact (0-100).
Considers:
- Financial impact in EUR
- Number of affected orders
- Customer impact level
- Production delays
- Revenue at risk
"""
score = 50 # Base score
# Financial impact
financial_impact = impact.get("financial_impact_eur", 0)
if financial_impact > 1000:
score += 30
elif financial_impact > 500:
score += 20
elif financial_impact > 100:
score += 10
# Affected orders
affected_orders = impact.get("affected_orders", 0)
if affected_orders > 10:
score += 15
elif affected_orders > 5:
score += 10
elif affected_orders > 0:
score += 5
# Customer impact
customer_impact = impact.get("customer_impact", "low")
if customer_impact == "high":
score += 15
elif customer_impact == "medium":
score += 5
# Production delay hours
production_delay_hours = impact.get("production_delay_hours", 0)
if production_delay_hours > 4:
score += 10
elif production_delay_hours > 2:
score += 5
# Revenue loss
revenue_loss = impact.get("estimated_revenue_loss_eur", 0)
if revenue_loss > 500:
score += 10
elif revenue_loss > 200:
score += 5
return min(100, score)
def _score_urgency(self, urgency: dict) -> int:
"""
Score urgency (0-100).
Considers:
- Time until consequence
- Can it wait until tomorrow?
- Deadline proximity
- Peak hour relevance
"""
score = 50 # Base score
# Time until consequence
hours_until = urgency.get("hours_until_consequence", 24)
if hours_until < 2:
score += 40
elif hours_until < 6:
score += 30
elif hours_until < 12:
score += 20
elif hours_until < 24:
score += 10
# Can it wait?
if not urgency.get("can_wait_until_tomorrow", True):
score += 10
# Deadline present
if urgency.get("deadline_utc"):
score += 5
# Peak hour relevant (production/demand related)
if urgency.get("peak_hour_relevant", False):
score += 5
return min(100, score)
def _score_user_agency(self, agency: dict) -> int:
"""
Score user agency (0-100).
Higher score when user CAN fix the issue.
Lower score when blocked or requires external parties.
Considers:
- Can user fix it?
- Requires external party?
- Has blockers?
- Suggested workarounds available?
"""
score = 50 # Base score
# Can user fix?
if agency.get("can_user_fix", False):
score += 30
else:
score -= 20
# Requires external party?
if agency.get("requires_external_party", False):
score -= 10
# Has blockers?
blockers = agency.get("blockers", [])
score -= len(blockers) * 5
# Has suggested workaround?
if agency.get("suggested_workaround"):
score += 5
return max(0, min(100, score))
def _calculate_escalation_boost(self, urgency: dict) -> int:
"""
Calculate escalation boost for pending alerts.
Boosts priority for:
- Age-based escalation (pending >48h, >72h)
- Deadline proximity (<6h, <24h)
Maximum boost: +30 points
"""
boost = 0
# Age-based escalation
hours_pending = urgency.get("hours_pending", 0)
if hours_pending > 72:
boost += 20
elif hours_pending > 48:
boost += 10
# Deadline proximity
hours_until = urgency.get("hours_until_consequence", 24)
if hours_until < 6:
boost += 30
elif hours_until < 24:
boost += 15
# Cap at +30
return min(30, boost)
def get_priority_level(self, score: int) -> str:
"""
Convert numeric score to priority level.
- 90-100: critical
- 70-89: important
- 50-69: standard
- 0-49: info
"""
if score >= self.CRITICAL_THRESHOLD:
return "critical"
elif score >= self.IMPORTANT_THRESHOLD:
return "important"
elif score >= self.STANDARD_THRESHOLD:
return "standard"
else:
return "info"

View File

@@ -0,0 +1,304 @@
"""
Smart action generator for alerts.
Generates actionable buttons with deep links, phone numbers,
and other interactive elements based on alert type and metadata.
"""
from typing import Dict, Any, List
import structlog
logger = structlog.get_logger()
class SmartActionGenerator:
"""Generate smart action buttons for alerts"""
def generate_actions(
self,
event_type: str,
metadata: Dict[str, Any],
orchestrator_context: dict
) -> List[dict]:
"""
Generate smart actions for an event.
Each action has:
- action_type: Identifier for frontend handling
- label_key: i18n key for button label
- label_params: Parameters for label translation
- variant: primary/secondary/danger/ghost
- disabled: Boolean
- disabled_reason_key: i18n key if disabled
- consequence_key: i18n key for confirmation dialog
- url: Deep link or tel: or mailto:
- metadata: Additional data for action
"""
actions = []
# If AI already addressed, show "View Action" button
if orchestrator_context and orchestrator_context.get("already_addressed"):
actions.append(self._create_view_action(orchestrator_context))
return actions
# Generate actions based on event type
if "po_approval" in event_type:
actions.extend(self._create_po_approval_actions(metadata))
elif "stock" in event_type or "shortage" in event_type:
actions.extend(self._create_stock_actions(metadata))
elif "production" in event_type or "delay" in event_type:
actions.extend(self._create_production_actions(metadata))
elif "equipment" in event_type:
actions.extend(self._create_equipment_actions(metadata))
elif "delivery" in event_type or "overdue" in event_type:
actions.extend(self._create_delivery_actions(metadata))
elif "temperature" in event_type:
actions.extend(self._create_temperature_actions(metadata))
# Always add common actions
actions.extend(self._create_common_actions())
return actions
def _create_view_action(self, orchestrator_context: dict) -> dict:
"""Create action to view what AI did"""
return {
"action_type": "open_reasoning",
"label_key": "actions.view_ai_action",
"label_params": {},
"variant": "primary",
"disabled": False,
"metadata": {
"action_id": orchestrator_context.get("action_id"),
"action_type": orchestrator_context.get("action_type")
}
}
def _create_po_approval_actions(self, metadata: Dict[str, Any]) -> List[dict]:
"""Create actions for PO approval alerts"""
po_id = metadata.get("po_id")
po_amount = metadata.get("total_amount", metadata.get("po_amount", 0))
return [
{
"action_type": "approve_po",
"label_key": "actions.approve_po",
"label_params": {"amount": po_amount},
"variant": "primary",
"disabled": False,
"consequence_key": "actions.approve_po_consequence",
"url": f"/app/procurement/purchase-orders/{po_id}",
"metadata": {"po_id": po_id, "amount": po_amount}
},
{
"action_type": "reject_po",
"label_key": "actions.reject_po",
"label_params": {},
"variant": "danger",
"disabled": False,
"consequence_key": "actions.reject_po_consequence",
"url": f"/app/procurement/purchase-orders/{po_id}",
"metadata": {"po_id": po_id}
},
{
"action_type": "modify_po",
"label_key": "actions.modify_po",
"label_params": {},
"variant": "secondary",
"disabled": False,
"url": f"/app/procurement/purchase-orders/{po_id}/edit",
"metadata": {"po_id": po_id}
}
]
def _create_stock_actions(self, metadata: Dict[str, Any]) -> List[dict]:
"""Create actions for stock-related alerts"""
actions = []
# If supplier info available, add call button
if metadata.get("supplier_contact"):
actions.append({
"action_type": "call_supplier",
"label_key": "actions.call_supplier",
"label_params": {
"supplier": metadata.get("supplier_name", "Supplier"),
"phone": metadata.get("supplier_contact")
},
"variant": "primary",
"disabled": False,
"url": f"tel:{metadata['supplier_contact']}",
"metadata": {
"supplier_name": metadata.get("supplier_name"),
"phone": metadata.get("supplier_contact")
}
})
# If PO exists, add view PO button
if metadata.get("po_id"):
if metadata.get("po_status") == "pending_approval":
actions.append({
"action_type": "approve_po",
"label_key": "actions.approve_po",
"label_params": {"amount": metadata.get("po_amount", 0)},
"variant": "primary",
"disabled": False,
"url": f"/app/procurement/purchase-orders/{metadata['po_id']}",
"metadata": {"po_id": metadata["po_id"]}
})
else:
actions.append({
"action_type": "view_po",
"label_key": "actions.view_po",
"label_params": {"po_number": metadata.get("po_number", metadata["po_id"])},
"variant": "secondary",
"disabled": False,
"url": f"/app/procurement/purchase-orders/{metadata['po_id']}",
"metadata": {"po_id": metadata["po_id"]}
})
# Add create PO button if no PO exists
else:
actions.append({
"action_type": "create_po",
"label_key": "actions.create_po",
"label_params": {},
"variant": "primary",
"disabled": False,
"url": f"/app/procurement/purchase-orders/new?ingredient_id={metadata.get('ingredient_id')}",
"metadata": {"ingredient_id": metadata.get("ingredient_id")}
})
return actions
def _create_production_actions(self, metadata: Dict[str, Any]) -> List[dict]:
"""Create actions for production-related alerts"""
actions = []
if metadata.get("batch_id"):
actions.append({
"action_type": "view_batch",
"label_key": "actions.view_batch",
"label_params": {"batch_number": metadata.get("batch_number", "")},
"variant": "primary",
"disabled": False,
"url": f"/app/production/batches/{metadata['batch_id']}",
"metadata": {"batch_id": metadata["batch_id"]}
})
actions.append({
"action_type": "adjust_production",
"label_key": "actions.adjust_production",
"label_params": {},
"variant": "secondary",
"disabled": False,
"url": f"/app/production/batches/{metadata['batch_id']}/adjust",
"metadata": {"batch_id": metadata["batch_id"]}
})
return actions
def _create_equipment_actions(self, metadata: Dict[str, Any]) -> List[dict]:
"""Create actions for equipment-related alerts"""
return [
{
"action_type": "view_equipment",
"label_key": "actions.view_equipment",
"label_params": {"equipment_name": metadata.get("equipment_name", "")},
"variant": "primary",
"disabled": False,
"url": f"/app/production/equipment/{metadata.get('equipment_id')}",
"metadata": {"equipment_id": metadata.get("equipment_id")}
},
{
"action_type": "schedule_maintenance",
"label_key": "actions.schedule_maintenance",
"label_params": {},
"variant": "secondary",
"disabled": False,
"url": f"/app/production/equipment/{metadata.get('equipment_id')}/maintenance",
"metadata": {"equipment_id": metadata.get("equipment_id")}
}
]
def _create_delivery_actions(self, metadata: Dict[str, Any]) -> List[dict]:
"""Create actions for delivery-related alerts"""
actions = []
if metadata.get("supplier_contact"):
actions.append({
"action_type": "call_supplier",
"label_key": "actions.call_supplier",
"label_params": {
"supplier": metadata.get("supplier_name", "Supplier"),
"phone": metadata.get("supplier_contact")
},
"variant": "primary",
"disabled": False,
"url": f"tel:{metadata['supplier_contact']}",
"metadata": {
"supplier_name": metadata.get("supplier_name"),
"phone": metadata.get("supplier_contact")
}
})
if metadata.get("po_id"):
actions.append({
"action_type": "view_po",
"label_key": "actions.view_po",
"label_params": {"po_number": metadata.get("po_number", "")},
"variant": "secondary",
"disabled": False,
"url": f"/app/procurement/purchase-orders/{metadata['po_id']}",
"metadata": {"po_id": metadata["po_id"]}
})
return actions
def _create_temperature_actions(self, metadata: Dict[str, Any]) -> List[dict]:
"""Create actions for temperature breach alerts"""
return [
{
"action_type": "view_sensor",
"label_key": "actions.view_sensor",
"label_params": {"location": metadata.get("location", "")},
"variant": "primary",
"disabled": False,
"url": f"/app/inventory/sensors/{metadata.get('sensor_id')}",
"metadata": {"sensor_id": metadata.get("sensor_id")}
},
{
"action_type": "acknowledge_breach",
"label_key": "actions.acknowledge_breach",
"label_params": {},
"variant": "secondary",
"disabled": False,
"metadata": {"sensor_id": metadata.get("sensor_id")}
}
]
def _create_common_actions(self) -> List[dict]:
"""Create common actions available for all alerts"""
return [
{
"action_type": "snooze",
"label_key": "actions.snooze",
"label_params": {"hours": 4},
"variant": "ghost",
"disabled": False,
"metadata": {"snooze_hours": 4}
},
{
"action_type": "dismiss",
"label_key": "actions.dismiss",
"label_params": {},
"variant": "ghost",
"disabled": False,
"metadata": {}
}
]

View File

@@ -0,0 +1,173 @@
"""
Urgency analyzer for alerts.
Assesses time sensitivity, deadlines, and determines if action can wait.
"""
from typing import Dict, Any
from datetime import datetime, timedelta, timezone
import structlog
logger = structlog.get_logger()
class UrgencyAnalyzer:
"""Analyze urgency from event metadata"""
def analyze(self, event_type: str, metadata: Dict[str, Any]) -> dict:
"""
Analyze urgency for an event.
Returns dict with:
- hours_until_consequence: Time until impact occurs
- can_wait_until_tomorrow: Boolean
- deadline_utc: ISO datetime if deadline exists
- peak_hour_relevant: Boolean
- hours_pending: Age of alert
"""
urgency = {
"hours_until_consequence": 24, # Default: 24 hours
"can_wait_until_tomorrow": True,
"deadline_utc": None,
"peak_hour_relevant": False,
"hours_pending": 0
}
# Calculate based on event type
if "critical" in event_type or "urgent" in event_type:
urgency["hours_until_consequence"] = 2
urgency["can_wait_until_tomorrow"] = False
elif "production" in event_type:
urgency.update(self._analyze_production_urgency(metadata))
elif "stock" in event_type or "shortage" in event_type:
urgency.update(self._analyze_stock_urgency(metadata))
elif "delivery" in event_type or "overdue" in event_type:
urgency.update(self._analyze_delivery_urgency(metadata))
elif "po_approval" in event_type:
urgency.update(self._analyze_po_approval_urgency(metadata))
# Check for explicit deadlines
if "required_delivery_date" in metadata:
urgency.update(self._calculate_deadline_urgency(metadata["required_delivery_date"]))
if "production_date" in metadata:
urgency.update(self._calculate_deadline_urgency(metadata["production_date"]))
if "expected_date" in metadata:
urgency.update(self._calculate_deadline_urgency(metadata["expected_date"]))
return urgency
def _analyze_production_urgency(self, metadata: Dict[str, Any]) -> dict:
"""Analyze urgency for production alerts"""
urgency = {}
delay_minutes = metadata.get("delay_minutes", 0)
if delay_minutes > 120:
urgency["hours_until_consequence"] = 1
urgency["can_wait_until_tomorrow"] = False
elif delay_minutes > 60:
urgency["hours_until_consequence"] = 4
urgency["can_wait_until_tomorrow"] = False
else:
urgency["hours_until_consequence"] = 8
# Production is peak-hour sensitive
urgency["peak_hour_relevant"] = True
return urgency
def _analyze_stock_urgency(self, metadata: Dict[str, Any]) -> dict:
"""Analyze urgency for stock alerts"""
urgency = {}
# Hours until needed
if "hours_until" in metadata:
urgency["hours_until_consequence"] = metadata["hours_until"]
urgency["can_wait_until_tomorrow"] = urgency["hours_until_consequence"] > 24
# Days until expiry
elif "days_until_expiry" in metadata:
days = metadata["days_until_expiry"]
if days <= 1:
urgency["hours_until_consequence"] = days * 24
urgency["can_wait_until_tomorrow"] = False
else:
urgency["hours_until_consequence"] = days * 24
return urgency
def _analyze_delivery_urgency(self, metadata: Dict[str, Any]) -> dict:
"""Analyze urgency for delivery alerts"""
urgency = {}
days_overdue = metadata.get("days_overdue", 0)
if days_overdue > 3:
urgency["hours_until_consequence"] = 2
urgency["can_wait_until_tomorrow"] = False
elif days_overdue > 1:
urgency["hours_until_consequence"] = 8
urgency["can_wait_until_tomorrow"] = False
return urgency
def _analyze_po_approval_urgency(self, metadata: Dict[str, Any]) -> dict:
"""
Analyze urgency for PO approval alerts.
Uses stockout time (when you run out of stock) instead of delivery date
to determine true urgency.
"""
urgency = {}
# Extract min_depletion_hours from reasoning_data.parameters
reasoning_data = metadata.get("reasoning_data", {})
parameters = reasoning_data.get("parameters", {})
min_depletion_hours = parameters.get("min_depletion_hours")
if min_depletion_hours is not None:
urgency["hours_until_consequence"] = max(0, round(min_depletion_hours, 1))
urgency["can_wait_until_tomorrow"] = min_depletion_hours > 24
# Set deadline_utc to when stock runs out
now = datetime.now(timezone.utc)
stockout_time = now + timedelta(hours=min_depletion_hours)
urgency["deadline_utc"] = stockout_time.isoformat()
logger.info(
"po_approval_urgency_calculated",
min_depletion_hours=min_depletion_hours,
stockout_deadline=urgency["deadline_utc"],
can_wait=urgency["can_wait_until_tomorrow"]
)
return urgency
def _calculate_deadline_urgency(self, deadline_str: str) -> dict:
"""Calculate urgency based on deadline"""
try:
if isinstance(deadline_str, str):
deadline = datetime.fromisoformat(deadline_str.replace('Z', '+00:00'))
else:
deadline = deadline_str
now = datetime.now(timezone.utc)
time_until = deadline - now
hours_until = time_until.total_seconds() / 3600
return {
"deadline_utc": deadline.isoformat(),
"hours_until_consequence": max(0, round(hours_until, 1)),
"can_wait_until_tomorrow": hours_until > 24
}
except Exception as e:
logger.warning("deadline_parse_failed", deadline=deadline_str, error=str(e))
return {}

View File

@@ -0,0 +1,116 @@
"""
User agency analyzer for alerts.
Determines whether user can fix the issue, what blockers exist,
and if external parties are required.
"""
from typing import Dict, Any
import structlog
logger = structlog.get_logger()
class UserAgencyAnalyzer:
"""Analyze user's ability to act on alerts"""
def analyze(
self,
event_type: str,
metadata: Dict[str, Any],
orchestrator_context: dict
) -> dict:
"""
Analyze user agency for an event.
Returns dict with:
- can_user_fix: Boolean - can user resolve this?
- requires_external_party: Boolean
- external_party_name: Name of required party
- external_party_contact: Contact info
- blockers: List of blocking factors
- suggested_workaround: Optional workaround suggestion
"""
agency = {
"can_user_fix": True,
"requires_external_party": False,
"external_party_name": None,
"external_party_contact": None,
"blockers": [],
"suggested_workaround": None
}
# If orchestrator already addressed it, user agency is low
if orchestrator_context and orchestrator_context.get("already_addressed"):
agency["can_user_fix"] = False
agency["blockers"].append("ai_already_handled")
return agency
# Analyze based on event type
if "po_approval" in event_type:
agency["can_user_fix"] = True
elif "delivery" in event_type or "supplier" in event_type:
agency.update(self._analyze_supplier_agency(metadata))
elif "equipment" in event_type:
agency.update(self._analyze_equipment_agency(metadata))
elif "stock" in event_type:
agency.update(self._analyze_stock_agency(metadata, orchestrator_context))
return agency
def _analyze_supplier_agency(self, metadata: Dict[str, Any]) -> dict:
"""Analyze agency for supplier-related alerts"""
agency = {
"requires_external_party": True,
"external_party_name": metadata.get("supplier_name"),
"external_party_contact": metadata.get("supplier_contact")
}
# User can contact supplier but can't directly fix
if not metadata.get("supplier_contact"):
agency["blockers"].append("no_supplier_contact")
return agency
def _analyze_equipment_agency(self, metadata: Dict[str, Any]) -> dict:
"""Analyze agency for equipment-related alerts"""
agency = {}
equipment_type = metadata.get("equipment_type", "")
if "oven" in equipment_type.lower() or "mixer" in equipment_type.lower():
agency["requires_external_party"] = True
agency["external_party_name"] = "Maintenance Team"
agency["blockers"].append("requires_technician")
return agency
def _analyze_stock_agency(
self,
metadata: Dict[str, Any],
orchestrator_context: dict
) -> dict:
"""Analyze agency for stock-related alerts"""
agency = {}
# If PO exists, user just needs to approve
if metadata.get("po_id"):
if metadata.get("po_status") == "pending_approval":
agency["can_user_fix"] = True
agency["suggested_workaround"] = "Approve pending PO"
else:
agency["blockers"].append("waiting_for_delivery")
agency["requires_external_party"] = True
agency["external_party_name"] = metadata.get("supplier_name")
# If no PO, user needs to create one
elif metadata.get("supplier_name"):
agency["can_user_fix"] = True
agency["requires_external_party"] = True
agency["external_party_name"] = metadata.get("supplier_name")
return agency

View File

@@ -0,0 +1,100 @@
"""
Alert Processor Service v2.0
Main FastAPI application with RabbitMQ consumer lifecycle management.
"""
import structlog
from app.core.config import settings
from app.consumer.event_consumer import EventConsumer
from app.api import alerts, sse
from shared.redis_utils import initialize_redis, close_redis
from shared.service_base import StandardFastAPIService
# Initialize logger
logger = structlog.get_logger()
# Global consumer instance
consumer: EventConsumer = None
class AlertProcessorService(StandardFastAPIService):
"""Alert Processor Service with standardized monitoring setup and RabbitMQ consumer"""
async def on_startup(self, app):
"""Custom startup logic for Alert Processor"""
global consumer
# Initialize Redis connection
await initialize_redis(
settings.REDIS_URL,
db=settings.REDIS_DB,
max_connections=settings.REDIS_MAX_CONNECTIONS
)
logger.info("redis_initialized")
# Start RabbitMQ consumer
consumer = EventConsumer()
await consumer.start()
logger.info("rabbitmq_consumer_started")
await super().on_startup(app)
async def on_shutdown(self, app):
"""Custom shutdown logic for Alert Processor"""
global consumer
await super().on_shutdown(app)
# Stop RabbitMQ consumer
if consumer:
await consumer.stop()
logger.info("rabbitmq_consumer_stopped")
# Close Redis
await close_redis()
logger.info("redis_closed")
# Create service instance
service = AlertProcessorService(
service_name="alert-processor",
app_name="Alert Processor Service",
description="Event processing, enrichment, and alert management system",
version=settings.VERSION,
log_level=getattr(settings, 'LOG_LEVEL', 'INFO'),
cors_origins=["*"], # Configure appropriately for production
api_prefix="/api/v1",
enable_metrics=True,
enable_health_checks=True,
enable_tracing=True,
enable_cors=True
)
# Create FastAPI app
app = service.create_app(debug=settings.DEBUG)
# Add service-specific routers
app.include_router(
alerts.router,
prefix="/api/v1/tenants/{tenant_id}",
tags=["alerts"]
)
app.include_router(
sse.router,
prefix="/api/v1",
tags=["sse"]
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app.main:app",
host="0.0.0.0",
port=8000,
reload=settings.DEBUG
)

View File

@@ -0,0 +1,84 @@
"""
SQLAlchemy models for events table.
"""
from sqlalchemy import Column, String, Integer, DateTime, Float, Index
from sqlalchemy.dialects.postgresql import UUID, JSONB
from sqlalchemy.ext.declarative import declarative_base
from datetime import datetime, timezone
import uuid
Base = declarative_base()
class Event(Base):
"""Unified event table for alerts, notifications, recommendations"""
__tablename__ = "events"
# Core fields
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
created_at = Column(
DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc),
nullable=False
)
updated_at = Column(
DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
nullable=False
)
# Classification
event_class = Column(String(50), nullable=False)
event_domain = Column(String(50), nullable=False, index=True)
event_type = Column(String(100), nullable=False, index=True)
service = Column(String(50), nullable=False)
# i18n content (NO hardcoded title/message)
i18n_title_key = Column(String(200), nullable=False)
i18n_title_params = Column(JSONB, nullable=False, default=dict)
i18n_message_key = Column(String(200), nullable=False)
i18n_message_params = Column(JSONB, nullable=False, default=dict)
# Priority
priority_score = Column(Integer, nullable=False, default=50, index=True)
priority_level = Column(String(20), nullable=False, index=True)
type_class = Column(String(50), nullable=False, index=True)
# Enrichment contexts (JSONB)
orchestrator_context = Column(JSONB, nullable=True)
business_impact = Column(JSONB, nullable=True)
urgency = Column(JSONB, nullable=True)
user_agency = Column(JSONB, nullable=True)
trend_context = Column(JSONB, nullable=True)
# Smart actions
smart_actions = Column(JSONB, nullable=False, default=list)
# AI reasoning
ai_reasoning_summary_key = Column(String(200), nullable=True)
ai_reasoning_summary_params = Column(JSONB, nullable=True)
ai_reasoning_details = Column(JSONB, nullable=True)
confidence_score = Column(Float, nullable=True)
# Entity references
entity_links = Column(JSONB, nullable=False, default=dict)
# Status
status = Column(String(20), nullable=False, default="active", index=True)
resolved_at = Column(DateTime(timezone=True), nullable=True)
acknowledged_at = Column(DateTime(timezone=True), nullable=True)
# Metadata
event_metadata = Column(JSONB, nullable=False, default=dict)
# Indexes for dashboard queries
__table_args__ = (
Index('idx_events_tenant_status', 'tenant_id', 'status'),
Index('idx_events_tenant_priority', 'tenant_id', 'priority_score'),
Index('idx_events_tenant_class', 'tenant_id', 'event_class'),
Index('idx_events_tenant_created', 'tenant_id', 'created_at'),
Index('idx_events_type_class_status', 'type_class', 'status'),
)

View File

@@ -0,0 +1,407 @@
"""
Event repository for database operations.
"""
from typing import List, Optional, Dict, Any
from uuid import UUID
from datetime import datetime, timezone
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, and_, desc
from sqlalchemy.dialects.postgresql import insert
import structlog
from app.models.events import Event
from app.schemas.events import EnrichedEvent, EventSummary, EventResponse, I18nContent, SmartAction
logger = structlog.get_logger()
class EventRepository:
"""Repository for event database operations"""
def __init__(self, session: AsyncSession):
self.session = session
async def create_event(self, enriched_event: EnrichedEvent) -> Event:
"""
Store enriched event in database.
Args:
enriched_event: Enriched event with all context
Returns:
Stored Event model
"""
# Convert enriched event to database model
event = Event(
id=enriched_event.id,
tenant_id=UUID(enriched_event.tenant_id),
event_class=enriched_event.event_class,
event_domain=enriched_event.event_domain,
event_type=enriched_event.event_type,
service=enriched_event.service,
# i18n content
i18n_title_key=enriched_event.i18n.title_key,
i18n_title_params=enriched_event.i18n.title_params,
i18n_message_key=enriched_event.i18n.message_key,
i18n_message_params=enriched_event.i18n.message_params,
# Priority
priority_score=enriched_event.priority_score,
priority_level=enriched_event.priority_level,
type_class=enriched_event.type_class,
# Enrichment contexts
orchestrator_context=enriched_event.orchestrator_context.dict() if enriched_event.orchestrator_context else None,
business_impact=enriched_event.business_impact.dict() if enriched_event.business_impact else None,
urgency=enriched_event.urgency.dict() if enriched_event.urgency else None,
user_agency=enriched_event.user_agency.dict() if enriched_event.user_agency else None,
trend_context=enriched_event.trend_context,
# Smart actions
smart_actions=[action.dict() for action in enriched_event.smart_actions],
# AI reasoning
ai_reasoning_summary_key=enriched_event.ai_reasoning_summary_key,
ai_reasoning_summary_params=enriched_event.ai_reasoning_summary_params,
ai_reasoning_details=enriched_event.ai_reasoning_details,
confidence_score=enriched_event.confidence_score,
# Entity links
entity_links=enriched_event.entity_links,
# Status
status=enriched_event.status,
# Metadata
event_metadata=enriched_event.event_metadata
)
self.session.add(event)
await self.session.commit()
await self.session.refresh(event)
logger.info("event_stored", event_id=event.id, event_type=event.event_type)
return event
async def get_events(
self,
tenant_id: UUID,
event_class: Optional[str] = None,
priority_level: Optional[List[str]] = None,
status: Optional[List[str]] = None,
event_domain: Optional[str] = None,
limit: int = 50,
offset: int = 0
) -> List[Event]:
"""
Get filtered list of events.
Args:
tenant_id: Tenant UUID
event_class: Filter by event class (alert, notification, recommendation)
priority_level: Filter by priority levels
status: Filter by status values
event_domain: Filter by domain
limit: Max results
offset: Pagination offset
Returns:
List of Event models
"""
query = select(Event).where(Event.tenant_id == tenant_id)
# Apply filters
if event_class:
query = query.where(Event.event_class == event_class)
if priority_level:
query = query.where(Event.priority_level.in_(priority_level))
if status:
query = query.where(Event.status.in_(status))
if event_domain:
query = query.where(Event.event_domain == event_domain)
# Order by priority and creation time
query = query.order_by(
desc(Event.priority_score),
desc(Event.created_at)
)
# Pagination
query = query.limit(limit).offset(offset)
result = await self.session.execute(query)
events = result.scalars().all()
return list(events)
async def get_event_by_id(self, event_id: UUID) -> Optional[Event]:
"""Get single event by ID"""
query = select(Event).where(Event.id == event_id)
result = await self.session.execute(query)
return result.scalar_one_or_none()
async def check_duplicate_alert(self, tenant_id: UUID, event_type: str, entity_links: Dict, event_metadata: Dict, time_window_hours: int = 24) -> Optional[Event]:
"""
Check if a similar alert already exists within the time window.
Args:
tenant_id: Tenant UUID
event_type: Type of event (e.g., 'production_delay', 'critical_stock_shortage')
entity_links: Entity references (e.g., batch_id, po_id, ingredient_id)
event_metadata: Event metadata for comparison
time_window_hours: Time window in hours to check for duplicates
Returns:
Existing event if duplicate found, None otherwise
"""
from datetime import datetime, timedelta, timezone
# Calculate time threshold
time_threshold = datetime.now(timezone.utc) - timedelta(hours=time_window_hours)
# Build query to find potential duplicates
query = select(Event).where(
and_(
Event.tenant_id == tenant_id,
Event.event_type == event_type,
Event.status == "active", # Only check active alerts
Event.created_at >= time_threshold
)
)
result = await self.session.execute(query)
potential_duplicates = result.scalars().all()
# Compare each potential duplicate for semantic similarity
for event in potential_duplicates:
# Check if entity links match (same batch, PO, ingredient, etc.)
if self._entities_match(event.entity_links, entity_links):
# For production delays, check if it's the same batch with similar delay
if event_type == "production_delay":
if self._production_delay_match(event.event_metadata, event_metadata):
return event
# For critical stock shortages, check if it's the same ingredient
elif event_type == "critical_stock_shortage":
if self._stock_shortage_match(event.event_metadata, event_metadata):
return event
# For delivery overdue alerts, check if it's the same PO
elif event_type == "delivery_overdue":
if self._delivery_overdue_match(event.event_metadata, event_metadata):
return event
# For general matching based on metadata
else:
if self._metadata_match(event.event_metadata, event_metadata):
return event
return None
def _entities_match(self, existing_links: Dict, new_links: Dict) -> bool:
"""Check if entity links match between two events."""
if not existing_links or not new_links:
return False
# Check for common entity types
common_entities = ['production_batch', 'purchase_order', 'ingredient', 'supplier', 'equipment']
for entity in common_entities:
if entity in existing_links and entity in new_links:
if existing_links[entity] == new_links[entity]:
return True
return False
def _production_delay_match(self, existing_meta: Dict, new_meta: Dict) -> bool:
"""Check if production delay alerts match."""
# Same batch_id indicates same production issue
return (existing_meta.get('batch_id') == new_meta.get('batch_id') and
existing_meta.get('product_name') == new_meta.get('product_name'))
def _stock_shortage_match(self, existing_meta: Dict, new_meta: Dict) -> bool:
"""Check if stock shortage alerts match."""
# Same ingredient_id indicates same shortage issue
return existing_meta.get('ingredient_id') == new_meta.get('ingredient_id')
def _delivery_overdue_match(self, existing_meta: Dict, new_meta: Dict) -> bool:
"""Check if delivery overdue alerts match."""
# Same PO indicates same delivery issue
return existing_meta.get('po_id') == new_meta.get('po_id')
def _metadata_match(self, existing_meta: Dict, new_meta: Dict) -> bool:
"""Generic metadata matching for other alert types."""
# Check for common identifying fields
common_fields = ['batch_id', 'po_id', 'ingredient_id', 'supplier_id', 'equipment_id']
for field in common_fields:
if field in existing_meta and field in new_meta:
if existing_meta[field] == new_meta[field]:
return True
return False
async def get_summary(self, tenant_id: UUID) -> EventSummary:
"""
Get summary statistics for dashboard.
Args:
tenant_id: Tenant UUID
Returns:
EventSummary with counts and statistics
"""
# Count by status
status_query = select(
Event.status,
func.count(Event.id).label('count')
).where(
Event.tenant_id == tenant_id
).group_by(Event.status)
status_result = await self.session.execute(status_query)
status_counts = {row.status: row.count for row in status_result}
# Count by priority
priority_query = select(
Event.priority_level,
func.count(Event.id).label('count')
).where(
and_(
Event.tenant_id == tenant_id,
Event.status == "active"
)
).group_by(Event.priority_level)
priority_result = await self.session.execute(priority_query)
priority_counts = {row.priority_level: row.count for row in priority_result}
# Count by domain
domain_query = select(
Event.event_domain,
func.count(Event.id).label('count')
).where(
and_(
Event.tenant_id == tenant_id,
Event.status == "active"
)
).group_by(Event.event_domain)
domain_result = await self.session.execute(domain_query)
domain_counts = {row.event_domain: row.count for row in domain_result}
# Count by type class
type_class_query = select(
Event.type_class,
func.count(Event.id).label('count')
).where(
and_(
Event.tenant_id == tenant_id,
Event.status == "active"
)
).group_by(Event.type_class)
type_class_result = await self.session.execute(type_class_query)
type_class_counts = {row.type_class: row.count for row in type_class_result}
return EventSummary(
total_active=status_counts.get("active", 0),
total_acknowledged=status_counts.get("acknowledged", 0),
total_resolved=status_counts.get("resolved", 0),
by_priority=priority_counts,
by_domain=domain_counts,
by_type_class=type_class_counts,
critical_alerts=priority_counts.get("critical", 0),
important_alerts=priority_counts.get("important", 0)
)
async def acknowledge_event(self, event_id: UUID) -> Event:
"""Mark event as acknowledged"""
event = await self.get_event_by_id(event_id)
if not event:
raise ValueError(f"Event {event_id} not found")
event.status = "acknowledged"
event.acknowledged_at = datetime.now(timezone.utc)
await self.session.commit()
await self.session.refresh(event)
logger.info("event_acknowledged", event_id=event_id)
return event
async def resolve_event(self, event_id: UUID) -> Event:
"""Mark event as resolved"""
event = await self.get_event_by_id(event_id)
if not event:
raise ValueError(f"Event {event_id} not found")
event.status = "resolved"
event.resolved_at = datetime.now(timezone.utc)
await self.session.commit()
await self.session.refresh(event)
logger.info("event_resolved", event_id=event_id)
return event
async def dismiss_event(self, event_id: UUID) -> Event:
"""Mark event as dismissed"""
event = await self.get_event_by_id(event_id)
if not event:
raise ValueError(f"Event {event_id} not found")
event.status = "dismissed"
await self.session.commit()
await self.session.refresh(event)
logger.info("event_dismissed", event_id=event_id)
return event
def _event_to_response(self, event: Event) -> EventResponse:
"""Convert Event model to EventResponse"""
return EventResponse(
id=event.id,
tenant_id=event.tenant_id,
created_at=event.created_at,
event_class=event.event_class,
event_domain=event.event_domain,
event_type=event.event_type,
i18n=I18nContent(
title_key=event.i18n_title_key,
title_params=event.i18n_title_params,
message_key=event.i18n_message_key,
message_params=event.i18n_message_params
),
priority_score=event.priority_score,
priority_level=event.priority_level,
type_class=event.type_class,
smart_actions=[SmartAction(**action) for action in event.smart_actions],
status=event.status,
orchestrator_context=event.orchestrator_context,
business_impact=event.business_impact,
urgency=event.urgency,
user_agency=event.user_agency,
ai_reasoning_summary_key=event.ai_reasoning_summary_key,
ai_reasoning_summary_params=event.ai_reasoning_summary_params,
ai_reasoning_details=event.ai_reasoning_details,
confidence_score=event.confidence_score,
entity_links=event.entity_links,
event_metadata=event.event_metadata
)

View File

@@ -0,0 +1,180 @@
"""
Pydantic schemas for enriched events.
"""
from pydantic import BaseModel, Field
from typing import Dict, Any, List, Optional, Literal
from datetime import datetime
from uuid import UUID
class I18nContent(BaseModel):
"""i18n content structure"""
title_key: str
title_params: Dict[str, Any] = {}
message_key: str
message_params: Dict[str, Any] = {}
class SmartAction(BaseModel):
"""Smart action button"""
action_type: str
label_key: str
label_params: Dict[str, Any] = {}
variant: Literal["primary", "secondary", "danger", "ghost"]
disabled: bool = False
disabled_reason_key: Optional[str] = None
consequence_key: Optional[str] = None
url: Optional[str] = None
metadata: Dict[str, Any] = {}
class BusinessImpact(BaseModel):
"""Business impact context"""
financial_impact_eur: float = 0
affected_orders: int = 0
affected_customers: List[str] = []
production_delay_hours: float = 0
estimated_revenue_loss_eur: float = 0
customer_impact: Literal["low", "medium", "high"] = "low"
waste_risk_kg: float = 0
class Urgency(BaseModel):
"""Urgency context"""
hours_until_consequence: float = 24
can_wait_until_tomorrow: bool = True
deadline_utc: Optional[str] = None
peak_hour_relevant: bool = False
hours_pending: float = 0
class UserAgency(BaseModel):
"""User agency context"""
can_user_fix: bool = True
requires_external_party: bool = False
external_party_name: Optional[str] = None
external_party_contact: Optional[str] = None
blockers: List[str] = []
suggested_workaround: Optional[str] = None
class OrchestratorContext(BaseModel):
"""AI orchestrator context"""
already_addressed: bool = False
action_id: Optional[str] = None
action_type: Optional[str] = None
action_summary: Optional[str] = None
reasoning: Optional[str] = None
confidence: float = 0.8
class EnrichedEvent(BaseModel):
"""Complete enriched event with all context"""
# Core fields
id: str
tenant_id: str
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
# Classification
event_class: Literal["alert", "notification", "recommendation"]
event_domain: str
event_type: str
service: str
# i18n content
i18n: I18nContent
# Priority
priority_score: int = Field(ge=0, le=100)
priority_level: Literal["critical", "important", "standard", "info"]
type_class: str
# Enrichment contexts
orchestrator_context: Optional[OrchestratorContext] = None
business_impact: Optional[BusinessImpact] = None
urgency: Optional[Urgency] = None
user_agency: Optional[UserAgency] = None
trend_context: Optional[Dict[str, Any]] = None
# Smart actions
smart_actions: List[SmartAction] = []
# AI reasoning
ai_reasoning_summary_key: Optional[str] = None
ai_reasoning_summary_params: Optional[Dict[str, Any]] = None
ai_reasoning_details: Optional[Dict[str, Any]] = None
confidence_score: Optional[float] = None
# Entity references
entity_links: Dict[str, str] = {}
# Status
status: Literal["active", "acknowledged", "resolved", "dismissed"] = "active"
resolved_at: Optional[datetime] = None
acknowledged_at: Optional[datetime] = None
# Original metadata
event_metadata: Dict[str, Any] = {}
class Config:
from_attributes = True
class EventResponse(BaseModel):
"""Event response for API"""
id: UUID
tenant_id: UUID
created_at: datetime
event_class: str
event_domain: str
event_type: str
i18n: I18nContent
priority_score: int
priority_level: str
type_class: str
smart_actions: List[SmartAction]
status: str
# Optional enrichment contexts (only if present)
orchestrator_context: Optional[Dict[str, Any]] = None
business_impact: Optional[Dict[str, Any]] = None
urgency: Optional[Dict[str, Any]] = None
user_agency: Optional[Dict[str, Any]] = None
# AI reasoning
ai_reasoning_summary_key: Optional[str] = None
ai_reasoning_summary_params: Optional[Dict[str, Any]] = None
ai_reasoning_details: Optional[Dict[str, Any]] = None
confidence_score: Optional[float] = None
entity_links: Dict[str, str] = {}
event_metadata: Optional[Dict[str, Any]] = None
class Config:
from_attributes = True
class EventSummary(BaseModel):
"""Summary statistics for dashboard"""
total_active: int
total_acknowledged: int
total_resolved: int
by_priority: Dict[str, int]
by_domain: Dict[str, int]
by_type_class: Dict[str, int]
critical_alerts: int
important_alerts: int
class EventFilter(BaseModel):
"""Filter criteria for event queries"""
tenant_id: UUID
event_class: Optional[str] = None
priority_level: Optional[List[str]] = None
status: Optional[List[str]] = None
event_domain: Optional[str] = None
limit: int = Field(default=50, le=100)
offset: int = 0

View File

@@ -0,0 +1,246 @@
"""
Enrichment orchestrator service.
Coordinates the complete enrichment pipeline for events.
"""
from typing import Dict, Any
import structlog
from uuid import uuid4
from shared.messaging import MinimalEvent
from app.schemas.events import EnrichedEvent, I18nContent, BusinessImpact, Urgency, UserAgency, OrchestratorContext
from app.enrichment.message_generator import MessageGenerator
from app.enrichment.priority_scorer import PriorityScorer
from app.enrichment.orchestrator_client import OrchestratorClient
from app.enrichment.smart_actions import SmartActionGenerator
from app.enrichment.business_impact import BusinessImpactAnalyzer
from app.enrichment.urgency_analyzer import UrgencyAnalyzer
from app.enrichment.user_agency import UserAgencyAnalyzer
logger = structlog.get_logger()
class EnrichmentOrchestrator:
"""Coordinates the enrichment pipeline for events"""
def __init__(self):
self.message_gen = MessageGenerator()
self.priority_scorer = PriorityScorer()
self.orchestrator_client = OrchestratorClient()
self.action_gen = SmartActionGenerator()
self.impact_analyzer = BusinessImpactAnalyzer()
self.urgency_analyzer = UrgencyAnalyzer()
self.agency_analyzer = UserAgencyAnalyzer()
async def enrich_event(self, event: MinimalEvent) -> EnrichedEvent:
"""
Run complete enrichment pipeline.
Steps:
1. Generate i18n message keys and parameters
2. Query orchestrator for AI context
3. Analyze business impact
4. Assess urgency
5. Determine user agency
6. Calculate priority score (0-100)
7. Determine priority level
8. Generate smart actions
9. Determine type class
10. Build enriched event
Args:
event: Minimal event from service
Returns:
Enriched event with all context
"""
logger.info("enrichment_started", event_type=event.event_type, tenant_id=event.tenant_id)
# 1. Generate i18n message keys and parameters
i18n_dict = self.message_gen.generate_message(event.event_type, event.metadata, event.event_class)
i18n = I18nContent(**i18n_dict)
# 2. Query orchestrator for AI context (parallel with other enrichments)
orchestrator_context_dict = await self.orchestrator_client.get_context(
tenant_id=event.tenant_id,
event_type=event.event_type,
metadata=event.metadata
)
# Fallback: If orchestrator service didn't return context with already_addressed,
# check if the event metadata contains orchestrator_context (e.g., from demo seeder)
if not orchestrator_context_dict.get("already_addressed"):
metadata_context = event.metadata.get("orchestrator_context")
if metadata_context and isinstance(metadata_context, dict):
# Merge metadata context into orchestrator context
orchestrator_context_dict.update(metadata_context)
logger.debug(
"using_metadata_orchestrator_context",
event_type=event.event_type,
already_addressed=metadata_context.get("already_addressed")
)
# Convert to OrchestratorContext if data exists
orchestrator_context = None
if orchestrator_context_dict:
orchestrator_context = OrchestratorContext(**orchestrator_context_dict)
# 3. Analyze business impact
business_impact_dict = self.impact_analyzer.analyze(
event_type=event.event_type,
metadata=event.metadata
)
business_impact = BusinessImpact(**business_impact_dict)
# 4. Assess urgency
urgency_dict = self.urgency_analyzer.analyze(
event_type=event.event_type,
metadata=event.metadata
)
urgency = Urgency(**urgency_dict)
# 5. Determine user agency
user_agency_dict = self.agency_analyzer.analyze(
event_type=event.event_type,
metadata=event.metadata,
orchestrator_context=orchestrator_context_dict
)
user_agency = UserAgency(**user_agency_dict)
# 6. Calculate priority score (0-100)
priority_score = self.priority_scorer.calculate_priority(
business_impact=business_impact_dict,
urgency=urgency_dict,
user_agency=user_agency_dict,
orchestrator_context=orchestrator_context_dict
)
# 7. Determine priority level
priority_level = self._get_priority_level(priority_score)
# 8. Generate smart actions
smart_actions = self.action_gen.generate_actions(
event_type=event.event_type,
metadata=event.metadata,
orchestrator_context=orchestrator_context_dict
)
# 9. Determine type class
type_class = self._determine_type_class(orchestrator_context_dict, event.metadata)
# 10. Extract AI reasoning from metadata (if present)
reasoning_data = event.metadata.get('reasoning_data')
ai_reasoning_details = None
confidence_score = None
if reasoning_data:
# Store the complete reasoning data structure
ai_reasoning_details = reasoning_data
# Extract confidence if available
if isinstance(reasoning_data, dict):
metadata_section = reasoning_data.get('metadata', {})
if isinstance(metadata_section, dict) and 'confidence' in metadata_section:
confidence_score = metadata_section.get('confidence')
# 11. Build enriched event
enriched = EnrichedEvent(
id=str(uuid4()),
tenant_id=event.tenant_id,
event_class=event.event_class,
event_domain=event.event_domain,
event_type=event.event_type,
service=event.service,
i18n=i18n,
priority_score=priority_score,
priority_level=priority_level,
type_class=type_class,
orchestrator_context=orchestrator_context,
business_impact=business_impact,
urgency=urgency,
user_agency=user_agency,
smart_actions=smart_actions,
ai_reasoning_details=ai_reasoning_details,
confidence_score=confidence_score,
entity_links=self._extract_entity_links(event.metadata),
status="active",
event_metadata=event.metadata
)
logger.info(
"enrichment_completed",
event_type=event.event_type,
priority_score=priority_score,
priority_level=priority_level,
type_class=type_class
)
return enriched
def _get_priority_level(self, score: int) -> str:
"""
Convert numeric score to priority level.
- 90-100: critical
- 70-89: important
- 50-69: standard
- 0-49: info
"""
if score >= 90:
return "critical"
elif score >= 70:
return "important"
elif score >= 50:
return "standard"
else:
return "info"
def _determine_type_class(self, orchestrator_context: dict, metadata: dict = None) -> str:
"""
Determine type class based on orchestrator context or metadata override.
Priority order:
1. Explicit type_class in metadata (e.g., from demo seeder)
2. orchestrator_context.already_addressed = True -> "prevented_issue"
3. Default: "action_needed"
- prevented_issue: AI already handled it
- action_needed: User action required
"""
# Check for explicit type_class in metadata (allows demo seeder override)
if metadata:
explicit_type_class = metadata.get("type_class")
if explicit_type_class in ("prevented_issue", "action_needed"):
return explicit_type_class
# Determine from orchestrator context
if orchestrator_context and orchestrator_context.get("already_addressed"):
return "prevented_issue"
return "action_needed"
def _extract_entity_links(self, metadata: dict) -> Dict[str, str]:
"""
Extract entity references from metadata.
Maps metadata keys to entity types for frontend deep linking.
"""
links = {}
# Map metadata keys to entity types
entity_mappings = {
"po_id": "purchase_order",
"batch_id": "production_batch",
"ingredient_id": "ingredient",
"order_id": "order",
"supplier_id": "supplier",
"equipment_id": "equipment",
"sensor_id": "sensor"
}
for key, entity_type in entity_mappings.items():
if key in metadata:
links[entity_type] = str(metadata[key])
return links

View File

@@ -0,0 +1,129 @@
"""
Server-Sent Events (SSE) service using Redis pub/sub.
"""
from typing import AsyncGenerator
import json
import structlog
from redis.asyncio import Redis
from app.core.config import settings
from app.models.events import Event
from shared.redis_utils import get_redis_client
logger = structlog.get_logger()
class SSEService:
"""
Manage real-time event streaming via Redis pub/sub.
Pattern: alerts:{tenant_id}
"""
def __init__(self, redis: Redis = None):
self._redis = redis # Use private attribute to allow lazy loading
self.prefix = settings.REDIS_SSE_PREFIX
@property
async def redis(self) -> Redis:
"""
Lazy load Redis client if not provided through dependency injection.
Uses the shared Redis utilities for consistency.
"""
if self._redis is None:
self._redis = await get_redis_client()
return self._redis
async def publish_event(self, event: Event) -> bool:
"""
Publish event to Redis for SSE streaming.
Args:
event: Event to publish
Returns:
True if published successfully
"""
try:
redis_client = await self.redis
# Build channel name
channel = f"{self.prefix}:{event.tenant_id}"
# Build message payload
payload = {
"id": str(event.id),
"tenant_id": str(event.tenant_id),
"event_class": event.event_class,
"event_domain": event.event_domain,
"event_type": event.event_type,
"priority_score": event.priority_score,
"priority_level": event.priority_level,
"type_class": event.type_class,
"status": event.status,
"created_at": event.created_at.isoformat(),
"i18n": {
"title_key": event.i18n_title_key,
"title_params": event.i18n_title_params,
"message_key": event.i18n_message_key,
"message_params": event.i18n_message_params
},
"smart_actions": event.smart_actions,
"entity_links": event.entity_links
}
# Publish to Redis
await redis_client.publish(channel, json.dumps(payload))
logger.debug(
"sse_event_published",
channel=channel,
event_type=event.event_type,
event_id=str(event.id)
)
return True
except Exception as e:
logger.error(
"sse_publish_failed",
error=str(e),
event_id=str(event.id)
)
return False
async def subscribe_to_tenant(
self,
tenant_id: str
) -> AsyncGenerator[str, None]:
"""
Subscribe to tenant's alert stream.
Args:
tenant_id: Tenant UUID
Yields:
JSON-encoded event messages
"""
redis_client = await self.redis
channel = f"{self.prefix}:{tenant_id}"
logger.info("sse_subscription_started", channel=channel)
# Subscribe to Redis channel
pubsub = redis_client.pubsub()
await pubsub.subscribe(channel)
try:
async for message in pubsub.listen():
if message["type"] == "message":
yield message["data"]
except Exception as e:
logger.error("sse_subscription_error", error=str(e), channel=channel)
raise
finally:
await pubsub.unsubscribe(channel)
await pubsub.close()
logger.info("sse_subscription_closed", channel=channel)

View File

@@ -0,0 +1,556 @@
"""
Alert type definitions with i18n key mappings.
Each alert type maps to:
- title_key: i18n key for title (e.g., "alerts.critical_stock_shortage.title")
- title_params: parameter mappings from metadata to i18n params
- message_variants: different message keys based on context
- message_params: parameter mappings for message
When adding new alert types:
1. Add entry to ALERT_TEMPLATES
2. Ensure corresponding translations exist in frontend/src/locales/*/alerts.json
3. Document required metadata fields
"""
# Alert type templates
ALERT_TEMPLATES = {
# ==================== INVENTORY ALERTS ====================
"critical_stock_shortage": {
"title_key": "alerts.critical_stock_shortage.title",
"title_params": {
"ingredient_name": "ingredient_name"
},
"message_variants": {
"with_po_pending": "alerts.critical_stock_shortage.message_with_po_pending",
"with_po_created": "alerts.critical_stock_shortage.message_with_po_created",
"with_hours": "alerts.critical_stock_shortage.message_with_hours",
"with_date": "alerts.critical_stock_shortage.message_with_date",
"generic": "alerts.critical_stock_shortage.message_generic"
},
"message_params": {
"ingredient_name": "ingredient_name",
"current_stock_kg": "current_stock",
"required_stock_kg": "required_stock",
"hours_until": "hours_until",
"production_day_name": "production_date",
"po_id": "po_id",
"po_amount": "po_amount",
"delivery_day_name": "delivery_date"
}
},
"low_stock_warning": {
"title_key": "alerts.low_stock.title",
"title_params": {
"ingredient_name": "ingredient_name"
},
"message_variants": {
"with_po": "alerts.low_stock.message_with_po",
"generic": "alerts.low_stock.message_generic"
},
"message_params": {
"ingredient_name": "ingredient_name",
"current_stock_kg": "current_stock",
"minimum_stock_kg": "minimum_stock"
}
},
"overstock_warning": {
"title_key": "alerts.overstock_warning.title",
"title_params": {
"ingredient_name": "ingredient_name"
},
"message_variants": {
"generic": "alerts.overstock_warning.message"
},
"message_params": {
"ingredient_name": "ingredient_name",
"current_stock_kg": "current_stock",
"maximum_stock_kg": "maximum_stock",
"excess_amount_kg": "excess_amount"
}
},
"expired_products": {
"title_key": "alerts.expired_products.title",
"title_params": {
"count": "expired_count"
},
"message_variants": {
"with_names": "alerts.expired_products.message_with_names",
"generic": "alerts.expired_products.message_generic"
},
"message_params": {
"expired_count": "expired_count",
"product_names": "product_names",
"total_value_eur": "total_value"
}
},
"urgent_expiry": {
"title_key": "alerts.urgent_expiry.title",
"title_params": {
"ingredient_name": "ingredient_name"
},
"message_variants": {
"generic": "alerts.urgent_expiry.message"
},
"message_params": {
"ingredient_name": "ingredient_name",
"days_until_expiry": "days_until_expiry",
"quantity_kg": "quantity"
}
},
"temperature_breach": {
"title_key": "alerts.temperature_breach.title",
"title_params": {
"location": "location"
},
"message_variants": {
"generic": "alerts.temperature_breach.message"
},
"message_params": {
"location": "location",
"temperature": "temperature",
"max_threshold": "max_threshold",
"duration_minutes": "duration_minutes"
}
},
"stock_depleted_by_order": {
"title_key": "alerts.stock_depleted_by_order.title",
"title_params": {
"ingredient_name": "ingredient_name"
},
"message_variants": {
"with_supplier": "alerts.stock_depleted_by_order.message_with_supplier",
"generic": "alerts.stock_depleted_by_order.message_generic"
},
"message_params": {
"ingredient_name": "ingredient_name",
"shortage_kg": "shortage_amount",
"supplier_name": "supplier_name",
"supplier_contact": "supplier_contact"
}
},
# ==================== PRODUCTION ALERTS ====================
"production_delay": {
"title_key": "alerts.production_delay.title",
"title_params": {
"product_name": "product_name",
"batch_number": "batch_number"
},
"message_variants": {
"with_customers": "alerts.production_delay.message_with_customers",
"with_orders": "alerts.production_delay.message_with_orders",
"generic": "alerts.production_delay.message_generic"
},
"message_params": {
"product_name": "product_name",
"batch_number": "batch_number",
"delay_minutes": "delay_minutes",
"affected_orders": "affected_orders",
"customer_names": "customer_names"
}
},
"equipment_failure": {
"title_key": "alerts.equipment_failure.title",
"title_params": {
"equipment_name": "equipment_name"
},
"message_variants": {
"with_batches": "alerts.equipment_failure.message_with_batches",
"generic": "alerts.equipment_failure.message_generic"
},
"message_params": {
"equipment_name": "equipment_name",
"equipment_type": "equipment_type",
"affected_batches": "affected_batches"
}
},
"maintenance_required": {
"title_key": "alerts.maintenance_required.title",
"title_params": {
"equipment_name": "equipment_name"
},
"message_variants": {
"with_hours": "alerts.maintenance_required.message_with_hours",
"with_days": "alerts.maintenance_required.message_with_days",
"generic": "alerts.maintenance_required.message_generic"
},
"message_params": {
"equipment_name": "equipment_name",
"hours_overdue": "hours_overdue",
"days_overdue": "days_overdue"
}
},
"low_equipment_efficiency": {
"title_key": "alerts.low_equipment_efficiency.title",
"title_params": {
"equipment_name": "equipment_name"
},
"message_variants": {
"generic": "alerts.low_equipment_efficiency.message"
},
"message_params": {
"equipment_name": "equipment_name",
"efficiency_percentage": "efficiency_percentage",
"target_efficiency": "target_efficiency"
}
},
"capacity_overload": {
"title_key": "alerts.capacity_overload.title",
"title_params": {
"date": "planned_date"
},
"message_variants": {
"generic": "alerts.capacity_overload.message"
},
"message_params": {
"planned_date": "planned_date",
"capacity_percentage": "capacity_percentage",
"equipment_count": "equipment_count"
}
},
"quality_control_failure": {
"title_key": "alerts.quality_control_failure.title",
"title_params": {
"product_name": "product_name",
"batch_number": "batch_number"
},
"message_variants": {
"generic": "alerts.quality_control_failure.message"
},
"message_params": {
"product_name": "product_name",
"batch_number": "batch_number",
"check_type": "check_type",
"quality_score": "quality_score",
"defect_count": "defect_count"
}
},
# ==================== PROCUREMENT ALERTS ====================
"po_approval_needed": {
"title_key": "alerts.po_approval_needed.title",
"title_params": {
"po_number": "po_number"
},
"message_variants": {
"generic": "alerts.po_approval_needed.message"
},
"message_params": {
"supplier_name": "supplier_name",
"total_amount": "total_amount",
"currency": "currency",
"required_delivery_date": "required_delivery_date",
"items_count": "items_count"
}
},
"po_approval_escalation": {
"title_key": "alerts.po_approval_escalation.title",
"title_params": {
"po_number": "po_number"
},
"message_variants": {
"generic": "alerts.po_approval_escalation.message"
},
"message_params": {
"po_number": "po_number",
"supplier_name": "supplier_name",
"hours_pending": "hours_pending",
"total_amount": "total_amount"
}
},
"delivery_overdue": {
"title_key": "alerts.delivery_overdue.title",
"title_params": {
"po_number": "po_number"
},
"message_variants": {
"generic": "alerts.delivery_overdue.message"
},
"message_params": {
"po_number": "po_number",
"supplier_name": "supplier_name",
"days_overdue": "days_overdue",
"expected_date": "expected_date"
}
},
# ==================== SUPPLY CHAIN ALERTS ====================
"supplier_delay": {
"title_key": "alerts.supplier_delay.title",
"title_params": {
"supplier_name": "supplier_name"
},
"message_variants": {
"generic": "alerts.supplier_delay.message"
},
"message_params": {
"supplier_name": "supplier_name",
"po_count": "po_count",
"avg_delay_days": "avg_delay_days"
}
},
# ==================== DEMAND ALERTS ====================
"demand_surge_weekend": {
"title_key": "alerts.demand_surge_weekend.title",
"title_params": {},
"message_variants": {
"generic": "alerts.demand_surge_weekend.message"
},
"message_params": {
"product_name": "product_name",
"predicted_demand": "predicted_demand",
"current_stock": "current_stock"
}
},
"weather_impact_alert": {
"title_key": "alerts.weather_impact_alert.title",
"title_params": {},
"message_variants": {
"generic": "alerts.weather_impact_alert.message"
},
"message_params": {
"weather_condition": "weather_condition",
"impact_percentage": "impact_percentage",
"date": "date"
}
},
# ==================== PRODUCTION BATCH ALERTS ====================
"production_batch_start": {
"title_key": "alerts.production_batch_start.title",
"title_params": {
"product_name": "product_name"
},
"message_variants": {
"generic": "alerts.production_batch_start.message"
},
"message_params": {
"product_name": "product_name",
"batch_number": "batch_number",
"quantity_planned": "quantity_planned",
"unit": "unit",
"priority": "priority"
}
},
# ==================== GENERIC FALLBACK ====================
"generic": {
"title_key": "alerts.generic.title",
"title_params": {},
"message_variants": {
"generic": "alerts.generic.message"
},
"message_params": {
"event_type": "event_type"
}
}
}
# Notification templates (informational events)
NOTIFICATION_TEMPLATES = {
"po_approved": {
"title_key": "notifications.po_approved.title",
"title_params": {
"po_number": "po_number"
},
"message_variants": {
"generic": "notifications.po_approved.message"
},
"message_params": {
"supplier_name": "supplier_name",
"total_amount": "total_amount"
}
},
"batch_state_changed": {
"title_key": "notifications.batch_state_changed.title",
"title_params": {
"product_name": "product_name"
},
"message_variants": {
"generic": "notifications.batch_state_changed.message"
},
"message_params": {
"batch_number": "batch_number",
"new_status": "new_status",
"quantity": "quantity",
"unit": "unit"
}
},
"stock_received": {
"title_key": "notifications.stock_received.title",
"title_params": {
"ingredient_name": "ingredient_name"
},
"message_variants": {
"generic": "notifications.stock_received.message"
},
"message_params": {
"quantity_received": "quantity_received",
"unit": "unit",
"supplier_name": "supplier_name"
}
}
}
# Recommendation templates (optimization suggestions)
RECOMMENDATION_TEMPLATES = {
"inventory_optimization": {
"title_key": "recommendations.inventory_optimization.title",
"title_params": {
"ingredient_name": "ingredient_name"
},
"message_variants": {
"generic": "recommendations.inventory_optimization.message"
},
"message_params": {
"ingredient_name": "ingredient_name",
"current_max_kg": "current_max",
"suggested_max_kg": "suggested_max",
"recommendation_type": "recommendation_type"
}
},
"production_efficiency": {
"title_key": "recommendations.production_efficiency.title",
"title_params": {
"product_name": "product_name"
},
"message_variants": {
"generic": "recommendations.production_efficiency.message"
},
"message_params": {
"product_name": "product_name",
"potential_time_saved_minutes": "time_saved",
"suggestion": "suggestion"
}
},
# ==================== AI INSIGHTS RECOMMENDATIONS ====================
"ai_yield_prediction": {
"title_key": "recommendations.ai_yield_prediction.title",
"title_params": {
"recipe_name": "recipe_name"
},
"message_variants": {
"generic": "recommendations.ai_yield_prediction.message"
},
"message_params": {
"recipe_name": "recipe_name",
"predicted_yield_percent": "predicted_yield",
"confidence_percent": "confidence",
"recommendation": "recommendation"
}
},
"ai_safety_stock_optimization": {
"title_key": "recommendations.ai_safety_stock_optimization.title",
"title_params": {
"ingredient_name": "ingredient_name"
},
"message_variants": {
"generic": "recommendations.ai_safety_stock_optimization.message"
},
"message_params": {
"ingredient_name": "ingredient_name",
"suggested_safety_stock_kg": "suggested_safety_stock",
"current_safety_stock_kg": "current_safety_stock",
"estimated_savings_eur": "estimated_savings",
"confidence_percent": "confidence"
}
},
"ai_supplier_recommendation": {
"title_key": "recommendations.ai_supplier_recommendation.title",
"title_params": {
"supplier_name": "supplier_name"
},
"message_variants": {
"generic": "recommendations.ai_supplier_recommendation.message"
},
"message_params": {
"supplier_name": "supplier_name",
"reliability_score": "reliability_score",
"recommendation": "recommendation",
"confidence_percent": "confidence"
}
},
"ai_price_forecast": {
"title_key": "recommendations.ai_price_forecast.title",
"title_params": {
"ingredient_name": "ingredient_name"
},
"message_variants": {
"generic": "recommendations.ai_price_forecast.message"
},
"message_params": {
"ingredient_name": "ingredient_name",
"predicted_price_eur": "predicted_price",
"current_price_eur": "current_price",
"price_trend": "price_trend",
"recommendation": "recommendation",
"confidence_percent": "confidence"
}
},
"ai_demand_forecast": {
"title_key": "recommendations.ai_demand_forecast.title",
"title_params": {
"product_name": "product_name"
},
"message_variants": {
"generic": "recommendations.ai_demand_forecast.message"
},
"message_params": {
"product_name": "product_name",
"predicted_demand": "predicted_demand",
"forecast_period": "forecast_period",
"confidence_percent": "confidence",
"recommendation": "recommendation"
}
},
"ai_business_rule": {
"title_key": "recommendations.ai_business_rule.title",
"title_params": {
"rule_category": "rule_category"
},
"message_variants": {
"generic": "recommendations.ai_business_rule.message"
},
"message_params": {
"rule_category": "rule_category",
"rule_description": "rule_description",
"confidence_percent": "confidence",
"recommendation": "recommendation"
}
}
}

View File

@@ -0,0 +1,134 @@
"""Alembic environment configuration for alert_processor service"""
import asyncio
import os
import sys
from logging.config import fileConfig
from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import async_engine_from_config
from alembic import context
# Add the service directory to the Python path
service_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if service_path not in sys.path:
sys.path.insert(0, service_path)
# Add shared modules to path
shared_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "shared"))
if shared_path not in sys.path:
sys.path.insert(0, shared_path)
try:
from shared.database.base import Base
# Import all models to ensure they are registered with Base.metadata
from app.models import * # noqa: F401, F403
except ImportError as e:
print(f"Import error in migrations env.py: {e}")
print(f"Current Python path: {sys.path}")
raise
# this is the Alembic Config object
config = context.config
# Determine service name from file path
service_name = os.path.basename(os.path.dirname(os.path.dirname(__file__)))
service_name_upper = service_name.upper().replace('-', '_')
# Set database URL from environment variables with multiple fallback strategies
database_url = (
os.getenv(f'{service_name_upper}_DATABASE_URL') or # Service-specific
os.getenv('DATABASE_URL') # Generic fallback
)
# If DATABASE_URL is not set, construct from individual components
if not database_url:
# Try generic PostgreSQL environment variables first
postgres_host = os.getenv('POSTGRES_HOST')
postgres_port = os.getenv('POSTGRES_PORT', '5432')
postgres_db = os.getenv('POSTGRES_DB')
postgres_user = os.getenv('POSTGRES_USER')
postgres_password = os.getenv('POSTGRES_PASSWORD')
if all([postgres_host, postgres_db, postgres_user, postgres_password]):
database_url = f"postgresql+asyncpg://{postgres_user}:{postgres_password}@{postgres_host}:{postgres_port}/{postgres_db}"
else:
# Try service-specific environment variables
db_host = os.getenv(f'{service_name_upper}_DB_HOST', f'{service_name}-db-service')
db_port = os.getenv(f'{service_name_upper}_DB_PORT', '5432')
db_name = os.getenv(f'{service_name_upper}_DB_NAME', f'{service_name.replace("-", "_")}_db')
db_user = os.getenv(f'{service_name_upper}_DB_USER', f'{service_name.replace("-", "_")}_user')
db_password = os.getenv(f'{service_name_upper}_DB_PASSWORD')
if db_password:
database_url = f"postgresql+asyncpg://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}"
if not database_url:
error_msg = f"ERROR: No database URL configured for {service_name} service"
print(error_msg)
raise Exception(error_msg)
config.set_main_option("sqlalchemy.url", database_url)
# Interpret the config file for Python logging
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# Set target metadata
target_metadata = Base.metadata
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode."""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
compare_type=True,
compare_server_default=True,
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection: Connection) -> None:
"""Execute migrations with the given connection."""
context.configure(
connection=connection,
target_metadata=target_metadata,
compare_type=True,
compare_server_default=True,
)
with context.begin_transaction():
context.run_migrations()
async def run_async_migrations() -> None:
"""Run migrations in 'online' mode with async support."""
connectable = async_engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
await connectable.dispose()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode."""
asyncio.run(run_async_migrations())
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@@ -0,0 +1,26 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View File

@@ -0,0 +1,97 @@
"""
Clean unified events table schema.
Revision ID: 20251205_unified
Revises:
Create Date: 2025-12-05
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers
revision = '20251205_unified'
down_revision = None
branch_labels = None
depends_on = None
def upgrade():
"""
Create unified events table with JSONB enrichment contexts.
"""
# Create events table
op.create_table(
'events',
# Core fields
sa.Column('id', postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column('tenant_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
# Classification
sa.Column('event_class', sa.String(50), nullable=False),
sa.Column('event_domain', sa.String(50), nullable=False),
sa.Column('event_type', sa.String(100), nullable=False),
sa.Column('service', sa.String(50), nullable=False),
# i18n content (NO hardcoded title/message)
sa.Column('i18n_title_key', sa.String(200), nullable=False),
sa.Column('i18n_title_params', postgresql.JSONB, nullable=False, server_default=sa.text("'{}'::jsonb")),
sa.Column('i18n_message_key', sa.String(200), nullable=False),
sa.Column('i18n_message_params', postgresql.JSONB, nullable=False, server_default=sa.text("'{}'::jsonb")),
# Priority
sa.Column('priority_score', sa.Integer, nullable=False, server_default='50'),
sa.Column('priority_level', sa.String(20), nullable=False),
sa.Column('type_class', sa.String(50), nullable=False),
# Enrichment contexts (JSONB)
sa.Column('orchestrator_context', postgresql.JSONB, nullable=True),
sa.Column('business_impact', postgresql.JSONB, nullable=True),
sa.Column('urgency', postgresql.JSONB, nullable=True),
sa.Column('user_agency', postgresql.JSONB, nullable=True),
sa.Column('trend_context', postgresql.JSONB, nullable=True),
# Smart actions
sa.Column('smart_actions', postgresql.JSONB, nullable=False, server_default=sa.text("'[]'::jsonb")),
# AI reasoning
sa.Column('ai_reasoning_summary_key', sa.String(200), nullable=True),
sa.Column('ai_reasoning_summary_params', postgresql.JSONB, nullable=True),
sa.Column('ai_reasoning_details', postgresql.JSONB, nullable=True),
sa.Column('confidence_score', sa.Float, nullable=True),
# Entity references
sa.Column('entity_links', postgresql.JSONB, nullable=False, server_default=sa.text("'{}'::jsonb")),
# Status
sa.Column('status', sa.String(20), nullable=False, server_default='active'),
sa.Column('resolved_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('acknowledged_at', sa.DateTime(timezone=True), nullable=True),
# Metadata
sa.Column('event_metadata', postgresql.JSONB, nullable=False, server_default=sa.text("'{}'::jsonb"))
)
# Create indexes for efficient queries (matching SQLAlchemy model)
op.create_index('idx_events_tenant_status', 'events', ['tenant_id', 'status'])
op.create_index('idx_events_tenant_priority', 'events', ['tenant_id', 'priority_score'])
op.create_index('idx_events_tenant_class', 'events', ['tenant_id', 'event_class'])
op.create_index('idx_events_tenant_created', 'events', ['tenant_id', 'created_at'])
op.create_index('idx_events_type_class_status', 'events', ['type_class', 'status'])
def downgrade():
"""
Drop events table and all indexes.
"""
op.drop_index('idx_events_type_class_status', 'events')
op.drop_index('idx_events_tenant_created', 'events')
op.drop_index('idx_events_tenant_class', 'events')
op.drop_index('idx_events_tenant_priority', 'events')
op.drop_index('idx_events_tenant_status', 'events')
op.drop_table('events')

View File

@@ -0,0 +1,45 @@
# Alert Processor Service v2.0 Dependencies
# FastAPI and server
fastapi==0.104.1
uvicorn[standard]==0.24.0
python-multipart==0.0.6
# Database
sqlalchemy[asyncio]==2.0.23
asyncpg==0.29.0
alembic==1.12.1
psycopg2-binary==2.9.9
# RabbitMQ
aio-pika==9.3.0
# Redis
redis[hiredis]==5.0.1
# HTTP client
httpx==0.25.1
# Validation and settings
pydantic==2.5.0
pydantic-settings==2.1.0
# Structured logging
structlog==23.2.0
# Utilities
python-dateutil==2.8.2
# Authentication
python-jose[cryptography]==3.3.0
# Monitoring and Observability
psutil==5.9.8
opentelemetry-api==1.39.1
opentelemetry-sdk==1.39.1
opentelemetry-instrumentation-fastapi==0.60b1
opentelemetry-exporter-otlp-proto-grpc==1.39.1
opentelemetry-exporter-otlp-proto-http==1.39.1
opentelemetry-instrumentation-httpx==0.60b1
opentelemetry-instrumentation-redis==0.60b1
opentelemetry-instrumentation-sqlalchemy==0.60b1

64
services/auth/Dockerfile Normal file
View File

@@ -0,0 +1,64 @@
# =============================================================================
# Auth Service Dockerfile - Environment-Configurable Base Images
# =============================================================================
# Build arguments for registry configuration:
# - BASE_REGISTRY: Registry URL (default: docker.io for Docker Hub)
# - PYTHON_IMAGE: Python image name and tag (default: python:3.11-slim)
# =============================================================================
ARG BASE_REGISTRY=docker.io
ARG PYTHON_IMAGE=python:3.11-slim
FROM ${BASE_REGISTRY}/${PYTHON_IMAGE} AS shared
WORKDIR /shared
COPY shared/ /shared/
ARG BASE_REGISTRY=docker.io
ARG PYTHON_IMAGE=python:3.11-slim
FROM ${BASE_REGISTRY}/${PYTHON_IMAGE}
# Create non-root user for security
RUN groupadd -r appgroup && useradd -r -g appgroup appuser
WORKDIR /app
# Install system dependencies
RUN apt-get update && apt-get install -y \
gcc \
curl \
&& rm -rf /var/lib/apt/lists/*
# Copy requirements
COPY shared/requirements-tracing.txt /tmp/
COPY services/auth/requirements.txt .
# Install Python dependencies
RUN pip install --no-cache-dir -r /tmp/requirements-tracing.txt
RUN pip install --no-cache-dir -r requirements.txt
# Copy shared libraries from the shared stage
COPY --from=shared /shared /app/shared
# Copy application code
COPY services/auth/ .
# Change ownership to non-root user
RUN chown -R appuser:appgroup /app
# Add shared libraries to Python path
ENV PYTHONPATH="/app:/app/shared:${PYTHONPATH:-}"
# Switch to non-root user
USER appuser
# Expose port
EXPOSE 8000
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1
# Run application
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]

1074
services/auth/README.md Normal file

File diff suppressed because it is too large Load Diff

84
services/auth/alembic.ini Normal file
View File

@@ -0,0 +1,84 @@
# ================================================================
# services/auth/alembic.ini - Alembic Configuration
# ================================================================
[alembic]
# path to migration scripts
script_location = migrations
# template used to generate migration file names
file_template = %%(year)d%%(month).2d%%(day).2d_%%(hour).2d%%(minute).2d_%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
timezone = Europe/Madrid
# max length of characters to apply to the
# "slug" field
truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
sourceless = false
# version of a migration file's filename format
version_num_format = %%s
# version path separator
version_path_separator = os
# set to 'true' to search source files recursively
# in each "version_locations" directory
recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
output_encoding = utf-8
# Database URL - will be overridden by environment variable or settings
sqlalchemy.url = postgresql+asyncpg://auth_user:password@auth-db-service:5432/auth_db
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts.
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

View File

View File

@@ -0,0 +1,3 @@
from .internal_demo import router as internal_demo_router
__all__ = ["internal_demo_router"]

View File

@@ -0,0 +1,214 @@
"""
User self-service account deletion API for GDPR compliance
Implements Article 17 (Right to erasure / "Right to be forgotten")
"""
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status, Request, BackgroundTasks
from pydantic import BaseModel, Field
from datetime import datetime, timezone
import structlog
from shared.auth.decorators import get_current_user_dep
from app.core.database import get_db
from app.services.admin_delete import AdminUserDeleteService
from app.models.users import User
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
import httpx
logger = structlog.get_logger()
router = APIRouter()
class AccountDeletionRequest(BaseModel):
"""Request model for account deletion"""
confirm_email: str = Field(..., description="User's email for confirmation")
reason: str = Field(default="", description="Optional reason for deletion")
password: str = Field(..., description="User's password for verification")
class DeletionScheduleResponse(BaseModel):
"""Response for scheduled deletion"""
message: str
user_id: str
scheduled_deletion_date: str
grace_period_days: int = 30
@router.delete("/api/v1/auth/me/account")
async def request_account_deletion(
deletion_request: AccountDeletionRequest,
request: Request,
current_user: dict = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""
Request account deletion (self-service)
GDPR Article 17 - Right to erasure ("right to be forgotten")
This initiates account deletion with a 30-day grace period.
During this period:
- Account is marked for deletion
- User can still log in and cancel deletion
- After 30 days, account is permanently deleted
Requires:
- Email confirmation matching logged-in user
- Current password verification
"""
try:
user_id = UUID(current_user["user_id"])
user_email = current_user.get("email")
if deletion_request.confirm_email.lower() != user_email.lower():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email confirmation does not match your account email"
)
query = select(User).where(User.id == user_id)
result = await db.execute(query)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
from app.core.security import SecurityManager
if not SecurityManager.verify_password(deletion_request.password, user.hashed_password):
logger.warning(
"account_deletion_invalid_password",
user_id=str(user_id),
ip_address=request.client.host if request.client else None
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid password"
)
logger.info(
"account_deletion_requested",
user_id=str(user_id),
email=user_email,
reason=deletion_request.reason[:100] if deletion_request.reason else None,
ip_address=request.client.host if request.client else None
)
tenant_id = current_user.get("tenant_id")
if tenant_id:
try:
async with httpx.AsyncClient(timeout=10.0) as client:
cancel_response = await client.get(
f"http://tenant-service:8000/api/v1/tenants/{tenant_id}/subscription/status",
headers={"Authorization": request.headers.get("Authorization")}
)
if cancel_response.status_code == 200:
subscription_data = cancel_response.json()
if subscription_data.get("status") in ["active", "pending_cancellation"]:
cancel_sub_response = await client.delete(
f"http://tenant-service:8000/api/v1/tenants/{tenant_id}/subscription",
headers={"Authorization": request.headers.get("Authorization")}
)
logger.info(
"subscription_cancelled_before_deletion",
user_id=str(user_id),
tenant_id=tenant_id,
subscription_status=subscription_data.get("status")
)
except Exception as e:
logger.warning(
"subscription_cancellation_failed_during_account_deletion",
user_id=str(user_id),
error=str(e)
)
deletion_service = AdminUserDeleteService(db)
result = await deletion_service.delete_admin_user_complete(
user_id=str(user_id),
requesting_user_id=str(user_id)
)
return {
"message": "Account deleted successfully",
"user_id": str(user_id),
"deletion_date": datetime.now(timezone.utc).isoformat(),
"data_retained": "Audit logs will be anonymized after legal retention period (1 year)",
"gdpr_article": "Article 17 - Right to erasure"
}
except HTTPException:
raise
except Exception as e:
logger.error(
"account_deletion_failed",
user_id=current_user.get("user_id"),
error=str(e)
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to process account deletion request"
)
@router.get("/api/v1/auth/me/account/deletion-info")
async def get_deletion_info(
current_user: dict = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""
Get information about what will be deleted
Shows user exactly what data will be deleted when they request
account deletion. Transparency requirement under GDPR.
"""
try:
user_id = UUID(current_user["user_id"])
deletion_service = AdminUserDeleteService(db)
preview = await deletion_service.preview_user_deletion(str(user_id))
return {
"user_info": preview.get("user"),
"what_will_be_deleted": {
"account_data": "Your account, email, name, and profile information",
"sessions": "All active sessions and refresh tokens",
"consents": "Your consent history and preferences",
"security_data": "Login history and security logs",
"tenant_data": preview.get("tenant_associations"),
"estimated_records": preview.get("estimated_deletions")
},
"what_will_be_retained": {
"audit_logs": "Anonymized for 1 year (legal requirement)",
"financial_records": "Anonymized for 7 years (tax law)",
"anonymized_analytics": "Aggregated data without personal identifiers"
},
"process": {
"immediate_deletion": True,
"grace_period": "No grace period - deletion is immediate",
"reversible": False,
"completion_time": "Immediate"
},
"gdpr_rights": {
"article_17": "Right to erasure (right to be forgotten)",
"article_5_1_e": "Storage limitation principle",
"exceptions": "Data required for legal obligations will be retained in anonymized form"
},
"warning": "⚠️ This action is irreversible. All your data will be permanently deleted."
}
except Exception as e:
logger.error(
"deletion_info_failed",
user_id=current_user.get("user_id"),
error=str(e)
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve deletion information"
)

View File

@@ -0,0 +1,657 @@
"""
Refactored Auth Operations with proper 3DS/3DS2 support
Implements SetupIntent-first architecture for secure registration flows
"""
import logging
from typing import Dict, Any
from fastapi import APIRouter, Depends, HTTPException, status, Request
from app.services.auth_service import auth_service, AuthService
from app.schemas.auth import UserRegistration, UserLogin, UserResponse
from app.models.users import User
from shared.exceptions.auth_exceptions import (
UserCreationError,
RegistrationError,
PaymentOrchestrationError
)
from shared.auth.decorators import get_current_user_dep
# Configure logging
logger = logging.getLogger(__name__)
# Create router
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
async def get_auth_service() -> AuthService:
"""Dependency injection for auth service"""
return auth_service
@router.post("/start-registration",
response_model=Dict[str, Any],
summary="Start secure registration with payment verification")
async def start_registration(
user_data: UserRegistration,
request: Request,
auth_service: AuthService = Depends(get_auth_service)
) -> Dict[str, Any]:
"""
Start secure registration flow with SetupIntent-first approach
This is the FIRST step in the atomic registration architecture:
1. Creates Stripe customer via tenant service
2. Creates SetupIntent with confirm=True
3. Returns SetupIntent data to frontend
IMPORTANT: NO subscription or user is created in this step!
Two possible outcomes:
- requires_action=True: 3DS required, frontend must confirm SetupIntent then call complete-registration
- requires_action=False: No 3DS required, but frontend STILL must call complete-registration
In BOTH cases, the frontend must call complete-registration to create the subscription and user.
This ensures consistent flow and prevents duplicate subscriptions.
Args:
user_data: User registration data with payment info
Returns:
SetupIntent result with:
- requires_action: True if 3DS required, False if not
- setup_intent_id: SetupIntent ID for verification
- client_secret: For 3DS authentication (when requires_action=True)
- customer_id: Stripe customer ID
- Other SetupIntent metadata
Raises:
HTTPException: 400 for validation errors, 500 for server errors
"""
try:
logger.info(f"Starting secure registration flow, email={user_data.email}, plan={user_data.subscription_plan}")
# Validate required fields
if not user_data.email or not user_data.email.strip():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email is required"
)
if not user_data.password or len(user_data.password) < 8:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Password must be at least 8 characters long"
)
if not user_data.full_name or not user_data.full_name.strip():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Full name is required"
)
if user_data.subscription_plan and not user_data.payment_method_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Payment method ID is required for subscription registration"
)
# Start secure registration flow
result = await auth_service.start_secure_registration_flow(user_data)
# Check if 3DS is required
if result.get('requires_action', False):
logger.info(f"Registration requires 3DS verification, email={user_data.email}, setup_intent_id={result.get('setup_intent_id')}")
return {
"requires_action": True,
"action_type": "setup_intent_confirmation",
"client_secret": result.get('client_secret'),
"setup_intent_id": result.get('setup_intent_id'),
"customer_id": result.get('customer_id'),
"payment_customer_id": result.get('payment_customer_id'),
"plan_id": result.get('plan_id'),
"payment_method_id": result.get('payment_method_id'),
"billing_cycle": result.get('billing_cycle'),
"coupon_info": result.get('coupon_info'),
"trial_info": result.get('trial_info'),
"email": result.get('email'),
"message": "Payment verification required. Frontend must confirm SetupIntent to handle 3DS."
}
else:
user = result.get('user')
user_id = user.id if user else None
logger.info(f"Registration completed without 3DS, email={user_data.email}, user_id={user_id}, subscription_id={result.get('subscription_id')}")
# Return complete registration result
user_data_response = None
if user:
user_data_response = {
"id": str(user.id),
"email": user.email,
"full_name": user.full_name,
"is_active": user.is_active
}
return {
"requires_action": False,
"setup_intent_id": result.get('setup_intent_id'),
"user": user_data_response,
"subscription_id": result.get('subscription_id'),
"payment_customer_id": result.get('payment_customer_id'),
"status": result.get('status'),
"message": "Registration completed successfully"
}
except RegistrationError as e:
logger.error(f"Registration flow failed: {str(e)}, email: {user_data.email}",
exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Registration failed: {str(e)}"
) from e
except PaymentOrchestrationError as e:
logger.error(f"Payment orchestration failed: {str(e)}, email: {user_data.email}",
exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Payment setup failed: {str(e)}"
) from e
except Exception as e:
logger.error(f"Unexpected registration error: {str(e)}, email: {user_data.email}",
exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Registration error: {str(e)}"
) from e
@router.post("/complete-registration",
response_model=Dict[str, Any],
summary="Complete registration after SetupIntent verification")
async def complete_registration(
verification_data: Dict[str, Any],
request: Request,
auth_service: AuthService = Depends(get_auth_service)
) -> Dict[str, Any]:
"""
Complete registration after frontend confirms SetupIntent
This is the SECOND step in the atomic registration architecture:
1. Called after frontend confirms SetupIntent (with or without 3DS)
2. Verifies SetupIntent status with Stripe
3. Creates subscription with verified payment method (FIRST time subscription is created)
4. Creates user record in auth database
5. Saves onboarding progress
6. Generates auth tokens for auto-login
This endpoint is called in TWO scenarios:
1. After user completes 3DS authentication (requires_action=True flow)
2. Immediately after start-registration (requires_action=False flow)
In BOTH cases, this is where the subscription and user are actually created.
This ensures consistent flow and prevents duplicate subscriptions.
Args:
verification_data: Must contain:
- setup_intent_id: Verified SetupIntent ID
- user_data: Original user registration data
Returns:
Complete registration result with:
- user: Created user data
- subscription_id: Created subscription ID
- payment_customer_id: Stripe customer ID
- access_token: JWT access token
- refresh_token: JWT refresh token
Raises:
HTTPException: 400 if setup_intent_id is missing, 500 for server errors
"""
try:
# Validate required fields
if not verification_data.get('setup_intent_id'):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="SetupIntent ID is required"
)
if not verification_data.get('user_data'):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="User data is required"
)
# Extract user data
user_data_dict = verification_data['user_data']
user_data = UserRegistration(**user_data_dict)
logger.info(f"Completing registration after SetupIntent verification, email={user_data.email}, setup_intent_id={verification_data['setup_intent_id']}")
# Complete registration with verified payment
result = await auth_service.complete_registration_with_verified_payment(
verification_data['setup_intent_id'],
user_data
)
logger.info(f"Registration completed successfully after 3DS, user_id={result['user'].id}, email={result['user'].email}, subscription_id={result.get('subscription_id')}")
return {
"success": True,
"user": {
"id": str(result['user'].id),
"email": result['user'].email,
"full_name": result['user'].full_name,
"is_active": result['user'].is_active,
"is_verified": result['user'].is_verified,
"created_at": result['user'].created_at.isoformat() if result['user'].created_at else None,
"role": result['user'].role
},
"subscription_id": result.get('subscription_id'),
"payment_customer_id": result.get('payment_customer_id'),
"status": result.get('status'),
"access_token": result.get('access_token'),
"refresh_token": result.get('refresh_token'),
"message": "Registration completed successfully after 3DS verification"
}
except RegistrationError as e:
logger.error(f"Registration completion after 3DS failed: {str(e)}, setup_intent_id: {verification_data.get('setup_intent_id')}, email: {user_data_dict.get('email') if user_data_dict else 'unknown'}",
exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Registration completion failed: {str(e)}"
) from e
except Exception as e:
logger.error(f"Unexpected registration completion error: {str(e)}, setup_intent_id: {verification_data.get('setup_intent_id')}",
exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Registration completion error: {str(e)}"
) from e
@router.post("/login",
response_model=Dict[str, Any],
summary="User login with subscription validation")
async def login(
login_data: UserLogin,
request: Request,
auth_service: AuthService = Depends(get_auth_service)
) -> Dict[str, Any]:
"""
User login endpoint with subscription validation
This endpoint:
1. Validates user credentials
2. Checks if user has active subscription (if required)
3. Returns authentication tokens
4. Updates last login timestamp
Args:
login_data: User login credentials (email and password)
Returns:
Authentication tokens and user information
Raises:
HTTPException: 401 for invalid credentials, 403 for subscription issues
"""
try:
logger.info(f"Login attempt, email={login_data.email}")
# Validate required fields
if not login_data.email or not login_data.email.strip():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email is required"
)
if not login_data.password or len(login_data.password) < 8:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Password must be at least 8 characters long"
)
# Call auth service to perform login
result = await auth_service.login_user(login_data)
logger.info(f"Login successful, email={login_data.email}, user_id={result['user'].id}")
# Extract tokens from result for top-level response
tokens = result.get('tokens', {})
return {
"success": True,
"access_token": tokens.get('access_token'),
"refresh_token": tokens.get('refresh_token'),
"token_type": tokens.get('token_type'),
"expires_in": tokens.get('expires_in'),
"user": {
"id": str(result['user'].id),
"email": result['user'].email,
"full_name": result['user'].full_name,
"is_active": result['user'].is_active,
"last_login": result['user'].last_login.isoformat() if result['user'].last_login else None
},
"subscription": result.get('subscription', {}),
"message": "Login successful"
}
except HTTPException:
# Re-raise HTTP exceptions (like 401 for invalid credentials)
raise
except Exception as e:
logger.error(f"Login failed: {str(e)}, email: {login_data.email}",
exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Login failed: {str(e)}"
) from e
# ============================================================================
# TOKEN MANAGEMENT ENDPOINTS - NEWLY ADDED
# ============================================================================
@router.post("/refresh",
response_model=Dict[str, Any],
summary="Refresh access token using refresh token")
async def refresh_token(
request: Request,
refresh_data: Dict[str, Any],
auth_service: AuthService = Depends(get_auth_service)
) -> Dict[str, Any]:
"""
Refresh access token using a valid refresh token
This endpoint:
1. Validates the refresh token
2. Generates new access and refresh tokens
3. Returns the new tokens
Args:
refresh_data: Dictionary containing refresh_token
Returns:
New authentication tokens
Raises:
HTTPException: 401 for invalid refresh tokens
"""
try:
logger.info("Token refresh request initiated")
# Extract refresh token from request
refresh_token = refresh_data.get("refresh_token")
if not refresh_token:
logger.warning("Refresh token missing from request")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Refresh token is required"
)
# Use service layer to refresh tokens
tokens = await auth_service.refresh_auth_tokens(refresh_token)
logger.info("Token refresh successful via service layer")
return {
"success": True,
"access_token": tokens.get("access_token"),
"refresh_token": tokens.get("refresh_token"),
"token_type": "bearer",
"expires_in": 1800, # 30 minutes
"message": "Token refresh successful"
}
except HTTPException:
# Re-raise HTTP exceptions
raise
except Exception as e:
logger.error(f"Token refresh failed: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Token refresh failed: {str(e)}"
) from e
@router.post("/verify",
response_model=Dict[str, Any],
summary="Verify token validity")
async def verify_token(
request: Request,
token_data: Dict[str, Any]
) -> Dict[str, Any]:
"""
Verify the validity of an access token
Args:
token_data: Dictionary containing access_token
Returns:
Token validation result
"""
try:
logger.info("Token verification request initiated")
# Extract access token from request
access_token = token_data.get("access_token")
if not access_token:
logger.warning("Access token missing from verification request")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Access token is required"
)
# Use service layer to verify token
result = await auth_service.verify_access_token(access_token)
logger.info("Token verification successful via service layer")
return {
"success": True,
"valid": result.get("valid"),
"user_id": result.get("user_id"),
"email": result.get("email"),
"message": "Token is valid"
}
except HTTPException:
# Re-raise HTTP exceptions
raise
except Exception as e:
logger.error(f"Token verification failed: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Token verification failed: {str(e)}"
) from e
@router.post("/logout",
response_model=Dict[str, Any],
summary="Logout and revoke refresh token")
async def logout(
request: Request,
logout_data: Dict[str, Any],
auth_service: AuthService = Depends(get_auth_service)
) -> Dict[str, Any]:
"""
Logout user and revoke refresh token
Args:
logout_data: Dictionary containing refresh_token
Returns:
Logout confirmation
"""
try:
logger.info("Logout request initiated")
# Extract refresh token from request
refresh_token = logout_data.get("refresh_token")
if not refresh_token:
logger.warning("Refresh token missing from logout request")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Refresh token is required"
)
# Use service layer to revoke refresh token
try:
await auth_service.revoke_refresh_token(refresh_token)
logger.info("Logout successful via service layer")
return {
"success": True,
"message": "Logout successful"
}
except Exception as e:
logger.error(f"Error during logout: {str(e)}")
# Don't fail logout if revocation fails
return {
"success": True,
"message": "Logout successful (token revocation failed but user logged out)"
}
except HTTPException:
# Re-raise HTTP exceptions
raise
except Exception as e:
logger.error(f"Logout failed: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Logout failed: {str(e)}"
) from e
@router.post("/change-password",
response_model=Dict[str, Any],
summary="Change user password")
async def change_password(
request: Request,
password_data: Dict[str, Any],
auth_service: AuthService = Depends(get_auth_service)
) -> Dict[str, Any]:
"""
Change user password
Args:
password_data: Dictionary containing current_password and new_password
Returns:
Password change confirmation
"""
try:
logger.info("Password change request initiated")
# Extract user from request state
if not hasattr(request.state, 'user') or not request.state.user:
logger.warning("Unauthorized password change attempt - no user context")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required"
)
user_id = request.state.user.get("user_id")
if not user_id:
logger.warning("Unauthorized password change attempt - no user_id")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid user context"
)
# Extract password data
current_password = password_data.get("current_password")
new_password = password_data.get("new_password")
if not current_password or not new_password:
logger.warning("Password change missing required fields")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Current password and new password are required"
)
if len(new_password) < 8:
logger.warning("New password too short")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="New password must be at least 8 characters long"
)
# Use service layer to change password
await auth_service.change_user_password(user_id, current_password, new_password)
logger.info(f"Password change successful via service layer, user_id={user_id}")
return {
"success": True,
"message": "Password changed successfully"
}
except HTTPException:
# Re-raise HTTP exceptions
raise
except Exception as e:
logger.error(f"Password change failed: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Password change failed: {str(e)}"
) from e
@router.post("/verify-email",
response_model=Dict[str, Any],
summary="Verify user email")
async def verify_email(
request: Request,
email_data: Dict[str, Any],
auth_service: AuthService = Depends(get_auth_service)
) -> Dict[str, Any]:
"""
Verify user email (placeholder implementation)
Args:
email_data: Dictionary containing email and verification_token
Returns:
Email verification confirmation
"""
try:
logger.info("Email verification request initiated")
# Extract email and token
email = email_data.get("email")
verification_token = email_data.get("verification_token")
if not email or not verification_token:
logger.warning("Email verification missing required fields")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email and verification token are required"
)
# Use service layer to verify email
await auth_service.verify_user_email(email, verification_token)
logger.info("Email verification successful via service layer")
return {
"success": True,
"message": "Email verified successfully"
}
except HTTPException:
# Re-raise HTTP exceptions
raise
except Exception as e:
logger.error(f"Email verification failed: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Email verification failed: {str(e)}"
) from e

View File

@@ -0,0 +1,372 @@
"""
User consent management API endpoints for GDPR compliance
"""
from typing import List, Optional
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status, Request
from pydantic import BaseModel, Field
from datetime import datetime, timezone
import structlog
import hashlib
from shared.auth.decorators import get_current_user_dep
from app.core.database import get_db
from app.models.consent import UserConsent, ConsentHistory
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_
logger = structlog.get_logger()
router = APIRouter()
class ConsentRequest(BaseModel):
"""Request model for granting/updating consent"""
terms_accepted: bool = Field(..., description="Accept terms of service")
privacy_accepted: bool = Field(..., description="Accept privacy policy")
marketing_consent: bool = Field(default=False, description="Consent to marketing communications")
analytics_consent: bool = Field(default=False, description="Consent to analytics cookies")
consent_method: str = Field(..., description="How consent was given (registration, settings, cookie_banner)")
consent_version: str = Field(default="1.0", description="Version of terms/privacy policy")
class ConsentResponse(BaseModel):
"""Response model for consent data"""
id: str
user_id: str
terms_accepted: bool
privacy_accepted: bool
marketing_consent: bool
analytics_consent: bool
consent_version: str
consent_method: str
consented_at: str
withdrawn_at: Optional[str]
class ConsentHistoryResponse(BaseModel):
"""Response model for consent history"""
id: str
user_id: str
action: str
consent_snapshot: dict
created_at: str
def hash_text(text: str) -> str:
"""Create hash of consent text for verification"""
return hashlib.sha256(text.encode()).hexdigest()
@router.post("/api/v1/auth/me/consent", response_model=ConsentResponse, status_code=status.HTTP_201_CREATED)
async def record_consent(
consent_data: ConsentRequest,
request: Request,
current_user: dict = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""
Record user consent for data processing
GDPR Article 7 - Conditions for consent
"""
try:
user_id = UUID(current_user["user_id"])
ip_address = request.client.host if request.client else None
user_agent = request.headers.get("user-agent")
consent = UserConsent(
user_id=user_id,
terms_accepted=consent_data.terms_accepted,
privacy_accepted=consent_data.privacy_accepted,
marketing_consent=consent_data.marketing_consent,
analytics_consent=consent_data.analytics_consent,
consent_version=consent_data.consent_version,
consent_method=consent_data.consent_method,
ip_address=ip_address,
user_agent=user_agent,
consented_at=datetime.now(timezone.utc)
)
db.add(consent)
await db.flush()
history = ConsentHistory(
user_id=user_id,
consent_id=consent.id,
action="granted",
consent_snapshot=consent_data.dict(),
ip_address=ip_address,
user_agent=user_agent,
consent_method=consent_data.consent_method,
created_at=datetime.now(timezone.utc)
)
db.add(history)
await db.commit()
await db.refresh(consent)
logger.info(
"consent_recorded",
user_id=str(user_id),
consent_version=consent_data.consent_version,
method=consent_data.consent_method
)
return ConsentResponse(
id=str(consent.id),
user_id=str(consent.user_id),
terms_accepted=consent.terms_accepted,
privacy_accepted=consent.privacy_accepted,
marketing_consent=consent.marketing_consent,
analytics_consent=consent.analytics_consent,
consent_version=consent.consent_version,
consent_method=consent.consent_method,
consented_at=consent.consented_at.isoformat(),
withdrawn_at=consent.withdrawn_at.isoformat() if consent.withdrawn_at else None
)
except Exception as e:
await db.rollback()
logger.error("error_recording_consent", error=str(e), user_id=current_user.get("user_id"))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to record consent"
)
@router.get("/api/v1/auth/me/consent/current", response_model=Optional[ConsentResponse])
async def get_current_consent(
current_user: dict = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""
Get current active consent for user
"""
try:
user_id = UUID(current_user["user_id"])
query = select(UserConsent).where(
and_(
UserConsent.user_id == user_id,
UserConsent.withdrawn_at.is_(None)
)
).order_by(UserConsent.consented_at.desc())
result = await db.execute(query)
consent = result.scalar_one_or_none()
if not consent:
return None
return ConsentResponse(
id=str(consent.id),
user_id=str(consent.user_id),
terms_accepted=consent.terms_accepted,
privacy_accepted=consent.privacy_accepted,
marketing_consent=consent.marketing_consent,
analytics_consent=consent.analytics_consent,
consent_version=consent.consent_version,
consent_method=consent.consent_method,
consented_at=consent.consented_at.isoformat(),
withdrawn_at=consent.withdrawn_at.isoformat() if consent.withdrawn_at else None
)
except Exception as e:
logger.error("error_getting_consent", error=str(e), user_id=current_user.get("user_id"))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve consent"
)
@router.get("/api/v1/auth/me/consent/history", response_model=List[ConsentHistoryResponse])
async def get_consent_history(
current_user: dict = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""
Get complete consent history for user
GDPR Article 7(1) - Demonstrating consent
"""
try:
user_id = UUID(current_user["user_id"])
query = select(ConsentHistory).where(
ConsentHistory.user_id == user_id
).order_by(ConsentHistory.created_at.desc())
result = await db.execute(query)
history = result.scalars().all()
return [
ConsentHistoryResponse(
id=str(h.id),
user_id=str(h.user_id),
action=h.action,
consent_snapshot=h.consent_snapshot,
created_at=h.created_at.isoformat()
)
for h in history
]
except Exception as e:
logger.error("error_getting_consent_history", error=str(e), user_id=current_user.get("user_id"))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve consent history"
)
@router.put("/api/v1/auth/me/consent", response_model=ConsentResponse)
async def update_consent(
consent_data: ConsentRequest,
request: Request,
current_user: dict = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""
Update user consent preferences
GDPR Article 7(3) - Withdrawal of consent
"""
try:
user_id = UUID(current_user["user_id"])
query = select(UserConsent).where(
and_(
UserConsent.user_id == user_id,
UserConsent.withdrawn_at.is_(None)
)
).order_by(UserConsent.consented_at.desc())
result = await db.execute(query)
current_consent = result.scalar_one_or_none()
if current_consent:
current_consent.withdrawn_at = datetime.now(timezone.utc)
history = ConsentHistory(
user_id=user_id,
consent_id=current_consent.id,
action="updated",
consent_snapshot=current_consent.to_dict(),
ip_address=request.client.host if request.client else None,
user_agent=request.headers.get("user-agent"),
consent_method=consent_data.consent_method,
created_at=datetime.now(timezone.utc)
)
db.add(history)
new_consent = UserConsent(
user_id=user_id,
terms_accepted=consent_data.terms_accepted,
privacy_accepted=consent_data.privacy_accepted,
marketing_consent=consent_data.marketing_consent,
analytics_consent=consent_data.analytics_consent,
consent_version=consent_data.consent_version,
consent_method=consent_data.consent_method,
ip_address=request.client.host if request.client else None,
user_agent=request.headers.get("user-agent"),
consented_at=datetime.now(timezone.utc)
)
db.add(new_consent)
await db.flush()
history = ConsentHistory(
user_id=user_id,
consent_id=new_consent.id,
action="granted" if not current_consent else "updated",
consent_snapshot=consent_data.dict(),
ip_address=request.client.host if request.client else None,
user_agent=request.headers.get("user-agent"),
consent_method=consent_data.consent_method,
created_at=datetime.now(timezone.utc)
)
db.add(history)
await db.commit()
await db.refresh(new_consent)
logger.info(
"consent_updated",
user_id=str(user_id),
consent_version=consent_data.consent_version
)
return ConsentResponse(
id=str(new_consent.id),
user_id=str(new_consent.user_id),
terms_accepted=new_consent.terms_accepted,
privacy_accepted=new_consent.privacy_accepted,
marketing_consent=new_consent.marketing_consent,
analytics_consent=new_consent.analytics_consent,
consent_version=new_consent.consent_version,
consent_method=new_consent.consent_method,
consented_at=new_consent.consented_at.isoformat(),
withdrawn_at=None
)
except Exception as e:
await db.rollback()
logger.error("error_updating_consent", error=str(e), user_id=current_user.get("user_id"))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to update consent"
)
@router.post("/api/v1/auth/me/consent/withdraw", status_code=status.HTTP_200_OK)
async def withdraw_consent(
request: Request,
current_user: dict = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""
Withdraw all consent
GDPR Article 7(3) - Right to withdraw consent
"""
try:
user_id = UUID(current_user["user_id"])
query = select(UserConsent).where(
and_(
UserConsent.user_id == user_id,
UserConsent.withdrawn_at.is_(None)
)
)
result = await db.execute(query)
consents = result.scalars().all()
for consent in consents:
consent.withdrawn_at = datetime.now(timezone.utc)
history = ConsentHistory(
user_id=user_id,
consent_id=consent.id,
action="withdrawn",
consent_snapshot=consent.to_dict(),
ip_address=request.client.host if request.client else None,
user_agent=request.headers.get("user-agent"),
consent_method="user_withdrawal",
created_at=datetime.now(timezone.utc)
)
db.add(history)
await db.commit()
logger.info("consent_withdrawn", user_id=str(user_id), count=len(consents))
return {
"message": "Consent withdrawn successfully",
"withdrawn_count": len(consents)
}
except Exception as e:
await db.rollback()
logger.error("error_withdrawing_consent", error=str(e), user_id=current_user.get("user_id"))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to withdraw consent"
)

View File

@@ -0,0 +1,121 @@
"""
User data export API endpoints for GDPR compliance
Implements Article 15 (Right to Access) and Article 20 (Right to Data Portability)
"""
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import JSONResponse
import structlog
from shared.auth.decorators import get_current_user_dep
from app.core.database import get_db
from app.services.data_export_service import DataExportService
logger = structlog.get_logger()
router = APIRouter()
@router.get("/api/v1/auth/me/export")
async def export_my_data(
current_user: dict = Depends(get_current_user_dep),
db = Depends(get_db)
):
"""
Export all personal data for the current user
GDPR Article 15 - Right of access by the data subject
GDPR Article 20 - Right to data portability
Returns complete user data in machine-readable JSON format including:
- Personal information
- Account data
- Consent history
- Security logs
- Audit trail
Response is provided in JSON format for easy data portability.
"""
try:
user_id = UUID(current_user["user_id"])
export_service = DataExportService(db)
data = await export_service.export_user_data(user_id)
logger.info(
"data_export_requested",
user_id=str(user_id),
email=current_user.get("email")
)
return JSONResponse(
content=data,
status_code=status.HTTP_200_OK,
headers={
"Content-Disposition": f'attachment; filename="user_data_export_{user_id}.json"',
"Content-Type": "application/json"
}
)
except Exception as e:
logger.error(
"data_export_failed",
user_id=current_user.get("user_id"),
error=str(e)
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to export user data"
)
@router.get("/api/v1/auth/me/export/summary")
async def get_export_summary(
current_user: dict = Depends(get_current_user_dep),
db = Depends(get_db)
):
"""
Get a summary of what data would be exported
Useful for showing users what data we have about them
before they request full export.
"""
try:
user_id = UUID(current_user["user_id"])
export_service = DataExportService(db)
data = await export_service.export_user_data(user_id)
summary = {
"user_id": str(user_id),
"data_categories": {
"personal_data": bool(data.get("personal_data")),
"account_data": bool(data.get("account_data")),
"consent_data": bool(data.get("consent_data")),
"security_data": bool(data.get("security_data")),
"onboarding_data": bool(data.get("onboarding_data")),
"audit_logs": bool(data.get("audit_logs"))
},
"data_counts": {
"active_sessions": data.get("account_data", {}).get("active_sessions_count", 0),
"consent_changes": data.get("consent_data", {}).get("total_consent_changes", 0),
"login_attempts": len(data.get("security_data", {}).get("recent_login_attempts", [])),
"audit_logs": data.get("audit_logs", {}).get("total_logs_exported", 0)
},
"export_format": "JSON",
"gdpr_articles": ["Article 15 (Right to Access)", "Article 20 (Data Portability)"]
}
return summary
except Exception as e:
logger.error(
"export_summary_failed",
user_id=current_user.get("user_id"),
error=str(e)
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to generate export summary"
)

View File

@@ -0,0 +1,229 @@
"""
Internal Demo Cloning API for Auth Service
Service-to-service endpoint for cloning authentication and user data
"""
from fastapi import APIRouter, Depends, HTTPException, Header
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
import structlog
import uuid
from datetime import datetime, timezone
from typing import Optional
import os
import sys
from pathlib import Path
import json
# Add shared path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent))
from app.core.database import get_db
from app.models.users import User
from app.core.config import settings
logger = structlog.get_logger()
router = APIRouter(prefix="/internal/demo", tags=["internal"])
# Base demo tenant IDs
DEMO_TENANT_PROFESSIONAL = "a1b2c3d4-e5f6-47a8-b9c0-d1e2f3a4b5c6"
@router.post("/clone")
async def clone_demo_data(
base_tenant_id: str,
virtual_tenant_id: str,
demo_account_type: str,
session_id: Optional[str] = None,
session_created_at: Optional[str] = None,
db: AsyncSession = Depends(get_db)
):
"""
Clone auth service data for a virtual demo tenant
Clones:
- Demo users (owner and staff)
Note: Tenant memberships are handled by the tenant service's internal_demo endpoint
Args:
base_tenant_id: Template tenant UUID to clone from
virtual_tenant_id: Target virtual tenant UUID
demo_account_type: Type of demo account
session_id: Originating session ID for tracing
Returns:
Cloning status and record counts
"""
start_time = datetime.now(timezone.utc)
# Parse session creation time
if session_created_at:
try:
session_time = datetime.fromisoformat(session_created_at.replace('Z', '+00:00'))
except (ValueError, AttributeError):
session_time = start_time
else:
session_time = start_time
logger.info(
"Starting auth data cloning",
base_tenant_id=base_tenant_id,
virtual_tenant_id=virtual_tenant_id,
demo_account_type=demo_account_type,
session_id=session_id,
session_created_at=session_created_at
)
try:
# Validate UUIDs
base_uuid = uuid.UUID(base_tenant_id)
virtual_uuid = uuid.UUID(virtual_tenant_id)
# Note: We don't check for existing users since User model doesn't have demo_session_id
# Demo users are identified by their email addresses from the seed data
# Idempotency is handled by checking if each user email already exists below
# Load demo users from JSON seed file
from shared.utils.seed_data_paths import get_seed_data_path
if demo_account_type == "professional":
json_file = get_seed_data_path("professional", "02-auth.json")
elif demo_account_type == "enterprise":
json_file = get_seed_data_path("enterprise", "02-auth.json")
elif demo_account_type == "enterprise_child":
# Child locations don't have separate auth data - they share parent's users
logger.info("enterprise_child uses parent tenant auth, skipping user cloning", virtual_tenant_id=virtual_tenant_id)
return {
"service": "auth",
"status": "completed",
"records_cloned": 0,
"duration_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000),
"details": {"users": 0, "note": "Child locations share parent auth"}
}
else:
raise ValueError(f"Invalid demo account type: {demo_account_type}")
# Load JSON data
import json
with open(json_file, 'r', encoding='utf-8') as f:
seed_data = json.load(f)
# Get demo users for this account type
demo_users_data = seed_data.get("users", [])
records_cloned = 0
# Create users and tenant memberships
for user_data in demo_users_data:
user_id = uuid.UUID(user_data["id"])
# Create user if not exists
user_result = await db.execute(
select(User).where(User.id == user_id)
)
existing_user = user_result.scalars().first()
if not existing_user:
# Apply date adjustments to created_at and updated_at
from shared.utils.demo_dates import adjust_date_for_demo
# Adjust created_at date
created_at_str = user_data.get("created_at", session_time.isoformat())
if isinstance(created_at_str, str):
try:
original_created_at = datetime.fromisoformat(created_at_str.replace('Z', '+00:00'))
adjusted_created_at = adjust_date_for_demo(original_created_at, session_time)
except ValueError:
adjusted_created_at = session_time
else:
adjusted_created_at = session_time
# Adjust updated_at date (same as created_at for demo users)
adjusted_updated_at = adjusted_created_at
# Get full_name from either "name" or "full_name" field
full_name = user_data.get("full_name") or user_data.get("name", "Demo User")
# For demo users, use a placeholder hashed password (they won't actually log in)
# In production, this would be properly hashed
demo_hashed_password = "$2b$12$LQv3c1yqBWVHxkd0LHAkCOYz6TtxMQJqhN8/LewY5GyYqNlI.eFKW" # "demo_password"
user = User(
id=user_id,
email=user_data["email"],
full_name=full_name,
hashed_password=demo_hashed_password,
is_active=user_data.get("is_active", True),
is_verified=True,
role=user_data.get("role", "member"),
language=user_data.get("language", "es"),
timezone=user_data.get("timezone", "Europe/Madrid"),
created_at=adjusted_created_at,
updated_at=adjusted_updated_at
)
db.add(user)
records_cloned += 1
# Note: Tenant memberships are handled by tenant service
# Only create users in auth service
await db.commit()
duration_ms = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
logger.info(
"Auth data cloning completed",
virtual_tenant_id=virtual_tenant_id,
session_id=session_id,
records_cloned=records_cloned,
duration_ms=duration_ms
)
return {
"service": "auth",
"status": "completed",
"records_cloned": records_cloned,
"base_tenant_id": str(base_tenant_id),
"virtual_tenant_id": str(virtual_tenant_id),
"session_id": session_id,
"demo_account_type": demo_account_type,
"duration_ms": duration_ms
}
except ValueError as e:
logger.error("Invalid UUID format", error=str(e), virtual_tenant_id=virtual_tenant_id)
raise HTTPException(status_code=400, detail=f"Invalid UUID: {str(e)}")
except Exception as e:
logger.error(
"Failed to clone auth data",
error=str(e),
virtual_tenant_id=virtual_tenant_id,
exc_info=True
)
# Rollback on error
await db.rollback()
return {
"service": "auth",
"status": "failed",
"records_cloned": 0,
"duration_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000),
"error": str(e)
}
@router.get("/clone/health")
async def clone_health_check():
"""
Health check for internal cloning endpoint
Used by orchestrator to verify service availability
"""
return {
"service": "auth",
"clone_endpoint": "available",
"version": "1.0.0"
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,308 @@
# services/auth/app/api/password_reset.py
"""
Password reset API endpoints
Handles forgot password and password reset functionality
"""
import logging
from datetime import datetime, timedelta, timezone
from typing import Dict, Any
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks
from app.services.auth_service import auth_service, AuthService
from app.schemas.auth import PasswordReset, PasswordResetConfirm
from app.core.security import SecurityManager
from app.core.config import settings
from app.repositories.password_reset_repository import PasswordResetTokenRepository
from app.repositories.user_repository import UserRepository
from app.models.users import User
from shared.clients.notification_client import NotificationServiceClient
import structlog
# Configure logging
logger = structlog.get_logger()
# Create router
router = APIRouter(prefix="/api/v1/auth", tags=["password-reset"])
async def get_auth_service() -> AuthService:
"""Dependency injection for auth service"""
return auth_service
def generate_reset_token() -> str:
"""Generate a secure password reset token"""
import secrets
return secrets.token_urlsafe(32)
async def send_password_reset_email(email: str, reset_token: str, user_full_name: str):
"""Send password reset email in background using notification service"""
try:
# Construct reset link (this should match your frontend URL)
# Use FRONTEND_URL from settings if available, otherwise fall back to gateway URL
frontend_url = getattr(settings, 'FRONTEND_URL', settings.GATEWAY_URL)
reset_link = f"{frontend_url}/reset-password?token={reset_token}"
# Create HTML content for the password reset email in Spanish
html_content = f"""
<!DOCTYPE html>
<html lang="es">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Restablecer Contraseña</title>
<style>
body {{
font-family: Arial, sans-serif;
line-height: 1.6;
color: #333;
max-width: 600px;
margin: 0 auto;
padding: 20px;
background-color: #f9f9f9;
}}
.header {{
text-align: center;
margin-bottom: 30px;
background: linear-gradient(135deg, #4F46E5 0%, #7C3AED 100%);
color: white;
padding: 20px;
border-radius: 8px;
}}
.content {{
background: white;
padding: 30px;
border-radius: 8px;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
}}
.button {{
display: inline-block;
padding: 12px 30px;
background-color: #4F46E5;
color: white;
text-decoration: none;
border-radius: 5px;
margin: 20px 0;
font-weight: bold;
}}
.footer {{
margin-top: 40px;
text-align: center;
font-size: 0.9em;
color: #666;
padding-top: 20px;
border-top: 1px solid #eee;
}}
</style>
</head>
<body>
<div class="header">
<h1>Restablecer Contraseña</h1>
</div>
<div class="content">
<p>Hola {user_full_name},</p>
<p>Recibimos una solicitud para restablecer tu contraseña. Haz clic en el botón de abajo para crear una nueva contraseña:</p>
<p style="text-align: center; margin: 30px 0;">
<a href="{reset_link}" class="button">Restablecer Contraseña</a>
</p>
<p>Si no solicitaste un restablecimiento de contraseña, puedes ignorar este correo electrónico de forma segura.</p>
<p>Este enlace expirará en 1 hora por razones de seguridad.</p>
</div>
<div class="footer">
<p>Este es un mensaje automático de BakeWise. Por favor, no respondas a este correo electrónico.</p>
</div>
</body>
</html>
"""
# Create text content as fallback
text_content = f"""
Hola {user_full_name},
Recibimos una solicitud para restablecer tu contraseña. Haz clic en el siguiente enlace para crear una nueva contraseña:
{reset_link}
Si no solicitaste un restablecimiento de contraseña, puedes ignorar este correo electrónico de forma segura.
Este enlace expirará en 1 hora por razones de seguridad.
Este es un mensaje automático de BakeWise. Por favor, no respondas a este correo electrónico.
"""
# Send email using the notification service
notification_client = NotificationServiceClient(settings)
# Send the notification using the send_email method
await notification_client.send_email(
tenant_id="system", # Using system tenant for password resets
to_email=email,
subject="Restablecer Contraseña",
message=text_content,
html_content=html_content,
priority="high"
)
logger.info(f"Password reset email sent successfully to {email}")
except Exception as e:
logger.error(f"Failed to send password reset email to {email}: {str(e)}")
@router.post("/password/reset-request",
summary="Request password reset",
description="Send a password reset link to the user's email")
async def request_password_reset(
reset_request: PasswordReset,
background_tasks: BackgroundTasks,
auth_service: AuthService = Depends(get_auth_service)
) -> Dict[str, Any]:
"""
Request a password reset
This endpoint:
1. Finds the user by email
2. Generates a password reset token
3. Stores the token in the database
4. Sends a password reset email to the user
"""
try:
logger.info(f"Password reset request for email: {reset_request.email}")
# Find user by email
async with auth_service.database_manager.get_session() as session:
user_repo = UserRepository(User, session)
user = await user_repo.get_by_field("email", reset_request.email)
if not user:
# Don't reveal if email exists to prevent enumeration attacks
logger.info(f"Password reset request for non-existent email: {reset_request.email}")
return {"message": "If an account with this email exists, a reset link has been sent."}
# Generate a secure reset token
reset_token = generate_reset_token()
# Set token expiration (e.g., 1 hour)
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
# Store the reset token in the database
token_repo = PasswordResetTokenRepository(session)
# Clean up any existing unused tokens for this user
await token_repo.cleanup_expired_tokens()
# Create new reset token
await token_repo.create_token(
user_id=str(user.id),
token=reset_token,
expires_at=expires_at
)
# Commit the transaction
await session.commit()
# Send password reset email in background
background_tasks.add_task(
send_password_reset_email,
user.email,
reset_token,
user.full_name
)
logger.info(f"Password reset token created for user: {user.email}")
return {"message": "If an account with this email exists, a reset link has been sent."}
except HTTPException:
raise
except Exception as e:
logger.error(f"Password reset request failed: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Password reset request failed"
)
@router.post("/password/reset",
summary="Reset password with token",
description="Reset user password using a valid reset token")
async def reset_password(
reset_confirm: PasswordResetConfirm,
auth_service: AuthService = Depends(get_auth_service)
) -> Dict[str, Any]:
"""
Reset password using a valid reset token
This endpoint:
1. Validates the reset token
2. Checks if the token is valid and not expired
3. Updates the user's password
4. Marks the token as used
"""
try:
logger.info(f"Password reset attempt with token: {reset_confirm.token[:10]}...")
# Validate password strength
if not SecurityManager.validate_password(reset_confirm.new_password):
errors = SecurityManager.get_password_validation_errors(reset_confirm.new_password)
logger.warning(f"Password validation failed: {errors}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Password does not meet requirements: {'; '.join(errors)}"
)
# Find the reset token in the database
async with auth_service.database_manager.get_session() as session:
token_repo = PasswordResetTokenRepository(session)
reset_token_obj = await token_repo.get_token_by_value(reset_confirm.token)
if not reset_token_obj:
logger.warning(f"Invalid or expired password reset token: {reset_confirm.token[:10]}...")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid or expired reset token"
)
# Get the user associated with this token
user_repo = UserRepository(User, session)
user = await user_repo.get_by_id(str(reset_token_obj.user_id))
if not user:
logger.error(f"User not found for reset token: {reset_confirm.token[:10]}...")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid reset token"
)
# Hash the new password
hashed_password = SecurityManager.hash_password(reset_confirm.new_password)
# Update user's password
await user_repo.update(str(user.id), {
"hashed_password": hashed_password
})
# Mark the reset token as used
await token_repo.mark_token_as_used(str(reset_token_obj.id))
# Commit the transactions
await session.commit()
logger.info(f"Password successfully reset for user: {user.email}")
return {"message": "Password has been reset successfully"}
except HTTPException:
raise
except Exception as e:
logger.error(f"Password reset failed: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Password reset failed"
)

View File

@@ -0,0 +1,662 @@
"""
User management API routes
"""
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks, Path, Body
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Dict, Any
import structlog
import uuid
from datetime import datetime, timezone
from app.core.database import get_db, get_background_db_session
from app.schemas.auth import UserResponse, PasswordChange
from app.schemas.users import UserUpdate, BatchUserRequest, OwnerUserCreate
from app.services.user_service import EnhancedUserService
from app.models.users import User
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.admin_delete import AdminUserDeleteService
from app.models import AuditLog
# Import unified authentication from shared library
from shared.auth.decorators import (
get_current_user_dep,
require_admin_role_dep
)
from shared.security import create_audit_logger, AuditSeverity, AuditAction
logger = structlog.get_logger()
router = APIRouter(tags=["users"])
# Initialize audit logger
audit_logger = create_audit_logger("auth-service", AuditLog)
@router.delete("/api/v1/auth/users/{user_id}")
async def delete_admin_user(
background_tasks: BackgroundTasks,
user_id: str = Path(..., description="User ID"),
current_user = Depends(require_admin_role_dep),
db: AsyncSession = Depends(get_db)
):
"""
Delete an admin user and all associated data across all services.
This operation will:
1. Cancel any active training jobs for user's tenants
2. Delete all trained models and artifacts
3. Delete all forecasts and predictions
4. Delete notification preferences and logs
5. Handle tenant ownership (transfer or delete)
6. Delete user account and authentication data
**Warning: This operation is irreversible!**
"""
# Validate user_id format
try:
uuid.UUID(user_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid user ID format"
)
# Quick validation that user exists before starting background task
deletion_service = AdminUserDeleteService(db)
user_info = await deletion_service._validate_admin_user(user_id)
if not user_info:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Admin user {user_id} not found"
)
# Log audit event for user deletion
try:
# Get tenant_id from current_user or use a placeholder for system-level operations
tenant_id_str = current_user.get("tenant_id", "00000000-0000-0000-0000-000000000000")
await audit_logger.log_deletion(
db_session=db,
tenant_id=tenant_id_str,
user_id=current_user["user_id"],
resource_type="user",
resource_id=user_id,
resource_data=user_info,
description=f"Admin {current_user.get('email', current_user['user_id'])} initiated deletion of user {user_info.get('email', user_id)}",
endpoint="/delete/{user_id}",
method="DELETE"
)
except Exception as audit_error:
logger.warning("Failed to log audit event", error=str(audit_error))
# Start deletion as background task for better performance
background_tasks.add_task(
execute_admin_user_deletion,
user_id=user_id,
requesting_user_id=current_user["user_id"]
)
return {
"success": True,
"message": f"Admin user deletion for {user_id} has been initiated",
"status": "processing",
"user_info": user_info,
"initiated_at": datetime.utcnow().isoformat(),
"note": "Deletion is processing in the background. Check logs for completion status."
}
# Add this background task function to services/auth/app/api/users.py:
async def execute_admin_user_deletion(user_id: str, requesting_user_id: str):
"""
Background task using shared infrastructure
"""
# ✅ Use the shared background session
async with get_background_db_session() as session:
deletion_service = AdminUserDeleteService(session)
result = await deletion_service.delete_admin_user_complete(
user_id=user_id,
requesting_user_id=requesting_user_id
)
logger.info("Background admin user deletion completed successfully",
user_id=user_id,
requesting_user=requesting_user_id,
result=result)
@router.get("/api/v1/auth/users/{user_id}/deletion-preview")
async def preview_user_deletion(
user_id: str = Path(..., description="User ID"),
db: AsyncSession = Depends(get_db)
):
"""
Preview what data would be deleted for an admin user.
This endpoint provides a dry-run preview of the deletion operation
without actually deleting any data.
"""
try:
uuid.UUID(user_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid user ID format"
)
deletion_service = AdminUserDeleteService(db)
# Get user info
user_info = await deletion_service._validate_admin_user(user_id)
if not user_info:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Admin user {user_id} not found"
)
# Get tenant associations
tenant_info = await deletion_service._get_user_tenant_info(user_id)
# Build preview
preview = {
"user": user_info,
"tenant_associations": tenant_info,
"estimated_deletions": {
"training_models": "All models for associated tenants",
"forecasts": "All forecasts for associated tenants",
"notifications": "All user notification data",
"tenant_memberships": tenant_info['total_tenants'],
"owned_tenants": f"{tenant_info['owned_tenants']} (will be transferred or deleted)"
},
"warning": "This operation is irreversible and will permanently delete all associated data"
}
return preview
@router.get("/api/v1/auth/users/{user_id}", response_model=UserResponse)
async def get_user_by_id(
user_id: str = Path(..., description="User ID"),
db: AsyncSession = Depends(get_db)
):
"""
Get user information by user ID.
This endpoint is for internal service-to-service communication.
It returns user details needed by other services (e.g., tenant service for enriching member data).
"""
try:
# Validate UUID format
try:
uuid.UUID(user_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid user ID format"
)
# Fetch user from database
from app.repositories import UserRepository
user_repo = UserRepository(User, db)
user = await user_repo.get_by_id(user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"User {user_id} not found"
)
logger.debug("Retrieved user by ID", user_id=user_id, email=user.email)
return UserResponse(
id=str(user.id),
email=user.email,
full_name=user.full_name,
is_active=user.is_active,
is_verified=user.is_verified,
phone=user.phone,
language=user.language or "es",
timezone=user.timezone or "Europe/Madrid",
created_at=user.created_at,
last_login=user.last_login,
role=user.role,
tenant_id=None,
payment_customer_id=user.payment_customer_id,
default_payment_method_id=user.default_payment_method_id
)
except HTTPException:
raise
except Exception as e:
logger.error("Get user by ID error", user_id=user_id, error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get user information"
)
@router.put("/api/v1/auth/users/{user_id}", response_model=UserResponse)
async def update_user_profile(
user_id: str = Path(..., description="User ID"),
update_data: UserUpdate = Body(..., description="User profile update data"),
current_user = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""
Update user profile information.
This endpoint allows users to update their profile information including:
- Full name
- Phone number
- Language preference
- Timezone
**Permissions:** Users can update their own profile, admins can update any user's profile
"""
try:
# Validate UUID format
try:
uuid.UUID(user_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid user ID format"
)
# Check permissions - user can update their own profile, admins can update any
if current_user["user_id"] != user_id:
# Check if current user has admin privileges
user_role = current_user.get("role", "user")
if user_role not in ["admin", "super_admin", "manager"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Insufficient permissions to update this user's profile"
)
# Fetch user from database
from app.repositories import UserRepository
user_repo = UserRepository(User, db)
user = await user_repo.get_by_id(user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"User {user_id} not found"
)
# Prepare update data (only include fields that are provided)
update_fields = update_data.dict(exclude_unset=True)
if not update_fields:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="No update data provided"
)
# Update user
updated_user = await user_repo.update(user_id, update_fields)
logger.info("User profile updated", user_id=user_id, updated_fields=list(update_fields.keys()))
# Log audit event for user profile update
try:
# Get tenant_id from current_user or use a placeholder for system-level operations
tenant_id_str = current_user.get("tenant_id", "00000000-0000-0000-0000-000000000000")
await audit_logger.log_event(
db_session=db,
tenant_id=tenant_id_str,
user_id=current_user["user_id"],
action=AuditAction.UPDATE.value,
resource_type="user",
resource_id=user_id,
severity=AuditSeverity.MEDIUM.value,
description=f"User {current_user.get('email', current_user['user_id'])} updated profile for user {user.email}",
changes={"updated_fields": list(update_fields.keys())},
audit_metadata={"updated_data": update_fields},
endpoint="/users/{user_id}",
method="PUT"
)
except Exception as audit_error:
logger.warning("Failed to log audit event", error=str(audit_error))
return UserResponse(
id=str(updated_user.id),
email=updated_user.email,
full_name=updated_user.full_name,
is_active=updated_user.is_active,
is_verified=updated_user.is_verified,
phone=updated_user.phone,
language=updated_user.language or "es",
timezone=updated_user.timezone or "Europe/Madrid",
created_at=updated_user.created_at,
last_login=updated_user.last_login,
role=updated_user.role,
tenant_id=None,
payment_customer_id=updated_user.payment_customer_id,
default_payment_method_id=updated_user.default_payment_method_id
)
except HTTPException:
raise
except Exception as e:
logger.error("Update user profile error", user_id=user_id, error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to update user profile"
)
@router.post("/api/v1/auth/users/create-by-owner", response_model=UserResponse)
async def create_user_by_owner(
user_data: OwnerUserCreate,
current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""
Create a new user account (owner/admin only - for pilot phase).
This endpoint allows tenant owners to directly create user accounts
with passwords during the pilot phase. In production, this will be
replaced with an invitation-based flow.
**Permissions:** Owner or Admin role required
**Security:** Password is hashed server-side before storage
"""
try:
# Verify caller has admin or owner privileges
# In pilot phase, we allow 'admin' role from auth service
user_role = current_user.get("role", "user")
if user_role not in ["admin", "super_admin", "manager"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only administrators can create users directly"
)
# Validate email uniqueness
from app.repositories import UserRepository
user_repo = UserRepository(User, db)
existing_user = await user_repo.get_by_email(user_data.email)
if existing_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"User with email {user_data.email} already exists"
)
# Hash password
from app.core.security import SecurityManager
hashed_password = SecurityManager.hash_password(user_data.password)
# Create user
create_data = {
"email": user_data.email,
"full_name": user_data.full_name,
"hashed_password": hashed_password,
"phone": user_data.phone,
"role": user_data.role,
"language": user_data.language or "es",
"timezone": user_data.timezone or "Europe/Madrid",
"is_active": True,
"is_verified": False # Can be verified later if needed
}
new_user = await user_repo.create_user(create_data)
logger.info(
"User created by owner",
created_user_id=str(new_user.id),
created_user_email=new_user.email,
created_by=current_user.get("user_id"),
created_by_email=current_user.get("email")
)
# Return user response
return UserResponse(
id=str(new_user.id),
email=new_user.email,
full_name=new_user.full_name,
is_active=new_user.is_active,
is_verified=new_user.is_verified,
phone=new_user.phone,
language=new_user.language,
timezone=new_user.timezone,
created_at=new_user.created_at,
last_login=new_user.last_login,
role=new_user.role,
tenant_id=None # Will be set when added to tenant
)
except HTTPException:
raise
except Exception as e:
logger.error(
"Failed to create user by owner",
email=user_data.email,
error=str(e),
created_by=current_user.get("user_id")
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to create user account"
)
@router.post("/api/v1/auth/users/batch", response_model=Dict[str, Any])
async def get_users_batch(
request: BatchUserRequest,
db: AsyncSession = Depends(get_db)
):
"""
Get multiple users by their IDs in a single request.
This endpoint is for internal service-to-service communication.
It efficiently fetches multiple user records needed by other services
(e.g., tenant service for enriching member lists).
Returns a dict mapping user_id -> user data, with null for non-existent users.
"""
try:
# Validate all UUIDs
validated_ids = []
for user_id in request.user_ids:
try:
uuid.UUID(user_id)
validated_ids.append(user_id)
except ValueError:
logger.warning(f"Invalid user ID format in batch request: {user_id}")
continue
if not validated_ids:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="No valid user IDs provided"
)
# Fetch users from database
from app.repositories import UserRepository
user_repo = UserRepository(User, db)
# Build response map
user_map = {}
for user_id in validated_ids:
user = await user_repo.get_by_id(user_id)
if user:
user_map[user_id] = {
"id": str(user.id),
"email": user.email,
"full_name": user.full_name,
"is_active": user.is_active,
"is_verified": user.is_verified,
"phone": user.phone,
"language": user.language or "es",
"timezone": user.timezone or "Europe/Madrid",
"created_at": user.created_at.isoformat() if user.created_at else None,
"last_login": user.last_login.isoformat() if user.last_login else None,
"role": user.role
}
else:
user_map[user_id] = None
logger.debug(
"Batch user fetch completed",
requested_count=len(request.user_ids),
found_count=sum(1 for v in user_map.values() if v is not None)
)
return {
"users": user_map,
"requested_count": len(request.user_ids),
"found_count": sum(1 for v in user_map.values() if v is not None)
}
except HTTPException:
raise
except Exception as e:
logger.error("Batch user fetch error", error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to fetch users"
)
@router.get("/api/v1/auth/users/{user_id}/activity")
async def get_user_activity(
user_id: str = Path(..., description="User ID"),
current_user = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""
Get user activity information.
This endpoint returns detailed activity information for a user including:
- Last login timestamp
- Account creation date
- Active session count
- Last activity timestamp
- User status information
**Permissions:** User can view their own activity, admins can view any user's activity
"""
try:
# Validate UUID format
try:
uuid.UUID(user_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid user ID format"
)
# Check permissions - user can view their own activity, admins can view any
if current_user["user_id"] != user_id:
# Check if current user has admin privileges
user_role = current_user.get("role", "user")
if user_role not in ["admin", "super_admin", "manager"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Insufficient permissions to view this user's activity"
)
# Initialize enhanced user service
from app.core.config import settings
from shared.database.base import create_database_manager
database_manager = create_database_manager(settings.DATABASE_URL, "tenant-service")
user_service = EnhancedUserService(database_manager)
# Get user activity data
activity_data = await user_service.get_user_activity(user_id)
if "error" in activity_data:
if activity_data["error"] == "User not found":
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get user activity: {activity_data['error']}"
)
logger.debug("Retrieved user activity", user_id=user_id)
return activity_data
except HTTPException:
raise
except Exception as e:
logger.error("Get user activity error", user_id=user_id, error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get user activity information"
)
@router.patch("/api/v1/auth/users/{user_id}/tenant")
async def update_user_tenant(
user_id: str = Path(..., description="User ID"),
tenant_data: Dict[str, Any] = Body(..., description="Tenant data containing tenant_id"),
db: AsyncSession = Depends(get_db)
):
"""
Update user's tenant_id after tenant registration
This endpoint is called by the tenant service after a user creates their tenant.
It links the user to their newly created tenant.
"""
try:
# Log the incoming request data for debugging
logger.debug("Received tenant update request",
user_id=user_id,
tenant_data=tenant_data)
tenant_id = tenant_data.get("tenant_id")
if not tenant_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="tenant_id is required"
)
logger.info("Updating user tenant_id",
user_id=user_id,
tenant_id=tenant_id)
user_service = EnhancedUserService(db)
user = await user_service.get_user_by_id(uuid.UUID(user_id), session=db)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
# DEPRECATED: User-tenant relationships are now managed by tenant service
# This endpoint is kept for backward compatibility but does nothing
# The tenant service should manage user-tenant relationships internally
logger.warning("DEPRECATED: update_user_tenant endpoint called - user-tenant relationships are now managed by tenant service",
user_id=user_id,
tenant_id=tenant_id)
# Return success for backward compatibility, but don't actually update anything
return {
"success": True,
"user_id": str(user.id),
"tenant_id": tenant_id,
"message": "User-tenant relationships are now managed by tenant service. This endpoint is deprecated."
}
except HTTPException:
raise
except Exception as e:
logger.error("Failed to update user tenant_id",
user_id=user_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to update user tenant_id"
)

View File

View File

@@ -0,0 +1,132 @@
"""
Authentication dependency for auth service
services/auth/app/core/auth.py
"""
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from jose import JWTError, jwt
import structlog
from app.core.config import settings
from app.core.database import get_db
from app.models.users import User
logger = structlog.get_logger()
security = HTTPBearer()
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: AsyncSession = Depends(get_db)
) -> User:
"""
Dependency to get the current authenticated user
"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
# Decode JWT token
payload = jwt.decode(
credentials.credentials,
settings.JWT_SECRET_KEY,
algorithms=[settings.JWT_ALGORITHM]
)
# Get user identifier from token
user_id: str = payload.get("sub")
if user_id is None:
logger.warning("Token payload missing 'sub' field")
raise credentials_exception
logger.info(f"Authenticating user: {user_id}")
except JWTError as e:
logger.warning(f"JWT decode error: {e}")
raise credentials_exception
try:
# Get user from database
result = await db.execute(
select(User).where(User.id == user_id)
)
user = result.scalar_one_or_none()
if user is None:
logger.warning(f"User not found for ID: {user_id}")
raise credentials_exception
if not user.is_active:
logger.warning(f"Inactive user attempted access: {user_id}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Inactive user"
)
logger.info(f"User authenticated: {user.email}")
return user
except Exception as e:
logger.error(f"Error getting user: {e}")
raise credentials_exception
async def get_current_active_user(
current_user: User = Depends(get_current_user)
) -> User:
"""
Dependency to get the current active user
"""
if not current_user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Inactive user"
)
return current_user
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash"""
from passlib.context import CryptContext
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
"""Generate password hash"""
from passlib.context import CryptContext
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
return pwd_context.hash(password)
def create_access_token(data: dict, expires_delta=None):
"""Create JWT access token"""
from datetime import datetime, timedelta
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
return encoded_jwt
def create_refresh_token(data: dict):
"""Create JWT refresh token"""
from datetime import datetime, timedelta
to_encode = data.copy()
expire = datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
return encoded_jwt

View File

@@ -0,0 +1,70 @@
# ================================================================
# AUTH SERVICE CONFIGURATION
# services/auth/app/core/config.py
# ================================================================
"""
Authentication service configuration
User management and JWT token handling
"""
from shared.config.base import BaseServiceSettings
import os
class AuthSettings(BaseServiceSettings):
"""Auth service specific settings"""
# Service Identity
APP_NAME: str = "Authentication Service"
SERVICE_NAME: str = "auth-service"
DESCRIPTION: str = "User authentication and authorization service"
# Database configuration (secure approach - build from components)
@property
def DATABASE_URL(self) -> str:
"""Build database URL from secure components"""
# Try complete URL first (for backward compatibility)
complete_url = os.getenv("AUTH_DATABASE_URL")
if complete_url:
return complete_url
# Build from components (secure approach)
user = os.getenv("AUTH_DB_USER", "auth_user")
password = os.getenv("AUTH_DB_PASSWORD", "auth_pass123")
host = os.getenv("AUTH_DB_HOST", "localhost")
port = os.getenv("AUTH_DB_PORT", "5432")
name = os.getenv("AUTH_DB_NAME", "auth_db")
return f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{name}"
# Redis Database (dedicated for auth)
REDIS_DB: int = 0
# Enhanced Password Requirements for Spain
PASSWORD_MIN_LENGTH: int = 8
PASSWORD_REQUIRE_UPPERCASE: bool = True
PASSWORD_REQUIRE_LOWERCASE: bool = True
PASSWORD_REQUIRE_NUMBERS: bool = True
PASSWORD_REQUIRE_SYMBOLS: bool = False
# Spanish GDPR Compliance
GDPR_COMPLIANCE_ENABLED: bool = True
DATA_RETENTION_DAYS: int = int(os.getenv("AUTH_DATA_RETENTION_DAYS", "365"))
CONSENT_REQUIRED: bool = True
PRIVACY_POLICY_URL: str = os.getenv("PRIVACY_POLICY_URL", "/privacy")
# Account Security
ACCOUNT_LOCKOUT_ENABLED: bool = True
MAX_LOGIN_ATTEMPTS: int = 5
LOCKOUT_DURATION_MINUTES: int = 30
PASSWORD_HISTORY_COUNT: int = 5
# Session Management
SESSION_TIMEOUT_MINUTES: int = int(os.getenv("SESSION_TIMEOUT_MINUTES", "60"))
CONCURRENT_SESSIONS_LIMIT: int = int(os.getenv("CONCURRENT_SESSIONS_LIMIT", "3"))
# Email Verification
EMAIL_VERIFICATION_REQUIRED: bool = os.getenv("EMAIL_VERIFICATION_REQUIRED", "true").lower() == "true"
EMAIL_VERIFICATION_EXPIRE_HOURS: int = int(os.getenv("EMAIL_VERIFICATION_EXPIRE_HOURS", "24"))
settings = AuthSettings()

View File

@@ -0,0 +1,290 @@
# ================================================================
# services/auth/app/core/database.py (ENHANCED VERSION)
# ================================================================
"""
Database configuration for authentication service
"""
import structlog
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy.pool import NullPool
from app.core.config import settings
from shared.database.base import Base
logger = structlog.get_logger()
# Create async engine
engine = create_async_engine(
settings.DATABASE_URL,
poolclass=NullPool,
echo=settings.DEBUG,
future=True
)
# Create session factory
AsyncSessionLocal = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False
)
async def get_db() -> AsyncSession:
"""Database dependency"""
async with AsyncSessionLocal() as session:
try:
yield session
await session.commit()
except Exception as e:
await session.rollback()
logger.error(f"Database session error: {e}")
raise
finally:
await session.close()
async def create_tables():
"""Create database tables"""
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
logger.info("Database tables created successfully")
# ================================================================
# services/auth/app/core/database.py - UPDATED TO USE SHARED INFRASTRUCTURE
# ================================================================
"""
Database configuration for authentication service
Uses shared database infrastructure for consistency
"""
import structlog
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
from shared.database.base import DatabaseManager, Base
from app.core.config import settings
logger = structlog.get_logger()
# ✅ Initialize database manager using shared infrastructure
database_manager = DatabaseManager(settings.DATABASE_URL)
# ✅ Alias for convenience - matches the existing interface
get_db = database_manager.get_db
# ✅ Use the shared background session method
get_background_db_session = database_manager.get_background_session
async def get_db_health() -> bool:
"""
Health check function for database connectivity
"""
try:
async with database_manager.async_engine.begin() as conn:
await conn.execute(text("SELECT 1"))
logger.debug("Database health check passed")
return True
except Exception as e:
logger.error(f"Database health check failed: {str(e)}")
return False
async def create_tables():
"""Create database tables using shared infrastructure"""
await database_manager.create_tables()
logger.info("Auth database tables created successfully")
# ✅ Auth service specific database utilities
class AuthDatabaseUtils:
"""Auth service specific database utilities"""
@staticmethod
async def cleanup_old_refresh_tokens(days_old: int = 30):
"""Clean up old refresh tokens"""
try:
async with database_manager.get_background_session() as session:
if settings.DATABASE_URL.startswith("sqlite"):
query = text(
"DELETE FROM refresh_tokens "
"WHERE created_at < datetime('now', :days_param)"
)
params = {"days_param": f"-{days_old} days"}
else:
# PostgreSQL
query = text(
"DELETE FROM refresh_tokens "
"WHERE created_at < NOW() - INTERVAL :days_param"
)
params = {"days_param": f"{days_old} days"}
result = await session.execute(query, params)
# No need to commit - get_background_session() handles it
logger.info("Cleaned up old refresh tokens",
deleted_count=result.rowcount,
days_old=days_old)
return result.rowcount
except Exception as e:
logger.error("Failed to cleanup old refresh tokens",
error=str(e))
return 0
@staticmethod
async def get_auth_statistics(tenant_id: str = None) -> dict:
"""Get authentication statistics"""
try:
async with database_manager.get_background_session() as session:
# Base query for users
users_query = text("SELECT COUNT(*) as count FROM users WHERE is_active = :is_active")
params = {}
if tenant_id:
# If tenant filtering is needed (though auth service might not have tenant_id in users table)
# This is just an example - adjust based on your actual schema
pass
# Get active users count
active_users_result = await session.execute(
users_query,
{**params, "is_active": True}
)
active_users = active_users_result.scalar() or 0
# Get inactive users count
inactive_users_result = await session.execute(
users_query,
{**params, "is_active": False}
)
inactive_users = inactive_users_result.scalar() or 0
# Get refresh tokens count
tokens_query = text("SELECT COUNT(*) as count FROM refresh_tokens")
tokens_result = await session.execute(tokens_query)
active_tokens = tokens_result.scalar() or 0
return {
"active_users": active_users,
"inactive_users": inactive_users,
"total_users": active_users + inactive_users,
"active_tokens": active_tokens
}
except Exception as e:
logger.error(f"Failed to get auth statistics: {str(e)}")
return {
"active_users": 0,
"inactive_users": 0,
"total_users": 0,
"active_tokens": 0
}
@staticmethod
async def check_user_exists(user_id: str) -> bool:
"""Check if user exists"""
try:
async with database_manager.get_background_session() as session:
query = text(
"SELECT COUNT(*) as count "
"FROM users "
"WHERE id = :user_id "
"LIMIT 1"
)
result = await session.execute(query, {"user_id": user_id})
count = result.scalar() or 0
return count > 0
except Exception as e:
logger.error("Failed to check user existence",
user_id=user_id, error=str(e))
return False
@staticmethod
async def get_user_token_count(user_id: str) -> int:
"""Get count of active refresh tokens for a user"""
try:
async with database_manager.get_background_session() as session:
query = text(
"SELECT COUNT(*) as count "
"FROM refresh_tokens "
"WHERE user_id = :user_id"
)
result = await session.execute(query, {"user_id": user_id})
count = result.scalar() or 0
return count
except Exception as e:
logger.error("Failed to get user token count",
user_id=user_id, error=str(e))
return 0
# Enhanced database session dependency with better error handling
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
"""
Enhanced database session dependency with better logging and error handling
"""
async with database_manager.async_session_local() as session:
try:
logger.debug("Database session created")
yield session
except Exception as e:
logger.error(f"Database session error: {str(e)}", exc_info=True)
await session.rollback()
raise
finally:
await session.close()
logger.debug("Database session closed")
# Database initialization for auth service
async def initialize_auth_database():
"""Initialize database tables for auth service"""
try:
logger.info("Initializing auth service database")
# Import models to ensure they're registered
from app.models.users import User
from app.models.refresh_tokens import RefreshToken
# Create tables using shared infrastructure
await database_manager.create_tables()
logger.info("Auth service database initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize auth service database: {str(e)}")
raise
# Database cleanup for auth service
async def cleanup_auth_database():
"""Cleanup database connections for auth service"""
try:
logger.info("Cleaning up auth service database connections")
# Close engine connections
if hasattr(database_manager, 'async_engine') and database_manager.async_engine:
await database_manager.async_engine.dispose()
logger.info("Auth service database cleanup completed")
except Exception as e:
logger.error(f"Failed to cleanup auth service database: {str(e)}")
# Export the commonly used items to maintain compatibility
__all__ = [
'Base',
'database_manager',
'get_db',
'get_background_db_session',
'get_db_session',
'get_db_health',
'AuthDatabaseUtils',
'initialize_auth_database',
'cleanup_auth_database',
'create_tables'
]

View File

@@ -0,0 +1,453 @@
# services/auth/app/core/security.py - FIXED VERSION
"""
Security utilities for authentication service
FIXED VERSION - Consistent password hashing using passlib
"""
import re
import hashlib
from datetime import datetime, timedelta, timezone
from typing import Optional, Dict, Any, List
from shared.redis_utils import get_redis_client
from fastapi import HTTPException, status
import structlog
from passlib.context import CryptContext
from app.core.config import settings
from shared.auth.jwt_handler import JWTHandler
logger = structlog.get_logger()
# ✅ FIX: Use passlib for consistent password hashing
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# Initialize JWT handler with SAME configuration as gateway
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
# Note: Redis client is now accessed via get_redis_client() from shared.redis_utils
class SecurityManager:
"""Security utilities for authentication - FIXED VERSION"""
@staticmethod
def validate_password(password: str) -> bool:
"""Validate password strength"""
if len(password) < settings.PASSWORD_MIN_LENGTH:
return False
if len(password) > 128: # Max length from Pydantic schema
return False
if settings.PASSWORD_REQUIRE_UPPERCASE and not re.search(r'[A-Z]', password):
return False
if settings.PASSWORD_REQUIRE_LOWERCASE and not re.search(r'[a-z]', password):
return False
if settings.PASSWORD_REQUIRE_NUMBERS and not re.search(r'\d', password):
return False
if settings.PASSWORD_REQUIRE_SYMBOLS and not re.search(r'[!@#$%^&*(),.?":{}|<>]', password):
return False
return True
@staticmethod
def get_password_validation_errors(password: str) -> List[str]:
"""Get detailed password validation errors for better UX"""
errors = []
if len(password) < settings.PASSWORD_MIN_LENGTH:
errors.append(f"Password must be at least {settings.PASSWORD_MIN_LENGTH} characters long")
if len(password) > 128:
errors.append("Password cannot exceed 128 characters")
if settings.PASSWORD_REQUIRE_UPPERCASE and not re.search(r'[A-Z]', password):
errors.append("Password must contain at least one uppercase letter")
if settings.PASSWORD_REQUIRE_LOWERCASE and not re.search(r'[a-z]', password):
errors.append("Password must contain at least one lowercase letter")
if settings.PASSWORD_REQUIRE_NUMBERS and not re.search(r'\d', password):
errors.append("Password must contain at least one number")
if settings.PASSWORD_REQUIRE_SYMBOLS and not re.search(r'[!@#$%^&*(),.?":{}|<>]', password):
errors.append("Password must contain at least one symbol (!@#$%^&*(),.?\":{}|<>)")
return errors
@staticmethod
def hash_password(password: str) -> str:
"""Hash password using passlib bcrypt - FIXED"""
return pwd_context.hash(password)
@staticmethod
def verify_password(password: str, hashed_password: str) -> bool:
"""Verify password against hash using passlib - FIXED"""
try:
return pwd_context.verify(password, hashed_password)
except Exception as e:
logger.error(f"Password verification error: {e}")
return False
@staticmethod
def create_access_token(user_data: Dict[str, Any]) -> str:
"""
Create JWT ACCESS token with proper payload structure
✅ FIXED: Only creates access tokens
"""
# Validate required fields for access token
if "user_id" not in user_data:
raise ValueError("user_id required for access token creation")
if "email" not in user_data:
raise ValueError("email required for access token creation")
try:
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES)
# ✅ FIX 1: ACCESS TOKEN payload structure
payload = {
"sub": user_data["user_id"],
"user_id": user_data["user_id"],
"email": user_data["email"],
"type": "access", # ✅ EXPLICITLY set as access token
"exp": expire,
"iat": datetime.now(timezone.utc),
"iss": "bakery-auth"
}
# Add optional fields for access tokens
if "full_name" in user_data:
payload["full_name"] = user_data["full_name"]
if "is_verified" in user_data:
payload["is_verified"] = user_data["is_verified"]
if "is_active" in user_data:
payload["is_active"] = user_data["is_active"]
# ✅ CRITICAL FIX: Include role in access token!
if "role" in user_data:
payload["role"] = user_data["role"]
else:
payload["role"] = "admin" # Default role if not specified
# NEW: Add subscription data to JWT payload
if "tenant_id" in user_data:
payload["tenant_id"] = user_data["tenant_id"]
if "tenant_role" in user_data:
payload["tenant_role"] = user_data["tenant_role"]
if "subscription" in user_data:
payload["subscription"] = user_data["subscription"]
if "tenant_access" in user_data:
# Limit tenant_access to 10 entries to prevent JWT size explosion
tenant_access = user_data["tenant_access"]
if tenant_access and len(tenant_access) > 10:
tenant_access = tenant_access[:10]
logger.warning(f"Truncated tenant_access to 10 entries for user {user_data['user_id']}")
payload["tenant_access"] = tenant_access
logger.debug(f"Creating access token with payload keys: {list(payload.keys())}")
# ✅ FIX 2: Use JWT handler to create access token
token = jwt_handler.create_access_token_from_payload(payload)
logger.debug(f"Access token created successfully for user {user_data['email']}")
return token
except Exception as e:
logger.error(f"Access token creation failed for {user_data.get('email', 'unknown')}: {e}")
raise ValueError(f"Failed to create access token: {str(e)}")
@staticmethod
def create_refresh_token(user_data: Dict[str, Any]) -> str:
"""
Create JWT REFRESH token with minimal payload structure
✅ FIXED: Only creates refresh tokens, different from access tokens
"""
# Validate required fields for refresh token
if "user_id" not in user_data:
raise ValueError("user_id required for refresh token creation")
if not user_data.get("user_id"):
raise ValueError("user_id cannot be empty")
try:
expire = datetime.now(timezone.utc) + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
# ✅ FIX 3: REFRESH TOKEN payload structure (minimal, different from access)
payload = {
"sub": user_data["user_id"],
"user_id": user_data["user_id"],
"type": "refresh", # ✅ EXPLICITLY set as refresh token
"exp": expire,
"iat": datetime.now(timezone.utc),
"iss": "bakery-auth"
}
# Add unique JTI for refresh tokens to prevent duplicates
if "jti" in user_data:
payload["jti"] = user_data["jti"]
else:
import uuid
payload["jti"] = str(uuid.uuid4())
# Include email only if available (optional for refresh tokens)
if "email" in user_data and user_data["email"]:
payload["email"] = user_data["email"]
logger.debug(f"Creating refresh token with payload keys: {list(payload.keys())}")
# ✅ FIX 4: Use JWT handler to create REFRESH token (not access token!)
token = jwt_handler.create_refresh_token_from_payload(payload)
logger.debug(f"Refresh token created successfully for user {user_data['user_id']}")
return token
except Exception as e:
logger.error(f"Refresh token creation failed for {user_data.get('user_id', 'unknown')}: {e}")
raise ValueError(f"Failed to create refresh token: {str(e)}")
@staticmethod
def verify_token(token: str) -> Optional[Dict[str, Any]]:
"""Verify JWT token with enhanced error handling"""
try:
payload = jwt_handler.verify_token(token)
if payload:
logger.debug(f"Token verified successfully for user: {payload.get('email', 'unknown')}")
return payload
except Exception as e:
logger.warning(f"Token verification failed: {e}")
return None
@staticmethod
def decode_token(token: str) -> Dict[str, Any]:
"""Decode JWT token without verification (for refresh token handling)"""
try:
payload = jwt_handler.decode_token_no_verify(token)
return payload
except Exception as e:
logger.error(f"Token decoding failed: {e}")
raise ValueError("Invalid token format")
@staticmethod
def generate_secure_hash(data: str) -> str:
"""Generate secure hash for token storage"""
return hashlib.sha256(data.encode()).hexdigest()
@staticmethod
def create_service_token(service_name: str, tenant_id: Optional[str] = None) -> str:
"""
Create JWT service token for inter-service communication
✅ UNIFIED: Uses shared JWT handler for consistent token creation
✅ ENHANCED: Supports tenant context for tenant-scoped operations
Args:
service_name: Name of the service (e.g., 'auth-service', 'tenant-service')
tenant_id: Optional tenant ID for tenant-scoped service operations
Returns:
Encoded JWT service token
"""
try:
# Use unified JWT handler to create service token
token = jwt_handler.create_service_token(
service_name=service_name,
tenant_id=tenant_id
)
logger.debug(f"Created service token for {service_name}", tenant_id=tenant_id)
return token
except Exception as e:
logger.error(f"Failed to create service token for {service_name}: {e}")
raise ValueError(f"Failed to create service token: {str(e)}")
@staticmethod
async def track_login_attempt(email: str, ip_address: str, success: bool) -> None:
"""Track login attempts for security monitoring"""
try:
# This would use Redis for production
# For now, just log the attempt
logger.info(f"Login attempt tracked: email={email}, ip={ip_address}, success={success}")
except Exception as e:
logger.warning(f"Failed to track login attempt: {e}")
@staticmethod
def is_token_expired(token: str) -> bool:
"""Check if token is expired"""
try:
payload = SecurityManager.decode_token(token)
exp_timestamp = payload.get("exp")
if exp_timestamp:
exp_datetime = datetime.fromtimestamp(exp_timestamp, tz=timezone.utc)
return datetime.now(timezone.utc) > exp_datetime
return True
except Exception:
return True
@staticmethod
def verify_token(token: str) -> Optional[Dict[str, Any]]:
"""Verify JWT token with enhanced error handling"""
try:
payload = jwt_handler.verify_token(token)
if payload:
logger.debug(f"Token verified successfully for user: {payload.get('email', 'unknown')}")
return payload
except Exception as e:
logger.warning(f"Token verification failed: {e}")
return None
@staticmethod
async def track_login_attempt(email: str, ip_address: str, success: bool) -> None:
"""Track login attempts for security monitoring"""
try:
redis_client = await get_redis_client()
key = f"login_attempts:{email}:{ip_address}"
if success:
# Clear failed attempts on successful login
await redis_client.delete(key)
else:
# Increment failed attempts
attempts = await redis_client.incr(key)
if attempts == 1:
# Set expiration on first failed attempt
await redis_client.expire(key, settings.LOCKOUT_DURATION_MINUTES * 60)
if attempts >= settings.MAX_LOGIN_ATTEMPTS:
logger.warning(f"Account locked for {email} from {ip_address}")
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=f"Too many failed login attempts. Try again in {settings.LOCKOUT_DURATION_MINUTES} minutes."
)
except HTTPException:
raise # Re-raise HTTPException
except Exception as e:
logger.error(f"Failed to track login attempt: {e}")
@staticmethod
async def is_account_locked(email: str, ip_address: str) -> bool:
"""Check if account is locked due to failed login attempts"""
try:
redis_client = await get_redis_client()
key = f"login_attempts:{email}:{ip_address}"
attempts = await redis_client.get(key)
if attempts:
attempts = int(attempts)
return attempts >= settings.MAX_LOGIN_ATTEMPTS
except Exception as e:
logger.error(f"Failed to check account lock status: {e}")
return False
@staticmethod
def hash_api_key(api_key: str) -> str:
"""Hash API key for storage"""
return hashlib.sha256(api_key.encode()).hexdigest()
@staticmethod
def generate_secure_token(length: int = 32) -> str:
"""Generate secure random token"""
import secrets
return secrets.token_urlsafe(length)
@staticmethod
def generate_reset_token() -> str:
"""Generate a secure password reset token"""
import secrets
return secrets.token_urlsafe(32)
@staticmethod
def mask_sensitive_data(data: str, visible_chars: int = 4) -> str:
"""Mask sensitive data for logging"""
if not data or len(data) <= visible_chars:
return "*" * len(data) if data else ""
return data[:visible_chars] + "*" * (len(data) - visible_chars)
@staticmethod
async def check_login_attempts(email: str) -> bool:
"""Check if user has exceeded login attempts"""
try:
redis_client = await get_redis_client()
key = f"login_attempts:{email}"
attempts = await redis_client.get(key)
if attempts is None:
return True
return int(attempts) < settings.MAX_LOGIN_ATTEMPTS
except Exception as e:
logger.error(f"Error checking login attempts: {e}")
return True # Allow on error
@staticmethod
async def increment_login_attempts(email: str) -> None:
"""Increment login attempts for email"""
try:
redis_client = await get_redis_client()
key = f"login_attempts:{email}"
await redis_client.incr(key)
await redis_client.expire(key, settings.LOCKOUT_DURATION_MINUTES * 60)
except Exception as e:
logger.error(f"Error incrementing login attempts: {e}")
@staticmethod
async def clear_login_attempts(email: str) -> None:
"""Clear login attempts for email after successful login"""
try:
redis_client = await get_redis_client()
key = f"login_attempts:{email}"
await redis_client.delete(key)
logger.debug(f"Cleared login attempts for {email}")
except Exception as e:
logger.error(f"Error clearing login attempts: {e}")
@staticmethod
async def store_refresh_token(user_id: str, token: str) -> None:
"""Store refresh token in Redis"""
try:
redis_client = await get_redis_client()
token_hash = SecurityManager.hash_api_key(token) # Reuse hash method
key = f"refresh_token:{user_id}:{token_hash}"
# Store with expiration matching JWT refresh token expiry
expire_seconds = settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60
await redis_client.setex(key, expire_seconds, "valid")
except Exception as e:
logger.error(f"Error storing refresh token: {e}")
@staticmethod
async def is_refresh_token_valid(user_id: str, token: str) -> bool:
"""Check if refresh token is still valid in Redis"""
try:
redis_client = await get_redis_client()
token_hash = SecurityManager.hash_api_key(token)
key = f"refresh_token:{user_id}:{token_hash}"
exists = await redis_client.exists(key)
return bool(exists)
except Exception as e:
logger.error(f"Error checking refresh token validity: {e}")
return False
@staticmethod
async def revoke_refresh_token(user_id: str, token: str) -> None:
"""Revoke refresh token by removing from Redis"""
try:
redis_client = await get_redis_client()
token_hash = SecurityManager.hash_api_key(token)
key = f"refresh_token:{user_id}:{token_hash}"
await redis_client.delete(key)
logger.debug(f"Revoked refresh token for user {user_id}")
except Exception as e:
logger.error(f"Error revoking refresh token: {e}")

225
services/auth/app/main.py Normal file
View File

@@ -0,0 +1,225 @@
"""
Authentication Service Main Application
"""
from fastapi import FastAPI
from sqlalchemy import text
from app.core.config import settings
from app.core.database import database_manager
from app.api import auth_operations, users, onboarding_progress, consent, data_export, account_deletion, internal_demo, password_reset
from shared.service_base import StandardFastAPIService
from shared.messaging import UnifiedEventPublisher
class AuthService(StandardFastAPIService):
"""Authentication Service with standardized setup"""
async def on_startup(self, app):
"""Custom startup logic including migration verification and Redis initialization"""
self.logger.info("Starting auth service on_startup")
await self.verify_migrations()
# Initialize Redis if not already done during service creation
if not self.redis_initialized:
try:
from shared.redis_utils import initialize_redis, get_redis_client
await initialize_redis(settings.REDIS_URL_WITH_DB, db=settings.REDIS_DB, max_connections=getattr(settings, 'REDIS_MAX_CONNECTIONS', 50))
self.redis_client = await get_redis_client()
self.redis_initialized = True
self.logger.info("Connected to Redis for token management")
except Exception as e:
self.logger.error(f"Failed to connect to Redis during startup: {e}")
raise
await super().on_startup(app)
async def on_shutdown(self, app):
"""Custom shutdown logic for Auth Service"""
await super().on_shutdown(app)
# Close Redis
from shared.redis_utils import close_redis
await close_redis()
self.logger.info("Redis connection closed")
async def verify_migrations(self):
"""Verify database schema matches the latest migrations."""
try:
async with self.database_manager.get_session() as session:
# Check if alembic_version table exists
result = await session.execute(text("""
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = 'alembic_version'
)
"""))
table_exists = result.scalar()
if table_exists:
# If table exists, check the version
result = await session.execute(text("SELECT version_num FROM alembic_version"))
version = result.scalar()
self.logger.info(f"Migration verification successful: {version}")
else:
# If table doesn't exist, migrations might not have run yet
# This is OK - the migration job should create it
self.logger.warning("alembic_version table does not exist yet - migrations may not have run")
except Exception as e:
self.logger.warning(f"Migration verification failed (this may be expected during initial setup): {e}")
def __init__(self):
# Initialize Redis during service creation so it's available when needed
try:
import asyncio
# We need to run the async initialization in a sync context
try:
# Check if there's already a running event loop
loop = asyncio.get_running_loop()
# If there is, we'll initialize Redis later in on_startup
self.redis_initialized = False
self.redis_client = None
except RuntimeError:
# No event loop running, safe to run the async function
import asyncio
import nest_asyncio
nest_asyncio.apply() # Allow nested event loops
async def init_redis():
from shared.redis_utils import initialize_redis, get_redis_client
await initialize_redis(settings.REDIS_URL_WITH_DB, db=settings.REDIS_DB, max_connections=getattr(settings, 'REDIS_MAX_CONNECTIONS', 50))
return await get_redis_client()
self.redis_client = asyncio.run(init_redis())
self.redis_initialized = True
self.logger.info("Connected to Redis for token management")
except Exception as e:
self.logger.error(f"Failed to initialize Redis during service creation: {e}")
self.redis_initialized = False
self.redis_client = None
# Define expected database tables for health checks
auth_expected_tables = [
'users', 'refresh_tokens', 'user_onboarding_progress',
'user_onboarding_summary', 'login_attempts', 'user_consents',
'consent_history', 'audit_logs'
]
# Define custom metrics for auth service
auth_custom_metrics = {
"registration_total": {
"type": "counter",
"description": "Total user registrations by status",
"labels": ["status"]
},
"login_success_total": {
"type": "counter",
"description": "Total successful user logins"
},
"login_failure_total": {
"type": "counter",
"description": "Total failed user logins by reason",
"labels": ["reason"]
},
"token_refresh_total": {
"type": "counter",
"description": "Total token refreshes by status",
"labels": ["status"]
},
"token_verify_total": {
"type": "counter",
"description": "Total token verifications by status",
"labels": ["status"]
},
"logout_total": {
"type": "counter",
"description": "Total user logouts by status",
"labels": ["status"]
},
"registration_duration_seconds": {
"type": "histogram",
"description": "Registration request duration"
},
"login_duration_seconds": {
"type": "histogram",
"description": "Login request duration"
},
"token_refresh_duration_seconds": {
"type": "histogram",
"description": "Token refresh duration"
}
}
super().__init__(
service_name="auth-service",
app_name="Authentication Service",
description="Handles user authentication and authorization for bakery forecasting platform",
version="1.0.0",
log_level=settings.LOG_LEVEL,
api_prefix="", # Empty because RouteBuilder already includes /api/v1
database_manager=database_manager,
expected_tables=auth_expected_tables,
enable_messaging=True,
custom_metrics=auth_custom_metrics
)
async def _setup_messaging(self):
"""Setup messaging for auth service"""
from shared.messaging import RabbitMQClient
try:
self.rabbitmq_client = RabbitMQClient(settings.RABBITMQ_URL, service_name="auth-service")
await self.rabbitmq_client.connect()
# Create event publisher
self.event_publisher = UnifiedEventPublisher(self.rabbitmq_client, "auth-service")
self.logger.info("Auth service messaging setup completed")
except Exception as e:
self.logger.error("Failed to setup auth messaging", error=str(e))
raise
async def _cleanup_messaging(self):
"""Cleanup messaging for auth service"""
try:
if self.rabbitmq_client:
await self.rabbitmq_client.disconnect()
self.logger.info("Auth service messaging cleanup completed")
except Exception as e:
self.logger.error("Error during auth messaging cleanup", error=str(e))
async def on_shutdown(self, app: FastAPI):
"""Custom shutdown logic for auth service"""
self.logger.info("Authentication Service shutdown complete")
def get_service_features(self):
"""Return auth-specific features"""
return [
"user_authentication",
"token_management",
"user_onboarding",
"role_based_access",
"messaging_integration"
]
# Create service instance
service = AuthService()
# Create FastAPI app with standardized setup
app = service.create_app(
docs_url="/docs",
redoc_url="/redoc"
)
# Setup standard endpoints
service.setup_standard_endpoints()
# Include routers with specific configurations
# Note: Routes now use RouteBuilder which includes full paths, so no prefix needed
service.add_router(auth_operations.router, tags=["authentication"])
service.add_router(users.router, tags=["users"])
service.add_router(onboarding_progress.router, tags=["onboarding"])
service.add_router(consent.router, tags=["gdpr", "consent"])
service.add_router(data_export.router, tags=["gdpr", "data-export"])
service.add_router(account_deletion.router, tags=["gdpr", "account-deletion"])
service.add_router(internal_demo.router, tags=["internal-demo"])
service.add_router(password_reset.router, tags=["password-reset"])

View File

@@ -0,0 +1,31 @@
# services/auth/app/models/__init__.py
"""
Models export for auth service
"""
# Import AuditLog model for this service
from shared.security import create_audit_log_model
from shared.database.base import Base
# Create audit log model for this service
AuditLog = create_audit_log_model(Base)
from .users import User
from .tokens import RefreshToken, LoginAttempt
from .onboarding import UserOnboardingProgress, UserOnboardingSummary
from .consent import UserConsent, ConsentHistory
from .deletion_job import DeletionJob
from .password_reset_tokens import PasswordResetToken
__all__ = [
'User',
'RefreshToken',
'LoginAttempt',
'UserOnboardingProgress',
'UserOnboardingSummary',
'UserConsent',
'ConsentHistory',
'DeletionJob',
'PasswordResetToken',
"AuditLog",
]

View File

@@ -0,0 +1,110 @@
"""
User consent tracking models for GDPR compliance
"""
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Index
from sqlalchemy.dialects.postgresql import UUID, JSON
from datetime import datetime, timezone
import uuid
from shared.database.base import Base
class UserConsent(Base):
"""
Tracks user consent for various data processing activities
GDPR Article 7 - Conditions for consent
"""
__tablename__ = "user_consents"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
# Consent types
terms_accepted = Column(Boolean, nullable=False, default=False)
privacy_accepted = Column(Boolean, nullable=False, default=False)
marketing_consent = Column(Boolean, nullable=False, default=False)
analytics_consent = Column(Boolean, nullable=False, default=False)
# Consent metadata
consent_version = Column(String(20), nullable=False, default="1.0")
consent_method = Column(String(50), nullable=False) # registration, settings_update, cookie_banner
ip_address = Column(String(45), nullable=True)
user_agent = Column(Text, nullable=True)
# Consent text at time of acceptance
terms_text_hash = Column(String(64), nullable=True)
privacy_text_hash = Column(String(64), nullable=True)
# Timestamps
consented_at = Column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc))
withdrawn_at = Column(DateTime(timezone=True), nullable=True)
# Additional metadata (renamed from 'metadata' to avoid SQLAlchemy reserved word)
extra_data = Column(JSON, nullable=True)
__table_args__ = (
Index('idx_user_consent_user_id', 'user_id'),
Index('idx_user_consent_consented_at', 'consented_at'),
)
def __repr__(self):
return f"<UserConsent(user_id={self.user_id}, version={self.consent_version})>"
def to_dict(self):
return {
"id": str(self.id),
"user_id": str(self.user_id),
"terms_accepted": self.terms_accepted,
"privacy_accepted": self.privacy_accepted,
"marketing_consent": self.marketing_consent,
"analytics_consent": self.analytics_consent,
"consent_version": self.consent_version,
"consent_method": self.consent_method,
"consented_at": self.consented_at.isoformat() if self.consented_at else None,
"withdrawn_at": self.withdrawn_at.isoformat() if self.withdrawn_at else None,
}
class ConsentHistory(Base):
"""
Historical record of all consent changes
Provides audit trail for GDPR compliance
"""
__tablename__ = "consent_history"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = Column(UUID(as_uuid=True), nullable=False, index=True)
consent_id = Column(UUID(as_uuid=True), ForeignKey("user_consents.id", ondelete="SET NULL"), nullable=True)
# Action type
action = Column(String(50), nullable=False) # granted, updated, withdrawn, revoked
# Consent state at time of action
consent_snapshot = Column(JSON, nullable=False)
# Context
ip_address = Column(String(45), nullable=True)
user_agent = Column(Text, nullable=True)
consent_method = Column(String(50), nullable=True)
# Timestamp
created_at = Column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc), index=True)
__table_args__ = (
Index('idx_consent_history_user_id', 'user_id'),
Index('idx_consent_history_created_at', 'created_at'),
Index('idx_consent_history_action', 'action'),
)
def __repr__(self):
return f"<ConsentHistory(user_id={self.user_id}, action={self.action})>"
def to_dict(self):
return {
"id": str(self.id),
"user_id": str(self.user_id),
"action": self.action,
"consent_snapshot": self.consent_snapshot,
"created_at": self.created_at.isoformat() if self.created_at else None,
}

View File

@@ -0,0 +1,64 @@
"""
Deletion Job Model
Tracks tenant deletion jobs for persistence and recovery
"""
from sqlalchemy import Column, String, DateTime, Text, JSON, Index, Integer
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.sql import func
import uuid
from shared.database.base import Base
class DeletionJob(Base):
"""
Persistent storage for tenant deletion jobs
Enables job recovery and tracking across service restarts
"""
__tablename__ = "deletion_jobs"
# Primary identifiers
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
job_id = Column(String(100), nullable=False, unique=True, index=True) # External job ID
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
# Job Metadata
tenant_name = Column(String(255), nullable=True)
initiated_by = Column(UUID(as_uuid=True), nullable=True) # User ID who started deletion
# Job Status
status = Column(String(50), nullable=False, default="pending", index=True) # pending, in_progress, completed, failed, rolled_back
# Service Results
service_results = Column(JSON, nullable=True) # Dict of service_name -> result details
# Progress Tracking
total_items_deleted = Column(Integer, default=0, nullable=False)
services_completed = Column(Integer, default=0, nullable=False)
services_failed = Column(Integer, default=0, nullable=False)
# Error Tracking
error_log = Column(JSON, nullable=True) # Array of error messages
# Timestamps
started_at = Column(DateTime(timezone=True), nullable=True, index=True)
completed_at = Column(DateTime(timezone=True), nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
# Additional Context
notes = Column(Text, nullable=True)
extra_metadata = Column(JSON, nullable=True) # Additional job-specific data
# Indexes for performance
__table_args__ = (
Index('idx_deletion_job_id', 'job_id'),
Index('idx_deletion_tenant_id', 'tenant_id'),
Index('idx_deletion_status', 'status'),
Index('idx_deletion_started_at', 'started_at'),
Index('idx_deletion_tenant_status', 'tenant_id', 'status'),
)
def __repr__(self):
return f"<DeletionJob(job_id='{self.job_id}', tenant_id={self.tenant_id}, status='{self.status}')>"

View File

@@ -0,0 +1,91 @@
# services/auth/app/models/onboarding.py
"""
User onboarding progress models
"""
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, JSON, UniqueConstraint
from sqlalchemy.dialects.postgresql import UUID
from datetime import datetime, timezone
import uuid
from shared.database.base import Base
class UserOnboardingProgress(Base):
"""User onboarding progress tracking model"""
__tablename__ = "user_onboarding_progress"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
# Step tracking
step_name = Column(String(50), nullable=False)
completed = Column(Boolean, default=False, nullable=False)
completed_at = Column(DateTime(timezone=True))
# Additional step data (JSON field for flexibility)
step_data = Column(JSON, default=dict)
# Timestamps
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
# Unique constraint to prevent duplicate step entries per user
__table_args__ = (
UniqueConstraint('user_id', 'step_name', name='uq_user_step'),
{'extend_existing': True}
)
def __repr__(self):
return f"<UserOnboardingProgress(id={self.id}, user_id={self.user_id}, step={self.step_name}, completed={self.completed})>"
def to_dict(self):
"""Convert to dictionary"""
return {
"id": str(self.id),
"user_id": str(self.user_id),
"step_name": self.step_name,
"completed": self.completed,
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
"step_data": self.step_data or {},
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None
}
class UserOnboardingSummary(Base):
"""User onboarding summary for quick lookups"""
__tablename__ = "user_onboarding_summary"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, unique=True, index=True)
# Summary fields
current_step = Column(String(50), nullable=False, default="user_registered")
next_step = Column(String(50))
completion_percentage = Column(String(50), default="0.0") # Store as string for precision
fully_completed = Column(Boolean, default=False)
# Progress tracking
steps_completed_count = Column(String(50), default="0") # Store as string: "3/5"
# Timestamps
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
last_activity_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
def __repr__(self):
return f"<UserOnboardingSummary(user_id={self.user_id}, current_step={self.current_step}, completion={self.completion_percentage}%)>"
def to_dict(self):
"""Convert to dictionary"""
return {
"id": str(self.id),
"user_id": str(self.user_id),
"current_step": self.current_step,
"next_step": self.next_step,
"completion_percentage": float(self.completion_percentage) if self.completion_percentage else 0.0,
"fully_completed": self.fully_completed,
"steps_completed_count": self.steps_completed_count,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
"last_activity_at": self.last_activity_at.isoformat() if self.last_activity_at else None
}

View File

@@ -0,0 +1,39 @@
# services/auth/app/models/password_reset_tokens.py
"""
Password reset token model for authentication service
"""
import uuid
from datetime import datetime, timezone
from sqlalchemy import Column, String, DateTime, Boolean, Index
from sqlalchemy.dialects.postgresql import UUID
from shared.database.base import Base
class PasswordResetToken(Base):
"""
Password reset token model
Stores temporary tokens for password reset functionality
"""
__tablename__ = "password_reset_tokens"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = Column(UUID(as_uuid=True), nullable=False, index=True)
token = Column(String(255), nullable=False, unique=True, index=True)
expires_at = Column(DateTime(timezone=True), nullable=False)
is_used = Column(Boolean, default=False, nullable=False)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
used_at = Column(DateTime(timezone=True), nullable=True)
# Add indexes for better performance
__table_args__ = (
Index('ix_password_reset_tokens_user_id', 'user_id'),
Index('ix_password_reset_tokens_token', 'token'),
Index('ix_password_reset_tokens_expires_at', 'expires_at'),
Index('ix_password_reset_tokens_is_used', 'is_used'),
)
def __repr__(self):
return f"<PasswordResetToken(id={self.id}, user_id={self.user_id}, token={self.token[:10]}..., is_used={self.is_used})>"

View File

@@ -0,0 +1,92 @@
# ================================================================
# services/auth/app/models/tokens.py
# ================================================================
"""
Token models for authentication service
"""
import hashlib
import uuid
from datetime import datetime, timezone
from sqlalchemy import Column, String, Boolean, DateTime, Text, Index
from sqlalchemy.dialects.postgresql import UUID
from shared.database.base import Base
class RefreshToken(Base):
"""
Refresh token model - FIXED to prevent duplicate constraint violations
"""
__tablename__ = "refresh_tokens"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = Column(UUID(as_uuid=True), nullable=False, index=True)
# ✅ FIX 1: Use TEXT instead of VARCHAR to handle longer tokens
token = Column(Text, nullable=False)
# ✅ FIX 2: Add token hash for uniqueness instead of full token
token_hash = Column(String(255), nullable=True, unique=True)
expires_at = Column(DateTime(timezone=True), nullable=False)
is_revoked = Column(Boolean, default=False, nullable=False)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
revoked_at = Column(DateTime(timezone=True), nullable=True)
# ✅ FIX 3: Add indexes for better performance
__table_args__ = (
Index('ix_refresh_tokens_user_id_active', 'user_id', 'is_revoked'),
Index('ix_refresh_tokens_expires_at', 'expires_at'),
Index('ix_refresh_tokens_token_hash', 'token_hash'),
)
def __init__(self, **kwargs):
"""Initialize refresh token with automatic hash generation"""
super().__init__(**kwargs)
if self.token and not self.token_hash:
self.token_hash = self._generate_token_hash(self.token)
@staticmethod
def _generate_token_hash(token: str) -> str:
"""Generate a hash of the token for uniqueness checking"""
return hashlib.sha256(token.encode()).hexdigest()
def update_token(self, new_token: str):
"""Update token and regenerate hash"""
self.token = new_token
self.token_hash = self._generate_token_hash(new_token)
@classmethod
async def create_refresh_token(cls, user_id: uuid.UUID, token: str, expires_at: datetime):
"""
Create a new refresh token with proper hash generation
"""
return cls(
id=uuid.uuid4(),
user_id=user_id,
token=token,
token_hash=cls._generate_token_hash(token),
expires_at=expires_at,
is_revoked=False,
created_at=datetime.now(timezone.utc)
)
def __repr__(self):
return f"<RefreshToken(id={self.id}, user_id={self.user_id}, expires_at={self.expires_at})>"
class LoginAttempt(Base):
"""Login attempt tracking model"""
__tablename__ = "login_attempts"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
email = Column(String(255), nullable=False, index=True)
ip_address = Column(String(45), nullable=False)
user_agent = Column(Text)
success = Column(Boolean, default=False)
failure_reason = Column(String(255))
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
def __repr__(self):
return f"<LoginAttempt(id={self.id}, email={self.email}, success={self.success})>"

View File

@@ -0,0 +1,61 @@
# services/auth/app/models/users.py - FIXED VERSION
"""
User models for authentication service - FIXED
Removed tenant relationships to eliminate cross-service dependencies
"""
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey
from sqlalchemy.dialects.postgresql import UUID
from datetime import datetime, timezone
import uuid
from shared.database.base import Base
class User(Base):
"""User model - FIXED without cross-service relationships"""
__tablename__ = "users"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
email = Column(String(255), unique=True, index=True, nullable=False)
hashed_password = Column(String(255), nullable=False)
full_name = Column(String(255), nullable=False)
is_active = Column(Boolean, default=True)
is_verified = Column(Boolean, default=False)
# Timezone-aware datetime fields
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
last_login = Column(DateTime(timezone=True))
# Profile fields
phone = Column(String(20))
language = Column(String(10), default="es")
timezone = Column(String(50), default="Europe/Madrid")
role = Column(String(20), nullable=False)
# Payment integration fields
payment_customer_id = Column(String(255), nullable=True, index=True)
default_payment_method_id = Column(String(255), nullable=True)
# REMOVED: All tenant relationships - these are handled by tenant service
# No tenant_memberships, tenants relationships
def __repr__(self):
return f"<User(id={self.id}, email={self.email})>"
def to_dict(self):
"""Convert user to dictionary"""
return {
"id": str(self.id),
"email": self.email,
"full_name": self.full_name,
"is_active": self.is_active,
"is_verified": self.is_verified,
"phone": self.phone,
"language": self.language,
"timezone": self.timezone,
"role": self.role,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
"last_login": self.last_login.isoformat() if self.last_login else None
}

View File

@@ -0,0 +1,16 @@
"""
Auth Service Repositories
Repository implementations for authentication service
"""
from .base import AuthBaseRepository
from .user_repository import UserRepository
from .token_repository import TokenRepository
from .onboarding_repository import OnboardingRepository
__all__ = [
"AuthBaseRepository",
"UserRepository",
"TokenRepository",
"OnboardingRepository"
]

View File

@@ -0,0 +1,101 @@
"""
Base Repository for Auth Service
Service-specific repository base class with auth service utilities
"""
from typing import Optional, List, Dict, Any, Type
from sqlalchemy.ext.asyncio import AsyncSession
from datetime import datetime, timezone
import structlog
from shared.database.repository import BaseRepository
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class AuthBaseRepository(BaseRepository):
"""Base repository for auth service with common auth operations"""
def __init__(self, model: Type, session: AsyncSession, cache_ttl: Optional[int] = 600):
# Auth data benefits from longer caching (10 minutes)
super().__init__(model, session, cache_ttl)
async def get_active_records(self, skip: int = 0, limit: int = 100) -> List:
"""Get active records (if model has is_active field)"""
if hasattr(self.model, 'is_active'):
return await self.get_multi(
skip=skip,
limit=limit,
filters={"is_active": True},
order_by="created_at",
order_desc=True
)
return await self.get_multi(skip=skip, limit=limit)
async def get_by_email(self, email: str) -> Optional:
"""Get record by email (if model has email field)"""
if hasattr(self.model, 'email'):
return await self.get_by_field("email", email)
return None
async def get_by_username(self, username: str) -> Optional:
"""Get record by username (if model has username field)"""
if hasattr(self.model, 'username'):
return await self.get_by_field("username", username)
return None
async def deactivate_record(self, record_id: Any) -> Optional:
"""Deactivate a record instead of deleting it"""
if hasattr(self.model, 'is_active'):
return await self.update(record_id, {"is_active": False})
return await self.delete(record_id)
async def activate_record(self, record_id: Any) -> Optional:
"""Activate a record"""
if hasattr(self.model, 'is_active'):
return await self.update(record_id, {"is_active": True})
return await self.get_by_id(record_id)
async def cleanup_expired_records(self, field_name: str = "expires_at") -> int:
"""Clean up expired records (for tokens, sessions, etc.)"""
try:
if not hasattr(self.model, field_name):
logger.warning(f"Model {self.model.__name__} has no {field_name} field for cleanup")
return 0
# This would need custom implementation with raw SQL for date comparison
# For now, return 0 to indicate no cleanup performed
logger.info(f"Cleanup requested for {self.model.__name__} but not implemented")
return 0
except Exception as e:
logger.error("Failed to cleanup expired records",
model=self.model.__name__,
error=str(e))
raise DatabaseError(f"Cleanup failed: {str(e)}")
def _validate_auth_data(self, data: Dict[str, Any], required_fields: List[str]) -> Dict[str, Any]:
"""Validate authentication-related data"""
errors = []
for field in required_fields:
if field not in data or not data[field]:
errors.append(f"Missing required field: {field}")
# Validate email format if present
if "email" in data and data["email"]:
email = data["email"]
if "@" not in email or "." not in email.split("@")[-1]:
errors.append("Invalid email format")
# Validate password strength if present
if "password" in data and data["password"]:
password = data["password"]
if len(password) < 8:
errors.append("Password must be at least 8 characters long")
return {
"is_valid": len(errors) == 0,
"errors": errors
}

View File

@@ -0,0 +1,110 @@
"""
Deletion Job Repository
Database operations for deletion job persistence
"""
from typing import List, Optional
from uuid import UUID
from sqlalchemy import select, and_, desc
from sqlalchemy.ext.asyncio import AsyncSession
import structlog
from app.models.deletion_job import DeletionJob
logger = structlog.get_logger()
class DeletionJobRepository:
"""Repository for deletion job database operations"""
def __init__(self, session: AsyncSession):
self.session = session
async def create(self, deletion_job: DeletionJob) -> DeletionJob:
"""Create a new deletion job record"""
try:
self.session.add(deletion_job)
await self.session.flush()
await self.session.refresh(deletion_job)
return deletion_job
except Exception as e:
logger.error("Failed to create deletion job", error=str(e))
raise
async def get_by_job_id(self, job_id: str) -> Optional[DeletionJob]:
"""Get deletion job by job_id"""
try:
query = select(DeletionJob).where(DeletionJob.job_id == job_id)
result = await self.session.execute(query)
return result.scalar_one_or_none()
except Exception as e:
logger.error("Failed to get deletion job", error=str(e), job_id=job_id)
raise
async def get_by_id(self, id: UUID) -> Optional[DeletionJob]:
"""Get deletion job by database ID"""
try:
return await self.session.get(DeletionJob, id)
except Exception as e:
logger.error("Failed to get deletion job by ID", error=str(e), id=str(id))
raise
async def list_by_tenant(
self,
tenant_id: UUID,
status: Optional[str] = None,
limit: int = 100
) -> List[DeletionJob]:
"""List deletion jobs for a tenant"""
try:
query = select(DeletionJob).where(DeletionJob.tenant_id == tenant_id)
if status:
query = query.where(DeletionJob.status == status)
query = query.order_by(desc(DeletionJob.started_at)).limit(limit)
result = await self.session.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error("Failed to list deletion jobs", error=str(e), tenant_id=str(tenant_id))
raise
async def list_all(
self,
status: Optional[str] = None,
limit: int = 100
) -> List[DeletionJob]:
"""List all deletion jobs with optional status filter"""
try:
query = select(DeletionJob)
if status:
query = query.where(DeletionJob.status == status)
query = query.order_by(desc(DeletionJob.started_at)).limit(limit)
result = await self.session.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error("Failed to list all deletion jobs", error=str(e))
raise
async def update(self, deletion_job: DeletionJob) -> DeletionJob:
"""Update a deletion job record"""
try:
await self.session.flush()
await self.session.refresh(deletion_job)
return deletion_job
except Exception as e:
logger.error("Failed to update deletion job", error=str(e))
raise
async def delete(self, deletion_job: DeletionJob) -> None:
"""Delete a deletion job record"""
try:
await self.session.delete(deletion_job)
await self.session.flush()
except Exception as e:
logger.error("Failed to delete deletion job", error=str(e))
raise

View File

@@ -0,0 +1,313 @@
# services/auth/app/repositories/onboarding_repository.py
"""
Onboarding Repository for database operations
"""
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, delete, and_
from sqlalchemy.dialects.postgresql import insert
from typing import List, Dict, Any, Optional
from datetime import datetime, timezone
import structlog
from app.models.onboarding import UserOnboardingProgress, UserOnboardingSummary
logger = structlog.get_logger()
class OnboardingRepository:
"""Repository for onboarding progress operations"""
def __init__(self, db: AsyncSession):
self.db = db
async def get_user_progress_steps(self, user_id: str) -> List[UserOnboardingProgress]:
"""Get all onboarding steps for a user"""
try:
result = await self.db.execute(
select(UserOnboardingProgress)
.where(UserOnboardingProgress.user_id == user_id)
.order_by(UserOnboardingProgress.created_at)
)
return result.scalars().all()
except Exception as e:
logger.error(f"Error getting user progress steps for {user_id}: {e}")
return []
async def get_user_step(self, user_id: str, step_name: str) -> Optional[UserOnboardingProgress]:
"""Get a specific step for a user"""
try:
result = await self.db.execute(
select(UserOnboardingProgress)
.where(
and_(
UserOnboardingProgress.user_id == user_id,
UserOnboardingProgress.step_name == step_name
)
)
)
return result.scalars().first()
except Exception as e:
logger.error(f"Error getting step {step_name} for user {user_id}: {e}")
return None
async def upsert_user_step(
self,
user_id: str,
step_name: str,
completed: bool,
step_data: Dict[str, Any] = None,
auto_commit: bool = True
) -> UserOnboardingProgress:
"""Insert or update a user's onboarding step
Args:
user_id: User ID
step_name: Name of the step
completed: Whether the step is completed
step_data: Additional data for the step
auto_commit: Whether to auto-commit (set to False when used within UnitOfWork)
"""
try:
completed_at = datetime.now(timezone.utc) if completed else None
step_data = step_data or {}
# Use PostgreSQL UPSERT (INSERT ... ON CONFLICT ... DO UPDATE)
stmt = insert(UserOnboardingProgress).values(
user_id=user_id,
step_name=step_name,
completed=completed,
completed_at=completed_at,
step_data=step_data,
updated_at=datetime.now(timezone.utc)
)
# On conflict, update the existing record
stmt = stmt.on_conflict_do_update(
index_elements=['user_id', 'step_name'],
set_=dict(
completed=stmt.excluded.completed,
completed_at=stmt.excluded.completed_at,
step_data=stmt.excluded.step_data,
updated_at=stmt.excluded.updated_at
)
)
# Return the updated record
stmt = stmt.returning(UserOnboardingProgress)
result = await self.db.execute(stmt)
# Only commit if auto_commit is True (not within a UnitOfWork)
if auto_commit:
await self.db.commit()
else:
# Flush to ensure the statement is executed
await self.db.flush()
return result.scalars().first()
except Exception as e:
logger.error(f"Error upserting step {step_name} for user {user_id}: {e}")
if auto_commit:
await self.db.rollback()
raise
async def get_user_summary(self, user_id: str) -> Optional[UserOnboardingSummary]:
"""Get user's onboarding summary"""
try:
result = await self.db.execute(
select(UserOnboardingSummary)
.where(UserOnboardingSummary.user_id == user_id)
)
return result.scalars().first()
except Exception as e:
logger.error(f"Error getting onboarding summary for user {user_id}: {e}")
return None
async def upsert_user_summary(
self,
user_id: str,
current_step: str,
next_step: Optional[str],
completion_percentage: float,
fully_completed: bool,
steps_completed_count: str
) -> UserOnboardingSummary:
"""Insert or update user's onboarding summary"""
try:
# Use PostgreSQL UPSERT
stmt = insert(UserOnboardingSummary).values(
user_id=user_id,
current_step=current_step,
next_step=next_step,
completion_percentage=str(completion_percentage),
fully_completed=fully_completed,
steps_completed_count=steps_completed_count,
updated_at=datetime.now(timezone.utc),
last_activity_at=datetime.now(timezone.utc)
)
# On conflict, update the existing record
stmt = stmt.on_conflict_do_update(
index_elements=['user_id'],
set_=dict(
current_step=stmt.excluded.current_step,
next_step=stmt.excluded.next_step,
completion_percentage=stmt.excluded.completion_percentage,
fully_completed=stmt.excluded.fully_completed,
steps_completed_count=stmt.excluded.steps_completed_count,
updated_at=stmt.excluded.updated_at,
last_activity_at=stmt.excluded.last_activity_at
)
)
# Return the updated record
stmt = stmt.returning(UserOnboardingSummary)
result = await self.db.execute(stmt)
await self.db.commit()
return result.scalars().first()
except Exception as e:
logger.error(f"Error upserting summary for user {user_id}: {e}")
await self.db.rollback()
raise
async def delete_user_progress(self, user_id: str) -> bool:
"""Delete all onboarding progress for a user"""
try:
# Delete steps
await self.db.execute(
delete(UserOnboardingProgress)
.where(UserOnboardingProgress.user_id == user_id)
)
# Delete summary
await self.db.execute(
delete(UserOnboardingSummary)
.where(UserOnboardingSummary.user_id == user_id)
)
await self.db.commit()
return True
except Exception as e:
logger.error(f"Error deleting progress for user {user_id}: {e}")
await self.db.rollback()
return False
async def save_step_data(
self,
user_id: str,
step_name: str,
step_data: Dict[str, Any],
auto_commit: bool = True
) -> UserOnboardingProgress:
"""Save data for a specific step without marking it as completed
Args:
user_id: User ID
step_name: Name of the step
step_data: Data to save
auto_commit: Whether to auto-commit (set to False when used within UnitOfWork)
"""
try:
# Get existing step or create new one
existing_step = await self.get_user_step(user_id, step_name)
if existing_step:
# Update existing step data (merge with existing data)
merged_data = {**(existing_step.step_data or {}), **step_data}
stmt = update(UserOnboardingProgress).where(
and_(
UserOnboardingProgress.user_id == user_id,
UserOnboardingProgress.step_name == step_name
)
).values(
step_data=merged_data,
updated_at=datetime.now(timezone.utc)
).returning(UserOnboardingProgress)
result = await self.db.execute(stmt)
if auto_commit:
await self.db.commit()
else:
await self.db.flush()
return result.scalars().first()
else:
# Create new step with data but not completed
return await self.upsert_user_step(
user_id=user_id,
step_name=step_name,
completed=False,
step_data=step_data,
auto_commit=auto_commit
)
except Exception as e:
logger.error(f"Error saving step data for {step_name}, user {user_id}: {e}")
if auto_commit:
await self.db.rollback()
raise
async def get_step_data(self, user_id: str, step_name: str) -> Optional[Dict[str, Any]]:
"""Get data for a specific step"""
try:
step = await self.get_user_step(user_id, step_name)
return step.step_data if step else None
except Exception as e:
logger.error(f"Error getting step data for {step_name}, user {user_id}: {e}")
return None
async def get_subscription_parameters(self, user_id: str) -> Optional[Dict[str, Any]]:
"""Get subscription parameters saved during onboarding for tenant creation"""
try:
step_data = await self.get_step_data(user_id, "user_registered")
if step_data:
# Extract subscription-related parameters
subscription_params = {
"subscription_plan": step_data.get("subscription_plan", "starter"),
"billing_cycle": step_data.get("billing_cycle", "monthly"),
"coupon_code": step_data.get("coupon_code"),
"payment_method_id": step_data.get("payment_method_id"),
"payment_customer_id": step_data.get("payment_customer_id"),
"saved_at": step_data.get("saved_at")
}
return subscription_params
return None
except Exception as e:
logger.error(f"Error getting subscription parameters for user {user_id}: {e}")
return None
async def get_completion_stats(self) -> Dict[str, Any]:
"""Get completion statistics across all users"""
try:
# Get total users with onboarding data
total_result = await self.db.execute(
select(UserOnboardingSummary).count()
)
total_users = total_result.scalar()
# Get completed users
completed_result = await self.db.execute(
select(UserOnboardingSummary)
.where(UserOnboardingSummary.fully_completed == True)
.count()
)
completed_users = completed_result.scalar()
return {
"total_users_in_onboarding": total_users,
"fully_completed_users": completed_users,
"completion_rate": (completed_users / total_users * 100) if total_users > 0 else 0
}
except Exception as e:
logger.error(f"Error getting completion stats: {e}")
return {
"total_users_in_onboarding": 0,
"fully_completed_users": 0,
"completion_rate": 0
}

View File

@@ -0,0 +1,124 @@
# services/auth/app/repositories/password_reset_repository.py
"""
Password reset token repository
Repository for password reset token operations
"""
from typing import Optional, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, text
from datetime import datetime, timezone
import structlog
import uuid
from .base import AuthBaseRepository
from app.models.password_reset_tokens import PasswordResetToken
from shared.database.exceptions import DatabaseError
logger = structlog.get_logger()
class PasswordResetTokenRepository(AuthBaseRepository):
"""Repository for password reset token operations"""
def __init__(self, session: AsyncSession):
super().__init__(PasswordResetToken, session)
async def create_token(self, user_id: str, token: str, expires_at: datetime) -> PasswordResetToken:
"""Create a new password reset token"""
try:
token_data = {
"user_id": user_id,
"token": token,
"expires_at": expires_at,
"is_used": False
}
reset_token = await self.create(token_data)
logger.debug("Password reset token created",
user_id=user_id,
token_id=reset_token.id,
expires_at=expires_at)
return reset_token
except Exception as e:
logger.error("Failed to create password reset token",
user_id=user_id,
error=str(e))
raise DatabaseError(f"Failed to create password reset token: {str(e)}")
async def get_token_by_value(self, token: str) -> Optional[PasswordResetToken]:
"""Get password reset token by token value"""
try:
stmt = select(PasswordResetToken).where(
and_(
PasswordResetToken.token == token,
PasswordResetToken.is_used == False,
PasswordResetToken.expires_at > datetime.now(timezone.utc)
)
)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
except Exception as e:
logger.error("Failed to get password reset token by value", error=str(e))
raise DatabaseError(f"Failed to get password reset token: {str(e)}")
async def mark_token_as_used(self, token_id: str) -> Optional[PasswordResetToken]:
"""Mark a password reset token as used"""
try:
return await self.update(token_id, {
"is_used": True,
"used_at": datetime.now(timezone.utc)
})
except Exception as e:
logger.error("Failed to mark password reset token as used",
token_id=token_id,
error=str(e))
raise DatabaseError(f"Failed to mark token as used: {str(e)}")
async def cleanup_expired_tokens(self) -> int:
"""Clean up expired password reset tokens"""
try:
now = datetime.now(timezone.utc)
# Delete expired tokens
query = text("""
DELETE FROM password_reset_tokens
WHERE expires_at < :now OR is_used = true
""")
result = await self.session.execute(query, {"now": now})
deleted_count = result.rowcount
logger.info("Cleaned up expired password reset tokens",
deleted_count=deleted_count)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup expired password reset tokens", error=str(e))
raise DatabaseError(f"Token cleanup failed: {str(e)}")
async def get_valid_token_for_user(self, user_id: str) -> Optional[PasswordResetToken]:
"""Get a valid (unused, not expired) password reset token for a user"""
try:
stmt = select(PasswordResetToken).where(
and_(
PasswordResetToken.user_id == user_id,
PasswordResetToken.is_used == False,
PasswordResetToken.expires_at > datetime.now(timezone.utc)
)
).order_by(PasswordResetToken.created_at.desc())
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
except Exception as e:
logger.error("Failed to get valid token for user",
user_id=user_id,
error=str(e))
raise DatabaseError(f"Failed to get valid token for user: {str(e)}")

View File

@@ -0,0 +1,305 @@
"""
Token Repository
Repository for refresh token operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, text
from datetime import datetime, timezone, timedelta
import structlog
from .base import AuthBaseRepository
from app.models.tokens import RefreshToken
from shared.database.exceptions import DatabaseError
logger = structlog.get_logger()
class TokenRepository(AuthBaseRepository):
"""Repository for refresh token operations"""
def __init__(self, model, session: AsyncSession, cache_ttl: Optional[int] = 300):
# Tokens change frequently, shorter cache time
super().__init__(model, session, cache_ttl)
async def create_token(self, token_data: Dict[str, Any]) -> RefreshToken:
"""Create a new refresh token from dictionary data"""
return await self.create(token_data)
async def create_refresh_token(
self,
user_id: str,
token: str,
expires_at: datetime
) -> RefreshToken:
"""Create a new refresh token"""
try:
token_data = {
"user_id": user_id,
"token": token,
"expires_at": expires_at,
"is_revoked": False
}
refresh_token = await self.create(token_data)
logger.debug("Refresh token created",
user_id=user_id,
token_id=refresh_token.id,
expires_at=expires_at)
return refresh_token
except Exception as e:
logger.error("Failed to create refresh token",
user_id=user_id,
error=str(e))
raise DatabaseError(f"Failed to create refresh token: {str(e)}")
async def get_token_by_value(self, token: str) -> Optional[RefreshToken]:
"""Get refresh token by token value"""
try:
return await self.get_by_field("token", token)
except Exception as e:
logger.error("Failed to get token by value", error=str(e))
raise DatabaseError(f"Failed to get token: {str(e)}")
async def get_active_tokens_for_user(self, user_id: str) -> List[RefreshToken]:
"""Get all active (non-revoked, non-expired) tokens for a user"""
try:
now = datetime.now(timezone.utc)
# Use raw query for complex filtering
query = text("""
SELECT * FROM refresh_tokens
WHERE user_id = :user_id
AND is_revoked = false
AND expires_at > :now
ORDER BY created_at DESC
""")
result = await self.session.execute(query, {
"user_id": user_id,
"now": now
})
# Convert rows to RefreshToken objects
tokens = []
for row in result.fetchall():
token = RefreshToken(
id=row.id,
user_id=row.user_id,
token=row.token,
expires_at=row.expires_at,
is_revoked=row.is_revoked,
created_at=row.created_at,
revoked_at=row.revoked_at
)
tokens.append(token)
return tokens
except Exception as e:
logger.error("Failed to get active tokens for user",
user_id=user_id,
error=str(e))
raise DatabaseError(f"Failed to get active tokens: {str(e)}")
async def revoke_token(self, token_id: str) -> Optional[RefreshToken]:
"""Revoke a refresh token"""
try:
return await self.update(token_id, {
"is_revoked": True,
"revoked_at": datetime.now(timezone.utc)
})
except Exception as e:
logger.error("Failed to revoke token",
token_id=token_id,
error=str(e))
raise DatabaseError(f"Failed to revoke token: {str(e)}")
async def revoke_all_user_tokens(self, user_id: str) -> int:
"""Revoke all tokens for a user"""
try:
# Use bulk update for efficiency
now = datetime.now(timezone.utc)
query = text("""
UPDATE refresh_tokens
SET is_revoked = true, revoked_at = :revoked_at
WHERE user_id = :user_id AND is_revoked = false
""")
result = await self.session.execute(query, {
"user_id": user_id,
"revoked_at": now
})
revoked_count = result.rowcount
logger.info("Revoked all user tokens",
user_id=user_id,
revoked_count=revoked_count)
return revoked_count
except Exception as e:
logger.error("Failed to revoke all user tokens",
user_id=user_id,
error=str(e))
raise DatabaseError(f"Failed to revoke user tokens: {str(e)}")
async def is_token_valid(self, token: str) -> bool:
"""Check if a token is valid (exists, not revoked, not expired)"""
try:
refresh_token = await self.get_token_by_value(token)
if not refresh_token:
return False
if refresh_token.is_revoked:
return False
if refresh_token.expires_at < datetime.now(timezone.utc):
return False
return True
except Exception as e:
logger.error("Failed to validate token", error=str(e))
return False
async def validate_refresh_token(self, token: str, user_id: str) -> bool:
"""Validate refresh token for a specific user"""
try:
refresh_token = await self.get_token_by_value(token)
if not refresh_token:
logger.debug("Refresh token not found", token_prefix=token[:10] + "...")
return False
# Convert both to strings for comparison to handle UUID vs string mismatch
token_user_id = str(refresh_token.user_id)
expected_user_id = str(user_id)
if token_user_id != expected_user_id:
logger.warning("Refresh token user_id mismatch",
expected_user_id=expected_user_id,
actual_user_id=token_user_id)
return False
if refresh_token.is_revoked:
logger.debug("Refresh token is revoked", user_id=user_id)
return False
if refresh_token.expires_at < datetime.now(timezone.utc):
logger.debug("Refresh token is expired", user_id=user_id)
return False
logger.debug("Refresh token is valid", user_id=user_id)
return True
except Exception as e:
logger.error("Failed to validate refresh token",
user_id=user_id,
error=str(e))
return False
async def cleanup_expired_tokens(self) -> int:
"""Clean up expired refresh tokens"""
try:
now = datetime.now(timezone.utc)
# Delete expired tokens
query = text("""
DELETE FROM refresh_tokens
WHERE expires_at < :now
""")
result = await self.session.execute(query, {"now": now})
deleted_count = result.rowcount
logger.info("Cleaned up expired tokens",
deleted_count=deleted_count)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup expired tokens", error=str(e))
raise DatabaseError(f"Token cleanup failed: {str(e)}")
async def cleanup_old_revoked_tokens(self, days_old: int = 30) -> int:
"""Clean up old revoked tokens"""
try:
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_old)
query = text("""
DELETE FROM refresh_tokens
WHERE is_revoked = true
AND revoked_at < :cutoff_date
""")
result = await self.session.execute(query, {
"cutoff_date": cutoff_date
})
deleted_count = result.rowcount
logger.info("Cleaned up old revoked tokens",
deleted_count=deleted_count,
days_old=days_old)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup old revoked tokens",
days_old=days_old,
error=str(e))
raise DatabaseError(f"Revoked token cleanup failed: {str(e)}")
async def get_token_statistics(self) -> Dict[str, Any]:
"""Get token statistics"""
try:
now = datetime.now(timezone.utc)
# Get counts with raw queries
stats_query = text("""
SELECT
COUNT(*) as total_tokens,
COUNT(CASE WHEN is_revoked = false AND expires_at > :now THEN 1 END) as active_tokens,
COUNT(CASE WHEN is_revoked = true THEN 1 END) as revoked_tokens,
COUNT(CASE WHEN expires_at <= :now THEN 1 END) as expired_tokens,
COUNT(DISTINCT user_id) as users_with_tokens
FROM refresh_tokens
""")
result = await self.session.execute(stats_query, {"now": now})
row = result.fetchone()
if row:
return {
"total_tokens": row.total_tokens,
"active_tokens": row.active_tokens,
"revoked_tokens": row.revoked_tokens,
"expired_tokens": row.expired_tokens,
"users_with_tokens": row.users_with_tokens
}
return {
"total_tokens": 0,
"active_tokens": 0,
"revoked_tokens": 0,
"expired_tokens": 0,
"users_with_tokens": 0
}
except Exception as e:
logger.error("Failed to get token statistics", error=str(e))
return {
"total_tokens": 0,
"active_tokens": 0,
"revoked_tokens": 0,
"expired_tokens": 0,
"users_with_tokens": 0
}

View File

@@ -0,0 +1,277 @@
"""
User Repository
Repository for user operations with authentication-specific queries
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, or_, func, desc, text
from datetime import datetime, timezone, timedelta
import structlog
from .base import AuthBaseRepository
from app.models.users import User
from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError
logger = structlog.get_logger()
class UserRepository(AuthBaseRepository):
"""Repository for user operations"""
def __init__(self, model, session: AsyncSession, cache_ttl: Optional[int] = 600):
super().__init__(model, session, cache_ttl)
async def create_user(self, user_data: Dict[str, Any]) -> User:
"""Create a new user with validation"""
try:
# Validate user data
validation_result = self._validate_auth_data(
user_data,
["email", "hashed_password", "full_name", "role"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid user data: {validation_result['errors']}")
# Check if user already exists
existing_user = await self.get_by_email(user_data["email"])
if existing_user:
raise DuplicateRecordError(f"User with email {user_data['email']} already exists")
# Create user
user = await self.create(user_data)
logger.info("User created successfully",
user_id=user.id,
email=user.email,
role=user.role)
return user
except (ValidationError, DuplicateRecordError):
raise
except Exception as e:
logger.error("Failed to create user",
email=user_data.get("email"),
error=str(e))
raise DatabaseError(f"Failed to create user: {str(e)}")
async def get_user_by_email(self, email: str) -> Optional[User]:
"""Get user by email address"""
return await self.get_by_email(email)
async def get_active_users(self, skip: int = 0, limit: int = 100) -> List[User]:
"""Get all active users"""
return await self.get_active_records(skip=skip, limit=limit)
async def authenticate_user(self, email: str, password: str) -> Optional[User]:
"""Authenticate user with email and plain password"""
try:
user = await self.get_by_email(email)
if not user:
logger.debug("User not found for authentication", email=email)
return None
if not user.is_active:
logger.debug("User account is inactive", email=email)
return None
# Verify password using security manager
from app.core.security import SecurityManager
if SecurityManager.verify_password(password, user.hashed_password):
# Update last login
await self.update_last_login(user.id)
logger.info("User authenticated successfully",
user_id=user.id,
email=email)
return user
logger.debug("Invalid password for user", email=email)
return None
except Exception as e:
logger.error("Authentication failed",
email=email,
error=str(e))
raise DatabaseError(f"Authentication failed: {str(e)}")
async def update_last_login(self, user_id: str) -> Optional[User]:
"""Update user's last login timestamp"""
try:
return await self.update(user_id, {
"last_login": datetime.now(timezone.utc)
})
except Exception as e:
logger.error("Failed to update last login",
user_id=user_id,
error=str(e))
# Don't raise here - last login update is not critical
return None
async def update_user_profile(self, user_id: str, profile_data: Dict[str, Any]) -> Optional[User]:
"""Update user profile information"""
try:
# Remove sensitive fields that shouldn't be updated via profile
profile_data.pop("id", None)
profile_data.pop("hashed_password", None)
profile_data.pop("created_at", None)
profile_data.pop("is_active", None)
# Validate email if being updated
if "email" in profile_data:
validation_result = self._validate_auth_data(
profile_data,
["email"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid profile data: {validation_result['errors']}")
# Check for email conflicts
existing_user = await self.get_by_email(profile_data["email"])
if existing_user and str(existing_user.id) != str(user_id):
raise DuplicateRecordError(f"Email {profile_data['email']} is already in use")
updated_user = await self.update(user_id, profile_data)
if updated_user:
logger.info("User profile updated",
user_id=user_id,
updated_fields=list(profile_data.keys()))
return updated_user
except (ValidationError, DuplicateRecordError):
raise
except Exception as e:
logger.error("Failed to update user profile",
user_id=user_id,
error=str(e))
raise DatabaseError(f"Failed to update profile: {str(e)}")
async def change_password(self, user_id: str, new_password_hash: str) -> bool:
"""Change user password"""
try:
updated_user = await self.update(user_id, {
"hashed_password": new_password_hash
})
if updated_user:
logger.info("Password changed successfully", user_id=user_id)
return True
return False
except Exception as e:
logger.error("Failed to change password",
user_id=user_id,
error=str(e))
raise DatabaseError(f"Failed to change password: {str(e)}")
async def verify_user_email(self, user_id: str) -> Optional[User]:
"""Mark user email as verified"""
try:
return await self.update(user_id, {
"is_verified": True
})
except Exception as e:
logger.error("Failed to verify user email",
user_id=user_id,
error=str(e))
raise DatabaseError(f"Failed to verify email: {str(e)}")
async def deactivate_user(self, user_id: str) -> Optional[User]:
"""Deactivate user account"""
return await self.deactivate_record(user_id)
async def activate_user(self, user_id: str) -> Optional[User]:
"""Activate user account"""
return await self.activate_record(user_id)
async def get_users_by_role(self, role: str, skip: int = 0, limit: int = 100) -> List[User]:
"""Get users by role"""
try:
return await self.get_multi(
skip=skip,
limit=limit,
filters={"role": role, "is_active": True},
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get users by role",
role=role,
error=str(e))
raise DatabaseError(f"Failed to get users by role: {str(e)}")
async def search_users(self, search_term: str, skip: int = 0, limit: int = 50) -> List[User]:
"""Search users by email or full name"""
try:
return await self.search(
search_term=search_term,
search_fields=["email", "full_name"],
skip=skip,
limit=limit
)
except Exception as e:
logger.error("Failed to search users",
search_term=search_term,
error=str(e))
raise DatabaseError(f"Failed to search users: {str(e)}")
async def get_user_statistics(self) -> Dict[str, Any]:
"""Get user statistics"""
try:
# Get basic counts
total_users = await self.count()
active_users = await self.count(filters={"is_active": True})
verified_users = await self.count(filters={"is_verified": True})
# Get users by role using raw query
role_query = text("""
SELECT role, COUNT(*) as count
FROM users
WHERE is_active = true
GROUP BY role
ORDER BY count DESC
""")
result = await self.session.execute(role_query)
role_stats = {row.role: row.count for row in result.fetchall()}
# Recent activity (users created in last 30 days)
thirty_days_ago = datetime.now(timezone.utc) - timedelta(days=30)
recent_users_query = text("""
SELECT COUNT(*) as count
FROM users
WHERE created_at >= :thirty_days_ago
""")
recent_result = await self.session.execute(
recent_users_query,
{"thirty_days_ago": thirty_days_ago}
)
recent_users = recent_result.scalar() or 0
return {
"total_users": total_users,
"active_users": active_users,
"inactive_users": total_users - active_users,
"verified_users": verified_users,
"unverified_users": active_users - verified_users,
"recent_registrations": recent_users,
"users_by_role": role_stats
}
except Exception as e:
logger.error("Failed to get user statistics", error=str(e))
return {
"total_users": 0,
"active_users": 0,
"inactive_users": 0,
"verified_users": 0,
"unverified_users": 0,
"recent_registrations": 0,
"users_by_role": {}
}

View File

View File

@@ -0,0 +1,230 @@
# services/auth/app/schemas/auth.py - UPDATED WITH UNIFIED TOKEN RESPONSE
"""
Authentication schemas - Updated with unified token response format
Following industry best practices from Firebase, Cognito, etc.
"""
from pydantic import BaseModel, EmailStr, Field
from typing import Optional, Dict, Any
from datetime import datetime
# ================================================================
# REQUEST SCHEMAS
# ================================================================
class UserRegistration(BaseModel):
"""User registration request"""
email: EmailStr
password: str = Field(..., min_length=8, max_length=128)
full_name: str = Field(..., min_length=1, max_length=255)
tenant_name: Optional[str] = Field(None, max_length=255)
role: Optional[str] = Field("admin", pattern=r'^(user|admin|manager|super_admin)$')
subscription_plan: Optional[str] = Field("starter", description="Selected subscription plan (starter, professional, enterprise)")
billing_cycle: Optional[str] = Field("monthly", description="Billing cycle (monthly, yearly)")
coupon_code: Optional[str] = Field(None, description="Discount coupon code")
payment_method_id: Optional[str] = Field(None, description="Stripe payment method ID")
# GDPR Consent fields
terms_accepted: Optional[bool] = Field(True, description="Accept terms of service")
privacy_accepted: Optional[bool] = Field(True, description="Accept privacy policy")
marketing_consent: Optional[bool] = Field(False, description="Consent to marketing communications")
analytics_consent: Optional[bool] = Field(False, description="Consent to analytics cookies")
class UserLogin(BaseModel):
"""User login request"""
email: EmailStr
password: str
class RefreshTokenRequest(BaseModel):
"""Refresh token request"""
refresh_token: str
class PasswordChange(BaseModel):
"""Password change request"""
current_password: str
new_password: str = Field(..., min_length=8, max_length=128)
class PasswordReset(BaseModel):
"""Password reset request"""
email: EmailStr
class PasswordResetConfirm(BaseModel):
"""Password reset confirmation"""
token: str
new_password: str = Field(..., min_length=8, max_length=128)
# ================================================================
# RESPONSE SCHEMAS
# ================================================================
class UserData(BaseModel):
"""User data embedded in token responses"""
id: str
email: str
full_name: str
is_active: bool
is_verified: bool
created_at: str # ISO format datetime string
tenant_id: Optional[str] = None
role: Optional[str] = "admin"
class TokenResponse(BaseModel):
"""
Unified token response for both registration and login
Follows industry standards (Firebase, AWS Cognito, etc.)
"""
access_token: str
refresh_token: Optional[str] = None
token_type: str = "bearer"
expires_in: int = 3600 # seconds
user: Optional[UserData] = None
subscription_id: Optional[str] = Field(None, description="Subscription ID if created during registration")
# Payment action fields (3DS, SetupIntent, etc.)
requires_action: Optional[bool] = Field(None, description="Whether payment action is required (3DS, SetupIntent confirmation)")
action_type: Optional[str] = Field(None, description="Type of action required (setup_intent_confirmation, payment_intent_confirmation)")
client_secret: Optional[str] = Field(None, description="Client secret for payment confirmation")
payment_intent_id: Optional[str] = Field(None, description="Payment intent ID for 3DS authentication")
setup_intent_id: Optional[str] = Field(None, description="SetupIntent ID for payment method verification")
customer_id: Optional[str] = Field(None, description="Stripe customer ID")
# Additional fields for post-confirmation subscription completion
plan_id: Optional[str] = Field(None, description="Subscription plan ID")
payment_method_id: Optional[str] = Field(None, description="Payment method ID")
trial_period_days: Optional[int] = Field(None, description="Trial period in days")
user_id: Optional[str] = Field(None, description="User ID for post-confirmation processing")
billing_interval: Optional[str] = Field(None, description="Billing interval (monthly, yearly)")
message: Optional[str] = Field(None, description="Additional message about payment action required")
class Config:
schema_extra = {
"example": {
"access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
"refresh_token": "def502004b8b7f8f...",
"token_type": "bearer",
"expires_in": 3600,
"user": {
"id": "123e4567-e89b-12d3-a456-426614174000",
"email": "user@example.com",
"full_name": "John Doe",
"is_active": True,
"is_verified": False,
"created_at": "2025-07-22T10:00:00Z",
"role": "user"
},
"subscription_id": "sub_1234567890",
"requires_action": True,
"action_type": "setup_intent_confirmation",
"client_secret": "seti_1234_secret_5678",
"payment_intent_id": None,
"setup_intent_id": "seti_1234567890",
"customer_id": "cus_1234567890"
}
}
class UserResponse(BaseModel):
"""User response for user management endpoints - FIXED"""
id: str
email: str
full_name: str
is_active: bool
is_verified: bool
created_at: datetime # ✅ Changed from str to datetime
last_login: Optional[datetime] = None # ✅ Added missing field
phone: Optional[str] = None # ✅ Added missing field
language: Optional[str] = None # ✅ Added missing field
timezone: Optional[str] = None # ✅ Added missing field
tenant_id: Optional[str] = None
role: Optional[str] = "admin"
payment_customer_id: Optional[str] = None # ✅ Added payment integration field
default_payment_method_id: Optional[str] = None # ✅ Added payment integration field
class Config:
from_attributes = True # ✅ Enable ORM mode for SQLAlchemy objects
class TokenVerification(BaseModel):
"""Token verification response"""
valid: bool
user_id: Optional[str] = None
email: Optional[str] = None
exp: Optional[int] = None
message: Optional[str] = None
class PasswordResetResponse(BaseModel):
"""Password reset response"""
message: str
reset_token: Optional[str] = None
class LogoutResponse(BaseModel):
"""Logout response"""
message: str
success: bool = True
# ================================================================
# ERROR SCHEMAS
# ================================================================
class ErrorDetail(BaseModel):
"""Error detail for API responses"""
message: str
code: Optional[str] = None
field: Optional[str] = None
class ErrorResponse(BaseModel):
"""Standardized error response"""
success: bool = False
error: ErrorDetail
timestamp: str
class Config:
schema_extra = {
"example": {
"success": False,
"error": {
"message": "Invalid credentials",
"code": "AUTH_001"
},
"timestamp": "2025-07-22T10:00:00Z"
}
}
# ================================================================
# VALIDATION SCHEMAS
# ================================================================
class EmailVerificationRequest(BaseModel):
"""Email verification request"""
email: EmailStr
class EmailVerificationConfirm(BaseModel):
"""Email verification confirmation"""
token: str
class ProfileUpdate(BaseModel):
"""Profile update request"""
full_name: Optional[str] = Field(None, min_length=1, max_length=255)
email: Optional[EmailStr] = None
# ================================================================
# INTERNAL SCHEMAS (for service communication)
# ================================================================
class UserContext(BaseModel):
"""User context for internal service communication"""
user_id: str
email: str
tenant_id: Optional[str] = None
roles: list[str] = ["admin"]
is_verified: bool = False
class TokenClaims(BaseModel):
"""JWT token claims structure"""
sub: str # subject (user_id)
email: str
full_name: str
user_id: str
is_verified: bool
tenant_id: Optional[str] = None
iat: int # issued at
exp: int # expires at
iss: str = "bakery-auth" # issuer

Some files were not shown because too many files have changed in this diff Show More