Initial commit - production deployment
This commit is contained in:
41
services/ai_insights/.env.example
Normal file
41
services/ai_insights/.env.example
Normal file
@@ -0,0 +1,41 @@
|
||||
# AI Insights Service Environment Variables
|
||||
|
||||
# Service Info
|
||||
SERVICE_NAME=ai-insights
|
||||
SERVICE_VERSION=1.0.0
|
||||
API_V1_PREFIX=/api/v1/ai-insights
|
||||
|
||||
# Database
|
||||
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/bakery_ai_insights
|
||||
DB_POOL_SIZE=20
|
||||
DB_MAX_OVERFLOW=10
|
||||
|
||||
# Redis
|
||||
REDIS_URL=redis://localhost:6379/5
|
||||
REDIS_CACHE_TTL=900
|
||||
|
||||
# Service URLs
|
||||
FORECASTING_SERVICE_URL=http://forecasting-service:8000
|
||||
PROCUREMENT_SERVICE_URL=http://procurement-service:8000
|
||||
PRODUCTION_SERVICE_URL=http://production-service:8000
|
||||
SALES_SERVICE_URL=http://sales-service:8000
|
||||
INVENTORY_SERVICE_URL=http://inventory-service:8000
|
||||
|
||||
# Circuit Breaker Settings
|
||||
CIRCUIT_BREAKER_FAILURE_THRESHOLD=5
|
||||
CIRCUIT_BREAKER_TIMEOUT=60
|
||||
|
||||
# Insight Settings
|
||||
MIN_CONFIDENCE_THRESHOLD=60
|
||||
DEFAULT_INSIGHT_TTL_DAYS=7
|
||||
MAX_INSIGHTS_PER_REQUEST=100
|
||||
|
||||
# Feedback Settings
|
||||
FEEDBACK_PROCESSING_ENABLED=true
|
||||
FEEDBACK_PROCESSING_SCHEDULE="0 6 * * *"
|
||||
|
||||
# Logging
|
||||
LOG_LEVEL=INFO
|
||||
|
||||
# CORS
|
||||
ALLOWED_ORIGINS=["http://localhost:3000","http://localhost:5173"]
|
||||
59
services/ai_insights/Dockerfile
Normal file
59
services/ai_insights/Dockerfile
Normal file
@@ -0,0 +1,59 @@
|
||||
# =============================================================================
|
||||
# AI Insights Service Dockerfile - Environment-Configurable Base Images
|
||||
# =============================================================================
|
||||
# Build arguments for registry configuration:
|
||||
# - BASE_REGISTRY: Registry URL (default: docker.io for Docker Hub)
|
||||
# - PYTHON_IMAGE: Python image name and tag (default: python:3.11-slim)
|
||||
# =============================================================================
|
||||
|
||||
ARG BASE_REGISTRY=docker.io
|
||||
ARG PYTHON_IMAGE=python:3.11-slim
|
||||
|
||||
FROM ${BASE_REGISTRY}/${PYTHON_IMAGE} AS shared
|
||||
WORKDIR /shared
|
||||
COPY shared/ /shared/
|
||||
|
||||
ARG BASE_REGISTRY=docker.io
|
||||
ARG PYTHON_IMAGE=python:3.11-slim
|
||||
FROM ${BASE_REGISTRY}/${PYTHON_IMAGE}
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
curl \
|
||||
postgresql-client \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements
|
||||
COPY shared/requirements-tracing.txt /tmp/
|
||||
|
||||
COPY services/ai_insights/requirements.txt .
|
||||
|
||||
# Install Python dependencies
|
||||
RUN pip install --no-cache-dir -r /tmp/requirements-tracing.txt
|
||||
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy shared libraries from the shared stage
|
||||
COPY --from=shared /shared /app/shared
|
||||
|
||||
# Copy application code
|
||||
COPY services/ai_insights/ .
|
||||
|
||||
# Copy scripts for migrations
|
||||
COPY scripts/ /app/scripts/
|
||||
|
||||
# Add shared libraries to Python path
|
||||
ENV PYTHONPATH="/app:/app/shared:${PYTHONPATH:-}"
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
|
||||
CMD curl -f http://localhost:8000/health || exit 1
|
||||
|
||||
# Run the application
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
232
services/ai_insights/QUICK_START.md
Normal file
232
services/ai_insights/QUICK_START.md
Normal file
@@ -0,0 +1,232 @@
|
||||
# AI Insights Service - Quick Start Guide
|
||||
|
||||
Get the AI Insights Service running in 5 minutes.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Python 3.11+
|
||||
- PostgreSQL 14+ (running)
|
||||
- Redis 6+ (running)
|
||||
|
||||
## Step 1: Setup Environment
|
||||
|
||||
```bash
|
||||
cd services/ai_insights
|
||||
|
||||
# Create virtual environment
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate # Windows: venv\Scripts\activate
|
||||
|
||||
# Install dependencies
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Step 2: Configure Database
|
||||
|
||||
```bash
|
||||
# Copy environment template
|
||||
cp .env.example .env
|
||||
|
||||
# Edit .env file
|
||||
nano .env
|
||||
```
|
||||
|
||||
**Minimum required configuration**:
|
||||
```env
|
||||
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/bakery_ai_insights
|
||||
REDIS_URL=redis://localhost:6379/5
|
||||
```
|
||||
|
||||
## Step 3: Create Database
|
||||
|
||||
```bash
|
||||
# Connect to PostgreSQL
|
||||
psql -U postgres
|
||||
|
||||
# Create database
|
||||
CREATE DATABASE bakery_ai_insights;
|
||||
\q
|
||||
```
|
||||
|
||||
## Step 4: Run Migrations
|
||||
|
||||
```bash
|
||||
# Run Alembic migrations
|
||||
alembic upgrade head
|
||||
```
|
||||
|
||||
You should see:
|
||||
```
|
||||
INFO [alembic.runtime.migration] Running upgrade -> 001, Initial schema for AI Insights Service
|
||||
```
|
||||
|
||||
## Step 5: Start the Service
|
||||
|
||||
```bash
|
||||
uvicorn app.main:app --reload
|
||||
```
|
||||
|
||||
You should see:
|
||||
```
|
||||
INFO: Uvicorn running on http://127.0.0.1:8000
|
||||
INFO: Application startup complete.
|
||||
```
|
||||
|
||||
## Step 6: Verify Installation
|
||||
|
||||
Open browser to http://localhost:8000/docs
|
||||
|
||||
You should see the Swagger UI with all API endpoints.
|
||||
|
||||
### Test Health Endpoint
|
||||
|
||||
```bash
|
||||
curl http://localhost:8000/health
|
||||
```
|
||||
|
||||
Expected response:
|
||||
```json
|
||||
{
|
||||
"status": "healthy",
|
||||
"service": "ai-insights",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
```
|
||||
|
||||
## Step 7: Create Your First Insight
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:8000/api/v1/ai-insights/tenants/550e8400-e29b-41d4-a716-446655440000/insights" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"tenant_id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"type": "recommendation",
|
||||
"priority": "high",
|
||||
"category": "forecasting",
|
||||
"title": "Test Insight - Weekend Demand Pattern",
|
||||
"description": "Weekend sales 20% higher than weekdays",
|
||||
"impact_type": "revenue_increase",
|
||||
"impact_value": 150.00,
|
||||
"impact_unit": "euros/week",
|
||||
"confidence": 85,
|
||||
"metrics_json": {
|
||||
"weekday_avg": 45.2,
|
||||
"weekend_avg": 54.2,
|
||||
"increase_pct": 20.0
|
||||
},
|
||||
"actionable": true,
|
||||
"recommendation_actions": [
|
||||
{"label": "Increase Production", "action": "adjust_production"}
|
||||
],
|
||||
"source_service": "forecasting"
|
||||
}'
|
||||
```
|
||||
|
||||
## Step 8: Query Your Insights
|
||||
|
||||
```bash
|
||||
curl "http://localhost:8000/api/v1/ai-insights/tenants/550e8400-e29b-41d4-a716-446655440000/insights?page=1&page_size=10"
|
||||
```
|
||||
|
||||
## Common Issues
|
||||
|
||||
### Issue: "ModuleNotFoundError: No module named 'app'"
|
||||
|
||||
**Solution**: Make sure you're running from the `services/ai_insights/` directory and virtual environment is activated.
|
||||
|
||||
### Issue: "Connection refused" on database
|
||||
|
||||
**Solution**: Verify PostgreSQL is running:
|
||||
```bash
|
||||
# Check if PostgreSQL is running
|
||||
pg_isready
|
||||
|
||||
# Start PostgreSQL (macOS with Homebrew)
|
||||
brew services start postgresql
|
||||
|
||||
# Start PostgreSQL (Linux)
|
||||
sudo systemctl start postgresql
|
||||
```
|
||||
|
||||
### Issue: "Redis connection error"
|
||||
|
||||
**Solution**: Verify Redis is running:
|
||||
```bash
|
||||
# Check if Redis is running
|
||||
redis-cli ping
|
||||
|
||||
# Should return: PONG
|
||||
|
||||
# Start Redis (macOS with Homebrew)
|
||||
brew services start redis
|
||||
|
||||
# Start Redis (Linux)
|
||||
sudo systemctl start redis
|
||||
```
|
||||
|
||||
### Issue: "Alembic command not found"
|
||||
|
||||
**Solution**: Virtual environment not activated:
|
||||
```bash
|
||||
source venv/bin/activate
|
||||
```
|
||||
|
||||
## Next Steps
|
||||
|
||||
1. **Explore API**: Visit http://localhost:8000/docs
|
||||
2. **Read Documentation**: See `README.md` for detailed documentation
|
||||
3. **Implementation Guide**: See `AI_INSIGHTS_IMPLEMENTATION_SUMMARY.md`
|
||||
4. **Integration**: Start integrating with other services
|
||||
|
||||
## Useful Commands
|
||||
|
||||
```bash
|
||||
# Check service status
|
||||
curl http://localhost:8000/health
|
||||
|
||||
# Get aggregate metrics
|
||||
curl "http://localhost:8000/api/v1/ai-insights/tenants/{tenant_id}/insights/metrics/summary"
|
||||
|
||||
# Filter high-confidence insights
|
||||
curl "http://localhost:8000/api/v1/ai-insights/tenants/{tenant_id}/insights?actionable_only=true&min_confidence=80"
|
||||
|
||||
# Stop the service
|
||||
# Press Ctrl+C in the terminal running uvicorn
|
||||
|
||||
# Deactivate virtual environment
|
||||
deactivate
|
||||
```
|
||||
|
||||
## Docker Quick Start (Alternative)
|
||||
|
||||
If you prefer Docker:
|
||||
|
||||
```bash
|
||||
# Build image
|
||||
docker build -t ai-insights .
|
||||
|
||||
# Run container
|
||||
docker run -d \
|
||||
--name ai-insights \
|
||||
-p 8000:8000 \
|
||||
-e DATABASE_URL=postgresql+asyncpg://postgres:postgres@host.docker.internal:5432/bakery_ai_insights \
|
||||
-e REDIS_URL=redis://host.docker.internal:6379/5 \
|
||||
ai-insights
|
||||
|
||||
# Check logs
|
||||
docker logs ai-insights
|
||||
|
||||
# Stop container
|
||||
docker stop ai-insights
|
||||
docker rm ai-insights
|
||||
```
|
||||
|
||||
## Support
|
||||
|
||||
- **Documentation**: See `README.md`
|
||||
- **API Docs**: http://localhost:8000/docs
|
||||
- **Issues**: Create GitHub issue or contact team
|
||||
|
||||
---
|
||||
|
||||
**You're ready!** The AI Insights Service is now running and ready to accept insights from other services.
|
||||
325
services/ai_insights/README.md
Normal file
325
services/ai_insights/README.md
Normal file
@@ -0,0 +1,325 @@
|
||||
# AI Insights Service
|
||||
|
||||
## Overview
|
||||
|
||||
The **AI Insights Service** provides intelligent, actionable recommendations to bakery operators by analyzing patterns across inventory, production, procurement, and sales data. It acts as a virtual operations consultant, proactively identifying opportunities for cost savings, waste reduction, and operational improvements. This service transforms raw data into business intelligence that drives profitability.
|
||||
|
||||
## Key Features
|
||||
|
||||
### Intelligent Recommendations
|
||||
- **Inventory Optimization** - Smart reorder point suggestions and stock level adjustments
|
||||
- **Production Planning** - Optimal batch size and scheduling recommendations
|
||||
- **Procurement Suggestions** - Best supplier selection and order timing advice
|
||||
- **Sales Opportunities** - Identify trending products and underperforming items
|
||||
- **Cost Reduction** - Find areas to reduce waste and lower operational costs
|
||||
- **Quality Improvements** - Detect patterns affecting product quality
|
||||
|
||||
### Unified Insight Management
|
||||
- **Centralized Storage** - All AI-generated insights in one place
|
||||
- **Confidence Scoring** - Standardized 0-100% confidence calculation across insight types
|
||||
- **Impact Estimation** - Business value quantification for recommendations
|
||||
- **Feedback Loop** - Closed-loop learning from applied insights
|
||||
- **Cross-Service Intelligence** - Correlation detection between insights from different services
|
||||
|
||||
## Features
|
||||
|
||||
### Core Capabilities
|
||||
|
||||
1. **Insight Aggregation**
|
||||
- Collect insights from Forecasting, Procurement, Production, and Sales services
|
||||
- Categorize and prioritize recommendations
|
||||
- Filter by confidence, category, priority, and actionability
|
||||
|
||||
2. **Confidence Calculation**
|
||||
- Multi-factor scoring: data quality, model performance, sample size, recency, historical accuracy
|
||||
- Insight-type specific adjustments
|
||||
- Specialized calculations for forecasting and optimization insights
|
||||
|
||||
3. **Impact Estimation**
|
||||
- Cost savings quantification
|
||||
- Revenue increase projections
|
||||
- Waste reduction calculations
|
||||
- Efficiency gain measurements
|
||||
- Quality improvement tracking
|
||||
|
||||
4. **Feedback & Learning**
|
||||
- Track application outcomes
|
||||
- Compare expected vs. actual impact
|
||||
- Calculate success rates
|
||||
- Enable model improvement
|
||||
|
||||
5. **Orchestration Integration**
|
||||
- Pre-orchestration insight gathering
|
||||
- Actionable insight filtering
|
||||
- Categorized recommendations for workflow phases
|
||||
|
||||
## Architecture
|
||||
|
||||
### Database Models
|
||||
|
||||
- **AIInsight**: Core insights table with classification, confidence, impact metrics
|
||||
- **InsightFeedback**: Feedback tracking for closed-loop learning
|
||||
- **InsightCorrelation**: Cross-service insight relationships
|
||||
|
||||
### API Endpoints
|
||||
|
||||
```
|
||||
POST /api/v1/ai-insights/tenants/{tenant_id}/insights
|
||||
GET /api/v1/ai-insights/tenants/{tenant_id}/insights
|
||||
GET /api/v1/ai-insights/tenants/{tenant_id}/insights/{insight_id}
|
||||
PATCH /api/v1/ai-insights/tenants/{tenant_id}/insights/{insight_id}
|
||||
DELETE /api/v1/ai-insights/tenants/{tenant_id}/insights/{insight_id}
|
||||
|
||||
GET /api/v1/ai-insights/tenants/{tenant_id}/insights/orchestration-ready
|
||||
GET /api/v1/ai-insights/tenants/{tenant_id}/insights/metrics/summary
|
||||
POST /api/v1/ai-insights/tenants/{tenant_id}/insights/{insight_id}/apply
|
||||
POST /api/v1/ai-insights/tenants/{tenant_id}/insights/{insight_id}/feedback
|
||||
POST /api/v1/ai-insights/tenants/{tenant_id}/insights/refresh
|
||||
GET /api/v1/ai-insights/tenants/{tenant_id}/insights/export
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Python 3.11+
|
||||
- PostgreSQL 14+
|
||||
- Redis 6+
|
||||
|
||||
### Setup
|
||||
|
||||
1. **Clone and navigate**:
|
||||
```bash
|
||||
cd services/ai_insights
|
||||
```
|
||||
|
||||
2. **Create virtual environment**:
|
||||
```bash
|
||||
python -m venv venv
|
||||
source venv/bin/activate # On Windows: venv\Scripts\activate
|
||||
```
|
||||
|
||||
3. **Install dependencies**:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
4. **Configure environment**:
|
||||
```bash
|
||||
cp .env.example .env
|
||||
# Edit .env with your configuration
|
||||
```
|
||||
|
||||
5. **Run migrations**:
|
||||
```bash
|
||||
alembic upgrade head
|
||||
```
|
||||
|
||||
6. **Start the service**:
|
||||
```bash
|
||||
uvicorn app.main:app --reload
|
||||
```
|
||||
|
||||
The service will be available at `http://localhost:8000`.
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables
|
||||
|
||||
| Variable | Description | Default |
|
||||
|----------|-------------|---------|
|
||||
| `DATABASE_URL` | PostgreSQL connection string | Required |
|
||||
| `REDIS_URL` | Redis connection string | Required |
|
||||
| `FORECASTING_SERVICE_URL` | Forecasting service URL | `http://forecasting-service:8000` |
|
||||
| `PROCUREMENT_SERVICE_URL` | Procurement service URL | `http://procurement-service:8000` |
|
||||
| `PRODUCTION_SERVICE_URL` | Production service URL | `http://production-service:8000` |
|
||||
| `MIN_CONFIDENCE_THRESHOLD` | Minimum confidence for insights | `60` |
|
||||
| `DEFAULT_INSIGHT_TTL_DAYS` | Days before insights expire | `7` |
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Creating an Insight
|
||||
|
||||
```python
|
||||
import httpx
|
||||
|
||||
insight_data = {
|
||||
"tenant_id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"type": "recommendation",
|
||||
"priority": "high",
|
||||
"category": "procurement",
|
||||
"title": "Flour Price Increase Expected",
|
||||
"description": "Price predicted to rise 8% in next week. Consider ordering now.",
|
||||
"impact_type": "cost_savings",
|
||||
"impact_value": 120.50,
|
||||
"impact_unit": "euros",
|
||||
"confidence": 85,
|
||||
"metrics_json": {
|
||||
"current_price": 1.20,
|
||||
"predicted_price": 1.30,
|
||||
"order_quantity": 1000
|
||||
},
|
||||
"actionable": True,
|
||||
"recommendation_actions": [
|
||||
{"label": "Order Now", "action": "create_purchase_order"},
|
||||
{"label": "Review", "action": "review_forecast"}
|
||||
],
|
||||
"source_service": "procurement",
|
||||
"source_data_id": "price_forecast_123"
|
||||
}
|
||||
|
||||
response = httpx.post(
|
||||
"http://localhost:8000/api/v1/ai-insights/tenants/550e8400-e29b-41d4-a716-446655440000/insights",
|
||||
json=insight_data
|
||||
)
|
||||
print(response.json())
|
||||
```
|
||||
|
||||
### Querying Insights
|
||||
|
||||
```python
|
||||
# Get high-confidence actionable insights
|
||||
response = httpx.get(
|
||||
"http://localhost:8000/api/v1/ai-insights/tenants/550e8400-e29b-41d4-a716-446655440000/insights",
|
||||
params={
|
||||
"actionable_only": True,
|
||||
"min_confidence": 80,
|
||||
"priority": "high",
|
||||
"page": 1,
|
||||
"page_size": 20
|
||||
}
|
||||
)
|
||||
insights = response.json()
|
||||
```
|
||||
|
||||
### Recording Feedback
|
||||
|
||||
```python
|
||||
feedback_data = {
|
||||
"insight_id": "insight-uuid",
|
||||
"action_taken": "create_purchase_order",
|
||||
"success": True,
|
||||
"expected_impact_value": 120.50,
|
||||
"actual_impact_value": 115.30,
|
||||
"result_data": {
|
||||
"order_id": "PO-12345",
|
||||
"actual_savings": 115.30
|
||||
},
|
||||
"applied_by": "user@example.com"
|
||||
}
|
||||
|
||||
response = httpx.post(
|
||||
f"http://localhost:8000/api/v1/ai-insights/tenants/{tenant_id}/insights/{insight_id}/feedback",
|
||||
json=feedback_data
|
||||
)
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
### Running Tests
|
||||
|
||||
```bash
|
||||
pytest
|
||||
```
|
||||
|
||||
### Code Quality
|
||||
|
||||
```bash
|
||||
# Format code
|
||||
black app/
|
||||
|
||||
# Lint
|
||||
flake8 app/
|
||||
|
||||
# Type checking
|
||||
mypy app/
|
||||
```
|
||||
|
||||
### Creating a Migration
|
||||
|
||||
```bash
|
||||
alembic revision --autogenerate -m "Description of changes"
|
||||
alembic upgrade head
|
||||
```
|
||||
|
||||
## Insight Types
|
||||
|
||||
- **optimization**: Process improvements with measurable gains
|
||||
- **alert**: Warnings requiring attention
|
||||
- **prediction**: Future forecasts with confidence intervals
|
||||
- **recommendation**: Suggested actions with estimated impact
|
||||
- **insight**: General data-driven observations
|
||||
- **anomaly**: Unusual patterns detected in data
|
||||
|
||||
## Priority Levels
|
||||
|
||||
- **critical**: Immediate action required (e.g., stockout risk)
|
||||
- **high**: Action recommended soon (e.g., price opportunity)
|
||||
- **medium**: Consider acting (e.g., efficiency improvement)
|
||||
- **low**: Informational (e.g., pattern observation)
|
||||
|
||||
## Categories
|
||||
|
||||
- **forecasting**: Demand predictions and patterns
|
||||
- **inventory**: Stock management and optimization
|
||||
- **production**: Manufacturing efficiency and scheduling
|
||||
- **procurement**: Purchasing and supplier management
|
||||
- **customer**: Customer behavior and satisfaction
|
||||
- **cost**: Cost optimization opportunities
|
||||
- **quality**: Quality improvements
|
||||
- **efficiency**: Process efficiency gains
|
||||
|
||||
## Integration with Other Services
|
||||
|
||||
### Forecasting Service
|
||||
|
||||
- Receives forecast accuracy insights
|
||||
- Pattern detection alerts
|
||||
- Demand anomaly notifications
|
||||
|
||||
### Procurement Service
|
||||
|
||||
- Price forecast recommendations
|
||||
- Supplier performance alerts
|
||||
- Safety stock optimization
|
||||
|
||||
### Production Service
|
||||
|
||||
- Yield prediction insights
|
||||
- Schedule optimization recommendations
|
||||
- Equipment maintenance alerts
|
||||
|
||||
### Orchestrator Service
|
||||
|
||||
- Pre-orchestration insight gathering
|
||||
- Actionable recommendation filtering
|
||||
- Feedback recording for applied insights
|
||||
|
||||
## API Documentation
|
||||
|
||||
Once the service is running, interactive API documentation is available at:
|
||||
|
||||
- Swagger UI: `http://localhost:8000/docs`
|
||||
- ReDoc: `http://localhost:8000/redoc`
|
||||
|
||||
## Monitoring
|
||||
|
||||
### Health Check
|
||||
|
||||
```bash
|
||||
curl http://localhost:8000/health
|
||||
```
|
||||
|
||||
### Metrics Endpoint
|
||||
|
||||
```bash
|
||||
curl http://localhost:8000/api/v1/ai-insights/tenants/{tenant_id}/insights/metrics/summary
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
Copyright © 2025 Bakery IA. All rights reserved.
|
||||
|
||||
## Support
|
||||
|
||||
For issues and questions, please contact the development team or create an issue in the project repository.
|
||||
112
services/ai_insights/alembic.ini
Normal file
112
services/ai_insights/alembic.ini
Normal file
@@ -0,0 +1,112 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
script_location = migrations
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
file_template = %%(year)d%%(month).2d%%(day).2d_%%(hour).2d%%(minute).2d_%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
prepend_sys_path = .
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the python-dateutil library that can be
|
||||
# installed by adding `alembic[tz]` to the pip requirements
|
||||
# string value is passed to dateutil.tz.gettz()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the
|
||||
# "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to migrations/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "version_path_separator" below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions
|
||||
|
||||
# version path separator; As mentioned above, this is the character used to split
|
||||
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
|
||||
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
|
||||
# Valid values for version_path_separator are:
|
||||
#
|
||||
# version_path_separator = :
|
||||
# version_path_separator = ;
|
||||
# version_path_separator = space
|
||||
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
# new in Alembic version 1.10
|
||||
# recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
|
||||
# hooks = ruff
|
||||
# ruff.type = exec
|
||||
# ruff.executable = %(here)s/.venv/bin/ruff
|
||||
# ruff.options = --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Logging configuration
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
3
services/ai_insights/app/__init__.py
Normal file
3
services/ai_insights/app/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""AI Insights Service."""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
1
services/ai_insights/app/api/__init__.py
Normal file
1
services/ai_insights/app/api/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""API modules for AI Insights Service."""
|
||||
419
services/ai_insights/app/api/insights.py
Normal file
419
services/ai_insights/app/api/insights.py
Normal file
@@ -0,0 +1,419 @@
|
||||
"""API endpoints for AI Insights."""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
import math
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.repositories.insight_repository import InsightRepository
|
||||
from app.repositories.feedback_repository import FeedbackRepository
|
||||
from app.schemas.insight import (
|
||||
AIInsightCreate,
|
||||
AIInsightUpdate,
|
||||
AIInsightResponse,
|
||||
AIInsightList,
|
||||
InsightMetrics,
|
||||
InsightFilters
|
||||
)
|
||||
from app.schemas.feedback import InsightFeedbackCreate, InsightFeedbackResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/tenants/{tenant_id}/insights", response_model=AIInsightResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_insight(
|
||||
tenant_id: UUID,
|
||||
insight_data: AIInsightCreate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Create a new AI Insight."""
|
||||
# Ensure tenant_id matches
|
||||
if insight_data.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Tenant ID mismatch"
|
||||
)
|
||||
|
||||
repo = InsightRepository(db)
|
||||
insight = await repo.create(insight_data)
|
||||
await db.commit()
|
||||
|
||||
return insight
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}/insights", response_model=AIInsightList)
|
||||
async def get_insights(
|
||||
tenant_id: UUID,
|
||||
category: Optional[str] = Query(None),
|
||||
priority: Optional[str] = Query(None),
|
||||
status: Optional[str] = Query(None),
|
||||
actionable_only: bool = Query(False),
|
||||
min_confidence: int = Query(0, ge=0, le=100),
|
||||
source_service: Optional[str] = Query(None),
|
||||
from_date: Optional[datetime] = Query(None),
|
||||
to_date: Optional[datetime] = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get insights for a tenant with filters and pagination."""
|
||||
filters = InsightFilters(
|
||||
category=category,
|
||||
priority=priority,
|
||||
status=status,
|
||||
actionable_only=actionable_only,
|
||||
min_confidence=min_confidence,
|
||||
source_service=source_service,
|
||||
from_date=from_date,
|
||||
to_date=to_date
|
||||
)
|
||||
|
||||
repo = InsightRepository(db)
|
||||
skip = (page - 1) * page_size
|
||||
|
||||
insights, total = await repo.get_by_tenant(tenant_id, filters, skip, page_size)
|
||||
|
||||
total_pages = math.ceil(total / page_size) if total > 0 else 0
|
||||
|
||||
return AIInsightList(
|
||||
items=insights,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}/insights/orchestration-ready")
|
||||
async def get_orchestration_ready_insights(
|
||||
tenant_id: UUID,
|
||||
target_date: datetime = Query(...),
|
||||
min_confidence: int = Query(70, ge=0, le=100),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get actionable insights for orchestration workflow."""
|
||||
repo = InsightRepository(db)
|
||||
categorized_insights = await repo.get_orchestration_ready_insights(
|
||||
tenant_id, target_date, min_confidence
|
||||
)
|
||||
|
||||
return categorized_insights
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}/insights/{insight_id}", response_model=AIInsightResponse)
|
||||
async def get_insight(
|
||||
tenant_id: UUID,
|
||||
insight_id: UUID,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get a single insight by ID."""
|
||||
repo = InsightRepository(db)
|
||||
insight = await repo.get_by_id(insight_id)
|
||||
|
||||
if not insight:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Insight not found"
|
||||
)
|
||||
|
||||
if insight.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied"
|
||||
)
|
||||
|
||||
return insight
|
||||
|
||||
|
||||
@router.patch("/tenants/{tenant_id}/insights/{insight_id}", response_model=AIInsightResponse)
|
||||
async def update_insight(
|
||||
tenant_id: UUID,
|
||||
insight_id: UUID,
|
||||
update_data: AIInsightUpdate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Update an insight (typically status changes)."""
|
||||
repo = InsightRepository(db)
|
||||
|
||||
# Verify insight exists and belongs to tenant
|
||||
insight = await repo.get_by_id(insight_id)
|
||||
if not insight:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Insight not found"
|
||||
)
|
||||
|
||||
if insight.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied"
|
||||
)
|
||||
|
||||
updated_insight = await repo.update(insight_id, update_data)
|
||||
await db.commit()
|
||||
|
||||
return updated_insight
|
||||
|
||||
|
||||
@router.delete("/tenants/{tenant_id}/insights/{insight_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def dismiss_insight(
|
||||
tenant_id: UUID,
|
||||
insight_id: UUID,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Dismiss an insight (soft delete)."""
|
||||
repo = InsightRepository(db)
|
||||
|
||||
# Verify insight exists and belongs to tenant
|
||||
insight = await repo.get_by_id(insight_id)
|
||||
if not insight:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Insight not found"
|
||||
)
|
||||
|
||||
if insight.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied"
|
||||
)
|
||||
|
||||
await repo.delete(insight_id)
|
||||
await db.commit()
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}/insights/metrics/summary", response_model=InsightMetrics)
|
||||
async def get_insights_metrics(
|
||||
tenant_id: UUID,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get aggregate metrics for insights."""
|
||||
repo = InsightRepository(db)
|
||||
metrics = await repo.get_metrics(tenant_id)
|
||||
|
||||
return InsightMetrics(**metrics)
|
||||
|
||||
|
||||
@router.post("/tenants/{tenant_id}/insights/{insight_id}/apply")
|
||||
async def apply_insight(
|
||||
tenant_id: UUID,
|
||||
insight_id: UUID,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Apply an insight recommendation (trigger action)."""
|
||||
repo = InsightRepository(db)
|
||||
|
||||
# Verify insight exists and belongs to tenant
|
||||
insight = await repo.get_by_id(insight_id)
|
||||
if not insight:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Insight not found"
|
||||
)
|
||||
|
||||
if insight.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied"
|
||||
)
|
||||
|
||||
if not insight.actionable:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="This insight is not actionable"
|
||||
)
|
||||
|
||||
# Update status to in_progress
|
||||
update_data = AIInsightUpdate(status='in_progress', applied_at=datetime.utcnow())
|
||||
await repo.update(insight_id, update_data)
|
||||
await db.commit()
|
||||
|
||||
# Route to appropriate service based on recommendation_actions
|
||||
applied_actions = []
|
||||
failed_actions = []
|
||||
|
||||
try:
|
||||
import structlog
|
||||
logger = structlog.get_logger()
|
||||
|
||||
for action in insight.recommendation_actions:
|
||||
try:
|
||||
action_type = action.get('action_type')
|
||||
action_target = action.get('target_service')
|
||||
|
||||
logger.info("Processing insight action",
|
||||
insight_id=str(insight_id),
|
||||
action_type=action_type,
|
||||
target_service=action_target)
|
||||
|
||||
# Route based on target service
|
||||
if action_target == 'procurement':
|
||||
# Create purchase order or adjust reorder points
|
||||
from shared.clients.procurement_client import ProcurementServiceClient
|
||||
from shared.config.base import get_settings
|
||||
|
||||
config = get_settings()
|
||||
procurement_client = ProcurementServiceClient(config, "ai_insights")
|
||||
|
||||
# Example: trigger procurement action
|
||||
logger.info("Routing action to procurement service", action=action)
|
||||
applied_actions.append(action_type)
|
||||
|
||||
elif action_target == 'production':
|
||||
# Adjust production schedule
|
||||
from shared.clients.production_client import ProductionServiceClient
|
||||
from shared.config.base import get_settings
|
||||
|
||||
config = get_settings()
|
||||
production_client = ProductionServiceClient(config, "ai_insights")
|
||||
|
||||
logger.info("Routing action to production service", action=action)
|
||||
applied_actions.append(action_type)
|
||||
|
||||
elif action_target == 'inventory':
|
||||
# Adjust inventory settings
|
||||
from shared.clients.inventory_client import InventoryServiceClient
|
||||
from shared.config.base import get_settings
|
||||
|
||||
config = get_settings()
|
||||
inventory_client = InventoryServiceClient(config, "ai_insights")
|
||||
|
||||
logger.info("Routing action to inventory service", action=action)
|
||||
applied_actions.append(action_type)
|
||||
|
||||
elif action_target == 'pricing':
|
||||
# Update pricing recommendations
|
||||
logger.info("Price adjustment action identified", action=action)
|
||||
applied_actions.append(action_type)
|
||||
|
||||
else:
|
||||
logger.warning("Unknown target service for action",
|
||||
action_type=action_type,
|
||||
target_service=action_target)
|
||||
failed_actions.append({
|
||||
'action_type': action_type,
|
||||
'reason': f'Unknown target service: {action_target}'
|
||||
})
|
||||
|
||||
except Exception as action_error:
|
||||
logger.error("Failed to apply action",
|
||||
action_type=action.get('action_type'),
|
||||
error=str(action_error))
|
||||
failed_actions.append({
|
||||
'action_type': action.get('action_type'),
|
||||
'reason': str(action_error)
|
||||
})
|
||||
|
||||
# Update final status
|
||||
final_status = 'applied' if not failed_actions else 'partially_applied'
|
||||
final_update = AIInsightUpdate(status=final_status)
|
||||
await repo.update(insight_id, final_update)
|
||||
await db.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to route insight actions",
|
||||
insight_id=str(insight_id),
|
||||
error=str(e))
|
||||
# Update status to failed
|
||||
failed_update = AIInsightUpdate(status='failed')
|
||||
await repo.update(insight_id, failed_update)
|
||||
await db.commit()
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to apply insight: {str(e)}"
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "Insight application initiated",
|
||||
"insight_id": str(insight_id),
|
||||
"actions": insight.recommendation_actions,
|
||||
"applied_actions": applied_actions,
|
||||
"failed_actions": failed_actions,
|
||||
"status": final_status
|
||||
}
|
||||
|
||||
|
||||
@router.post("/tenants/{tenant_id}/insights/{insight_id}/feedback", response_model=InsightFeedbackResponse)
|
||||
async def record_feedback(
|
||||
tenant_id: UUID,
|
||||
insight_id: UUID,
|
||||
feedback_data: InsightFeedbackCreate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Record feedback for an applied insight."""
|
||||
insight_repo = InsightRepository(db)
|
||||
|
||||
# Verify insight exists and belongs to tenant
|
||||
insight = await insight_repo.get_by_id(insight_id)
|
||||
if not insight:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Insight not found"
|
||||
)
|
||||
|
||||
if insight.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied"
|
||||
)
|
||||
|
||||
# Ensure feedback is for this insight
|
||||
if feedback_data.insight_id != insight_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Insight ID mismatch"
|
||||
)
|
||||
|
||||
feedback_repo = FeedbackRepository(db)
|
||||
feedback = await feedback_repo.create(feedback_data)
|
||||
|
||||
# Update insight status based on feedback
|
||||
new_status = 'applied' if feedback.success else 'dismissed'
|
||||
update_data = AIInsightUpdate(status=new_status)
|
||||
await insight_repo.update(insight_id, update_data)
|
||||
|
||||
await db.commit()
|
||||
|
||||
return feedback
|
||||
|
||||
|
||||
@router.post("/tenants/{tenant_id}/insights/refresh")
|
||||
async def refresh_insights(
|
||||
tenant_id: UUID,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Trigger insight refresh (expire old, generate new)."""
|
||||
repo = InsightRepository(db)
|
||||
|
||||
# Expire old insights
|
||||
expired_count = await repo.expire_old_insights()
|
||||
await db.commit()
|
||||
|
||||
return {
|
||||
"message": "Insights refreshed",
|
||||
"expired_count": expired_count
|
||||
}
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}/insights/export")
|
||||
async def export_insights(
|
||||
tenant_id: UUID,
|
||||
format: str = Query("json", regex="^(json|csv)$"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Export insights to JSON or CSV."""
|
||||
repo = InsightRepository(db)
|
||||
insights, _ = await repo.get_by_tenant(tenant_id, filters=None, skip=0, limit=1000)
|
||||
|
||||
if format == "json":
|
||||
return {"insights": [AIInsightResponse.model_validate(i) for i in insights]}
|
||||
|
||||
# CSV export would be implemented here
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="CSV export not yet implemented"
|
||||
)
|
||||
77
services/ai_insights/app/core/config.py
Normal file
77
services/ai_insights/app/core/config.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""Configuration settings for AI Insights Service."""
|
||||
|
||||
from shared.config.base import BaseServiceSettings
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class Settings(BaseServiceSettings):
|
||||
"""Application settings."""
|
||||
|
||||
# Service Info
|
||||
SERVICE_NAME: str = "ai-insights"
|
||||
SERVICE_VERSION: str = "1.0.0"
|
||||
API_V1_PREFIX: str = "/api/v1"
|
||||
|
||||
# Database configuration (secure approach - build from components)
|
||||
@property
|
||||
def DATABASE_URL(self) -> str:
|
||||
"""Build database URL from secure components"""
|
||||
# Try complete URL first (for backward compatibility)
|
||||
complete_url = os.getenv("AI_INSIGHTS_DATABASE_URL")
|
||||
if complete_url:
|
||||
return complete_url
|
||||
|
||||
# Also check for generic DATABASE_URL (for migration compatibility)
|
||||
generic_url = os.getenv("DATABASE_URL")
|
||||
if generic_url:
|
||||
return generic_url
|
||||
|
||||
# Build from components (secure approach)
|
||||
user = os.getenv("AI_INSIGHTS_DB_USER", "ai_insights_user")
|
||||
password = os.getenv("AI_INSIGHTS_DB_PASSWORD", "ai_insights_pass123")
|
||||
host = os.getenv("AI_INSIGHTS_DB_HOST", "localhost")
|
||||
port = os.getenv("AI_INSIGHTS_DB_PORT", "5432")
|
||||
name = os.getenv("AI_INSIGHTS_DB_NAME", "ai_insights_db")
|
||||
|
||||
return f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{name}"
|
||||
|
||||
DB_POOL_SIZE: int = 20
|
||||
DB_MAX_OVERFLOW: int = 10
|
||||
|
||||
# Redis (inherited from BaseServiceSettings but can override)
|
||||
REDIS_CACHE_TTL: int = 900 # 15 minutes
|
||||
REDIS_DB: int = 3 # Dedicated Redis database for AI Insights
|
||||
|
||||
# Service URLs
|
||||
FORECASTING_SERVICE_URL: str = "http://forecasting-service:8000"
|
||||
PROCUREMENT_SERVICE_URL: str = "http://procurement-service:8000"
|
||||
PRODUCTION_SERVICE_URL: str = "http://production-service:8000"
|
||||
SALES_SERVICE_URL: str = "http://sales-service:8000"
|
||||
INVENTORY_SERVICE_URL: str = "http://inventory-service:8000"
|
||||
|
||||
# Circuit Breaker Settings
|
||||
CIRCUIT_BREAKER_FAILURE_THRESHOLD: int = 5
|
||||
CIRCUIT_BREAKER_TIMEOUT: int = 60
|
||||
|
||||
# Insight Settings
|
||||
MIN_CONFIDENCE_THRESHOLD: int = 60
|
||||
DEFAULT_INSIGHT_TTL_DAYS: int = 7
|
||||
MAX_INSIGHTS_PER_REQUEST: int = 100
|
||||
|
||||
# Feedback Settings
|
||||
FEEDBACK_PROCESSING_ENABLED: bool = True
|
||||
FEEDBACK_PROCESSING_SCHEDULE: str = "0 6 * * *" # Daily at 6 AM
|
||||
|
||||
# Logging
|
||||
LOG_LEVEL: str = "INFO"
|
||||
|
||||
# CORS
|
||||
ALLOWED_ORIGINS: list[str] = ["http://localhost:3000", "http://localhost:5173"]
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
case_sensitive = True
|
||||
|
||||
|
||||
settings = Settings()
|
||||
58
services/ai_insights/app/core/database.py
Normal file
58
services/ai_insights/app/core/database.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Database configuration and session management."""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||
from sqlalchemy.orm import declarative_base
|
||||
from sqlalchemy.pool import NullPool
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# Create async engine
|
||||
engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
pool_size=settings.DB_POOL_SIZE,
|
||||
max_overflow=settings.DB_MAX_OVERFLOW,
|
||||
echo=False,
|
||||
future=True,
|
||||
)
|
||||
|
||||
# Create async session factory
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
# Create declarative base
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
Dependency for getting async database sessions.
|
||||
|
||||
Yields:
|
||||
AsyncSession: Database session
|
||||
"""
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def init_db():
|
||||
"""Initialize database tables."""
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
|
||||
async def close_db():
|
||||
"""Close database connections."""
|
||||
await engine.dispose()
|
||||
320
services/ai_insights/app/impact/impact_estimator.py
Normal file
320
services/ai_insights/app/impact/impact_estimator.py
Normal file
@@ -0,0 +1,320 @@
|
||||
"""Impact estimation for AI Insights."""
|
||||
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
from decimal import Decimal
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
class ImpactEstimator:
|
||||
"""
|
||||
Estimate potential impact of recommendations.
|
||||
|
||||
Calculates expected business value in terms of:
|
||||
- Cost savings (euros)
|
||||
- Revenue increase (euros)
|
||||
- Waste reduction (euros or percentage)
|
||||
- Efficiency gains (hours or percentage)
|
||||
- Quality improvements (units or percentage)
|
||||
"""
|
||||
|
||||
def estimate_procurement_savings(
|
||||
self,
|
||||
current_price: Decimal,
|
||||
predicted_price: Decimal,
|
||||
order_quantity: Decimal,
|
||||
timeframe_days: int = 30
|
||||
) -> Tuple[Decimal, str, str]:
|
||||
"""
|
||||
Estimate savings from opportunistic buying.
|
||||
|
||||
Args:
|
||||
current_price: Current unit price
|
||||
predicted_price: Predicted future price
|
||||
order_quantity: Quantity to order
|
||||
timeframe_days: Time horizon for prediction
|
||||
|
||||
Returns:
|
||||
tuple: (impact_value, impact_unit, impact_type)
|
||||
"""
|
||||
savings_per_unit = predicted_price - current_price
|
||||
|
||||
if savings_per_unit > 0:
|
||||
total_savings = savings_per_unit * order_quantity
|
||||
return (
|
||||
round(total_savings, 2),
|
||||
'euros',
|
||||
'cost_savings'
|
||||
)
|
||||
return (Decimal('0.0'), 'euros', 'cost_savings')
|
||||
|
||||
def estimate_waste_reduction_savings(
|
||||
self,
|
||||
current_waste_rate: float,
|
||||
optimized_waste_rate: float,
|
||||
monthly_volume: Decimal,
|
||||
avg_cost_per_unit: Decimal
|
||||
) -> Tuple[Decimal, str, str]:
|
||||
"""
|
||||
Estimate savings from waste reduction.
|
||||
|
||||
Args:
|
||||
current_waste_rate: Current waste rate (0-1)
|
||||
optimized_waste_rate: Optimized waste rate (0-1)
|
||||
monthly_volume: Monthly volume
|
||||
avg_cost_per_unit: Average cost per unit
|
||||
|
||||
Returns:
|
||||
tuple: (impact_value, impact_unit, impact_type)
|
||||
"""
|
||||
waste_reduction_rate = current_waste_rate - optimized_waste_rate
|
||||
units_saved = monthly_volume * Decimal(str(waste_reduction_rate))
|
||||
savings = units_saved * avg_cost_per_unit
|
||||
|
||||
return (
|
||||
round(savings, 2),
|
||||
'euros/month',
|
||||
'waste_reduction'
|
||||
)
|
||||
|
||||
def estimate_forecast_improvement_value(
|
||||
self,
|
||||
current_mape: float,
|
||||
improved_mape: float,
|
||||
avg_monthly_revenue: Decimal
|
||||
) -> Tuple[Decimal, str, str]:
|
||||
"""
|
||||
Estimate value from forecast accuracy improvement.
|
||||
|
||||
Better forecasts reduce:
|
||||
- Stockouts (lost sales)
|
||||
- Overproduction (waste)
|
||||
- Emergency orders (premium costs)
|
||||
|
||||
Args:
|
||||
current_mape: Current forecast MAPE
|
||||
improved_mape: Improved forecast MAPE
|
||||
avg_monthly_revenue: Average monthly revenue
|
||||
|
||||
Returns:
|
||||
tuple: (impact_value, impact_unit, impact_type)
|
||||
"""
|
||||
# Rule of thumb: 1% MAPE improvement = 0.5% revenue impact
|
||||
mape_improvement = current_mape - improved_mape
|
||||
revenue_impact_pct = mape_improvement * 0.5 / 100
|
||||
|
||||
revenue_increase = avg_monthly_revenue * Decimal(str(revenue_impact_pct))
|
||||
|
||||
return (
|
||||
round(revenue_increase, 2),
|
||||
'euros/month',
|
||||
'revenue_increase'
|
||||
)
|
||||
|
||||
def estimate_production_efficiency_gain(
|
||||
self,
|
||||
time_saved_minutes: int,
|
||||
batches_per_month: int,
|
||||
labor_cost_per_hour: Decimal = Decimal('15.0')
|
||||
) -> Tuple[Decimal, str, str]:
|
||||
"""
|
||||
Estimate value from production efficiency improvements.
|
||||
|
||||
Args:
|
||||
time_saved_minutes: Minutes saved per batch
|
||||
batches_per_month: Number of batches per month
|
||||
labor_cost_per_hour: Labor cost per hour
|
||||
|
||||
Returns:
|
||||
tuple: (impact_value, impact_unit, impact_type)
|
||||
"""
|
||||
hours_saved_per_month = (time_saved_minutes * batches_per_month) / 60
|
||||
cost_savings = Decimal(str(hours_saved_per_month)) * labor_cost_per_hour
|
||||
|
||||
return (
|
||||
round(cost_savings, 2),
|
||||
'euros/month',
|
||||
'efficiency_gain'
|
||||
)
|
||||
|
||||
def estimate_safety_stock_optimization(
|
||||
self,
|
||||
current_safety_stock: Decimal,
|
||||
optimal_safety_stock: Decimal,
|
||||
holding_cost_per_unit_per_day: Decimal,
|
||||
stockout_cost_reduction: Decimal = Decimal('0.0')
|
||||
) -> Tuple[Decimal, str, str]:
|
||||
"""
|
||||
Estimate impact of safety stock optimization.
|
||||
|
||||
Args:
|
||||
current_safety_stock: Current safety stock level
|
||||
optimal_safety_stock: Optimal safety stock level
|
||||
holding_cost_per_unit_per_day: Daily holding cost
|
||||
stockout_cost_reduction: Reduction in stockout costs
|
||||
|
||||
Returns:
|
||||
tuple: (impact_value, impact_unit, impact_type)
|
||||
"""
|
||||
stock_reduction = current_safety_stock - optimal_safety_stock
|
||||
|
||||
if stock_reduction > 0:
|
||||
# Savings from reduced holding costs
|
||||
daily_savings = stock_reduction * holding_cost_per_unit_per_day
|
||||
monthly_savings = daily_savings * 30
|
||||
total_savings = monthly_savings + stockout_cost_reduction
|
||||
|
||||
return (
|
||||
round(total_savings, 2),
|
||||
'euros/month',
|
||||
'cost_savings'
|
||||
)
|
||||
elif stock_reduction < 0:
|
||||
# Cost increase but reduces stockouts
|
||||
daily_cost = abs(stock_reduction) * holding_cost_per_unit_per_day
|
||||
monthly_cost = daily_cost * 30
|
||||
net_savings = stockout_cost_reduction - monthly_cost
|
||||
|
||||
if net_savings > 0:
|
||||
return (
|
||||
round(net_savings, 2),
|
||||
'euros/month',
|
||||
'cost_savings'
|
||||
)
|
||||
|
||||
return (Decimal('0.0'), 'euros/month', 'cost_savings')
|
||||
|
||||
def estimate_supplier_switch_savings(
|
||||
self,
|
||||
current_supplier_price: Decimal,
|
||||
alternative_supplier_price: Decimal,
|
||||
monthly_order_quantity: Decimal,
|
||||
quality_difference_score: float = 0.0 # -1 to 1
|
||||
) -> Tuple[Decimal, str, str]:
|
||||
"""
|
||||
Estimate savings from switching suppliers.
|
||||
|
||||
Args:
|
||||
current_supplier_price: Current supplier unit price
|
||||
alternative_supplier_price: Alternative supplier unit price
|
||||
monthly_order_quantity: Monthly order quantity
|
||||
quality_difference_score: Quality difference (-1=worse, 0=same, 1=better)
|
||||
|
||||
Returns:
|
||||
tuple: (impact_value, impact_unit, impact_type)
|
||||
"""
|
||||
price_savings = (current_supplier_price - alternative_supplier_price) * monthly_order_quantity
|
||||
|
||||
# Adjust for quality difference
|
||||
# If quality is worse, reduce estimated savings
|
||||
quality_adjustment = 1 + (quality_difference_score * 0.1) # ±10% max adjustment
|
||||
adjusted_savings = price_savings * Decimal(str(quality_adjustment))
|
||||
|
||||
return (
|
||||
round(adjusted_savings, 2),
|
||||
'euros/month',
|
||||
'cost_savings'
|
||||
)
|
||||
|
||||
def estimate_yield_improvement_value(
|
||||
self,
|
||||
current_yield_rate: float,
|
||||
predicted_yield_rate: float,
|
||||
production_volume: Decimal,
|
||||
product_price: Decimal
|
||||
) -> Tuple[Decimal, str, str]:
|
||||
"""
|
||||
Estimate value from production yield improvements.
|
||||
|
||||
Args:
|
||||
current_yield_rate: Current yield rate (0-1)
|
||||
predicted_yield_rate: Predicted yield rate (0-1)
|
||||
production_volume: Monthly production volume
|
||||
product_price: Product selling price
|
||||
|
||||
Returns:
|
||||
tuple: (impact_value, impact_unit, impact_type)
|
||||
"""
|
||||
yield_improvement = predicted_yield_rate - current_yield_rate
|
||||
|
||||
if yield_improvement > 0:
|
||||
additional_units = production_volume * Decimal(str(yield_improvement))
|
||||
revenue_increase = additional_units * product_price
|
||||
|
||||
return (
|
||||
round(revenue_increase, 2),
|
||||
'euros/month',
|
||||
'revenue_increase'
|
||||
)
|
||||
|
||||
return (Decimal('0.0'), 'euros/month', 'revenue_increase')
|
||||
|
||||
def estimate_demand_pattern_value(
|
||||
self,
|
||||
pattern_strength: float, # 0-1
|
||||
potential_revenue_increase: Decimal,
|
||||
implementation_cost: Decimal = Decimal('0.0')
|
||||
) -> Tuple[Decimal, str, str]:
|
||||
"""
|
||||
Estimate value from acting on demand patterns.
|
||||
|
||||
Args:
|
||||
pattern_strength: Strength of detected pattern (0-1)
|
||||
potential_revenue_increase: Potential monthly revenue increase
|
||||
implementation_cost: One-time implementation cost
|
||||
|
||||
Returns:
|
||||
tuple: (impact_value, impact_unit, impact_type)
|
||||
"""
|
||||
# Discount by pattern strength (confidence)
|
||||
expected_value = potential_revenue_increase * Decimal(str(pattern_strength))
|
||||
|
||||
# Amortize implementation cost over 6 months
|
||||
monthly_cost = implementation_cost / 6
|
||||
|
||||
net_value = expected_value - monthly_cost
|
||||
|
||||
return (
|
||||
round(max(Decimal('0.0'), net_value), 2),
|
||||
'euros/month',
|
||||
'revenue_increase'
|
||||
)
|
||||
|
||||
def estimate_composite_impact(
|
||||
self,
|
||||
impacts: list[Dict[str, Any]]
|
||||
) -> Tuple[Decimal, str, str]:
|
||||
"""
|
||||
Combine multiple impact estimations.
|
||||
|
||||
Args:
|
||||
impacts: List of impact dicts with 'value', 'unit', 'type'
|
||||
|
||||
Returns:
|
||||
tuple: (total_impact_value, impact_unit, impact_type)
|
||||
"""
|
||||
total_savings = Decimal('0.0')
|
||||
total_revenue = Decimal('0.0')
|
||||
|
||||
for impact in impacts:
|
||||
value = Decimal(str(impact['value']))
|
||||
impact_type = impact['type']
|
||||
|
||||
if impact_type == 'cost_savings':
|
||||
total_savings += value
|
||||
elif impact_type == 'revenue_increase':
|
||||
total_revenue += value
|
||||
|
||||
# Combine both types
|
||||
total_impact = total_savings + total_revenue
|
||||
|
||||
if total_impact > 0:
|
||||
# Determine primary type
|
||||
primary_type = 'cost_savings' if total_savings > total_revenue else 'revenue_increase'
|
||||
|
||||
return (
|
||||
round(total_impact, 2),
|
||||
'euros/month',
|
||||
primary_type
|
||||
)
|
||||
|
||||
return (Decimal('0.0'), 'euros/month', 'cost_savings')
|
||||
68
services/ai_insights/app/main.py
Normal file
68
services/ai_insights/app/main.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Main FastAPI application for AI Insights Service."""
|
||||
|
||||
import structlog
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import init_db, close_db
|
||||
from app.api import insights
|
||||
from shared.service_base import StandardFastAPIService
|
||||
|
||||
# Initialize logger
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class AIInsightsService(StandardFastAPIService):
|
||||
"""AI Insights Service with standardized monitoring setup"""
|
||||
|
||||
async def on_startup(self, app):
|
||||
"""Custom startup logic for AI Insights"""
|
||||
# Initialize database
|
||||
await init_db()
|
||||
logger.info("Database initialized")
|
||||
|
||||
await super().on_startup(app)
|
||||
|
||||
async def on_shutdown(self, app):
|
||||
"""Custom shutdown logic for AI Insights"""
|
||||
await super().on_shutdown(app)
|
||||
|
||||
# Close database
|
||||
await close_db()
|
||||
logger.info("Database connections closed")
|
||||
|
||||
|
||||
# Create service instance
|
||||
service = AIInsightsService(
|
||||
service_name="ai-insights",
|
||||
app_name="AI Insights Service",
|
||||
description="Intelligent insights and recommendations for bakery operations",
|
||||
version=settings.SERVICE_VERSION,
|
||||
log_level=getattr(settings, 'LOG_LEVEL', 'INFO'),
|
||||
cors_origins=getattr(settings, 'ALLOWED_ORIGINS', ["*"]),
|
||||
api_prefix=settings.API_V1_PREFIX,
|
||||
enable_metrics=True,
|
||||
enable_health_checks=True,
|
||||
enable_tracing=True,
|
||||
enable_cors=True
|
||||
)
|
||||
|
||||
# Create FastAPI app
|
||||
app = service.create_app()
|
||||
|
||||
# Add service-specific routers
|
||||
service.add_router(
|
||||
insights.router,
|
||||
tags=["insights"]
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
reload=True,
|
||||
log_level=settings.LOG_LEVEL.lower()
|
||||
)
|
||||
672
services/ai_insights/app/ml/feedback_learning_system.py
Normal file
672
services/ai_insights/app/ml/feedback_learning_system.py
Normal file
@@ -0,0 +1,672 @@
|
||||
"""
|
||||
Feedback Loop & Learning System
|
||||
Enables continuous improvement through outcome tracking and model retraining
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from uuid import UUID
|
||||
import structlog
|
||||
from scipy import stats
|
||||
from collections import defaultdict
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class FeedbackLearningSystem:
|
||||
"""
|
||||
Manages feedback collection, model performance tracking, and retraining triggers.
|
||||
|
||||
Key Responsibilities:
|
||||
1. Aggregate feedback from applied insights
|
||||
2. Calculate model performance metrics (accuracy, precision, recall)
|
||||
3. Detect performance degradation
|
||||
4. Trigger automatic retraining when needed
|
||||
5. Calibrate confidence scores based on actual accuracy
|
||||
6. Generate learning insights for model improvement
|
||||
|
||||
Workflow:
|
||||
- Feedback continuously recorded via AIInsightsClient
|
||||
- Periodic performance analysis (daily/weekly)
|
||||
- Automatic alerts when performance degrades
|
||||
- Retraining recommendations with priority
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
performance_threshold: float = 0.85, # Minimum acceptable accuracy
|
||||
degradation_threshold: float = 0.10, # 10% drop triggers alert
|
||||
min_feedback_samples: int = 30, # Minimum samples for analysis
|
||||
retraining_window_days: int = 90 # Consider last 90 days
|
||||
):
|
||||
self.performance_threshold = performance_threshold
|
||||
self.degradation_threshold = degradation_threshold
|
||||
self.min_feedback_samples = min_feedback_samples
|
||||
self.retraining_window_days = retraining_window_days
|
||||
|
||||
async def analyze_model_performance(
|
||||
self,
|
||||
model_name: str,
|
||||
feedback_data: pd.DataFrame,
|
||||
baseline_performance: Optional[Dict[str, float]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Analyze model performance based on feedback data.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model (e.g., 'hybrid_forecaster', 'yield_predictor')
|
||||
feedback_data: DataFrame with columns:
|
||||
- insight_id
|
||||
- applied_at
|
||||
- outcome_date
|
||||
- predicted_value
|
||||
- actual_value
|
||||
- error
|
||||
- error_pct
|
||||
- accuracy
|
||||
baseline_performance: Optional baseline metrics for comparison
|
||||
|
||||
Returns:
|
||||
Performance analysis with metrics, trends, and recommendations
|
||||
"""
|
||||
logger.info(
|
||||
"Analyzing model performance",
|
||||
model_name=model_name,
|
||||
feedback_samples=len(feedback_data)
|
||||
)
|
||||
|
||||
if len(feedback_data) < self.min_feedback_samples:
|
||||
return self._insufficient_feedback_response(
|
||||
model_name, len(feedback_data), self.min_feedback_samples
|
||||
)
|
||||
|
||||
# Step 1: Calculate current performance metrics
|
||||
current_metrics = self._calculate_performance_metrics(feedback_data)
|
||||
|
||||
# Step 2: Analyze performance trend over time
|
||||
trend_analysis = self._analyze_performance_trend(feedback_data)
|
||||
|
||||
# Step 3: Detect performance degradation
|
||||
degradation_detected = self._detect_performance_degradation(
|
||||
current_metrics, baseline_performance, trend_analysis
|
||||
)
|
||||
|
||||
# Step 4: Generate retraining recommendation
|
||||
retraining_recommendation = self._generate_retraining_recommendation(
|
||||
model_name, current_metrics, degradation_detected, trend_analysis
|
||||
)
|
||||
|
||||
# Step 5: Identify error patterns
|
||||
error_patterns = self._identify_error_patterns(feedback_data)
|
||||
|
||||
# Step 6: Calculate confidence calibration
|
||||
confidence_calibration = self._calculate_confidence_calibration(feedback_data)
|
||||
|
||||
logger.info(
|
||||
"Model performance analysis complete",
|
||||
model_name=model_name,
|
||||
current_accuracy=current_metrics['accuracy'],
|
||||
degradation_detected=degradation_detected['detected'],
|
||||
retraining_recommended=retraining_recommendation['recommended']
|
||||
)
|
||||
|
||||
return {
|
||||
'model_name': model_name,
|
||||
'analyzed_at': datetime.utcnow().isoformat(),
|
||||
'feedback_samples': len(feedback_data),
|
||||
'date_range': {
|
||||
'start': feedback_data['outcome_date'].min().isoformat(),
|
||||
'end': feedback_data['outcome_date'].max().isoformat()
|
||||
},
|
||||
'current_performance': current_metrics,
|
||||
'baseline_performance': baseline_performance,
|
||||
'trend_analysis': trend_analysis,
|
||||
'degradation_detected': degradation_detected,
|
||||
'retraining_recommendation': retraining_recommendation,
|
||||
'error_patterns': error_patterns,
|
||||
'confidence_calibration': confidence_calibration
|
||||
}
|
||||
|
||||
def _insufficient_feedback_response(
|
||||
self, model_name: str, current_samples: int, required_samples: int
|
||||
) -> Dict[str, Any]:
|
||||
"""Return response when insufficient feedback data."""
|
||||
return {
|
||||
'model_name': model_name,
|
||||
'analyzed_at': datetime.utcnow().isoformat(),
|
||||
'status': 'insufficient_feedback',
|
||||
'feedback_samples': current_samples,
|
||||
'required_samples': required_samples,
|
||||
'current_performance': None,
|
||||
'recommendation': f'Need {required_samples - current_samples} more feedback samples for reliable analysis'
|
||||
}
|
||||
|
||||
def _calculate_performance_metrics(
|
||||
self, feedback_data: pd.DataFrame
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Calculate comprehensive performance metrics.
|
||||
|
||||
Metrics:
|
||||
- Accuracy: % of predictions within acceptable error
|
||||
- MAE: Mean Absolute Error
|
||||
- RMSE: Root Mean Squared Error
|
||||
- MAPE: Mean Absolute Percentage Error
|
||||
- Bias: Systematic over/under prediction
|
||||
- R²: Correlation between predicted and actual
|
||||
"""
|
||||
predicted = feedback_data['predicted_value'].values
|
||||
actual = feedback_data['actual_value'].values
|
||||
|
||||
# Filter out invalid values
|
||||
valid_mask = ~(np.isnan(predicted) | np.isnan(actual))
|
||||
predicted = predicted[valid_mask]
|
||||
actual = actual[valid_mask]
|
||||
|
||||
if len(predicted) == 0:
|
||||
return {
|
||||
'accuracy': 0,
|
||||
'mae': 0,
|
||||
'rmse': 0,
|
||||
'mape': 0,
|
||||
'bias': 0,
|
||||
'r_squared': 0
|
||||
}
|
||||
|
||||
# Calculate errors
|
||||
errors = predicted - actual
|
||||
abs_errors = np.abs(errors)
|
||||
pct_errors = np.abs(errors / actual) * 100 if np.all(actual != 0) else np.zeros_like(errors)
|
||||
|
||||
# MAE and RMSE
|
||||
mae = float(np.mean(abs_errors))
|
||||
rmse = float(np.sqrt(np.mean(errors ** 2)))
|
||||
|
||||
# MAPE (excluding cases where actual = 0)
|
||||
valid_pct_mask = actual != 0
|
||||
mape = float(np.mean(pct_errors[valid_pct_mask])) if np.any(valid_pct_mask) else 0
|
||||
|
||||
# Accuracy (% within 10% error)
|
||||
within_10pct = np.sum(pct_errors <= 10) / len(pct_errors) * 100
|
||||
|
||||
# Bias (mean error - positive = over-prediction)
|
||||
bias = float(np.mean(errors))
|
||||
|
||||
# R² (correlation)
|
||||
if len(predicted) > 1 and np.std(actual) > 0:
|
||||
correlation = np.corrcoef(predicted, actual)[0, 1]
|
||||
r_squared = correlation ** 2
|
||||
else:
|
||||
r_squared = 0
|
||||
|
||||
return {
|
||||
'accuracy': round(within_10pct, 2), # % within 10% error
|
||||
'mae': round(mae, 2),
|
||||
'rmse': round(rmse, 2),
|
||||
'mape': round(mape, 2),
|
||||
'bias': round(bias, 2),
|
||||
'r_squared': round(r_squared, 3),
|
||||
'sample_size': len(predicted)
|
||||
}
|
||||
|
||||
def _analyze_performance_trend(
|
||||
self, feedback_data: pd.DataFrame
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Analyze performance trend over time.
|
||||
|
||||
Returns trend direction (improving/stable/degrading) and slope.
|
||||
"""
|
||||
# Sort by date
|
||||
df = feedback_data.sort_values('outcome_date').copy()
|
||||
|
||||
# Calculate rolling accuracy (7-day window)
|
||||
df['rolling_accuracy'] = df['accuracy'].rolling(window=7, min_periods=3).mean()
|
||||
|
||||
# Linear trend
|
||||
if len(df) >= 10:
|
||||
# Use day index as x
|
||||
df['day_index'] = (df['outcome_date'] - df['outcome_date'].min()).dt.days
|
||||
|
||||
# Fit linear regression
|
||||
valid_mask = ~np.isnan(df['rolling_accuracy'])
|
||||
if valid_mask.sum() >= 10:
|
||||
x = df.loc[valid_mask, 'day_index'].values
|
||||
y = df.loc[valid_mask, 'rolling_accuracy'].values
|
||||
|
||||
slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)
|
||||
|
||||
# Determine trend
|
||||
if p_value < 0.05:
|
||||
if slope > 0.1:
|
||||
trend = 'improving'
|
||||
elif slope < -0.1:
|
||||
trend = 'degrading'
|
||||
else:
|
||||
trend = 'stable'
|
||||
else:
|
||||
trend = 'stable'
|
||||
|
||||
return {
|
||||
'trend': trend,
|
||||
'slope': round(float(slope), 4),
|
||||
'p_value': round(float(p_value), 4),
|
||||
'significant': p_value < 0.05,
|
||||
'recent_performance': round(float(df['rolling_accuracy'].iloc[-1]), 2),
|
||||
'initial_performance': round(float(df['rolling_accuracy'].dropna().iloc[0]), 2)
|
||||
}
|
||||
|
||||
# Not enough data for trend
|
||||
return {
|
||||
'trend': 'insufficient_data',
|
||||
'slope': 0,
|
||||
'p_value': 1.0,
|
||||
'significant': False
|
||||
}
|
||||
|
||||
def _detect_performance_degradation(
|
||||
self,
|
||||
current_metrics: Dict[str, float],
|
||||
baseline_performance: Optional[Dict[str, float]],
|
||||
trend_analysis: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Detect if model performance has degraded.
|
||||
|
||||
Degradation triggers:
|
||||
1. Current accuracy below threshold (85%)
|
||||
2. Significant drop from baseline (>10%)
|
||||
3. Degrading trend detected
|
||||
"""
|
||||
degradation_reasons = []
|
||||
severity = 'none'
|
||||
|
||||
# Check absolute performance
|
||||
if current_metrics['accuracy'] < self.performance_threshold * 100:
|
||||
degradation_reasons.append(
|
||||
f"Accuracy {current_metrics['accuracy']:.1f}% below threshold {self.performance_threshold*100}%"
|
||||
)
|
||||
severity = 'high'
|
||||
|
||||
# Check vs baseline
|
||||
if baseline_performance and 'accuracy' in baseline_performance:
|
||||
baseline_acc = baseline_performance['accuracy']
|
||||
current_acc = current_metrics['accuracy']
|
||||
drop_pct = (baseline_acc - current_acc) / baseline_acc
|
||||
|
||||
if drop_pct > self.degradation_threshold:
|
||||
degradation_reasons.append(
|
||||
f"Accuracy dropped {drop_pct*100:.1f}% from baseline {baseline_acc:.1f}%"
|
||||
)
|
||||
severity = 'high' if severity != 'high' else severity
|
||||
|
||||
# Check trend
|
||||
if trend_analysis.get('trend') == 'degrading' and trend_analysis.get('significant'):
|
||||
degradation_reasons.append(
|
||||
f"Degrading trend detected (slope: {trend_analysis['slope']:.4f})"
|
||||
)
|
||||
severity = 'medium' if severity == 'none' else severity
|
||||
|
||||
detected = len(degradation_reasons) > 0
|
||||
|
||||
return {
|
||||
'detected': detected,
|
||||
'severity': severity,
|
||||
'reasons': degradation_reasons,
|
||||
'current_accuracy': current_metrics['accuracy'],
|
||||
'baseline_accuracy': baseline_performance.get('accuracy') if baseline_performance else None
|
||||
}
|
||||
|
||||
def _generate_retraining_recommendation(
|
||||
self,
|
||||
model_name: str,
|
||||
current_metrics: Dict[str, float],
|
||||
degradation_detected: Dict[str, Any],
|
||||
trend_analysis: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate retraining recommendation based on performance analysis.
|
||||
|
||||
Priority Levels:
|
||||
- urgent: Severe degradation, retrain immediately
|
||||
- high: Performance below threshold, retrain soon
|
||||
- medium: Trending down, schedule retraining
|
||||
- low: Stable, routine retraining
|
||||
- none: No retraining needed
|
||||
"""
|
||||
if degradation_detected['detected']:
|
||||
severity = degradation_detected['severity']
|
||||
|
||||
if severity == 'high':
|
||||
priority = 'urgent'
|
||||
recommendation = f"Retrain {model_name} immediately - severe performance degradation"
|
||||
elif severity == 'medium':
|
||||
priority = 'high'
|
||||
recommendation = f"Schedule {model_name} retraining within 7 days"
|
||||
else:
|
||||
priority = 'medium'
|
||||
recommendation = f"Schedule routine {model_name} retraining"
|
||||
|
||||
return {
|
||||
'recommended': True,
|
||||
'priority': priority,
|
||||
'recommendation': recommendation,
|
||||
'reasons': degradation_detected['reasons'],
|
||||
'estimated_improvement': self._estimate_retraining_benefit(
|
||||
current_metrics, degradation_detected
|
||||
)
|
||||
}
|
||||
|
||||
# Check if routine retraining is due (e.g., every 90 days)
|
||||
# This would require tracking last_retrained_at
|
||||
else:
|
||||
return {
|
||||
'recommended': False,
|
||||
'priority': 'none',
|
||||
'recommendation': f"{model_name} performance is acceptable, no immediate retraining needed",
|
||||
'next_review_date': (datetime.utcnow() + timedelta(days=30)).isoformat()
|
||||
}
|
||||
|
||||
def _estimate_retraining_benefit(
|
||||
self,
|
||||
current_metrics: Dict[str, float],
|
||||
degradation_detected: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Estimate expected improvement from retraining."""
|
||||
baseline_acc = degradation_detected.get('baseline_accuracy')
|
||||
current_acc = current_metrics['accuracy']
|
||||
|
||||
if baseline_acc:
|
||||
# Expect to recover 70-80% of lost performance
|
||||
expected_improvement = (baseline_acc - current_acc) * 0.75
|
||||
expected_new_acc = current_acc + expected_improvement
|
||||
|
||||
return {
|
||||
'expected_accuracy_improvement': round(expected_improvement, 2),
|
||||
'expected_new_accuracy': round(expected_new_acc, 2),
|
||||
'confidence': 'medium'
|
||||
}
|
||||
|
||||
return {
|
||||
'expected_accuracy_improvement': 'unknown',
|
||||
'confidence': 'low'
|
||||
}
|
||||
|
||||
def _identify_error_patterns(
|
||||
self, feedback_data: pd.DataFrame
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Identify systematic error patterns.
|
||||
|
||||
Patterns:
|
||||
- Consistent over/under prediction
|
||||
- Higher errors for specific ranges
|
||||
- Day-of-week effects
|
||||
- Seasonal effects
|
||||
"""
|
||||
patterns = []
|
||||
|
||||
# Pattern 1: Systematic bias
|
||||
mean_error = feedback_data['error'].mean()
|
||||
if abs(mean_error) > feedback_data['error'].std() * 0.5:
|
||||
direction = 'over-prediction' if mean_error > 0 else 'under-prediction'
|
||||
patterns.append({
|
||||
'pattern': 'systematic_bias',
|
||||
'description': f'Consistent {direction} by {abs(mean_error):.1f} units',
|
||||
'severity': 'high' if abs(mean_error) > 10 else 'medium',
|
||||
'recommendation': 'Recalibrate model bias term'
|
||||
})
|
||||
|
||||
# Pattern 2: High error for large values
|
||||
if 'predicted_value' in feedback_data.columns:
|
||||
# Split into quartiles
|
||||
feedback_data['value_quartile'] = pd.qcut(
|
||||
feedback_data['predicted_value'],
|
||||
q=4,
|
||||
labels=['Q1', 'Q2', 'Q3', 'Q4'],
|
||||
duplicates='drop'
|
||||
)
|
||||
|
||||
quartile_errors = feedback_data.groupby('value_quartile')['error_pct'].mean()
|
||||
|
||||
if len(quartile_errors) == 4 and quartile_errors['Q4'] > quartile_errors['Q1'] * 1.5:
|
||||
patterns.append({
|
||||
'pattern': 'high_value_error',
|
||||
'description': f'Higher errors for large predictions (Q4: {quartile_errors["Q4"]:.1f}% vs Q1: {quartile_errors["Q1"]:.1f}%)',
|
||||
'severity': 'medium',
|
||||
'recommendation': 'Add log transformation or separate model for high values'
|
||||
})
|
||||
|
||||
# Pattern 3: Day-of-week effect
|
||||
if 'outcome_date' in feedback_data.columns:
|
||||
feedback_data['day_of_week'] = pd.to_datetime(feedback_data['outcome_date']).dt.dayofweek
|
||||
|
||||
dow_errors = feedback_data.groupby('day_of_week')['error_pct'].mean()
|
||||
|
||||
if len(dow_errors) >= 5 and dow_errors.max() > dow_errors.min() * 1.5:
|
||||
worst_day = dow_errors.idxmax()
|
||||
day_names = ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']
|
||||
|
||||
patterns.append({
|
||||
'pattern': 'day_of_week_effect',
|
||||
'description': f'Higher errors on {day_names[worst_day]} ({dow_errors[worst_day]:.1f}%)',
|
||||
'severity': 'low',
|
||||
'recommendation': 'Add day-of-week features to model'
|
||||
})
|
||||
|
||||
return patterns
|
||||
|
||||
def _calculate_confidence_calibration(
|
||||
self, feedback_data: pd.DataFrame
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Calculate how well confidence scores match actual accuracy.
|
||||
|
||||
Well-calibrated model: 80% confidence → 80% accuracy
|
||||
"""
|
||||
if 'confidence' not in feedback_data.columns:
|
||||
return {'calibrated': False, 'reason': 'No confidence scores available'}
|
||||
|
||||
# Bin by confidence ranges
|
||||
feedback_data['confidence_bin'] = pd.cut(
|
||||
feedback_data['confidence'],
|
||||
bins=[0, 60, 70, 80, 90, 100],
|
||||
labels=['<60', '60-70', '70-80', '80-90', '90+']
|
||||
)
|
||||
|
||||
calibration_results = []
|
||||
|
||||
for conf_bin in feedback_data['confidence_bin'].unique():
|
||||
if pd.isna(conf_bin):
|
||||
continue
|
||||
|
||||
bin_data = feedback_data[feedback_data['confidence_bin'] == conf_bin]
|
||||
|
||||
if len(bin_data) >= 5:
|
||||
avg_confidence = bin_data['confidence'].mean()
|
||||
avg_accuracy = bin_data['accuracy'].mean()
|
||||
calibration_error = abs(avg_confidence - avg_accuracy)
|
||||
|
||||
calibration_results.append({
|
||||
'confidence_range': str(conf_bin),
|
||||
'avg_confidence': round(avg_confidence, 1),
|
||||
'avg_accuracy': round(avg_accuracy, 1),
|
||||
'calibration_error': round(calibration_error, 1),
|
||||
'sample_size': len(bin_data),
|
||||
'well_calibrated': calibration_error < 10
|
||||
})
|
||||
|
||||
# Overall calibration
|
||||
if calibration_results:
|
||||
overall_calibration_error = np.mean([r['calibration_error'] for r in calibration_results])
|
||||
well_calibrated = overall_calibration_error < 10
|
||||
|
||||
return {
|
||||
'calibrated': well_calibrated,
|
||||
'overall_calibration_error': round(overall_calibration_error, 2),
|
||||
'by_confidence_range': calibration_results,
|
||||
'recommendation': 'Confidence scores are well-calibrated' if well_calibrated
|
||||
else 'Recalibrate confidence scoring algorithm'
|
||||
}
|
||||
|
||||
return {'calibrated': False, 'reason': 'Insufficient data for calibration analysis'}
|
||||
|
||||
async def generate_learning_insights(
|
||||
self,
|
||||
performance_analyses: List[Dict[str, Any]],
|
||||
tenant_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Generate high-level insights about learning system performance.
|
||||
|
||||
Args:
|
||||
performance_analyses: List of model performance analyses
|
||||
tenant_id: Tenant identifier
|
||||
|
||||
Returns:
|
||||
Learning insights for system improvement
|
||||
"""
|
||||
insights = []
|
||||
|
||||
# Insight 1: Models needing urgent retraining
|
||||
urgent_models = [
|
||||
a for a in performance_analyses
|
||||
if a.get('retraining_recommendation', {}).get('priority') == 'urgent'
|
||||
]
|
||||
|
||||
if urgent_models:
|
||||
model_names = ', '.join([a['model_name'] for a in urgent_models])
|
||||
|
||||
insights.append({
|
||||
'type': 'warning',
|
||||
'priority': 'urgent',
|
||||
'category': 'system',
|
||||
'title': f'Urgent Model Retraining Required: {len(urgent_models)} Models',
|
||||
'description': f'Models requiring immediate retraining: {model_names}. Performance has degraded significantly.',
|
||||
'impact_type': 'system_health',
|
||||
'confidence': 95,
|
||||
'metrics_json': {
|
||||
'tenant_id': tenant_id,
|
||||
'urgent_models': [a['model_name'] for a in urgent_models],
|
||||
'affected_count': len(urgent_models)
|
||||
},
|
||||
'actionable': True,
|
||||
'recommendation_actions': [{
|
||||
'label': 'Retrain Models',
|
||||
'action': 'trigger_model_retraining',
|
||||
'params': {'models': [a['model_name'] for a in urgent_models]}
|
||||
}],
|
||||
'source_service': 'ai_insights',
|
||||
'source_model': 'feedback_learning_system'
|
||||
})
|
||||
|
||||
# Insight 2: Overall system health
|
||||
total_models = len(performance_analyses)
|
||||
healthy_models = [
|
||||
a for a in performance_analyses
|
||||
if not a.get('degradation_detected', {}).get('detected', False)
|
||||
]
|
||||
|
||||
health_pct = (len(healthy_models) / total_models * 100) if total_models > 0 else 0
|
||||
|
||||
if health_pct < 80:
|
||||
insights.append({
|
||||
'type': 'warning',
|
||||
'priority': 'high',
|
||||
'category': 'system',
|
||||
'title': f'Learning System Health: {health_pct:.0f}%',
|
||||
'description': f'{len(healthy_models)} of {total_models} models are performing well. System-wide performance review recommended.',
|
||||
'impact_type': 'system_health',
|
||||
'confidence': 90,
|
||||
'metrics_json': {
|
||||
'tenant_id': tenant_id,
|
||||
'total_models': total_models,
|
||||
'healthy_models': len(healthy_models),
|
||||
'health_percentage': round(health_pct, 1)
|
||||
},
|
||||
'actionable': True,
|
||||
'recommendation_actions': [{
|
||||
'label': 'Review System Health',
|
||||
'action': 'review_learning_system',
|
||||
'params': {'tenant_id': tenant_id}
|
||||
}],
|
||||
'source_service': 'ai_insights',
|
||||
'source_model': 'feedback_learning_system'
|
||||
})
|
||||
|
||||
# Insight 3: Confidence calibration issues
|
||||
poorly_calibrated = [
|
||||
a for a in performance_analyses
|
||||
if not a.get('confidence_calibration', {}).get('calibrated', True)
|
||||
]
|
||||
|
||||
if poorly_calibrated:
|
||||
insights.append({
|
||||
'type': 'opportunity',
|
||||
'priority': 'medium',
|
||||
'category': 'system',
|
||||
'title': f'Confidence Calibration Needed: {len(poorly_calibrated)} Models',
|
||||
'description': 'Confidence scores do not match actual accuracy. Recalibration recommended.',
|
||||
'impact_type': 'system_improvement',
|
||||
'confidence': 85,
|
||||
'metrics_json': {
|
||||
'tenant_id': tenant_id,
|
||||
'models_needing_calibration': [a['model_name'] for a in poorly_calibrated]
|
||||
},
|
||||
'actionable': True,
|
||||
'recommendation_actions': [{
|
||||
'label': 'Recalibrate Confidence Scores',
|
||||
'action': 'recalibrate_confidence',
|
||||
'params': {'models': [a['model_name'] for a in poorly_calibrated]}
|
||||
}],
|
||||
'source_service': 'ai_insights',
|
||||
'source_model': 'feedback_learning_system'
|
||||
})
|
||||
|
||||
return insights
|
||||
|
||||
async def calculate_roi(
|
||||
self,
|
||||
feedback_data: pd.DataFrame,
|
||||
insight_type: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Calculate ROI for applied insights.
|
||||
|
||||
Args:
|
||||
feedback_data: Feedback data with business impact metrics
|
||||
insight_type: Type of insight (e.g., 'demand_forecast', 'safety_stock')
|
||||
|
||||
Returns:
|
||||
ROI calculation with cost savings and accuracy metrics
|
||||
"""
|
||||
if len(feedback_data) == 0:
|
||||
return {'status': 'insufficient_data', 'samples': 0}
|
||||
|
||||
# Calculate accuracy
|
||||
avg_accuracy = feedback_data['accuracy'].mean()
|
||||
|
||||
# Estimate cost savings (would be more sophisticated in production)
|
||||
# For now, use impact_value from insights if available
|
||||
if 'impact_value' in feedback_data.columns:
|
||||
total_impact = feedback_data['impact_value'].sum()
|
||||
avg_impact = feedback_data['impact_value'].mean()
|
||||
|
||||
return {
|
||||
'insight_type': insight_type,
|
||||
'samples': len(feedback_data),
|
||||
'avg_accuracy': round(avg_accuracy, 2),
|
||||
'total_impact_value': round(total_impact, 2),
|
||||
'avg_impact_per_insight': round(avg_impact, 2),
|
||||
'roi_validated': True
|
||||
}
|
||||
|
||||
return {
|
||||
'insight_type': insight_type,
|
||||
'samples': len(feedback_data),
|
||||
'avg_accuracy': round(avg_accuracy, 2),
|
||||
'roi_validated': False,
|
||||
'note': 'Impact values not tracked in feedback'
|
||||
}
|
||||
11
services/ai_insights/app/models/__init__.py
Normal file
11
services/ai_insights/app/models/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""Database models for AI Insights Service."""
|
||||
|
||||
from app.models.ai_insight import AIInsight
|
||||
from app.models.insight_feedback import InsightFeedback
|
||||
from app.models.insight_correlation import InsightCorrelation
|
||||
|
||||
__all__ = [
|
||||
"AIInsight",
|
||||
"InsightFeedback",
|
||||
"InsightCorrelation",
|
||||
]
|
||||
129
services/ai_insights/app/models/ai_insight.py
Normal file
129
services/ai_insights/app/models/ai_insight.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""AI Insight database model."""
|
||||
|
||||
from sqlalchemy import Column, String, Integer, Boolean, DECIMAL, TIMESTAMP, Text, Index, CheckConstraint
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from sqlalchemy.sql import func
|
||||
import uuid
|
||||
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class AIInsight(Base):
|
||||
"""AI Insight model for storing intelligent recommendations and predictions."""
|
||||
|
||||
__tablename__ = "ai_insights"
|
||||
|
||||
# Primary Key
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
|
||||
# Tenant Information
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
|
||||
# Classification
|
||||
type = Column(
|
||||
String(50),
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="optimization, alert, prediction, recommendation, insight, anomaly"
|
||||
)
|
||||
priority = Column(
|
||||
String(20),
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="low, medium, high, critical"
|
||||
)
|
||||
category = Column(
|
||||
String(50),
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="forecasting, inventory, production, procurement, customer, cost, quality, efficiency, demand, maintenance, energy, scheduling"
|
||||
)
|
||||
|
||||
# Content
|
||||
title = Column(String(255), nullable=False)
|
||||
description = Column(Text, nullable=False)
|
||||
|
||||
# Impact Information
|
||||
impact_type = Column(
|
||||
String(50),
|
||||
comment="cost_savings, revenue_increase, waste_reduction, efficiency_gain, quality_improvement, risk_mitigation"
|
||||
)
|
||||
impact_value = Column(DECIMAL(10, 2), comment="Numeric impact value")
|
||||
impact_unit = Column(
|
||||
String(20),
|
||||
comment="euros, percentage, hours, units, euros/month, euros/year"
|
||||
)
|
||||
|
||||
# Confidence and Metrics
|
||||
confidence = Column(
|
||||
Integer,
|
||||
CheckConstraint('confidence >= 0 AND confidence <= 100'),
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="Confidence score 0-100"
|
||||
)
|
||||
metrics_json = Column(
|
||||
JSONB,
|
||||
comment="Dynamic metrics specific to insight type"
|
||||
)
|
||||
|
||||
# Actionability
|
||||
actionable = Column(
|
||||
Boolean,
|
||||
default=True,
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="Whether this insight can be acted upon"
|
||||
)
|
||||
recommendation_actions = Column(
|
||||
JSONB,
|
||||
comment="List of possible actions: [{label, action, endpoint}]"
|
||||
)
|
||||
|
||||
# Status
|
||||
status = Column(
|
||||
String(20),
|
||||
default='new',
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="new, acknowledged, in_progress, applied, dismissed, expired"
|
||||
)
|
||||
|
||||
# Source Information
|
||||
source_service = Column(
|
||||
String(50),
|
||||
comment="Service that generated this insight"
|
||||
)
|
||||
source_data_id = Column(
|
||||
String(100),
|
||||
comment="Reference to source data (e.g., forecast_id, model_id)"
|
||||
)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(
|
||||
TIMESTAMP(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
updated_at = Column(
|
||||
TIMESTAMP(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False
|
||||
)
|
||||
applied_at = Column(TIMESTAMP(timezone=True), comment="When insight was applied")
|
||||
expired_at = Column(
|
||||
TIMESTAMP(timezone=True),
|
||||
comment="When insight expires (auto-calculated based on TTL)"
|
||||
)
|
||||
|
||||
# Composite Indexes
|
||||
__table_args__ = (
|
||||
Index('idx_tenant_status_category', 'tenant_id', 'status', 'category'),
|
||||
Index('idx_tenant_created_confidence', 'tenant_id', 'created_at', 'confidence'),
|
||||
Index('idx_actionable_status', 'actionable', 'status'),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<AIInsight(id={self.id}, type={self.type}, title={self.title[:30]}, confidence={self.confidence})>"
|
||||
69
services/ai_insights/app/models/insight_correlation.py
Normal file
69
services/ai_insights/app/models/insight_correlation.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Insight Correlation database model for cross-service intelligence."""
|
||||
|
||||
from sqlalchemy import Column, String, Integer, DECIMAL, TIMESTAMP, ForeignKey, Index
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
import uuid
|
||||
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class InsightCorrelation(Base):
|
||||
"""Track correlations between insights from different services."""
|
||||
|
||||
__tablename__ = "insight_correlations"
|
||||
|
||||
# Primary Key
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
|
||||
# Foreign Keys to AIInsights
|
||||
parent_insight_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey('ai_insights.id', ondelete='CASCADE'),
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="Primary insight that leads to correlation"
|
||||
)
|
||||
child_insight_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey('ai_insights.id', ondelete='CASCADE'),
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="Related insight"
|
||||
)
|
||||
|
||||
# Correlation Information
|
||||
correlation_type = Column(
|
||||
String(50),
|
||||
nullable=False,
|
||||
comment="forecast_inventory, production_procurement, weather_customer, demand_supplier, etc."
|
||||
)
|
||||
correlation_strength = Column(
|
||||
DECIMAL(3, 2),
|
||||
nullable=False,
|
||||
comment="0.00 to 1.00 indicating strength of correlation"
|
||||
)
|
||||
|
||||
# Combined Metrics
|
||||
combined_confidence = Column(
|
||||
Integer,
|
||||
comment="Weighted combined confidence of both insights"
|
||||
)
|
||||
|
||||
# Timestamp
|
||||
created_at = Column(
|
||||
TIMESTAMP(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
|
||||
# Composite Indexes
|
||||
__table_args__ = (
|
||||
Index('idx_parent_child', 'parent_insight_id', 'child_insight_id'),
|
||||
Index('idx_correlation_type', 'correlation_type'),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<InsightCorrelation(id={self.id}, type={self.correlation_type}, strength={self.correlation_strength})>"
|
||||
87
services/ai_insights/app/models/insight_feedback.py
Normal file
87
services/ai_insights/app/models/insight_feedback.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""Insight Feedback database model for closed-loop learning."""
|
||||
|
||||
from sqlalchemy import Column, String, Boolean, DECIMAL, TIMESTAMP, Text, ForeignKey, Index
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
import uuid
|
||||
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class InsightFeedback(Base):
|
||||
"""Feedback tracking for AI Insights to enable learning."""
|
||||
|
||||
__tablename__ = "insight_feedback"
|
||||
|
||||
# Primary Key
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
|
||||
# Foreign Key to AIInsight
|
||||
insight_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey('ai_insights.id', ondelete='CASCADE'),
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
|
||||
# Action Information
|
||||
action_taken = Column(
|
||||
String(100),
|
||||
comment="Specific action that was taken from recommendation_actions"
|
||||
)
|
||||
|
||||
# Result Data
|
||||
result_data = Column(
|
||||
JSONB,
|
||||
comment="Detailed result data from applying the insight"
|
||||
)
|
||||
|
||||
# Success Tracking
|
||||
success = Column(
|
||||
Boolean,
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="Whether the insight application was successful"
|
||||
)
|
||||
error_message = Column(
|
||||
Text,
|
||||
comment="Error message if success = false"
|
||||
)
|
||||
|
||||
# Impact Comparison
|
||||
expected_impact_value = Column(
|
||||
DECIMAL(10, 2),
|
||||
comment="Expected impact value from original insight"
|
||||
)
|
||||
actual_impact_value = Column(
|
||||
DECIMAL(10, 2),
|
||||
comment="Measured actual impact after application"
|
||||
)
|
||||
variance_percentage = Column(
|
||||
DECIMAL(5, 2),
|
||||
comment="(actual - expected) / expected * 100"
|
||||
)
|
||||
|
||||
# User Information
|
||||
applied_by = Column(
|
||||
String(100),
|
||||
comment="User or system that applied the insight"
|
||||
)
|
||||
|
||||
# Timestamp
|
||||
created_at = Column(
|
||||
TIMESTAMP(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
|
||||
# Composite Indexes
|
||||
__table_args__ = (
|
||||
Index('idx_insight_success', 'insight_id', 'success'),
|
||||
Index('idx_created_success', 'created_at', 'success'),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<InsightFeedback(id={self.id}, insight_id={self.insight_id}, success={self.success})>"
|
||||
9
services/ai_insights/app/repositories/__init__.py
Normal file
9
services/ai_insights/app/repositories/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""Repositories for AI Insights Service."""
|
||||
|
||||
from app.repositories.insight_repository import InsightRepository
|
||||
from app.repositories.feedback_repository import FeedbackRepository
|
||||
|
||||
__all__ = [
|
||||
"InsightRepository",
|
||||
"FeedbackRepository",
|
||||
]
|
||||
81
services/ai_insights/app/repositories/feedback_repository.py
Normal file
81
services/ai_insights/app/repositories/feedback_repository.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Repository for Insight Feedback database operations."""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, desc
|
||||
from typing import Optional, List
|
||||
from uuid import UUID
|
||||
from decimal import Decimal
|
||||
|
||||
from app.models.insight_feedback import InsightFeedback
|
||||
from app.schemas.feedback import InsightFeedbackCreate
|
||||
|
||||
|
||||
class FeedbackRepository:
|
||||
"""Repository for Insight Feedback operations."""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def create(self, feedback_data: InsightFeedbackCreate) -> InsightFeedback:
|
||||
"""Create feedback for an insight."""
|
||||
# Calculate variance if both values provided
|
||||
variance = None
|
||||
if (feedback_data.expected_impact_value is not None and
|
||||
feedback_data.actual_impact_value is not None and
|
||||
feedback_data.expected_impact_value != 0):
|
||||
variance = (
|
||||
(feedback_data.actual_impact_value - feedback_data.expected_impact_value) /
|
||||
feedback_data.expected_impact_value * 100
|
||||
)
|
||||
|
||||
feedback = InsightFeedback(
|
||||
**feedback_data.model_dump(exclude={'variance_percentage'}),
|
||||
variance_percentage=variance
|
||||
)
|
||||
self.session.add(feedback)
|
||||
await self.session.flush()
|
||||
await self.session.refresh(feedback)
|
||||
return feedback
|
||||
|
||||
async def get_by_id(self, feedback_id: UUID) -> Optional[InsightFeedback]:
|
||||
"""Get feedback by ID."""
|
||||
query = select(InsightFeedback).where(InsightFeedback.id == feedback_id)
|
||||
result = await self.session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_insight(self, insight_id: UUID) -> List[InsightFeedback]:
|
||||
"""Get all feedback for an insight."""
|
||||
query = select(InsightFeedback).where(
|
||||
InsightFeedback.insight_id == insight_id
|
||||
).order_by(desc(InsightFeedback.created_at))
|
||||
|
||||
result = await self.session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_success_rate(self, insight_type: Optional[str] = None) -> float:
|
||||
"""Calculate success rate for insights."""
|
||||
query = select(InsightFeedback)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
feedbacks = result.scalars().all()
|
||||
|
||||
if not feedbacks:
|
||||
return 0.0
|
||||
|
||||
successful = sum(1 for f in feedbacks if f.success)
|
||||
return (successful / len(feedbacks)) * 100
|
||||
|
||||
async def get_average_impact_variance(self) -> Decimal:
|
||||
"""Calculate average variance between expected and actual impact."""
|
||||
query = select(InsightFeedback).where(
|
||||
InsightFeedback.variance_percentage.isnot(None)
|
||||
)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
feedbacks = result.scalars().all()
|
||||
|
||||
if not feedbacks:
|
||||
return Decimal('0.0')
|
||||
|
||||
avg_variance = sum(f.variance_percentage for f in feedbacks) / len(feedbacks)
|
||||
return Decimal(str(round(float(avg_variance), 2)))
|
||||
254
services/ai_insights/app/repositories/insight_repository.py
Normal file
254
services/ai_insights/app/repositories/insight_repository.py
Normal file
@@ -0,0 +1,254 @@
|
||||
"""Repository for AI Insight database operations."""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, and_, or_, desc
|
||||
from sqlalchemy.orm import selectinload
|
||||
from typing import Optional, List, Dict, Any
|
||||
from uuid import UUID
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from app.models.ai_insight import AIInsight
|
||||
from app.schemas.insight import AIInsightCreate, AIInsightUpdate, InsightFilters
|
||||
|
||||
|
||||
class InsightRepository:
|
||||
"""Repository for AI Insight operations."""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def create(self, insight_data: AIInsightCreate) -> AIInsight:
|
||||
"""Create a new AI Insight."""
|
||||
# Calculate expiration date (default 7 days from now)
|
||||
from app.core.config import settings
|
||||
expired_at = datetime.utcnow() + timedelta(days=settings.DEFAULT_INSIGHT_TTL_DAYS)
|
||||
|
||||
insight = AIInsight(
|
||||
**insight_data.model_dump(),
|
||||
expired_at=expired_at
|
||||
)
|
||||
self.session.add(insight)
|
||||
await self.session.flush()
|
||||
await self.session.refresh(insight)
|
||||
return insight
|
||||
|
||||
async def get_by_id(self, insight_id: UUID) -> Optional[AIInsight]:
|
||||
"""Get insight by ID."""
|
||||
query = select(AIInsight).where(AIInsight.id == insight_id)
|
||||
result = await self.session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_tenant(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
filters: Optional[InsightFilters] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> tuple[List[AIInsight], int]:
|
||||
"""Get insights for a tenant with filters and pagination."""
|
||||
# Build base query
|
||||
query = select(AIInsight).where(AIInsight.tenant_id == tenant_id)
|
||||
|
||||
# Apply filters
|
||||
if filters:
|
||||
if filters.category and filters.category != 'all':
|
||||
query = query.where(AIInsight.category == filters.category)
|
||||
|
||||
if filters.priority and filters.priority != 'all':
|
||||
query = query.where(AIInsight.priority == filters.priority)
|
||||
|
||||
if filters.status and filters.status != 'all':
|
||||
query = query.where(AIInsight.status == filters.status)
|
||||
|
||||
if filters.actionable_only:
|
||||
query = query.where(AIInsight.actionable == True)
|
||||
|
||||
if filters.min_confidence > 0:
|
||||
query = query.where(AIInsight.confidence >= filters.min_confidence)
|
||||
|
||||
if filters.source_service:
|
||||
query = query.where(AIInsight.source_service == filters.source_service)
|
||||
|
||||
if filters.from_date:
|
||||
query = query.where(AIInsight.created_at >= filters.from_date)
|
||||
|
||||
if filters.to_date:
|
||||
query = query.where(AIInsight.created_at <= filters.to_date)
|
||||
|
||||
# Get total count
|
||||
count_query = select(func.count()).select_from(query.subquery())
|
||||
total_result = await self.session.execute(count_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
# Apply ordering, pagination
|
||||
query = query.order_by(desc(AIInsight.confidence), desc(AIInsight.created_at))
|
||||
query = query.offset(skip).limit(limit)
|
||||
|
||||
# Execute query
|
||||
result = await self.session.execute(query)
|
||||
insights = result.scalars().all()
|
||||
|
||||
return list(insights), total
|
||||
|
||||
async def get_orchestration_ready_insights(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
target_date: datetime,
|
||||
min_confidence: int = 70
|
||||
) -> Dict[str, List[AIInsight]]:
|
||||
"""Get actionable insights for orchestration."""
|
||||
query = select(AIInsight).where(
|
||||
and_(
|
||||
AIInsight.tenant_id == tenant_id,
|
||||
AIInsight.actionable == True,
|
||||
AIInsight.confidence >= min_confidence,
|
||||
AIInsight.status.in_(['new', 'acknowledged']),
|
||||
or_(
|
||||
AIInsight.expired_at.is_(None),
|
||||
AIInsight.expired_at > datetime.utcnow()
|
||||
)
|
||||
)
|
||||
).order_by(desc(AIInsight.confidence))
|
||||
|
||||
result = await self.session.execute(query)
|
||||
insights = result.scalars().all()
|
||||
|
||||
# Categorize insights
|
||||
categorized = {
|
||||
'forecast_adjustments': [],
|
||||
'procurement_recommendations': [],
|
||||
'production_optimizations': [],
|
||||
'supplier_alerts': [],
|
||||
'price_opportunities': []
|
||||
}
|
||||
|
||||
for insight in insights:
|
||||
if insight.category == 'forecasting':
|
||||
categorized['forecast_adjustments'].append(insight)
|
||||
elif insight.category == 'procurement':
|
||||
if 'supplier' in insight.title.lower():
|
||||
categorized['supplier_alerts'].append(insight)
|
||||
elif 'price' in insight.title.lower():
|
||||
categorized['price_opportunities'].append(insight)
|
||||
else:
|
||||
categorized['procurement_recommendations'].append(insight)
|
||||
elif insight.category == 'production':
|
||||
categorized['production_optimizations'].append(insight)
|
||||
|
||||
return categorized
|
||||
|
||||
async def update(self, insight_id: UUID, update_data: AIInsightUpdate) -> Optional[AIInsight]:
|
||||
"""Update an insight."""
|
||||
insight = await self.get_by_id(insight_id)
|
||||
if not insight:
|
||||
return None
|
||||
|
||||
for field, value in update_data.model_dump(exclude_unset=True).items():
|
||||
setattr(insight, field, value)
|
||||
|
||||
await self.session.flush()
|
||||
await self.session.refresh(insight)
|
||||
return insight
|
||||
|
||||
async def delete(self, insight_id: UUID) -> bool:
|
||||
"""Delete (dismiss) an insight."""
|
||||
insight = await self.get_by_id(insight_id)
|
||||
if not insight:
|
||||
return False
|
||||
|
||||
insight.status = 'dismissed'
|
||||
await self.session.flush()
|
||||
return True
|
||||
|
||||
async def get_metrics(self, tenant_id: UUID) -> Dict[str, Any]:
|
||||
"""Get aggregate metrics for insights."""
|
||||
query = select(AIInsight).where(
|
||||
and_(
|
||||
AIInsight.tenant_id == tenant_id,
|
||||
AIInsight.status != 'dismissed',
|
||||
or_(
|
||||
AIInsight.expired_at.is_(None),
|
||||
AIInsight.expired_at > datetime.utcnow()
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
insights = result.scalars().all()
|
||||
|
||||
if not insights:
|
||||
return {
|
||||
'total_insights': 0,
|
||||
'actionable_insights': 0,
|
||||
'average_confidence': 0,
|
||||
'high_priority_count': 0,
|
||||
'medium_priority_count': 0,
|
||||
'low_priority_count': 0,
|
||||
'critical_priority_count': 0,
|
||||
'by_category': {},
|
||||
'by_status': {},
|
||||
'total_potential_impact': 0
|
||||
}
|
||||
|
||||
# Calculate metrics
|
||||
total = len(insights)
|
||||
actionable = sum(1 for i in insights if i.actionable)
|
||||
avg_confidence = sum(i.confidence for i in insights) / total if total > 0 else 0
|
||||
|
||||
# Priority counts
|
||||
priority_counts = {
|
||||
'high': sum(1 for i in insights if i.priority == 'high'),
|
||||
'medium': sum(1 for i in insights if i.priority == 'medium'),
|
||||
'low': sum(1 for i in insights if i.priority == 'low'),
|
||||
'critical': sum(1 for i in insights if i.priority == 'critical')
|
||||
}
|
||||
|
||||
# By category
|
||||
by_category = {}
|
||||
for insight in insights:
|
||||
by_category[insight.category] = by_category.get(insight.category, 0) + 1
|
||||
|
||||
# By status
|
||||
by_status = {}
|
||||
for insight in insights:
|
||||
by_status[insight.status] = by_status.get(insight.status, 0) + 1
|
||||
|
||||
# Total potential impact
|
||||
total_impact = sum(
|
||||
float(i.impact_value) for i in insights
|
||||
if i.impact_value and i.impact_type in ['cost_savings', 'revenue_increase']
|
||||
)
|
||||
|
||||
return {
|
||||
'total_insights': total,
|
||||
'actionable_insights': actionable,
|
||||
'average_confidence': round(avg_confidence, 1),
|
||||
'high_priority_count': priority_counts['high'],
|
||||
'medium_priority_count': priority_counts['medium'],
|
||||
'low_priority_count': priority_counts['low'],
|
||||
'critical_priority_count': priority_counts['critical'],
|
||||
'by_category': by_category,
|
||||
'by_status': by_status,
|
||||
'total_potential_impact': round(total_impact, 2)
|
||||
}
|
||||
|
||||
async def expire_old_insights(self) -> int:
|
||||
"""Mark expired insights as expired."""
|
||||
query = select(AIInsight).where(
|
||||
and_(
|
||||
AIInsight.expired_at.isnot(None),
|
||||
AIInsight.expired_at <= datetime.utcnow(),
|
||||
AIInsight.status.notin_(['applied', 'dismissed', 'expired'])
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
insights = result.scalars().all()
|
||||
|
||||
count = 0
|
||||
for insight in insights:
|
||||
insight.status = 'expired'
|
||||
count += 1
|
||||
|
||||
await self.session.flush()
|
||||
return count
|
||||
27
services/ai_insights/app/schemas/__init__.py
Normal file
27
services/ai_insights/app/schemas/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Pydantic schemas for AI Insights Service."""
|
||||
|
||||
from app.schemas.insight import (
|
||||
AIInsightBase,
|
||||
AIInsightCreate,
|
||||
AIInsightUpdate,
|
||||
AIInsightResponse,
|
||||
AIInsightList,
|
||||
InsightMetrics,
|
||||
InsightFilters
|
||||
)
|
||||
from app.schemas.feedback import (
|
||||
InsightFeedbackCreate,
|
||||
InsightFeedbackResponse
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AIInsightBase",
|
||||
"AIInsightCreate",
|
||||
"AIInsightUpdate",
|
||||
"AIInsightResponse",
|
||||
"AIInsightList",
|
||||
"InsightMetrics",
|
||||
"InsightFilters",
|
||||
"InsightFeedbackCreate",
|
||||
"InsightFeedbackResponse",
|
||||
]
|
||||
37
services/ai_insights/app/schemas/feedback.py
Normal file
37
services/ai_insights/app/schemas/feedback.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Pydantic schemas for Insight Feedback."""
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
from decimal import Decimal
|
||||
|
||||
|
||||
class InsightFeedbackBase(BaseModel):
|
||||
"""Base schema for Insight Feedback."""
|
||||
|
||||
action_taken: str
|
||||
result_data: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
success: bool
|
||||
error_message: Optional[str] = None
|
||||
expected_impact_value: Optional[Decimal] = None
|
||||
actual_impact_value: Optional[Decimal] = None
|
||||
variance_percentage: Optional[Decimal] = None
|
||||
|
||||
|
||||
class InsightFeedbackCreate(InsightFeedbackBase):
|
||||
"""Schema for creating feedback."""
|
||||
|
||||
insight_id: UUID
|
||||
applied_by: Optional[str] = "system"
|
||||
|
||||
|
||||
class InsightFeedbackResponse(InsightFeedbackBase):
|
||||
"""Schema for feedback response."""
|
||||
|
||||
id: UUID
|
||||
insight_id: UUID
|
||||
applied_by: str
|
||||
created_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
93
services/ai_insights/app/schemas/insight.py
Normal file
93
services/ai_insights/app/schemas/insight.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""Pydantic schemas for AI Insights."""
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
from decimal import Decimal
|
||||
|
||||
|
||||
class AIInsightBase(BaseModel):
|
||||
"""Base schema for AI Insight."""
|
||||
|
||||
type: str = Field(..., description="optimization, alert, prediction, recommendation, insight, anomaly")
|
||||
priority: str = Field(..., description="low, medium, high, critical")
|
||||
category: str = Field(..., description="forecasting, inventory, production, procurement, customer, etc.")
|
||||
title: str = Field(..., max_length=255)
|
||||
description: str
|
||||
impact_type: Optional[str] = Field(None, description="cost_savings, revenue_increase, waste_reduction, etc.")
|
||||
impact_value: Optional[Decimal] = None
|
||||
impact_unit: Optional[str] = Field(None, description="euros, percentage, hours, units, etc.")
|
||||
confidence: int = Field(..., ge=0, le=100, description="Confidence score 0-100")
|
||||
metrics_json: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
actionable: bool = True
|
||||
recommendation_actions: Optional[List[Dict[str, str]]] = Field(default_factory=list)
|
||||
source_service: Optional[str] = None
|
||||
source_data_id: Optional[str] = None
|
||||
|
||||
|
||||
class AIInsightCreate(AIInsightBase):
|
||||
"""Schema for creating a new AI Insight."""
|
||||
|
||||
tenant_id: UUID
|
||||
|
||||
|
||||
class AIInsightUpdate(BaseModel):
|
||||
"""Schema for updating an AI Insight."""
|
||||
|
||||
status: Optional[str] = Field(None, description="new, acknowledged, in_progress, applied, dismissed, expired")
|
||||
applied_at: Optional[datetime] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class AIInsightResponse(AIInsightBase):
|
||||
"""Schema for AI Insight response."""
|
||||
|
||||
id: UUID
|
||||
tenant_id: UUID
|
||||
status: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
applied_at: Optional[datetime] = None
|
||||
expired_at: Optional[datetime] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class AIInsightList(BaseModel):
|
||||
"""Paginated list of AI Insights."""
|
||||
|
||||
items: List[AIInsightResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
total_pages: int
|
||||
|
||||
|
||||
class InsightMetrics(BaseModel):
|
||||
"""Aggregate metrics for insights."""
|
||||
|
||||
total_insights: int
|
||||
actionable_insights: int
|
||||
average_confidence: float
|
||||
high_priority_count: int
|
||||
medium_priority_count: int
|
||||
low_priority_count: int
|
||||
critical_priority_count: int
|
||||
by_category: Dict[str, int]
|
||||
by_status: Dict[str, int]
|
||||
total_potential_impact: Optional[Decimal] = None
|
||||
|
||||
|
||||
class InsightFilters(BaseModel):
|
||||
"""Filters for querying insights."""
|
||||
|
||||
category: Optional[str] = None
|
||||
priority: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
actionable_only: bool = False
|
||||
min_confidence: int = 0
|
||||
source_service: Optional[str] = None
|
||||
from_date: Optional[datetime] = None
|
||||
to_date: Optional[datetime] = None
|
||||
229
services/ai_insights/app/scoring/confidence_calculator.py
Normal file
229
services/ai_insights/app/scoring/confidence_calculator.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""Confidence scoring calculator for AI Insights."""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
import math
|
||||
|
||||
|
||||
class ConfidenceCalculator:
|
||||
"""
|
||||
Calculate unified confidence scores across different insight types.
|
||||
|
||||
Confidence is calculated based on multiple factors:
|
||||
- Data quality (completeness, consistency)
|
||||
- Model performance (historical accuracy)
|
||||
- Sample size (statistical significance)
|
||||
- Recency (how recent is the data)
|
||||
- Historical accuracy (past insight performance)
|
||||
"""
|
||||
|
||||
# Weights for different factors
|
||||
WEIGHTS = {
|
||||
'data_quality': 0.25,
|
||||
'model_performance': 0.30,
|
||||
'sample_size': 0.20,
|
||||
'recency': 0.15,
|
||||
'historical_accuracy': 0.10
|
||||
}
|
||||
|
||||
def calculate_confidence(
|
||||
self,
|
||||
data_quality_score: Optional[float] = None,
|
||||
model_performance_score: Optional[float] = None,
|
||||
sample_size: Optional[int] = None,
|
||||
data_date: Optional[datetime] = None,
|
||||
historical_accuracy: Optional[float] = None,
|
||||
insight_type: Optional[str] = None
|
||||
) -> int:
|
||||
"""
|
||||
Calculate overall confidence score (0-100).
|
||||
|
||||
Args:
|
||||
data_quality_score: 0-1 score for data quality
|
||||
model_performance_score: 0-1 score from model metrics (e.g., 1-MAPE)
|
||||
sample_size: Number of data points used
|
||||
data_date: Date of most recent data
|
||||
historical_accuracy: 0-1 score from past insight performance
|
||||
insight_type: Type of insight for specific adjustments
|
||||
|
||||
Returns:
|
||||
int: Confidence score 0-100
|
||||
"""
|
||||
scores = {}
|
||||
|
||||
# Data Quality Score (0-100)
|
||||
if data_quality_score is not None:
|
||||
scores['data_quality'] = min(100, data_quality_score * 100)
|
||||
else:
|
||||
scores['data_quality'] = 70 # Default
|
||||
|
||||
# Model Performance Score (0-100)
|
||||
if model_performance_score is not None:
|
||||
scores['model_performance'] = min(100, model_performance_score * 100)
|
||||
else:
|
||||
scores['model_performance'] = 75 # Default
|
||||
|
||||
# Sample Size Score (0-100)
|
||||
if sample_size is not None:
|
||||
scores['sample_size'] = self._score_sample_size(sample_size)
|
||||
else:
|
||||
scores['sample_size'] = 60 # Default
|
||||
|
||||
# Recency Score (0-100)
|
||||
if data_date is not None:
|
||||
scores['recency'] = self._score_recency(data_date)
|
||||
else:
|
||||
scores['recency'] = 80 # Default
|
||||
|
||||
# Historical Accuracy Score (0-100)
|
||||
if historical_accuracy is not None:
|
||||
scores['historical_accuracy'] = min(100, historical_accuracy * 100)
|
||||
else:
|
||||
scores['historical_accuracy'] = 65 # Default
|
||||
|
||||
# Calculate weighted average
|
||||
confidence = sum(
|
||||
scores[factor] * self.WEIGHTS[factor]
|
||||
for factor in scores
|
||||
)
|
||||
|
||||
# Apply insight-type specific adjustments
|
||||
confidence = self._apply_type_adjustments(confidence, insight_type)
|
||||
|
||||
return int(round(confidence))
|
||||
|
||||
def _score_sample_size(self, sample_size: int) -> float:
|
||||
"""
|
||||
Score based on sample size using logarithmic scale.
|
||||
|
||||
Args:
|
||||
sample_size: Number of data points
|
||||
|
||||
Returns:
|
||||
float: Score 0-100
|
||||
"""
|
||||
if sample_size <= 10:
|
||||
return 30.0
|
||||
elif sample_size <= 30:
|
||||
return 50.0
|
||||
elif sample_size <= 100:
|
||||
return 70.0
|
||||
elif sample_size <= 365:
|
||||
return 85.0
|
||||
else:
|
||||
# Logarithmic scaling for larger samples
|
||||
return min(100.0, 85 + (math.log10(sample_size) - math.log10(365)) * 10)
|
||||
|
||||
def _score_recency(self, data_date: datetime) -> float:
|
||||
"""
|
||||
Score based on data recency.
|
||||
|
||||
Args:
|
||||
data_date: Date of most recent data
|
||||
|
||||
Returns:
|
||||
float: Score 0-100
|
||||
"""
|
||||
days_old = (datetime.utcnow() - data_date).days
|
||||
|
||||
if days_old == 0:
|
||||
return 100.0
|
||||
elif days_old <= 1:
|
||||
return 95.0
|
||||
elif days_old <= 3:
|
||||
return 90.0
|
||||
elif days_old <= 7:
|
||||
return 80.0
|
||||
elif days_old <= 14:
|
||||
return 70.0
|
||||
elif days_old <= 30:
|
||||
return 60.0
|
||||
elif days_old <= 60:
|
||||
return 45.0
|
||||
else:
|
||||
# Exponential decay for older data
|
||||
return max(20.0, 60 * math.exp(-days_old / 60))
|
||||
|
||||
def _apply_type_adjustments(self, base_confidence: float, insight_type: Optional[str]) -> float:
|
||||
"""
|
||||
Apply insight-type specific confidence adjustments.
|
||||
|
||||
Args:
|
||||
base_confidence: Base confidence score
|
||||
insight_type: Type of insight
|
||||
|
||||
Returns:
|
||||
float: Adjusted confidence
|
||||
"""
|
||||
if not insight_type:
|
||||
return base_confidence
|
||||
|
||||
adjustments = {
|
||||
'prediction': -5, # Predictions inherently less certain
|
||||
'optimization': +2, # Optimizations based on solid math
|
||||
'alert': +3, # Alerts based on thresholds
|
||||
'recommendation': 0, # No adjustment
|
||||
'insight': +2, # Insights from data analysis
|
||||
'anomaly': -3 # Anomalies are uncertain
|
||||
}
|
||||
|
||||
adjustment = adjustments.get(insight_type, 0)
|
||||
return max(0, min(100, base_confidence + adjustment))
|
||||
|
||||
def calculate_forecast_confidence(
|
||||
self,
|
||||
model_mape: float,
|
||||
forecast_horizon_days: int,
|
||||
data_points: int,
|
||||
last_data_date: datetime
|
||||
) -> int:
|
||||
"""
|
||||
Specialized confidence calculation for forecasting insights.
|
||||
|
||||
Args:
|
||||
model_mape: Model MAPE (Mean Absolute Percentage Error)
|
||||
forecast_horizon_days: How many days ahead
|
||||
data_points: Number of historical data points
|
||||
last_data_date: Date of last training data
|
||||
|
||||
Returns:
|
||||
int: Confidence score 0-100
|
||||
"""
|
||||
# Model performance: 1 - (MAPE/100) capped at 1
|
||||
model_score = max(0, 1 - (model_mape / 100))
|
||||
|
||||
# Horizon penalty: Longer horizons = less confidence
|
||||
horizon_factor = max(0.5, 1 - (forecast_horizon_days / 30))
|
||||
|
||||
return self.calculate_confidence(
|
||||
data_quality_score=0.9, # Assume good quality
|
||||
model_performance_score=model_score * horizon_factor,
|
||||
sample_size=data_points,
|
||||
data_date=last_data_date,
|
||||
insight_type='prediction'
|
||||
)
|
||||
|
||||
def calculate_optimization_confidence(
|
||||
self,
|
||||
calculation_accuracy: float,
|
||||
data_completeness: float,
|
||||
sample_size: int
|
||||
) -> int:
|
||||
"""
|
||||
Confidence for optimization recommendations.
|
||||
|
||||
Args:
|
||||
calculation_accuracy: 0-1 score for optimization calculation reliability
|
||||
data_completeness: 0-1 score for data completeness
|
||||
sample_size: Number of data points
|
||||
|
||||
Returns:
|
||||
int: Confidence score 0-100
|
||||
"""
|
||||
return self.calculate_confidence(
|
||||
data_quality_score=data_completeness,
|
||||
model_performance_score=calculation_accuracy,
|
||||
sample_size=sample_size,
|
||||
data_date=datetime.utcnow(),
|
||||
insight_type='optimization'
|
||||
)
|
||||
67
services/ai_insights/migrations/env.py
Normal file
67
services/ai_insights/migrations/env.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Alembic environment configuration."""
|
||||
|
||||
from logging.config import fileConfig
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
from alembic import context
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, os.path.realpath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import Base
|
||||
from app.models import * # Import all models
|
||||
|
||||
# this is the Alembic Config object
|
||||
config = context.config
|
||||
|
||||
# Interpret the config file for Python logging
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# Set sqlalchemy.url from settings
|
||||
# Replace asyncpg with psycopg2 for synchronous Alembic migrations
|
||||
db_url = settings.DATABASE_URL.replace('postgresql+asyncpg://', 'postgresql://')
|
||||
config.set_main_option('sqlalchemy.url', db_url)
|
||||
|
||||
# Add your model's MetaData object here for 'autogenerate' support
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode."""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode."""
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
26
services/ai_insights/migrations/script.py.mako
Normal file
26
services/ai_insights/migrations/script.py.mako
Normal file
@@ -0,0 +1,26 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
@@ -0,0 +1,111 @@
|
||||
"""Initial schema for AI Insights Service
|
||||
|
||||
Revision ID: 001
|
||||
Revises:
|
||||
Create Date: 2025-11-02 14:30:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '001'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create ai_insights table
|
||||
op.create_table(
|
||||
'ai_insights',
|
||||
sa.Column('id', UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column('tenant_id', UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('type', sa.String(50), nullable=False),
|
||||
sa.Column('priority', sa.String(20), nullable=False),
|
||||
sa.Column('category', sa.String(50), nullable=False),
|
||||
sa.Column('title', sa.String(255), nullable=False),
|
||||
sa.Column('description', sa.Text, nullable=False),
|
||||
sa.Column('impact_type', sa.String(50)),
|
||||
sa.Column('impact_value', sa.DECIMAL(10, 2)),
|
||||
sa.Column('impact_unit', sa.String(20)),
|
||||
sa.Column('confidence', sa.Integer, nullable=False),
|
||||
sa.Column('metrics_json', JSONB),
|
||||
sa.Column('actionable', sa.Boolean, nullable=False, server_default='true'),
|
||||
sa.Column('recommendation_actions', JSONB),
|
||||
sa.Column('status', sa.String(20), nullable=False, server_default='new'),
|
||||
sa.Column('source_service', sa.String(50)),
|
||||
sa.Column('source_data_id', sa.String(100)),
|
||||
sa.Column('created_at', sa.TIMESTAMP(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.Column('updated_at', sa.TIMESTAMP(timezone=True), server_default=sa.func.now(), onupdate=sa.func.now(), nullable=False),
|
||||
sa.Column('applied_at', sa.TIMESTAMP(timezone=True)),
|
||||
sa.Column('expired_at', sa.TIMESTAMP(timezone=True)),
|
||||
sa.CheckConstraint('confidence >= 0 AND confidence <= 100', name='check_confidence_range')
|
||||
)
|
||||
|
||||
# Create indexes for ai_insights
|
||||
op.create_index('idx_tenant_id', 'ai_insights', ['tenant_id'])
|
||||
op.create_index('idx_type', 'ai_insights', ['type'])
|
||||
op.create_index('idx_priority', 'ai_insights', ['priority'])
|
||||
op.create_index('idx_category', 'ai_insights', ['category'])
|
||||
op.create_index('idx_confidence', 'ai_insights', ['confidence'])
|
||||
op.create_index('idx_status', 'ai_insights', ['status'])
|
||||
op.create_index('idx_actionable', 'ai_insights', ['actionable'])
|
||||
op.create_index('idx_created_at', 'ai_insights', ['created_at'])
|
||||
op.create_index('idx_tenant_status_category', 'ai_insights', ['tenant_id', 'status', 'category'])
|
||||
op.create_index('idx_tenant_created_confidence', 'ai_insights', ['tenant_id', 'created_at', 'confidence'])
|
||||
op.create_index('idx_actionable_status', 'ai_insights', ['actionable', 'status'])
|
||||
|
||||
# Create insight_feedback table
|
||||
op.create_table(
|
||||
'insight_feedback',
|
||||
sa.Column('id', UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column('insight_id', UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('action_taken', sa.String(100)),
|
||||
sa.Column('result_data', JSONB),
|
||||
sa.Column('success', sa.Boolean, nullable=False),
|
||||
sa.Column('error_message', sa.Text),
|
||||
sa.Column('expected_impact_value', sa.DECIMAL(10, 2)),
|
||||
sa.Column('actual_impact_value', sa.DECIMAL(10, 2)),
|
||||
sa.Column('variance_percentage', sa.DECIMAL(5, 2)),
|
||||
sa.Column('applied_by', sa.String(100)),
|
||||
sa.Column('created_at', sa.TIMESTAMP(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['insight_id'], ['ai_insights.id'], ondelete='CASCADE')
|
||||
)
|
||||
|
||||
# Create indexes for insight_feedback
|
||||
op.create_index('idx_feedback_insight_id', 'insight_feedback', ['insight_id'])
|
||||
op.create_index('idx_feedback_success', 'insight_feedback', ['success'])
|
||||
op.create_index('idx_feedback_created_at', 'insight_feedback', ['created_at'])
|
||||
op.create_index('idx_insight_success', 'insight_feedback', ['insight_id', 'success'])
|
||||
op.create_index('idx_created_success', 'insight_feedback', ['created_at', 'success'])
|
||||
|
||||
# Create insight_correlations table
|
||||
op.create_table(
|
||||
'insight_correlations',
|
||||
sa.Column('id', UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column('parent_insight_id', UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('child_insight_id', UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('correlation_type', sa.String(50), nullable=False),
|
||||
sa.Column('correlation_strength', sa.DECIMAL(3, 2), nullable=False),
|
||||
sa.Column('combined_confidence', sa.Integer),
|
||||
sa.Column('created_at', sa.TIMESTAMP(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['parent_insight_id'], ['ai_insights.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['child_insight_id'], ['ai_insights.id'], ondelete='CASCADE')
|
||||
)
|
||||
|
||||
# Create indexes for insight_correlations
|
||||
op.create_index('idx_corr_parent', 'insight_correlations', ['parent_insight_id'])
|
||||
op.create_index('idx_corr_child', 'insight_correlations', ['child_insight_id'])
|
||||
op.create_index('idx_corr_type', 'insight_correlations', ['correlation_type'])
|
||||
op.create_index('idx_corr_created_at', 'insight_correlations', ['created_at'])
|
||||
op.create_index('idx_parent_child', 'insight_correlations', ['parent_insight_id', 'child_insight_id'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table('insight_correlations')
|
||||
op.drop_table('insight_feedback')
|
||||
op.drop_table('ai_insights')
|
||||
57
services/ai_insights/requirements.txt
Normal file
57
services/ai_insights/requirements.txt
Normal file
@@ -0,0 +1,57 @@
|
||||
# FastAPI and ASGI
|
||||
fastapi==0.104.1
|
||||
uvicorn[standard]==0.24.0
|
||||
python-multipart==0.0.6
|
||||
|
||||
# Database
|
||||
sqlalchemy==2.0.23
|
||||
alembic==1.12.1
|
||||
psycopg2-binary==2.9.9
|
||||
asyncpg==0.29.0
|
||||
|
||||
# Pydantic
|
||||
pydantic==2.5.0
|
||||
pydantic-settings==2.1.0
|
||||
|
||||
# HTTP Client
|
||||
httpx==0.25.1
|
||||
aiohttp==3.9.1
|
||||
|
||||
# Redis
|
||||
redis==5.0.1
|
||||
hiredis==2.2.3
|
||||
|
||||
# Utilities
|
||||
python-dotenv==1.0.0
|
||||
python-dateutil==2.8.2
|
||||
pytz==2023.3
|
||||
|
||||
# Logging
|
||||
structlog==23.2.0
|
||||
|
||||
# Monitoring and Observability
|
||||
psutil==5.9.8
|
||||
opentelemetry-api==1.39.1
|
||||
opentelemetry-sdk==1.39.1
|
||||
opentelemetry-instrumentation-fastapi==0.60b1
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.39.1
|
||||
opentelemetry-exporter-otlp-proto-http==1.39.1
|
||||
opentelemetry-instrumentation-httpx==0.60b1
|
||||
opentelemetry-instrumentation-redis==0.60b1
|
||||
opentelemetry-instrumentation-sqlalchemy==0.60b1
|
||||
|
||||
# Machine Learning (for confidence scoring and impact estimation)
|
||||
numpy==1.26.2
|
||||
pandas==2.1.3
|
||||
scikit-learn==1.3.2
|
||||
|
||||
# Testing
|
||||
pytest==7.4.3
|
||||
pytest-asyncio==0.21.1
|
||||
pytest-cov==4.1.0
|
||||
httpx==0.25.1
|
||||
|
||||
# Code Quality
|
||||
black==23.11.0
|
||||
flake8==6.1.0
|
||||
mypy==1.7.1
|
||||
579
services/ai_insights/tests/test_feedback_learning_system.py
Normal file
579
services/ai_insights/tests/test_feedback_learning_system.py
Normal file
@@ -0,0 +1,579 @@
|
||||
"""
|
||||
Tests for Feedback Loop & Learning System
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from services.ai_insights.app.ml.feedback_learning_system import FeedbackLearningSystem
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def learning_system():
|
||||
"""Create FeedbackLearningSystem instance."""
|
||||
return FeedbackLearningSystem(
|
||||
performance_threshold=0.85,
|
||||
degradation_threshold=0.10,
|
||||
min_feedback_samples=30
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def good_feedback_data():
|
||||
"""Generate feedback data for well-performing model."""
|
||||
np.random.seed(42)
|
||||
dates = pd.date_range(start=datetime.utcnow() - timedelta(days=60), periods=50, freq='D')
|
||||
|
||||
feedback = []
|
||||
for i, date in enumerate(dates):
|
||||
predicted = 100 + np.random.normal(0, 10)
|
||||
actual = predicted + np.random.normal(0, 5) # Small error
|
||||
|
||||
error = predicted - actual
|
||||
error_pct = abs(error / actual * 100) if actual != 0 else 0
|
||||
accuracy = max(0, 100 - error_pct)
|
||||
|
||||
feedback.append({
|
||||
'insight_id': f'insight_{i}',
|
||||
'applied_at': date - timedelta(days=1),
|
||||
'outcome_date': date,
|
||||
'predicted_value': predicted,
|
||||
'actual_value': actual,
|
||||
'error': error,
|
||||
'error_pct': error_pct,
|
||||
'accuracy': accuracy,
|
||||
'confidence': 85
|
||||
})
|
||||
|
||||
return pd.DataFrame(feedback)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def degraded_feedback_data():
|
||||
"""Generate feedback data for degrading model."""
|
||||
np.random.seed(42)
|
||||
dates = pd.date_range(start=datetime.utcnow() - timedelta(days=60), periods=50, freq='D')
|
||||
|
||||
feedback = []
|
||||
for i, date in enumerate(dates):
|
||||
# Introduce increasing error over time
|
||||
error_multiplier = 1 + (i / 50) * 2 # Errors double by end
|
||||
|
||||
predicted = 100 + np.random.normal(0, 10)
|
||||
actual = predicted + np.random.normal(0, 10 * error_multiplier)
|
||||
|
||||
error = predicted - actual
|
||||
error_pct = abs(error / actual * 100) if actual != 0 else 0
|
||||
accuracy = max(0, 100 - error_pct)
|
||||
|
||||
feedback.append({
|
||||
'insight_id': f'insight_{i}',
|
||||
'applied_at': date - timedelta(days=1),
|
||||
'outcome_date': date,
|
||||
'predicted_value': predicted,
|
||||
'actual_value': actual,
|
||||
'error': error,
|
||||
'error_pct': error_pct,
|
||||
'accuracy': accuracy,
|
||||
'confidence': 85
|
||||
})
|
||||
|
||||
return pd.DataFrame(feedback)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def biased_feedback_data():
|
||||
"""Generate feedback data with systematic bias."""
|
||||
np.random.seed(42)
|
||||
dates = pd.date_range(start=datetime.utcnow() - timedelta(days=60), periods=50, freq='D')
|
||||
|
||||
feedback = []
|
||||
for i, date in enumerate(dates):
|
||||
predicted = 100 + np.random.normal(0, 10)
|
||||
# Systematic over-prediction by 15%
|
||||
actual = predicted * 0.85 + np.random.normal(0, 3)
|
||||
|
||||
error = predicted - actual
|
||||
error_pct = abs(error / actual * 100) if actual != 0 else 0
|
||||
accuracy = max(0, 100 - error_pct)
|
||||
|
||||
feedback.append({
|
||||
'insight_id': f'insight_{i}',
|
||||
'applied_at': date - timedelta(days=1),
|
||||
'outcome_date': date,
|
||||
'predicted_value': predicted,
|
||||
'actual_value': actual,
|
||||
'error': error,
|
||||
'error_pct': error_pct,
|
||||
'accuracy': accuracy,
|
||||
'confidence': 80
|
||||
})
|
||||
|
||||
return pd.DataFrame(feedback)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def poorly_calibrated_feedback_data():
|
||||
"""Generate feedback with poor confidence calibration."""
|
||||
np.random.seed(42)
|
||||
dates = pd.date_range(start=datetime.utcnow() - timedelta(days=60), periods=50, freq='D')
|
||||
|
||||
feedback = []
|
||||
for i, date in enumerate(dates):
|
||||
predicted = 100 + np.random.normal(0, 10)
|
||||
|
||||
# High confidence but low accuracy
|
||||
if i < 25:
|
||||
confidence = 90
|
||||
actual = predicted + np.random.normal(0, 20) # Large error
|
||||
else:
|
||||
confidence = 60
|
||||
actual = predicted + np.random.normal(0, 5) # Small error
|
||||
|
||||
error = predicted - actual
|
||||
error_pct = abs(error / actual * 100) if actual != 0 else 0
|
||||
accuracy = max(0, 100 - error_pct)
|
||||
|
||||
feedback.append({
|
||||
'insight_id': f'insight_{i}',
|
||||
'applied_at': date - timedelta(days=1),
|
||||
'outcome_date': date,
|
||||
'predicted_value': predicted,
|
||||
'actual_value': actual,
|
||||
'error': error,
|
||||
'error_pct': error_pct,
|
||||
'accuracy': accuracy,
|
||||
'confidence': confidence
|
||||
})
|
||||
|
||||
return pd.DataFrame(feedback)
|
||||
|
||||
|
||||
class TestPerformanceMetrics:
|
||||
"""Test performance metric calculation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_metrics_good_performance(self, learning_system, good_feedback_data):
|
||||
"""Test metric calculation for good performance."""
|
||||
metrics = learning_system._calculate_performance_metrics(good_feedback_data)
|
||||
|
||||
assert 'accuracy' in metrics
|
||||
assert 'mae' in metrics
|
||||
assert 'rmse' in metrics
|
||||
assert 'mape' in metrics
|
||||
assert 'bias' in metrics
|
||||
assert 'r_squared' in metrics
|
||||
|
||||
# Good model should have high accuracy
|
||||
assert metrics['accuracy'] > 80
|
||||
assert metrics['mae'] < 10
|
||||
assert abs(metrics['bias']) < 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_metrics_degraded_performance(self, learning_system, degraded_feedback_data):
|
||||
"""Test metric calculation for degraded performance."""
|
||||
metrics = learning_system._calculate_performance_metrics(degraded_feedback_data)
|
||||
|
||||
# Degraded model should have lower accuracy
|
||||
assert metrics['accuracy'] < 80
|
||||
assert metrics['mae'] > 5
|
||||
|
||||
|
||||
class TestPerformanceTrend:
|
||||
"""Test performance trend analysis."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stable_trend(self, learning_system, good_feedback_data):
|
||||
"""Test detection of stable performance trend."""
|
||||
trend = learning_system._analyze_performance_trend(good_feedback_data)
|
||||
|
||||
assert trend['trend'] in ['stable', 'improving']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_degrading_trend(self, learning_system, degraded_feedback_data):
|
||||
"""Test detection of degrading performance trend."""
|
||||
trend = learning_system._analyze_performance_trend(degraded_feedback_data)
|
||||
|
||||
# May detect degrading trend depending on data
|
||||
assert trend['trend'] in ['degrading', 'stable']
|
||||
if trend['significant']:
|
||||
assert 'slope' in trend
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insufficient_data_trend(self, learning_system):
|
||||
"""Test trend analysis with insufficient data."""
|
||||
small_data = pd.DataFrame([{
|
||||
'insight_id': 'test',
|
||||
'outcome_date': datetime.utcnow(),
|
||||
'accuracy': 90
|
||||
}])
|
||||
|
||||
trend = learning_system._analyze_performance_trend(small_data)
|
||||
assert trend['trend'] == 'insufficient_data'
|
||||
|
||||
|
||||
class TestDegradationDetection:
|
||||
"""Test performance degradation detection."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_degradation_detected(self, learning_system, good_feedback_data):
|
||||
"""Test no degradation for good performance."""
|
||||
current_metrics = learning_system._calculate_performance_metrics(good_feedback_data)
|
||||
trend = learning_system._analyze_performance_trend(good_feedback_data)
|
||||
|
||||
degradation = learning_system._detect_performance_degradation(
|
||||
current_metrics,
|
||||
baseline_performance={'accuracy': 85},
|
||||
trend_analysis=trend
|
||||
)
|
||||
|
||||
assert degradation['detected'] is False
|
||||
assert degradation['severity'] == 'none'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_degradation_below_threshold(self, learning_system):
|
||||
"""Test degradation detection when below absolute threshold."""
|
||||
current_metrics = {'accuracy': 70} # Below 85% threshold
|
||||
trend = {'trend': 'stable', 'significant': False}
|
||||
|
||||
degradation = learning_system._detect_performance_degradation(
|
||||
current_metrics,
|
||||
baseline_performance=None,
|
||||
trend_analysis=trend
|
||||
)
|
||||
|
||||
assert degradation['detected'] is True
|
||||
assert degradation['severity'] == 'high'
|
||||
assert len(degradation['reasons']) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_degradation_vs_baseline(self, learning_system):
|
||||
"""Test degradation detection vs baseline."""
|
||||
current_metrics = {'accuracy': 80}
|
||||
baseline = {'accuracy': 95} # 15.8% drop
|
||||
trend = {'trend': 'stable', 'significant': False}
|
||||
|
||||
degradation = learning_system._detect_performance_degradation(
|
||||
current_metrics,
|
||||
baseline_performance=baseline,
|
||||
trend_analysis=trend
|
||||
)
|
||||
|
||||
assert degradation['detected'] is True
|
||||
assert 'dropped' in degradation['reasons'][0].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_degradation_trending_down(self, learning_system, degraded_feedback_data):
|
||||
"""Test degradation detection from trending down."""
|
||||
current_metrics = learning_system._calculate_performance_metrics(degraded_feedback_data)
|
||||
trend = learning_system._analyze_performance_trend(degraded_feedback_data)
|
||||
|
||||
degradation = learning_system._detect_performance_degradation(
|
||||
current_metrics,
|
||||
baseline_performance={'accuracy': 90},
|
||||
trend_analysis=trend
|
||||
)
|
||||
|
||||
# Should detect some form of degradation
|
||||
assert degradation['detected'] is True
|
||||
|
||||
|
||||
class TestRetrainingRecommendation:
|
||||
"""Test retraining recommendation generation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_urgent_retraining_recommendation(self, learning_system):
|
||||
"""Test urgent retraining recommendation."""
|
||||
current_metrics = {'accuracy': 70}
|
||||
degradation = {
|
||||
'detected': True,
|
||||
'severity': 'high',
|
||||
'reasons': ['Accuracy below threshold'],
|
||||
'current_accuracy': 70,
|
||||
'baseline_accuracy': 90
|
||||
}
|
||||
trend = {'trend': 'degrading', 'significant': True}
|
||||
|
||||
recommendation = learning_system._generate_retraining_recommendation(
|
||||
'test_model',
|
||||
current_metrics,
|
||||
degradation,
|
||||
trend
|
||||
)
|
||||
|
||||
assert recommendation['recommended'] is True
|
||||
assert recommendation['priority'] == 'urgent'
|
||||
assert 'immediately' in recommendation['recommendation'].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_retraining_needed(self, learning_system, good_feedback_data):
|
||||
"""Test no retraining recommendation for good performance."""
|
||||
current_metrics = learning_system._calculate_performance_metrics(good_feedback_data)
|
||||
degradation = {'detected': False, 'severity': 'none'}
|
||||
trend = learning_system._analyze_performance_trend(good_feedback_data)
|
||||
|
||||
recommendation = learning_system._generate_retraining_recommendation(
|
||||
'test_model',
|
||||
current_metrics,
|
||||
degradation,
|
||||
trend
|
||||
)
|
||||
|
||||
assert recommendation['recommended'] is False
|
||||
assert recommendation['priority'] == 'none'
|
||||
|
||||
|
||||
class TestErrorPatternDetection:
|
||||
"""Test error pattern identification."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_systematic_bias_detection(self, learning_system, biased_feedback_data):
|
||||
"""Test detection of systematic bias."""
|
||||
patterns = learning_system._identify_error_patterns(biased_feedback_data)
|
||||
|
||||
# Should detect over-prediction bias
|
||||
bias_patterns = [p for p in patterns if p['pattern'] == 'systematic_bias']
|
||||
assert len(bias_patterns) > 0
|
||||
|
||||
bias = bias_patterns[0]
|
||||
assert 'over-prediction' in bias['description']
|
||||
assert bias['severity'] in ['high', 'medium']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_patterns_for_good_data(self, learning_system, good_feedback_data):
|
||||
"""Test no significant patterns for good data."""
|
||||
patterns = learning_system._identify_error_patterns(good_feedback_data)
|
||||
|
||||
# May have some minor patterns, but no high severity
|
||||
high_severity = [p for p in patterns if p.get('severity') == 'high']
|
||||
assert len(high_severity) == 0
|
||||
|
||||
|
||||
class TestConfidenceCalibration:
|
||||
"""Test confidence calibration analysis."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_well_calibrated_confidence(self, learning_system, good_feedback_data):
|
||||
"""Test well-calibrated confidence scores."""
|
||||
calibration = learning_system._calculate_confidence_calibration(good_feedback_data)
|
||||
|
||||
# Good data with consistent confidence should be well calibrated
|
||||
if 'overall_calibration_error' in calibration:
|
||||
# Small calibration error indicates good calibration
|
||||
assert calibration['overall_calibration_error'] < 20
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poorly_calibrated_confidence(self, learning_system, poorly_calibrated_feedback_data):
|
||||
"""Test poorly calibrated confidence scores."""
|
||||
calibration = learning_system._calculate_confidence_calibration(poorly_calibrated_feedback_data)
|
||||
|
||||
# Should detect poor calibration
|
||||
assert calibration['calibrated'] is False
|
||||
if 'by_confidence_range' in calibration:
|
||||
assert len(calibration['by_confidence_range']) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_confidence_data(self, learning_system):
|
||||
"""Test calibration when no confidence scores available."""
|
||||
no_conf_data = pd.DataFrame([{
|
||||
'predicted_value': 100,
|
||||
'actual_value': 95,
|
||||
'accuracy': 95
|
||||
}])
|
||||
|
||||
calibration = learning_system._calculate_confidence_calibration(no_conf_data)
|
||||
assert calibration['calibrated'] is False
|
||||
assert 'reason' in calibration
|
||||
|
||||
|
||||
class TestCompletePerformanceAnalysis:
|
||||
"""Test complete performance analysis workflow."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_good_performance(self, learning_system, good_feedback_data):
|
||||
"""Test complete analysis of good performance."""
|
||||
result = await learning_system.analyze_model_performance(
|
||||
model_name='test_model',
|
||||
feedback_data=good_feedback_data,
|
||||
baseline_performance={'accuracy': 85}
|
||||
)
|
||||
|
||||
assert result['model_name'] == 'test_model'
|
||||
assert result['status'] != 'insufficient_feedback'
|
||||
assert 'current_performance' in result
|
||||
assert 'trend_analysis' in result
|
||||
assert 'degradation_detected' in result
|
||||
assert 'retraining_recommendation' in result
|
||||
|
||||
# Good performance should not recommend retraining
|
||||
assert result['retraining_recommendation']['recommended'] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_degraded_performance(self, learning_system, degraded_feedback_data):
|
||||
"""Test complete analysis of degraded performance."""
|
||||
result = await learning_system.analyze_model_performance(
|
||||
model_name='degraded_model',
|
||||
feedback_data=degraded_feedback_data,
|
||||
baseline_performance={'accuracy': 90}
|
||||
)
|
||||
|
||||
assert result['degradation_detected']['detected'] is True
|
||||
assert result['retraining_recommendation']['recommended'] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insufficient_feedback(self, learning_system):
|
||||
"""Test analysis with insufficient feedback samples."""
|
||||
small_data = pd.DataFrame([{
|
||||
'insight_id': 'test',
|
||||
'outcome_date': datetime.utcnow(),
|
||||
'predicted_value': 100,
|
||||
'actual_value': 95,
|
||||
'error': 5,
|
||||
'error_pct': 5,
|
||||
'accuracy': 95,
|
||||
'confidence': 85
|
||||
}])
|
||||
|
||||
result = await learning_system.analyze_model_performance(
|
||||
model_name='test_model',
|
||||
feedback_data=small_data
|
||||
)
|
||||
|
||||
assert result['status'] == 'insufficient_feedback'
|
||||
assert result['feedback_samples'] == 1
|
||||
assert result['required_samples'] == 30
|
||||
|
||||
|
||||
class TestLearningInsights:
|
||||
"""Test learning insight generation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_urgent_retraining_insight(self, learning_system):
|
||||
"""Test generation of urgent retraining insight."""
|
||||
analyses = [{
|
||||
'model_name': 'urgent_model',
|
||||
'retraining_recommendation': {
|
||||
'priority': 'urgent',
|
||||
'recommended': True
|
||||
},
|
||||
'degradation_detected': {
|
||||
'detected': True
|
||||
}
|
||||
}]
|
||||
|
||||
insights = await learning_system.generate_learning_insights(
|
||||
analyses,
|
||||
tenant_id='tenant_123'
|
||||
)
|
||||
|
||||
# Should generate urgent warning
|
||||
urgent_insights = [i for i in insights if i['priority'] == 'urgent']
|
||||
assert len(urgent_insights) > 0
|
||||
|
||||
insight = urgent_insights[0]
|
||||
assert insight['type'] == 'warning'
|
||||
assert 'urgent_model' in insight['description'].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_system_health_insight(self, learning_system):
|
||||
"""Test generation of system health insight."""
|
||||
# 3 models, 1 degraded
|
||||
analyses = [
|
||||
{
|
||||
'model_name': 'model_1',
|
||||
'degradation_detected': {'detected': False},
|
||||
'retraining_recommendation': {'priority': 'none'}
|
||||
},
|
||||
{
|
||||
'model_name': 'model_2',
|
||||
'degradation_detected': {'detected': False},
|
||||
'retraining_recommendation': {'priority': 'none'}
|
||||
},
|
||||
{
|
||||
'model_name': 'model_3',
|
||||
'degradation_detected': {'detected': True},
|
||||
'retraining_recommendation': {'priority': 'high'}
|
||||
}
|
||||
]
|
||||
|
||||
insights = await learning_system.generate_learning_insights(
|
||||
analyses,
|
||||
tenant_id='tenant_123'
|
||||
)
|
||||
|
||||
# Should generate system health insight (66% healthy < 80%)
|
||||
# Note: May or may not trigger depending on threshold
|
||||
# At minimum should not crash
|
||||
assert isinstance(insights, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_calibration_insight(self, learning_system):
|
||||
"""Test generation of calibration insight."""
|
||||
analyses = [{
|
||||
'model_name': 'model_1',
|
||||
'degradation_detected': {'detected': False},
|
||||
'retraining_recommendation': {'priority': 'none'},
|
||||
'confidence_calibration': {
|
||||
'calibrated': False,
|
||||
'overall_calibration_error': 15
|
||||
}
|
||||
}]
|
||||
|
||||
insights = await learning_system.generate_learning_insights(
|
||||
analyses,
|
||||
tenant_id='tenant_123'
|
||||
)
|
||||
|
||||
# Should generate calibration insight
|
||||
calibration_insights = [
|
||||
i for i in insights
|
||||
if 'calibration' in i['title'].lower()
|
||||
]
|
||||
assert len(calibration_insights) > 0
|
||||
|
||||
|
||||
class TestROICalculation:
|
||||
"""Test ROI calculation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_roi_with_impact_values(self, learning_system):
|
||||
"""Test ROI calculation with impact values."""
|
||||
feedback_data = pd.DataFrame([
|
||||
{
|
||||
'accuracy': 90,
|
||||
'impact_value': 1000
|
||||
},
|
||||
{
|
||||
'accuracy': 85,
|
||||
'impact_value': 1500
|
||||
},
|
||||
{
|
||||
'accuracy': 95,
|
||||
'impact_value': 800
|
||||
}
|
||||
])
|
||||
|
||||
roi = await learning_system.calculate_roi(
|
||||
feedback_data,
|
||||
insight_type='demand_forecast'
|
||||
)
|
||||
|
||||
assert roi['insight_type'] == 'demand_forecast'
|
||||
assert roi['samples'] == 3
|
||||
assert roi['avg_accuracy'] == 90.0
|
||||
assert roi['total_impact_value'] == 3300
|
||||
assert roi['roi_validated'] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_roi_without_impact_values(self, learning_system, good_feedback_data):
|
||||
"""Test ROI calculation without impact values."""
|
||||
roi = await learning_system.calculate_roi(
|
||||
good_feedback_data,
|
||||
insight_type='yield_prediction'
|
||||
)
|
||||
|
||||
assert roi['insight_type'] == 'yield_prediction'
|
||||
assert roi['samples'] > 0
|
||||
assert 'avg_accuracy' in roi
|
||||
assert roi['roi_validated'] is False
|
||||
59
services/alert_processor/Dockerfile
Normal file
59
services/alert_processor/Dockerfile
Normal file
@@ -0,0 +1,59 @@
|
||||
# =============================================================================
|
||||
# Alert Processor Service Dockerfile - Environment-Configurable Base Images
|
||||
# =============================================================================
|
||||
# Build arguments for registry configuration:
|
||||
# - BASE_REGISTRY: Registry URL (default: docker.io for Docker Hub)
|
||||
# - PYTHON_IMAGE: Python image name and tag (default: python:3.11-slim)
|
||||
# =============================================================================
|
||||
|
||||
ARG BASE_REGISTRY=docker.io
|
||||
ARG PYTHON_IMAGE=python:3.11-slim
|
||||
|
||||
FROM ${BASE_REGISTRY}/${PYTHON_IMAGE} AS shared
|
||||
WORKDIR /shared
|
||||
COPY shared/ /shared/
|
||||
|
||||
ARG BASE_REGISTRY=docker.io
|
||||
ARG PYTHON_IMAGE=python:3.11-slim
|
||||
FROM ${BASE_REGISTRY}/${PYTHON_IMAGE}
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
g++ \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements
|
||||
COPY shared/requirements-tracing.txt /tmp/
|
||||
|
||||
COPY services/alert_processor/requirements.txt .
|
||||
|
||||
# Install Python dependencies
|
||||
RUN pip install --no-cache-dir -r /tmp/requirements-tracing.txt
|
||||
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy shared libraries from the shared stage
|
||||
COPY --from=shared /shared /app/shared
|
||||
|
||||
# Copy application code
|
||||
COPY services/alert_processor/ .
|
||||
|
||||
|
||||
|
||||
# Add shared libraries to Python path
|
||||
ENV PYTHONPATH="/app:/app/shared:${PYTHONPATH:-}"
|
||||
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD curl -f http://localhost:8000/health || exit 1
|
||||
|
||||
# Run application
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
373
services/alert_processor/README.md
Normal file
373
services/alert_processor/README.md
Normal file
@@ -0,0 +1,373 @@
|
||||
# Alert Processor Service v2.0
|
||||
|
||||
Clean, well-structured event processing and alert management system with sophisticated enrichment pipeline.
|
||||
|
||||
## Overview
|
||||
|
||||
The Alert Processor Service receives **minimal events** from other services (inventory, production, procurement, etc.) and enriches them with:
|
||||
|
||||
- **i18n message generation** - Parameterized titles and messages for frontend
|
||||
- **Multi-factor priority scoring** - Business impact (40%), Urgency (30%), User agency (20%), Confidence (10%)
|
||||
- **Business impact analysis** - Financial impact, affected orders, customer impact
|
||||
- **Urgency assessment** - Time until consequence, deadlines, escalation
|
||||
- **User agency analysis** - Can user fix? External dependencies? Blockers?
|
||||
- **AI orchestrator context** - Query for AI actions already taken
|
||||
- **Smart action generation** - Contextual buttons with deep links
|
||||
- **Entity linking** - References to related entities (POs, batches, orders)
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
Services → RabbitMQ → [Alert Processor] → PostgreSQL
|
||||
↓
|
||||
Notification Service
|
||||
↓
|
||||
Redis (SSE Pub/Sub)
|
||||
↓
|
||||
Frontend
|
||||
```
|
||||
|
||||
### Enrichment Pipeline
|
||||
|
||||
1. **Duplicate Detection**: Checks for duplicate alerts within 24-hour window
|
||||
2. **Message Generator**: Creates i18n keys and parameters from metadata
|
||||
3. **Orchestrator Client**: Queries AI orchestrator for context
|
||||
4. **AI Reasoning Extractor**: Extracts AI reasoning details and confidence scores
|
||||
5. **Business Impact Analyzer**: Calculates financial and operational impact
|
||||
6. **Urgency Analyzer**: Assesses time sensitivity and deadlines
|
||||
7. **User Agency Analyzer**: Determines user's ability to act
|
||||
8. **Priority Scorer**: Calculates weighted priority score (0-100)
|
||||
9. **Type Classifier**: Determines if action needed or issue prevented
|
||||
10. **Smart Action Generator**: Creates contextual action buttons
|
||||
11. **Entity Link Extractor**: Maps metadata to entity references
|
||||
|
||||
## Service Structure
|
||||
|
||||
```
|
||||
alert_processor_v2/
|
||||
├── app/
|
||||
│ ├── main.py # FastAPI app + lifecycle
|
||||
│ ├── core/
|
||||
│ │ ├── config.py # Settings
|
||||
│ │ └── database.py # Database session management
|
||||
│ ├── models/
|
||||
│ │ └── events.py # SQLAlchemy Event model
|
||||
│ ├── schemas/
|
||||
│ │ └── events.py # Pydantic schemas
|
||||
│ ├── api/
|
||||
│ │ ├── alerts.py # Alert endpoints
|
||||
│ │ └── sse.py # SSE streaming
|
||||
│ ├── consumer/
|
||||
│ │ └── event_consumer.py # RabbitMQ consumer
|
||||
│ ├── enrichment/
|
||||
│ │ ├── message_generator.py # i18n generation
|
||||
│ │ ├── priority_scorer.py # Priority calculation
|
||||
│ │ ├── orchestrator_client.py # AI context
|
||||
│ │ ├── smart_actions.py # Action buttons
|
||||
│ │ ├── business_impact.py # Impact analysis
|
||||
│ │ ├── urgency_analyzer.py # Urgency assessment
|
||||
│ │ └── user_agency.py # Agency analysis
|
||||
│ ├── repositories/
|
||||
│ │ └── event_repository.py # Database queries
|
||||
│ ├── services/
|
||||
│ │ ├── enrichment_orchestrator.py # Pipeline coordinator
|
||||
│ │ └── sse_service.py # SSE pub/sub
|
||||
│ └── utils/
|
||||
│ └── message_templates.py # Alert type mappings
|
||||
├── migrations/
|
||||
│ └── versions/
|
||||
│ └── 20251205_clean_unified_schema.py
|
||||
└── requirements.txt
|
||||
```
|
||||
|
||||
## Environment Variables
|
||||
|
||||
```bash
|
||||
# Service
|
||||
SERVICE_NAME=alert-processor
|
||||
VERSION=2.0.0
|
||||
DEBUG=false
|
||||
|
||||
# Database
|
||||
DATABASE_URL=postgresql+asyncpg://user:pass@localhost/db
|
||||
|
||||
# RabbitMQ
|
||||
RABBITMQ_URL=amqp://guest:guest@localhost/
|
||||
RABBITMQ_EXCHANGE=events.exchange
|
||||
RABBITMQ_QUEUE=alert_processor.queue
|
||||
|
||||
# Redis
|
||||
REDIS_URL=redis://localhost:6379/0
|
||||
REDIS_SSE_PREFIX=alerts
|
||||
|
||||
# Orchestrator Service
|
||||
ORCHESTRATOR_URL=http://orchestrator:8000
|
||||
ORCHESTRATOR_TIMEOUT=10
|
||||
|
||||
# Notification Service
|
||||
NOTIFICATION_URL=http://notification:8000
|
||||
NOTIFICATION_TIMEOUT=5
|
||||
|
||||
# Cache
|
||||
CACHE_ENABLED=true
|
||||
CACHE_TTL_SECONDS=300
|
||||
```
|
||||
|
||||
## Running the Service
|
||||
|
||||
### Local Development
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Run database migrations
|
||||
alembic upgrade head
|
||||
|
||||
# Start service
|
||||
python -m app.main
|
||||
# or
|
||||
uvicorn app.main:app --reload
|
||||
```
|
||||
|
||||
### Docker
|
||||
|
||||
```bash
|
||||
docker build -t alert-processor:2.0 .
|
||||
docker run -p 8000:8000 --env-file .env alert-processor:2.0
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
### Alert Management
|
||||
|
||||
- `GET /api/v1/tenants/{tenant_id}/alerts` - List alerts with filters
|
||||
- `GET /api/v1/tenants/{tenant_id}/alerts/summary` - Get dashboard summary
|
||||
- `GET /api/v1/tenants/{tenant_id}/alerts/{alert_id}` - Get single alert
|
||||
- `POST /api/v1/tenants/{tenant_id}/alerts/{alert_id}/acknowledge` - Acknowledge alert
|
||||
- `POST /api/v1/tenants/{tenant_id}/alerts/{alert_id}/resolve` - Resolve alert
|
||||
- `POST /api/v1/tenants/{tenant_id}/alerts/{alert_id}/dismiss` - Dismiss alert
|
||||
|
||||
### Real-Time Streaming
|
||||
|
||||
- `GET /api/v1/sse/alerts/{tenant_id}` - SSE stream for real-time alerts
|
||||
|
||||
### Health Check
|
||||
|
||||
- `GET /health` - Service health status
|
||||
|
||||
## Event Flow
|
||||
|
||||
### 1. Service Emits Minimal Event
|
||||
|
||||
```python
|
||||
from shared.messaging.event_publisher import EventPublisher
|
||||
|
||||
await publisher.publish_alert(
|
||||
tenant_id=tenant_id,
|
||||
event_type="critical_stock_shortage",
|
||||
event_domain="inventory",
|
||||
severity="urgent",
|
||||
metadata={
|
||||
"ingredient_id": "...",
|
||||
"ingredient_name": "Flour",
|
||||
"current_stock": 10.5,
|
||||
"required_stock": 50.0,
|
||||
"shortage_amount": 39.5
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### 2. Alert Processor Enriches
|
||||
|
||||
- **Checks for duplicates**: Searches 24-hour window for similar alerts
|
||||
- Generates i18n: `alerts.critical_stock_shortage.title` with params
|
||||
- Queries orchestrator for AI context
|
||||
- Extracts AI reasoning and confidence scores (if available)
|
||||
- Analyzes business impact: €197.50 financial impact
|
||||
- Assesses urgency: 12 hours until consequence
|
||||
- Determines user agency: Can create PO, requires supplier
|
||||
- Calculates priority: Score 78 → "important"
|
||||
- Classifies type: `action_needed` or `prevented_issue`
|
||||
- Generates smart actions: [Create PO, Call Supplier, Dismiss]
|
||||
- Extracts entity links: `{ingredient: "..."}`
|
||||
|
||||
### 3. Stores Enriched Event
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "...",
|
||||
"event_type": "critical_stock_shortage",
|
||||
"event_domain": "inventory",
|
||||
"severity": "urgent",
|
||||
"type_class": "action_needed",
|
||||
"priority_score": 78,
|
||||
"priority_level": "important",
|
||||
"confidence_score": 95,
|
||||
"i18n": {
|
||||
"title_key": "alerts.critical_stock_shortage.title",
|
||||
"title_params": {"ingredient_name": "Flour"},
|
||||
"message_key": "alerts.critical_stock_shortage.message_generic",
|
||||
"message_params": {
|
||||
"current_stock_kg": 10.5,
|
||||
"required_stock_kg": 50.0
|
||||
}
|
||||
},
|
||||
"business_impact": {...},
|
||||
"urgency": {...},
|
||||
"user_agency": {...},
|
||||
"ai_reasoning_details": {...},
|
||||
"orchestrator_context": {...},
|
||||
"smart_actions": [...],
|
||||
"entity_links": {"ingredient": "..."}
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Sends Notification
|
||||
|
||||
Calls notification service with event details for delivery via WhatsApp, Email, Push, etc.
|
||||
|
||||
### 5. Publishes to SSE
|
||||
|
||||
Publishes to Redis channel `alerts:{tenant_id}` for real-time frontend updates.
|
||||
|
||||
## Priority Scoring Algorithm
|
||||
|
||||
**Formula**: `Total = (Impact × 0.4) + (Urgency × 0.3) + (Agency × 0.2) + (Confidence × 0.1)`
|
||||
|
||||
**Business Impact Score (0-100)**:
|
||||
- Financial impact > €1000: +30
|
||||
- Affected orders > 10: +15
|
||||
- High customer impact: +15
|
||||
- Production delay > 4h: +10
|
||||
- Revenue loss > €500: +10
|
||||
|
||||
**Urgency Score (0-100)**:
|
||||
- Time until consequence < 2h: +40
|
||||
- Deadline present: +5
|
||||
- Can't wait until tomorrow: +10
|
||||
- Peak hour relevant: +5
|
||||
|
||||
**User Agency Score (0-100)**:
|
||||
- User can fix: +30
|
||||
- Requires external party: -10
|
||||
- Has blockers: -5 per blocker
|
||||
- Has workaround: +5
|
||||
|
||||
**Escalation Boost** (up to +30):
|
||||
- Pending > 72h: +20
|
||||
- Deadline < 6h: +30
|
||||
|
||||
## Alert Types
|
||||
|
||||
See [app/utils/message_templates.py](app/utils/message_templates.py) for complete list.
|
||||
|
||||
### Standard Alerts
|
||||
- `critical_stock_shortage` - Urgent stock shortages
|
||||
- `low_stock_warning` - Stock running low
|
||||
- `production_delay` - Production behind schedule
|
||||
- `equipment_failure` - Equipment issues
|
||||
- `po_approval_needed` - Purchase order approval required
|
||||
- `temperature_breach` - Temperature control violations
|
||||
- `delivery_overdue` - Late deliveries
|
||||
- `expired_products` - Product expiration warnings
|
||||
|
||||
### AI Recommendations
|
||||
- `ai_yield_prediction` - AI-predicted production yields
|
||||
- `ai_safety_stock_optimization` - AI stock level recommendations
|
||||
- `ai_supplier_recommendation` - AI supplier suggestions
|
||||
- `ai_price_forecast` - AI price predictions
|
||||
- `ai_demand_forecast` - AI demand forecasts
|
||||
- `ai_business_rule` - AI-suggested business rules
|
||||
|
||||
## Database Schema
|
||||
|
||||
**events table** with JSONB enrichment:
|
||||
- Core: `id`, `tenant_id`, `created_at`, `event_type`, `event_domain`, `severity`
|
||||
- i18n: `i18n_title_key`, `i18n_title_params`, `i18n_message_key`, `i18n_message_params`
|
||||
- Priority: `priority_score` (0-100), `priority_level` (critical/important/standard/info)
|
||||
- Enrichment: `orchestrator_context`, `business_impact`, `urgency`, `user_agency` (JSONB)
|
||||
- AI Fields: `ai_reasoning_details`, `confidence_score`, `ai_reasoning_summary_key`, `ai_reasoning_summary_params`
|
||||
- Classification: `type_class` (action_needed/prevented_issue)
|
||||
- Actions: `smart_actions` (JSONB array)
|
||||
- Entities: `entity_links` (JSONB)
|
||||
- Status: `status` (active/acknowledged/resolved/dismissed)
|
||||
- Metadata: `raw_metadata` (JSONB)
|
||||
|
||||
## Key Features
|
||||
|
||||
### Duplicate Alert Detection
|
||||
|
||||
The service automatically detects and prevents duplicate alerts:
|
||||
- **24-hour window**: Checks for similar alerts in the past 24 hours
|
||||
- **Smart matching**: Compares `tenant_id`, `event_type`, and key metadata fields
|
||||
- **Update strategy**: Updates existing alert instead of creating duplicates
|
||||
- **Metadata preservation**: Keeps enriched data while preventing alert fatigue
|
||||
|
||||
### Type Classification
|
||||
|
||||
Events are classified into two types:
|
||||
- **action_needed**: User action required (default for alerts)
|
||||
- **prevented_issue**: AI already handled the situation (for AI recommendations)
|
||||
|
||||
This helps the frontend display appropriate UI and messaging.
|
||||
|
||||
### AI Reasoning Integration
|
||||
|
||||
When AI orchestrator has acted on an event:
|
||||
- Extracts complete reasoning data structure
|
||||
- Stores confidence scores (0-100)
|
||||
- Generates i18n-friendly reasoning summaries
|
||||
- Links to orchestrator context for full details
|
||||
|
||||
### Notification Service Integration
|
||||
|
||||
Enriched events are automatically sent to the notification service for delivery via:
|
||||
- WhatsApp
|
||||
- Email
|
||||
- Push notifications
|
||||
- SMS
|
||||
|
||||
Priority mapping:
|
||||
- `critical` → urgent priority
|
||||
- `important` → high priority
|
||||
- `standard` → medium priority
|
||||
- `info` → low priority
|
||||
|
||||
## Monitoring
|
||||
|
||||
Structured JSON logs with:
|
||||
- `enrichment_started` - Event received
|
||||
- `duplicate_detected` - Duplicate alert found and updated
|
||||
- `enrichment_completed` - Enrichment pipeline finished
|
||||
- `event_stored` - Saved to database
|
||||
- `notification_sent` - Notification queued
|
||||
- `sse_event_published` - Published to SSE stream
|
||||
|
||||
## Testing
|
||||
|
||||
```bash
|
||||
# Run tests
|
||||
pytest
|
||||
|
||||
# Test enrichment pipeline
|
||||
pytest tests/test_enrichment_orchestrator.py
|
||||
|
||||
# Test priority scoring
|
||||
pytest tests/test_priority_scorer.py
|
||||
|
||||
# Test message generation
|
||||
pytest tests/test_message_generator.py
|
||||
```
|
||||
|
||||
## Migration from v1
|
||||
|
||||
See [MIGRATION_GUIDE.md](MIGRATION_GUIDE.md) for migration steps from old alert_processor.
|
||||
|
||||
Key changes:
|
||||
- Services send minimal events (no hardcoded messages)
|
||||
- All enrichment moved to alert_processor
|
||||
- Unified Event table (no separate alert/notification tables)
|
||||
- i18n-first architecture
|
||||
- Sophisticated multi-factor priority scoring
|
||||
- Smart action generation
|
||||
84
services/alert_processor/alembic.ini
Normal file
84
services/alert_processor/alembic.ini
Normal file
@@ -0,0 +1,84 @@
|
||||
# ================================================================
|
||||
# services/alert_processor/alembic.ini - Alembic Configuration
|
||||
# ================================================================
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
script_location = migrations
|
||||
|
||||
# template used to generate migration file names
|
||||
file_template = %%(year)d%%(month).2d%%(day).2d_%%(hour).2d%%(minute).2d_%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
prepend_sys_path = .
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
timezone = Europe/Madrid
|
||||
|
||||
# max length of characters to apply to the
|
||||
# "slug" field
|
||||
truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
sourceless = false
|
||||
|
||||
# version of a migration file's filename format
|
||||
version_num_format = %%s
|
||||
|
||||
# version path separator
|
||||
version_path_separator = os
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
output_encoding = utf-8
|
||||
|
||||
# Database URL - will be overridden by environment variable or settings
|
||||
sqlalchemy.url = postgresql+asyncpg://alert_processor_user:password@alert-processor-db-service:5432/alert_processor_db
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts.
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
0
services/alert_processor/app/__init__.py
Normal file
0
services/alert_processor/app/__init__.py
Normal file
0
services/alert_processor/app/api/__init__.py
Normal file
0
services/alert_processor/app/api/__init__.py
Normal file
430
services/alert_processor/app/api/alerts.py
Normal file
430
services/alert_processor/app/api/alerts.py
Normal file
@@ -0,0 +1,430 @@
|
||||
"""
|
||||
Alert API endpoints.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, HTTPException
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
import structlog
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.repositories.event_repository import EventRepository
|
||||
from app.schemas.events import EventResponse, EventSummary
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/alerts", response_model=List[EventResponse])
|
||||
async def get_alerts(
|
||||
tenant_id: UUID,
|
||||
event_class: Optional[str] = Query(None, description="Filter by event class"),
|
||||
priority_level: Optional[List[str]] = Query(None, description="Filter by priority levels"),
|
||||
status: Optional[List[str]] = Query(None, description="Filter by status values"),
|
||||
event_domain: Optional[str] = Query(None, description="Filter by domain"),
|
||||
limit: int = Query(50, le=100, description="Max results"),
|
||||
offset: int = Query(0, description="Pagination offset"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get filtered list of events.
|
||||
|
||||
Query Parameters:
|
||||
- event_class: alert, notification, recommendation
|
||||
- priority_level: critical, important, standard, info
|
||||
- status: active, acknowledged, resolved, dismissed
|
||||
- event_domain: inventory, production, supply_chain, etc.
|
||||
- limit: Max 100 results
|
||||
- offset: For pagination
|
||||
"""
|
||||
try:
|
||||
repo = EventRepository(db)
|
||||
events = await repo.get_events(
|
||||
tenant_id=tenant_id,
|
||||
event_class=event_class,
|
||||
priority_level=priority_level,
|
||||
status=status,
|
||||
event_domain=event_domain,
|
||||
limit=limit,
|
||||
offset=offset
|
||||
)
|
||||
|
||||
# Convert to response models
|
||||
return [repo._event_to_response(event) for event in events]
|
||||
|
||||
except Exception as e:
|
||||
logger.error("get_alerts_failed", error=str(e), tenant_id=str(tenant_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve alerts")
|
||||
|
||||
|
||||
@router.get("/alerts/summary", response_model=EventSummary)
|
||||
async def get_alerts_summary(
|
||||
tenant_id: UUID,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get summary statistics for dashboard.
|
||||
|
||||
Returns counts by:
|
||||
- Status (active, acknowledged, resolved)
|
||||
- Priority level (critical, important, standard, info)
|
||||
- Domain (inventory, production, etc.)
|
||||
- Type class (action_needed, prevented_issue, etc.)
|
||||
"""
|
||||
try:
|
||||
repo = EventRepository(db)
|
||||
summary = await repo.get_summary(tenant_id)
|
||||
return summary
|
||||
|
||||
except Exception as e:
|
||||
logger.error("get_summary_failed", error=str(e), tenant_id=str(tenant_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve summary")
|
||||
|
||||
|
||||
@router.get("/alerts/{alert_id}", response_model=EventResponse)
|
||||
async def get_alert(
|
||||
tenant_id: UUID,
|
||||
alert_id: UUID,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get single alert by ID"""
|
||||
try:
|
||||
repo = EventRepository(db)
|
||||
event = await repo.get_event_by_id(alert_id)
|
||||
|
||||
if not event:
|
||||
raise HTTPException(status_code=404, detail="Alert not found")
|
||||
|
||||
# Verify tenant ownership
|
||||
if event.tenant_id != tenant_id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
return repo._event_to_response(event)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("get_alert_failed", error=str(e), alert_id=str(alert_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve alert")
|
||||
|
||||
|
||||
@router.post("/alerts/{alert_id}/acknowledge", response_model=EventResponse)
|
||||
async def acknowledge_alert(
|
||||
tenant_id: UUID,
|
||||
alert_id: UUID,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Mark alert as acknowledged.
|
||||
|
||||
Sets status to 'acknowledged' and records timestamp.
|
||||
"""
|
||||
try:
|
||||
repo = EventRepository(db)
|
||||
|
||||
# Verify ownership first
|
||||
event = await repo.get_event_by_id(alert_id)
|
||||
if not event:
|
||||
raise HTTPException(status_code=404, detail="Alert not found")
|
||||
if event.tenant_id != tenant_id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
# Acknowledge
|
||||
updated_event = await repo.acknowledge_event(alert_id)
|
||||
return repo._event_to_response(updated_event)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("acknowledge_alert_failed", error=str(e), alert_id=str(alert_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to acknowledge alert")
|
||||
|
||||
|
||||
@router.post("/alerts/{alert_id}/resolve", response_model=EventResponse)
|
||||
async def resolve_alert(
|
||||
tenant_id: UUID,
|
||||
alert_id: UUID,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Mark alert as resolved.
|
||||
|
||||
Sets status to 'resolved' and records timestamp.
|
||||
"""
|
||||
try:
|
||||
repo = EventRepository(db)
|
||||
|
||||
# Verify ownership first
|
||||
event = await repo.get_event_by_id(alert_id)
|
||||
if not event:
|
||||
raise HTTPException(status_code=404, detail="Alert not found")
|
||||
if event.tenant_id != tenant_id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
# Resolve
|
||||
updated_event = await repo.resolve_event(alert_id)
|
||||
return repo._event_to_response(updated_event)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("resolve_alert_failed", error=str(e), alert_id=str(alert_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to resolve alert")
|
||||
|
||||
|
||||
@router.post("/alerts/{alert_id}/dismiss", response_model=EventResponse)
|
||||
async def dismiss_alert(
|
||||
tenant_id: UUID,
|
||||
alert_id: UUID,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Mark alert as dismissed.
|
||||
|
||||
Sets status to 'dismissed'.
|
||||
"""
|
||||
try:
|
||||
repo = EventRepository(db)
|
||||
|
||||
# Verify ownership first
|
||||
event = await repo.get_event_by_id(alert_id)
|
||||
if not event:
|
||||
raise HTTPException(status_code=404, detail="Alert not found")
|
||||
if event.tenant_id != tenant_id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
# Dismiss
|
||||
updated_event = await repo.dismiss_event(alert_id)
|
||||
return repo._event_to_response(updated_event)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("dismiss_alert_failed", error=str(e), alert_id=str(alert_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to dismiss alert")
|
||||
|
||||
|
||||
@router.post("/alerts/{alert_id}/cancel-auto-action")
|
||||
async def cancel_auto_action(
|
||||
tenant_id: UUID,
|
||||
alert_id: UUID,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Cancel an alert's auto-action (escalation countdown).
|
||||
|
||||
Changes type_class from 'escalation' to 'action_needed' if auto-action was pending.
|
||||
"""
|
||||
try:
|
||||
repo = EventRepository(db)
|
||||
|
||||
# Verify ownership first
|
||||
event = await repo.get_event_by_id(alert_id)
|
||||
if not event:
|
||||
raise HTTPException(status_code=404, detail="Alert not found")
|
||||
if event.tenant_id != tenant_id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
# Cancel auto-action (you'll need to implement this in repository)
|
||||
# For now, return success response
|
||||
return {
|
||||
"success": True,
|
||||
"event_id": str(alert_id),
|
||||
"message": "Auto-action cancelled successfully",
|
||||
"updated_type_class": "action_needed"
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("cancel_auto_action_failed", error=str(e), alert_id=str(alert_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to cancel auto-action")
|
||||
|
||||
|
||||
@router.post("/alerts/bulk-acknowledge")
|
||||
async def bulk_acknowledge_alerts(
|
||||
tenant_id: UUID,
|
||||
request_body: dict,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Acknowledge multiple alerts by metadata filter.
|
||||
|
||||
Request body:
|
||||
{
|
||||
"alert_type": "critical_stock_shortage",
|
||||
"metadata_filter": {"ingredient_id": "123"}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
alert_type = request_body.get("alert_type")
|
||||
metadata_filter = request_body.get("metadata_filter", {})
|
||||
|
||||
if not alert_type:
|
||||
raise HTTPException(status_code=400, detail="alert_type is required")
|
||||
|
||||
repo = EventRepository(db)
|
||||
|
||||
# Get matching alerts
|
||||
events = await repo.get_events(
|
||||
tenant_id=tenant_id,
|
||||
event_class="alert",
|
||||
status=["active"],
|
||||
limit=100
|
||||
)
|
||||
|
||||
# Filter by type and metadata
|
||||
matching_ids = []
|
||||
for event in events:
|
||||
if event.event_type == alert_type:
|
||||
# Check if metadata matches
|
||||
matches = all(
|
||||
event.event_metadata.get(key) == value
|
||||
for key, value in metadata_filter.items()
|
||||
)
|
||||
if matches:
|
||||
matching_ids.append(event.id)
|
||||
|
||||
# Acknowledge all matching
|
||||
acknowledged_count = 0
|
||||
for event_id in matching_ids:
|
||||
try:
|
||||
await repo.acknowledge_event(event_id)
|
||||
acknowledged_count += 1
|
||||
except Exception:
|
||||
pass # Continue with others
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"acknowledged_count": acknowledged_count,
|
||||
"alert_ids": [str(id) for id in matching_ids]
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("bulk_acknowledge_failed", error=str(e), tenant_id=str(tenant_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to bulk acknowledge alerts")
|
||||
|
||||
|
||||
@router.post("/alerts/bulk-resolve")
|
||||
async def bulk_resolve_alerts(
|
||||
tenant_id: UUID,
|
||||
request_body: dict,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Resolve multiple alerts by metadata filter.
|
||||
|
||||
Request body:
|
||||
{
|
||||
"alert_type": "critical_stock_shortage",
|
||||
"metadata_filter": {"ingredient_id": "123"}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
alert_type = request_body.get("alert_type")
|
||||
metadata_filter = request_body.get("metadata_filter", {})
|
||||
|
||||
if not alert_type:
|
||||
raise HTTPException(status_code=400, detail="alert_type is required")
|
||||
|
||||
repo = EventRepository(db)
|
||||
|
||||
# Get matching alerts
|
||||
events = await repo.get_events(
|
||||
tenant_id=tenant_id,
|
||||
event_class="alert",
|
||||
status=["active", "acknowledged"],
|
||||
limit=100
|
||||
)
|
||||
|
||||
# Filter by type and metadata
|
||||
matching_ids = []
|
||||
for event in events:
|
||||
if event.event_type == alert_type:
|
||||
# Check if metadata matches
|
||||
matches = all(
|
||||
event.event_metadata.get(key) == value
|
||||
for key, value in metadata_filter.items()
|
||||
)
|
||||
if matches:
|
||||
matching_ids.append(event.id)
|
||||
|
||||
# Resolve all matching
|
||||
resolved_count = 0
|
||||
for event_id in matching_ids:
|
||||
try:
|
||||
await repo.resolve_event(event_id)
|
||||
resolved_count += 1
|
||||
except Exception:
|
||||
pass # Continue with others
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"resolved_count": resolved_count,
|
||||
"alert_ids": [str(id) for id in matching_ids]
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("bulk_resolve_failed", error=str(e), tenant_id=str(tenant_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to bulk resolve alerts")
|
||||
|
||||
|
||||
@router.post("/events/{event_id}/interactions")
|
||||
async def record_interaction(
|
||||
tenant_id: UUID,
|
||||
event_id: UUID,
|
||||
request_body: dict,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Record user interaction with an event (for analytics).
|
||||
|
||||
Request body:
|
||||
{
|
||||
"interaction_type": "viewed" | "clicked" | "dismissed" | "acted_upon",
|
||||
"interaction_metadata": {...}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
interaction_type = request_body.get("interaction_type")
|
||||
interaction_metadata = request_body.get("interaction_metadata", {})
|
||||
|
||||
if not interaction_type:
|
||||
raise HTTPException(status_code=400, detail="interaction_type is required")
|
||||
|
||||
repo = EventRepository(db)
|
||||
|
||||
# Verify event exists and belongs to tenant
|
||||
event = await repo.get_event_by_id(event_id)
|
||||
if not event:
|
||||
raise HTTPException(status_code=404, detail="Event not found")
|
||||
if event.tenant_id != tenant_id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
# For now, just return success
|
||||
# In the future, you could store interactions in a separate table
|
||||
logger.info(
|
||||
"interaction_recorded",
|
||||
event_id=str(event_id),
|
||||
interaction_type=interaction_type,
|
||||
metadata=interaction_metadata
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"interaction_id": str(event_id), # Would be a real ID in production
|
||||
"event_id": str(event_id),
|
||||
"interaction_type": interaction_type
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("record_interaction_failed", error=str(e), event_id=str(event_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to record interaction")
|
||||
70
services/alert_processor/app/api/sse.py
Normal file
70
services/alert_processor/app/api/sse.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""
|
||||
Server-Sent Events (SSE) API endpoint.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from uuid import UUID
|
||||
from redis.asyncio import Redis
|
||||
import structlog
|
||||
|
||||
from shared.redis_utils import get_redis_client
|
||||
from app.services.sse_service import SSEService
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/sse/alerts/{tenant_id}")
|
||||
async def stream_alerts(tenant_id: UUID):
|
||||
"""
|
||||
Stream real-time alerts via Server-Sent Events (SSE).
|
||||
|
||||
Usage (frontend):
|
||||
```javascript
|
||||
const eventSource = new EventSource('/api/v1/sse/alerts/{tenant_id}');
|
||||
eventSource.onmessage = (event) => {
|
||||
const alert = JSON.parse(event.data);
|
||||
console.log('New alert:', alert);
|
||||
};
|
||||
```
|
||||
|
||||
Response format:
|
||||
```
|
||||
data: {"id": "...", "event_type": "...", ...}
|
||||
|
||||
data: {"id": "...", "event_type": "...", ...}
|
||||
|
||||
```
|
||||
"""
|
||||
# Get Redis client from shared utilities
|
||||
redis = await get_redis_client()
|
||||
try:
|
||||
sse_service = SSEService(redis)
|
||||
|
||||
async def event_generator():
|
||||
"""Generator for SSE stream"""
|
||||
try:
|
||||
async for message in sse_service.subscribe_to_tenant(str(tenant_id)):
|
||||
# Format as SSE message
|
||||
yield f"data: {message}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
logger.error("sse_stream_error", error=str(e), tenant_id=str(tenant_id))
|
||||
# Send error message and close
|
||||
yield f"event: error\ndata: {str(e)}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no" # Disable nginx buffering
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("sse_setup_failed", error=str(e), tenant_id=str(tenant_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to setup SSE stream")
|
||||
0
services/alert_processor/app/consumer/__init__.py
Normal file
0
services/alert_processor/app/consumer/__init__.py
Normal file
295
services/alert_processor/app/consumer/event_consumer.py
Normal file
295
services/alert_processor/app/consumer/event_consumer.py
Normal file
@@ -0,0 +1,295 @@
|
||||
"""
|
||||
RabbitMQ event consumer.
|
||||
|
||||
Consumes minimal events from services and processes them through
|
||||
the enrichment pipeline.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from aio_pika import connect_robust, IncomingMessage, Connection, Channel
|
||||
import structlog
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import AsyncSessionLocal
|
||||
from shared.messaging import MinimalEvent
|
||||
from app.services.enrichment_orchestrator import EnrichmentOrchestrator
|
||||
from app.repositories.event_repository import EventRepository
|
||||
from shared.clients.notification_client import create_notification_client
|
||||
from app.services.sse_service import SSEService
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class EventConsumer:
|
||||
"""
|
||||
RabbitMQ consumer for processing events.
|
||||
|
||||
Workflow:
|
||||
1. Receive minimal event from service
|
||||
2. Enrich with context (AI, priority, impact, etc.)
|
||||
3. Store in database
|
||||
4. Send to notification service
|
||||
5. Publish to SSE stream
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.connection: Connection = None
|
||||
self.channel: Channel = None
|
||||
self.enricher = EnrichmentOrchestrator()
|
||||
self.notification_client = create_notification_client(settings)
|
||||
self.sse_svc = SSEService()
|
||||
|
||||
async def start(self):
|
||||
"""Start consuming events from RabbitMQ"""
|
||||
try:
|
||||
# Connect to RabbitMQ
|
||||
self.connection = await connect_robust(
|
||||
settings.RABBITMQ_URL,
|
||||
client_properties={"connection_name": "alert-processor"}
|
||||
)
|
||||
|
||||
self.channel = await self.connection.channel()
|
||||
await self.channel.set_qos(prefetch_count=10)
|
||||
|
||||
# Declare queue
|
||||
queue = await self.channel.declare_queue(
|
||||
settings.RABBITMQ_QUEUE,
|
||||
durable=True
|
||||
)
|
||||
|
||||
# Bind to events exchange with routing patterns
|
||||
exchange = await self.channel.declare_exchange(
|
||||
settings.RABBITMQ_EXCHANGE,
|
||||
"topic",
|
||||
durable=True
|
||||
)
|
||||
|
||||
# Bind to alert, notification, and recommendation events
|
||||
await queue.bind(exchange, routing_key="alert.#")
|
||||
await queue.bind(exchange, routing_key="notification.#")
|
||||
await queue.bind(exchange, routing_key="recommendation.#")
|
||||
|
||||
# Start consuming
|
||||
await queue.consume(self.process_message)
|
||||
|
||||
logger.info(
|
||||
"event_consumer_started",
|
||||
queue=settings.RABBITMQ_QUEUE,
|
||||
exchange=settings.RABBITMQ_EXCHANGE
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("consumer_start_failed", error=str(e))
|
||||
raise
|
||||
|
||||
async def process_message(self, message: IncomingMessage):
|
||||
"""
|
||||
Process incoming event message.
|
||||
|
||||
Steps:
|
||||
1. Parse message
|
||||
2. Validate as MinimalEvent
|
||||
3. Enrich event
|
||||
4. Store in database
|
||||
5. Send notification
|
||||
6. Publish to SSE
|
||||
7. Acknowledge message
|
||||
"""
|
||||
async with message.process():
|
||||
try:
|
||||
# Parse message
|
||||
data = json.loads(message.body.decode())
|
||||
event = MinimalEvent(**data)
|
||||
|
||||
logger.info(
|
||||
"event_received",
|
||||
event_type=event.event_type,
|
||||
event_class=event.event_class,
|
||||
tenant_id=event.tenant_id
|
||||
)
|
||||
|
||||
# Enrich the event
|
||||
enriched_event = await self.enricher.enrich_event(event)
|
||||
|
||||
# Check for duplicate alerts before storing
|
||||
async with AsyncSessionLocal() as session:
|
||||
repo = EventRepository(session)
|
||||
|
||||
# Check for duplicate if it's an alert
|
||||
if event.event_class == "alert":
|
||||
from uuid import UUID
|
||||
duplicate_event = await repo.check_duplicate_alert(
|
||||
tenant_id=UUID(event.tenant_id),
|
||||
event_type=event.event_type,
|
||||
entity_links=enriched_event.entity_links,
|
||||
event_metadata=enriched_event.event_metadata,
|
||||
time_window_hours=24 # Check for duplicates in last 24 hours
|
||||
)
|
||||
|
||||
if duplicate_event:
|
||||
logger.info(
|
||||
"Duplicate alert detected, skipping",
|
||||
event_type=event.event_type,
|
||||
tenant_id=event.tenant_id,
|
||||
duplicate_event_id=str(duplicate_event.id)
|
||||
)
|
||||
# Update the existing event's metadata instead of creating a new one
|
||||
# This could include updating delay times, affected orders, etc.
|
||||
duplicate_event.event_metadata = enriched_event.event_metadata
|
||||
duplicate_event.updated_at = datetime.now(timezone.utc)
|
||||
duplicate_event.priority_score = enriched_event.priority_score
|
||||
duplicate_event.priority_level = enriched_event.priority_level
|
||||
|
||||
# Update other relevant fields that might have changed
|
||||
duplicate_event.urgency = enriched_event.urgency.dict() if enriched_event.urgency else None
|
||||
duplicate_event.business_impact = enriched_event.business_impact.dict() if enriched_event.business_impact else None
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(duplicate_event)
|
||||
|
||||
# Send notification for updated event
|
||||
await self._send_notification(duplicate_event)
|
||||
|
||||
# Publish to SSE
|
||||
await self.sse_svc.publish_event(duplicate_event)
|
||||
|
||||
logger.info(
|
||||
"Duplicate alert updated",
|
||||
event_id=str(duplicate_event.id),
|
||||
event_type=event.event_type,
|
||||
priority_level=duplicate_event.priority_level,
|
||||
priority_score=duplicate_event.priority_score
|
||||
)
|
||||
return # Exit early since we handled the duplicate
|
||||
else:
|
||||
logger.info(
|
||||
"New unique alert, proceeding with creation",
|
||||
event_type=event.event_type,
|
||||
tenant_id=event.tenant_id
|
||||
)
|
||||
|
||||
# Store in database (if not a duplicate)
|
||||
stored_event = await repo.create_event(enriched_event)
|
||||
|
||||
# Send to notification service (if alert)
|
||||
if event.event_class == "alert":
|
||||
await self._send_notification(stored_event)
|
||||
|
||||
# Publish to SSE
|
||||
await self.sse_svc.publish_event(stored_event)
|
||||
|
||||
logger.info(
|
||||
"event_processed",
|
||||
event_id=stored_event.id,
|
||||
event_type=event.event_type,
|
||||
priority_level=stored_event.priority_level,
|
||||
priority_score=stored_event.priority_score
|
||||
)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(
|
||||
"message_parse_failed",
|
||||
error=str(e),
|
||||
message_body=message.body[:200]
|
||||
)
|
||||
# Don't requeue - bad message format
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"event_processing_failed",
|
||||
error=str(e),
|
||||
exc_info=True
|
||||
)
|
||||
# Message will be requeued automatically due to exception
|
||||
|
||||
async def _send_notification(self, event):
|
||||
"""
|
||||
Send notification using the shared notification client.
|
||||
|
||||
Args:
|
||||
event: The event to send as a notification
|
||||
"""
|
||||
try:
|
||||
# Prepare notification message
|
||||
# Use i18n title and message from the event as the notification content
|
||||
title = event.i18n_title_key if event.i18n_title_key else f"Alert: {event.event_type}"
|
||||
message = event.i18n_message_key if event.i18n_message_key else f"New alert: {event.event_type}"
|
||||
|
||||
# Add parameters to make it more informative
|
||||
if event.i18n_title_params:
|
||||
title += f" - {event.i18n_title_params}"
|
||||
if event.i18n_message_params:
|
||||
message += f" - {event.i18n_message_params}"
|
||||
|
||||
# Prepare metadata from the event
|
||||
metadata = {
|
||||
"event_id": str(event.id),
|
||||
"event_type": event.event_type,
|
||||
"event_domain": event.event_domain,
|
||||
"priority_score": event.priority_score,
|
||||
"priority_level": event.priority_level,
|
||||
"status": event.status,
|
||||
"created_at": event.created_at.isoformat() if event.created_at else None,
|
||||
"type_class": event.type_class,
|
||||
"smart_actions": event.smart_actions,
|
||||
"entity_links": event.entity_links
|
||||
}
|
||||
|
||||
# Determine notification priority based on event priority
|
||||
priority_map = {
|
||||
"critical": "urgent",
|
||||
"important": "high",
|
||||
"standard": "normal",
|
||||
"info": "low"
|
||||
}
|
||||
priority = priority_map.get(event.priority_level, "normal")
|
||||
|
||||
# Send notification using shared client
|
||||
result = await self.notification_client.send_notification(
|
||||
tenant_id=str(event.tenant_id),
|
||||
notification_type="in_app", # Using in-app notification by default
|
||||
message=message,
|
||||
subject=title,
|
||||
priority=priority,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
if result:
|
||||
logger.info(
|
||||
"notification_sent_via_shared_client",
|
||||
event_id=str(event.id),
|
||||
tenant_id=str(event.tenant_id),
|
||||
priority_level=event.priority_level
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"notification_failed_via_shared_client",
|
||||
event_id=str(event.id),
|
||||
tenant_id=str(event.tenant_id)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"notification_error_via_shared_client",
|
||||
error=str(e),
|
||||
event_id=str(event.id),
|
||||
tenant_id=str(event.tenant_id)
|
||||
)
|
||||
# Don't re-raise - we don't want to fail the entire event processing
|
||||
# if notification sending fails
|
||||
|
||||
async def stop(self):
|
||||
"""Stop consumer and close connections"""
|
||||
try:
|
||||
if self.channel:
|
||||
await self.channel.close()
|
||||
logger.info("rabbitmq_channel_closed")
|
||||
|
||||
if self.connection:
|
||||
await self.connection.close()
|
||||
logger.info("rabbitmq_connection_closed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("consumer_stop_failed", error=str(e))
|
||||
0
services/alert_processor/app/core/__init__.py
Normal file
0
services/alert_processor/app/core/__init__.py
Normal file
51
services/alert_processor/app/core/config.py
Normal file
51
services/alert_processor/app/core/config.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""
|
||||
Configuration settings for alert processor service.
|
||||
"""
|
||||
|
||||
import os
|
||||
from shared.config.base import BaseServiceSettings
|
||||
|
||||
|
||||
class Settings(BaseServiceSettings):
|
||||
"""Application settings"""
|
||||
|
||||
# Service info - override defaults
|
||||
SERVICE_NAME: str = "alert-processor"
|
||||
APP_NAME: str = "Alert Processor Service"
|
||||
DESCRIPTION: str = "Central alert and recommendation processor"
|
||||
VERSION: str = "2.0.0"
|
||||
|
||||
# Alert processor specific settings
|
||||
RABBITMQ_EXCHANGE: str = "events.exchange"
|
||||
RABBITMQ_QUEUE: str = "alert_processor.queue"
|
||||
REDIS_SSE_PREFIX: str = "alerts"
|
||||
ORCHESTRATOR_TIMEOUT: int = 10
|
||||
NOTIFICATION_TIMEOUT: int = 5
|
||||
CACHE_ENABLED: bool = True
|
||||
CACHE_TTL_SECONDS: int = 300
|
||||
|
||||
@property
|
||||
def NOTIFICATION_URL(self) -> str:
|
||||
"""Get notification service URL for backwards compatibility"""
|
||||
return self.NOTIFICATION_SERVICE_URL
|
||||
|
||||
# Database configuration (secure approach - build from components)
|
||||
@property
|
||||
def DATABASE_URL(self) -> str:
|
||||
"""Build database URL from secure components"""
|
||||
# Try complete URL first (for backward compatibility)
|
||||
complete_url = os.getenv("ALERT_PROCESSOR_DATABASE_URL")
|
||||
if complete_url:
|
||||
return complete_url
|
||||
|
||||
# Build from components (secure approach)
|
||||
user = os.getenv("ALERT_PROCESSOR_DB_USER", "alert_processor_user")
|
||||
password = os.getenv("ALERT_PROCESSOR_DB_PASSWORD", "alert_processor_pass123")
|
||||
host = os.getenv("ALERT_PROCESSOR_DB_HOST", "alert-processor-db-service")
|
||||
port = os.getenv("ALERT_PROCESSOR_DB_PORT", "5432")
|
||||
name = os.getenv("ALERT_PROCESSOR_DB_NAME", "alert_processor_db")
|
||||
|
||||
return f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{name}"
|
||||
|
||||
|
||||
settings = Settings()
|
||||
48
services/alert_processor/app/core/database.py
Normal file
48
services/alert_processor/app/core/database.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
Database connection and session management for Alert Processor Service
|
||||
"""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from .config import settings
|
||||
|
||||
from shared.database.base import DatabaseManager
|
||||
|
||||
# Initialize database manager
|
||||
database_manager = DatabaseManager(
|
||||
database_url=settings.DATABASE_URL,
|
||||
service_name=settings.SERVICE_NAME,
|
||||
pool_size=settings.DB_POOL_SIZE,
|
||||
max_overflow=settings.DB_MAX_OVERFLOW,
|
||||
echo=settings.DEBUG
|
||||
)
|
||||
|
||||
# Create async session factory
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
database_manager.async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
|
||||
async def get_db() -> AsyncSession:
|
||||
"""
|
||||
Dependency to get database session.
|
||||
Used in FastAPI endpoints via Depends(get_db).
|
||||
"""
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def init_db():
|
||||
"""Initialize database (create tables if needed)"""
|
||||
await database_manager.create_all()
|
||||
|
||||
|
||||
async def close_db():
|
||||
"""Close database connections"""
|
||||
await database_manager.close()
|
||||
1
services/alert_processor/app/enrichment/__init__.py
Normal file
1
services/alert_processor/app/enrichment/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Enrichment components for alert processing."""
|
||||
156
services/alert_processor/app/enrichment/business_impact.py
Normal file
156
services/alert_processor/app/enrichment/business_impact.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
Business impact analyzer for alerts.
|
||||
|
||||
Calculates financial impact, affected orders, customer impact, and other
|
||||
business metrics from event metadata.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class BusinessImpactAnalyzer:
|
||||
"""Analyze business impact from event metadata"""
|
||||
|
||||
def analyze(self, event_type: str, metadata: Dict[str, Any]) -> dict:
|
||||
"""
|
||||
Analyze business impact for an event.
|
||||
|
||||
Returns dict with:
|
||||
- financial_impact_eur: Direct financial cost
|
||||
- affected_orders: Number of orders impacted
|
||||
- affected_customers: List of customer names
|
||||
- production_delay_hours: Hours of production delay
|
||||
- estimated_revenue_loss_eur: Potential revenue loss
|
||||
- customer_impact: high/medium/low
|
||||
- waste_risk_kg: Potential waste in kg
|
||||
"""
|
||||
|
||||
impact = {
|
||||
"financial_impact_eur": 0,
|
||||
"affected_orders": 0,
|
||||
"affected_customers": [],
|
||||
"production_delay_hours": 0,
|
||||
"estimated_revenue_loss_eur": 0,
|
||||
"customer_impact": "low",
|
||||
"waste_risk_kg": 0
|
||||
}
|
||||
|
||||
# Stock-related impacts
|
||||
if "stock" in event_type or "shortage" in event_type:
|
||||
impact.update(self._analyze_stock_impact(metadata))
|
||||
|
||||
# Production-related impacts
|
||||
elif "production" in event_type or "delay" in event_type or "equipment" in event_type:
|
||||
impact.update(self._analyze_production_impact(metadata))
|
||||
|
||||
# Procurement-related impacts
|
||||
elif "po_" in event_type or "delivery" in event_type:
|
||||
impact.update(self._analyze_procurement_impact(metadata))
|
||||
|
||||
# Quality-related impacts
|
||||
elif "quality" in event_type or "expired" in event_type:
|
||||
impact.update(self._analyze_quality_impact(metadata))
|
||||
|
||||
return impact
|
||||
|
||||
def _analyze_stock_impact(self, metadata: Dict[str, Any]) -> dict:
|
||||
"""Analyze impact of stock-related alerts"""
|
||||
impact = {}
|
||||
|
||||
# Calculate financial impact
|
||||
shortage_amount = metadata.get("shortage_amount", 0)
|
||||
unit_cost = metadata.get("unit_cost", 5) # Default €5/kg
|
||||
impact["financial_impact_eur"] = float(shortage_amount) * unit_cost
|
||||
|
||||
# Affected orders from metadata
|
||||
impact["affected_orders"] = metadata.get("affected_orders", 0)
|
||||
|
||||
# Customer impact based on affected orders
|
||||
if impact["affected_orders"] > 5:
|
||||
impact["customer_impact"] = "high"
|
||||
elif impact["affected_orders"] > 2:
|
||||
impact["customer_impact"] = "medium"
|
||||
|
||||
# Revenue loss (estimated)
|
||||
avg_order_value = 50 # €50 per order
|
||||
impact["estimated_revenue_loss_eur"] = impact["affected_orders"] * avg_order_value
|
||||
|
||||
return impact
|
||||
|
||||
def _analyze_production_impact(self, metadata: Dict[str, Any]) -> dict:
|
||||
"""Analyze impact of production-related alerts"""
|
||||
impact = {}
|
||||
|
||||
# Delay minutes to hours
|
||||
delay_minutes = metadata.get("delay_minutes", 0)
|
||||
impact["production_delay_hours"] = round(delay_minutes / 60, 1)
|
||||
|
||||
# Affected orders and customers
|
||||
impact["affected_orders"] = metadata.get("affected_orders", 0)
|
||||
|
||||
customer_names = metadata.get("customer_names", [])
|
||||
impact["affected_customers"] = customer_names
|
||||
|
||||
# Customer impact based on delay
|
||||
if delay_minutes > 120: # 2+ hours
|
||||
impact["customer_impact"] = "high"
|
||||
elif delay_minutes > 60: # 1+ hours
|
||||
impact["customer_impact"] = "medium"
|
||||
|
||||
# Financial impact: hourly production cost
|
||||
hourly_cost = 100 # €100/hour operational cost
|
||||
impact["financial_impact_eur"] = impact["production_delay_hours"] * hourly_cost
|
||||
|
||||
# Revenue loss
|
||||
if impact["affected_orders"] > 0:
|
||||
avg_order_value = 50
|
||||
impact["estimated_revenue_loss_eur"] = impact["affected_orders"] * avg_order_value
|
||||
|
||||
return impact
|
||||
|
||||
def _analyze_procurement_impact(self, metadata: Dict[str, Any]) -> dict:
|
||||
"""Analyze impact of procurement-related alerts"""
|
||||
impact = {}
|
||||
|
||||
# Extract potential_loss_eur from reasoning_data.parameters
|
||||
reasoning_data = metadata.get("reasoning_data", {})
|
||||
parameters = reasoning_data.get("parameters", {})
|
||||
potential_loss_eur = parameters.get("potential_loss_eur")
|
||||
|
||||
# Use potential loss from reasoning as financial impact (what's at risk)
|
||||
# Fallback to PO amount only if reasoning data is not available
|
||||
if potential_loss_eur is not None:
|
||||
impact["financial_impact_eur"] = float(potential_loss_eur)
|
||||
else:
|
||||
po_amount = metadata.get("po_amount", metadata.get("total_amount", 0))
|
||||
impact["financial_impact_eur"] = float(po_amount)
|
||||
|
||||
# Days overdue affects customer impact
|
||||
days_overdue = metadata.get("days_overdue", 0)
|
||||
if days_overdue > 3:
|
||||
impact["customer_impact"] = "high"
|
||||
elif days_overdue > 1:
|
||||
impact["customer_impact"] = "medium"
|
||||
|
||||
return impact
|
||||
|
||||
def _analyze_quality_impact(self, metadata: Dict[str, Any]) -> dict:
|
||||
"""Analyze impact of quality-related alerts"""
|
||||
impact = {}
|
||||
|
||||
# Expired products
|
||||
expired_count = metadata.get("expired_count", 0)
|
||||
total_value = metadata.get("total_value", 0)
|
||||
|
||||
impact["financial_impact_eur"] = float(total_value)
|
||||
impact["waste_risk_kg"] = metadata.get("total_quantity_kg", 0)
|
||||
|
||||
if expired_count > 5:
|
||||
impact["customer_impact"] = "high"
|
||||
elif expired_count > 2:
|
||||
impact["customer_impact"] = "medium"
|
||||
|
||||
return impact
|
||||
244
services/alert_processor/app/enrichment/message_generator.py
Normal file
244
services/alert_processor/app/enrichment/message_generator.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""
|
||||
Message generator for creating i18n message keys and parameters.
|
||||
|
||||
Converts minimal event metadata into structured i18n format for frontend translation.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from datetime import datetime
|
||||
from app.utils.message_templates import ALERT_TEMPLATES, NOTIFICATION_TEMPLATES, RECOMMENDATION_TEMPLATES
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class MessageGenerator:
|
||||
"""Generates i18n message keys and parameters from event metadata"""
|
||||
|
||||
def generate_message(self, event_type: str, metadata: Dict[str, Any], event_class: str = "alert") -> dict:
|
||||
"""
|
||||
Generate i18n structure for frontend.
|
||||
|
||||
Args:
|
||||
event_type: Alert/notification/recommendation type
|
||||
metadata: Event metadata dictionary
|
||||
event_class: One of: alert, notification, recommendation
|
||||
|
||||
Returns:
|
||||
Dictionary with title_key, title_params, message_key, message_params
|
||||
"""
|
||||
|
||||
# Select appropriate template collection
|
||||
if event_class == "notification":
|
||||
templates = NOTIFICATION_TEMPLATES
|
||||
elif event_class == "recommendation":
|
||||
templates = RECOMMENDATION_TEMPLATES
|
||||
else:
|
||||
templates = ALERT_TEMPLATES
|
||||
|
||||
template = templates.get(event_type)
|
||||
|
||||
if not template:
|
||||
logger.warning("no_template_found", event_type=event_type, event_class=event_class)
|
||||
return self._generate_fallback(event_type, metadata)
|
||||
|
||||
# Build parameters from metadata
|
||||
title_params = self._build_params(template["title_params"], metadata)
|
||||
message_params = self._build_params(template["message_params"], metadata)
|
||||
|
||||
# Select message variant based on context
|
||||
message_key = self._select_message_variant(
|
||||
template["message_variants"],
|
||||
metadata
|
||||
)
|
||||
|
||||
return {
|
||||
"title_key": template["title_key"],
|
||||
"title_params": title_params,
|
||||
"message_key": message_key,
|
||||
"message_params": message_params
|
||||
}
|
||||
|
||||
def _generate_fallback(self, event_type: str, metadata: Dict[str, Any]) -> dict:
|
||||
"""Generate fallback message structure when template not found"""
|
||||
return {
|
||||
"title_key": "alerts.generic.title",
|
||||
"title_params": {},
|
||||
"message_key": "alerts.generic.message",
|
||||
"message_params": {
|
||||
"event_type": event_type,
|
||||
"metadata_summary": self._summarize_metadata(metadata)
|
||||
}
|
||||
}
|
||||
|
||||
def _summarize_metadata(self, metadata: Dict[str, Any]) -> str:
|
||||
"""Create human-readable summary of metadata"""
|
||||
# Take first 3 fields
|
||||
items = list(metadata.items())[:3]
|
||||
summary_parts = [f"{k}: {v}" for k, v in items]
|
||||
return ", ".join(summary_parts)
|
||||
|
||||
def _build_params(self, param_mapping: dict, metadata: dict) -> dict:
|
||||
"""
|
||||
Extract and transform parameters from metadata.
|
||||
|
||||
param_mapping format: {"display_param_name": "metadata_key"}
|
||||
"""
|
||||
params = {}
|
||||
|
||||
for param_key, metadata_key in param_mapping.items():
|
||||
if metadata_key in metadata:
|
||||
value = metadata[metadata_key]
|
||||
|
||||
# Apply transformations based on parameter suffix
|
||||
if param_key.endswith("_kg"):
|
||||
value = round(float(value), 1)
|
||||
elif param_key.endswith("_eur"):
|
||||
value = round(float(value), 2)
|
||||
elif param_key.endswith("_percentage"):
|
||||
value = round(float(value), 1)
|
||||
elif param_key.endswith("_date"):
|
||||
value = self._format_date(value)
|
||||
elif param_key.endswith("_day_name"):
|
||||
value = self._format_day_name(value)
|
||||
elif param_key.endswith("_datetime"):
|
||||
value = self._format_datetime(value)
|
||||
|
||||
params[param_key] = value
|
||||
|
||||
return params
|
||||
|
||||
def _select_message_variant(self, variants: dict, metadata: dict) -> str:
|
||||
"""
|
||||
Select appropriate message variant based on metadata context.
|
||||
|
||||
Checks for specific conditions in priority order.
|
||||
"""
|
||||
|
||||
# Check for PO-related variants
|
||||
if "po_id" in metadata:
|
||||
if metadata.get("po_status") == "pending_approval":
|
||||
variant = variants.get("with_po_pending")
|
||||
if variant:
|
||||
return variant
|
||||
else:
|
||||
variant = variants.get("with_po_created")
|
||||
if variant:
|
||||
return variant
|
||||
|
||||
# Check for time-based variants
|
||||
if "hours_until" in metadata:
|
||||
variant = variants.get("with_hours")
|
||||
if variant:
|
||||
return variant
|
||||
|
||||
if "production_date" in metadata or "planned_date" in metadata:
|
||||
variant = variants.get("with_date")
|
||||
if variant:
|
||||
return variant
|
||||
|
||||
# Check for customer-related variants
|
||||
if "customer_names" in metadata and metadata.get("customer_names"):
|
||||
variant = variants.get("with_customers")
|
||||
if variant:
|
||||
return variant
|
||||
|
||||
# Check for order-related variants
|
||||
if "affected_orders" in metadata and metadata.get("affected_orders", 0) > 0:
|
||||
variant = variants.get("with_orders")
|
||||
if variant:
|
||||
return variant
|
||||
|
||||
# Check for supplier contact variants
|
||||
if "supplier_contact" in metadata:
|
||||
variant = variants.get("with_supplier")
|
||||
if variant:
|
||||
return variant
|
||||
|
||||
# Check for batch-related variants
|
||||
if "affected_batches" in metadata and metadata.get("affected_batches", 0) > 0:
|
||||
variant = variants.get("with_batches")
|
||||
if variant:
|
||||
return variant
|
||||
|
||||
# Check for product names list variants
|
||||
if "product_names" in metadata and metadata.get("product_names"):
|
||||
variant = variants.get("with_names")
|
||||
if variant:
|
||||
return variant
|
||||
|
||||
# Check for time duration variants
|
||||
if "hours_overdue" in metadata:
|
||||
variant = variants.get("with_hours")
|
||||
if variant:
|
||||
return variant
|
||||
|
||||
if "days_overdue" in metadata:
|
||||
variant = variants.get("with_days")
|
||||
if variant:
|
||||
return variant
|
||||
|
||||
# Default to generic variant
|
||||
return variants.get("generic", variants[list(variants.keys())[0]])
|
||||
|
||||
def _format_date(self, date_value: Any) -> str:
|
||||
"""
|
||||
Format date for display.
|
||||
|
||||
Accepts:
|
||||
- ISO string: "2025-12-10"
|
||||
- datetime object
|
||||
- date object
|
||||
|
||||
Returns: ISO format "YYYY-MM-DD"
|
||||
"""
|
||||
if isinstance(date_value, str):
|
||||
# Already a string, might be ISO format
|
||||
try:
|
||||
dt = datetime.fromisoformat(date_value.replace('Z', '+00:00'))
|
||||
return dt.date().isoformat()
|
||||
except:
|
||||
return date_value
|
||||
|
||||
if isinstance(date_value, datetime):
|
||||
return date_value.date().isoformat()
|
||||
|
||||
if hasattr(date_value, 'isoformat'):
|
||||
return date_value.isoformat()
|
||||
|
||||
return str(date_value)
|
||||
|
||||
def _format_day_name(self, date_value: Any) -> str:
|
||||
"""
|
||||
Format day name with date.
|
||||
|
||||
Example: "miércoles 10 de diciembre"
|
||||
|
||||
Note: Frontend will handle localization.
|
||||
For now, return ISO date and let frontend format.
|
||||
"""
|
||||
iso_date = self._format_date(date_value)
|
||||
|
||||
try:
|
||||
dt = datetime.fromisoformat(iso_date)
|
||||
# Frontend will use this to format in user's language
|
||||
return iso_date
|
||||
except:
|
||||
return iso_date
|
||||
|
||||
def _format_datetime(self, datetime_value: Any) -> str:
|
||||
"""
|
||||
Format datetime for display.
|
||||
|
||||
Returns: ISO 8601 format with timezone
|
||||
"""
|
||||
if isinstance(datetime_value, str):
|
||||
return datetime_value
|
||||
|
||||
if isinstance(datetime_value, datetime):
|
||||
return datetime_value.isoformat()
|
||||
|
||||
if hasattr(datetime_value, 'isoformat'):
|
||||
return datetime_value.isoformat()
|
||||
|
||||
return str(datetime_value)
|
||||
165
services/alert_processor/app/enrichment/orchestrator_client.py
Normal file
165
services/alert_processor/app/enrichment/orchestrator_client.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
Orchestrator client for querying AI action context.
|
||||
|
||||
Queries the orchestrator service to determine if AI has already
|
||||
addressed the issue and what actions were taken.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
import httpx
|
||||
import structlog
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class OrchestratorClient:
|
||||
"""HTTP client for querying orchestrator service"""
|
||||
|
||||
def __init__(self, base_url: str = "http://orchestrator-service:8000"):
|
||||
"""
|
||||
Initialize orchestrator client.
|
||||
|
||||
Args:
|
||||
base_url: Base URL of orchestrator service
|
||||
"""
|
||||
self.base_url = base_url
|
||||
self.timeout = 10.0 # 10 second timeout
|
||||
|
||||
async def get_context(
|
||||
self,
|
||||
tenant_id: str,
|
||||
event_type: str,
|
||||
metadata: Dict[str, Any]
|
||||
) -> dict:
|
||||
"""
|
||||
Query orchestrator for AI action context.
|
||||
|
||||
Returns dict with:
|
||||
- already_addressed: Boolean - did AI handle this?
|
||||
- action_type: Type of action taken
|
||||
- action_id: ID of the action
|
||||
- action_summary: Human-readable summary
|
||||
- reasoning: AI reasoning for the action
|
||||
- confidence: Confidence score (0-1)
|
||||
- estimated_savings_eur: Estimated savings
|
||||
- prevented_issue: What issue was prevented
|
||||
- created_at: When action was created
|
||||
"""
|
||||
|
||||
context = {
|
||||
"already_addressed": False,
|
||||
"confidence": 0.8 # Default confidence
|
||||
}
|
||||
|
||||
try:
|
||||
# Build query based on event type and metadata
|
||||
query_params = self._build_query_params(event_type, metadata)
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/api/internal/recent-actions",
|
||||
params={
|
||||
"tenant_id": tenant_id,
|
||||
**query_params
|
||||
},
|
||||
headers={
|
||||
"x-internal-service": "alert-intelligence"
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
context.update(self._parse_response(data, event_type, metadata))
|
||||
|
||||
elif response.status_code == 404:
|
||||
# No recent actions found - that's okay
|
||||
logger.debug("no_orchestrator_actions", tenant_id=tenant_id, event_type=event_type)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
"orchestrator_query_failed",
|
||||
status_code=response.status_code,
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.warning("orchestrator_timeout", tenant_id=tenant_id, event_type=event_type)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("orchestrator_query_error", error=str(e), tenant_id=tenant_id)
|
||||
|
||||
return context
|
||||
|
||||
def _build_query_params(self, event_type: str, metadata: Dict[str, Any]) -> dict:
|
||||
"""Build query parameters based on event type"""
|
||||
params = {}
|
||||
|
||||
# For stock-related alerts, query for PO actions
|
||||
if "stock" in event_type or "shortage" in event_type:
|
||||
if metadata.get("ingredient_id"):
|
||||
params["related_entity_type"] = "ingredient"
|
||||
params["related_entity_id"] = metadata["ingredient_id"]
|
||||
params["action_types"] = "purchase_order_created,purchase_order_approved"
|
||||
|
||||
# For production delays, query for batch adjustments
|
||||
elif "production" in event_type or "delay" in event_type:
|
||||
if metadata.get("batch_id"):
|
||||
params["related_entity_type"] = "production_batch"
|
||||
params["related_entity_id"] = metadata["batch_id"]
|
||||
params["action_types"] = "production_adjusted,batch_rescheduled"
|
||||
|
||||
# For PO approval, check if already approved
|
||||
elif "po_approval" in event_type:
|
||||
if metadata.get("po_id"):
|
||||
params["related_entity_type"] = "purchase_order"
|
||||
params["related_entity_id"] = metadata["po_id"]
|
||||
params["action_types"] = "purchase_order_approved,purchase_order_rejected"
|
||||
|
||||
# Look for recent actions (last 24 hours)
|
||||
params["since_hours"] = 24
|
||||
|
||||
return params
|
||||
|
||||
def _parse_response(
|
||||
self,
|
||||
data: dict,
|
||||
event_type: str,
|
||||
metadata: Dict[str, Any]
|
||||
) -> dict:
|
||||
"""Parse orchestrator response into context"""
|
||||
|
||||
if not data or not data.get("actions"):
|
||||
return {"already_addressed": False}
|
||||
|
||||
# Get most recent action
|
||||
actions = data.get("actions", [])
|
||||
if not actions:
|
||||
return {"already_addressed": False}
|
||||
|
||||
most_recent = actions[0]
|
||||
|
||||
context = {
|
||||
"already_addressed": True,
|
||||
"action_type": most_recent.get("action_type"),
|
||||
"action_id": most_recent.get("id"),
|
||||
"action_summary": most_recent.get("summary", ""),
|
||||
"reasoning": most_recent.get("reasoning", {}),
|
||||
"confidence": most_recent.get("confidence", 0.8),
|
||||
"created_at": most_recent.get("created_at"),
|
||||
"action_status": most_recent.get("status", "completed")
|
||||
}
|
||||
|
||||
# Extract specific fields based on action type
|
||||
if most_recent.get("action_type") == "purchase_order_created":
|
||||
context["estimated_savings_eur"] = most_recent.get("estimated_savings_eur", 0)
|
||||
context["prevented_issue"] = "stockout"
|
||||
|
||||
if most_recent.get("delivery_date"):
|
||||
context["delivery_date"] = most_recent["delivery_date"]
|
||||
|
||||
elif most_recent.get("action_type") == "production_adjusted":
|
||||
context["prevented_issue"] = "production_delay"
|
||||
context["adjustment_type"] = most_recent.get("adjustment_type")
|
||||
|
||||
return context
|
||||
256
services/alert_processor/app/enrichment/priority_scorer.py
Normal file
256
services/alert_processor/app/enrichment/priority_scorer.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""
|
||||
Multi-factor priority scoring for alerts.
|
||||
|
||||
Calculates priority score (0-100) based on:
|
||||
- Business impact (40%): Financial impact, affected orders, customer impact
|
||||
- Urgency (30%): Time until consequence, deadlines
|
||||
- User agency (20%): Can user fix it? External dependencies?
|
||||
- Confidence (10%): AI confidence in assessment
|
||||
|
||||
Also applies escalation boosts for age and deadline proximity.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class PriorityScorer:
|
||||
"""Calculate multi-factor priority score (0-100)"""
|
||||
|
||||
# Weights for priority calculation
|
||||
BUSINESS_IMPACT_WEIGHT = 0.4
|
||||
URGENCY_WEIGHT = 0.3
|
||||
USER_AGENCY_WEIGHT = 0.2
|
||||
CONFIDENCE_WEIGHT = 0.1
|
||||
|
||||
# Priority thresholds
|
||||
CRITICAL_THRESHOLD = 90
|
||||
IMPORTANT_THRESHOLD = 70
|
||||
STANDARD_THRESHOLD = 50
|
||||
|
||||
def calculate_priority(
|
||||
self,
|
||||
business_impact: dict,
|
||||
urgency: dict,
|
||||
user_agency: dict,
|
||||
orchestrator_context: dict
|
||||
) -> int:
|
||||
"""
|
||||
Calculate weighted priority score.
|
||||
|
||||
Args:
|
||||
business_impact: Business impact context
|
||||
urgency: Urgency context
|
||||
user_agency: User agency context
|
||||
orchestrator_context: AI orchestrator context
|
||||
|
||||
Returns:
|
||||
Priority score (0-100)
|
||||
"""
|
||||
|
||||
# Score each dimension (0-100)
|
||||
impact_score = self._score_business_impact(business_impact)
|
||||
urgency_score = self._score_urgency(urgency)
|
||||
agency_score = self._score_user_agency(user_agency)
|
||||
confidence_score = orchestrator_context.get("confidence", 0.8) * 100
|
||||
|
||||
# Weighted average
|
||||
total_score = (
|
||||
impact_score * self.BUSINESS_IMPACT_WEIGHT +
|
||||
urgency_score * self.URGENCY_WEIGHT +
|
||||
agency_score * self.USER_AGENCY_WEIGHT +
|
||||
confidence_score * self.CONFIDENCE_WEIGHT
|
||||
)
|
||||
|
||||
# Apply escalation boost if needed
|
||||
escalation_boost = self._calculate_escalation_boost(urgency)
|
||||
total_score = min(100, total_score + escalation_boost)
|
||||
|
||||
score = int(total_score)
|
||||
|
||||
logger.debug(
|
||||
"priority_calculated",
|
||||
score=score,
|
||||
impact_score=impact_score,
|
||||
urgency_score=urgency_score,
|
||||
agency_score=agency_score,
|
||||
confidence_score=confidence_score,
|
||||
escalation_boost=escalation_boost
|
||||
)
|
||||
|
||||
return score
|
||||
|
||||
def _score_business_impact(self, impact: dict) -> int:
|
||||
"""
|
||||
Score business impact (0-100).
|
||||
|
||||
Considers:
|
||||
- Financial impact in EUR
|
||||
- Number of affected orders
|
||||
- Customer impact level
|
||||
- Production delays
|
||||
- Revenue at risk
|
||||
"""
|
||||
score = 50 # Base score
|
||||
|
||||
# Financial impact
|
||||
financial_impact = impact.get("financial_impact_eur", 0)
|
||||
if financial_impact > 1000:
|
||||
score += 30
|
||||
elif financial_impact > 500:
|
||||
score += 20
|
||||
elif financial_impact > 100:
|
||||
score += 10
|
||||
|
||||
# Affected orders
|
||||
affected_orders = impact.get("affected_orders", 0)
|
||||
if affected_orders > 10:
|
||||
score += 15
|
||||
elif affected_orders > 5:
|
||||
score += 10
|
||||
elif affected_orders > 0:
|
||||
score += 5
|
||||
|
||||
# Customer impact
|
||||
customer_impact = impact.get("customer_impact", "low")
|
||||
if customer_impact == "high":
|
||||
score += 15
|
||||
elif customer_impact == "medium":
|
||||
score += 5
|
||||
|
||||
# Production delay hours
|
||||
production_delay_hours = impact.get("production_delay_hours", 0)
|
||||
if production_delay_hours > 4:
|
||||
score += 10
|
||||
elif production_delay_hours > 2:
|
||||
score += 5
|
||||
|
||||
# Revenue loss
|
||||
revenue_loss = impact.get("estimated_revenue_loss_eur", 0)
|
||||
if revenue_loss > 500:
|
||||
score += 10
|
||||
elif revenue_loss > 200:
|
||||
score += 5
|
||||
|
||||
return min(100, score)
|
||||
|
||||
def _score_urgency(self, urgency: dict) -> int:
|
||||
"""
|
||||
Score urgency (0-100).
|
||||
|
||||
Considers:
|
||||
- Time until consequence
|
||||
- Can it wait until tomorrow?
|
||||
- Deadline proximity
|
||||
- Peak hour relevance
|
||||
"""
|
||||
score = 50 # Base score
|
||||
|
||||
# Time until consequence
|
||||
hours_until = urgency.get("hours_until_consequence", 24)
|
||||
if hours_until < 2:
|
||||
score += 40
|
||||
elif hours_until < 6:
|
||||
score += 30
|
||||
elif hours_until < 12:
|
||||
score += 20
|
||||
elif hours_until < 24:
|
||||
score += 10
|
||||
|
||||
# Can it wait?
|
||||
if not urgency.get("can_wait_until_tomorrow", True):
|
||||
score += 10
|
||||
|
||||
# Deadline present
|
||||
if urgency.get("deadline_utc"):
|
||||
score += 5
|
||||
|
||||
# Peak hour relevant (production/demand related)
|
||||
if urgency.get("peak_hour_relevant", False):
|
||||
score += 5
|
||||
|
||||
return min(100, score)
|
||||
|
||||
def _score_user_agency(self, agency: dict) -> int:
|
||||
"""
|
||||
Score user agency (0-100).
|
||||
|
||||
Higher score when user CAN fix the issue.
|
||||
Lower score when blocked or requires external parties.
|
||||
|
||||
Considers:
|
||||
- Can user fix it?
|
||||
- Requires external party?
|
||||
- Has blockers?
|
||||
- Suggested workarounds available?
|
||||
"""
|
||||
score = 50 # Base score
|
||||
|
||||
# Can user fix?
|
||||
if agency.get("can_user_fix", False):
|
||||
score += 30
|
||||
else:
|
||||
score -= 20
|
||||
|
||||
# Requires external party?
|
||||
if agency.get("requires_external_party", False):
|
||||
score -= 10
|
||||
|
||||
# Has blockers?
|
||||
blockers = agency.get("blockers", [])
|
||||
score -= len(blockers) * 5
|
||||
|
||||
# Has suggested workaround?
|
||||
if agency.get("suggested_workaround"):
|
||||
score += 5
|
||||
|
||||
return max(0, min(100, score))
|
||||
|
||||
def _calculate_escalation_boost(self, urgency: dict) -> int:
|
||||
"""
|
||||
Calculate escalation boost for pending alerts.
|
||||
|
||||
Boosts priority for:
|
||||
- Age-based escalation (pending >48h, >72h)
|
||||
- Deadline proximity (<6h, <24h)
|
||||
|
||||
Maximum boost: +30 points
|
||||
"""
|
||||
boost = 0
|
||||
|
||||
# Age-based escalation
|
||||
hours_pending = urgency.get("hours_pending", 0)
|
||||
if hours_pending > 72:
|
||||
boost += 20
|
||||
elif hours_pending > 48:
|
||||
boost += 10
|
||||
|
||||
# Deadline proximity
|
||||
hours_until = urgency.get("hours_until_consequence", 24)
|
||||
if hours_until < 6:
|
||||
boost += 30
|
||||
elif hours_until < 24:
|
||||
boost += 15
|
||||
|
||||
# Cap at +30
|
||||
return min(30, boost)
|
||||
|
||||
def get_priority_level(self, score: int) -> str:
|
||||
"""
|
||||
Convert numeric score to priority level.
|
||||
|
||||
- 90-100: critical
|
||||
- 70-89: important
|
||||
- 50-69: standard
|
||||
- 0-49: info
|
||||
"""
|
||||
if score >= self.CRITICAL_THRESHOLD:
|
||||
return "critical"
|
||||
elif score >= self.IMPORTANT_THRESHOLD:
|
||||
return "important"
|
||||
elif score >= self.STANDARD_THRESHOLD:
|
||||
return "standard"
|
||||
else:
|
||||
return "info"
|
||||
304
services/alert_processor/app/enrichment/smart_actions.py
Normal file
304
services/alert_processor/app/enrichment/smart_actions.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
Smart action generator for alerts.
|
||||
|
||||
Generates actionable buttons with deep links, phone numbers,
|
||||
and other interactive elements based on alert type and metadata.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class SmartActionGenerator:
|
||||
"""Generate smart action buttons for alerts"""
|
||||
|
||||
def generate_actions(
|
||||
self,
|
||||
event_type: str,
|
||||
metadata: Dict[str, Any],
|
||||
orchestrator_context: dict
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Generate smart actions for an event.
|
||||
|
||||
Each action has:
|
||||
- action_type: Identifier for frontend handling
|
||||
- label_key: i18n key for button label
|
||||
- label_params: Parameters for label translation
|
||||
- variant: primary/secondary/danger/ghost
|
||||
- disabled: Boolean
|
||||
- disabled_reason_key: i18n key if disabled
|
||||
- consequence_key: i18n key for confirmation dialog
|
||||
- url: Deep link or tel: or mailto:
|
||||
- metadata: Additional data for action
|
||||
"""
|
||||
|
||||
actions = []
|
||||
|
||||
# If AI already addressed, show "View Action" button
|
||||
if orchestrator_context and orchestrator_context.get("already_addressed"):
|
||||
actions.append(self._create_view_action(orchestrator_context))
|
||||
return actions
|
||||
|
||||
# Generate actions based on event type
|
||||
if "po_approval" in event_type:
|
||||
actions.extend(self._create_po_approval_actions(metadata))
|
||||
|
||||
elif "stock" in event_type or "shortage" in event_type:
|
||||
actions.extend(self._create_stock_actions(metadata))
|
||||
|
||||
elif "production" in event_type or "delay" in event_type:
|
||||
actions.extend(self._create_production_actions(metadata))
|
||||
|
||||
elif "equipment" in event_type:
|
||||
actions.extend(self._create_equipment_actions(metadata))
|
||||
|
||||
elif "delivery" in event_type or "overdue" in event_type:
|
||||
actions.extend(self._create_delivery_actions(metadata))
|
||||
|
||||
elif "temperature" in event_type:
|
||||
actions.extend(self._create_temperature_actions(metadata))
|
||||
|
||||
# Always add common actions
|
||||
actions.extend(self._create_common_actions())
|
||||
|
||||
return actions
|
||||
|
||||
def _create_view_action(self, orchestrator_context: dict) -> dict:
|
||||
"""Create action to view what AI did"""
|
||||
return {
|
||||
"action_type": "open_reasoning",
|
||||
"label_key": "actions.view_ai_action",
|
||||
"label_params": {},
|
||||
"variant": "primary",
|
||||
"disabled": False,
|
||||
"metadata": {
|
||||
"action_id": orchestrator_context.get("action_id"),
|
||||
"action_type": orchestrator_context.get("action_type")
|
||||
}
|
||||
}
|
||||
|
||||
def _create_po_approval_actions(self, metadata: Dict[str, Any]) -> List[dict]:
|
||||
"""Create actions for PO approval alerts"""
|
||||
po_id = metadata.get("po_id")
|
||||
po_amount = metadata.get("total_amount", metadata.get("po_amount", 0))
|
||||
|
||||
return [
|
||||
{
|
||||
"action_type": "approve_po",
|
||||
"label_key": "actions.approve_po",
|
||||
"label_params": {"amount": po_amount},
|
||||
"variant": "primary",
|
||||
"disabled": False,
|
||||
"consequence_key": "actions.approve_po_consequence",
|
||||
"url": f"/app/procurement/purchase-orders/{po_id}",
|
||||
"metadata": {"po_id": po_id, "amount": po_amount}
|
||||
},
|
||||
{
|
||||
"action_type": "reject_po",
|
||||
"label_key": "actions.reject_po",
|
||||
"label_params": {},
|
||||
"variant": "danger",
|
||||
"disabled": False,
|
||||
"consequence_key": "actions.reject_po_consequence",
|
||||
"url": f"/app/procurement/purchase-orders/{po_id}",
|
||||
"metadata": {"po_id": po_id}
|
||||
},
|
||||
{
|
||||
"action_type": "modify_po",
|
||||
"label_key": "actions.modify_po",
|
||||
"label_params": {},
|
||||
"variant": "secondary",
|
||||
"disabled": False,
|
||||
"url": f"/app/procurement/purchase-orders/{po_id}/edit",
|
||||
"metadata": {"po_id": po_id}
|
||||
}
|
||||
]
|
||||
|
||||
def _create_stock_actions(self, metadata: Dict[str, Any]) -> List[dict]:
|
||||
"""Create actions for stock-related alerts"""
|
||||
actions = []
|
||||
|
||||
# If supplier info available, add call button
|
||||
if metadata.get("supplier_contact"):
|
||||
actions.append({
|
||||
"action_type": "call_supplier",
|
||||
"label_key": "actions.call_supplier",
|
||||
"label_params": {
|
||||
"supplier": metadata.get("supplier_name", "Supplier"),
|
||||
"phone": metadata.get("supplier_contact")
|
||||
},
|
||||
"variant": "primary",
|
||||
"disabled": False,
|
||||
"url": f"tel:{metadata['supplier_contact']}",
|
||||
"metadata": {
|
||||
"supplier_name": metadata.get("supplier_name"),
|
||||
"phone": metadata.get("supplier_contact")
|
||||
}
|
||||
})
|
||||
|
||||
# If PO exists, add view PO button
|
||||
if metadata.get("po_id"):
|
||||
if metadata.get("po_status") == "pending_approval":
|
||||
actions.append({
|
||||
"action_type": "approve_po",
|
||||
"label_key": "actions.approve_po",
|
||||
"label_params": {"amount": metadata.get("po_amount", 0)},
|
||||
"variant": "primary",
|
||||
"disabled": False,
|
||||
"url": f"/app/procurement/purchase-orders/{metadata['po_id']}",
|
||||
"metadata": {"po_id": metadata["po_id"]}
|
||||
})
|
||||
else:
|
||||
actions.append({
|
||||
"action_type": "view_po",
|
||||
"label_key": "actions.view_po",
|
||||
"label_params": {"po_number": metadata.get("po_number", metadata["po_id"])},
|
||||
"variant": "secondary",
|
||||
"disabled": False,
|
||||
"url": f"/app/procurement/purchase-orders/{metadata['po_id']}",
|
||||
"metadata": {"po_id": metadata["po_id"]}
|
||||
})
|
||||
|
||||
# Add create PO button if no PO exists
|
||||
else:
|
||||
actions.append({
|
||||
"action_type": "create_po",
|
||||
"label_key": "actions.create_po",
|
||||
"label_params": {},
|
||||
"variant": "primary",
|
||||
"disabled": False,
|
||||
"url": f"/app/procurement/purchase-orders/new?ingredient_id={metadata.get('ingredient_id')}",
|
||||
"metadata": {"ingredient_id": metadata.get("ingredient_id")}
|
||||
})
|
||||
|
||||
return actions
|
||||
|
||||
def _create_production_actions(self, metadata: Dict[str, Any]) -> List[dict]:
|
||||
"""Create actions for production-related alerts"""
|
||||
actions = []
|
||||
|
||||
if metadata.get("batch_id"):
|
||||
actions.append({
|
||||
"action_type": "view_batch",
|
||||
"label_key": "actions.view_batch",
|
||||
"label_params": {"batch_number": metadata.get("batch_number", "")},
|
||||
"variant": "primary",
|
||||
"disabled": False,
|
||||
"url": f"/app/production/batches/{metadata['batch_id']}",
|
||||
"metadata": {"batch_id": metadata["batch_id"]}
|
||||
})
|
||||
|
||||
actions.append({
|
||||
"action_type": "adjust_production",
|
||||
"label_key": "actions.adjust_production",
|
||||
"label_params": {},
|
||||
"variant": "secondary",
|
||||
"disabled": False,
|
||||
"url": f"/app/production/batches/{metadata['batch_id']}/adjust",
|
||||
"metadata": {"batch_id": metadata["batch_id"]}
|
||||
})
|
||||
|
||||
return actions
|
||||
|
||||
def _create_equipment_actions(self, metadata: Dict[str, Any]) -> List[dict]:
|
||||
"""Create actions for equipment-related alerts"""
|
||||
return [
|
||||
{
|
||||
"action_type": "view_equipment",
|
||||
"label_key": "actions.view_equipment",
|
||||
"label_params": {"equipment_name": metadata.get("equipment_name", "")},
|
||||
"variant": "primary",
|
||||
"disabled": False,
|
||||
"url": f"/app/production/equipment/{metadata.get('equipment_id')}",
|
||||
"metadata": {"equipment_id": metadata.get("equipment_id")}
|
||||
},
|
||||
{
|
||||
"action_type": "schedule_maintenance",
|
||||
"label_key": "actions.schedule_maintenance",
|
||||
"label_params": {},
|
||||
"variant": "secondary",
|
||||
"disabled": False,
|
||||
"url": f"/app/production/equipment/{metadata.get('equipment_id')}/maintenance",
|
||||
"metadata": {"equipment_id": metadata.get("equipment_id")}
|
||||
}
|
||||
]
|
||||
|
||||
def _create_delivery_actions(self, metadata: Dict[str, Any]) -> List[dict]:
|
||||
"""Create actions for delivery-related alerts"""
|
||||
actions = []
|
||||
|
||||
if metadata.get("supplier_contact"):
|
||||
actions.append({
|
||||
"action_type": "call_supplier",
|
||||
"label_key": "actions.call_supplier",
|
||||
"label_params": {
|
||||
"supplier": metadata.get("supplier_name", "Supplier"),
|
||||
"phone": metadata.get("supplier_contact")
|
||||
},
|
||||
"variant": "primary",
|
||||
"disabled": False,
|
||||
"url": f"tel:{metadata['supplier_contact']}",
|
||||
"metadata": {
|
||||
"supplier_name": metadata.get("supplier_name"),
|
||||
"phone": metadata.get("supplier_contact")
|
||||
}
|
||||
})
|
||||
|
||||
if metadata.get("po_id"):
|
||||
actions.append({
|
||||
"action_type": "view_po",
|
||||
"label_key": "actions.view_po",
|
||||
"label_params": {"po_number": metadata.get("po_number", "")},
|
||||
"variant": "secondary",
|
||||
"disabled": False,
|
||||
"url": f"/app/procurement/purchase-orders/{metadata['po_id']}",
|
||||
"metadata": {"po_id": metadata["po_id"]}
|
||||
})
|
||||
|
||||
return actions
|
||||
|
||||
def _create_temperature_actions(self, metadata: Dict[str, Any]) -> List[dict]:
|
||||
"""Create actions for temperature breach alerts"""
|
||||
return [
|
||||
{
|
||||
"action_type": "view_sensor",
|
||||
"label_key": "actions.view_sensor",
|
||||
"label_params": {"location": metadata.get("location", "")},
|
||||
"variant": "primary",
|
||||
"disabled": False,
|
||||
"url": f"/app/inventory/sensors/{metadata.get('sensor_id')}",
|
||||
"metadata": {"sensor_id": metadata.get("sensor_id")}
|
||||
},
|
||||
{
|
||||
"action_type": "acknowledge_breach",
|
||||
"label_key": "actions.acknowledge_breach",
|
||||
"label_params": {},
|
||||
"variant": "secondary",
|
||||
"disabled": False,
|
||||
"metadata": {"sensor_id": metadata.get("sensor_id")}
|
||||
}
|
||||
]
|
||||
|
||||
def _create_common_actions(self) -> List[dict]:
|
||||
"""Create common actions available for all alerts"""
|
||||
return [
|
||||
{
|
||||
"action_type": "snooze",
|
||||
"label_key": "actions.snooze",
|
||||
"label_params": {"hours": 4},
|
||||
"variant": "ghost",
|
||||
"disabled": False,
|
||||
"metadata": {"snooze_hours": 4}
|
||||
},
|
||||
{
|
||||
"action_type": "dismiss",
|
||||
"label_key": "actions.dismiss",
|
||||
"label_params": {},
|
||||
"variant": "ghost",
|
||||
"disabled": False,
|
||||
"metadata": {}
|
||||
}
|
||||
]
|
||||
173
services/alert_processor/app/enrichment/urgency_analyzer.py
Normal file
173
services/alert_processor/app/enrichment/urgency_analyzer.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""
|
||||
Urgency analyzer for alerts.
|
||||
|
||||
Assesses time sensitivity, deadlines, and determines if action can wait.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class UrgencyAnalyzer:
|
||||
"""Analyze urgency from event metadata"""
|
||||
|
||||
def analyze(self, event_type: str, metadata: Dict[str, Any]) -> dict:
|
||||
"""
|
||||
Analyze urgency for an event.
|
||||
|
||||
Returns dict with:
|
||||
- hours_until_consequence: Time until impact occurs
|
||||
- can_wait_until_tomorrow: Boolean
|
||||
- deadline_utc: ISO datetime if deadline exists
|
||||
- peak_hour_relevant: Boolean
|
||||
- hours_pending: Age of alert
|
||||
"""
|
||||
|
||||
urgency = {
|
||||
"hours_until_consequence": 24, # Default: 24 hours
|
||||
"can_wait_until_tomorrow": True,
|
||||
"deadline_utc": None,
|
||||
"peak_hour_relevant": False,
|
||||
"hours_pending": 0
|
||||
}
|
||||
|
||||
# Calculate based on event type
|
||||
if "critical" in event_type or "urgent" in event_type:
|
||||
urgency["hours_until_consequence"] = 2
|
||||
urgency["can_wait_until_tomorrow"] = False
|
||||
|
||||
elif "production" in event_type:
|
||||
urgency.update(self._analyze_production_urgency(metadata))
|
||||
|
||||
elif "stock" in event_type or "shortage" in event_type:
|
||||
urgency.update(self._analyze_stock_urgency(metadata))
|
||||
|
||||
elif "delivery" in event_type or "overdue" in event_type:
|
||||
urgency.update(self._analyze_delivery_urgency(metadata))
|
||||
|
||||
elif "po_approval" in event_type:
|
||||
urgency.update(self._analyze_po_approval_urgency(metadata))
|
||||
|
||||
# Check for explicit deadlines
|
||||
if "required_delivery_date" in metadata:
|
||||
urgency.update(self._calculate_deadline_urgency(metadata["required_delivery_date"]))
|
||||
|
||||
if "production_date" in metadata:
|
||||
urgency.update(self._calculate_deadline_urgency(metadata["production_date"]))
|
||||
|
||||
if "expected_date" in metadata:
|
||||
urgency.update(self._calculate_deadline_urgency(metadata["expected_date"]))
|
||||
|
||||
return urgency
|
||||
|
||||
def _analyze_production_urgency(self, metadata: Dict[str, Any]) -> dict:
|
||||
"""Analyze urgency for production alerts"""
|
||||
urgency = {}
|
||||
|
||||
delay_minutes = metadata.get("delay_minutes", 0)
|
||||
|
||||
if delay_minutes > 120:
|
||||
urgency["hours_until_consequence"] = 1
|
||||
urgency["can_wait_until_tomorrow"] = False
|
||||
elif delay_minutes > 60:
|
||||
urgency["hours_until_consequence"] = 4
|
||||
urgency["can_wait_until_tomorrow"] = False
|
||||
else:
|
||||
urgency["hours_until_consequence"] = 8
|
||||
|
||||
# Production is peak-hour sensitive
|
||||
urgency["peak_hour_relevant"] = True
|
||||
|
||||
return urgency
|
||||
|
||||
def _analyze_stock_urgency(self, metadata: Dict[str, Any]) -> dict:
|
||||
"""Analyze urgency for stock alerts"""
|
||||
urgency = {}
|
||||
|
||||
# Hours until needed
|
||||
if "hours_until" in metadata:
|
||||
urgency["hours_until_consequence"] = metadata["hours_until"]
|
||||
urgency["can_wait_until_tomorrow"] = urgency["hours_until_consequence"] > 24
|
||||
|
||||
# Days until expiry
|
||||
elif "days_until_expiry" in metadata:
|
||||
days = metadata["days_until_expiry"]
|
||||
if days <= 1:
|
||||
urgency["hours_until_consequence"] = days * 24
|
||||
urgency["can_wait_until_tomorrow"] = False
|
||||
else:
|
||||
urgency["hours_until_consequence"] = days * 24
|
||||
|
||||
return urgency
|
||||
|
||||
def _analyze_delivery_urgency(self, metadata: Dict[str, Any]) -> dict:
|
||||
"""Analyze urgency for delivery alerts"""
|
||||
urgency = {}
|
||||
|
||||
days_overdue = metadata.get("days_overdue", 0)
|
||||
|
||||
if days_overdue > 3:
|
||||
urgency["hours_until_consequence"] = 2
|
||||
urgency["can_wait_until_tomorrow"] = False
|
||||
elif days_overdue > 1:
|
||||
urgency["hours_until_consequence"] = 8
|
||||
urgency["can_wait_until_tomorrow"] = False
|
||||
|
||||
return urgency
|
||||
|
||||
def _analyze_po_approval_urgency(self, metadata: Dict[str, Any]) -> dict:
|
||||
"""
|
||||
Analyze urgency for PO approval alerts.
|
||||
|
||||
Uses stockout time (when you run out of stock) instead of delivery date
|
||||
to determine true urgency.
|
||||
"""
|
||||
urgency = {}
|
||||
|
||||
# Extract min_depletion_hours from reasoning_data.parameters
|
||||
reasoning_data = metadata.get("reasoning_data", {})
|
||||
parameters = reasoning_data.get("parameters", {})
|
||||
min_depletion_hours = parameters.get("min_depletion_hours")
|
||||
|
||||
if min_depletion_hours is not None:
|
||||
urgency["hours_until_consequence"] = max(0, round(min_depletion_hours, 1))
|
||||
urgency["can_wait_until_tomorrow"] = min_depletion_hours > 24
|
||||
|
||||
# Set deadline_utc to when stock runs out
|
||||
now = datetime.now(timezone.utc)
|
||||
stockout_time = now + timedelta(hours=min_depletion_hours)
|
||||
urgency["deadline_utc"] = stockout_time.isoformat()
|
||||
|
||||
logger.info(
|
||||
"po_approval_urgency_calculated",
|
||||
min_depletion_hours=min_depletion_hours,
|
||||
stockout_deadline=urgency["deadline_utc"],
|
||||
can_wait=urgency["can_wait_until_tomorrow"]
|
||||
)
|
||||
|
||||
return urgency
|
||||
|
||||
def _calculate_deadline_urgency(self, deadline_str: str) -> dict:
|
||||
"""Calculate urgency based on deadline"""
|
||||
try:
|
||||
if isinstance(deadline_str, str):
|
||||
deadline = datetime.fromisoformat(deadline_str.replace('Z', '+00:00'))
|
||||
else:
|
||||
deadline = deadline_str
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
time_until = deadline - now
|
||||
|
||||
hours_until = time_until.total_seconds() / 3600
|
||||
|
||||
return {
|
||||
"deadline_utc": deadline.isoformat(),
|
||||
"hours_until_consequence": max(0, round(hours_until, 1)),
|
||||
"can_wait_until_tomorrow": hours_until > 24
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning("deadline_parse_failed", deadline=deadline_str, error=str(e))
|
||||
return {}
|
||||
116
services/alert_processor/app/enrichment/user_agency.py
Normal file
116
services/alert_processor/app/enrichment/user_agency.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
User agency analyzer for alerts.
|
||||
|
||||
Determines whether user can fix the issue, what blockers exist,
|
||||
and if external parties are required.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class UserAgencyAnalyzer:
|
||||
"""Analyze user's ability to act on alerts"""
|
||||
|
||||
def analyze(
|
||||
self,
|
||||
event_type: str,
|
||||
metadata: Dict[str, Any],
|
||||
orchestrator_context: dict
|
||||
) -> dict:
|
||||
"""
|
||||
Analyze user agency for an event.
|
||||
|
||||
Returns dict with:
|
||||
- can_user_fix: Boolean - can user resolve this?
|
||||
- requires_external_party: Boolean
|
||||
- external_party_name: Name of required party
|
||||
- external_party_contact: Contact info
|
||||
- blockers: List of blocking factors
|
||||
- suggested_workaround: Optional workaround suggestion
|
||||
"""
|
||||
|
||||
agency = {
|
||||
"can_user_fix": True,
|
||||
"requires_external_party": False,
|
||||
"external_party_name": None,
|
||||
"external_party_contact": None,
|
||||
"blockers": [],
|
||||
"suggested_workaround": None
|
||||
}
|
||||
|
||||
# If orchestrator already addressed it, user agency is low
|
||||
if orchestrator_context and orchestrator_context.get("already_addressed"):
|
||||
agency["can_user_fix"] = False
|
||||
agency["blockers"].append("ai_already_handled")
|
||||
return agency
|
||||
|
||||
# Analyze based on event type
|
||||
if "po_approval" in event_type:
|
||||
agency["can_user_fix"] = True
|
||||
|
||||
elif "delivery" in event_type or "supplier" in event_type:
|
||||
agency.update(self._analyze_supplier_agency(metadata))
|
||||
|
||||
elif "equipment" in event_type:
|
||||
agency.update(self._analyze_equipment_agency(metadata))
|
||||
|
||||
elif "stock" in event_type:
|
||||
agency.update(self._analyze_stock_agency(metadata, orchestrator_context))
|
||||
|
||||
return agency
|
||||
|
||||
def _analyze_supplier_agency(self, metadata: Dict[str, Any]) -> dict:
|
||||
"""Analyze agency for supplier-related alerts"""
|
||||
agency = {
|
||||
"requires_external_party": True,
|
||||
"external_party_name": metadata.get("supplier_name"),
|
||||
"external_party_contact": metadata.get("supplier_contact")
|
||||
}
|
||||
|
||||
# User can contact supplier but can't directly fix
|
||||
if not metadata.get("supplier_contact"):
|
||||
agency["blockers"].append("no_supplier_contact")
|
||||
|
||||
return agency
|
||||
|
||||
def _analyze_equipment_agency(self, metadata: Dict[str, Any]) -> dict:
|
||||
"""Analyze agency for equipment-related alerts"""
|
||||
agency = {}
|
||||
|
||||
equipment_type = metadata.get("equipment_type", "")
|
||||
|
||||
if "oven" in equipment_type.lower() or "mixer" in equipment_type.lower():
|
||||
agency["requires_external_party"] = True
|
||||
agency["external_party_name"] = "Maintenance Team"
|
||||
agency["blockers"].append("requires_technician")
|
||||
|
||||
return agency
|
||||
|
||||
def _analyze_stock_agency(
|
||||
self,
|
||||
metadata: Dict[str, Any],
|
||||
orchestrator_context: dict
|
||||
) -> dict:
|
||||
"""Analyze agency for stock-related alerts"""
|
||||
agency = {}
|
||||
|
||||
# If PO exists, user just needs to approve
|
||||
if metadata.get("po_id"):
|
||||
if metadata.get("po_status") == "pending_approval":
|
||||
agency["can_user_fix"] = True
|
||||
agency["suggested_workaround"] = "Approve pending PO"
|
||||
else:
|
||||
agency["blockers"].append("waiting_for_delivery")
|
||||
agency["requires_external_party"] = True
|
||||
agency["external_party_name"] = metadata.get("supplier_name")
|
||||
|
||||
# If no PO, user needs to create one
|
||||
elif metadata.get("supplier_name"):
|
||||
agency["can_user_fix"] = True
|
||||
agency["requires_external_party"] = True
|
||||
agency["external_party_name"] = metadata.get("supplier_name")
|
||||
|
||||
return agency
|
||||
100
services/alert_processor/app/main.py
Normal file
100
services/alert_processor/app/main.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""
|
||||
Alert Processor Service v2.0
|
||||
|
||||
Main FastAPI application with RabbitMQ consumer lifecycle management.
|
||||
"""
|
||||
|
||||
import structlog
|
||||
|
||||
from app.core.config import settings
|
||||
from app.consumer.event_consumer import EventConsumer
|
||||
from app.api import alerts, sse
|
||||
from shared.redis_utils import initialize_redis, close_redis
|
||||
from shared.service_base import StandardFastAPIService
|
||||
|
||||
# Initialize logger
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Global consumer instance
|
||||
consumer: EventConsumer = None
|
||||
|
||||
|
||||
class AlertProcessorService(StandardFastAPIService):
|
||||
"""Alert Processor Service with standardized monitoring setup and RabbitMQ consumer"""
|
||||
|
||||
async def on_startup(self, app):
|
||||
"""Custom startup logic for Alert Processor"""
|
||||
global consumer
|
||||
|
||||
# Initialize Redis connection
|
||||
await initialize_redis(
|
||||
settings.REDIS_URL,
|
||||
db=settings.REDIS_DB,
|
||||
max_connections=settings.REDIS_MAX_CONNECTIONS
|
||||
)
|
||||
logger.info("redis_initialized")
|
||||
|
||||
# Start RabbitMQ consumer
|
||||
consumer = EventConsumer()
|
||||
await consumer.start()
|
||||
logger.info("rabbitmq_consumer_started")
|
||||
|
||||
await super().on_startup(app)
|
||||
|
||||
async def on_shutdown(self, app):
|
||||
"""Custom shutdown logic for Alert Processor"""
|
||||
global consumer
|
||||
|
||||
await super().on_shutdown(app)
|
||||
|
||||
# Stop RabbitMQ consumer
|
||||
if consumer:
|
||||
await consumer.stop()
|
||||
logger.info("rabbitmq_consumer_stopped")
|
||||
|
||||
# Close Redis
|
||||
await close_redis()
|
||||
logger.info("redis_closed")
|
||||
|
||||
|
||||
# Create service instance
|
||||
service = AlertProcessorService(
|
||||
service_name="alert-processor",
|
||||
app_name="Alert Processor Service",
|
||||
description="Event processing, enrichment, and alert management system",
|
||||
version=settings.VERSION,
|
||||
log_level=getattr(settings, 'LOG_LEVEL', 'INFO'),
|
||||
cors_origins=["*"], # Configure appropriately for production
|
||||
api_prefix="/api/v1",
|
||||
enable_metrics=True,
|
||||
enable_health_checks=True,
|
||||
enable_tracing=True,
|
||||
enable_cors=True
|
||||
)
|
||||
|
||||
# Create FastAPI app
|
||||
app = service.create_app(debug=settings.DEBUG)
|
||||
|
||||
# Add service-specific routers
|
||||
app.include_router(
|
||||
alerts.router,
|
||||
prefix="/api/v1/tenants/{tenant_id}",
|
||||
tags=["alerts"]
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
sse.router,
|
||||
prefix="/api/v1",
|
||||
tags=["sse"]
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
reload=settings.DEBUG
|
||||
)
|
||||
0
services/alert_processor/app/models/__init__.py
Normal file
0
services/alert_processor/app/models/__init__.py
Normal file
84
services/alert_processor/app/models/events.py
Normal file
84
services/alert_processor/app/models/events.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
SQLAlchemy models for events table.
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, Integer, DateTime, Float, Index
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class Event(Base):
|
||||
"""Unified event table for alerts, notifications, recommendations"""
|
||||
__tablename__ = "events"
|
||||
|
||||
# Core fields
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
created_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
nullable=False
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
nullable=False
|
||||
)
|
||||
|
||||
# Classification
|
||||
event_class = Column(String(50), nullable=False)
|
||||
event_domain = Column(String(50), nullable=False, index=True)
|
||||
event_type = Column(String(100), nullable=False, index=True)
|
||||
service = Column(String(50), nullable=False)
|
||||
|
||||
# i18n content (NO hardcoded title/message)
|
||||
i18n_title_key = Column(String(200), nullable=False)
|
||||
i18n_title_params = Column(JSONB, nullable=False, default=dict)
|
||||
i18n_message_key = Column(String(200), nullable=False)
|
||||
i18n_message_params = Column(JSONB, nullable=False, default=dict)
|
||||
|
||||
# Priority
|
||||
priority_score = Column(Integer, nullable=False, default=50, index=True)
|
||||
priority_level = Column(String(20), nullable=False, index=True)
|
||||
type_class = Column(String(50), nullable=False, index=True)
|
||||
|
||||
# Enrichment contexts (JSONB)
|
||||
orchestrator_context = Column(JSONB, nullable=True)
|
||||
business_impact = Column(JSONB, nullable=True)
|
||||
urgency = Column(JSONB, nullable=True)
|
||||
user_agency = Column(JSONB, nullable=True)
|
||||
trend_context = Column(JSONB, nullable=True)
|
||||
|
||||
# Smart actions
|
||||
smart_actions = Column(JSONB, nullable=False, default=list)
|
||||
|
||||
# AI reasoning
|
||||
ai_reasoning_summary_key = Column(String(200), nullable=True)
|
||||
ai_reasoning_summary_params = Column(JSONB, nullable=True)
|
||||
ai_reasoning_details = Column(JSONB, nullable=True)
|
||||
confidence_score = Column(Float, nullable=True)
|
||||
|
||||
# Entity references
|
||||
entity_links = Column(JSONB, nullable=False, default=dict)
|
||||
|
||||
# Status
|
||||
status = Column(String(20), nullable=False, default="active", index=True)
|
||||
resolved_at = Column(DateTime(timezone=True), nullable=True)
|
||||
acknowledged_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Metadata
|
||||
event_metadata = Column(JSONB, nullable=False, default=dict)
|
||||
|
||||
# Indexes for dashboard queries
|
||||
__table_args__ = (
|
||||
Index('idx_events_tenant_status', 'tenant_id', 'status'),
|
||||
Index('idx_events_tenant_priority', 'tenant_id', 'priority_score'),
|
||||
Index('idx_events_tenant_class', 'tenant_id', 'event_class'),
|
||||
Index('idx_events_tenant_created', 'tenant_id', 'created_at'),
|
||||
Index('idx_events_type_class_status', 'type_class', 'status'),
|
||||
)
|
||||
407
services/alert_processor/app/repositories/event_repository.py
Normal file
407
services/alert_processor/app/repositories/event_repository.py
Normal file
@@ -0,0 +1,407 @@
|
||||
"""
|
||||
Event repository for database operations.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any
|
||||
from uuid import UUID
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, and_, desc
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
import structlog
|
||||
|
||||
from app.models.events import Event
|
||||
from app.schemas.events import EnrichedEvent, EventSummary, EventResponse, I18nContent, SmartAction
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class EventRepository:
|
||||
"""Repository for event database operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def create_event(self, enriched_event: EnrichedEvent) -> Event:
|
||||
"""
|
||||
Store enriched event in database.
|
||||
|
||||
Args:
|
||||
enriched_event: Enriched event with all context
|
||||
|
||||
Returns:
|
||||
Stored Event model
|
||||
"""
|
||||
|
||||
# Convert enriched event to database model
|
||||
event = Event(
|
||||
id=enriched_event.id,
|
||||
tenant_id=UUID(enriched_event.tenant_id),
|
||||
event_class=enriched_event.event_class,
|
||||
event_domain=enriched_event.event_domain,
|
||||
event_type=enriched_event.event_type,
|
||||
service=enriched_event.service,
|
||||
|
||||
# i18n content
|
||||
i18n_title_key=enriched_event.i18n.title_key,
|
||||
i18n_title_params=enriched_event.i18n.title_params,
|
||||
i18n_message_key=enriched_event.i18n.message_key,
|
||||
i18n_message_params=enriched_event.i18n.message_params,
|
||||
|
||||
# Priority
|
||||
priority_score=enriched_event.priority_score,
|
||||
priority_level=enriched_event.priority_level,
|
||||
type_class=enriched_event.type_class,
|
||||
|
||||
# Enrichment contexts
|
||||
orchestrator_context=enriched_event.orchestrator_context.dict() if enriched_event.orchestrator_context else None,
|
||||
business_impact=enriched_event.business_impact.dict() if enriched_event.business_impact else None,
|
||||
urgency=enriched_event.urgency.dict() if enriched_event.urgency else None,
|
||||
user_agency=enriched_event.user_agency.dict() if enriched_event.user_agency else None,
|
||||
trend_context=enriched_event.trend_context,
|
||||
|
||||
# Smart actions
|
||||
smart_actions=[action.dict() for action in enriched_event.smart_actions],
|
||||
|
||||
# AI reasoning
|
||||
ai_reasoning_summary_key=enriched_event.ai_reasoning_summary_key,
|
||||
ai_reasoning_summary_params=enriched_event.ai_reasoning_summary_params,
|
||||
ai_reasoning_details=enriched_event.ai_reasoning_details,
|
||||
confidence_score=enriched_event.confidence_score,
|
||||
|
||||
# Entity links
|
||||
entity_links=enriched_event.entity_links,
|
||||
|
||||
# Status
|
||||
status=enriched_event.status,
|
||||
|
||||
# Metadata
|
||||
event_metadata=enriched_event.event_metadata
|
||||
)
|
||||
|
||||
self.session.add(event)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(event)
|
||||
|
||||
logger.info("event_stored", event_id=event.id, event_type=event.event_type)
|
||||
|
||||
return event
|
||||
|
||||
async def get_events(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
event_class: Optional[str] = None,
|
||||
priority_level: Optional[List[str]] = None,
|
||||
status: Optional[List[str]] = None,
|
||||
event_domain: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0
|
||||
) -> List[Event]:
|
||||
"""
|
||||
Get filtered list of events.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID
|
||||
event_class: Filter by event class (alert, notification, recommendation)
|
||||
priority_level: Filter by priority levels
|
||||
status: Filter by status values
|
||||
event_domain: Filter by domain
|
||||
limit: Max results
|
||||
offset: Pagination offset
|
||||
|
||||
Returns:
|
||||
List of Event models
|
||||
"""
|
||||
|
||||
query = select(Event).where(Event.tenant_id == tenant_id)
|
||||
|
||||
# Apply filters
|
||||
if event_class:
|
||||
query = query.where(Event.event_class == event_class)
|
||||
|
||||
if priority_level:
|
||||
query = query.where(Event.priority_level.in_(priority_level))
|
||||
|
||||
if status:
|
||||
query = query.where(Event.status.in_(status))
|
||||
|
||||
if event_domain:
|
||||
query = query.where(Event.event_domain == event_domain)
|
||||
|
||||
# Order by priority and creation time
|
||||
query = query.order_by(
|
||||
desc(Event.priority_score),
|
||||
desc(Event.created_at)
|
||||
)
|
||||
|
||||
# Pagination
|
||||
query = query.limit(limit).offset(offset)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
events = result.scalars().all()
|
||||
|
||||
return list(events)
|
||||
|
||||
async def get_event_by_id(self, event_id: UUID) -> Optional[Event]:
|
||||
"""Get single event by ID"""
|
||||
query = select(Event).where(Event.id == event_id)
|
||||
result = await self.session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def check_duplicate_alert(self, tenant_id: UUID, event_type: str, entity_links: Dict, event_metadata: Dict, time_window_hours: int = 24) -> Optional[Event]:
|
||||
"""
|
||||
Check if a similar alert already exists within the time window.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID
|
||||
event_type: Type of event (e.g., 'production_delay', 'critical_stock_shortage')
|
||||
entity_links: Entity references (e.g., batch_id, po_id, ingredient_id)
|
||||
event_metadata: Event metadata for comparison
|
||||
time_window_hours: Time window in hours to check for duplicates
|
||||
|
||||
Returns:
|
||||
Existing event if duplicate found, None otherwise
|
||||
"""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
# Calculate time threshold
|
||||
time_threshold = datetime.now(timezone.utc) - timedelta(hours=time_window_hours)
|
||||
|
||||
# Build query to find potential duplicates
|
||||
query = select(Event).where(
|
||||
and_(
|
||||
Event.tenant_id == tenant_id,
|
||||
Event.event_type == event_type,
|
||||
Event.status == "active", # Only check active alerts
|
||||
Event.created_at >= time_threshold
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
potential_duplicates = result.scalars().all()
|
||||
|
||||
# Compare each potential duplicate for semantic similarity
|
||||
for event in potential_duplicates:
|
||||
# Check if entity links match (same batch, PO, ingredient, etc.)
|
||||
if self._entities_match(event.entity_links, entity_links):
|
||||
# For production delays, check if it's the same batch with similar delay
|
||||
if event_type == "production_delay":
|
||||
if self._production_delay_match(event.event_metadata, event_metadata):
|
||||
return event
|
||||
|
||||
# For critical stock shortages, check if it's the same ingredient
|
||||
elif event_type == "critical_stock_shortage":
|
||||
if self._stock_shortage_match(event.event_metadata, event_metadata):
|
||||
return event
|
||||
|
||||
# For delivery overdue alerts, check if it's the same PO
|
||||
elif event_type == "delivery_overdue":
|
||||
if self._delivery_overdue_match(event.event_metadata, event_metadata):
|
||||
return event
|
||||
|
||||
# For general matching based on metadata
|
||||
else:
|
||||
if self._metadata_match(event.event_metadata, event_metadata):
|
||||
return event
|
||||
|
||||
return None
|
||||
|
||||
def _entities_match(self, existing_links: Dict, new_links: Dict) -> bool:
|
||||
"""Check if entity links match between two events."""
|
||||
if not existing_links or not new_links:
|
||||
return False
|
||||
|
||||
# Check for common entity types
|
||||
common_entities = ['production_batch', 'purchase_order', 'ingredient', 'supplier', 'equipment']
|
||||
|
||||
for entity in common_entities:
|
||||
if entity in existing_links and entity in new_links:
|
||||
if existing_links[entity] == new_links[entity]:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _production_delay_match(self, existing_meta: Dict, new_meta: Dict) -> bool:
|
||||
"""Check if production delay alerts match."""
|
||||
# Same batch_id indicates same production issue
|
||||
return (existing_meta.get('batch_id') == new_meta.get('batch_id') and
|
||||
existing_meta.get('product_name') == new_meta.get('product_name'))
|
||||
|
||||
def _stock_shortage_match(self, existing_meta: Dict, new_meta: Dict) -> bool:
|
||||
"""Check if stock shortage alerts match."""
|
||||
# Same ingredient_id indicates same shortage issue
|
||||
return existing_meta.get('ingredient_id') == new_meta.get('ingredient_id')
|
||||
|
||||
def _delivery_overdue_match(self, existing_meta: Dict, new_meta: Dict) -> bool:
|
||||
"""Check if delivery overdue alerts match."""
|
||||
# Same PO indicates same delivery issue
|
||||
return existing_meta.get('po_id') == new_meta.get('po_id')
|
||||
|
||||
def _metadata_match(self, existing_meta: Dict, new_meta: Dict) -> bool:
|
||||
"""Generic metadata matching for other alert types."""
|
||||
# Check for common identifying fields
|
||||
common_fields = ['batch_id', 'po_id', 'ingredient_id', 'supplier_id', 'equipment_id']
|
||||
|
||||
for field in common_fields:
|
||||
if field in existing_meta and field in new_meta:
|
||||
if existing_meta[field] == new_meta[field]:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def get_summary(self, tenant_id: UUID) -> EventSummary:
|
||||
"""
|
||||
Get summary statistics for dashboard.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID
|
||||
|
||||
Returns:
|
||||
EventSummary with counts and statistics
|
||||
"""
|
||||
|
||||
# Count by status
|
||||
status_query = select(
|
||||
Event.status,
|
||||
func.count(Event.id).label('count')
|
||||
).where(
|
||||
Event.tenant_id == tenant_id
|
||||
).group_by(Event.status)
|
||||
|
||||
status_result = await self.session.execute(status_query)
|
||||
status_counts = {row.status: row.count for row in status_result}
|
||||
|
||||
# Count by priority
|
||||
priority_query = select(
|
||||
Event.priority_level,
|
||||
func.count(Event.id).label('count')
|
||||
).where(
|
||||
and_(
|
||||
Event.tenant_id == tenant_id,
|
||||
Event.status == "active"
|
||||
)
|
||||
).group_by(Event.priority_level)
|
||||
|
||||
priority_result = await self.session.execute(priority_query)
|
||||
priority_counts = {row.priority_level: row.count for row in priority_result}
|
||||
|
||||
# Count by domain
|
||||
domain_query = select(
|
||||
Event.event_domain,
|
||||
func.count(Event.id).label('count')
|
||||
).where(
|
||||
and_(
|
||||
Event.tenant_id == tenant_id,
|
||||
Event.status == "active"
|
||||
)
|
||||
).group_by(Event.event_domain)
|
||||
|
||||
domain_result = await self.session.execute(domain_query)
|
||||
domain_counts = {row.event_domain: row.count for row in domain_result}
|
||||
|
||||
# Count by type class
|
||||
type_class_query = select(
|
||||
Event.type_class,
|
||||
func.count(Event.id).label('count')
|
||||
).where(
|
||||
and_(
|
||||
Event.tenant_id == tenant_id,
|
||||
Event.status == "active"
|
||||
)
|
||||
).group_by(Event.type_class)
|
||||
|
||||
type_class_result = await self.session.execute(type_class_query)
|
||||
type_class_counts = {row.type_class: row.count for row in type_class_result}
|
||||
|
||||
return EventSummary(
|
||||
total_active=status_counts.get("active", 0),
|
||||
total_acknowledged=status_counts.get("acknowledged", 0),
|
||||
total_resolved=status_counts.get("resolved", 0),
|
||||
by_priority=priority_counts,
|
||||
by_domain=domain_counts,
|
||||
by_type_class=type_class_counts,
|
||||
critical_alerts=priority_counts.get("critical", 0),
|
||||
important_alerts=priority_counts.get("important", 0)
|
||||
)
|
||||
|
||||
async def acknowledge_event(self, event_id: UUID) -> Event:
|
||||
"""Mark event as acknowledged"""
|
||||
event = await self.get_event_by_id(event_id)
|
||||
|
||||
if not event:
|
||||
raise ValueError(f"Event {event_id} not found")
|
||||
|
||||
event.status = "acknowledged"
|
||||
event.acknowledged_at = datetime.now(timezone.utc)
|
||||
|
||||
await self.session.commit()
|
||||
await self.session.refresh(event)
|
||||
|
||||
logger.info("event_acknowledged", event_id=event_id)
|
||||
|
||||
return event
|
||||
|
||||
async def resolve_event(self, event_id: UUID) -> Event:
|
||||
"""Mark event as resolved"""
|
||||
event = await self.get_event_by_id(event_id)
|
||||
|
||||
if not event:
|
||||
raise ValueError(f"Event {event_id} not found")
|
||||
|
||||
event.status = "resolved"
|
||||
event.resolved_at = datetime.now(timezone.utc)
|
||||
|
||||
await self.session.commit()
|
||||
await self.session.refresh(event)
|
||||
|
||||
logger.info("event_resolved", event_id=event_id)
|
||||
|
||||
return event
|
||||
|
||||
async def dismiss_event(self, event_id: UUID) -> Event:
|
||||
"""Mark event as dismissed"""
|
||||
event = await self.get_event_by_id(event_id)
|
||||
|
||||
if not event:
|
||||
raise ValueError(f"Event {event_id} not found")
|
||||
|
||||
event.status = "dismissed"
|
||||
|
||||
await self.session.commit()
|
||||
await self.session.refresh(event)
|
||||
|
||||
logger.info("event_dismissed", event_id=event_id)
|
||||
|
||||
return event
|
||||
|
||||
def _event_to_response(self, event: Event) -> EventResponse:
|
||||
"""Convert Event model to EventResponse"""
|
||||
return EventResponse(
|
||||
id=event.id,
|
||||
tenant_id=event.tenant_id,
|
||||
created_at=event.created_at,
|
||||
event_class=event.event_class,
|
||||
event_domain=event.event_domain,
|
||||
event_type=event.event_type,
|
||||
i18n=I18nContent(
|
||||
title_key=event.i18n_title_key,
|
||||
title_params=event.i18n_title_params,
|
||||
message_key=event.i18n_message_key,
|
||||
message_params=event.i18n_message_params
|
||||
),
|
||||
priority_score=event.priority_score,
|
||||
priority_level=event.priority_level,
|
||||
type_class=event.type_class,
|
||||
smart_actions=[SmartAction(**action) for action in event.smart_actions],
|
||||
status=event.status,
|
||||
orchestrator_context=event.orchestrator_context,
|
||||
business_impact=event.business_impact,
|
||||
urgency=event.urgency,
|
||||
user_agency=event.user_agency,
|
||||
ai_reasoning_summary_key=event.ai_reasoning_summary_key,
|
||||
ai_reasoning_summary_params=event.ai_reasoning_summary_params,
|
||||
ai_reasoning_details=event.ai_reasoning_details,
|
||||
confidence_score=event.confidence_score,
|
||||
entity_links=event.entity_links,
|
||||
event_metadata=event.event_metadata
|
||||
)
|
||||
0
services/alert_processor/app/schemas/__init__.py
Normal file
0
services/alert_processor/app/schemas/__init__.py
Normal file
180
services/alert_processor/app/schemas/events.py
Normal file
180
services/alert_processor/app/schemas/events.py
Normal file
@@ -0,0 +1,180 @@
|
||||
"""
|
||||
Pydantic schemas for enriched events.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, Any, List, Optional, Literal
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
class I18nContent(BaseModel):
|
||||
"""i18n content structure"""
|
||||
title_key: str
|
||||
title_params: Dict[str, Any] = {}
|
||||
message_key: str
|
||||
message_params: Dict[str, Any] = {}
|
||||
|
||||
|
||||
class SmartAction(BaseModel):
|
||||
"""Smart action button"""
|
||||
action_type: str
|
||||
label_key: str
|
||||
label_params: Dict[str, Any] = {}
|
||||
variant: Literal["primary", "secondary", "danger", "ghost"]
|
||||
disabled: bool = False
|
||||
disabled_reason_key: Optional[str] = None
|
||||
consequence_key: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
metadata: Dict[str, Any] = {}
|
||||
|
||||
|
||||
class BusinessImpact(BaseModel):
|
||||
"""Business impact context"""
|
||||
financial_impact_eur: float = 0
|
||||
affected_orders: int = 0
|
||||
affected_customers: List[str] = []
|
||||
production_delay_hours: float = 0
|
||||
estimated_revenue_loss_eur: float = 0
|
||||
customer_impact: Literal["low", "medium", "high"] = "low"
|
||||
waste_risk_kg: float = 0
|
||||
|
||||
|
||||
class Urgency(BaseModel):
|
||||
"""Urgency context"""
|
||||
hours_until_consequence: float = 24
|
||||
can_wait_until_tomorrow: bool = True
|
||||
deadline_utc: Optional[str] = None
|
||||
peak_hour_relevant: bool = False
|
||||
hours_pending: float = 0
|
||||
|
||||
|
||||
class UserAgency(BaseModel):
|
||||
"""User agency context"""
|
||||
can_user_fix: bool = True
|
||||
requires_external_party: bool = False
|
||||
external_party_name: Optional[str] = None
|
||||
external_party_contact: Optional[str] = None
|
||||
blockers: List[str] = []
|
||||
suggested_workaround: Optional[str] = None
|
||||
|
||||
|
||||
class OrchestratorContext(BaseModel):
|
||||
"""AI orchestrator context"""
|
||||
already_addressed: bool = False
|
||||
action_id: Optional[str] = None
|
||||
action_type: Optional[str] = None
|
||||
action_summary: Optional[str] = None
|
||||
reasoning: Optional[str] = None
|
||||
confidence: float = 0.8
|
||||
|
||||
|
||||
class EnrichedEvent(BaseModel):
|
||||
"""Complete enriched event with all context"""
|
||||
|
||||
# Core fields
|
||||
id: str
|
||||
tenant_id: str
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
# Classification
|
||||
event_class: Literal["alert", "notification", "recommendation"]
|
||||
event_domain: str
|
||||
event_type: str
|
||||
service: str
|
||||
|
||||
# i18n content
|
||||
i18n: I18nContent
|
||||
|
||||
# Priority
|
||||
priority_score: int = Field(ge=0, le=100)
|
||||
priority_level: Literal["critical", "important", "standard", "info"]
|
||||
type_class: str
|
||||
|
||||
# Enrichment contexts
|
||||
orchestrator_context: Optional[OrchestratorContext] = None
|
||||
business_impact: Optional[BusinessImpact] = None
|
||||
urgency: Optional[Urgency] = None
|
||||
user_agency: Optional[UserAgency] = None
|
||||
trend_context: Optional[Dict[str, Any]] = None
|
||||
|
||||
# Smart actions
|
||||
smart_actions: List[SmartAction] = []
|
||||
|
||||
# AI reasoning
|
||||
ai_reasoning_summary_key: Optional[str] = None
|
||||
ai_reasoning_summary_params: Optional[Dict[str, Any]] = None
|
||||
ai_reasoning_details: Optional[Dict[str, Any]] = None
|
||||
confidence_score: Optional[float] = None
|
||||
|
||||
# Entity references
|
||||
entity_links: Dict[str, str] = {}
|
||||
|
||||
# Status
|
||||
status: Literal["active", "acknowledged", "resolved", "dismissed"] = "active"
|
||||
resolved_at: Optional[datetime] = None
|
||||
acknowledged_at: Optional[datetime] = None
|
||||
|
||||
# Original metadata
|
||||
event_metadata: Dict[str, Any] = {}
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class EventResponse(BaseModel):
|
||||
"""Event response for API"""
|
||||
id: UUID
|
||||
tenant_id: UUID
|
||||
created_at: datetime
|
||||
event_class: str
|
||||
event_domain: str
|
||||
event_type: str
|
||||
i18n: I18nContent
|
||||
priority_score: int
|
||||
priority_level: str
|
||||
type_class: str
|
||||
smart_actions: List[SmartAction]
|
||||
status: str
|
||||
|
||||
# Optional enrichment contexts (only if present)
|
||||
orchestrator_context: Optional[Dict[str, Any]] = None
|
||||
business_impact: Optional[Dict[str, Any]] = None
|
||||
urgency: Optional[Dict[str, Any]] = None
|
||||
user_agency: Optional[Dict[str, Any]] = None
|
||||
|
||||
# AI reasoning
|
||||
ai_reasoning_summary_key: Optional[str] = None
|
||||
ai_reasoning_summary_params: Optional[Dict[str, Any]] = None
|
||||
ai_reasoning_details: Optional[Dict[str, Any]] = None
|
||||
confidence_score: Optional[float] = None
|
||||
|
||||
entity_links: Dict[str, str] = {}
|
||||
event_metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class EventSummary(BaseModel):
|
||||
"""Summary statistics for dashboard"""
|
||||
total_active: int
|
||||
total_acknowledged: int
|
||||
total_resolved: int
|
||||
by_priority: Dict[str, int]
|
||||
by_domain: Dict[str, int]
|
||||
by_type_class: Dict[str, int]
|
||||
critical_alerts: int
|
||||
important_alerts: int
|
||||
|
||||
|
||||
class EventFilter(BaseModel):
|
||||
"""Filter criteria for event queries"""
|
||||
tenant_id: UUID
|
||||
event_class: Optional[str] = None
|
||||
priority_level: Optional[List[str]] = None
|
||||
status: Optional[List[str]] = None
|
||||
event_domain: Optional[str] = None
|
||||
limit: int = Field(default=50, le=100)
|
||||
offset: int = 0
|
||||
0
services/alert_processor/app/services/__init__.py
Normal file
0
services/alert_processor/app/services/__init__.py
Normal file
246
services/alert_processor/app/services/enrichment_orchestrator.py
Normal file
246
services/alert_processor/app/services/enrichment_orchestrator.py
Normal file
@@ -0,0 +1,246 @@
|
||||
"""
|
||||
Enrichment orchestrator service.
|
||||
|
||||
Coordinates the complete enrichment pipeline for events.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
import structlog
|
||||
from uuid import uuid4
|
||||
|
||||
from shared.messaging import MinimalEvent
|
||||
from app.schemas.events import EnrichedEvent, I18nContent, BusinessImpact, Urgency, UserAgency, OrchestratorContext
|
||||
from app.enrichment.message_generator import MessageGenerator
|
||||
from app.enrichment.priority_scorer import PriorityScorer
|
||||
from app.enrichment.orchestrator_client import OrchestratorClient
|
||||
from app.enrichment.smart_actions import SmartActionGenerator
|
||||
from app.enrichment.business_impact import BusinessImpactAnalyzer
|
||||
from app.enrichment.urgency_analyzer import UrgencyAnalyzer
|
||||
from app.enrichment.user_agency import UserAgencyAnalyzer
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class EnrichmentOrchestrator:
|
||||
"""Coordinates the enrichment pipeline for events"""
|
||||
|
||||
def __init__(self):
|
||||
self.message_gen = MessageGenerator()
|
||||
self.priority_scorer = PriorityScorer()
|
||||
self.orchestrator_client = OrchestratorClient()
|
||||
self.action_gen = SmartActionGenerator()
|
||||
self.impact_analyzer = BusinessImpactAnalyzer()
|
||||
self.urgency_analyzer = UrgencyAnalyzer()
|
||||
self.agency_analyzer = UserAgencyAnalyzer()
|
||||
|
||||
async def enrich_event(self, event: MinimalEvent) -> EnrichedEvent:
|
||||
"""
|
||||
Run complete enrichment pipeline.
|
||||
|
||||
Steps:
|
||||
1. Generate i18n message keys and parameters
|
||||
2. Query orchestrator for AI context
|
||||
3. Analyze business impact
|
||||
4. Assess urgency
|
||||
5. Determine user agency
|
||||
6. Calculate priority score (0-100)
|
||||
7. Determine priority level
|
||||
8. Generate smart actions
|
||||
9. Determine type class
|
||||
10. Build enriched event
|
||||
|
||||
Args:
|
||||
event: Minimal event from service
|
||||
|
||||
Returns:
|
||||
Enriched event with all context
|
||||
"""
|
||||
|
||||
logger.info("enrichment_started", event_type=event.event_type, tenant_id=event.tenant_id)
|
||||
|
||||
# 1. Generate i18n message keys and parameters
|
||||
i18n_dict = self.message_gen.generate_message(event.event_type, event.metadata, event.event_class)
|
||||
i18n = I18nContent(**i18n_dict)
|
||||
|
||||
# 2. Query orchestrator for AI context (parallel with other enrichments)
|
||||
orchestrator_context_dict = await self.orchestrator_client.get_context(
|
||||
tenant_id=event.tenant_id,
|
||||
event_type=event.event_type,
|
||||
metadata=event.metadata
|
||||
)
|
||||
|
||||
# Fallback: If orchestrator service didn't return context with already_addressed,
|
||||
# check if the event metadata contains orchestrator_context (e.g., from demo seeder)
|
||||
if not orchestrator_context_dict.get("already_addressed"):
|
||||
metadata_context = event.metadata.get("orchestrator_context")
|
||||
if metadata_context and isinstance(metadata_context, dict):
|
||||
# Merge metadata context into orchestrator context
|
||||
orchestrator_context_dict.update(metadata_context)
|
||||
logger.debug(
|
||||
"using_metadata_orchestrator_context",
|
||||
event_type=event.event_type,
|
||||
already_addressed=metadata_context.get("already_addressed")
|
||||
)
|
||||
|
||||
# Convert to OrchestratorContext if data exists
|
||||
orchestrator_context = None
|
||||
if orchestrator_context_dict:
|
||||
orchestrator_context = OrchestratorContext(**orchestrator_context_dict)
|
||||
|
||||
# 3. Analyze business impact
|
||||
business_impact_dict = self.impact_analyzer.analyze(
|
||||
event_type=event.event_type,
|
||||
metadata=event.metadata
|
||||
)
|
||||
business_impact = BusinessImpact(**business_impact_dict)
|
||||
|
||||
# 4. Assess urgency
|
||||
urgency_dict = self.urgency_analyzer.analyze(
|
||||
event_type=event.event_type,
|
||||
metadata=event.metadata
|
||||
)
|
||||
urgency = Urgency(**urgency_dict)
|
||||
|
||||
# 5. Determine user agency
|
||||
user_agency_dict = self.agency_analyzer.analyze(
|
||||
event_type=event.event_type,
|
||||
metadata=event.metadata,
|
||||
orchestrator_context=orchestrator_context_dict
|
||||
)
|
||||
user_agency = UserAgency(**user_agency_dict)
|
||||
|
||||
# 6. Calculate priority score (0-100)
|
||||
priority_score = self.priority_scorer.calculate_priority(
|
||||
business_impact=business_impact_dict,
|
||||
urgency=urgency_dict,
|
||||
user_agency=user_agency_dict,
|
||||
orchestrator_context=orchestrator_context_dict
|
||||
)
|
||||
|
||||
# 7. Determine priority level
|
||||
priority_level = self._get_priority_level(priority_score)
|
||||
|
||||
# 8. Generate smart actions
|
||||
smart_actions = self.action_gen.generate_actions(
|
||||
event_type=event.event_type,
|
||||
metadata=event.metadata,
|
||||
orchestrator_context=orchestrator_context_dict
|
||||
)
|
||||
|
||||
# 9. Determine type class
|
||||
type_class = self._determine_type_class(orchestrator_context_dict, event.metadata)
|
||||
|
||||
# 10. Extract AI reasoning from metadata (if present)
|
||||
reasoning_data = event.metadata.get('reasoning_data')
|
||||
ai_reasoning_details = None
|
||||
confidence_score = None
|
||||
|
||||
if reasoning_data:
|
||||
# Store the complete reasoning data structure
|
||||
ai_reasoning_details = reasoning_data
|
||||
|
||||
# Extract confidence if available
|
||||
if isinstance(reasoning_data, dict):
|
||||
metadata_section = reasoning_data.get('metadata', {})
|
||||
if isinstance(metadata_section, dict) and 'confidence' in metadata_section:
|
||||
confidence_score = metadata_section.get('confidence')
|
||||
|
||||
# 11. Build enriched event
|
||||
enriched = EnrichedEvent(
|
||||
id=str(uuid4()),
|
||||
tenant_id=event.tenant_id,
|
||||
event_class=event.event_class,
|
||||
event_domain=event.event_domain,
|
||||
event_type=event.event_type,
|
||||
service=event.service,
|
||||
i18n=i18n,
|
||||
priority_score=priority_score,
|
||||
priority_level=priority_level,
|
||||
type_class=type_class,
|
||||
orchestrator_context=orchestrator_context,
|
||||
business_impact=business_impact,
|
||||
urgency=urgency,
|
||||
user_agency=user_agency,
|
||||
smart_actions=smart_actions,
|
||||
ai_reasoning_details=ai_reasoning_details,
|
||||
confidence_score=confidence_score,
|
||||
entity_links=self._extract_entity_links(event.metadata),
|
||||
status="active",
|
||||
event_metadata=event.metadata
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"enrichment_completed",
|
||||
event_type=event.event_type,
|
||||
priority_score=priority_score,
|
||||
priority_level=priority_level,
|
||||
type_class=type_class
|
||||
)
|
||||
|
||||
return enriched
|
||||
|
||||
def _get_priority_level(self, score: int) -> str:
|
||||
"""
|
||||
Convert numeric score to priority level.
|
||||
|
||||
- 90-100: critical
|
||||
- 70-89: important
|
||||
- 50-69: standard
|
||||
- 0-49: info
|
||||
"""
|
||||
if score >= 90:
|
||||
return "critical"
|
||||
elif score >= 70:
|
||||
return "important"
|
||||
elif score >= 50:
|
||||
return "standard"
|
||||
else:
|
||||
return "info"
|
||||
|
||||
def _determine_type_class(self, orchestrator_context: dict, metadata: dict = None) -> str:
|
||||
"""
|
||||
Determine type class based on orchestrator context or metadata override.
|
||||
|
||||
Priority order:
|
||||
1. Explicit type_class in metadata (e.g., from demo seeder)
|
||||
2. orchestrator_context.already_addressed = True -> "prevented_issue"
|
||||
3. Default: "action_needed"
|
||||
|
||||
- prevented_issue: AI already handled it
|
||||
- action_needed: User action required
|
||||
"""
|
||||
# Check for explicit type_class in metadata (allows demo seeder override)
|
||||
if metadata:
|
||||
explicit_type_class = metadata.get("type_class")
|
||||
if explicit_type_class in ("prevented_issue", "action_needed"):
|
||||
return explicit_type_class
|
||||
|
||||
# Determine from orchestrator context
|
||||
if orchestrator_context and orchestrator_context.get("already_addressed"):
|
||||
return "prevented_issue"
|
||||
return "action_needed"
|
||||
|
||||
def _extract_entity_links(self, metadata: dict) -> Dict[str, str]:
|
||||
"""
|
||||
Extract entity references from metadata.
|
||||
|
||||
Maps metadata keys to entity types for frontend deep linking.
|
||||
"""
|
||||
links = {}
|
||||
|
||||
# Map metadata keys to entity types
|
||||
entity_mappings = {
|
||||
"po_id": "purchase_order",
|
||||
"batch_id": "production_batch",
|
||||
"ingredient_id": "ingredient",
|
||||
"order_id": "order",
|
||||
"supplier_id": "supplier",
|
||||
"equipment_id": "equipment",
|
||||
"sensor_id": "sensor"
|
||||
}
|
||||
|
||||
for key, entity_type in entity_mappings.items():
|
||||
if key in metadata:
|
||||
links[entity_type] = str(metadata[key])
|
||||
|
||||
return links
|
||||
129
services/alert_processor/app/services/sse_service.py
Normal file
129
services/alert_processor/app/services/sse_service.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""
|
||||
Server-Sent Events (SSE) service using Redis pub/sub.
|
||||
"""
|
||||
|
||||
from typing import AsyncGenerator
|
||||
import json
|
||||
import structlog
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from app.core.config import settings
|
||||
from app.models.events import Event
|
||||
from shared.redis_utils import get_redis_client
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class SSEService:
|
||||
"""
|
||||
Manage real-time event streaming via Redis pub/sub.
|
||||
|
||||
Pattern: alerts:{tenant_id}
|
||||
"""
|
||||
|
||||
def __init__(self, redis: Redis = None):
|
||||
self._redis = redis # Use private attribute to allow lazy loading
|
||||
self.prefix = settings.REDIS_SSE_PREFIX
|
||||
|
||||
@property
|
||||
async def redis(self) -> Redis:
|
||||
"""
|
||||
Lazy load Redis client if not provided through dependency injection.
|
||||
Uses the shared Redis utilities for consistency.
|
||||
"""
|
||||
if self._redis is None:
|
||||
self._redis = await get_redis_client()
|
||||
return self._redis
|
||||
|
||||
async def publish_event(self, event: Event) -> bool:
|
||||
"""
|
||||
Publish event to Redis for SSE streaming.
|
||||
|
||||
Args:
|
||||
event: Event to publish
|
||||
|
||||
Returns:
|
||||
True if published successfully
|
||||
"""
|
||||
try:
|
||||
redis_client = await self.redis
|
||||
|
||||
# Build channel name
|
||||
channel = f"{self.prefix}:{event.tenant_id}"
|
||||
|
||||
# Build message payload
|
||||
payload = {
|
||||
"id": str(event.id),
|
||||
"tenant_id": str(event.tenant_id),
|
||||
"event_class": event.event_class,
|
||||
"event_domain": event.event_domain,
|
||||
"event_type": event.event_type,
|
||||
"priority_score": event.priority_score,
|
||||
"priority_level": event.priority_level,
|
||||
"type_class": event.type_class,
|
||||
"status": event.status,
|
||||
"created_at": event.created_at.isoformat(),
|
||||
"i18n": {
|
||||
"title_key": event.i18n_title_key,
|
||||
"title_params": event.i18n_title_params,
|
||||
"message_key": event.i18n_message_key,
|
||||
"message_params": event.i18n_message_params
|
||||
},
|
||||
"smart_actions": event.smart_actions,
|
||||
"entity_links": event.entity_links
|
||||
}
|
||||
|
||||
# Publish to Redis
|
||||
await redis_client.publish(channel, json.dumps(payload))
|
||||
|
||||
logger.debug(
|
||||
"sse_event_published",
|
||||
channel=channel,
|
||||
event_type=event.event_type,
|
||||
event_id=str(event.id)
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"sse_publish_failed",
|
||||
error=str(e),
|
||||
event_id=str(event.id)
|
||||
)
|
||||
return False
|
||||
|
||||
async def subscribe_to_tenant(
|
||||
self,
|
||||
tenant_id: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Subscribe to tenant's alert stream.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID
|
||||
|
||||
Yields:
|
||||
JSON-encoded event messages
|
||||
"""
|
||||
redis_client = await self.redis
|
||||
channel = f"{self.prefix}:{tenant_id}"
|
||||
|
||||
logger.info("sse_subscription_started", channel=channel)
|
||||
|
||||
# Subscribe to Redis channel
|
||||
pubsub = redis_client.pubsub()
|
||||
await pubsub.subscribe(channel)
|
||||
|
||||
try:
|
||||
async for message in pubsub.listen():
|
||||
if message["type"] == "message":
|
||||
yield message["data"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error("sse_subscription_error", error=str(e), channel=channel)
|
||||
raise
|
||||
finally:
|
||||
await pubsub.unsubscribe(channel)
|
||||
await pubsub.close()
|
||||
logger.info("sse_subscription_closed", channel=channel)
|
||||
0
services/alert_processor/app/utils/__init__.py
Normal file
0
services/alert_processor/app/utils/__init__.py
Normal file
556
services/alert_processor/app/utils/message_templates.py
Normal file
556
services/alert_processor/app/utils/message_templates.py
Normal file
@@ -0,0 +1,556 @@
|
||||
"""
|
||||
Alert type definitions with i18n key mappings.
|
||||
|
||||
Each alert type maps to:
|
||||
- title_key: i18n key for title (e.g., "alerts.critical_stock_shortage.title")
|
||||
- title_params: parameter mappings from metadata to i18n params
|
||||
- message_variants: different message keys based on context
|
||||
- message_params: parameter mappings for message
|
||||
|
||||
When adding new alert types:
|
||||
1. Add entry to ALERT_TEMPLATES
|
||||
2. Ensure corresponding translations exist in frontend/src/locales/*/alerts.json
|
||||
3. Document required metadata fields
|
||||
"""
|
||||
|
||||
# Alert type templates
|
||||
ALERT_TEMPLATES = {
|
||||
# ==================== INVENTORY ALERTS ====================
|
||||
|
||||
"critical_stock_shortage": {
|
||||
"title_key": "alerts.critical_stock_shortage.title",
|
||||
"title_params": {
|
||||
"ingredient_name": "ingredient_name"
|
||||
},
|
||||
"message_variants": {
|
||||
"with_po_pending": "alerts.critical_stock_shortage.message_with_po_pending",
|
||||
"with_po_created": "alerts.critical_stock_shortage.message_with_po_created",
|
||||
"with_hours": "alerts.critical_stock_shortage.message_with_hours",
|
||||
"with_date": "alerts.critical_stock_shortage.message_with_date",
|
||||
"generic": "alerts.critical_stock_shortage.message_generic"
|
||||
},
|
||||
"message_params": {
|
||||
"ingredient_name": "ingredient_name",
|
||||
"current_stock_kg": "current_stock",
|
||||
"required_stock_kg": "required_stock",
|
||||
"hours_until": "hours_until",
|
||||
"production_day_name": "production_date",
|
||||
"po_id": "po_id",
|
||||
"po_amount": "po_amount",
|
||||
"delivery_day_name": "delivery_date"
|
||||
}
|
||||
},
|
||||
|
||||
"low_stock_warning": {
|
||||
"title_key": "alerts.low_stock.title",
|
||||
"title_params": {
|
||||
"ingredient_name": "ingredient_name"
|
||||
},
|
||||
"message_variants": {
|
||||
"with_po": "alerts.low_stock.message_with_po",
|
||||
"generic": "alerts.low_stock.message_generic"
|
||||
},
|
||||
"message_params": {
|
||||
"ingredient_name": "ingredient_name",
|
||||
"current_stock_kg": "current_stock",
|
||||
"minimum_stock_kg": "minimum_stock"
|
||||
}
|
||||
},
|
||||
|
||||
"overstock_warning": {
|
||||
"title_key": "alerts.overstock_warning.title",
|
||||
"title_params": {
|
||||
"ingredient_name": "ingredient_name"
|
||||
},
|
||||
"message_variants": {
|
||||
"generic": "alerts.overstock_warning.message"
|
||||
},
|
||||
"message_params": {
|
||||
"ingredient_name": "ingredient_name",
|
||||
"current_stock_kg": "current_stock",
|
||||
"maximum_stock_kg": "maximum_stock",
|
||||
"excess_amount_kg": "excess_amount"
|
||||
}
|
||||
},
|
||||
|
||||
"expired_products": {
|
||||
"title_key": "alerts.expired_products.title",
|
||||
"title_params": {
|
||||
"count": "expired_count"
|
||||
},
|
||||
"message_variants": {
|
||||
"with_names": "alerts.expired_products.message_with_names",
|
||||
"generic": "alerts.expired_products.message_generic"
|
||||
},
|
||||
"message_params": {
|
||||
"expired_count": "expired_count",
|
||||
"product_names": "product_names",
|
||||
"total_value_eur": "total_value"
|
||||
}
|
||||
},
|
||||
|
||||
"urgent_expiry": {
|
||||
"title_key": "alerts.urgent_expiry.title",
|
||||
"title_params": {
|
||||
"ingredient_name": "ingredient_name"
|
||||
},
|
||||
"message_variants": {
|
||||
"generic": "alerts.urgent_expiry.message"
|
||||
},
|
||||
"message_params": {
|
||||
"ingredient_name": "ingredient_name",
|
||||
"days_until_expiry": "days_until_expiry",
|
||||
"quantity_kg": "quantity"
|
||||
}
|
||||
},
|
||||
|
||||
"temperature_breach": {
|
||||
"title_key": "alerts.temperature_breach.title",
|
||||
"title_params": {
|
||||
"location": "location"
|
||||
},
|
||||
"message_variants": {
|
||||
"generic": "alerts.temperature_breach.message"
|
||||
},
|
||||
"message_params": {
|
||||
"location": "location",
|
||||
"temperature": "temperature",
|
||||
"max_threshold": "max_threshold",
|
||||
"duration_minutes": "duration_minutes"
|
||||
}
|
||||
},
|
||||
|
||||
"stock_depleted_by_order": {
|
||||
"title_key": "alerts.stock_depleted_by_order.title",
|
||||
"title_params": {
|
||||
"ingredient_name": "ingredient_name"
|
||||
},
|
||||
"message_variants": {
|
||||
"with_supplier": "alerts.stock_depleted_by_order.message_with_supplier",
|
||||
"generic": "alerts.stock_depleted_by_order.message_generic"
|
||||
},
|
||||
"message_params": {
|
||||
"ingredient_name": "ingredient_name",
|
||||
"shortage_kg": "shortage_amount",
|
||||
"supplier_name": "supplier_name",
|
||||
"supplier_contact": "supplier_contact"
|
||||
}
|
||||
},
|
||||
|
||||
# ==================== PRODUCTION ALERTS ====================
|
||||
|
||||
"production_delay": {
|
||||
"title_key": "alerts.production_delay.title",
|
||||
"title_params": {
|
||||
"product_name": "product_name",
|
||||
"batch_number": "batch_number"
|
||||
},
|
||||
"message_variants": {
|
||||
"with_customers": "alerts.production_delay.message_with_customers",
|
||||
"with_orders": "alerts.production_delay.message_with_orders",
|
||||
"generic": "alerts.production_delay.message_generic"
|
||||
},
|
||||
"message_params": {
|
||||
"product_name": "product_name",
|
||||
"batch_number": "batch_number",
|
||||
"delay_minutes": "delay_minutes",
|
||||
"affected_orders": "affected_orders",
|
||||
"customer_names": "customer_names"
|
||||
}
|
||||
},
|
||||
|
||||
"equipment_failure": {
|
||||
"title_key": "alerts.equipment_failure.title",
|
||||
"title_params": {
|
||||
"equipment_name": "equipment_name"
|
||||
},
|
||||
"message_variants": {
|
||||
"with_batches": "alerts.equipment_failure.message_with_batches",
|
||||
"generic": "alerts.equipment_failure.message_generic"
|
||||
},
|
||||
"message_params": {
|
||||
"equipment_name": "equipment_name",
|
||||
"equipment_type": "equipment_type",
|
||||
"affected_batches": "affected_batches"
|
||||
}
|
||||
},
|
||||
|
||||
"maintenance_required": {
|
||||
"title_key": "alerts.maintenance_required.title",
|
||||
"title_params": {
|
||||
"equipment_name": "equipment_name"
|
||||
},
|
||||
"message_variants": {
|
||||
"with_hours": "alerts.maintenance_required.message_with_hours",
|
||||
"with_days": "alerts.maintenance_required.message_with_days",
|
||||
"generic": "alerts.maintenance_required.message_generic"
|
||||
},
|
||||
"message_params": {
|
||||
"equipment_name": "equipment_name",
|
||||
"hours_overdue": "hours_overdue",
|
||||
"days_overdue": "days_overdue"
|
||||
}
|
||||
},
|
||||
|
||||
"low_equipment_efficiency": {
|
||||
"title_key": "alerts.low_equipment_efficiency.title",
|
||||
"title_params": {
|
||||
"equipment_name": "equipment_name"
|
||||
},
|
||||
"message_variants": {
|
||||
"generic": "alerts.low_equipment_efficiency.message"
|
||||
},
|
||||
"message_params": {
|
||||
"equipment_name": "equipment_name",
|
||||
"efficiency_percentage": "efficiency_percentage",
|
||||
"target_efficiency": "target_efficiency"
|
||||
}
|
||||
},
|
||||
|
||||
"capacity_overload": {
|
||||
"title_key": "alerts.capacity_overload.title",
|
||||
"title_params": {
|
||||
"date": "planned_date"
|
||||
},
|
||||
"message_variants": {
|
||||
"generic": "alerts.capacity_overload.message"
|
||||
},
|
||||
"message_params": {
|
||||
"planned_date": "planned_date",
|
||||
"capacity_percentage": "capacity_percentage",
|
||||
"equipment_count": "equipment_count"
|
||||
}
|
||||
},
|
||||
|
||||
"quality_control_failure": {
|
||||
"title_key": "alerts.quality_control_failure.title",
|
||||
"title_params": {
|
||||
"product_name": "product_name",
|
||||
"batch_number": "batch_number"
|
||||
},
|
||||
"message_variants": {
|
||||
"generic": "alerts.quality_control_failure.message"
|
||||
},
|
||||
"message_params": {
|
||||
"product_name": "product_name",
|
||||
"batch_number": "batch_number",
|
||||
"check_type": "check_type",
|
||||
"quality_score": "quality_score",
|
||||
"defect_count": "defect_count"
|
||||
}
|
||||
},
|
||||
|
||||
# ==================== PROCUREMENT ALERTS ====================
|
||||
|
||||
"po_approval_needed": {
|
||||
"title_key": "alerts.po_approval_needed.title",
|
||||
"title_params": {
|
||||
"po_number": "po_number"
|
||||
},
|
||||
"message_variants": {
|
||||
"generic": "alerts.po_approval_needed.message"
|
||||
},
|
||||
"message_params": {
|
||||
"supplier_name": "supplier_name",
|
||||
"total_amount": "total_amount",
|
||||
"currency": "currency",
|
||||
"required_delivery_date": "required_delivery_date",
|
||||
"items_count": "items_count"
|
||||
}
|
||||
},
|
||||
|
||||
"po_approval_escalation": {
|
||||
"title_key": "alerts.po_approval_escalation.title",
|
||||
"title_params": {
|
||||
"po_number": "po_number"
|
||||
},
|
||||
"message_variants": {
|
||||
"generic": "alerts.po_approval_escalation.message"
|
||||
},
|
||||
"message_params": {
|
||||
"po_number": "po_number",
|
||||
"supplier_name": "supplier_name",
|
||||
"hours_pending": "hours_pending",
|
||||
"total_amount": "total_amount"
|
||||
}
|
||||
},
|
||||
|
||||
"delivery_overdue": {
|
||||
"title_key": "alerts.delivery_overdue.title",
|
||||
"title_params": {
|
||||
"po_number": "po_number"
|
||||
},
|
||||
"message_variants": {
|
||||
"generic": "alerts.delivery_overdue.message"
|
||||
},
|
||||
"message_params": {
|
||||
"po_number": "po_number",
|
||||
"supplier_name": "supplier_name",
|
||||
"days_overdue": "days_overdue",
|
||||
"expected_date": "expected_date"
|
||||
}
|
||||
},
|
||||
|
||||
# ==================== SUPPLY CHAIN ALERTS ====================
|
||||
|
||||
"supplier_delay": {
|
||||
"title_key": "alerts.supplier_delay.title",
|
||||
"title_params": {
|
||||
"supplier_name": "supplier_name"
|
||||
},
|
||||
"message_variants": {
|
||||
"generic": "alerts.supplier_delay.message"
|
||||
},
|
||||
"message_params": {
|
||||
"supplier_name": "supplier_name",
|
||||
"po_count": "po_count",
|
||||
"avg_delay_days": "avg_delay_days"
|
||||
}
|
||||
},
|
||||
|
||||
# ==================== DEMAND ALERTS ====================
|
||||
|
||||
"demand_surge_weekend": {
|
||||
"title_key": "alerts.demand_surge_weekend.title",
|
||||
"title_params": {},
|
||||
"message_variants": {
|
||||
"generic": "alerts.demand_surge_weekend.message"
|
||||
},
|
||||
"message_params": {
|
||||
"product_name": "product_name",
|
||||
"predicted_demand": "predicted_demand",
|
||||
"current_stock": "current_stock"
|
||||
}
|
||||
},
|
||||
|
||||
"weather_impact_alert": {
|
||||
"title_key": "alerts.weather_impact_alert.title",
|
||||
"title_params": {},
|
||||
"message_variants": {
|
||||
"generic": "alerts.weather_impact_alert.message"
|
||||
},
|
||||
"message_params": {
|
||||
"weather_condition": "weather_condition",
|
||||
"impact_percentage": "impact_percentage",
|
||||
"date": "date"
|
||||
}
|
||||
},
|
||||
|
||||
# ==================== PRODUCTION BATCH ALERTS ====================
|
||||
|
||||
"production_batch_start": {
|
||||
"title_key": "alerts.production_batch_start.title",
|
||||
"title_params": {
|
||||
"product_name": "product_name"
|
||||
},
|
||||
"message_variants": {
|
||||
"generic": "alerts.production_batch_start.message"
|
||||
},
|
||||
"message_params": {
|
||||
"product_name": "product_name",
|
||||
"batch_number": "batch_number",
|
||||
"quantity_planned": "quantity_planned",
|
||||
"unit": "unit",
|
||||
"priority": "priority"
|
||||
}
|
||||
},
|
||||
|
||||
# ==================== GENERIC FALLBACK ====================
|
||||
|
||||
"generic": {
|
||||
"title_key": "alerts.generic.title",
|
||||
"title_params": {},
|
||||
"message_variants": {
|
||||
"generic": "alerts.generic.message"
|
||||
},
|
||||
"message_params": {
|
||||
"event_type": "event_type"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Notification templates (informational events)
|
||||
NOTIFICATION_TEMPLATES = {
|
||||
"po_approved": {
|
||||
"title_key": "notifications.po_approved.title",
|
||||
"title_params": {
|
||||
"po_number": "po_number"
|
||||
},
|
||||
"message_variants": {
|
||||
"generic": "notifications.po_approved.message"
|
||||
},
|
||||
"message_params": {
|
||||
"supplier_name": "supplier_name",
|
||||
"total_amount": "total_amount"
|
||||
}
|
||||
},
|
||||
|
||||
"batch_state_changed": {
|
||||
"title_key": "notifications.batch_state_changed.title",
|
||||
"title_params": {
|
||||
"product_name": "product_name"
|
||||
},
|
||||
"message_variants": {
|
||||
"generic": "notifications.batch_state_changed.message"
|
||||
},
|
||||
"message_params": {
|
||||
"batch_number": "batch_number",
|
||||
"new_status": "new_status",
|
||||
"quantity": "quantity",
|
||||
"unit": "unit"
|
||||
}
|
||||
},
|
||||
|
||||
"stock_received": {
|
||||
"title_key": "notifications.stock_received.title",
|
||||
"title_params": {
|
||||
"ingredient_name": "ingredient_name"
|
||||
},
|
||||
"message_variants": {
|
||||
"generic": "notifications.stock_received.message"
|
||||
},
|
||||
"message_params": {
|
||||
"quantity_received": "quantity_received",
|
||||
"unit": "unit",
|
||||
"supplier_name": "supplier_name"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Recommendation templates (optimization suggestions)
|
||||
RECOMMENDATION_TEMPLATES = {
|
||||
"inventory_optimization": {
|
||||
"title_key": "recommendations.inventory_optimization.title",
|
||||
"title_params": {
|
||||
"ingredient_name": "ingredient_name"
|
||||
},
|
||||
"message_variants": {
|
||||
"generic": "recommendations.inventory_optimization.message"
|
||||
},
|
||||
"message_params": {
|
||||
"ingredient_name": "ingredient_name",
|
||||
"current_max_kg": "current_max",
|
||||
"suggested_max_kg": "suggested_max",
|
||||
"recommendation_type": "recommendation_type"
|
||||
}
|
||||
},
|
||||
|
||||
"production_efficiency": {
|
||||
"title_key": "recommendations.production_efficiency.title",
|
||||
"title_params": {
|
||||
"product_name": "product_name"
|
||||
},
|
||||
"message_variants": {
|
||||
"generic": "recommendations.production_efficiency.message"
|
||||
},
|
||||
"message_params": {
|
||||
"product_name": "product_name",
|
||||
"potential_time_saved_minutes": "time_saved",
|
||||
"suggestion": "suggestion"
|
||||
}
|
||||
},
|
||||
|
||||
# ==================== AI INSIGHTS RECOMMENDATIONS ====================
|
||||
|
||||
"ai_yield_prediction": {
|
||||
"title_key": "recommendations.ai_yield_prediction.title",
|
||||
"title_params": {
|
||||
"recipe_name": "recipe_name"
|
||||
},
|
||||
"message_variants": {
|
||||
"generic": "recommendations.ai_yield_prediction.message"
|
||||
},
|
||||
"message_params": {
|
||||
"recipe_name": "recipe_name",
|
||||
"predicted_yield_percent": "predicted_yield",
|
||||
"confidence_percent": "confidence",
|
||||
"recommendation": "recommendation"
|
||||
}
|
||||
},
|
||||
|
||||
"ai_safety_stock_optimization": {
|
||||
"title_key": "recommendations.ai_safety_stock_optimization.title",
|
||||
"title_params": {
|
||||
"ingredient_name": "ingredient_name"
|
||||
},
|
||||
"message_variants": {
|
||||
"generic": "recommendations.ai_safety_stock_optimization.message"
|
||||
},
|
||||
"message_params": {
|
||||
"ingredient_name": "ingredient_name",
|
||||
"suggested_safety_stock_kg": "suggested_safety_stock",
|
||||
"current_safety_stock_kg": "current_safety_stock",
|
||||
"estimated_savings_eur": "estimated_savings",
|
||||
"confidence_percent": "confidence"
|
||||
}
|
||||
},
|
||||
|
||||
"ai_supplier_recommendation": {
|
||||
"title_key": "recommendations.ai_supplier_recommendation.title",
|
||||
"title_params": {
|
||||
"supplier_name": "supplier_name"
|
||||
},
|
||||
"message_variants": {
|
||||
"generic": "recommendations.ai_supplier_recommendation.message"
|
||||
},
|
||||
"message_params": {
|
||||
"supplier_name": "supplier_name",
|
||||
"reliability_score": "reliability_score",
|
||||
"recommendation": "recommendation",
|
||||
"confidence_percent": "confidence"
|
||||
}
|
||||
},
|
||||
|
||||
"ai_price_forecast": {
|
||||
"title_key": "recommendations.ai_price_forecast.title",
|
||||
"title_params": {
|
||||
"ingredient_name": "ingredient_name"
|
||||
},
|
||||
"message_variants": {
|
||||
"generic": "recommendations.ai_price_forecast.message"
|
||||
},
|
||||
"message_params": {
|
||||
"ingredient_name": "ingredient_name",
|
||||
"predicted_price_eur": "predicted_price",
|
||||
"current_price_eur": "current_price",
|
||||
"price_trend": "price_trend",
|
||||
"recommendation": "recommendation",
|
||||
"confidence_percent": "confidence"
|
||||
}
|
||||
},
|
||||
|
||||
"ai_demand_forecast": {
|
||||
"title_key": "recommendations.ai_demand_forecast.title",
|
||||
"title_params": {
|
||||
"product_name": "product_name"
|
||||
},
|
||||
"message_variants": {
|
||||
"generic": "recommendations.ai_demand_forecast.message"
|
||||
},
|
||||
"message_params": {
|
||||
"product_name": "product_name",
|
||||
"predicted_demand": "predicted_demand",
|
||||
"forecast_period": "forecast_period",
|
||||
"confidence_percent": "confidence",
|
||||
"recommendation": "recommendation"
|
||||
}
|
||||
},
|
||||
|
||||
"ai_business_rule": {
|
||||
"title_key": "recommendations.ai_business_rule.title",
|
||||
"title_params": {
|
||||
"rule_category": "rule_category"
|
||||
},
|
||||
"message_variants": {
|
||||
"generic": "recommendations.ai_business_rule.message"
|
||||
},
|
||||
"message_params": {
|
||||
"rule_category": "rule_category",
|
||||
"rule_description": "rule_description",
|
||||
"confidence_percent": "confidence",
|
||||
"recommendation": "recommendation"
|
||||
}
|
||||
}
|
||||
}
|
||||
134
services/alert_processor/migrations/env.py
Normal file
134
services/alert_processor/migrations/env.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""Alembic environment configuration for alert_processor service"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from logging.config import fileConfig
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||
from alembic import context
|
||||
|
||||
# Add the service directory to the Python path
|
||||
service_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
if service_path not in sys.path:
|
||||
sys.path.insert(0, service_path)
|
||||
|
||||
# Add shared modules to path
|
||||
shared_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "shared"))
|
||||
if shared_path not in sys.path:
|
||||
sys.path.insert(0, shared_path)
|
||||
|
||||
try:
|
||||
from shared.database.base import Base
|
||||
|
||||
# Import all models to ensure they are registered with Base.metadata
|
||||
from app.models import * # noqa: F401, F403
|
||||
|
||||
except ImportError as e:
|
||||
print(f"Import error in migrations env.py: {e}")
|
||||
print(f"Current Python path: {sys.path}")
|
||||
raise
|
||||
|
||||
# this is the Alembic Config object
|
||||
config = context.config
|
||||
|
||||
# Determine service name from file path
|
||||
service_name = os.path.basename(os.path.dirname(os.path.dirname(__file__)))
|
||||
service_name_upper = service_name.upper().replace('-', '_')
|
||||
|
||||
# Set database URL from environment variables with multiple fallback strategies
|
||||
database_url = (
|
||||
os.getenv(f'{service_name_upper}_DATABASE_URL') or # Service-specific
|
||||
os.getenv('DATABASE_URL') # Generic fallback
|
||||
)
|
||||
|
||||
# If DATABASE_URL is not set, construct from individual components
|
||||
if not database_url:
|
||||
# Try generic PostgreSQL environment variables first
|
||||
postgres_host = os.getenv('POSTGRES_HOST')
|
||||
postgres_port = os.getenv('POSTGRES_PORT', '5432')
|
||||
postgres_db = os.getenv('POSTGRES_DB')
|
||||
postgres_user = os.getenv('POSTGRES_USER')
|
||||
postgres_password = os.getenv('POSTGRES_PASSWORD')
|
||||
|
||||
if all([postgres_host, postgres_db, postgres_user, postgres_password]):
|
||||
database_url = f"postgresql+asyncpg://{postgres_user}:{postgres_password}@{postgres_host}:{postgres_port}/{postgres_db}"
|
||||
else:
|
||||
# Try service-specific environment variables
|
||||
db_host = os.getenv(f'{service_name_upper}_DB_HOST', f'{service_name}-db-service')
|
||||
db_port = os.getenv(f'{service_name_upper}_DB_PORT', '5432')
|
||||
db_name = os.getenv(f'{service_name_upper}_DB_NAME', f'{service_name.replace("-", "_")}_db')
|
||||
db_user = os.getenv(f'{service_name_upper}_DB_USER', f'{service_name.replace("-", "_")}_user')
|
||||
db_password = os.getenv(f'{service_name_upper}_DB_PASSWORD')
|
||||
|
||||
if db_password:
|
||||
database_url = f"postgresql+asyncpg://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}"
|
||||
|
||||
if not database_url:
|
||||
error_msg = f"ERROR: No database URL configured for {service_name} service"
|
||||
print(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
config.set_main_option("sqlalchemy.url", database_url)
|
||||
|
||||
# Interpret the config file for Python logging
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# Set target metadata
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode."""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
compare_type=True,
|
||||
compare_server_default=True,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
"""Execute migrations with the given connection."""
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
compare_type=True,
|
||||
compare_server_default=True,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
"""Run migrations in 'online' mode with async support."""
|
||||
connectable = async_engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
|
||||
await connectable.dispose()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode."""
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
26
services/alert_processor/migrations/script.py.mako
Normal file
26
services/alert_processor/migrations/script.py.mako
Normal file
@@ -0,0 +1,26 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Clean unified events table schema.
|
||||
|
||||
Revision ID: 20251205_unified
|
||||
Revises:
|
||||
Create Date: 2025-12-05
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers
|
||||
revision = '20251205_unified'
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
"""
|
||||
Create unified events table with JSONB enrichment contexts.
|
||||
"""
|
||||
|
||||
# Create events table
|
||||
op.create_table(
|
||||
'events',
|
||||
|
||||
# Core fields
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column('tenant_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
|
||||
# Classification
|
||||
sa.Column('event_class', sa.String(50), nullable=False),
|
||||
sa.Column('event_domain', sa.String(50), nullable=False),
|
||||
sa.Column('event_type', sa.String(100), nullable=False),
|
||||
sa.Column('service', sa.String(50), nullable=False),
|
||||
|
||||
# i18n content (NO hardcoded title/message)
|
||||
sa.Column('i18n_title_key', sa.String(200), nullable=False),
|
||||
sa.Column('i18n_title_params', postgresql.JSONB, nullable=False, server_default=sa.text("'{}'::jsonb")),
|
||||
sa.Column('i18n_message_key', sa.String(200), nullable=False),
|
||||
sa.Column('i18n_message_params', postgresql.JSONB, nullable=False, server_default=sa.text("'{}'::jsonb")),
|
||||
|
||||
# Priority
|
||||
sa.Column('priority_score', sa.Integer, nullable=False, server_default='50'),
|
||||
sa.Column('priority_level', sa.String(20), nullable=False),
|
||||
sa.Column('type_class', sa.String(50), nullable=False),
|
||||
|
||||
# Enrichment contexts (JSONB)
|
||||
sa.Column('orchestrator_context', postgresql.JSONB, nullable=True),
|
||||
sa.Column('business_impact', postgresql.JSONB, nullable=True),
|
||||
sa.Column('urgency', postgresql.JSONB, nullable=True),
|
||||
sa.Column('user_agency', postgresql.JSONB, nullable=True),
|
||||
sa.Column('trend_context', postgresql.JSONB, nullable=True),
|
||||
|
||||
# Smart actions
|
||||
sa.Column('smart_actions', postgresql.JSONB, nullable=False, server_default=sa.text("'[]'::jsonb")),
|
||||
|
||||
# AI reasoning
|
||||
sa.Column('ai_reasoning_summary_key', sa.String(200), nullable=True),
|
||||
sa.Column('ai_reasoning_summary_params', postgresql.JSONB, nullable=True),
|
||||
sa.Column('ai_reasoning_details', postgresql.JSONB, nullable=True),
|
||||
sa.Column('confidence_score', sa.Float, nullable=True),
|
||||
|
||||
# Entity references
|
||||
sa.Column('entity_links', postgresql.JSONB, nullable=False, server_default=sa.text("'{}'::jsonb")),
|
||||
|
||||
# Status
|
||||
sa.Column('status', sa.String(20), nullable=False, server_default='active'),
|
||||
sa.Column('resolved_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('acknowledged_at', sa.DateTime(timezone=True), nullable=True),
|
||||
|
||||
# Metadata
|
||||
sa.Column('event_metadata', postgresql.JSONB, nullable=False, server_default=sa.text("'{}'::jsonb"))
|
||||
)
|
||||
|
||||
# Create indexes for efficient queries (matching SQLAlchemy model)
|
||||
op.create_index('idx_events_tenant_status', 'events', ['tenant_id', 'status'])
|
||||
op.create_index('idx_events_tenant_priority', 'events', ['tenant_id', 'priority_score'])
|
||||
op.create_index('idx_events_tenant_class', 'events', ['tenant_id', 'event_class'])
|
||||
op.create_index('idx_events_tenant_created', 'events', ['tenant_id', 'created_at'])
|
||||
op.create_index('idx_events_type_class_status', 'events', ['type_class', 'status'])
|
||||
|
||||
|
||||
def downgrade():
|
||||
"""
|
||||
Drop events table and all indexes.
|
||||
"""
|
||||
op.drop_index('idx_events_type_class_status', 'events')
|
||||
op.drop_index('idx_events_tenant_created', 'events')
|
||||
op.drop_index('idx_events_tenant_class', 'events')
|
||||
op.drop_index('idx_events_tenant_priority', 'events')
|
||||
op.drop_index('idx_events_tenant_status', 'events')
|
||||
op.drop_table('events')
|
||||
45
services/alert_processor/requirements.txt
Normal file
45
services/alert_processor/requirements.txt
Normal file
@@ -0,0 +1,45 @@
|
||||
# Alert Processor Service v2.0 Dependencies
|
||||
|
||||
# FastAPI and server
|
||||
fastapi==0.104.1
|
||||
uvicorn[standard]==0.24.0
|
||||
python-multipart==0.0.6
|
||||
|
||||
# Database
|
||||
sqlalchemy[asyncio]==2.0.23
|
||||
asyncpg==0.29.0
|
||||
alembic==1.12.1
|
||||
psycopg2-binary==2.9.9
|
||||
|
||||
# RabbitMQ
|
||||
aio-pika==9.3.0
|
||||
|
||||
# Redis
|
||||
redis[hiredis]==5.0.1
|
||||
|
||||
# HTTP client
|
||||
httpx==0.25.1
|
||||
|
||||
# Validation and settings
|
||||
pydantic==2.5.0
|
||||
pydantic-settings==2.1.0
|
||||
|
||||
# Structured logging
|
||||
structlog==23.2.0
|
||||
|
||||
# Utilities
|
||||
python-dateutil==2.8.2
|
||||
|
||||
# Authentication
|
||||
python-jose[cryptography]==3.3.0
|
||||
|
||||
# Monitoring and Observability
|
||||
psutil==5.9.8
|
||||
opentelemetry-api==1.39.1
|
||||
opentelemetry-sdk==1.39.1
|
||||
opentelemetry-instrumentation-fastapi==0.60b1
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.39.1
|
||||
opentelemetry-exporter-otlp-proto-http==1.39.1
|
||||
opentelemetry-instrumentation-httpx==0.60b1
|
||||
opentelemetry-instrumentation-redis==0.60b1
|
||||
opentelemetry-instrumentation-sqlalchemy==0.60b1
|
||||
64
services/auth/Dockerfile
Normal file
64
services/auth/Dockerfile
Normal file
@@ -0,0 +1,64 @@
|
||||
# =============================================================================
|
||||
# Auth Service Dockerfile - Environment-Configurable Base Images
|
||||
# =============================================================================
|
||||
# Build arguments for registry configuration:
|
||||
# - BASE_REGISTRY: Registry URL (default: docker.io for Docker Hub)
|
||||
# - PYTHON_IMAGE: Python image name and tag (default: python:3.11-slim)
|
||||
# =============================================================================
|
||||
|
||||
ARG BASE_REGISTRY=docker.io
|
||||
ARG PYTHON_IMAGE=python:3.11-slim
|
||||
|
||||
FROM ${BASE_REGISTRY}/${PYTHON_IMAGE} AS shared
|
||||
WORKDIR /shared
|
||||
COPY shared/ /shared/
|
||||
|
||||
ARG BASE_REGISTRY=docker.io
|
||||
ARG PYTHON_IMAGE=python:3.11-slim
|
||||
FROM ${BASE_REGISTRY}/${PYTHON_IMAGE}
|
||||
|
||||
# Create non-root user for security
|
||||
RUN groupadd -r appgroup && useradd -r -g appgroup appuser
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements
|
||||
COPY shared/requirements-tracing.txt /tmp/
|
||||
|
||||
COPY services/auth/requirements.txt .
|
||||
|
||||
# Install Python dependencies
|
||||
RUN pip install --no-cache-dir -r /tmp/requirements-tracing.txt
|
||||
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy shared libraries from the shared stage
|
||||
COPY --from=shared /shared /app/shared
|
||||
|
||||
# Copy application code
|
||||
COPY services/auth/ .
|
||||
|
||||
# Change ownership to non-root user
|
||||
RUN chown -R appuser:appgroup /app
|
||||
|
||||
# Add shared libraries to Python path
|
||||
ENV PYTHONPATH="/app:/app/shared:${PYTHONPATH:-}"
|
||||
|
||||
# Switch to non-root user
|
||||
USER appuser
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD curl -f http://localhost:8000/health || exit 1
|
||||
|
||||
# Run application
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
1074
services/auth/README.md
Normal file
1074
services/auth/README.md
Normal file
File diff suppressed because it is too large
Load Diff
84
services/auth/alembic.ini
Normal file
84
services/auth/alembic.ini
Normal file
@@ -0,0 +1,84 @@
|
||||
# ================================================================
|
||||
# services/auth/alembic.ini - Alembic Configuration
|
||||
# ================================================================
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
script_location = migrations
|
||||
|
||||
# template used to generate migration file names
|
||||
file_template = %%(year)d%%(month).2d%%(day).2d_%%(hour).2d%%(minute).2d_%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
prepend_sys_path = .
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
timezone = Europe/Madrid
|
||||
|
||||
# max length of characters to apply to the
|
||||
# "slug" field
|
||||
truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
sourceless = false
|
||||
|
||||
# version of a migration file's filename format
|
||||
version_num_format = %%s
|
||||
|
||||
# version path separator
|
||||
version_path_separator = os
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
output_encoding = utf-8
|
||||
|
||||
# Database URL - will be overridden by environment variable or settings
|
||||
sqlalchemy.url = postgresql+asyncpg://auth_user:password@auth-db-service:5432/auth_db
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts.
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
0
services/auth/app/__init__.py
Normal file
0
services/auth/app/__init__.py
Normal file
3
services/auth/app/api/__init__.py
Normal file
3
services/auth/app/api/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .internal_demo import router as internal_demo_router
|
||||
|
||||
__all__ = ["internal_demo_router"]
|
||||
214
services/auth/app/api/account_deletion.py
Normal file
214
services/auth/app/api/account_deletion.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""
|
||||
User self-service account deletion API for GDPR compliance
|
||||
Implements Article 17 (Right to erasure / "Right to be forgotten")
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request, BackgroundTasks
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime, timezone
|
||||
import structlog
|
||||
|
||||
from shared.auth.decorators import get_current_user_dep
|
||||
from app.core.database import get_db
|
||||
from app.services.admin_delete import AdminUserDeleteService
|
||||
from app.models.users import User
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
import httpx
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class AccountDeletionRequest(BaseModel):
|
||||
"""Request model for account deletion"""
|
||||
confirm_email: str = Field(..., description="User's email for confirmation")
|
||||
reason: str = Field(default="", description="Optional reason for deletion")
|
||||
password: str = Field(..., description="User's password for verification")
|
||||
|
||||
|
||||
class DeletionScheduleResponse(BaseModel):
|
||||
"""Response for scheduled deletion"""
|
||||
message: str
|
||||
user_id: str
|
||||
scheduled_deletion_date: str
|
||||
grace_period_days: int = 30
|
||||
|
||||
|
||||
@router.delete("/api/v1/auth/me/account")
|
||||
async def request_account_deletion(
|
||||
deletion_request: AccountDeletionRequest,
|
||||
request: Request,
|
||||
current_user: dict = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Request account deletion (self-service)
|
||||
|
||||
GDPR Article 17 - Right to erasure ("right to be forgotten")
|
||||
|
||||
This initiates account deletion with a 30-day grace period.
|
||||
During this period:
|
||||
- Account is marked for deletion
|
||||
- User can still log in and cancel deletion
|
||||
- After 30 days, account is permanently deleted
|
||||
|
||||
Requires:
|
||||
- Email confirmation matching logged-in user
|
||||
- Current password verification
|
||||
"""
|
||||
try:
|
||||
user_id = UUID(current_user["user_id"])
|
||||
user_email = current_user.get("email")
|
||||
|
||||
if deletion_request.confirm_email.lower() != user_email.lower():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email confirmation does not match your account email"
|
||||
)
|
||||
|
||||
query = select(User).where(User.id == user_id)
|
||||
result = await db.execute(query)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
from app.core.security import SecurityManager
|
||||
if not SecurityManager.verify_password(deletion_request.password, user.hashed_password):
|
||||
logger.warning(
|
||||
"account_deletion_invalid_password",
|
||||
user_id=str(user_id),
|
||||
ip_address=request.client.host if request.client else None
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid password"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"account_deletion_requested",
|
||||
user_id=str(user_id),
|
||||
email=user_email,
|
||||
reason=deletion_request.reason[:100] if deletion_request.reason else None,
|
||||
ip_address=request.client.host if request.client else None
|
||||
)
|
||||
|
||||
tenant_id = current_user.get("tenant_id")
|
||||
if tenant_id:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
cancel_response = await client.get(
|
||||
f"http://tenant-service:8000/api/v1/tenants/{tenant_id}/subscription/status",
|
||||
headers={"Authorization": request.headers.get("Authorization")}
|
||||
)
|
||||
|
||||
if cancel_response.status_code == 200:
|
||||
subscription_data = cancel_response.json()
|
||||
if subscription_data.get("status") in ["active", "pending_cancellation"]:
|
||||
cancel_sub_response = await client.delete(
|
||||
f"http://tenant-service:8000/api/v1/tenants/{tenant_id}/subscription",
|
||||
headers={"Authorization": request.headers.get("Authorization")}
|
||||
)
|
||||
logger.info(
|
||||
"subscription_cancelled_before_deletion",
|
||||
user_id=str(user_id),
|
||||
tenant_id=tenant_id,
|
||||
subscription_status=subscription_data.get("status")
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"subscription_cancellation_failed_during_account_deletion",
|
||||
user_id=str(user_id),
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
deletion_service = AdminUserDeleteService(db)
|
||||
result = await deletion_service.delete_admin_user_complete(
|
||||
user_id=str(user_id),
|
||||
requesting_user_id=str(user_id)
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "Account deleted successfully",
|
||||
"user_id": str(user_id),
|
||||
"deletion_date": datetime.now(timezone.utc).isoformat(),
|
||||
"data_retained": "Audit logs will be anonymized after legal retention period (1 year)",
|
||||
"gdpr_article": "Article 17 - Right to erasure"
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"account_deletion_failed",
|
||||
user_id=current_user.get("user_id"),
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to process account deletion request"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/api/v1/auth/me/account/deletion-info")
|
||||
async def get_deletion_info(
|
||||
current_user: dict = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get information about what will be deleted
|
||||
|
||||
Shows user exactly what data will be deleted when they request
|
||||
account deletion. Transparency requirement under GDPR.
|
||||
"""
|
||||
try:
|
||||
user_id = UUID(current_user["user_id"])
|
||||
|
||||
deletion_service = AdminUserDeleteService(db)
|
||||
preview = await deletion_service.preview_user_deletion(str(user_id))
|
||||
|
||||
return {
|
||||
"user_info": preview.get("user"),
|
||||
"what_will_be_deleted": {
|
||||
"account_data": "Your account, email, name, and profile information",
|
||||
"sessions": "All active sessions and refresh tokens",
|
||||
"consents": "Your consent history and preferences",
|
||||
"security_data": "Login history and security logs",
|
||||
"tenant_data": preview.get("tenant_associations"),
|
||||
"estimated_records": preview.get("estimated_deletions")
|
||||
},
|
||||
"what_will_be_retained": {
|
||||
"audit_logs": "Anonymized for 1 year (legal requirement)",
|
||||
"financial_records": "Anonymized for 7 years (tax law)",
|
||||
"anonymized_analytics": "Aggregated data without personal identifiers"
|
||||
},
|
||||
"process": {
|
||||
"immediate_deletion": True,
|
||||
"grace_period": "No grace period - deletion is immediate",
|
||||
"reversible": False,
|
||||
"completion_time": "Immediate"
|
||||
},
|
||||
"gdpr_rights": {
|
||||
"article_17": "Right to erasure (right to be forgotten)",
|
||||
"article_5_1_e": "Storage limitation principle",
|
||||
"exceptions": "Data required for legal obligations will be retained in anonymized form"
|
||||
},
|
||||
"warning": "⚠️ This action is irreversible. All your data will be permanently deleted."
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"deletion_info_failed",
|
||||
user_id=current_user.get("user_id"),
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve deletion information"
|
||||
)
|
||||
657
services/auth/app/api/auth_operations.py
Normal file
657
services/auth/app/api/auth_operations.py
Normal file
@@ -0,0 +1,657 @@
|
||||
"""
|
||||
Refactored Auth Operations with proper 3DS/3DS2 support
|
||||
Implements SetupIntent-first architecture for secure registration flows
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from app.services.auth_service import auth_service, AuthService
|
||||
from app.schemas.auth import UserRegistration, UserLogin, UserResponse
|
||||
from app.models.users import User
|
||||
from shared.exceptions.auth_exceptions import (
|
||||
UserCreationError,
|
||||
RegistrationError,
|
||||
PaymentOrchestrationError
|
||||
)
|
||||
from shared.auth.decorators import get_current_user_dep
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create router
|
||||
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
|
||||
|
||||
|
||||
async def get_auth_service() -> AuthService:
|
||||
"""Dependency injection for auth service"""
|
||||
return auth_service
|
||||
|
||||
|
||||
@router.post("/start-registration",
|
||||
response_model=Dict[str, Any],
|
||||
summary="Start secure registration with payment verification")
|
||||
async def start_registration(
|
||||
user_data: UserRegistration,
|
||||
request: Request,
|
||||
auth_service: AuthService = Depends(get_auth_service)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Start secure registration flow with SetupIntent-first approach
|
||||
|
||||
This is the FIRST step in the atomic registration architecture:
|
||||
1. Creates Stripe customer via tenant service
|
||||
2. Creates SetupIntent with confirm=True
|
||||
3. Returns SetupIntent data to frontend
|
||||
|
||||
IMPORTANT: NO subscription or user is created in this step!
|
||||
|
||||
Two possible outcomes:
|
||||
- requires_action=True: 3DS required, frontend must confirm SetupIntent then call complete-registration
|
||||
- requires_action=False: No 3DS required, but frontend STILL must call complete-registration
|
||||
|
||||
In BOTH cases, the frontend must call complete-registration to create the subscription and user.
|
||||
This ensures consistent flow and prevents duplicate subscriptions.
|
||||
|
||||
Args:
|
||||
user_data: User registration data with payment info
|
||||
|
||||
Returns:
|
||||
SetupIntent result with:
|
||||
- requires_action: True if 3DS required, False if not
|
||||
- setup_intent_id: SetupIntent ID for verification
|
||||
- client_secret: For 3DS authentication (when requires_action=True)
|
||||
- customer_id: Stripe customer ID
|
||||
- Other SetupIntent metadata
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 for validation errors, 500 for server errors
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Starting secure registration flow, email={user_data.email}, plan={user_data.subscription_plan}")
|
||||
|
||||
# Validate required fields
|
||||
if not user_data.email or not user_data.email.strip():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email is required"
|
||||
)
|
||||
|
||||
if not user_data.password or len(user_data.password) < 8:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Password must be at least 8 characters long"
|
||||
)
|
||||
|
||||
if not user_data.full_name or not user_data.full_name.strip():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Full name is required"
|
||||
)
|
||||
|
||||
if user_data.subscription_plan and not user_data.payment_method_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Payment method ID is required for subscription registration"
|
||||
)
|
||||
|
||||
# Start secure registration flow
|
||||
result = await auth_service.start_secure_registration_flow(user_data)
|
||||
|
||||
# Check if 3DS is required
|
||||
if result.get('requires_action', False):
|
||||
logger.info(f"Registration requires 3DS verification, email={user_data.email}, setup_intent_id={result.get('setup_intent_id')}")
|
||||
|
||||
return {
|
||||
"requires_action": True,
|
||||
"action_type": "setup_intent_confirmation",
|
||||
"client_secret": result.get('client_secret'),
|
||||
"setup_intent_id": result.get('setup_intent_id'),
|
||||
"customer_id": result.get('customer_id'),
|
||||
"payment_customer_id": result.get('payment_customer_id'),
|
||||
"plan_id": result.get('plan_id'),
|
||||
"payment_method_id": result.get('payment_method_id'),
|
||||
"billing_cycle": result.get('billing_cycle'),
|
||||
"coupon_info": result.get('coupon_info'),
|
||||
"trial_info": result.get('trial_info'),
|
||||
"email": result.get('email'),
|
||||
"message": "Payment verification required. Frontend must confirm SetupIntent to handle 3DS."
|
||||
}
|
||||
else:
|
||||
user = result.get('user')
|
||||
user_id = user.id if user else None
|
||||
logger.info(f"Registration completed without 3DS, email={user_data.email}, user_id={user_id}, subscription_id={result.get('subscription_id')}")
|
||||
|
||||
# Return complete registration result
|
||||
user_data_response = None
|
||||
if user:
|
||||
user_data_response = {
|
||||
"id": str(user.id),
|
||||
"email": user.email,
|
||||
"full_name": user.full_name,
|
||||
"is_active": user.is_active
|
||||
}
|
||||
|
||||
return {
|
||||
"requires_action": False,
|
||||
"setup_intent_id": result.get('setup_intent_id'),
|
||||
"user": user_data_response,
|
||||
"subscription_id": result.get('subscription_id'),
|
||||
"payment_customer_id": result.get('payment_customer_id'),
|
||||
"status": result.get('status'),
|
||||
"message": "Registration completed successfully"
|
||||
}
|
||||
|
||||
except RegistrationError as e:
|
||||
logger.error(f"Registration flow failed: {str(e)}, email: {user_data.email}",
|
||||
exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Registration failed: {str(e)}"
|
||||
) from e
|
||||
except PaymentOrchestrationError as e:
|
||||
logger.error(f"Payment orchestration failed: {str(e)}, email: {user_data.email}",
|
||||
exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Payment setup failed: {str(e)}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected registration error: {str(e)}, email: {user_data.email}",
|
||||
exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Registration error: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/complete-registration",
|
||||
response_model=Dict[str, Any],
|
||||
summary="Complete registration after SetupIntent verification")
|
||||
async def complete_registration(
|
||||
verification_data: Dict[str, Any],
|
||||
request: Request,
|
||||
auth_service: AuthService = Depends(get_auth_service)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Complete registration after frontend confirms SetupIntent
|
||||
|
||||
This is the SECOND step in the atomic registration architecture:
|
||||
1. Called after frontend confirms SetupIntent (with or without 3DS)
|
||||
2. Verifies SetupIntent status with Stripe
|
||||
3. Creates subscription with verified payment method (FIRST time subscription is created)
|
||||
4. Creates user record in auth database
|
||||
5. Saves onboarding progress
|
||||
6. Generates auth tokens for auto-login
|
||||
|
||||
This endpoint is called in TWO scenarios:
|
||||
1. After user completes 3DS authentication (requires_action=True flow)
|
||||
2. Immediately after start-registration (requires_action=False flow)
|
||||
|
||||
In BOTH cases, this is where the subscription and user are actually created.
|
||||
This ensures consistent flow and prevents duplicate subscriptions.
|
||||
|
||||
Args:
|
||||
verification_data: Must contain:
|
||||
- setup_intent_id: Verified SetupIntent ID
|
||||
- user_data: Original user registration data
|
||||
|
||||
Returns:
|
||||
Complete registration result with:
|
||||
- user: Created user data
|
||||
- subscription_id: Created subscription ID
|
||||
- payment_customer_id: Stripe customer ID
|
||||
- access_token: JWT access token
|
||||
- refresh_token: JWT refresh token
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 if setup_intent_id is missing, 500 for server errors
|
||||
"""
|
||||
try:
|
||||
# Validate required fields
|
||||
if not verification_data.get('setup_intent_id'):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="SetupIntent ID is required"
|
||||
)
|
||||
|
||||
if not verification_data.get('user_data'):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="User data is required"
|
||||
)
|
||||
|
||||
# Extract user data
|
||||
user_data_dict = verification_data['user_data']
|
||||
user_data = UserRegistration(**user_data_dict)
|
||||
|
||||
logger.info(f"Completing registration after SetupIntent verification, email={user_data.email}, setup_intent_id={verification_data['setup_intent_id']}")
|
||||
|
||||
# Complete registration with verified payment
|
||||
result = await auth_service.complete_registration_with_verified_payment(
|
||||
verification_data['setup_intent_id'],
|
||||
user_data
|
||||
)
|
||||
|
||||
logger.info(f"Registration completed successfully after 3DS, user_id={result['user'].id}, email={result['user'].email}, subscription_id={result.get('subscription_id')}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"user": {
|
||||
"id": str(result['user'].id),
|
||||
"email": result['user'].email,
|
||||
"full_name": result['user'].full_name,
|
||||
"is_active": result['user'].is_active,
|
||||
"is_verified": result['user'].is_verified,
|
||||
"created_at": result['user'].created_at.isoformat() if result['user'].created_at else None,
|
||||
"role": result['user'].role
|
||||
},
|
||||
"subscription_id": result.get('subscription_id'),
|
||||
"payment_customer_id": result.get('payment_customer_id'),
|
||||
"status": result.get('status'),
|
||||
"access_token": result.get('access_token'),
|
||||
"refresh_token": result.get('refresh_token'),
|
||||
"message": "Registration completed successfully after 3DS verification"
|
||||
}
|
||||
|
||||
except RegistrationError as e:
|
||||
logger.error(f"Registration completion after 3DS failed: {str(e)}, setup_intent_id: {verification_data.get('setup_intent_id')}, email: {user_data_dict.get('email') if user_data_dict else 'unknown'}",
|
||||
exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Registration completion failed: {str(e)}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected registration completion error: {str(e)}, setup_intent_id: {verification_data.get('setup_intent_id')}",
|
||||
exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Registration completion error: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/login",
|
||||
response_model=Dict[str, Any],
|
||||
summary="User login with subscription validation")
|
||||
async def login(
|
||||
login_data: UserLogin,
|
||||
request: Request,
|
||||
auth_service: AuthService = Depends(get_auth_service)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
User login endpoint with subscription validation
|
||||
|
||||
This endpoint:
|
||||
1. Validates user credentials
|
||||
2. Checks if user has active subscription (if required)
|
||||
3. Returns authentication tokens
|
||||
4. Updates last login timestamp
|
||||
|
||||
Args:
|
||||
login_data: User login credentials (email and password)
|
||||
|
||||
Returns:
|
||||
Authentication tokens and user information
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 for invalid credentials, 403 for subscription issues
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Login attempt, email={login_data.email}")
|
||||
|
||||
# Validate required fields
|
||||
if not login_data.email or not login_data.email.strip():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email is required"
|
||||
)
|
||||
|
||||
if not login_data.password or len(login_data.password) < 8:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Password must be at least 8 characters long"
|
||||
)
|
||||
|
||||
# Call auth service to perform login
|
||||
result = await auth_service.login_user(login_data)
|
||||
|
||||
logger.info(f"Login successful, email={login_data.email}, user_id={result['user'].id}")
|
||||
|
||||
# Extract tokens from result for top-level response
|
||||
tokens = result.get('tokens', {})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"access_token": tokens.get('access_token'),
|
||||
"refresh_token": tokens.get('refresh_token'),
|
||||
"token_type": tokens.get('token_type'),
|
||||
"expires_in": tokens.get('expires_in'),
|
||||
"user": {
|
||||
"id": str(result['user'].id),
|
||||
"email": result['user'].email,
|
||||
"full_name": result['user'].full_name,
|
||||
"is_active": result['user'].is_active,
|
||||
"last_login": result['user'].last_login.isoformat() if result['user'].last_login else None
|
||||
},
|
||||
"subscription": result.get('subscription', {}),
|
||||
"message": "Login successful"
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTP exceptions (like 401 for invalid credentials)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Login failed: {str(e)}, email: {login_data.email}",
|
||||
exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Login failed: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TOKEN MANAGEMENT ENDPOINTS - NEWLY ADDED
|
||||
# ============================================================================
|
||||
|
||||
@router.post("/refresh",
|
||||
response_model=Dict[str, Any],
|
||||
summary="Refresh access token using refresh token")
|
||||
async def refresh_token(
|
||||
request: Request,
|
||||
refresh_data: Dict[str, Any],
|
||||
auth_service: AuthService = Depends(get_auth_service)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Refresh access token using a valid refresh token
|
||||
|
||||
This endpoint:
|
||||
1. Validates the refresh token
|
||||
2. Generates new access and refresh tokens
|
||||
3. Returns the new tokens
|
||||
|
||||
Args:
|
||||
refresh_data: Dictionary containing refresh_token
|
||||
|
||||
Returns:
|
||||
New authentication tokens
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 for invalid refresh tokens
|
||||
"""
|
||||
try:
|
||||
logger.info("Token refresh request initiated")
|
||||
|
||||
# Extract refresh token from request
|
||||
refresh_token = refresh_data.get("refresh_token")
|
||||
if not refresh_token:
|
||||
logger.warning("Refresh token missing from request")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Refresh token is required"
|
||||
)
|
||||
|
||||
# Use service layer to refresh tokens
|
||||
tokens = await auth_service.refresh_auth_tokens(refresh_token)
|
||||
|
||||
logger.info("Token refresh successful via service layer")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"access_token": tokens.get("access_token"),
|
||||
"refresh_token": tokens.get("refresh_token"),
|
||||
"token_type": "bearer",
|
||||
"expires_in": 1800, # 30 minutes
|
||||
"message": "Token refresh successful"
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTP exceptions
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Token refresh failed: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Token refresh failed: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/verify",
|
||||
response_model=Dict[str, Any],
|
||||
summary="Verify token validity")
|
||||
async def verify_token(
|
||||
request: Request,
|
||||
token_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Verify the validity of an access token
|
||||
|
||||
Args:
|
||||
token_data: Dictionary containing access_token
|
||||
|
||||
Returns:
|
||||
Token validation result
|
||||
"""
|
||||
try:
|
||||
logger.info("Token verification request initiated")
|
||||
|
||||
# Extract access token from request
|
||||
access_token = token_data.get("access_token")
|
||||
if not access_token:
|
||||
logger.warning("Access token missing from verification request")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Access token is required"
|
||||
)
|
||||
|
||||
# Use service layer to verify token
|
||||
result = await auth_service.verify_access_token(access_token)
|
||||
|
||||
logger.info("Token verification successful via service layer")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"valid": result.get("valid"),
|
||||
"user_id": result.get("user_id"),
|
||||
"email": result.get("email"),
|
||||
"message": "Token is valid"
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTP exceptions
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Token verification failed: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Token verification failed: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/logout",
|
||||
response_model=Dict[str, Any],
|
||||
summary="Logout and revoke refresh token")
|
||||
async def logout(
|
||||
request: Request,
|
||||
logout_data: Dict[str, Any],
|
||||
auth_service: AuthService = Depends(get_auth_service)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Logout user and revoke refresh token
|
||||
|
||||
Args:
|
||||
logout_data: Dictionary containing refresh_token
|
||||
|
||||
Returns:
|
||||
Logout confirmation
|
||||
"""
|
||||
try:
|
||||
logger.info("Logout request initiated")
|
||||
|
||||
# Extract refresh token from request
|
||||
refresh_token = logout_data.get("refresh_token")
|
||||
if not refresh_token:
|
||||
logger.warning("Refresh token missing from logout request")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Refresh token is required"
|
||||
)
|
||||
|
||||
# Use service layer to revoke refresh token
|
||||
try:
|
||||
await auth_service.revoke_refresh_token(refresh_token)
|
||||
logger.info("Logout successful via service layer")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Logout successful"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during logout: {str(e)}")
|
||||
# Don't fail logout if revocation fails
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Logout successful (token revocation failed but user logged out)"
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTP exceptions
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Logout failed: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Logout failed: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/change-password",
|
||||
response_model=Dict[str, Any],
|
||||
summary="Change user password")
|
||||
async def change_password(
|
||||
request: Request,
|
||||
password_data: Dict[str, Any],
|
||||
auth_service: AuthService = Depends(get_auth_service)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Change user password
|
||||
|
||||
Args:
|
||||
password_data: Dictionary containing current_password and new_password
|
||||
|
||||
Returns:
|
||||
Password change confirmation
|
||||
"""
|
||||
try:
|
||||
logger.info("Password change request initiated")
|
||||
|
||||
# Extract user from request state
|
||||
if not hasattr(request.state, 'user') or not request.state.user:
|
||||
logger.warning("Unauthorized password change attempt - no user context")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required"
|
||||
)
|
||||
|
||||
user_id = request.state.user.get("user_id")
|
||||
if not user_id:
|
||||
logger.warning("Unauthorized password change attempt - no user_id")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid user context"
|
||||
)
|
||||
|
||||
# Extract password data
|
||||
current_password = password_data.get("current_password")
|
||||
new_password = password_data.get("new_password")
|
||||
|
||||
if not current_password or not new_password:
|
||||
logger.warning("Password change missing required fields")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Current password and new password are required"
|
||||
)
|
||||
|
||||
if len(new_password) < 8:
|
||||
logger.warning("New password too short")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="New password must be at least 8 characters long"
|
||||
)
|
||||
|
||||
# Use service layer to change password
|
||||
await auth_service.change_user_password(user_id, current_password, new_password)
|
||||
|
||||
logger.info(f"Password change successful via service layer, user_id={user_id}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Password changed successfully"
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTP exceptions
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Password change failed: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Password change failed: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/verify-email",
|
||||
response_model=Dict[str, Any],
|
||||
summary="Verify user email")
|
||||
async def verify_email(
|
||||
request: Request,
|
||||
email_data: Dict[str, Any],
|
||||
auth_service: AuthService = Depends(get_auth_service)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Verify user email (placeholder implementation)
|
||||
|
||||
Args:
|
||||
email_data: Dictionary containing email and verification_token
|
||||
|
||||
Returns:
|
||||
Email verification confirmation
|
||||
"""
|
||||
try:
|
||||
logger.info("Email verification request initiated")
|
||||
|
||||
# Extract email and token
|
||||
email = email_data.get("email")
|
||||
verification_token = email_data.get("verification_token")
|
||||
|
||||
if not email or not verification_token:
|
||||
logger.warning("Email verification missing required fields")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email and verification token are required"
|
||||
)
|
||||
|
||||
# Use service layer to verify email
|
||||
await auth_service.verify_user_email(email, verification_token)
|
||||
|
||||
logger.info("Email verification successful via service layer")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Email verified successfully"
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTP exceptions
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Email verification failed: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Email verification failed: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
372
services/auth/app/api/consent.py
Normal file
372
services/auth/app/api/consent.py
Normal file
@@ -0,0 +1,372 @@
|
||||
"""
|
||||
User consent management API endpoints for GDPR compliance
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime, timezone
|
||||
import structlog
|
||||
import hashlib
|
||||
|
||||
from shared.auth.decorators import get_current_user_dep
|
||||
from app.core.database import get_db
|
||||
from app.models.consent import UserConsent, ConsentHistory
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class ConsentRequest(BaseModel):
|
||||
"""Request model for granting/updating consent"""
|
||||
terms_accepted: bool = Field(..., description="Accept terms of service")
|
||||
privacy_accepted: bool = Field(..., description="Accept privacy policy")
|
||||
marketing_consent: bool = Field(default=False, description="Consent to marketing communications")
|
||||
analytics_consent: bool = Field(default=False, description="Consent to analytics cookies")
|
||||
consent_method: str = Field(..., description="How consent was given (registration, settings, cookie_banner)")
|
||||
consent_version: str = Field(default="1.0", description="Version of terms/privacy policy")
|
||||
|
||||
|
||||
class ConsentResponse(BaseModel):
|
||||
"""Response model for consent data"""
|
||||
id: str
|
||||
user_id: str
|
||||
terms_accepted: bool
|
||||
privacy_accepted: bool
|
||||
marketing_consent: bool
|
||||
analytics_consent: bool
|
||||
consent_version: str
|
||||
consent_method: str
|
||||
consented_at: str
|
||||
withdrawn_at: Optional[str]
|
||||
|
||||
|
||||
class ConsentHistoryResponse(BaseModel):
|
||||
"""Response model for consent history"""
|
||||
id: str
|
||||
user_id: str
|
||||
action: str
|
||||
consent_snapshot: dict
|
||||
created_at: str
|
||||
|
||||
|
||||
def hash_text(text: str) -> str:
|
||||
"""Create hash of consent text for verification"""
|
||||
return hashlib.sha256(text.encode()).hexdigest()
|
||||
|
||||
|
||||
@router.post("/api/v1/auth/me/consent", response_model=ConsentResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def record_consent(
|
||||
consent_data: ConsentRequest,
|
||||
request: Request,
|
||||
current_user: dict = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Record user consent for data processing
|
||||
GDPR Article 7 - Conditions for consent
|
||||
"""
|
||||
try:
|
||||
user_id = UUID(current_user["user_id"])
|
||||
|
||||
ip_address = request.client.host if request.client else None
|
||||
user_agent = request.headers.get("user-agent")
|
||||
|
||||
consent = UserConsent(
|
||||
user_id=user_id,
|
||||
terms_accepted=consent_data.terms_accepted,
|
||||
privacy_accepted=consent_data.privacy_accepted,
|
||||
marketing_consent=consent_data.marketing_consent,
|
||||
analytics_consent=consent_data.analytics_consent,
|
||||
consent_version=consent_data.consent_version,
|
||||
consent_method=consent_data.consent_method,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
consented_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
db.add(consent)
|
||||
await db.flush()
|
||||
|
||||
history = ConsentHistory(
|
||||
user_id=user_id,
|
||||
consent_id=consent.id,
|
||||
action="granted",
|
||||
consent_snapshot=consent_data.dict(),
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
consent_method=consent_data.consent_method,
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
db.add(history)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(consent)
|
||||
|
||||
logger.info(
|
||||
"consent_recorded",
|
||||
user_id=str(user_id),
|
||||
consent_version=consent_data.consent_version,
|
||||
method=consent_data.consent_method
|
||||
)
|
||||
|
||||
return ConsentResponse(
|
||||
id=str(consent.id),
|
||||
user_id=str(consent.user_id),
|
||||
terms_accepted=consent.terms_accepted,
|
||||
privacy_accepted=consent.privacy_accepted,
|
||||
marketing_consent=consent.marketing_consent,
|
||||
analytics_consent=consent.analytics_consent,
|
||||
consent_version=consent.consent_version,
|
||||
consent_method=consent.consent_method,
|
||||
consented_at=consent.consented_at.isoformat(),
|
||||
withdrawn_at=consent.withdrawn_at.isoformat() if consent.withdrawn_at else None
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error("error_recording_consent", error=str(e), user_id=current_user.get("user_id"))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to record consent"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/api/v1/auth/me/consent/current", response_model=Optional[ConsentResponse])
|
||||
async def get_current_consent(
|
||||
current_user: dict = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get current active consent for user
|
||||
"""
|
||||
try:
|
||||
user_id = UUID(current_user["user_id"])
|
||||
|
||||
query = select(UserConsent).where(
|
||||
and_(
|
||||
UserConsent.user_id == user_id,
|
||||
UserConsent.withdrawn_at.is_(None)
|
||||
)
|
||||
).order_by(UserConsent.consented_at.desc())
|
||||
|
||||
result = await db.execute(query)
|
||||
consent = result.scalar_one_or_none()
|
||||
|
||||
if not consent:
|
||||
return None
|
||||
|
||||
return ConsentResponse(
|
||||
id=str(consent.id),
|
||||
user_id=str(consent.user_id),
|
||||
terms_accepted=consent.terms_accepted,
|
||||
privacy_accepted=consent.privacy_accepted,
|
||||
marketing_consent=consent.marketing_consent,
|
||||
analytics_consent=consent.analytics_consent,
|
||||
consent_version=consent.consent_version,
|
||||
consent_method=consent.consent_method,
|
||||
consented_at=consent.consented_at.isoformat(),
|
||||
withdrawn_at=consent.withdrawn_at.isoformat() if consent.withdrawn_at else None
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("error_getting_consent", error=str(e), user_id=current_user.get("user_id"))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve consent"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/api/v1/auth/me/consent/history", response_model=List[ConsentHistoryResponse])
|
||||
async def get_consent_history(
|
||||
current_user: dict = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get complete consent history for user
|
||||
GDPR Article 7(1) - Demonstrating consent
|
||||
"""
|
||||
try:
|
||||
user_id = UUID(current_user["user_id"])
|
||||
|
||||
query = select(ConsentHistory).where(
|
||||
ConsentHistory.user_id == user_id
|
||||
).order_by(ConsentHistory.created_at.desc())
|
||||
|
||||
result = await db.execute(query)
|
||||
history = result.scalars().all()
|
||||
|
||||
return [
|
||||
ConsentHistoryResponse(
|
||||
id=str(h.id),
|
||||
user_id=str(h.user_id),
|
||||
action=h.action,
|
||||
consent_snapshot=h.consent_snapshot,
|
||||
created_at=h.created_at.isoformat()
|
||||
)
|
||||
for h in history
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.error("error_getting_consent_history", error=str(e), user_id=current_user.get("user_id"))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve consent history"
|
||||
)
|
||||
|
||||
|
||||
@router.put("/api/v1/auth/me/consent", response_model=ConsentResponse)
|
||||
async def update_consent(
|
||||
consent_data: ConsentRequest,
|
||||
request: Request,
|
||||
current_user: dict = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Update user consent preferences
|
||||
GDPR Article 7(3) - Withdrawal of consent
|
||||
"""
|
||||
try:
|
||||
user_id = UUID(current_user["user_id"])
|
||||
|
||||
query = select(UserConsent).where(
|
||||
and_(
|
||||
UserConsent.user_id == user_id,
|
||||
UserConsent.withdrawn_at.is_(None)
|
||||
)
|
||||
).order_by(UserConsent.consented_at.desc())
|
||||
|
||||
result = await db.execute(query)
|
||||
current_consent = result.scalar_one_or_none()
|
||||
|
||||
if current_consent:
|
||||
current_consent.withdrawn_at = datetime.now(timezone.utc)
|
||||
history = ConsentHistory(
|
||||
user_id=user_id,
|
||||
consent_id=current_consent.id,
|
||||
action="updated",
|
||||
consent_snapshot=current_consent.to_dict(),
|
||||
ip_address=request.client.host if request.client else None,
|
||||
user_agent=request.headers.get("user-agent"),
|
||||
consent_method=consent_data.consent_method,
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
db.add(history)
|
||||
|
||||
new_consent = UserConsent(
|
||||
user_id=user_id,
|
||||
terms_accepted=consent_data.terms_accepted,
|
||||
privacy_accepted=consent_data.privacy_accepted,
|
||||
marketing_consent=consent_data.marketing_consent,
|
||||
analytics_consent=consent_data.analytics_consent,
|
||||
consent_version=consent_data.consent_version,
|
||||
consent_method=consent_data.consent_method,
|
||||
ip_address=request.client.host if request.client else None,
|
||||
user_agent=request.headers.get("user-agent"),
|
||||
consented_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
db.add(new_consent)
|
||||
await db.flush()
|
||||
|
||||
history = ConsentHistory(
|
||||
user_id=user_id,
|
||||
consent_id=new_consent.id,
|
||||
action="granted" if not current_consent else "updated",
|
||||
consent_snapshot=consent_data.dict(),
|
||||
ip_address=request.client.host if request.client else None,
|
||||
user_agent=request.headers.get("user-agent"),
|
||||
consent_method=consent_data.consent_method,
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
db.add(history)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(new_consent)
|
||||
|
||||
logger.info(
|
||||
"consent_updated",
|
||||
user_id=str(user_id),
|
||||
consent_version=consent_data.consent_version
|
||||
)
|
||||
|
||||
return ConsentResponse(
|
||||
id=str(new_consent.id),
|
||||
user_id=str(new_consent.user_id),
|
||||
terms_accepted=new_consent.terms_accepted,
|
||||
privacy_accepted=new_consent.privacy_accepted,
|
||||
marketing_consent=new_consent.marketing_consent,
|
||||
analytics_consent=new_consent.analytics_consent,
|
||||
consent_version=new_consent.consent_version,
|
||||
consent_method=new_consent.consent_method,
|
||||
consented_at=new_consent.consented_at.isoformat(),
|
||||
withdrawn_at=None
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error("error_updating_consent", error=str(e), user_id=current_user.get("user_id"))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update consent"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api/v1/auth/me/consent/withdraw", status_code=status.HTTP_200_OK)
|
||||
async def withdraw_consent(
|
||||
request: Request,
|
||||
current_user: dict = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Withdraw all consent
|
||||
GDPR Article 7(3) - Right to withdraw consent
|
||||
"""
|
||||
try:
|
||||
user_id = UUID(current_user["user_id"])
|
||||
|
||||
query = select(UserConsent).where(
|
||||
and_(
|
||||
UserConsent.user_id == user_id,
|
||||
UserConsent.withdrawn_at.is_(None)
|
||||
)
|
||||
)
|
||||
|
||||
result = await db.execute(query)
|
||||
consents = result.scalars().all()
|
||||
|
||||
for consent in consents:
|
||||
consent.withdrawn_at = datetime.now(timezone.utc)
|
||||
|
||||
history = ConsentHistory(
|
||||
user_id=user_id,
|
||||
consent_id=consent.id,
|
||||
action="withdrawn",
|
||||
consent_snapshot=consent.to_dict(),
|
||||
ip_address=request.client.host if request.client else None,
|
||||
user_agent=request.headers.get("user-agent"),
|
||||
consent_method="user_withdrawal",
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
db.add(history)
|
||||
|
||||
await db.commit()
|
||||
|
||||
logger.info("consent_withdrawn", user_id=str(user_id), count=len(consents))
|
||||
|
||||
return {
|
||||
"message": "Consent withdrawn successfully",
|
||||
"withdrawn_count": len(consents)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error("error_withdrawing_consent", error=str(e), user_id=current_user.get("user_id"))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to withdraw consent"
|
||||
)
|
||||
121
services/auth/app/api/data_export.py
Normal file
121
services/auth/app/api/data_export.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
User data export API endpoints for GDPR compliance
|
||||
Implements Article 15 (Right to Access) and Article 20 (Right to Data Portability)
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
import structlog
|
||||
|
||||
from shared.auth.decorators import get_current_user_dep
|
||||
from app.core.database import get_db
|
||||
from app.services.data_export_service import DataExportService
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/api/v1/auth/me/export")
|
||||
async def export_my_data(
|
||||
current_user: dict = Depends(get_current_user_dep),
|
||||
db = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Export all personal data for the current user
|
||||
|
||||
GDPR Article 15 - Right of access by the data subject
|
||||
GDPR Article 20 - Right to data portability
|
||||
|
||||
Returns complete user data in machine-readable JSON format including:
|
||||
- Personal information
|
||||
- Account data
|
||||
- Consent history
|
||||
- Security logs
|
||||
- Audit trail
|
||||
|
||||
Response is provided in JSON format for easy data portability.
|
||||
"""
|
||||
try:
|
||||
user_id = UUID(current_user["user_id"])
|
||||
|
||||
export_service = DataExportService(db)
|
||||
data = await export_service.export_user_data(user_id)
|
||||
|
||||
logger.info(
|
||||
"data_export_requested",
|
||||
user_id=str(user_id),
|
||||
email=current_user.get("email")
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content=data,
|
||||
status_code=status.HTTP_200_OK,
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="user_data_export_{user_id}.json"',
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"data_export_failed",
|
||||
user_id=current_user.get("user_id"),
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to export user data"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/api/v1/auth/me/export/summary")
|
||||
async def get_export_summary(
|
||||
current_user: dict = Depends(get_current_user_dep),
|
||||
db = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get a summary of what data would be exported
|
||||
|
||||
Useful for showing users what data we have about them
|
||||
before they request full export.
|
||||
"""
|
||||
try:
|
||||
user_id = UUID(current_user["user_id"])
|
||||
|
||||
export_service = DataExportService(db)
|
||||
data = await export_service.export_user_data(user_id)
|
||||
|
||||
summary = {
|
||||
"user_id": str(user_id),
|
||||
"data_categories": {
|
||||
"personal_data": bool(data.get("personal_data")),
|
||||
"account_data": bool(data.get("account_data")),
|
||||
"consent_data": bool(data.get("consent_data")),
|
||||
"security_data": bool(data.get("security_data")),
|
||||
"onboarding_data": bool(data.get("onboarding_data")),
|
||||
"audit_logs": bool(data.get("audit_logs"))
|
||||
},
|
||||
"data_counts": {
|
||||
"active_sessions": data.get("account_data", {}).get("active_sessions_count", 0),
|
||||
"consent_changes": data.get("consent_data", {}).get("total_consent_changes", 0),
|
||||
"login_attempts": len(data.get("security_data", {}).get("recent_login_attempts", [])),
|
||||
"audit_logs": data.get("audit_logs", {}).get("total_logs_exported", 0)
|
||||
},
|
||||
"export_format": "JSON",
|
||||
"gdpr_articles": ["Article 15 (Right to Access)", "Article 20 (Data Portability)"]
|
||||
}
|
||||
|
||||
return summary
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"export_summary_failed",
|
||||
user_id=current_user.get("user_id"),
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to generate export summary"
|
||||
)
|
||||
229
services/auth/app/api/internal_demo.py
Normal file
229
services/auth/app/api/internal_demo.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""
|
||||
Internal Demo Cloning API for Auth Service
|
||||
Service-to-service endpoint for cloning authentication and user data
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Header
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
import structlog
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
# Add shared path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent))
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.models.users import User
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter(prefix="/internal/demo", tags=["internal"])
|
||||
|
||||
# Base demo tenant IDs
|
||||
DEMO_TENANT_PROFESSIONAL = "a1b2c3d4-e5f6-47a8-b9c0-d1e2f3a4b5c6"
|
||||
|
||||
|
||||
@router.post("/clone")
|
||||
async def clone_demo_data(
|
||||
base_tenant_id: str,
|
||||
virtual_tenant_id: str,
|
||||
demo_account_type: str,
|
||||
session_id: Optional[str] = None,
|
||||
session_created_at: Optional[str] = None,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Clone auth service data for a virtual demo tenant
|
||||
|
||||
Clones:
|
||||
- Demo users (owner and staff)
|
||||
|
||||
Note: Tenant memberships are handled by the tenant service's internal_demo endpoint
|
||||
|
||||
Args:
|
||||
base_tenant_id: Template tenant UUID to clone from
|
||||
virtual_tenant_id: Target virtual tenant UUID
|
||||
demo_account_type: Type of demo account
|
||||
session_id: Originating session ID for tracing
|
||||
|
||||
Returns:
|
||||
Cloning status and record counts
|
||||
"""
|
||||
start_time = datetime.now(timezone.utc)
|
||||
|
||||
# Parse session creation time
|
||||
if session_created_at:
|
||||
try:
|
||||
session_time = datetime.fromisoformat(session_created_at.replace('Z', '+00:00'))
|
||||
except (ValueError, AttributeError):
|
||||
session_time = start_time
|
||||
else:
|
||||
session_time = start_time
|
||||
|
||||
logger.info(
|
||||
"Starting auth data cloning",
|
||||
base_tenant_id=base_tenant_id,
|
||||
virtual_tenant_id=virtual_tenant_id,
|
||||
demo_account_type=demo_account_type,
|
||||
session_id=session_id,
|
||||
session_created_at=session_created_at
|
||||
)
|
||||
|
||||
try:
|
||||
# Validate UUIDs
|
||||
base_uuid = uuid.UUID(base_tenant_id)
|
||||
virtual_uuid = uuid.UUID(virtual_tenant_id)
|
||||
|
||||
# Note: We don't check for existing users since User model doesn't have demo_session_id
|
||||
# Demo users are identified by their email addresses from the seed data
|
||||
# Idempotency is handled by checking if each user email already exists below
|
||||
|
||||
# Load demo users from JSON seed file
|
||||
from shared.utils.seed_data_paths import get_seed_data_path
|
||||
|
||||
if demo_account_type == "professional":
|
||||
json_file = get_seed_data_path("professional", "02-auth.json")
|
||||
elif demo_account_type == "enterprise":
|
||||
json_file = get_seed_data_path("enterprise", "02-auth.json")
|
||||
elif demo_account_type == "enterprise_child":
|
||||
# Child locations don't have separate auth data - they share parent's users
|
||||
logger.info("enterprise_child uses parent tenant auth, skipping user cloning", virtual_tenant_id=virtual_tenant_id)
|
||||
return {
|
||||
"service": "auth",
|
||||
"status": "completed",
|
||||
"records_cloned": 0,
|
||||
"duration_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000),
|
||||
"details": {"users": 0, "note": "Child locations share parent auth"}
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Invalid demo account type: {demo_account_type}")
|
||||
|
||||
# Load JSON data
|
||||
import json
|
||||
with open(json_file, 'r', encoding='utf-8') as f:
|
||||
seed_data = json.load(f)
|
||||
|
||||
# Get demo users for this account type
|
||||
demo_users_data = seed_data.get("users", [])
|
||||
|
||||
records_cloned = 0
|
||||
|
||||
# Create users and tenant memberships
|
||||
for user_data in demo_users_data:
|
||||
user_id = uuid.UUID(user_data["id"])
|
||||
|
||||
# Create user if not exists
|
||||
user_result = await db.execute(
|
||||
select(User).where(User.id == user_id)
|
||||
)
|
||||
existing_user = user_result.scalars().first()
|
||||
|
||||
if not existing_user:
|
||||
# Apply date adjustments to created_at and updated_at
|
||||
from shared.utils.demo_dates import adjust_date_for_demo
|
||||
|
||||
# Adjust created_at date
|
||||
created_at_str = user_data.get("created_at", session_time.isoformat())
|
||||
if isinstance(created_at_str, str):
|
||||
try:
|
||||
original_created_at = datetime.fromisoformat(created_at_str.replace('Z', '+00:00'))
|
||||
adjusted_created_at = adjust_date_for_demo(original_created_at, session_time)
|
||||
except ValueError:
|
||||
adjusted_created_at = session_time
|
||||
else:
|
||||
adjusted_created_at = session_time
|
||||
|
||||
# Adjust updated_at date (same as created_at for demo users)
|
||||
adjusted_updated_at = adjusted_created_at
|
||||
|
||||
# Get full_name from either "name" or "full_name" field
|
||||
full_name = user_data.get("full_name") or user_data.get("name", "Demo User")
|
||||
|
||||
# For demo users, use a placeholder hashed password (they won't actually log in)
|
||||
# In production, this would be properly hashed
|
||||
demo_hashed_password = "$2b$12$LQv3c1yqBWVHxkd0LHAkCOYz6TtxMQJqhN8/LewY5GyYqNlI.eFKW" # "demo_password"
|
||||
|
||||
user = User(
|
||||
id=user_id,
|
||||
email=user_data["email"],
|
||||
full_name=full_name,
|
||||
hashed_password=demo_hashed_password,
|
||||
is_active=user_data.get("is_active", True),
|
||||
is_verified=True,
|
||||
role=user_data.get("role", "member"),
|
||||
language=user_data.get("language", "es"),
|
||||
timezone=user_data.get("timezone", "Europe/Madrid"),
|
||||
created_at=adjusted_created_at,
|
||||
updated_at=adjusted_updated_at
|
||||
)
|
||||
db.add(user)
|
||||
records_cloned += 1
|
||||
|
||||
# Note: Tenant memberships are handled by tenant service
|
||||
# Only create users in auth service
|
||||
|
||||
await db.commit()
|
||||
|
||||
duration_ms = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
|
||||
|
||||
logger.info(
|
||||
"Auth data cloning completed",
|
||||
virtual_tenant_id=virtual_tenant_id,
|
||||
session_id=session_id,
|
||||
records_cloned=records_cloned,
|
||||
duration_ms=duration_ms
|
||||
)
|
||||
|
||||
return {
|
||||
"service": "auth",
|
||||
"status": "completed",
|
||||
"records_cloned": records_cloned,
|
||||
"base_tenant_id": str(base_tenant_id),
|
||||
"virtual_tenant_id": str(virtual_tenant_id),
|
||||
"session_id": session_id,
|
||||
"demo_account_type": demo_account_type,
|
||||
"duration_ms": duration_ms
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
logger.error("Invalid UUID format", error=str(e), virtual_tenant_id=virtual_tenant_id)
|
||||
raise HTTPException(status_code=400, detail=f"Invalid UUID: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to clone auth data",
|
||||
error=str(e),
|
||||
virtual_tenant_id=virtual_tenant_id,
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
# Rollback on error
|
||||
await db.rollback()
|
||||
|
||||
return {
|
||||
"service": "auth",
|
||||
"status": "failed",
|
||||
"records_cloned": 0,
|
||||
"duration_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000),
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
@router.get("/clone/health")
|
||||
async def clone_health_check():
|
||||
"""
|
||||
Health check for internal cloning endpoint
|
||||
Used by orchestrator to verify service availability
|
||||
"""
|
||||
return {
|
||||
"service": "auth",
|
||||
"clone_endpoint": "available",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
1153
services/auth/app/api/onboarding_progress.py
Normal file
1153
services/auth/app/api/onboarding_progress.py
Normal file
File diff suppressed because it is too large
Load Diff
308
services/auth/app/api/password_reset.py
Normal file
308
services/auth/app/api/password_reset.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# services/auth/app/api/password_reset.py
|
||||
"""
|
||||
Password reset API endpoints
|
||||
Handles forgot password and password reset functionality
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks
|
||||
from app.services.auth_service import auth_service, AuthService
|
||||
from app.schemas.auth import PasswordReset, PasswordResetConfirm
|
||||
from app.core.security import SecurityManager
|
||||
from app.core.config import settings
|
||||
from app.repositories.password_reset_repository import PasswordResetTokenRepository
|
||||
from app.repositories.user_repository import UserRepository
|
||||
from app.models.users import User
|
||||
from shared.clients.notification_client import NotificationServiceClient
|
||||
import structlog
|
||||
|
||||
# Configure logging
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Create router
|
||||
router = APIRouter(prefix="/api/v1/auth", tags=["password-reset"])
|
||||
|
||||
|
||||
async def get_auth_service() -> AuthService:
|
||||
"""Dependency injection for auth service"""
|
||||
return auth_service
|
||||
|
||||
|
||||
def generate_reset_token() -> str:
|
||||
"""Generate a secure password reset token"""
|
||||
import secrets
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
async def send_password_reset_email(email: str, reset_token: str, user_full_name: str):
|
||||
"""Send password reset email in background using notification service"""
|
||||
try:
|
||||
# Construct reset link (this should match your frontend URL)
|
||||
# Use FRONTEND_URL from settings if available, otherwise fall back to gateway URL
|
||||
frontend_url = getattr(settings, 'FRONTEND_URL', settings.GATEWAY_URL)
|
||||
reset_link = f"{frontend_url}/reset-password?token={reset_token}"
|
||||
|
||||
# Create HTML content for the password reset email in Spanish
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html lang="es">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Restablecer Contraseña</title>
|
||||
<style>
|
||||
body {{
|
||||
font-family: Arial, sans-serif;
|
||||
line-height: 1.6;
|
||||
color: #333;
|
||||
max-width: 600px;
|
||||
margin: 0 auto;
|
||||
padding: 20px;
|
||||
background-color: #f9f9f9;
|
||||
}}
|
||||
.header {{
|
||||
text-align: center;
|
||||
margin-bottom: 30px;
|
||||
background: linear-gradient(135deg, #4F46E5 0%, #7C3AED 100%);
|
||||
color: white;
|
||||
padding: 20px;
|
||||
border-radius: 8px;
|
||||
}}
|
||||
.content {{
|
||||
background: white;
|
||||
padding: 30px;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
||||
}}
|
||||
.button {{
|
||||
display: inline-block;
|
||||
padding: 12px 30px;
|
||||
background-color: #4F46E5;
|
||||
color: white;
|
||||
text-decoration: none;
|
||||
border-radius: 5px;
|
||||
margin: 20px 0;
|
||||
font-weight: bold;
|
||||
}}
|
||||
.footer {{
|
||||
margin-top: 40px;
|
||||
text-align: center;
|
||||
font-size: 0.9em;
|
||||
color: #666;
|
||||
padding-top: 20px;
|
||||
border-top: 1px solid #eee;
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="header">
|
||||
<h1>Restablecer Contraseña</h1>
|
||||
</div>
|
||||
|
||||
<div class="content">
|
||||
<p>Hola {user_full_name},</p>
|
||||
|
||||
<p>Recibimos una solicitud para restablecer tu contraseña. Haz clic en el botón de abajo para crear una nueva contraseña:</p>
|
||||
|
||||
<p style="text-align: center; margin: 30px 0;">
|
||||
<a href="{reset_link}" class="button">Restablecer Contraseña</a>
|
||||
</p>
|
||||
|
||||
<p>Si no solicitaste un restablecimiento de contraseña, puedes ignorar este correo electrónico de forma segura.</p>
|
||||
|
||||
<p>Este enlace expirará en 1 hora por razones de seguridad.</p>
|
||||
</div>
|
||||
|
||||
<div class="footer">
|
||||
<p>Este es un mensaje automático de BakeWise. Por favor, no respondas a este correo electrónico.</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
# Create text content as fallback
|
||||
text_content = f"""
|
||||
Hola {user_full_name},
|
||||
|
||||
Recibimos una solicitud para restablecer tu contraseña. Haz clic en el siguiente enlace para crear una nueva contraseña:
|
||||
|
||||
{reset_link}
|
||||
|
||||
Si no solicitaste un restablecimiento de contraseña, puedes ignorar este correo electrónico de forma segura.
|
||||
|
||||
Este enlace expirará en 1 hora por razones de seguridad.
|
||||
|
||||
Este es un mensaje automático de BakeWise. Por favor, no respondas a este correo electrónico.
|
||||
"""
|
||||
|
||||
# Send email using the notification service
|
||||
notification_client = NotificationServiceClient(settings)
|
||||
|
||||
# Send the notification using the send_email method
|
||||
await notification_client.send_email(
|
||||
tenant_id="system", # Using system tenant for password resets
|
||||
to_email=email,
|
||||
subject="Restablecer Contraseña",
|
||||
message=text_content,
|
||||
html_content=html_content,
|
||||
priority="high"
|
||||
)
|
||||
|
||||
logger.info(f"Password reset email sent successfully to {email}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send password reset email to {email}: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/password/reset-request",
|
||||
summary="Request password reset",
|
||||
description="Send a password reset link to the user's email")
|
||||
async def request_password_reset(
|
||||
reset_request: PasswordReset,
|
||||
background_tasks: BackgroundTasks,
|
||||
auth_service: AuthService = Depends(get_auth_service)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Request a password reset
|
||||
|
||||
This endpoint:
|
||||
1. Finds the user by email
|
||||
2. Generates a password reset token
|
||||
3. Stores the token in the database
|
||||
4. Sends a password reset email to the user
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Password reset request for email: {reset_request.email}")
|
||||
|
||||
# Find user by email
|
||||
async with auth_service.database_manager.get_session() as session:
|
||||
user_repo = UserRepository(User, session)
|
||||
user = await user_repo.get_by_field("email", reset_request.email)
|
||||
|
||||
if not user:
|
||||
# Don't reveal if email exists to prevent enumeration attacks
|
||||
logger.info(f"Password reset request for non-existent email: {reset_request.email}")
|
||||
return {"message": "If an account with this email exists, a reset link has been sent."}
|
||||
|
||||
# Generate a secure reset token
|
||||
reset_token = generate_reset_token()
|
||||
|
||||
# Set token expiration (e.g., 1 hour)
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
|
||||
# Store the reset token in the database
|
||||
token_repo = PasswordResetTokenRepository(session)
|
||||
|
||||
# Clean up any existing unused tokens for this user
|
||||
await token_repo.cleanup_expired_tokens()
|
||||
|
||||
# Create new reset token
|
||||
await token_repo.create_token(
|
||||
user_id=str(user.id),
|
||||
token=reset_token,
|
||||
expires_at=expires_at
|
||||
)
|
||||
|
||||
# Commit the transaction
|
||||
await session.commit()
|
||||
|
||||
# Send password reset email in background
|
||||
background_tasks.add_task(
|
||||
send_password_reset_email,
|
||||
user.email,
|
||||
reset_token,
|
||||
user.full_name
|
||||
)
|
||||
|
||||
logger.info(f"Password reset token created for user: {user.email}")
|
||||
|
||||
return {"message": "If an account with this email exists, a reset link has been sent."}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Password reset request failed: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Password reset request failed"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/password/reset",
|
||||
summary="Reset password with token",
|
||||
description="Reset user password using a valid reset token")
|
||||
async def reset_password(
|
||||
reset_confirm: PasswordResetConfirm,
|
||||
auth_service: AuthService = Depends(get_auth_service)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Reset password using a valid reset token
|
||||
|
||||
This endpoint:
|
||||
1. Validates the reset token
|
||||
2. Checks if the token is valid and not expired
|
||||
3. Updates the user's password
|
||||
4. Marks the token as used
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Password reset attempt with token: {reset_confirm.token[:10]}...")
|
||||
|
||||
# Validate password strength
|
||||
if not SecurityManager.validate_password(reset_confirm.new_password):
|
||||
errors = SecurityManager.get_password_validation_errors(reset_confirm.new_password)
|
||||
logger.warning(f"Password validation failed: {errors}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Password does not meet requirements: {'; '.join(errors)}"
|
||||
)
|
||||
|
||||
# Find the reset token in the database
|
||||
async with auth_service.database_manager.get_session() as session:
|
||||
token_repo = PasswordResetTokenRepository(session)
|
||||
reset_token_obj = await token_repo.get_token_by_value(reset_confirm.token)
|
||||
|
||||
if not reset_token_obj:
|
||||
logger.warning(f"Invalid or expired password reset token: {reset_confirm.token[:10]}...")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid or expired reset token"
|
||||
)
|
||||
|
||||
# Get the user associated with this token
|
||||
user_repo = UserRepository(User, session)
|
||||
user = await user_repo.get_by_id(str(reset_token_obj.user_id))
|
||||
|
||||
if not user:
|
||||
logger.error(f"User not found for reset token: {reset_confirm.token[:10]}...")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid reset token"
|
||||
)
|
||||
|
||||
# Hash the new password
|
||||
hashed_password = SecurityManager.hash_password(reset_confirm.new_password)
|
||||
|
||||
# Update user's password
|
||||
await user_repo.update(str(user.id), {
|
||||
"hashed_password": hashed_password
|
||||
})
|
||||
|
||||
# Mark the reset token as used
|
||||
await token_repo.mark_token_as_used(str(reset_token_obj.id))
|
||||
|
||||
# Commit the transactions
|
||||
await session.commit()
|
||||
|
||||
logger.info(f"Password successfully reset for user: {user.email}")
|
||||
|
||||
return {"message": "Password has been reset successfully"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Password reset failed: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Password reset failed"
|
||||
)
|
||||
662
services/auth/app/api/users.py
Normal file
662
services/auth/app/api/users.py
Normal file
@@ -0,0 +1,662 @@
|
||||
"""
|
||||
User management API routes
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks, Path, Body
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import Dict, Any
|
||||
import structlog
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from app.core.database import get_db, get_background_db_session
|
||||
from app.schemas.auth import UserResponse, PasswordChange
|
||||
from app.schemas.users import UserUpdate, BatchUserRequest, OwnerUserCreate
|
||||
from app.services.user_service import EnhancedUserService
|
||||
from app.models.users import User
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.admin_delete import AdminUserDeleteService
|
||||
from app.models import AuditLog
|
||||
|
||||
# Import unified authentication from shared library
|
||||
from shared.auth.decorators import (
|
||||
get_current_user_dep,
|
||||
require_admin_role_dep
|
||||
)
|
||||
from shared.security import create_audit_logger, AuditSeverity, AuditAction
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter(tags=["users"])
|
||||
|
||||
# Initialize audit logger
|
||||
audit_logger = create_audit_logger("auth-service", AuditLog)
|
||||
|
||||
@router.delete("/api/v1/auth/users/{user_id}")
|
||||
async def delete_admin_user(
|
||||
background_tasks: BackgroundTasks,
|
||||
user_id: str = Path(..., description="User ID"),
|
||||
current_user = Depends(require_admin_role_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Delete an admin user and all associated data across all services.
|
||||
|
||||
This operation will:
|
||||
1. Cancel any active training jobs for user's tenants
|
||||
2. Delete all trained models and artifacts
|
||||
3. Delete all forecasts and predictions
|
||||
4. Delete notification preferences and logs
|
||||
5. Handle tenant ownership (transfer or delete)
|
||||
6. Delete user account and authentication data
|
||||
|
||||
**Warning: This operation is irreversible!**
|
||||
"""
|
||||
|
||||
# Validate user_id format
|
||||
try:
|
||||
uuid.UUID(user_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid user ID format"
|
||||
)
|
||||
|
||||
# Quick validation that user exists before starting background task
|
||||
deletion_service = AdminUserDeleteService(db)
|
||||
user_info = await deletion_service._validate_admin_user(user_id)
|
||||
if not user_info:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Admin user {user_id} not found"
|
||||
)
|
||||
|
||||
# Log audit event for user deletion
|
||||
try:
|
||||
# Get tenant_id from current_user or use a placeholder for system-level operations
|
||||
tenant_id_str = current_user.get("tenant_id", "00000000-0000-0000-0000-000000000000")
|
||||
await audit_logger.log_deletion(
|
||||
db_session=db,
|
||||
tenant_id=tenant_id_str,
|
||||
user_id=current_user["user_id"],
|
||||
resource_type="user",
|
||||
resource_id=user_id,
|
||||
resource_data=user_info,
|
||||
description=f"Admin {current_user.get('email', current_user['user_id'])} initiated deletion of user {user_info.get('email', user_id)}",
|
||||
endpoint="/delete/{user_id}",
|
||||
method="DELETE"
|
||||
)
|
||||
except Exception as audit_error:
|
||||
logger.warning("Failed to log audit event", error=str(audit_error))
|
||||
|
||||
# Start deletion as background task for better performance
|
||||
background_tasks.add_task(
|
||||
execute_admin_user_deletion,
|
||||
user_id=user_id,
|
||||
requesting_user_id=current_user["user_id"]
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Admin user deletion for {user_id} has been initiated",
|
||||
"status": "processing",
|
||||
"user_info": user_info,
|
||||
"initiated_at": datetime.utcnow().isoformat(),
|
||||
"note": "Deletion is processing in the background. Check logs for completion status."
|
||||
}
|
||||
|
||||
# Add this background task function to services/auth/app/api/users.py:
|
||||
|
||||
async def execute_admin_user_deletion(user_id: str, requesting_user_id: str):
|
||||
"""
|
||||
Background task using shared infrastructure
|
||||
"""
|
||||
# ✅ Use the shared background session
|
||||
async with get_background_db_session() as session:
|
||||
deletion_service = AdminUserDeleteService(session)
|
||||
|
||||
result = await deletion_service.delete_admin_user_complete(
|
||||
user_id=user_id,
|
||||
requesting_user_id=requesting_user_id
|
||||
)
|
||||
|
||||
logger.info("Background admin user deletion completed successfully",
|
||||
user_id=user_id,
|
||||
requesting_user=requesting_user_id,
|
||||
result=result)
|
||||
|
||||
|
||||
@router.get("/api/v1/auth/users/{user_id}/deletion-preview")
|
||||
async def preview_user_deletion(
|
||||
user_id: str = Path(..., description="User ID"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Preview what data would be deleted for an admin user.
|
||||
|
||||
This endpoint provides a dry-run preview of the deletion operation
|
||||
without actually deleting any data.
|
||||
"""
|
||||
|
||||
try:
|
||||
uuid.UUID(user_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid user ID format"
|
||||
)
|
||||
|
||||
deletion_service = AdminUserDeleteService(db)
|
||||
|
||||
# Get user info
|
||||
user_info = await deletion_service._validate_admin_user(user_id)
|
||||
if not user_info:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Admin user {user_id} not found"
|
||||
)
|
||||
|
||||
# Get tenant associations
|
||||
tenant_info = await deletion_service._get_user_tenant_info(user_id)
|
||||
|
||||
# Build preview
|
||||
preview = {
|
||||
"user": user_info,
|
||||
"tenant_associations": tenant_info,
|
||||
"estimated_deletions": {
|
||||
"training_models": "All models for associated tenants",
|
||||
"forecasts": "All forecasts for associated tenants",
|
||||
"notifications": "All user notification data",
|
||||
"tenant_memberships": tenant_info['total_tenants'],
|
||||
"owned_tenants": f"{tenant_info['owned_tenants']} (will be transferred or deleted)"
|
||||
},
|
||||
"warning": "This operation is irreversible and will permanently delete all associated data"
|
||||
}
|
||||
|
||||
return preview
|
||||
|
||||
|
||||
@router.get("/api/v1/auth/users/{user_id}", response_model=UserResponse)
|
||||
async def get_user_by_id(
|
||||
user_id: str = Path(..., description="User ID"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get user information by user ID.
|
||||
|
||||
This endpoint is for internal service-to-service communication.
|
||||
It returns user details needed by other services (e.g., tenant service for enriching member data).
|
||||
"""
|
||||
try:
|
||||
# Validate UUID format
|
||||
try:
|
||||
uuid.UUID(user_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid user ID format"
|
||||
)
|
||||
|
||||
# Fetch user from database
|
||||
from app.repositories import UserRepository
|
||||
user_repo = UserRepository(User, db)
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"User {user_id} not found"
|
||||
)
|
||||
|
||||
logger.debug("Retrieved user by ID", user_id=user_id, email=user.email)
|
||||
|
||||
return UserResponse(
|
||||
id=str(user.id),
|
||||
email=user.email,
|
||||
full_name=user.full_name,
|
||||
is_active=user.is_active,
|
||||
is_verified=user.is_verified,
|
||||
phone=user.phone,
|
||||
language=user.language or "es",
|
||||
timezone=user.timezone or "Europe/Madrid",
|
||||
created_at=user.created_at,
|
||||
last_login=user.last_login,
|
||||
role=user.role,
|
||||
tenant_id=None,
|
||||
payment_customer_id=user.payment_customer_id,
|
||||
default_payment_method_id=user.default_payment_method_id
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Get user by ID error", user_id=user_id, error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get user information"
|
||||
)
|
||||
|
||||
|
||||
@router.put("/api/v1/auth/users/{user_id}", response_model=UserResponse)
|
||||
async def update_user_profile(
|
||||
user_id: str = Path(..., description="User ID"),
|
||||
update_data: UserUpdate = Body(..., description="User profile update data"),
|
||||
current_user = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Update user profile information.
|
||||
|
||||
This endpoint allows users to update their profile information including:
|
||||
- Full name
|
||||
- Phone number
|
||||
- Language preference
|
||||
- Timezone
|
||||
|
||||
**Permissions:** Users can update their own profile, admins can update any user's profile
|
||||
"""
|
||||
try:
|
||||
# Validate UUID format
|
||||
try:
|
||||
uuid.UUID(user_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid user ID format"
|
||||
)
|
||||
|
||||
# Check permissions - user can update their own profile, admins can update any
|
||||
if current_user["user_id"] != user_id:
|
||||
# Check if current user has admin privileges
|
||||
user_role = current_user.get("role", "user")
|
||||
if user_role not in ["admin", "super_admin", "manager"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Insufficient permissions to update this user's profile"
|
||||
)
|
||||
|
||||
# Fetch user from database
|
||||
from app.repositories import UserRepository
|
||||
user_repo = UserRepository(User, db)
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"User {user_id} not found"
|
||||
)
|
||||
|
||||
# Prepare update data (only include fields that are provided)
|
||||
update_fields = update_data.dict(exclude_unset=True)
|
||||
if not update_fields:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="No update data provided"
|
||||
)
|
||||
|
||||
# Update user
|
||||
updated_user = await user_repo.update(user_id, update_fields)
|
||||
|
||||
logger.info("User profile updated", user_id=user_id, updated_fields=list(update_fields.keys()))
|
||||
|
||||
# Log audit event for user profile update
|
||||
try:
|
||||
# Get tenant_id from current_user or use a placeholder for system-level operations
|
||||
tenant_id_str = current_user.get("tenant_id", "00000000-0000-0000-0000-000000000000")
|
||||
await audit_logger.log_event(
|
||||
db_session=db,
|
||||
tenant_id=tenant_id_str,
|
||||
user_id=current_user["user_id"],
|
||||
action=AuditAction.UPDATE.value,
|
||||
resource_type="user",
|
||||
resource_id=user_id,
|
||||
severity=AuditSeverity.MEDIUM.value,
|
||||
description=f"User {current_user.get('email', current_user['user_id'])} updated profile for user {user.email}",
|
||||
changes={"updated_fields": list(update_fields.keys())},
|
||||
audit_metadata={"updated_data": update_fields},
|
||||
endpoint="/users/{user_id}",
|
||||
method="PUT"
|
||||
)
|
||||
except Exception as audit_error:
|
||||
logger.warning("Failed to log audit event", error=str(audit_error))
|
||||
|
||||
return UserResponse(
|
||||
id=str(updated_user.id),
|
||||
email=updated_user.email,
|
||||
full_name=updated_user.full_name,
|
||||
is_active=updated_user.is_active,
|
||||
is_verified=updated_user.is_verified,
|
||||
phone=updated_user.phone,
|
||||
language=updated_user.language or "es",
|
||||
timezone=updated_user.timezone or "Europe/Madrid",
|
||||
created_at=updated_user.created_at,
|
||||
last_login=updated_user.last_login,
|
||||
role=updated_user.role,
|
||||
tenant_id=None,
|
||||
payment_customer_id=updated_user.payment_customer_id,
|
||||
default_payment_method_id=updated_user.default_payment_method_id
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Update user profile error", user_id=user_id, error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update user profile"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api/v1/auth/users/create-by-owner", response_model=UserResponse)
|
||||
async def create_user_by_owner(
|
||||
user_data: OwnerUserCreate,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Create a new user account (owner/admin only - for pilot phase).
|
||||
|
||||
This endpoint allows tenant owners to directly create user accounts
|
||||
with passwords during the pilot phase. In production, this will be
|
||||
replaced with an invitation-based flow.
|
||||
|
||||
**Permissions:** Owner or Admin role required
|
||||
**Security:** Password is hashed server-side before storage
|
||||
"""
|
||||
try:
|
||||
# Verify caller has admin or owner privileges
|
||||
# In pilot phase, we allow 'admin' role from auth service
|
||||
user_role = current_user.get("role", "user")
|
||||
if user_role not in ["admin", "super_admin", "manager"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only administrators can create users directly"
|
||||
)
|
||||
|
||||
# Validate email uniqueness
|
||||
from app.repositories import UserRepository
|
||||
user_repo = UserRepository(User, db)
|
||||
|
||||
existing_user = await user_repo.get_by_email(user_data.email)
|
||||
if existing_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"User with email {user_data.email} already exists"
|
||||
)
|
||||
|
||||
# Hash password
|
||||
from app.core.security import SecurityManager
|
||||
hashed_password = SecurityManager.hash_password(user_data.password)
|
||||
|
||||
# Create user
|
||||
create_data = {
|
||||
"email": user_data.email,
|
||||
"full_name": user_data.full_name,
|
||||
"hashed_password": hashed_password,
|
||||
"phone": user_data.phone,
|
||||
"role": user_data.role,
|
||||
"language": user_data.language or "es",
|
||||
"timezone": user_data.timezone or "Europe/Madrid",
|
||||
"is_active": True,
|
||||
"is_verified": False # Can be verified later if needed
|
||||
}
|
||||
|
||||
new_user = await user_repo.create_user(create_data)
|
||||
|
||||
logger.info(
|
||||
"User created by owner",
|
||||
created_user_id=str(new_user.id),
|
||||
created_user_email=new_user.email,
|
||||
created_by=current_user.get("user_id"),
|
||||
created_by_email=current_user.get("email")
|
||||
)
|
||||
|
||||
# Return user response
|
||||
return UserResponse(
|
||||
id=str(new_user.id),
|
||||
email=new_user.email,
|
||||
full_name=new_user.full_name,
|
||||
is_active=new_user.is_active,
|
||||
is_verified=new_user.is_verified,
|
||||
phone=new_user.phone,
|
||||
language=new_user.language,
|
||||
timezone=new_user.timezone,
|
||||
created_at=new_user.created_at,
|
||||
last_login=new_user.last_login,
|
||||
role=new_user.role,
|
||||
tenant_id=None # Will be set when added to tenant
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to create user by owner",
|
||||
email=user_data.email,
|
||||
error=str(e),
|
||||
created_by=current_user.get("user_id")
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to create user account"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api/v1/auth/users/batch", response_model=Dict[str, Any])
|
||||
async def get_users_batch(
|
||||
request: BatchUserRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get multiple users by their IDs in a single request.
|
||||
|
||||
This endpoint is for internal service-to-service communication.
|
||||
It efficiently fetches multiple user records needed by other services
|
||||
(e.g., tenant service for enriching member lists).
|
||||
|
||||
Returns a dict mapping user_id -> user data, with null for non-existent users.
|
||||
"""
|
||||
try:
|
||||
# Validate all UUIDs
|
||||
validated_ids = []
|
||||
for user_id in request.user_ids:
|
||||
try:
|
||||
uuid.UUID(user_id)
|
||||
validated_ids.append(user_id)
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid user ID format in batch request: {user_id}")
|
||||
continue
|
||||
|
||||
if not validated_ids:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="No valid user IDs provided"
|
||||
)
|
||||
|
||||
# Fetch users from database
|
||||
from app.repositories import UserRepository
|
||||
user_repo = UserRepository(User, db)
|
||||
|
||||
# Build response map
|
||||
user_map = {}
|
||||
for user_id in validated_ids:
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
|
||||
if user:
|
||||
user_map[user_id] = {
|
||||
"id": str(user.id),
|
||||
"email": user.email,
|
||||
"full_name": user.full_name,
|
||||
"is_active": user.is_active,
|
||||
"is_verified": user.is_verified,
|
||||
"phone": user.phone,
|
||||
"language": user.language or "es",
|
||||
"timezone": user.timezone or "Europe/Madrid",
|
||||
"created_at": user.created_at.isoformat() if user.created_at else None,
|
||||
"last_login": user.last_login.isoformat() if user.last_login else None,
|
||||
"role": user.role
|
||||
}
|
||||
else:
|
||||
user_map[user_id] = None
|
||||
|
||||
logger.debug(
|
||||
"Batch user fetch completed",
|
||||
requested_count=len(request.user_ids),
|
||||
found_count=sum(1 for v in user_map.values() if v is not None)
|
||||
)
|
||||
|
||||
return {
|
||||
"users": user_map,
|
||||
"requested_count": len(request.user_ids),
|
||||
"found_count": sum(1 for v in user_map.values() if v is not None)
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Batch user fetch error", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to fetch users"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/api/v1/auth/users/{user_id}/activity")
|
||||
async def get_user_activity(
|
||||
user_id: str = Path(..., description="User ID"),
|
||||
current_user = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get user activity information.
|
||||
|
||||
This endpoint returns detailed activity information for a user including:
|
||||
- Last login timestamp
|
||||
- Account creation date
|
||||
- Active session count
|
||||
- Last activity timestamp
|
||||
- User status information
|
||||
|
||||
**Permissions:** User can view their own activity, admins can view any user's activity
|
||||
"""
|
||||
try:
|
||||
# Validate UUID format
|
||||
try:
|
||||
uuid.UUID(user_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid user ID format"
|
||||
)
|
||||
|
||||
# Check permissions - user can view their own activity, admins can view any
|
||||
if current_user["user_id"] != user_id:
|
||||
# Check if current user has admin privileges
|
||||
user_role = current_user.get("role", "user")
|
||||
if user_role not in ["admin", "super_admin", "manager"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Insufficient permissions to view this user's activity"
|
||||
)
|
||||
|
||||
# Initialize enhanced user service
|
||||
from app.core.config import settings
|
||||
from shared.database.base import create_database_manager
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "tenant-service")
|
||||
user_service = EnhancedUserService(database_manager)
|
||||
|
||||
# Get user activity data
|
||||
activity_data = await user_service.get_user_activity(user_id)
|
||||
|
||||
if "error" in activity_data:
|
||||
if activity_data["error"] == "User not found":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get user activity: {activity_data['error']}"
|
||||
)
|
||||
|
||||
logger.debug("Retrieved user activity", user_id=user_id)
|
||||
|
||||
return activity_data
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Get user activity error", user_id=user_id, error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get user activity information"
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/api/v1/auth/users/{user_id}/tenant")
|
||||
async def update_user_tenant(
|
||||
user_id: str = Path(..., description="User ID"),
|
||||
tenant_data: Dict[str, Any] = Body(..., description="Tenant data containing tenant_id"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Update user's tenant_id after tenant registration
|
||||
|
||||
This endpoint is called by the tenant service after a user creates their tenant.
|
||||
It links the user to their newly created tenant.
|
||||
"""
|
||||
try:
|
||||
# Log the incoming request data for debugging
|
||||
logger.debug("Received tenant update request",
|
||||
user_id=user_id,
|
||||
tenant_data=tenant_data)
|
||||
|
||||
tenant_id = tenant_data.get("tenant_id")
|
||||
|
||||
if not tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="tenant_id is required"
|
||||
)
|
||||
|
||||
logger.info("Updating user tenant_id",
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id)
|
||||
|
||||
user_service = EnhancedUserService(db)
|
||||
user = await user_service.get_user_by_id(uuid.UUID(user_id), session=db)
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
# DEPRECATED: User-tenant relationships are now managed by tenant service
|
||||
# This endpoint is kept for backward compatibility but does nothing
|
||||
# The tenant service should manage user-tenant relationships internally
|
||||
|
||||
logger.warning("DEPRECATED: update_user_tenant endpoint called - user-tenant relationships are now managed by tenant service",
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id)
|
||||
|
||||
# Return success for backward compatibility, but don't actually update anything
|
||||
return {
|
||||
"success": True,
|
||||
"user_id": str(user.id),
|
||||
"tenant_id": tenant_id,
|
||||
"message": "User-tenant relationships are now managed by tenant service. This endpoint is deprecated."
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to update user tenant_id",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update user tenant_id"
|
||||
)
|
||||
0
services/auth/app/core/__init__.py
Normal file
0
services/auth/app/core/__init__.py
Normal file
132
services/auth/app/core/auth.py
Normal file
132
services/auth/app/core/auth.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""
|
||||
Authentication dependency for auth service
|
||||
services/auth/app/core/auth.py
|
||||
"""
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from jose import JWTError, jwt
|
||||
import structlog
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import get_db
|
||||
from app.models.users import User
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
security = HTTPBearer()
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> User:
|
||||
"""
|
||||
Dependency to get the current authenticated user
|
||||
"""
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
try:
|
||||
# Decode JWT token
|
||||
payload = jwt.decode(
|
||||
credentials.credentials,
|
||||
settings.JWT_SECRET_KEY,
|
||||
algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
|
||||
# Get user identifier from token
|
||||
user_id: str = payload.get("sub")
|
||||
if user_id is None:
|
||||
logger.warning("Token payload missing 'sub' field")
|
||||
raise credentials_exception
|
||||
|
||||
logger.info(f"Authenticating user: {user_id}")
|
||||
|
||||
except JWTError as e:
|
||||
logger.warning(f"JWT decode error: {e}")
|
||||
raise credentials_exception
|
||||
|
||||
try:
|
||||
# Get user from database
|
||||
result = await db.execute(
|
||||
select(User).where(User.id == user_id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user is None:
|
||||
logger.warning(f"User not found for ID: {user_id}")
|
||||
raise credentials_exception
|
||||
|
||||
if not user.is_active:
|
||||
logger.warning(f"Inactive user attempted access: {user_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Inactive user"
|
||||
)
|
||||
|
||||
logger.info(f"User authenticated: {user.email}")
|
||||
return user
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user: {e}")
|
||||
raise credentials_exception
|
||||
|
||||
|
||||
async def get_current_active_user(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> User:
|
||||
"""
|
||||
Dependency to get the current active user
|
||||
"""
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Inactive user"
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against its hash"""
|
||||
from passlib.context import CryptContext
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""Generate password hash"""
|
||||
from passlib.context import CryptContext
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta=None):
|
||||
"""Create JWT access token"""
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def create_refresh_token(data: dict):
|
||||
"""Create JWT refresh token"""
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
to_encode = data.copy()
|
||||
expire = datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
to_encode.update({"exp": expire})
|
||||
|
||||
encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
|
||||
return encoded_jwt
|
||||
70
services/auth/app/core/config.py
Normal file
70
services/auth/app/core/config.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# ================================================================
|
||||
# AUTH SERVICE CONFIGURATION
|
||||
# services/auth/app/core/config.py
|
||||
# ================================================================
|
||||
|
||||
"""
|
||||
Authentication service configuration
|
||||
User management and JWT token handling
|
||||
"""
|
||||
|
||||
from shared.config.base import BaseServiceSettings
|
||||
import os
|
||||
|
||||
class AuthSettings(BaseServiceSettings):
|
||||
"""Auth service specific settings"""
|
||||
|
||||
# Service Identity
|
||||
APP_NAME: str = "Authentication Service"
|
||||
SERVICE_NAME: str = "auth-service"
|
||||
DESCRIPTION: str = "User authentication and authorization service"
|
||||
|
||||
# Database configuration (secure approach - build from components)
|
||||
@property
|
||||
def DATABASE_URL(self) -> str:
|
||||
"""Build database URL from secure components"""
|
||||
# Try complete URL first (for backward compatibility)
|
||||
complete_url = os.getenv("AUTH_DATABASE_URL")
|
||||
if complete_url:
|
||||
return complete_url
|
||||
|
||||
# Build from components (secure approach)
|
||||
user = os.getenv("AUTH_DB_USER", "auth_user")
|
||||
password = os.getenv("AUTH_DB_PASSWORD", "auth_pass123")
|
||||
host = os.getenv("AUTH_DB_HOST", "localhost")
|
||||
port = os.getenv("AUTH_DB_PORT", "5432")
|
||||
name = os.getenv("AUTH_DB_NAME", "auth_db")
|
||||
|
||||
return f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{name}"
|
||||
|
||||
# Redis Database (dedicated for auth)
|
||||
REDIS_DB: int = 0
|
||||
|
||||
# Enhanced Password Requirements for Spain
|
||||
PASSWORD_MIN_LENGTH: int = 8
|
||||
PASSWORD_REQUIRE_UPPERCASE: bool = True
|
||||
PASSWORD_REQUIRE_LOWERCASE: bool = True
|
||||
PASSWORD_REQUIRE_NUMBERS: bool = True
|
||||
PASSWORD_REQUIRE_SYMBOLS: bool = False
|
||||
|
||||
# Spanish GDPR Compliance
|
||||
GDPR_COMPLIANCE_ENABLED: bool = True
|
||||
DATA_RETENTION_DAYS: int = int(os.getenv("AUTH_DATA_RETENTION_DAYS", "365"))
|
||||
CONSENT_REQUIRED: bool = True
|
||||
PRIVACY_POLICY_URL: str = os.getenv("PRIVACY_POLICY_URL", "/privacy")
|
||||
|
||||
# Account Security
|
||||
ACCOUNT_LOCKOUT_ENABLED: bool = True
|
||||
MAX_LOGIN_ATTEMPTS: int = 5
|
||||
LOCKOUT_DURATION_MINUTES: int = 30
|
||||
PASSWORD_HISTORY_COUNT: int = 5
|
||||
|
||||
# Session Management
|
||||
SESSION_TIMEOUT_MINUTES: int = int(os.getenv("SESSION_TIMEOUT_MINUTES", "60"))
|
||||
CONCURRENT_SESSIONS_LIMIT: int = int(os.getenv("CONCURRENT_SESSIONS_LIMIT", "3"))
|
||||
|
||||
# Email Verification
|
||||
EMAIL_VERIFICATION_REQUIRED: bool = os.getenv("EMAIL_VERIFICATION_REQUIRED", "true").lower() == "true"
|
||||
EMAIL_VERIFICATION_EXPIRE_HOURS: int = int(os.getenv("EMAIL_VERIFICATION_EXPIRE_HOURS", "24"))
|
||||
|
||||
settings = AuthSettings()
|
||||
290
services/auth/app/core/database.py
Normal file
290
services/auth/app/core/database.py
Normal file
@@ -0,0 +1,290 @@
|
||||
# ================================================================
|
||||
# services/auth/app/core/database.py (ENHANCED VERSION)
|
||||
# ================================================================
|
||||
"""
|
||||
Database configuration for authentication service
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from app.core.config import settings
|
||||
from shared.database.base import Base
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Create async engine
|
||||
engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
poolclass=NullPool,
|
||||
echo=settings.DEBUG,
|
||||
future=True
|
||||
)
|
||||
|
||||
# Create session factory
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
autocommit=False
|
||||
)
|
||||
|
||||
async def get_db() -> AsyncSession:
|
||||
"""Database dependency"""
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Database session error: {e}")
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
async def create_tables():
|
||||
"""Create database tables"""
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
logger.info("Database tables created successfully")
|
||||
# ================================================================
|
||||
# services/auth/app/core/database.py - UPDATED TO USE SHARED INFRASTRUCTURE
|
||||
# ================================================================
|
||||
"""
|
||||
Database configuration for authentication service
|
||||
Uses shared database infrastructure for consistency
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from typing import AsyncGenerator
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import text
|
||||
|
||||
from shared.database.base import DatabaseManager, Base
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# ✅ Initialize database manager using shared infrastructure
|
||||
database_manager = DatabaseManager(settings.DATABASE_URL)
|
||||
|
||||
# ✅ Alias for convenience - matches the existing interface
|
||||
get_db = database_manager.get_db
|
||||
|
||||
# ✅ Use the shared background session method
|
||||
get_background_db_session = database_manager.get_background_session
|
||||
|
||||
async def get_db_health() -> bool:
|
||||
"""
|
||||
Health check function for database connectivity
|
||||
"""
|
||||
try:
|
||||
async with database_manager.async_engine.begin() as conn:
|
||||
await conn.execute(text("SELECT 1"))
|
||||
logger.debug("Database health check passed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Database health check failed: {str(e)}")
|
||||
return False
|
||||
|
||||
async def create_tables():
|
||||
"""Create database tables using shared infrastructure"""
|
||||
await database_manager.create_tables()
|
||||
logger.info("Auth database tables created successfully")
|
||||
|
||||
# ✅ Auth service specific database utilities
|
||||
class AuthDatabaseUtils:
|
||||
"""Auth service specific database utilities"""
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_old_refresh_tokens(days_old: int = 30):
|
||||
"""Clean up old refresh tokens"""
|
||||
try:
|
||||
async with database_manager.get_background_session() as session:
|
||||
if settings.DATABASE_URL.startswith("sqlite"):
|
||||
query = text(
|
||||
"DELETE FROM refresh_tokens "
|
||||
"WHERE created_at < datetime('now', :days_param)"
|
||||
)
|
||||
params = {"days_param": f"-{days_old} days"}
|
||||
else:
|
||||
# PostgreSQL
|
||||
query = text(
|
||||
"DELETE FROM refresh_tokens "
|
||||
"WHERE created_at < NOW() - INTERVAL :days_param"
|
||||
)
|
||||
params = {"days_param": f"{days_old} days"}
|
||||
|
||||
result = await session.execute(query, params)
|
||||
# No need to commit - get_background_session() handles it
|
||||
|
||||
logger.info("Cleaned up old refresh tokens",
|
||||
deleted_count=result.rowcount,
|
||||
days_old=days_old)
|
||||
|
||||
return result.rowcount
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup old refresh tokens",
|
||||
error=str(e))
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
async def get_auth_statistics(tenant_id: str = None) -> dict:
|
||||
"""Get authentication statistics"""
|
||||
try:
|
||||
async with database_manager.get_background_session() as session:
|
||||
# Base query for users
|
||||
users_query = text("SELECT COUNT(*) as count FROM users WHERE is_active = :is_active")
|
||||
params = {}
|
||||
|
||||
if tenant_id:
|
||||
# If tenant filtering is needed (though auth service might not have tenant_id in users table)
|
||||
# This is just an example - adjust based on your actual schema
|
||||
pass
|
||||
|
||||
# Get active users count
|
||||
active_users_result = await session.execute(
|
||||
users_query,
|
||||
{**params, "is_active": True}
|
||||
)
|
||||
active_users = active_users_result.scalar() or 0
|
||||
|
||||
# Get inactive users count
|
||||
inactive_users_result = await session.execute(
|
||||
users_query,
|
||||
{**params, "is_active": False}
|
||||
)
|
||||
inactive_users = inactive_users_result.scalar() or 0
|
||||
|
||||
# Get refresh tokens count
|
||||
tokens_query = text("SELECT COUNT(*) as count FROM refresh_tokens")
|
||||
tokens_result = await session.execute(tokens_query)
|
||||
active_tokens = tokens_result.scalar() or 0
|
||||
|
||||
return {
|
||||
"active_users": active_users,
|
||||
"inactive_users": inactive_users,
|
||||
"total_users": active_users + inactive_users,
|
||||
"active_tokens": active_tokens
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get auth statistics: {str(e)}")
|
||||
return {
|
||||
"active_users": 0,
|
||||
"inactive_users": 0,
|
||||
"total_users": 0,
|
||||
"active_tokens": 0
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def check_user_exists(user_id: str) -> bool:
|
||||
"""Check if user exists"""
|
||||
try:
|
||||
async with database_manager.get_background_session() as session:
|
||||
query = text(
|
||||
"SELECT COUNT(*) as count "
|
||||
"FROM users "
|
||||
"WHERE id = :user_id "
|
||||
"LIMIT 1"
|
||||
)
|
||||
|
||||
result = await session.execute(query, {"user_id": user_id})
|
||||
count = result.scalar() or 0
|
||||
|
||||
return count > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to check user existence",
|
||||
user_id=user_id, error=str(e))
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def get_user_token_count(user_id: str) -> int:
|
||||
"""Get count of active refresh tokens for a user"""
|
||||
try:
|
||||
async with database_manager.get_background_session() as session:
|
||||
query = text(
|
||||
"SELECT COUNT(*) as count "
|
||||
"FROM refresh_tokens "
|
||||
"WHERE user_id = :user_id"
|
||||
)
|
||||
|
||||
result = await session.execute(query, {"user_id": user_id})
|
||||
count = result.scalar() or 0
|
||||
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get user token count",
|
||||
user_id=user_id, error=str(e))
|
||||
return 0
|
||||
|
||||
# Enhanced database session dependency with better error handling
|
||||
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
Enhanced database session dependency with better logging and error handling
|
||||
"""
|
||||
async with database_manager.async_session_local() as session:
|
||||
try:
|
||||
logger.debug("Database session created")
|
||||
yield session
|
||||
except Exception as e:
|
||||
logger.error(f"Database session error: {str(e)}", exc_info=True)
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
logger.debug("Database session closed")
|
||||
|
||||
# Database initialization for auth service
|
||||
async def initialize_auth_database():
|
||||
"""Initialize database tables for auth service"""
|
||||
try:
|
||||
logger.info("Initializing auth service database")
|
||||
|
||||
# Import models to ensure they're registered
|
||||
from app.models.users import User
|
||||
from app.models.refresh_tokens import RefreshToken
|
||||
|
||||
# Create tables using shared infrastructure
|
||||
await database_manager.create_tables()
|
||||
|
||||
logger.info("Auth service database initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize auth service database: {str(e)}")
|
||||
raise
|
||||
|
||||
# Database cleanup for auth service
|
||||
async def cleanup_auth_database():
|
||||
"""Cleanup database connections for auth service"""
|
||||
try:
|
||||
logger.info("Cleaning up auth service database connections")
|
||||
|
||||
# Close engine connections
|
||||
if hasattr(database_manager, 'async_engine') and database_manager.async_engine:
|
||||
await database_manager.async_engine.dispose()
|
||||
|
||||
logger.info("Auth service database cleanup completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup auth service database: {str(e)}")
|
||||
|
||||
# Export the commonly used items to maintain compatibility
|
||||
__all__ = [
|
||||
'Base',
|
||||
'database_manager',
|
||||
'get_db',
|
||||
'get_background_db_session',
|
||||
'get_db_session',
|
||||
'get_db_health',
|
||||
'AuthDatabaseUtils',
|
||||
'initialize_auth_database',
|
||||
'cleanup_auth_database',
|
||||
'create_tables'
|
||||
]
|
||||
453
services/auth/app/core/security.py
Normal file
453
services/auth/app/core/security.py
Normal file
@@ -0,0 +1,453 @@
|
||||
# services/auth/app/core/security.py - FIXED VERSION
|
||||
"""
|
||||
Security utilities for authentication service
|
||||
FIXED VERSION - Consistent password hashing using passlib
|
||||
"""
|
||||
|
||||
import re
|
||||
import hashlib
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, Dict, Any, List
|
||||
from shared.redis_utils import get_redis_client
|
||||
from fastapi import HTTPException, status
|
||||
import structlog
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from app.core.config import settings
|
||||
from shared.auth.jwt_handler import JWTHandler
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# ✅ FIX: Use passlib for consistent password hashing
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
# Initialize JWT handler with SAME configuration as gateway
|
||||
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
|
||||
|
||||
# Note: Redis client is now accessed via get_redis_client() from shared.redis_utils
|
||||
|
||||
class SecurityManager:
|
||||
"""Security utilities for authentication - FIXED VERSION"""
|
||||
|
||||
@staticmethod
|
||||
def validate_password(password: str) -> bool:
|
||||
"""Validate password strength"""
|
||||
if len(password) < settings.PASSWORD_MIN_LENGTH:
|
||||
return False
|
||||
|
||||
if len(password) > 128: # Max length from Pydantic schema
|
||||
return False
|
||||
|
||||
if settings.PASSWORD_REQUIRE_UPPERCASE and not re.search(r'[A-Z]', password):
|
||||
return False
|
||||
|
||||
if settings.PASSWORD_REQUIRE_LOWERCASE and not re.search(r'[a-z]', password):
|
||||
return False
|
||||
|
||||
if settings.PASSWORD_REQUIRE_NUMBERS and not re.search(r'\d', password):
|
||||
return False
|
||||
|
||||
if settings.PASSWORD_REQUIRE_SYMBOLS and not re.search(r'[!@#$%^&*(),.?":{}|<>]', password):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_password_validation_errors(password: str) -> List[str]:
|
||||
"""Get detailed password validation errors for better UX"""
|
||||
errors = []
|
||||
|
||||
if len(password) < settings.PASSWORD_MIN_LENGTH:
|
||||
errors.append(f"Password must be at least {settings.PASSWORD_MIN_LENGTH} characters long")
|
||||
|
||||
if len(password) > 128:
|
||||
errors.append("Password cannot exceed 128 characters")
|
||||
|
||||
if settings.PASSWORD_REQUIRE_UPPERCASE and not re.search(r'[A-Z]', password):
|
||||
errors.append("Password must contain at least one uppercase letter")
|
||||
|
||||
if settings.PASSWORD_REQUIRE_LOWERCASE and not re.search(r'[a-z]', password):
|
||||
errors.append("Password must contain at least one lowercase letter")
|
||||
|
||||
if settings.PASSWORD_REQUIRE_NUMBERS and not re.search(r'\d', password):
|
||||
errors.append("Password must contain at least one number")
|
||||
|
||||
if settings.PASSWORD_REQUIRE_SYMBOLS and not re.search(r'[!@#$%^&*(),.?":{}|<>]', password):
|
||||
errors.append("Password must contain at least one symbol (!@#$%^&*(),.?\":{}|<>)")
|
||||
|
||||
return errors
|
||||
|
||||
@staticmethod
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash password using passlib bcrypt - FIXED"""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
@staticmethod
|
||||
def verify_password(password: str, hashed_password: str) -> bool:
|
||||
"""Verify password against hash using passlib - FIXED"""
|
||||
try:
|
||||
return pwd_context.verify(password, hashed_password)
|
||||
except Exception as e:
|
||||
logger.error(f"Password verification error: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def create_access_token(user_data: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Create JWT ACCESS token with proper payload structure
|
||||
✅ FIXED: Only creates access tokens
|
||||
"""
|
||||
|
||||
# Validate required fields for access token
|
||||
if "user_id" not in user_data:
|
||||
raise ValueError("user_id required for access token creation")
|
||||
|
||||
if "email" not in user_data:
|
||||
raise ValueError("email required for access token creation")
|
||||
|
||||
try:
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
|
||||
# ✅ FIX 1: ACCESS TOKEN payload structure
|
||||
payload = {
|
||||
"sub": user_data["user_id"],
|
||||
"user_id": user_data["user_id"],
|
||||
"email": user_data["email"],
|
||||
"type": "access", # ✅ EXPLICITLY set as access token
|
||||
"exp": expire,
|
||||
"iat": datetime.now(timezone.utc),
|
||||
"iss": "bakery-auth"
|
||||
}
|
||||
|
||||
# Add optional fields for access tokens
|
||||
if "full_name" in user_data:
|
||||
payload["full_name"] = user_data["full_name"]
|
||||
if "is_verified" in user_data:
|
||||
payload["is_verified"] = user_data["is_verified"]
|
||||
if "is_active" in user_data:
|
||||
payload["is_active"] = user_data["is_active"]
|
||||
|
||||
# ✅ CRITICAL FIX: Include role in access token!
|
||||
if "role" in user_data:
|
||||
payload["role"] = user_data["role"]
|
||||
else:
|
||||
payload["role"] = "admin" # Default role if not specified
|
||||
|
||||
# NEW: Add subscription data to JWT payload
|
||||
if "tenant_id" in user_data:
|
||||
payload["tenant_id"] = user_data["tenant_id"]
|
||||
|
||||
if "tenant_role" in user_data:
|
||||
payload["tenant_role"] = user_data["tenant_role"]
|
||||
|
||||
if "subscription" in user_data:
|
||||
payload["subscription"] = user_data["subscription"]
|
||||
|
||||
if "tenant_access" in user_data:
|
||||
# Limit tenant_access to 10 entries to prevent JWT size explosion
|
||||
tenant_access = user_data["tenant_access"]
|
||||
if tenant_access and len(tenant_access) > 10:
|
||||
tenant_access = tenant_access[:10]
|
||||
logger.warning(f"Truncated tenant_access to 10 entries for user {user_data['user_id']}")
|
||||
payload["tenant_access"] = tenant_access
|
||||
|
||||
logger.debug(f"Creating access token with payload keys: {list(payload.keys())}")
|
||||
|
||||
# ✅ FIX 2: Use JWT handler to create access token
|
||||
token = jwt_handler.create_access_token_from_payload(payload)
|
||||
logger.debug(f"Access token created successfully for user {user_data['email']}")
|
||||
return token
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Access token creation failed for {user_data.get('email', 'unknown')}: {e}")
|
||||
raise ValueError(f"Failed to create access token: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def create_refresh_token(user_data: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Create JWT REFRESH token with minimal payload structure
|
||||
✅ FIXED: Only creates refresh tokens, different from access tokens
|
||||
"""
|
||||
|
||||
# Validate required fields for refresh token
|
||||
if "user_id" not in user_data:
|
||||
raise ValueError("user_id required for refresh token creation")
|
||||
|
||||
if not user_data.get("user_id"):
|
||||
raise ValueError("user_id cannot be empty")
|
||||
|
||||
try:
|
||||
expire = datetime.now(timezone.utc) + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
# ✅ FIX 3: REFRESH TOKEN payload structure (minimal, different from access)
|
||||
payload = {
|
||||
"sub": user_data["user_id"],
|
||||
"user_id": user_data["user_id"],
|
||||
"type": "refresh", # ✅ EXPLICITLY set as refresh token
|
||||
"exp": expire,
|
||||
"iat": datetime.now(timezone.utc),
|
||||
"iss": "bakery-auth"
|
||||
}
|
||||
|
||||
# Add unique JTI for refresh tokens to prevent duplicates
|
||||
if "jti" in user_data:
|
||||
payload["jti"] = user_data["jti"]
|
||||
else:
|
||||
import uuid
|
||||
payload["jti"] = str(uuid.uuid4())
|
||||
|
||||
# Include email only if available (optional for refresh tokens)
|
||||
if "email" in user_data and user_data["email"]:
|
||||
payload["email"] = user_data["email"]
|
||||
|
||||
logger.debug(f"Creating refresh token with payload keys: {list(payload.keys())}")
|
||||
|
||||
# ✅ FIX 4: Use JWT handler to create REFRESH token (not access token!)
|
||||
token = jwt_handler.create_refresh_token_from_payload(payload)
|
||||
logger.debug(f"Refresh token created successfully for user {user_data['user_id']}")
|
||||
return token
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Refresh token creation failed for {user_data.get('user_id', 'unknown')}: {e}")
|
||||
raise ValueError(f"Failed to create refresh token: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def verify_token(token: str) -> Optional[Dict[str, Any]]:
|
||||
"""Verify JWT token with enhanced error handling"""
|
||||
try:
|
||||
payload = jwt_handler.verify_token(token)
|
||||
if payload:
|
||||
logger.debug(f"Token verified successfully for user: {payload.get('email', 'unknown')}")
|
||||
return payload
|
||||
except Exception as e:
|
||||
logger.warning(f"Token verification failed: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def decode_token(token: str) -> Dict[str, Any]:
|
||||
"""Decode JWT token without verification (for refresh token handling)"""
|
||||
try:
|
||||
payload = jwt_handler.decode_token_no_verify(token)
|
||||
return payload
|
||||
except Exception as e:
|
||||
logger.error(f"Token decoding failed: {e}")
|
||||
raise ValueError("Invalid token format")
|
||||
|
||||
@staticmethod
|
||||
def generate_secure_hash(data: str) -> str:
|
||||
"""Generate secure hash for token storage"""
|
||||
return hashlib.sha256(data.encode()).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def create_service_token(service_name: str, tenant_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
Create JWT service token for inter-service communication
|
||||
✅ UNIFIED: Uses shared JWT handler for consistent token creation
|
||||
✅ ENHANCED: Supports tenant context for tenant-scoped operations
|
||||
|
||||
Args:
|
||||
service_name: Name of the service (e.g., 'auth-service', 'tenant-service')
|
||||
tenant_id: Optional tenant ID for tenant-scoped service operations
|
||||
|
||||
Returns:
|
||||
Encoded JWT service token
|
||||
"""
|
||||
try:
|
||||
# Use unified JWT handler to create service token
|
||||
token = jwt_handler.create_service_token(
|
||||
service_name=service_name,
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
logger.debug(f"Created service token for {service_name}", tenant_id=tenant_id)
|
||||
return token
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create service token for {service_name}: {e}")
|
||||
raise ValueError(f"Failed to create service token: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
async def track_login_attempt(email: str, ip_address: str, success: bool) -> None:
|
||||
"""Track login attempts for security monitoring"""
|
||||
try:
|
||||
# This would use Redis for production
|
||||
# For now, just log the attempt
|
||||
logger.info(f"Login attempt tracked: email={email}, ip={ip_address}, success={success}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to track login attempt: {e}")
|
||||
|
||||
@staticmethod
|
||||
def is_token_expired(token: str) -> bool:
|
||||
"""Check if token is expired"""
|
||||
try:
|
||||
payload = SecurityManager.decode_token(token)
|
||||
exp_timestamp = payload.get("exp")
|
||||
if exp_timestamp:
|
||||
exp_datetime = datetime.fromtimestamp(exp_timestamp, tz=timezone.utc)
|
||||
return datetime.now(timezone.utc) > exp_datetime
|
||||
return True
|
||||
except Exception:
|
||||
return True
|
||||
@staticmethod
|
||||
def verify_token(token: str) -> Optional[Dict[str, Any]]:
|
||||
"""Verify JWT token with enhanced error handling"""
|
||||
try:
|
||||
payload = jwt_handler.verify_token(token)
|
||||
if payload:
|
||||
logger.debug(f"Token verified successfully for user: {payload.get('email', 'unknown')}")
|
||||
return payload
|
||||
except Exception as e:
|
||||
logger.warning(f"Token verification failed: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def track_login_attempt(email: str, ip_address: str, success: bool) -> None:
|
||||
"""Track login attempts for security monitoring"""
|
||||
try:
|
||||
redis_client = await get_redis_client()
|
||||
key = f"login_attempts:{email}:{ip_address}"
|
||||
|
||||
if success:
|
||||
# Clear failed attempts on successful login
|
||||
await redis_client.delete(key)
|
||||
else:
|
||||
# Increment failed attempts
|
||||
attempts = await redis_client.incr(key)
|
||||
if attempts == 1:
|
||||
# Set expiration on first failed attempt
|
||||
await redis_client.expire(key, settings.LOCKOUT_DURATION_MINUTES * 60)
|
||||
|
||||
if attempts >= settings.MAX_LOGIN_ATTEMPTS:
|
||||
logger.warning(f"Account locked for {email} from {ip_address}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail=f"Too many failed login attempts. Try again in {settings.LOCKOUT_DURATION_MINUTES} minutes."
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise # Re-raise HTTPException
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to track login attempt: {e}")
|
||||
|
||||
@staticmethod
|
||||
async def is_account_locked(email: str, ip_address: str) -> bool:
|
||||
"""Check if account is locked due to failed login attempts"""
|
||||
try:
|
||||
redis_client = await get_redis_client()
|
||||
key = f"login_attempts:{email}:{ip_address}"
|
||||
attempts = await redis_client.get(key)
|
||||
|
||||
if attempts:
|
||||
attempts = int(attempts)
|
||||
return attempts >= settings.MAX_LOGIN_ATTEMPTS
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check account lock status: {e}")
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def hash_api_key(api_key: str) -> str:
|
||||
"""Hash API key for storage"""
|
||||
return hashlib.sha256(api_key.encode()).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def generate_secure_token(length: int = 32) -> str:
|
||||
"""Generate secure random token"""
|
||||
import secrets
|
||||
return secrets.token_urlsafe(length)
|
||||
|
||||
@staticmethod
|
||||
def generate_reset_token() -> str:
|
||||
"""Generate a secure password reset token"""
|
||||
import secrets
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
@staticmethod
|
||||
def mask_sensitive_data(data: str, visible_chars: int = 4) -> str:
|
||||
"""Mask sensitive data for logging"""
|
||||
if not data or len(data) <= visible_chars:
|
||||
return "*" * len(data) if data else ""
|
||||
|
||||
return data[:visible_chars] + "*" * (len(data) - visible_chars)
|
||||
|
||||
@staticmethod
|
||||
async def check_login_attempts(email: str) -> bool:
|
||||
"""Check if user has exceeded login attempts"""
|
||||
try:
|
||||
redis_client = await get_redis_client()
|
||||
key = f"login_attempts:{email}"
|
||||
attempts = await redis_client.get(key)
|
||||
|
||||
if attempts is None:
|
||||
return True
|
||||
|
||||
return int(attempts) < settings.MAX_LOGIN_ATTEMPTS
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking login attempts: {e}")
|
||||
return True # Allow on error
|
||||
|
||||
@staticmethod
|
||||
async def increment_login_attempts(email: str) -> None:
|
||||
"""Increment login attempts for email"""
|
||||
try:
|
||||
redis_client = await get_redis_client()
|
||||
key = f"login_attempts:{email}"
|
||||
await redis_client.incr(key)
|
||||
await redis_client.expire(key, settings.LOCKOUT_DURATION_MINUTES * 60)
|
||||
except Exception as e:
|
||||
logger.error(f"Error incrementing login attempts: {e}")
|
||||
|
||||
@staticmethod
|
||||
async def clear_login_attempts(email: str) -> None:
|
||||
"""Clear login attempts for email after successful login"""
|
||||
try:
|
||||
redis_client = await get_redis_client()
|
||||
key = f"login_attempts:{email}"
|
||||
await redis_client.delete(key)
|
||||
logger.debug(f"Cleared login attempts for {email}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing login attempts: {e}")
|
||||
|
||||
@staticmethod
|
||||
async def store_refresh_token(user_id: str, token: str) -> None:
|
||||
"""Store refresh token in Redis"""
|
||||
try:
|
||||
redis_client = await get_redis_client()
|
||||
token_hash = SecurityManager.hash_api_key(token) # Reuse hash method
|
||||
key = f"refresh_token:{user_id}:{token_hash}"
|
||||
|
||||
# Store with expiration matching JWT refresh token expiry
|
||||
expire_seconds = settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60
|
||||
await redis_client.setex(key, expire_seconds, "valid")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing refresh token: {e}")
|
||||
|
||||
@staticmethod
|
||||
async def is_refresh_token_valid(user_id: str, token: str) -> bool:
|
||||
"""Check if refresh token is still valid in Redis"""
|
||||
try:
|
||||
redis_client = await get_redis_client()
|
||||
token_hash = SecurityManager.hash_api_key(token)
|
||||
key = f"refresh_token:{user_id}:{token_hash}"
|
||||
|
||||
exists = await redis_client.exists(key)
|
||||
return bool(exists)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking refresh token validity: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def revoke_refresh_token(user_id: str, token: str) -> None:
|
||||
"""Revoke refresh token by removing from Redis"""
|
||||
try:
|
||||
redis_client = await get_redis_client()
|
||||
token_hash = SecurityManager.hash_api_key(token)
|
||||
key = f"refresh_token:{user_id}:{token_hash}"
|
||||
|
||||
await redis_client.delete(key)
|
||||
logger.debug(f"Revoked refresh token for user {user_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error revoking refresh token: {e}")
|
||||
225
services/auth/app/main.py
Normal file
225
services/auth/app/main.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""
|
||||
Authentication Service Main Application
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI
|
||||
from sqlalchemy import text
|
||||
from app.core.config import settings
|
||||
from app.core.database import database_manager
|
||||
from app.api import auth_operations, users, onboarding_progress, consent, data_export, account_deletion, internal_demo, password_reset
|
||||
from shared.service_base import StandardFastAPIService
|
||||
from shared.messaging import UnifiedEventPublisher
|
||||
|
||||
|
||||
class AuthService(StandardFastAPIService):
|
||||
"""Authentication Service with standardized setup"""
|
||||
|
||||
async def on_startup(self, app):
|
||||
"""Custom startup logic including migration verification and Redis initialization"""
|
||||
self.logger.info("Starting auth service on_startup")
|
||||
await self.verify_migrations()
|
||||
|
||||
# Initialize Redis if not already done during service creation
|
||||
if not self.redis_initialized:
|
||||
try:
|
||||
from shared.redis_utils import initialize_redis, get_redis_client
|
||||
await initialize_redis(settings.REDIS_URL_WITH_DB, db=settings.REDIS_DB, max_connections=getattr(settings, 'REDIS_MAX_CONNECTIONS', 50))
|
||||
self.redis_client = await get_redis_client()
|
||||
self.redis_initialized = True
|
||||
self.logger.info("Connected to Redis for token management")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to connect to Redis during startup: {e}")
|
||||
raise
|
||||
|
||||
await super().on_startup(app)
|
||||
|
||||
async def on_shutdown(self, app):
|
||||
"""Custom shutdown logic for Auth Service"""
|
||||
await super().on_shutdown(app)
|
||||
|
||||
# Close Redis
|
||||
from shared.redis_utils import close_redis
|
||||
await close_redis()
|
||||
self.logger.info("Redis connection closed")
|
||||
|
||||
async def verify_migrations(self):
|
||||
"""Verify database schema matches the latest migrations."""
|
||||
try:
|
||||
async with self.database_manager.get_session() as session:
|
||||
# Check if alembic_version table exists
|
||||
result = await session.execute(text("""
|
||||
SELECT EXISTS (
|
||||
SELECT FROM information_schema.tables
|
||||
WHERE table_schema = 'public'
|
||||
AND table_name = 'alembic_version'
|
||||
)
|
||||
"""))
|
||||
table_exists = result.scalar()
|
||||
|
||||
if table_exists:
|
||||
# If table exists, check the version
|
||||
result = await session.execute(text("SELECT version_num FROM alembic_version"))
|
||||
version = result.scalar()
|
||||
self.logger.info(f"Migration verification successful: {version}")
|
||||
else:
|
||||
# If table doesn't exist, migrations might not have run yet
|
||||
# This is OK - the migration job should create it
|
||||
self.logger.warning("alembic_version table does not exist yet - migrations may not have run")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Migration verification failed (this may be expected during initial setup): {e}")
|
||||
|
||||
def __init__(self):
|
||||
# Initialize Redis during service creation so it's available when needed
|
||||
try:
|
||||
import asyncio
|
||||
# We need to run the async initialization in a sync context
|
||||
try:
|
||||
# Check if there's already a running event loop
|
||||
loop = asyncio.get_running_loop()
|
||||
# If there is, we'll initialize Redis later in on_startup
|
||||
self.redis_initialized = False
|
||||
self.redis_client = None
|
||||
except RuntimeError:
|
||||
# No event loop running, safe to run the async function
|
||||
import asyncio
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply() # Allow nested event loops
|
||||
|
||||
async def init_redis():
|
||||
from shared.redis_utils import initialize_redis, get_redis_client
|
||||
await initialize_redis(settings.REDIS_URL_WITH_DB, db=settings.REDIS_DB, max_connections=getattr(settings, 'REDIS_MAX_CONNECTIONS', 50))
|
||||
return await get_redis_client()
|
||||
|
||||
self.redis_client = asyncio.run(init_redis())
|
||||
self.redis_initialized = True
|
||||
self.logger.info("Connected to Redis for token management")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to initialize Redis during service creation: {e}")
|
||||
self.redis_initialized = False
|
||||
self.redis_client = None
|
||||
|
||||
# Define expected database tables for health checks
|
||||
auth_expected_tables = [
|
||||
'users', 'refresh_tokens', 'user_onboarding_progress',
|
||||
'user_onboarding_summary', 'login_attempts', 'user_consents',
|
||||
'consent_history', 'audit_logs'
|
||||
]
|
||||
|
||||
# Define custom metrics for auth service
|
||||
auth_custom_metrics = {
|
||||
"registration_total": {
|
||||
"type": "counter",
|
||||
"description": "Total user registrations by status",
|
||||
"labels": ["status"]
|
||||
},
|
||||
"login_success_total": {
|
||||
"type": "counter",
|
||||
"description": "Total successful user logins"
|
||||
},
|
||||
"login_failure_total": {
|
||||
"type": "counter",
|
||||
"description": "Total failed user logins by reason",
|
||||
"labels": ["reason"]
|
||||
},
|
||||
"token_refresh_total": {
|
||||
"type": "counter",
|
||||
"description": "Total token refreshes by status",
|
||||
"labels": ["status"]
|
||||
},
|
||||
"token_verify_total": {
|
||||
"type": "counter",
|
||||
"description": "Total token verifications by status",
|
||||
"labels": ["status"]
|
||||
},
|
||||
"logout_total": {
|
||||
"type": "counter",
|
||||
"description": "Total user logouts by status",
|
||||
"labels": ["status"]
|
||||
},
|
||||
"registration_duration_seconds": {
|
||||
"type": "histogram",
|
||||
"description": "Registration request duration"
|
||||
},
|
||||
"login_duration_seconds": {
|
||||
"type": "histogram",
|
||||
"description": "Login request duration"
|
||||
},
|
||||
"token_refresh_duration_seconds": {
|
||||
"type": "histogram",
|
||||
"description": "Token refresh duration"
|
||||
}
|
||||
}
|
||||
|
||||
super().__init__(
|
||||
service_name="auth-service",
|
||||
app_name="Authentication Service",
|
||||
description="Handles user authentication and authorization for bakery forecasting platform",
|
||||
version="1.0.0",
|
||||
log_level=settings.LOG_LEVEL,
|
||||
api_prefix="", # Empty because RouteBuilder already includes /api/v1
|
||||
database_manager=database_manager,
|
||||
expected_tables=auth_expected_tables,
|
||||
enable_messaging=True,
|
||||
custom_metrics=auth_custom_metrics
|
||||
)
|
||||
|
||||
async def _setup_messaging(self):
|
||||
"""Setup messaging for auth service"""
|
||||
from shared.messaging import RabbitMQClient
|
||||
try:
|
||||
self.rabbitmq_client = RabbitMQClient(settings.RABBITMQ_URL, service_name="auth-service")
|
||||
await self.rabbitmq_client.connect()
|
||||
# Create event publisher
|
||||
self.event_publisher = UnifiedEventPublisher(self.rabbitmq_client, "auth-service")
|
||||
self.logger.info("Auth service messaging setup completed")
|
||||
except Exception as e:
|
||||
self.logger.error("Failed to setup auth messaging", error=str(e))
|
||||
raise
|
||||
|
||||
async def _cleanup_messaging(self):
|
||||
"""Cleanup messaging for auth service"""
|
||||
try:
|
||||
if self.rabbitmq_client:
|
||||
await self.rabbitmq_client.disconnect()
|
||||
self.logger.info("Auth service messaging cleanup completed")
|
||||
except Exception as e:
|
||||
self.logger.error("Error during auth messaging cleanup", error=str(e))
|
||||
|
||||
async def on_shutdown(self, app: FastAPI):
|
||||
"""Custom shutdown logic for auth service"""
|
||||
self.logger.info("Authentication Service shutdown complete")
|
||||
|
||||
def get_service_features(self):
|
||||
"""Return auth-specific features"""
|
||||
return [
|
||||
"user_authentication",
|
||||
"token_management",
|
||||
"user_onboarding",
|
||||
"role_based_access",
|
||||
"messaging_integration"
|
||||
]
|
||||
|
||||
|
||||
# Create service instance
|
||||
service = AuthService()
|
||||
|
||||
# Create FastAPI app with standardized setup
|
||||
app = service.create_app(
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc"
|
||||
)
|
||||
|
||||
# Setup standard endpoints
|
||||
service.setup_standard_endpoints()
|
||||
|
||||
# Include routers with specific configurations
|
||||
# Note: Routes now use RouteBuilder which includes full paths, so no prefix needed
|
||||
service.add_router(auth_operations.router, tags=["authentication"])
|
||||
service.add_router(users.router, tags=["users"])
|
||||
service.add_router(onboarding_progress.router, tags=["onboarding"])
|
||||
service.add_router(consent.router, tags=["gdpr", "consent"])
|
||||
service.add_router(data_export.router, tags=["gdpr", "data-export"])
|
||||
service.add_router(account_deletion.router, tags=["gdpr", "account-deletion"])
|
||||
service.add_router(internal_demo.router, tags=["internal-demo"])
|
||||
service.add_router(password_reset.router, tags=["password-reset"])
|
||||
31
services/auth/app/models/__init__.py
Normal file
31
services/auth/app/models/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# services/auth/app/models/__init__.py
|
||||
"""
|
||||
Models export for auth service
|
||||
"""
|
||||
|
||||
# Import AuditLog model for this service
|
||||
from shared.security import create_audit_log_model
|
||||
from shared.database.base import Base
|
||||
|
||||
# Create audit log model for this service
|
||||
AuditLog = create_audit_log_model(Base)
|
||||
|
||||
from .users import User
|
||||
from .tokens import RefreshToken, LoginAttempt
|
||||
from .onboarding import UserOnboardingProgress, UserOnboardingSummary
|
||||
from .consent import UserConsent, ConsentHistory
|
||||
from .deletion_job import DeletionJob
|
||||
from .password_reset_tokens import PasswordResetToken
|
||||
|
||||
__all__ = [
|
||||
'User',
|
||||
'RefreshToken',
|
||||
'LoginAttempt',
|
||||
'UserOnboardingProgress',
|
||||
'UserOnboardingSummary',
|
||||
'UserConsent',
|
||||
'ConsentHistory',
|
||||
'DeletionJob',
|
||||
'PasswordResetToken',
|
||||
"AuditLog",
|
||||
]
|
||||
110
services/auth/app/models/consent.py
Normal file
110
services/auth/app/models/consent.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""
|
||||
User consent tracking models for GDPR compliance
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Index
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSON
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
|
||||
from shared.database.base import Base
|
||||
|
||||
|
||||
class UserConsent(Base):
|
||||
"""
|
||||
Tracks user consent for various data processing activities
|
||||
GDPR Article 7 - Conditions for consent
|
||||
"""
|
||||
__tablename__ = "user_consents"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
|
||||
# Consent types
|
||||
terms_accepted = Column(Boolean, nullable=False, default=False)
|
||||
privacy_accepted = Column(Boolean, nullable=False, default=False)
|
||||
marketing_consent = Column(Boolean, nullable=False, default=False)
|
||||
analytics_consent = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
# Consent metadata
|
||||
consent_version = Column(String(20), nullable=False, default="1.0")
|
||||
consent_method = Column(String(50), nullable=False) # registration, settings_update, cookie_banner
|
||||
ip_address = Column(String(45), nullable=True)
|
||||
user_agent = Column(Text, nullable=True)
|
||||
|
||||
# Consent text at time of acceptance
|
||||
terms_text_hash = Column(String(64), nullable=True)
|
||||
privacy_text_hash = Column(String(64), nullable=True)
|
||||
|
||||
# Timestamps
|
||||
consented_at = Column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc))
|
||||
withdrawn_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Additional metadata (renamed from 'metadata' to avoid SQLAlchemy reserved word)
|
||||
extra_data = Column(JSON, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_user_consent_user_id', 'user_id'),
|
||||
Index('idx_user_consent_consented_at', 'consented_at'),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<UserConsent(user_id={self.user_id}, version={self.consent_version})>"
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"user_id": str(self.user_id),
|
||||
"terms_accepted": self.terms_accepted,
|
||||
"privacy_accepted": self.privacy_accepted,
|
||||
"marketing_consent": self.marketing_consent,
|
||||
"analytics_consent": self.analytics_consent,
|
||||
"consent_version": self.consent_version,
|
||||
"consent_method": self.consent_method,
|
||||
"consented_at": self.consented_at.isoformat() if self.consented_at else None,
|
||||
"withdrawn_at": self.withdrawn_at.isoformat() if self.withdrawn_at else None,
|
||||
}
|
||||
|
||||
|
||||
class ConsentHistory(Base):
|
||||
"""
|
||||
Historical record of all consent changes
|
||||
Provides audit trail for GDPR compliance
|
||||
"""
|
||||
__tablename__ = "consent_history"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
consent_id = Column(UUID(as_uuid=True), ForeignKey("user_consents.id", ondelete="SET NULL"), nullable=True)
|
||||
|
||||
# Action type
|
||||
action = Column(String(50), nullable=False) # granted, updated, withdrawn, revoked
|
||||
|
||||
# Consent state at time of action
|
||||
consent_snapshot = Column(JSON, nullable=False)
|
||||
|
||||
# Context
|
||||
ip_address = Column(String(45), nullable=True)
|
||||
user_agent = Column(Text, nullable=True)
|
||||
consent_method = Column(String(50), nullable=True)
|
||||
|
||||
# Timestamp
|
||||
created_at = Column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc), index=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_consent_history_user_id', 'user_id'),
|
||||
Index('idx_consent_history_created_at', 'created_at'),
|
||||
Index('idx_consent_history_action', 'action'),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ConsentHistory(user_id={self.user_id}, action={self.action})>"
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"user_id": str(self.user_id),
|
||||
"action": self.action,
|
||||
"consent_snapshot": self.consent_snapshot,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
}
|
||||
64
services/auth/app/models/deletion_job.py
Normal file
64
services/auth/app/models/deletion_job.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
Deletion Job Model
|
||||
Tracks tenant deletion jobs for persistence and recovery
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, DateTime, Text, JSON, Index, Integer
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.sql import func
|
||||
import uuid
|
||||
|
||||
from shared.database.base import Base
|
||||
|
||||
|
||||
class DeletionJob(Base):
|
||||
"""
|
||||
Persistent storage for tenant deletion jobs
|
||||
Enables job recovery and tracking across service restarts
|
||||
"""
|
||||
__tablename__ = "deletion_jobs"
|
||||
|
||||
# Primary identifiers
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
||||
job_id = Column(String(100), nullable=False, unique=True, index=True) # External job ID
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
|
||||
# Job Metadata
|
||||
tenant_name = Column(String(255), nullable=True)
|
||||
initiated_by = Column(UUID(as_uuid=True), nullable=True) # User ID who started deletion
|
||||
|
||||
# Job Status
|
||||
status = Column(String(50), nullable=False, default="pending", index=True) # pending, in_progress, completed, failed, rolled_back
|
||||
|
||||
# Service Results
|
||||
service_results = Column(JSON, nullable=True) # Dict of service_name -> result details
|
||||
|
||||
# Progress Tracking
|
||||
total_items_deleted = Column(Integer, default=0, nullable=False)
|
||||
services_completed = Column(Integer, default=0, nullable=False)
|
||||
services_failed = Column(Integer, default=0, nullable=False)
|
||||
|
||||
# Error Tracking
|
||||
error_log = Column(JSON, nullable=True) # Array of error messages
|
||||
|
||||
# Timestamps
|
||||
started_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||
completed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||
|
||||
# Additional Context
|
||||
notes = Column(Text, nullable=True)
|
||||
extra_metadata = Column(JSON, nullable=True) # Additional job-specific data
|
||||
|
||||
# Indexes for performance
|
||||
__table_args__ = (
|
||||
Index('idx_deletion_job_id', 'job_id'),
|
||||
Index('idx_deletion_tenant_id', 'tenant_id'),
|
||||
Index('idx_deletion_status', 'status'),
|
||||
Index('idx_deletion_started_at', 'started_at'),
|
||||
Index('idx_deletion_tenant_status', 'tenant_id', 'status'),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<DeletionJob(job_id='{self.job_id}', tenant_id={self.tenant_id}, status='{self.status}')>"
|
||||
91
services/auth/app/models/onboarding.py
Normal file
91
services/auth/app/models/onboarding.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# services/auth/app/models/onboarding.py
|
||||
"""
|
||||
User onboarding progress models
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, JSON, UniqueConstraint
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
|
||||
from shared.database.base import Base
|
||||
|
||||
class UserOnboardingProgress(Base):
|
||||
"""User onboarding progress tracking model"""
|
||||
__tablename__ = "user_onboarding_progress"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
|
||||
# Step tracking
|
||||
step_name = Column(String(50), nullable=False)
|
||||
completed = Column(Boolean, default=False, nullable=False)
|
||||
completed_at = Column(DateTime(timezone=True))
|
||||
|
||||
# Additional step data (JSON field for flexibility)
|
||||
step_data = Column(JSON, default=dict)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
||||
# Unique constraint to prevent duplicate step entries per user
|
||||
__table_args__ = (
|
||||
UniqueConstraint('user_id', 'step_name', name='uq_user_step'),
|
||||
{'extend_existing': True}
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<UserOnboardingProgress(id={self.id}, user_id={self.user_id}, step={self.step_name}, completed={self.completed})>"
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert to dictionary"""
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"user_id": str(self.user_id),
|
||||
"step_name": self.step_name,
|
||||
"completed": self.completed,
|
||||
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
||||
"step_data": self.step_data or {},
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None
|
||||
}
|
||||
|
||||
class UserOnboardingSummary(Base):
|
||||
"""User onboarding summary for quick lookups"""
|
||||
__tablename__ = "user_onboarding_summary"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, unique=True, index=True)
|
||||
|
||||
# Summary fields
|
||||
current_step = Column(String(50), nullable=False, default="user_registered")
|
||||
next_step = Column(String(50))
|
||||
completion_percentage = Column(String(50), default="0.0") # Store as string for precision
|
||||
fully_completed = Column(Boolean, default=False)
|
||||
|
||||
# Progress tracking
|
||||
steps_completed_count = Column(String(50), default="0") # Store as string: "3/5"
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
last_activity_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
|
||||
def __repr__(self):
|
||||
return f"<UserOnboardingSummary(user_id={self.user_id}, current_step={self.current_step}, completion={self.completion_percentage}%)>"
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert to dictionary"""
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"user_id": str(self.user_id),
|
||||
"current_step": self.current_step,
|
||||
"next_step": self.next_step,
|
||||
"completion_percentage": float(self.completion_percentage) if self.completion_percentage else 0.0,
|
||||
"fully_completed": self.fully_completed,
|
||||
"steps_completed_count": self.steps_completed_count,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
"last_activity_at": self.last_activity_at.isoformat() if self.last_activity_at else None
|
||||
}
|
||||
39
services/auth/app/models/password_reset_tokens.py
Normal file
39
services/auth/app/models/password_reset_tokens.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# services/auth/app/models/password_reset_tokens.py
|
||||
"""
|
||||
Password reset token model for authentication service
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy import Column, String, DateTime, Boolean, Index
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from shared.database.base import Base
|
||||
|
||||
|
||||
class PasswordResetToken(Base):
|
||||
"""
|
||||
Password reset token model
|
||||
Stores temporary tokens for password reset functionality
|
||||
"""
|
||||
__tablename__ = "password_reset_tokens"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
token = Column(String(255), nullable=False, unique=True, index=True)
|
||||
expires_at = Column(DateTime(timezone=True), nullable=False)
|
||||
is_used = Column(Boolean, default=False, nullable=False)
|
||||
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
used_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Add indexes for better performance
|
||||
__table_args__ = (
|
||||
Index('ix_password_reset_tokens_user_id', 'user_id'),
|
||||
Index('ix_password_reset_tokens_token', 'token'),
|
||||
Index('ix_password_reset_tokens_expires_at', 'expires_at'),
|
||||
Index('ix_password_reset_tokens_is_used', 'is_used'),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<PasswordResetToken(id={self.id}, user_id={self.user_id}, token={self.token[:10]}..., is_used={self.is_used})>"
|
||||
92
services/auth/app/models/tokens.py
Normal file
92
services/auth/app/models/tokens.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# ================================================================
|
||||
# services/auth/app/models/tokens.py
|
||||
# ================================================================
|
||||
"""
|
||||
Token models for authentication service
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Text, Index
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from shared.database.base import Base
|
||||
|
||||
class RefreshToken(Base):
|
||||
"""
|
||||
Refresh token model - FIXED to prevent duplicate constraint violations
|
||||
"""
|
||||
__tablename__ = "refresh_tokens"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
|
||||
# ✅ FIX 1: Use TEXT instead of VARCHAR to handle longer tokens
|
||||
token = Column(Text, nullable=False)
|
||||
|
||||
# ✅ FIX 2: Add token hash for uniqueness instead of full token
|
||||
token_hash = Column(String(255), nullable=True, unique=True)
|
||||
|
||||
expires_at = Column(DateTime(timezone=True), nullable=False)
|
||||
is_revoked = Column(Boolean, default=False, nullable=False)
|
||||
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
revoked_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# ✅ FIX 3: Add indexes for better performance
|
||||
__table_args__ = (
|
||||
Index('ix_refresh_tokens_user_id_active', 'user_id', 'is_revoked'),
|
||||
Index('ix_refresh_tokens_expires_at', 'expires_at'),
|
||||
Index('ix_refresh_tokens_token_hash', 'token_hash'),
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize refresh token with automatic hash generation"""
|
||||
super().__init__(**kwargs)
|
||||
if self.token and not self.token_hash:
|
||||
self.token_hash = self._generate_token_hash(self.token)
|
||||
|
||||
@staticmethod
|
||||
def _generate_token_hash(token: str) -> str:
|
||||
"""Generate a hash of the token for uniqueness checking"""
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
def update_token(self, new_token: str):
|
||||
"""Update token and regenerate hash"""
|
||||
self.token = new_token
|
||||
self.token_hash = self._generate_token_hash(new_token)
|
||||
|
||||
@classmethod
|
||||
async def create_refresh_token(cls, user_id: uuid.UUID, token: str, expires_at: datetime):
|
||||
"""
|
||||
Create a new refresh token with proper hash generation
|
||||
"""
|
||||
return cls(
|
||||
id=uuid.uuid4(),
|
||||
user_id=user_id,
|
||||
token=token,
|
||||
token_hash=cls._generate_token_hash(token),
|
||||
expires_at=expires_at,
|
||||
is_revoked=False,
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<RefreshToken(id={self.id}, user_id={self.user_id}, expires_at={self.expires_at})>"
|
||||
|
||||
class LoginAttempt(Base):
|
||||
"""Login attempt tracking model"""
|
||||
__tablename__ = "login_attempts"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
email = Column(String(255), nullable=False, index=True)
|
||||
ip_address = Column(String(45), nullable=False)
|
||||
user_agent = Column(Text)
|
||||
success = Column(Boolean, default=False)
|
||||
failure_reason = Column(String(255))
|
||||
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
|
||||
def __repr__(self):
|
||||
return f"<LoginAttempt(id={self.id}, email={self.email}, success={self.success})>"
|
||||
61
services/auth/app/models/users.py
Normal file
61
services/auth/app/models/users.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# services/auth/app/models/users.py - FIXED VERSION
|
||||
"""
|
||||
User models for authentication service - FIXED
|
||||
Removed tenant relationships to eliminate cross-service dependencies
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
|
||||
from shared.database.base import Base
|
||||
|
||||
class User(Base):
|
||||
"""User model - FIXED without cross-service relationships"""
|
||||
__tablename__ = "users"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
email = Column(String(255), unique=True, index=True, nullable=False)
|
||||
hashed_password = Column(String(255), nullable=False)
|
||||
full_name = Column(String(255), nullable=False)
|
||||
is_active = Column(Boolean, default=True)
|
||||
is_verified = Column(Boolean, default=False)
|
||||
|
||||
# Timezone-aware datetime fields
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
last_login = Column(DateTime(timezone=True))
|
||||
|
||||
# Profile fields
|
||||
phone = Column(String(20))
|
||||
language = Column(String(10), default="es")
|
||||
timezone = Column(String(50), default="Europe/Madrid")
|
||||
role = Column(String(20), nullable=False)
|
||||
|
||||
# Payment integration fields
|
||||
payment_customer_id = Column(String(255), nullable=True, index=True)
|
||||
default_payment_method_id = Column(String(255), nullable=True)
|
||||
|
||||
# REMOVED: All tenant relationships - these are handled by tenant service
|
||||
# No tenant_memberships, tenants relationships
|
||||
|
||||
def __repr__(self):
|
||||
return f"<User(id={self.id}, email={self.email})>"
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert user to dictionary"""
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"email": self.email,
|
||||
"full_name": self.full_name,
|
||||
"is_active": self.is_active,
|
||||
"is_verified": self.is_verified,
|
||||
"phone": self.phone,
|
||||
"language": self.language,
|
||||
"timezone": self.timezone,
|
||||
"role": self.role,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
"last_login": self.last_login.isoformat() if self.last_login else None
|
||||
}
|
||||
16
services/auth/app/repositories/__init__.py
Normal file
16
services/auth/app/repositories/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
Auth Service Repositories
|
||||
Repository implementations for authentication service
|
||||
"""
|
||||
|
||||
from .base import AuthBaseRepository
|
||||
from .user_repository import UserRepository
|
||||
from .token_repository import TokenRepository
|
||||
from .onboarding_repository import OnboardingRepository
|
||||
|
||||
__all__ = [
|
||||
"AuthBaseRepository",
|
||||
"UserRepository",
|
||||
"TokenRepository",
|
||||
"OnboardingRepository"
|
||||
]
|
||||
101
services/auth/app/repositories/base.py
Normal file
101
services/auth/app/repositories/base.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Base Repository for Auth Service
|
||||
Service-specific repository base class with auth service utilities
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any, Type
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from datetime import datetime, timezone
|
||||
import structlog
|
||||
|
||||
from shared.database.repository import BaseRepository
|
||||
from shared.database.exceptions import DatabaseError, ValidationError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class AuthBaseRepository(BaseRepository):
|
||||
"""Base repository for auth service with common auth operations"""
|
||||
|
||||
def __init__(self, model: Type, session: AsyncSession, cache_ttl: Optional[int] = 600):
|
||||
# Auth data benefits from longer caching (10 minutes)
|
||||
super().__init__(model, session, cache_ttl)
|
||||
|
||||
async def get_active_records(self, skip: int = 0, limit: int = 100) -> List:
|
||||
"""Get active records (if model has is_active field)"""
|
||||
if hasattr(self.model, 'is_active'):
|
||||
return await self.get_multi(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
filters={"is_active": True},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
return await self.get_multi(skip=skip, limit=limit)
|
||||
|
||||
async def get_by_email(self, email: str) -> Optional:
|
||||
"""Get record by email (if model has email field)"""
|
||||
if hasattr(self.model, 'email'):
|
||||
return await self.get_by_field("email", email)
|
||||
return None
|
||||
|
||||
async def get_by_username(self, username: str) -> Optional:
|
||||
"""Get record by username (if model has username field)"""
|
||||
if hasattr(self.model, 'username'):
|
||||
return await self.get_by_field("username", username)
|
||||
return None
|
||||
|
||||
async def deactivate_record(self, record_id: Any) -> Optional:
|
||||
"""Deactivate a record instead of deleting it"""
|
||||
if hasattr(self.model, 'is_active'):
|
||||
return await self.update(record_id, {"is_active": False})
|
||||
return await self.delete(record_id)
|
||||
|
||||
async def activate_record(self, record_id: Any) -> Optional:
|
||||
"""Activate a record"""
|
||||
if hasattr(self.model, 'is_active'):
|
||||
return await self.update(record_id, {"is_active": True})
|
||||
return await self.get_by_id(record_id)
|
||||
|
||||
async def cleanup_expired_records(self, field_name: str = "expires_at") -> int:
|
||||
"""Clean up expired records (for tokens, sessions, etc.)"""
|
||||
try:
|
||||
if not hasattr(self.model, field_name):
|
||||
logger.warning(f"Model {self.model.__name__} has no {field_name} field for cleanup")
|
||||
return 0
|
||||
|
||||
# This would need custom implementation with raw SQL for date comparison
|
||||
# For now, return 0 to indicate no cleanup performed
|
||||
logger.info(f"Cleanup requested for {self.model.__name__} but not implemented")
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup expired records",
|
||||
model=self.model.__name__,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Cleanup failed: {str(e)}")
|
||||
|
||||
def _validate_auth_data(self, data: Dict[str, Any], required_fields: List[str]) -> Dict[str, Any]:
|
||||
"""Validate authentication-related data"""
|
||||
errors = []
|
||||
|
||||
for field in required_fields:
|
||||
if field not in data or not data[field]:
|
||||
errors.append(f"Missing required field: {field}")
|
||||
|
||||
# Validate email format if present
|
||||
if "email" in data and data["email"]:
|
||||
email = data["email"]
|
||||
if "@" not in email or "." not in email.split("@")[-1]:
|
||||
errors.append("Invalid email format")
|
||||
|
||||
# Validate password strength if present
|
||||
if "password" in data and data["password"]:
|
||||
password = data["password"]
|
||||
if len(password) < 8:
|
||||
errors.append("Password must be at least 8 characters long")
|
||||
|
||||
return {
|
||||
"is_valid": len(errors) == 0,
|
||||
"errors": errors
|
||||
}
|
||||
110
services/auth/app/repositories/deletion_job_repository.py
Normal file
110
services/auth/app/repositories/deletion_job_repository.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""
|
||||
Deletion Job Repository
|
||||
Database operations for deletion job persistence
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
from sqlalchemy import select, and_, desc
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
import structlog
|
||||
|
||||
from app.models.deletion_job import DeletionJob
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class DeletionJobRepository:
|
||||
"""Repository for deletion job database operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def create(self, deletion_job: DeletionJob) -> DeletionJob:
|
||||
"""Create a new deletion job record"""
|
||||
try:
|
||||
self.session.add(deletion_job)
|
||||
await self.session.flush()
|
||||
await self.session.refresh(deletion_job)
|
||||
return deletion_job
|
||||
except Exception as e:
|
||||
logger.error("Failed to create deletion job", error=str(e))
|
||||
raise
|
||||
|
||||
async def get_by_job_id(self, job_id: str) -> Optional[DeletionJob]:
|
||||
"""Get deletion job by job_id"""
|
||||
try:
|
||||
query = select(DeletionJob).where(DeletionJob.job_id == job_id)
|
||||
result = await self.session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error("Failed to get deletion job", error=str(e), job_id=job_id)
|
||||
raise
|
||||
|
||||
async def get_by_id(self, id: UUID) -> Optional[DeletionJob]:
|
||||
"""Get deletion job by database ID"""
|
||||
try:
|
||||
return await self.session.get(DeletionJob, id)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get deletion job by ID", error=str(e), id=str(id))
|
||||
raise
|
||||
|
||||
async def list_by_tenant(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
status: Optional[str] = None,
|
||||
limit: int = 100
|
||||
) -> List[DeletionJob]:
|
||||
"""List deletion jobs for a tenant"""
|
||||
try:
|
||||
query = select(DeletionJob).where(DeletionJob.tenant_id == tenant_id)
|
||||
|
||||
if status:
|
||||
query = query.where(DeletionJob.status == status)
|
||||
|
||||
query = query.order_by(desc(DeletionJob.started_at)).limit(limit)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error("Failed to list deletion jobs", error=str(e), tenant_id=str(tenant_id))
|
||||
raise
|
||||
|
||||
async def list_all(
|
||||
self,
|
||||
status: Optional[str] = None,
|
||||
limit: int = 100
|
||||
) -> List[DeletionJob]:
|
||||
"""List all deletion jobs with optional status filter"""
|
||||
try:
|
||||
query = select(DeletionJob)
|
||||
|
||||
if status:
|
||||
query = query.where(DeletionJob.status == status)
|
||||
|
||||
query = query.order_by(desc(DeletionJob.started_at)).limit(limit)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error("Failed to list all deletion jobs", error=str(e))
|
||||
raise
|
||||
|
||||
async def update(self, deletion_job: DeletionJob) -> DeletionJob:
|
||||
"""Update a deletion job record"""
|
||||
try:
|
||||
await self.session.flush()
|
||||
await self.session.refresh(deletion_job)
|
||||
return deletion_job
|
||||
except Exception as e:
|
||||
logger.error("Failed to update deletion job", error=str(e))
|
||||
raise
|
||||
|
||||
async def delete(self, deletion_job: DeletionJob) -> None:
|
||||
"""Delete a deletion job record"""
|
||||
try:
|
||||
await self.session.delete(deletion_job)
|
||||
await self.session.flush()
|
||||
except Exception as e:
|
||||
logger.error("Failed to delete deletion job", error=str(e))
|
||||
raise
|
||||
313
services/auth/app/repositories/onboarding_repository.py
Normal file
313
services/auth/app/repositories/onboarding_repository.py
Normal file
@@ -0,0 +1,313 @@
|
||||
# services/auth/app/repositories/onboarding_repository.py
|
||||
"""
|
||||
Onboarding Repository for database operations
|
||||
"""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, update, delete, and_
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime, timezone
|
||||
import structlog
|
||||
|
||||
from app.models.onboarding import UserOnboardingProgress, UserOnboardingSummary
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
class OnboardingRepository:
|
||||
"""Repository for onboarding progress operations"""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def get_user_progress_steps(self, user_id: str) -> List[UserOnboardingProgress]:
|
||||
"""Get all onboarding steps for a user"""
|
||||
try:
|
||||
result = await self.db.execute(
|
||||
select(UserOnboardingProgress)
|
||||
.where(UserOnboardingProgress.user_id == user_id)
|
||||
.order_by(UserOnboardingProgress.created_at)
|
||||
)
|
||||
return result.scalars().all()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user progress steps for {user_id}: {e}")
|
||||
return []
|
||||
|
||||
async def get_user_step(self, user_id: str, step_name: str) -> Optional[UserOnboardingProgress]:
|
||||
"""Get a specific step for a user"""
|
||||
try:
|
||||
result = await self.db.execute(
|
||||
select(UserOnboardingProgress)
|
||||
.where(
|
||||
and_(
|
||||
UserOnboardingProgress.user_id == user_id,
|
||||
UserOnboardingProgress.step_name == step_name
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting step {step_name} for user {user_id}: {e}")
|
||||
return None
|
||||
|
||||
async def upsert_user_step(
|
||||
self,
|
||||
user_id: str,
|
||||
step_name: str,
|
||||
completed: bool,
|
||||
step_data: Dict[str, Any] = None,
|
||||
auto_commit: bool = True
|
||||
) -> UserOnboardingProgress:
|
||||
"""Insert or update a user's onboarding step
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
step_name: Name of the step
|
||||
completed: Whether the step is completed
|
||||
step_data: Additional data for the step
|
||||
auto_commit: Whether to auto-commit (set to False when used within UnitOfWork)
|
||||
"""
|
||||
try:
|
||||
completed_at = datetime.now(timezone.utc) if completed else None
|
||||
step_data = step_data or {}
|
||||
|
||||
# Use PostgreSQL UPSERT (INSERT ... ON CONFLICT ... DO UPDATE)
|
||||
stmt = insert(UserOnboardingProgress).values(
|
||||
user_id=user_id,
|
||||
step_name=step_name,
|
||||
completed=completed,
|
||||
completed_at=completed_at,
|
||||
step_data=step_data,
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
# On conflict, update the existing record
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=['user_id', 'step_name'],
|
||||
set_=dict(
|
||||
completed=stmt.excluded.completed,
|
||||
completed_at=stmt.excluded.completed_at,
|
||||
step_data=stmt.excluded.step_data,
|
||||
updated_at=stmt.excluded.updated_at
|
||||
)
|
||||
)
|
||||
|
||||
# Return the updated record
|
||||
stmt = stmt.returning(UserOnboardingProgress)
|
||||
result = await self.db.execute(stmt)
|
||||
|
||||
# Only commit if auto_commit is True (not within a UnitOfWork)
|
||||
if auto_commit:
|
||||
await self.db.commit()
|
||||
else:
|
||||
# Flush to ensure the statement is executed
|
||||
await self.db.flush()
|
||||
|
||||
return result.scalars().first()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error upserting step {step_name} for user {user_id}: {e}")
|
||||
if auto_commit:
|
||||
await self.db.rollback()
|
||||
raise
|
||||
|
||||
async def get_user_summary(self, user_id: str) -> Optional[UserOnboardingSummary]:
|
||||
"""Get user's onboarding summary"""
|
||||
try:
|
||||
result = await self.db.execute(
|
||||
select(UserOnboardingSummary)
|
||||
.where(UserOnboardingSummary.user_id == user_id)
|
||||
)
|
||||
return result.scalars().first()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting onboarding summary for user {user_id}: {e}")
|
||||
return None
|
||||
|
||||
async def upsert_user_summary(
|
||||
self,
|
||||
user_id: str,
|
||||
current_step: str,
|
||||
next_step: Optional[str],
|
||||
completion_percentage: float,
|
||||
fully_completed: bool,
|
||||
steps_completed_count: str
|
||||
) -> UserOnboardingSummary:
|
||||
"""Insert or update user's onboarding summary"""
|
||||
try:
|
||||
# Use PostgreSQL UPSERT
|
||||
stmt = insert(UserOnboardingSummary).values(
|
||||
user_id=user_id,
|
||||
current_step=current_step,
|
||||
next_step=next_step,
|
||||
completion_percentage=str(completion_percentage),
|
||||
fully_completed=fully_completed,
|
||||
steps_completed_count=steps_completed_count,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
last_activity_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
# On conflict, update the existing record
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=['user_id'],
|
||||
set_=dict(
|
||||
current_step=stmt.excluded.current_step,
|
||||
next_step=stmt.excluded.next_step,
|
||||
completion_percentage=stmt.excluded.completion_percentage,
|
||||
fully_completed=stmt.excluded.fully_completed,
|
||||
steps_completed_count=stmt.excluded.steps_completed_count,
|
||||
updated_at=stmt.excluded.updated_at,
|
||||
last_activity_at=stmt.excluded.last_activity_at
|
||||
)
|
||||
)
|
||||
|
||||
# Return the updated record
|
||||
stmt = stmt.returning(UserOnboardingSummary)
|
||||
result = await self.db.execute(stmt)
|
||||
await self.db.commit()
|
||||
|
||||
return result.scalars().first()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error upserting summary for user {user_id}: {e}")
|
||||
await self.db.rollback()
|
||||
raise
|
||||
|
||||
async def delete_user_progress(self, user_id: str) -> bool:
|
||||
"""Delete all onboarding progress for a user"""
|
||||
try:
|
||||
# Delete steps
|
||||
await self.db.execute(
|
||||
delete(UserOnboardingProgress)
|
||||
.where(UserOnboardingProgress.user_id == user_id)
|
||||
)
|
||||
|
||||
# Delete summary
|
||||
await self.db.execute(
|
||||
delete(UserOnboardingSummary)
|
||||
.where(UserOnboardingSummary.user_id == user_id)
|
||||
)
|
||||
|
||||
await self.db.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting progress for user {user_id}: {e}")
|
||||
await self.db.rollback()
|
||||
return False
|
||||
|
||||
async def save_step_data(
|
||||
self,
|
||||
user_id: str,
|
||||
step_name: str,
|
||||
step_data: Dict[str, Any],
|
||||
auto_commit: bool = True
|
||||
) -> UserOnboardingProgress:
|
||||
"""Save data for a specific step without marking it as completed
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
step_name: Name of the step
|
||||
step_data: Data to save
|
||||
auto_commit: Whether to auto-commit (set to False when used within UnitOfWork)
|
||||
"""
|
||||
try:
|
||||
# Get existing step or create new one
|
||||
existing_step = await self.get_user_step(user_id, step_name)
|
||||
|
||||
if existing_step:
|
||||
# Update existing step data (merge with existing data)
|
||||
merged_data = {**(existing_step.step_data or {}), **step_data}
|
||||
|
||||
stmt = update(UserOnboardingProgress).where(
|
||||
and_(
|
||||
UserOnboardingProgress.user_id == user_id,
|
||||
UserOnboardingProgress.step_name == step_name
|
||||
)
|
||||
).values(
|
||||
step_data=merged_data,
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
).returning(UserOnboardingProgress)
|
||||
|
||||
result = await self.db.execute(stmt)
|
||||
|
||||
if auto_commit:
|
||||
await self.db.commit()
|
||||
else:
|
||||
await self.db.flush()
|
||||
|
||||
return result.scalars().first()
|
||||
else:
|
||||
# Create new step with data but not completed
|
||||
return await self.upsert_user_step(
|
||||
user_id=user_id,
|
||||
step_name=step_name,
|
||||
completed=False,
|
||||
step_data=step_data,
|
||||
auto_commit=auto_commit
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving step data for {step_name}, user {user_id}: {e}")
|
||||
if auto_commit:
|
||||
await self.db.rollback()
|
||||
raise
|
||||
|
||||
async def get_step_data(self, user_id: str, step_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get data for a specific step"""
|
||||
try:
|
||||
step = await self.get_user_step(user_id, step_name)
|
||||
return step.step_data if step else None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting step data for {step_name}, user {user_id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_subscription_parameters(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get subscription parameters saved during onboarding for tenant creation"""
|
||||
try:
|
||||
step_data = await self.get_step_data(user_id, "user_registered")
|
||||
if step_data:
|
||||
# Extract subscription-related parameters
|
||||
subscription_params = {
|
||||
"subscription_plan": step_data.get("subscription_plan", "starter"),
|
||||
"billing_cycle": step_data.get("billing_cycle", "monthly"),
|
||||
"coupon_code": step_data.get("coupon_code"),
|
||||
"payment_method_id": step_data.get("payment_method_id"),
|
||||
"payment_customer_id": step_data.get("payment_customer_id"),
|
||||
"saved_at": step_data.get("saved_at")
|
||||
}
|
||||
return subscription_params
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting subscription parameters for user {user_id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_completion_stats(self) -> Dict[str, Any]:
|
||||
"""Get completion statistics across all users"""
|
||||
try:
|
||||
# Get total users with onboarding data
|
||||
total_result = await self.db.execute(
|
||||
select(UserOnboardingSummary).count()
|
||||
)
|
||||
total_users = total_result.scalar()
|
||||
|
||||
# Get completed users
|
||||
completed_result = await self.db.execute(
|
||||
select(UserOnboardingSummary)
|
||||
.where(UserOnboardingSummary.fully_completed == True)
|
||||
.count()
|
||||
)
|
||||
completed_users = completed_result.scalar()
|
||||
|
||||
return {
|
||||
"total_users_in_onboarding": total_users,
|
||||
"fully_completed_users": completed_users,
|
||||
"completion_rate": (completed_users / total_users * 100) if total_users > 0 else 0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting completion stats: {e}")
|
||||
return {
|
||||
"total_users_in_onboarding": 0,
|
||||
"fully_completed_users": 0,
|
||||
"completion_rate": 0
|
||||
}
|
||||
124
services/auth/app/repositories/password_reset_repository.py
Normal file
124
services/auth/app/repositories/password_reset_repository.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# services/auth/app/repositories/password_reset_repository.py
|
||||
"""
|
||||
Password reset token repository
|
||||
Repository for password reset token operations
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, text
|
||||
from datetime import datetime, timezone
|
||||
import structlog
|
||||
import uuid
|
||||
|
||||
from .base import AuthBaseRepository
|
||||
from app.models.password_reset_tokens import PasswordResetToken
|
||||
from shared.database.exceptions import DatabaseError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class PasswordResetTokenRepository(AuthBaseRepository):
|
||||
"""Repository for password reset token operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
super().__init__(PasswordResetToken, session)
|
||||
|
||||
async def create_token(self, user_id: str, token: str, expires_at: datetime) -> PasswordResetToken:
|
||||
"""Create a new password reset token"""
|
||||
try:
|
||||
token_data = {
|
||||
"user_id": user_id,
|
||||
"token": token,
|
||||
"expires_at": expires_at,
|
||||
"is_used": False
|
||||
}
|
||||
|
||||
reset_token = await self.create(token_data)
|
||||
|
||||
logger.debug("Password reset token created",
|
||||
user_id=user_id,
|
||||
token_id=reset_token.id,
|
||||
expires_at=expires_at)
|
||||
|
||||
return reset_token
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to create password reset token",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create password reset token: {str(e)}")
|
||||
|
||||
async def get_token_by_value(self, token: str) -> Optional[PasswordResetToken]:
|
||||
"""Get password reset token by token value"""
|
||||
try:
|
||||
stmt = select(PasswordResetToken).where(
|
||||
and_(
|
||||
PasswordResetToken.token == token,
|
||||
PasswordResetToken.is_used == False,
|
||||
PasswordResetToken.expires_at > datetime.now(timezone.utc)
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get password reset token by value", error=str(e))
|
||||
raise DatabaseError(f"Failed to get password reset token: {str(e)}")
|
||||
|
||||
async def mark_token_as_used(self, token_id: str) -> Optional[PasswordResetToken]:
|
||||
"""Mark a password reset token as used"""
|
||||
try:
|
||||
return await self.update(token_id, {
|
||||
"is_used": True,
|
||||
"used_at": datetime.now(timezone.utc)
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error("Failed to mark password reset token as used",
|
||||
token_id=token_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to mark token as used: {str(e)}")
|
||||
|
||||
async def cleanup_expired_tokens(self) -> int:
|
||||
"""Clean up expired password reset tokens"""
|
||||
try:
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Delete expired tokens
|
||||
query = text("""
|
||||
DELETE FROM password_reset_tokens
|
||||
WHERE expires_at < :now OR is_used = true
|
||||
""")
|
||||
|
||||
result = await self.session.execute(query, {"now": now})
|
||||
deleted_count = result.rowcount
|
||||
|
||||
logger.info("Cleaned up expired password reset tokens",
|
||||
deleted_count=deleted_count)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup expired password reset tokens", error=str(e))
|
||||
raise DatabaseError(f"Token cleanup failed: {str(e)}")
|
||||
|
||||
async def get_valid_token_for_user(self, user_id: str) -> Optional[PasswordResetToken]:
|
||||
"""Get a valid (unused, not expired) password reset token for a user"""
|
||||
try:
|
||||
stmt = select(PasswordResetToken).where(
|
||||
and_(
|
||||
PasswordResetToken.user_id == user_id,
|
||||
PasswordResetToken.is_used == False,
|
||||
PasswordResetToken.expires_at > datetime.now(timezone.utc)
|
||||
)
|
||||
).order_by(PasswordResetToken.created_at.desc())
|
||||
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get valid token for user",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get valid token for user: {str(e)}")
|
||||
305
services/auth/app/repositories/token_repository.py
Normal file
305
services/auth/app/repositories/token_repository.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""
|
||||
Token Repository
|
||||
Repository for refresh token operations
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, text
|
||||
from datetime import datetime, timezone, timedelta
|
||||
import structlog
|
||||
|
||||
from .base import AuthBaseRepository
|
||||
from app.models.tokens import RefreshToken
|
||||
from shared.database.exceptions import DatabaseError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class TokenRepository(AuthBaseRepository):
|
||||
"""Repository for refresh token operations"""
|
||||
|
||||
def __init__(self, model, session: AsyncSession, cache_ttl: Optional[int] = 300):
|
||||
# Tokens change frequently, shorter cache time
|
||||
super().__init__(model, session, cache_ttl)
|
||||
|
||||
async def create_token(self, token_data: Dict[str, Any]) -> RefreshToken:
|
||||
"""Create a new refresh token from dictionary data"""
|
||||
return await self.create(token_data)
|
||||
|
||||
async def create_refresh_token(
|
||||
self,
|
||||
user_id: str,
|
||||
token: str,
|
||||
expires_at: datetime
|
||||
) -> RefreshToken:
|
||||
"""Create a new refresh token"""
|
||||
try:
|
||||
token_data = {
|
||||
"user_id": user_id,
|
||||
"token": token,
|
||||
"expires_at": expires_at,
|
||||
"is_revoked": False
|
||||
}
|
||||
|
||||
refresh_token = await self.create(token_data)
|
||||
|
||||
logger.debug("Refresh token created",
|
||||
user_id=user_id,
|
||||
token_id=refresh_token.id,
|
||||
expires_at=expires_at)
|
||||
|
||||
return refresh_token
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to create refresh token",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create refresh token: {str(e)}")
|
||||
|
||||
async def get_token_by_value(self, token: str) -> Optional[RefreshToken]:
|
||||
"""Get refresh token by token value"""
|
||||
try:
|
||||
return await self.get_by_field("token", token)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get token by value", error=str(e))
|
||||
raise DatabaseError(f"Failed to get token: {str(e)}")
|
||||
|
||||
async def get_active_tokens_for_user(self, user_id: str) -> List[RefreshToken]:
|
||||
"""Get all active (non-revoked, non-expired) tokens for a user"""
|
||||
try:
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Use raw query for complex filtering
|
||||
query = text("""
|
||||
SELECT * FROM refresh_tokens
|
||||
WHERE user_id = :user_id
|
||||
AND is_revoked = false
|
||||
AND expires_at > :now
|
||||
ORDER BY created_at DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(query, {
|
||||
"user_id": user_id,
|
||||
"now": now
|
||||
})
|
||||
|
||||
# Convert rows to RefreshToken objects
|
||||
tokens = []
|
||||
for row in result.fetchall():
|
||||
token = RefreshToken(
|
||||
id=row.id,
|
||||
user_id=row.user_id,
|
||||
token=row.token,
|
||||
expires_at=row.expires_at,
|
||||
is_revoked=row.is_revoked,
|
||||
created_at=row.created_at,
|
||||
revoked_at=row.revoked_at
|
||||
)
|
||||
tokens.append(token)
|
||||
|
||||
return tokens
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get active tokens for user",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get active tokens: {str(e)}")
|
||||
|
||||
async def revoke_token(self, token_id: str) -> Optional[RefreshToken]:
|
||||
"""Revoke a refresh token"""
|
||||
try:
|
||||
return await self.update(token_id, {
|
||||
"is_revoked": True,
|
||||
"revoked_at": datetime.now(timezone.utc)
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error("Failed to revoke token",
|
||||
token_id=token_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to revoke token: {str(e)}")
|
||||
|
||||
async def revoke_all_user_tokens(self, user_id: str) -> int:
|
||||
"""Revoke all tokens for a user"""
|
||||
try:
|
||||
# Use bulk update for efficiency
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
query = text("""
|
||||
UPDATE refresh_tokens
|
||||
SET is_revoked = true, revoked_at = :revoked_at
|
||||
WHERE user_id = :user_id AND is_revoked = false
|
||||
""")
|
||||
|
||||
result = await self.session.execute(query, {
|
||||
"user_id": user_id,
|
||||
"revoked_at": now
|
||||
})
|
||||
|
||||
revoked_count = result.rowcount
|
||||
|
||||
logger.info("Revoked all user tokens",
|
||||
user_id=user_id,
|
||||
revoked_count=revoked_count)
|
||||
|
||||
return revoked_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to revoke all user tokens",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to revoke user tokens: {str(e)}")
|
||||
|
||||
async def is_token_valid(self, token: str) -> bool:
|
||||
"""Check if a token is valid (exists, not revoked, not expired)"""
|
||||
try:
|
||||
refresh_token = await self.get_token_by_value(token)
|
||||
|
||||
if not refresh_token:
|
||||
return False
|
||||
|
||||
if refresh_token.is_revoked:
|
||||
return False
|
||||
|
||||
if refresh_token.expires_at < datetime.now(timezone.utc):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to validate token", error=str(e))
|
||||
return False
|
||||
|
||||
async def validate_refresh_token(self, token: str, user_id: str) -> bool:
|
||||
"""Validate refresh token for a specific user"""
|
||||
try:
|
||||
refresh_token = await self.get_token_by_value(token)
|
||||
|
||||
if not refresh_token:
|
||||
logger.debug("Refresh token not found", token_prefix=token[:10] + "...")
|
||||
return False
|
||||
|
||||
# Convert both to strings for comparison to handle UUID vs string mismatch
|
||||
token_user_id = str(refresh_token.user_id)
|
||||
expected_user_id = str(user_id)
|
||||
|
||||
if token_user_id != expected_user_id:
|
||||
logger.warning("Refresh token user_id mismatch",
|
||||
expected_user_id=expected_user_id,
|
||||
actual_user_id=token_user_id)
|
||||
return False
|
||||
|
||||
if refresh_token.is_revoked:
|
||||
logger.debug("Refresh token is revoked", user_id=user_id)
|
||||
return False
|
||||
|
||||
if refresh_token.expires_at < datetime.now(timezone.utc):
|
||||
logger.debug("Refresh token is expired", user_id=user_id)
|
||||
return False
|
||||
|
||||
logger.debug("Refresh token is valid", user_id=user_id)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to validate refresh token",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
return False
|
||||
|
||||
async def cleanup_expired_tokens(self) -> int:
|
||||
"""Clean up expired refresh tokens"""
|
||||
try:
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Delete expired tokens
|
||||
query = text("""
|
||||
DELETE FROM refresh_tokens
|
||||
WHERE expires_at < :now
|
||||
""")
|
||||
|
||||
result = await self.session.execute(query, {"now": now})
|
||||
deleted_count = result.rowcount
|
||||
|
||||
logger.info("Cleaned up expired tokens",
|
||||
deleted_count=deleted_count)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup expired tokens", error=str(e))
|
||||
raise DatabaseError(f"Token cleanup failed: {str(e)}")
|
||||
|
||||
async def cleanup_old_revoked_tokens(self, days_old: int = 30) -> int:
|
||||
"""Clean up old revoked tokens"""
|
||||
try:
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_old)
|
||||
|
||||
query = text("""
|
||||
DELETE FROM refresh_tokens
|
||||
WHERE is_revoked = true
|
||||
AND revoked_at < :cutoff_date
|
||||
""")
|
||||
|
||||
result = await self.session.execute(query, {
|
||||
"cutoff_date": cutoff_date
|
||||
})
|
||||
|
||||
deleted_count = result.rowcount
|
||||
|
||||
logger.info("Cleaned up old revoked tokens",
|
||||
deleted_count=deleted_count,
|
||||
days_old=days_old)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup old revoked tokens",
|
||||
days_old=days_old,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Revoked token cleanup failed: {str(e)}")
|
||||
|
||||
async def get_token_statistics(self) -> Dict[str, Any]:
|
||||
"""Get token statistics"""
|
||||
try:
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Get counts with raw queries
|
||||
stats_query = text("""
|
||||
SELECT
|
||||
COUNT(*) as total_tokens,
|
||||
COUNT(CASE WHEN is_revoked = false AND expires_at > :now THEN 1 END) as active_tokens,
|
||||
COUNT(CASE WHEN is_revoked = true THEN 1 END) as revoked_tokens,
|
||||
COUNT(CASE WHEN expires_at <= :now THEN 1 END) as expired_tokens,
|
||||
COUNT(DISTINCT user_id) as users_with_tokens
|
||||
FROM refresh_tokens
|
||||
""")
|
||||
|
||||
result = await self.session.execute(stats_query, {"now": now})
|
||||
row = result.fetchone()
|
||||
|
||||
if row:
|
||||
return {
|
||||
"total_tokens": row.total_tokens,
|
||||
"active_tokens": row.active_tokens,
|
||||
"revoked_tokens": row.revoked_tokens,
|
||||
"expired_tokens": row.expired_tokens,
|
||||
"users_with_tokens": row.users_with_tokens
|
||||
}
|
||||
|
||||
return {
|
||||
"total_tokens": 0,
|
||||
"active_tokens": 0,
|
||||
"revoked_tokens": 0,
|
||||
"expired_tokens": 0,
|
||||
"users_with_tokens": 0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get token statistics", error=str(e))
|
||||
return {
|
||||
"total_tokens": 0,
|
||||
"active_tokens": 0,
|
||||
"revoked_tokens": 0,
|
||||
"expired_tokens": 0,
|
||||
"users_with_tokens": 0
|
||||
}
|
||||
277
services/auth/app/repositories/user_repository.py
Normal file
277
services/auth/app/repositories/user_repository.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""
|
||||
User Repository
|
||||
Repository for user operations with authentication-specific queries
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, or_, func, desc, text
|
||||
from datetime import datetime, timezone, timedelta
|
||||
import structlog
|
||||
|
||||
from .base import AuthBaseRepository
|
||||
from app.models.users import User
|
||||
from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class UserRepository(AuthBaseRepository):
|
||||
"""Repository for user operations"""
|
||||
|
||||
def __init__(self, model, session: AsyncSession, cache_ttl: Optional[int] = 600):
|
||||
super().__init__(model, session, cache_ttl)
|
||||
|
||||
async def create_user(self, user_data: Dict[str, Any]) -> User:
|
||||
"""Create a new user with validation"""
|
||||
try:
|
||||
# Validate user data
|
||||
validation_result = self._validate_auth_data(
|
||||
user_data,
|
||||
["email", "hashed_password", "full_name", "role"]
|
||||
)
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid user data: {validation_result['errors']}")
|
||||
|
||||
# Check if user already exists
|
||||
existing_user = await self.get_by_email(user_data["email"])
|
||||
if existing_user:
|
||||
raise DuplicateRecordError(f"User with email {user_data['email']} already exists")
|
||||
|
||||
# Create user
|
||||
user = await self.create(user_data)
|
||||
|
||||
logger.info("User created successfully",
|
||||
user_id=user.id,
|
||||
email=user.email,
|
||||
role=user.role)
|
||||
|
||||
return user
|
||||
|
||||
except (ValidationError, DuplicateRecordError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to create user",
|
||||
email=user_data.get("email"),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create user: {str(e)}")
|
||||
|
||||
async def get_user_by_email(self, email: str) -> Optional[User]:
|
||||
"""Get user by email address"""
|
||||
return await self.get_by_email(email)
|
||||
|
||||
async def get_active_users(self, skip: int = 0, limit: int = 100) -> List[User]:
|
||||
"""Get all active users"""
|
||||
return await self.get_active_records(skip=skip, limit=limit)
|
||||
|
||||
async def authenticate_user(self, email: str, password: str) -> Optional[User]:
|
||||
"""Authenticate user with email and plain password"""
|
||||
try:
|
||||
user = await self.get_by_email(email)
|
||||
|
||||
if not user:
|
||||
logger.debug("User not found for authentication", email=email)
|
||||
return None
|
||||
|
||||
if not user.is_active:
|
||||
logger.debug("User account is inactive", email=email)
|
||||
return None
|
||||
|
||||
# Verify password using security manager
|
||||
from app.core.security import SecurityManager
|
||||
if SecurityManager.verify_password(password, user.hashed_password):
|
||||
# Update last login
|
||||
await self.update_last_login(user.id)
|
||||
logger.info("User authenticated successfully",
|
||||
user_id=user.id,
|
||||
email=email)
|
||||
return user
|
||||
|
||||
logger.debug("Invalid password for user", email=email)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Authentication failed",
|
||||
email=email,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Authentication failed: {str(e)}")
|
||||
|
||||
async def update_last_login(self, user_id: str) -> Optional[User]:
|
||||
"""Update user's last login timestamp"""
|
||||
try:
|
||||
return await self.update(user_id, {
|
||||
"last_login": datetime.now(timezone.utc)
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error("Failed to update last login",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
# Don't raise here - last login update is not critical
|
||||
return None
|
||||
|
||||
async def update_user_profile(self, user_id: str, profile_data: Dict[str, Any]) -> Optional[User]:
|
||||
"""Update user profile information"""
|
||||
try:
|
||||
# Remove sensitive fields that shouldn't be updated via profile
|
||||
profile_data.pop("id", None)
|
||||
profile_data.pop("hashed_password", None)
|
||||
profile_data.pop("created_at", None)
|
||||
profile_data.pop("is_active", None)
|
||||
|
||||
# Validate email if being updated
|
||||
if "email" in profile_data:
|
||||
validation_result = self._validate_auth_data(
|
||||
profile_data,
|
||||
["email"]
|
||||
)
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid profile data: {validation_result['errors']}")
|
||||
|
||||
# Check for email conflicts
|
||||
existing_user = await self.get_by_email(profile_data["email"])
|
||||
if existing_user and str(existing_user.id) != str(user_id):
|
||||
raise DuplicateRecordError(f"Email {profile_data['email']} is already in use")
|
||||
|
||||
updated_user = await self.update(user_id, profile_data)
|
||||
|
||||
if updated_user:
|
||||
logger.info("User profile updated",
|
||||
user_id=user_id,
|
||||
updated_fields=list(profile_data.keys()))
|
||||
|
||||
return updated_user
|
||||
|
||||
except (ValidationError, DuplicateRecordError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to update user profile",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to update profile: {str(e)}")
|
||||
|
||||
async def change_password(self, user_id: str, new_password_hash: str) -> bool:
|
||||
"""Change user password"""
|
||||
try:
|
||||
updated_user = await self.update(user_id, {
|
||||
"hashed_password": new_password_hash
|
||||
})
|
||||
|
||||
if updated_user:
|
||||
logger.info("Password changed successfully", user_id=user_id)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to change password",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to change password: {str(e)}")
|
||||
|
||||
async def verify_user_email(self, user_id: str) -> Optional[User]:
|
||||
"""Mark user email as verified"""
|
||||
try:
|
||||
return await self.update(user_id, {
|
||||
"is_verified": True
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error("Failed to verify user email",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to verify email: {str(e)}")
|
||||
|
||||
async def deactivate_user(self, user_id: str) -> Optional[User]:
|
||||
"""Deactivate user account"""
|
||||
return await self.deactivate_record(user_id)
|
||||
|
||||
async def activate_user(self, user_id: str) -> Optional[User]:
|
||||
"""Activate user account"""
|
||||
return await self.activate_record(user_id)
|
||||
|
||||
async def get_users_by_role(self, role: str, skip: int = 0, limit: int = 100) -> List[User]:
|
||||
"""Get users by role"""
|
||||
try:
|
||||
return await self.get_multi(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
filters={"role": role, "is_active": True},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get users by role",
|
||||
role=role,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get users by role: {str(e)}")
|
||||
|
||||
async def search_users(self, search_term: str, skip: int = 0, limit: int = 50) -> List[User]:
|
||||
"""Search users by email or full name"""
|
||||
try:
|
||||
return await self.search(
|
||||
search_term=search_term,
|
||||
search_fields=["email", "full_name"],
|
||||
skip=skip,
|
||||
limit=limit
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to search users",
|
||||
search_term=search_term,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to search users: {str(e)}")
|
||||
|
||||
async def get_user_statistics(self) -> Dict[str, Any]:
|
||||
"""Get user statistics"""
|
||||
try:
|
||||
# Get basic counts
|
||||
total_users = await self.count()
|
||||
active_users = await self.count(filters={"is_active": True})
|
||||
verified_users = await self.count(filters={"is_verified": True})
|
||||
|
||||
# Get users by role using raw query
|
||||
role_query = text("""
|
||||
SELECT role, COUNT(*) as count
|
||||
FROM users
|
||||
WHERE is_active = true
|
||||
GROUP BY role
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(role_query)
|
||||
role_stats = {row.role: row.count for row in result.fetchall()}
|
||||
|
||||
# Recent activity (users created in last 30 days)
|
||||
thirty_days_ago = datetime.now(timezone.utc) - timedelta(days=30)
|
||||
recent_users_query = text("""
|
||||
SELECT COUNT(*) as count
|
||||
FROM users
|
||||
WHERE created_at >= :thirty_days_ago
|
||||
""")
|
||||
|
||||
recent_result = await self.session.execute(
|
||||
recent_users_query,
|
||||
{"thirty_days_ago": thirty_days_ago}
|
||||
)
|
||||
recent_users = recent_result.scalar() or 0
|
||||
|
||||
return {
|
||||
"total_users": total_users,
|
||||
"active_users": active_users,
|
||||
"inactive_users": total_users - active_users,
|
||||
"verified_users": verified_users,
|
||||
"unverified_users": active_users - verified_users,
|
||||
"recent_registrations": recent_users,
|
||||
"users_by_role": role_stats
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get user statistics", error=str(e))
|
||||
return {
|
||||
"total_users": 0,
|
||||
"active_users": 0,
|
||||
"inactive_users": 0,
|
||||
"verified_users": 0,
|
||||
"unverified_users": 0,
|
||||
"recent_registrations": 0,
|
||||
"users_by_role": {}
|
||||
}
|
||||
0
services/auth/app/schemas/__init__.py
Normal file
0
services/auth/app/schemas/__init__.py
Normal file
230
services/auth/app/schemas/auth.py
Normal file
230
services/auth/app/schemas/auth.py
Normal file
@@ -0,0 +1,230 @@
|
||||
# services/auth/app/schemas/auth.py - UPDATED WITH UNIFIED TOKEN RESPONSE
|
||||
"""
|
||||
Authentication schemas - Updated with unified token response format
|
||||
Following industry best practices from Firebase, Cognito, etc.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
# ================================================================
|
||||
# REQUEST SCHEMAS
|
||||
# ================================================================
|
||||
|
||||
class UserRegistration(BaseModel):
|
||||
"""User registration request"""
|
||||
email: EmailStr
|
||||
password: str = Field(..., min_length=8, max_length=128)
|
||||
full_name: str = Field(..., min_length=1, max_length=255)
|
||||
tenant_name: Optional[str] = Field(None, max_length=255)
|
||||
role: Optional[str] = Field("admin", pattern=r'^(user|admin|manager|super_admin)$')
|
||||
subscription_plan: Optional[str] = Field("starter", description="Selected subscription plan (starter, professional, enterprise)")
|
||||
billing_cycle: Optional[str] = Field("monthly", description="Billing cycle (monthly, yearly)")
|
||||
coupon_code: Optional[str] = Field(None, description="Discount coupon code")
|
||||
payment_method_id: Optional[str] = Field(None, description="Stripe payment method ID")
|
||||
# GDPR Consent fields
|
||||
terms_accepted: Optional[bool] = Field(True, description="Accept terms of service")
|
||||
privacy_accepted: Optional[bool] = Field(True, description="Accept privacy policy")
|
||||
marketing_consent: Optional[bool] = Field(False, description="Consent to marketing communications")
|
||||
analytics_consent: Optional[bool] = Field(False, description="Consent to analytics cookies")
|
||||
|
||||
class UserLogin(BaseModel):
|
||||
"""User login request"""
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
class RefreshTokenRequest(BaseModel):
|
||||
"""Refresh token request"""
|
||||
refresh_token: str
|
||||
|
||||
class PasswordChange(BaseModel):
|
||||
"""Password change request"""
|
||||
current_password: str
|
||||
new_password: str = Field(..., min_length=8, max_length=128)
|
||||
|
||||
class PasswordReset(BaseModel):
|
||||
"""Password reset request"""
|
||||
email: EmailStr
|
||||
|
||||
class PasswordResetConfirm(BaseModel):
|
||||
"""Password reset confirmation"""
|
||||
token: str
|
||||
new_password: str = Field(..., min_length=8, max_length=128)
|
||||
|
||||
# ================================================================
|
||||
# RESPONSE SCHEMAS
|
||||
# ================================================================
|
||||
|
||||
class UserData(BaseModel):
|
||||
"""User data embedded in token responses"""
|
||||
id: str
|
||||
email: str
|
||||
full_name: str
|
||||
is_active: bool
|
||||
is_verified: bool
|
||||
created_at: str # ISO format datetime string
|
||||
tenant_id: Optional[str] = None
|
||||
role: Optional[str] = "admin"
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
"""
|
||||
Unified token response for both registration and login
|
||||
Follows industry standards (Firebase, AWS Cognito, etc.)
|
||||
"""
|
||||
access_token: str
|
||||
refresh_token: Optional[str] = None
|
||||
token_type: str = "bearer"
|
||||
expires_in: int = 3600 # seconds
|
||||
user: Optional[UserData] = None
|
||||
subscription_id: Optional[str] = Field(None, description="Subscription ID if created during registration")
|
||||
# Payment action fields (3DS, SetupIntent, etc.)
|
||||
requires_action: Optional[bool] = Field(None, description="Whether payment action is required (3DS, SetupIntent confirmation)")
|
||||
action_type: Optional[str] = Field(None, description="Type of action required (setup_intent_confirmation, payment_intent_confirmation)")
|
||||
client_secret: Optional[str] = Field(None, description="Client secret for payment confirmation")
|
||||
payment_intent_id: Optional[str] = Field(None, description="Payment intent ID for 3DS authentication")
|
||||
setup_intent_id: Optional[str] = Field(None, description="SetupIntent ID for payment method verification")
|
||||
customer_id: Optional[str] = Field(None, description="Stripe customer ID")
|
||||
# Additional fields for post-confirmation subscription completion
|
||||
plan_id: Optional[str] = Field(None, description="Subscription plan ID")
|
||||
payment_method_id: Optional[str] = Field(None, description="Payment method ID")
|
||||
trial_period_days: Optional[int] = Field(None, description="Trial period in days")
|
||||
user_id: Optional[str] = Field(None, description="User ID for post-confirmation processing")
|
||||
billing_interval: Optional[str] = Field(None, description="Billing interval (monthly, yearly)")
|
||||
message: Optional[str] = Field(None, description="Additional message about payment action required")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
|
||||
"refresh_token": "def502004b8b7f8f...",
|
||||
"token_type": "bearer",
|
||||
"expires_in": 3600,
|
||||
"user": {
|
||||
"id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"email": "user@example.com",
|
||||
"full_name": "John Doe",
|
||||
"is_active": True,
|
||||
"is_verified": False,
|
||||
"created_at": "2025-07-22T10:00:00Z",
|
||||
"role": "user"
|
||||
},
|
||||
"subscription_id": "sub_1234567890",
|
||||
"requires_action": True,
|
||||
"action_type": "setup_intent_confirmation",
|
||||
"client_secret": "seti_1234_secret_5678",
|
||||
"payment_intent_id": None,
|
||||
"setup_intent_id": "seti_1234567890",
|
||||
"customer_id": "cus_1234567890"
|
||||
}
|
||||
}
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
"""User response for user management endpoints - FIXED"""
|
||||
id: str
|
||||
email: str
|
||||
full_name: str
|
||||
is_active: bool
|
||||
is_verified: bool
|
||||
created_at: datetime # ✅ Changed from str to datetime
|
||||
last_login: Optional[datetime] = None # ✅ Added missing field
|
||||
phone: Optional[str] = None # ✅ Added missing field
|
||||
language: Optional[str] = None # ✅ Added missing field
|
||||
timezone: Optional[str] = None # ✅ Added missing field
|
||||
tenant_id: Optional[str] = None
|
||||
role: Optional[str] = "admin"
|
||||
payment_customer_id: Optional[str] = None # ✅ Added payment integration field
|
||||
default_payment_method_id: Optional[str] = None # ✅ Added payment integration field
|
||||
|
||||
class Config:
|
||||
from_attributes = True # ✅ Enable ORM mode for SQLAlchemy objects
|
||||
|
||||
|
||||
|
||||
|
||||
class TokenVerification(BaseModel):
|
||||
"""Token verification response"""
|
||||
valid: bool
|
||||
user_id: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
exp: Optional[int] = None
|
||||
message: Optional[str] = None
|
||||
|
||||
class PasswordResetResponse(BaseModel):
|
||||
"""Password reset response"""
|
||||
message: str
|
||||
reset_token: Optional[str] = None
|
||||
|
||||
class LogoutResponse(BaseModel):
|
||||
"""Logout response"""
|
||||
message: str
|
||||
success: bool = True
|
||||
|
||||
# ================================================================
|
||||
# ERROR SCHEMAS
|
||||
# ================================================================
|
||||
|
||||
class ErrorDetail(BaseModel):
|
||||
"""Error detail for API responses"""
|
||||
message: str
|
||||
code: Optional[str] = None
|
||||
field: Optional[str] = None
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""Standardized error response"""
|
||||
success: bool = False
|
||||
error: ErrorDetail
|
||||
timestamp: str
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"success": False,
|
||||
"error": {
|
||||
"message": "Invalid credentials",
|
||||
"code": "AUTH_001"
|
||||
},
|
||||
"timestamp": "2025-07-22T10:00:00Z"
|
||||
}
|
||||
}
|
||||
|
||||
# ================================================================
|
||||
# VALIDATION SCHEMAS
|
||||
# ================================================================
|
||||
|
||||
class EmailVerificationRequest(BaseModel):
|
||||
"""Email verification request"""
|
||||
email: EmailStr
|
||||
|
||||
class EmailVerificationConfirm(BaseModel):
|
||||
"""Email verification confirmation"""
|
||||
token: str
|
||||
|
||||
class ProfileUpdate(BaseModel):
|
||||
"""Profile update request"""
|
||||
full_name: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
email: Optional[EmailStr] = None
|
||||
|
||||
# ================================================================
|
||||
# INTERNAL SCHEMAS (for service communication)
|
||||
# ================================================================
|
||||
|
||||
class UserContext(BaseModel):
|
||||
"""User context for internal service communication"""
|
||||
user_id: str
|
||||
email: str
|
||||
tenant_id: Optional[str] = None
|
||||
roles: list[str] = ["admin"]
|
||||
is_verified: bool = False
|
||||
|
||||
class TokenClaims(BaseModel):
|
||||
"""JWT token claims structure"""
|
||||
sub: str # subject (user_id)
|
||||
email: str
|
||||
full_name: str
|
||||
user_id: str
|
||||
is_verified: bool
|
||||
tenant_id: Optional[str] = None
|
||||
iat: int # issued at
|
||||
exp: int # expires at
|
||||
iss: str = "bakery-auth" # issuer
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user