Add forecasting service
This commit is contained in:
@@ -1,9 +1,9 @@
|
|||||||
"""
|
# ================================================================
|
||||||
Forecasting routes for gateway
|
# Gateway Integration: Update gateway/app/routes/forecasting.py
|
||||||
"""
|
# ================================================================
|
||||||
|
"""Forecasting service routes for API Gateway"""
|
||||||
|
|
||||||
from fastapi import APIRouter, Request, HTTPException
|
from fastapi import APIRouter, Request
|
||||||
from fastapi.responses import JSONResponse
|
|
||||||
import httpx
|
import httpx
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@@ -12,55 +12,49 @@ from app.core.config import settings
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@router.post("/predict")
|
@router.api_route("/forecasts/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||||||
async def create_forecast(request: Request):
|
async def proxy_forecasts(request: Request, path: str):
|
||||||
"""Proxy forecast request to forecasting service"""
|
"""Proxy forecast requests to forecasting service"""
|
||||||
|
return await _proxy_request(request, f"/api/v1/forecasts/{path}")
|
||||||
|
|
||||||
|
@router.api_route("/predictions/{path:path}", methods=["GET", "POST"])
|
||||||
|
async def proxy_predictions(request: Request, path: str):
|
||||||
|
"""Proxy prediction requests to forecasting service"""
|
||||||
|
return await _proxy_request(request, f"/api/v1/predictions/{path}")
|
||||||
|
|
||||||
|
async def _proxy_request(request: Request, target_path: str):
|
||||||
|
"""Proxy request to forecasting service with user context"""
|
||||||
try:
|
try:
|
||||||
body = await request.body()
|
url = f"{settings.FORECASTING_SERVICE_URL}{target_path}"
|
||||||
auth_header = request.headers.get("Authorization")
|
|
||||||
|
# Forward headers and add user context
|
||||||
|
headers = dict(request.headers)
|
||||||
|
headers.pop("host", None)
|
||||||
|
|
||||||
|
# Add user context from gateway authentication
|
||||||
|
if hasattr(request.state, 'user'):
|
||||||
|
headers["X-User-ID"] = str(request.state.user.get("user_id"))
|
||||||
|
headers["X-User-Email"] = request.state.user.get("email", "")
|
||||||
|
headers["X-Tenant-ID"] = str(request.state.user.get("tenant_id"))
|
||||||
|
headers["X-User-Roles"] = ",".join(request.state.user.get("roles", []))
|
||||||
|
|
||||||
|
# Get request body if present
|
||||||
|
body = None
|
||||||
|
if request.method in ["POST", "PUT", "PATCH"]:
|
||||||
|
body = await request.body()
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
response = await client.post(
|
response = await client.request(
|
||||||
f"{settings.FORECASTING_SERVICE_URL}/predict",
|
method=request.method,
|
||||||
|
url=url,
|
||||||
|
headers=headers,
|
||||||
content=body,
|
content=body,
|
||||||
headers={
|
params=request.query_params
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": auth_header
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return JSONResponse(
|
# Return response
|
||||||
status_code=response.status_code,
|
return response.json() if response.headers.get("content-type", "").startswith("application/json") else response.text
|
||||||
content=response.json()
|
|
||||||
)
|
|
||||||
|
|
||||||
except httpx.RequestError as e:
|
except Exception as e:
|
||||||
logger.error(f"Forecasting service unavailable: {e}")
|
logger.error(f"Error proxying to forecasting service: {e}")
|
||||||
raise HTTPException(
|
raise
|
||||||
status_code=503,
|
|
||||||
detail="Forecasting service unavailable"
|
|
||||||
)
|
|
||||||
|
|
||||||
@router.get("/forecasts")
|
|
||||||
async def get_forecasts(request: Request):
|
|
||||||
"""Get forecasts"""
|
|
||||||
try:
|
|
||||||
auth_header = request.headers.get("Authorization")
|
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
|
||||||
response = await client.get(
|
|
||||||
f"{settings.FORECASTING_SERVICE_URL}/forecasts",
|
|
||||||
headers={"Authorization": auth_header}
|
|
||||||
)
|
|
||||||
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=response.status_code,
|
|
||||||
content=response.json()
|
|
||||||
)
|
|
||||||
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
logger.error(f"Forecasting service unavailable: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=503,
|
|
||||||
detail="Forecasting service unavailable"
|
|
||||||
)
|
|
||||||
|
|||||||
78
infrastructure/kubernetes/base/forecasting-service.yaml
Normal file
78
infrastructure/kubernetes/base/forecasting-service.yaml
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
# ================================================================
|
||||||
|
# Kubernetes Deployment: infrastructure/kubernetes/base/forecasting-service.yaml
|
||||||
|
# ================================================================
|
||||||
|
apiVersion: apps/v1
|
||||||
|
kind: Deployment
|
||||||
|
metadata:
|
||||||
|
name: forecasting-service
|
||||||
|
labels:
|
||||||
|
app: forecasting-service
|
||||||
|
spec:
|
||||||
|
replicas: 2
|
||||||
|
selector:
|
||||||
|
matchLabels:
|
||||||
|
app: forecasting-service
|
||||||
|
template:
|
||||||
|
metadata:
|
||||||
|
labels:
|
||||||
|
app: forecasting-service
|
||||||
|
spec:
|
||||||
|
containers:
|
||||||
|
- name: forecasting-service
|
||||||
|
image: bakery-forecasting/forecasting-service:latest
|
||||||
|
ports:
|
||||||
|
- containerPort: 8000
|
||||||
|
env:
|
||||||
|
- name: DATABASE_URL
|
||||||
|
valueFrom:
|
||||||
|
secretKeyRef:
|
||||||
|
name: forecasting-db-secret
|
||||||
|
key: database-url
|
||||||
|
- name: RABBITMQ_URL
|
||||||
|
valueFrom:
|
||||||
|
secretKeyRef:
|
||||||
|
name: rabbitmq-secret
|
||||||
|
key: url
|
||||||
|
- name: REDIS_URL
|
||||||
|
valueFrom:
|
||||||
|
secretKeyRef:
|
||||||
|
name: redis-secret
|
||||||
|
key: url
|
||||||
|
- name: TRAINING_SERVICE_URL
|
||||||
|
value: "http://training-service:8000"
|
||||||
|
- name: DATA_SERVICE_URL
|
||||||
|
value: "http://data-service:8000"
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
memory: "512Mi"
|
||||||
|
cpu: "250m"
|
||||||
|
limits:
|
||||||
|
memory: "1Gi"
|
||||||
|
cpu: "500m"
|
||||||
|
livenessProbe:
|
||||||
|
httpGet:
|
||||||
|
path: /health
|
||||||
|
port: 8000
|
||||||
|
initialDelaySeconds: 30
|
||||||
|
periodSeconds: 10
|
||||||
|
readinessProbe:
|
||||||
|
httpGet:
|
||||||
|
path: /health
|
||||||
|
port: 8000
|
||||||
|
initialDelaySeconds: 5
|
||||||
|
periodSeconds: 5
|
||||||
|
---
|
||||||
|
apiVersion: v1
|
||||||
|
kind: Service
|
||||||
|
metadata:
|
||||||
|
name: forecasting-service
|
||||||
|
labels:
|
||||||
|
app: forecasting-service
|
||||||
|
spec:
|
||||||
|
selector:
|
||||||
|
app: forecasting-service
|
||||||
|
ports:
|
||||||
|
- port: 8000
|
||||||
|
targetPort: 8000
|
||||||
|
type: ClusterIP
|
||||||
|
|
||||||
42
infrastructure/monitoring/prometheus/forecasting-service.yml
Normal file
42
infrastructure/monitoring/prometheus/forecasting-service.yml
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
# ================================================================
|
||||||
|
# Monitoring Configuration: infrastructure/monitoring/prometheus/forecasting-service.yml
|
||||||
|
# ================================================================
|
||||||
|
groups:
|
||||||
|
- name: forecasting-service
|
||||||
|
rules:
|
||||||
|
- alert: ForecastingServiceDown
|
||||||
|
expr: up{job="forecasting-service"} == 0
|
||||||
|
for: 1m
|
||||||
|
labels:
|
||||||
|
severity: critical
|
||||||
|
annotations:
|
||||||
|
summary: "Forecasting service is down"
|
||||||
|
description: "Forecasting service has been down for more than 1 minute"
|
||||||
|
|
||||||
|
- alert: HighForecastingLatency
|
||||||
|
expr: histogram_quantile(0.95, forecast_processing_time_seconds) > 10
|
||||||
|
for: 5m
|
||||||
|
labels:
|
||||||
|
severity: warning
|
||||||
|
annotations:
|
||||||
|
summary: "High forecasting latency"
|
||||||
|
description: "95th percentile forecasting latency is {{ $value }}s"
|
||||||
|
|
||||||
|
- alert: ForecastingErrorRate
|
||||||
|
expr: rate(forecasting_errors_total[5m]) > 0.1
|
||||||
|
for: 5m
|
||||||
|
labels:
|
||||||
|
severity: critical
|
||||||
|
annotations:
|
||||||
|
summary: "High forecasting error rate"
|
||||||
|
description: "Forecasting error rate is {{ $value }} errors/sec"
|
||||||
|
|
||||||
|
- alert: LowModelAccuracy
|
||||||
|
expr: avg(model_accuracy_score) < 0.7
|
||||||
|
for: 10m
|
||||||
|
labels:
|
||||||
|
severity: warning
|
||||||
|
annotations:
|
||||||
|
summary: "Low model accuracy detected"
|
||||||
|
description: "Average model accuracy is {{ $value }}"
|
||||||
|
|
||||||
169
services/forecasting/README.md
Normal file
169
services/forecasting/README.md
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
================================================================
|
||||||
|
# Documentation: services/forecasting/README.md
|
||||||
|
# ================================================================
|
||||||
|
|
||||||
|
# Forecasting Service
|
||||||
|
|
||||||
|
AI-powered demand prediction service for bakery operations in Madrid, Spain.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The Forecasting Service is a specialized microservice responsible for generating accurate demand predictions for bakery products. It integrates trained ML models with real-time weather and traffic data to provide actionable forecasts for business planning.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
### Core Functionality
|
||||||
|
- **Single Product Forecasting**: Generate predictions for individual products
|
||||||
|
- **Batch Forecasting**: Process multiple products and time periods
|
||||||
|
- **Real-time Predictions**: On-demand forecasting with external data
|
||||||
|
- **Business Rules**: Spanish bakery-specific adjustments
|
||||||
|
- **Alert System**: Automated notifications for demand anomalies
|
||||||
|
|
||||||
|
### Integration Points
|
||||||
|
- **Training Service**: Loads trained Prophet models
|
||||||
|
- **Data Service**: Retrieves weather and traffic data
|
||||||
|
- **Notification Service**: Sends alerts and reports
|
||||||
|
- **Gateway Service**: Authentication and request routing
|
||||||
|
|
||||||
|
## API Endpoints
|
||||||
|
|
||||||
|
### Forecasts
|
||||||
|
- `POST /api/v1/forecasts/single` - Generate single forecast
|
||||||
|
- `POST /api/v1/forecasts/batch` - Generate batch forecasts
|
||||||
|
- `GET /api/v1/forecasts/list` - List historical forecasts
|
||||||
|
- `GET /api/v1/forecasts/alerts` - Get forecast alerts
|
||||||
|
- `PUT /api/v1/forecasts/alerts/{id}/acknowledge` - Acknowledge alert
|
||||||
|
|
||||||
|
### Predictions
|
||||||
|
- `POST /api/v1/predictions/realtime` - Real-time prediction
|
||||||
|
- `GET /api/v1/predictions/quick/{product}` - Quick multi-day forecast
|
||||||
|
|
||||||
|
## Business Logic
|
||||||
|
|
||||||
|
### Spanish Bakery Rules
|
||||||
|
- **Siesta Impact**: Reduced afternoon activity consideration
|
||||||
|
- **Weather Adjustments**: Rain reduces traffic, extreme temperatures affect product mix
|
||||||
|
- **Holiday Handling**: Spanish holiday calendar integration
|
||||||
|
- **Weekend Patterns**: Different demand patterns for weekends
|
||||||
|
|
||||||
|
### Business Types
|
||||||
|
- **Individual Bakery**: Single location with direct sales
|
||||||
|
- **Central Workshop**: Production facility supplying multiple locations
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
```bash
|
||||||
|
# Database
|
||||||
|
DATABASE_URL=postgresql+asyncpg://user:pass@host:port/db
|
||||||
|
|
||||||
|
# External Services
|
||||||
|
TRAINING_SERVICE_URL=http://training-service:8000
|
||||||
|
DATA_SERVICE_URL=http://data-service:8000
|
||||||
|
|
||||||
|
# Business Rules
|
||||||
|
WEEKEND_ADJUSTMENT_FACTOR=0.8
|
||||||
|
HOLIDAY_ADJUSTMENT_FACTOR=0.5
|
||||||
|
RAIN_IMPACT_FACTOR=0.7
|
||||||
|
```
|
||||||
|
|
||||||
|
### Performance Settings
|
||||||
|
```bash
|
||||||
|
MAX_FORECAST_DAYS=30
|
||||||
|
PREDICTION_CACHE_TTL_HOURS=6
|
||||||
|
FORECAST_BATCH_SIZE=100
|
||||||
|
```
|
||||||
|
|
||||||
|
## Development
|
||||||
|
|
||||||
|
### Setup
|
||||||
|
```bash
|
||||||
|
cd services/forecasting
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
### Testing
|
||||||
|
```bash
|
||||||
|
pytest tests/ -v --cov=app
|
||||||
|
```
|
||||||
|
|
||||||
|
### Running Locally
|
||||||
|
```bash
|
||||||
|
uvicorn app.main:app --reload --port 8000
|
||||||
|
```
|
||||||
|
|
||||||
|
## Deployment
|
||||||
|
|
||||||
|
### Docker
|
||||||
|
```bash
|
||||||
|
docker build -t forecasting-service .
|
||||||
|
docker run -p 8000:8000 forecasting-service
|
||||||
|
```
|
||||||
|
|
||||||
|
### Kubernetes
|
||||||
|
```bash
|
||||||
|
kubectl apply -f infrastructure/kubernetes/base/forecasting-service.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
## Monitoring
|
||||||
|
|
||||||
|
### Metrics
|
||||||
|
- `forecasts_generated_total` - Total forecasts generated
|
||||||
|
- `predictions_served_total` - Total predictions served
|
||||||
|
- `forecast_processing_time_seconds` - Processing time histogram
|
||||||
|
- `active_models_count` - Number of active models
|
||||||
|
|
||||||
|
### Health Checks
|
||||||
|
- `/health` - Service health status
|
||||||
|
- `/metrics` - Prometheus metrics endpoint
|
||||||
|
|
||||||
|
## Performance
|
||||||
|
|
||||||
|
### Benchmarks
|
||||||
|
- **Single Forecast**: < 2 seconds average
|
||||||
|
- **Batch Forecasting**: 100 products in < 30 seconds
|
||||||
|
- **Concurrent Load**: 95%+ success rate at 20 concurrent requests
|
||||||
|
|
||||||
|
### Optimization
|
||||||
|
- Model caching for faster predictions
|
||||||
|
- Feature preparation optimization
|
||||||
|
- Database query optimization
|
||||||
|
- Asynchronous external API calls
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Common Issues
|
||||||
|
|
||||||
|
1. **No Model Found Error**
|
||||||
|
- Ensure training service has models for tenant/product
|
||||||
|
- Check model training logs in training service
|
||||||
|
|
||||||
|
2. **High Prediction Latency**
|
||||||
|
- Monitor model cache hit rate
|
||||||
|
- Check external service response times
|
||||||
|
- Review database query performance
|
||||||
|
|
||||||
|
3. **Inaccurate Predictions**
|
||||||
|
- Verify external data quality (weather/traffic)
|
||||||
|
- Check model performance metrics
|
||||||
|
- Review business rule configurations
|
||||||
|
|
||||||
|
### Logging
|
||||||
|
```bash
|
||||||
|
# View service logs
|
||||||
|
docker logs forecasting-service
|
||||||
|
|
||||||
|
# Debug level logging
|
||||||
|
LOG_LEVEL=DEBUG uvicorn app.main:app
|
||||||
|
```
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
1. Follow the existing code structure and patterns
|
||||||
|
2. Add tests for new functionality
|
||||||
|
3. Update documentation for API changes
|
||||||
|
4. Ensure performance benchmarks are maintained
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
This service is part of the Bakery Forecasting Platform - MIT License
|
||||||
@@ -1,72 +0,0 @@
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks
|
|
||||||
from typing import List, Optional, Dict, Any
|
|
||||||
from datetime import datetime, date
|
|
||||||
import structlog
|
|
||||||
|
|
||||||
from app.schemas.forecast import (
|
|
||||||
ForecastRequest,
|
|
||||||
ForecastResponse,
|
|
||||||
BatchForecastRequest,
|
|
||||||
ForecastPerformanceResponse
|
|
||||||
)
|
|
||||||
from app.services.forecast_service import ForecastService
|
|
||||||
from app.services.messaging import publish_forecast_generated
|
|
||||||
|
|
||||||
# Import unified authentication
|
|
||||||
from shared.auth.decorators import (
|
|
||||||
get_current_user_dep,
|
|
||||||
get_current_tenant_id_dep
|
|
||||||
)
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/forecasts", tags=["forecasting"])
|
|
||||||
logger = structlog.get_logger()
|
|
||||||
|
|
||||||
@router.post("/generate", response_model=ForecastResponse)
|
|
||||||
async def generate_forecast(
|
|
||||||
request: ForecastRequest,
|
|
||||||
background_tasks: BackgroundTasks,
|
|
||||||
tenant_id: str = Depends(get_current_tenant_id_dep),
|
|
||||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
|
||||||
):
|
|
||||||
"""Generate forecast for products"""
|
|
||||||
try:
|
|
||||||
logger.info("Generating forecast",
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
user_id=current_user["user_id"],
|
|
||||||
products=len(request.products) if request.products else "all")
|
|
||||||
|
|
||||||
forecast_service = ForecastService()
|
|
||||||
|
|
||||||
# Ensure products belong to tenant
|
|
||||||
if request.products:
|
|
||||||
valid_products = await forecast_service.validate_products(
|
|
||||||
tenant_id, request.products
|
|
||||||
)
|
|
||||||
if len(valid_products) != len(request.products):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="Some products not found or not accessible"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate forecast
|
|
||||||
forecast = await forecast_service.generate_forecast(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
request=request,
|
|
||||||
user_id=current_user["user_id"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Publish event
|
|
||||||
background_tasks.add_task(
|
|
||||||
publish_forecast_generated,
|
|
||||||
forecast_id=forecast.id,
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
user_id=current_user["user_id"]
|
|
||||||
)
|
|
||||||
|
|
||||||
return forecast
|
|
||||||
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Failed to generate forecast", error=str(e))
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
|
||||||
326
services/forecasting/app/api/forecasts.py
Normal file
326
services/forecasting/app/api/forecasts.py
Normal file
@@ -0,0 +1,326 @@
|
|||||||
|
# ================================================================
|
||||||
|
# services/forecasting/app/api/forecasts.py
|
||||||
|
# ================================================================
|
||||||
|
"""
|
||||||
|
Forecast API endpoints
|
||||||
|
"""
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from typing import List, Optional
|
||||||
|
from datetime import date
|
||||||
|
|
||||||
|
from app.core.database import get_db
|
||||||
|
from app.core.auth import get_current_user_from_headers
|
||||||
|
from app.services.forecasting_service import ForecastingService
|
||||||
|
from app.schemas.forecasts import (
|
||||||
|
ForecastRequest, ForecastResponse, BatchForecastRequest,
|
||||||
|
BatchForecastResponse, AlertResponse
|
||||||
|
)
|
||||||
|
from app.models.forecasts import Forecast, PredictionBatch, ForecastAlert
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
# Initialize service
|
||||||
|
forecasting_service = ForecastingService()
|
||||||
|
|
||||||
|
@router.post("/single", response_model=ForecastResponse)
|
||||||
|
async def create_single_forecast(
|
||||||
|
request: ForecastRequest,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: dict = Depends(get_current_user_from_headers)
|
||||||
|
):
|
||||||
|
"""Generate a single product forecast"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Verify tenant access
|
||||||
|
if str(request.tenant_id) != str(current_user.get("tenant_id")):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Access denied to this tenant"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate forecast
|
||||||
|
forecast = await forecasting_service.generate_forecast(request, db)
|
||||||
|
|
||||||
|
# Convert to response model
|
||||||
|
return ForecastResponse(
|
||||||
|
id=str(forecast.id),
|
||||||
|
tenant_id=str(forecast.tenant_id),
|
||||||
|
product_name=forecast.product_name,
|
||||||
|
location=forecast.location,
|
||||||
|
forecast_date=forecast.forecast_date,
|
||||||
|
predicted_demand=forecast.predicted_demand,
|
||||||
|
confidence_lower=forecast.confidence_lower,
|
||||||
|
confidence_upper=forecast.confidence_upper,
|
||||||
|
confidence_level=forecast.confidence_level,
|
||||||
|
model_id=str(forecast.model_id),
|
||||||
|
model_version=forecast.model_version,
|
||||||
|
algorithm=forecast.algorithm,
|
||||||
|
business_type=forecast.business_type,
|
||||||
|
is_holiday=forecast.is_holiday,
|
||||||
|
is_weekend=forecast.is_weekend,
|
||||||
|
day_of_week=forecast.day_of_week,
|
||||||
|
weather_temperature=forecast.weather_temperature,
|
||||||
|
weather_precipitation=forecast.weather_precipitation,
|
||||||
|
weather_description=forecast.weather_description,
|
||||||
|
traffic_volume=forecast.traffic_volume,
|
||||||
|
created_at=forecast.created_at,
|
||||||
|
processing_time_ms=forecast.processing_time_ms,
|
||||||
|
features_used=forecast.features_used
|
||||||
|
)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=str(e)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error creating single forecast", error=str(e))
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Internal server error"
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.post("/batch", response_model=BatchForecastResponse)
|
||||||
|
async def create_batch_forecast(
|
||||||
|
request: BatchForecastRequest,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: dict = Depends(get_current_user_from_headers)
|
||||||
|
):
|
||||||
|
"""Generate batch forecasts for multiple products"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Verify tenant access
|
||||||
|
if str(request.tenant_id) != str(current_user.get("tenant_id")):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Access denied to this tenant"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate batch forecast
|
||||||
|
batch = await forecasting_service.generate_batch_forecast(request, db)
|
||||||
|
|
||||||
|
# Get associated forecasts
|
||||||
|
forecasts = await forecasting_service.get_forecasts(
|
||||||
|
tenant_id=request.tenant_id,
|
||||||
|
location=request.location,
|
||||||
|
db=db
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert forecasts to response models
|
||||||
|
forecast_responses = []
|
||||||
|
for forecast in forecasts[:batch.total_products]: # Limit to batch size
|
||||||
|
forecast_responses.append(ForecastResponse(
|
||||||
|
id=str(forecast.id),
|
||||||
|
tenant_id=str(forecast.tenant_id),
|
||||||
|
product_name=forecast.product_name,
|
||||||
|
location=forecast.location,
|
||||||
|
forecast_date=forecast.forecast_date,
|
||||||
|
predicted_demand=forecast.predicted_demand,
|
||||||
|
confidence_lower=forecast.confidence_lower,
|
||||||
|
confidence_upper=forecast.confidence_upper,
|
||||||
|
confidence_level=forecast.confidence_level,
|
||||||
|
model_id=str(forecast.model_id),
|
||||||
|
model_version=forecast.model_version,
|
||||||
|
algorithm=forecast.algorithm,
|
||||||
|
business_type=forecast.business_type,
|
||||||
|
is_holiday=forecast.is_holiday,
|
||||||
|
is_weekend=forecast.is_weekend,
|
||||||
|
day_of_week=forecast.day_of_week,
|
||||||
|
weather_temperature=forecast.weather_temperature,
|
||||||
|
weather_precipitation=forecast.weather_precipitation,
|
||||||
|
weather_description=forecast.weather_description,
|
||||||
|
traffic_volume=forecast.traffic_volume,
|
||||||
|
created_at=forecast.created_at,
|
||||||
|
processing_time_ms=forecast.processing_time_ms,
|
||||||
|
features_used=forecast.features_used
|
||||||
|
))
|
||||||
|
|
||||||
|
return BatchForecastResponse(
|
||||||
|
id=str(batch.id),
|
||||||
|
tenant_id=str(batch.tenant_id),
|
||||||
|
batch_name=batch.batch_name,
|
||||||
|
status=batch.status,
|
||||||
|
total_products=batch.total_products,
|
||||||
|
completed_products=batch.completed_products,
|
||||||
|
failed_products=batch.failed_products,
|
||||||
|
requested_at=batch.requested_at,
|
||||||
|
completed_at=batch.completed_at,
|
||||||
|
processing_time_ms=batch.processing_time_ms,
|
||||||
|
forecasts=forecast_responses
|
||||||
|
)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=str(e)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error creating batch forecast", error=str(e))
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Internal server error"
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("/list", response_model=List[ForecastResponse])
|
||||||
|
async def list_forecasts(
|
||||||
|
location: str,
|
||||||
|
start_date: Optional[date] = Query(None),
|
||||||
|
end_date: Optional[date] = Query(None),
|
||||||
|
product_name: Optional[str] = Query(None),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: dict = Depends(get_current_user_from_headers)
|
||||||
|
):
|
||||||
|
"""List forecasts with filtering"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
tenant_id = str(current_user.get("tenant_id"))
|
||||||
|
|
||||||
|
# Get forecasts
|
||||||
|
forecasts = await forecasting_service.get_forecasts(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
location=location,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
product_name=product_name,
|
||||||
|
db=db
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to response models
|
||||||
|
return [
|
||||||
|
ForecastResponse(
|
||||||
|
id=str(forecast.id),
|
||||||
|
tenant_id=str(forecast.tenant_id),
|
||||||
|
product_name=forecast.product_name,
|
||||||
|
location=forecast.location,
|
||||||
|
forecast_date=forecast.forecast_date,
|
||||||
|
predicted_demand=forecast.predicted_demand,
|
||||||
|
confidence_lower=forecast.confidence_lower,
|
||||||
|
confidence_upper=forecast.confidence_upper,
|
||||||
|
confidence_level=forecast.confidence_level,
|
||||||
|
model_id=str(forecast.model_id),
|
||||||
|
model_version=forecast.model_version,
|
||||||
|
algorithm=forecast.algorithm,
|
||||||
|
business_type=forecast.business_type,
|
||||||
|
is_holiday=forecast.is_holiday,
|
||||||
|
is_weekend=forecast.is_weekend,
|
||||||
|
day_of_week=forecast.day_of_week,
|
||||||
|
weather_temperature=forecast.weather_temperature,
|
||||||
|
weather_precipitation=forecast.weather_precipitation,
|
||||||
|
weather_description=forecast.weather_description,
|
||||||
|
traffic_volume=forecast.traffic_volume,
|
||||||
|
created_at=forecast.created_at,
|
||||||
|
processing_time_ms=forecast.processing_time_ms,
|
||||||
|
features_used=forecast.features_used
|
||||||
|
)
|
||||||
|
for forecast in forecasts
|
||||||
|
]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error listing forecasts", error=str(e))
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Internal server error"
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("/alerts", response_model=List[AlertResponse])
|
||||||
|
async def get_forecast_alerts(
|
||||||
|
active_only: bool = Query(True),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: dict = Depends(get_current_user_from_headers)
|
||||||
|
):
|
||||||
|
"""Get forecast alerts for tenant"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sqlalchemy import select, and_
|
||||||
|
|
||||||
|
tenant_id = current_user.get("tenant_id")
|
||||||
|
|
||||||
|
# Build query
|
||||||
|
query = select(ForecastAlert).where(
|
||||||
|
ForecastAlert.tenant_id == tenant_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if active_only:
|
||||||
|
query = query.where(ForecastAlert.is_active == True)
|
||||||
|
|
||||||
|
query = query.order_by(ForecastAlert.created_at.desc())
|
||||||
|
|
||||||
|
# Execute query
|
||||||
|
result = await db.execute(query)
|
||||||
|
alerts = result.scalars().all()
|
||||||
|
|
||||||
|
# Convert to response models
|
||||||
|
return [
|
||||||
|
AlertResponse(
|
||||||
|
id=str(alert.id),
|
||||||
|
tenant_id=str(alert.tenant_id),
|
||||||
|
forecast_id=str(alert.forecast_id),
|
||||||
|
alert_type=alert.alert_type,
|
||||||
|
severity=alert.severity,
|
||||||
|
message=alert.message,
|
||||||
|
is_active=alert.is_active,
|
||||||
|
created_at=alert.created_at,
|
||||||
|
acknowledged_at=alert.acknowledged_at,
|
||||||
|
notification_sent=alert.notification_sent
|
||||||
|
)
|
||||||
|
for alert in alerts
|
||||||
|
]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error getting forecast alerts", error=str(e))
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Internal server error"
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.put("/alerts/{alert_id}/acknowledge")
|
||||||
|
async def acknowledge_alert(
|
||||||
|
alert_id: str,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: dict = Depends(get_current_user_from_headers)
|
||||||
|
):
|
||||||
|
"""Acknowledge a forecast alert"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sqlalchemy import select, update
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
tenant_id = current_user.get("tenant_id")
|
||||||
|
|
||||||
|
# Get alert
|
||||||
|
result = await db.execute(
|
||||||
|
select(ForecastAlert).where(
|
||||||
|
and_(
|
||||||
|
ForecastAlert.id == alert_id,
|
||||||
|
ForecastAlert.tenant_id == tenant_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
alert = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not alert:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Alert not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update alert
|
||||||
|
alert.acknowledged_at = datetime.now()
|
||||||
|
alert.is_active = False
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
return {"message": "Alert acknowledged successfully"}
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error acknowledging alert", error=str(e))
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Internal server error"
|
||||||
|
)
|
||||||
141
services/forecasting/app/api/predictions.py
Normal file
141
services/forecasting/app/api/predictions.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
# ================================================================
|
||||||
|
# services/forecasting/app/api/predictions.py
|
||||||
|
# ================================================================
|
||||||
|
"""
|
||||||
|
Prediction API endpoints - Real-time prediction capabilities
|
||||||
|
"""
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
from datetime import date, datetime, timedelta
|
||||||
|
|
||||||
|
from app.core.database import get_db
|
||||||
|
from app.core.auth import get_current_user_from_headers
|
||||||
|
from app.services.prediction_service import PredictionService
|
||||||
|
from app.schemas.forecasts import ForecastRequest
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
# Initialize service
|
||||||
|
prediction_service = PredictionService()
|
||||||
|
|
||||||
|
@router.post("/realtime")
|
||||||
|
async def get_realtime_prediction(
|
||||||
|
product_name: str,
|
||||||
|
location: str,
|
||||||
|
forecast_date: date,
|
||||||
|
features: Dict[str, Any],
|
||||||
|
current_user: dict = Depends(get_current_user_from_headers)
|
||||||
|
):
|
||||||
|
"""Get real-time prediction without storing in database"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
tenant_id = str(current_user.get("tenant_id"))
|
||||||
|
|
||||||
|
# Get latest model
|
||||||
|
from app.services.forecasting_service import ForecastingService
|
||||||
|
forecasting_service = ForecastingService()
|
||||||
|
|
||||||
|
model_info = await forecasting_service._get_latest_model(
|
||||||
|
tenant_id, product_name, location
|
||||||
|
)
|
||||||
|
|
||||||
|
if not model_info:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"No trained model found for {product_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate prediction
|
||||||
|
prediction = await prediction_service.predict(
|
||||||
|
model_id=model_info["model_id"],
|
||||||
|
features=features,
|
||||||
|
confidence_level=0.8
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"product_name": product_name,
|
||||||
|
"location": location,
|
||||||
|
"forecast_date": forecast_date,
|
||||||
|
"predicted_demand": prediction["demand"],
|
||||||
|
"confidence_lower": prediction["lower_bound"],
|
||||||
|
"confidence_upper": prediction["upper_bound"],
|
||||||
|
"model_id": model_info["model_id"],
|
||||||
|
"model_version": model_info["version"],
|
||||||
|
"generated_at": datetime.now(),
|
||||||
|
"features_used": features
|
||||||
|
}
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error getting realtime prediction", error=str(e))
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Internal server error"
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("/quick/{product_name}")
|
||||||
|
async def get_quick_prediction(
|
||||||
|
product_name: str,
|
||||||
|
location: str = Query(...),
|
||||||
|
days_ahead: int = Query(1, ge=1, le=7),
|
||||||
|
current_user: dict = Depends(get_current_user_from_headers)
|
||||||
|
):
|
||||||
|
"""Get quick prediction for next few days"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
tenant_id = str(current_user.get("tenant_id"))
|
||||||
|
|
||||||
|
# Generate predictions for the next N days
|
||||||
|
predictions = []
|
||||||
|
|
||||||
|
for day in range(1, days_ahead + 1):
|
||||||
|
forecast_date = date.today() + timedelta(days=day)
|
||||||
|
|
||||||
|
# Prepare basic features
|
||||||
|
features = {
|
||||||
|
"date": forecast_date.isoformat(),
|
||||||
|
"day_of_week": forecast_date.weekday(),
|
||||||
|
"is_weekend": forecast_date.weekday() >= 5,
|
||||||
|
"business_type": "individual"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get model and predict
|
||||||
|
from app.services.forecasting_service import ForecastingService
|
||||||
|
forecasting_service = ForecastingService()
|
||||||
|
|
||||||
|
model_info = await forecasting_service._get_latest_model(
|
||||||
|
tenant_id, product_name, location
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_info:
|
||||||
|
prediction = await prediction_service.predict(
|
||||||
|
model_id=model_info["model_id"],
|
||||||
|
features=features
|
||||||
|
)
|
||||||
|
|
||||||
|
predictions.append({
|
||||||
|
"date": forecast_date,
|
||||||
|
"predicted_demand": prediction["demand"],
|
||||||
|
"confidence_lower": prediction["lower_bound"],
|
||||||
|
"confidence_upper": prediction["upper_bound"]
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"product_name": product_name,
|
||||||
|
"location": location,
|
||||||
|
"predictions": predictions,
|
||||||
|
"generated_at": datetime.now()
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error getting quick prediction", error=str(e))
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Internal server error"
|
||||||
|
)
|
||||||
|
|
||||||
@@ -1,22 +1,48 @@
|
|||||||
|
# ================================================================
|
||||||
|
# services/forecasting/app/core/auth.py
|
||||||
|
# ================================================================
|
||||||
"""
|
"""
|
||||||
Authentication configuration for forecasting service
|
Authentication utilities for forecasting service
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from shared.auth.jwt_handler import JWTHandler
|
import structlog
|
||||||
from shared.auth.decorators import require_auth, require_role
|
from fastapi import HTTPException, status, Request
|
||||||
from app.core.config import settings
|
from typing import Dict, Any, Optional
|
||||||
|
|
||||||
# Initialize JWT handler
|
logger = structlog.get_logger()
|
||||||
jwt_handler = JWTHandler(
|
|
||||||
secret_key=settings.JWT_SECRET_KEY,
|
|
||||||
algorithm=settings.JWT_ALGORITHM,
|
|
||||||
access_token_expire_minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES
|
|
||||||
)
|
|
||||||
|
|
||||||
# Export commonly used functions
|
async def get_current_user_from_headers(request: Request) -> Dict[str, Any]:
|
||||||
verify_token = jwt_handler.verify_token
|
"""
|
||||||
create_access_token = jwt_handler.create_access_token
|
Get current user from gateway headers
|
||||||
get_current_user = jwt_handler.get_current_user
|
Gateway middleware adds user context to headers after JWT verification
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Extract user information from headers set by API Gateway
|
||||||
|
user_id = request.headers.get("X-User-ID")
|
||||||
|
user_email = request.headers.get("X-User-Email")
|
||||||
|
tenant_id = request.headers.get("X-Tenant-ID")
|
||||||
|
user_roles = request.headers.get("X-User-Roles", "").split(",")
|
||||||
|
|
||||||
|
if not user_id or not tenant_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Authentication required"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"user_id": user_id,
|
||||||
|
"email": user_email,
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"roles": [role.strip() for role in user_roles if role.strip()]
|
||||||
|
}
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error extracting user from headers", error=str(e))
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid authentication"
|
||||||
|
)
|
||||||
|
|
||||||
# Export decorators
|
|
||||||
__all__ = ['verify_token', 'create_access_token', 'get_current_user', 'require_auth', 'require_role']
|
|
||||||
|
|||||||
@@ -1,12 +1,73 @@
|
|||||||
|
# ================================================================
|
||||||
|
# services/forecasting/app/core/database.py
|
||||||
|
# ================================================================
|
||||||
"""
|
"""
|
||||||
Database configuration for forecasting service
|
Database configuration for forecasting service
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from shared.database.base import DatabaseManager
|
import structlog
|
||||||
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||||
|
from sqlalchemy.pool import NullPool
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
from shared.database.base import Base
|
||||||
|
|
||||||
# Initialize database manager
|
logger = structlog.get_logger()
|
||||||
database_manager = DatabaseManager(settings.DATABASE_URL)
|
|
||||||
|
|
||||||
# Alias for convenience
|
# Create async engine
|
||||||
get_db = database_manager.get_db
|
engine = create_async_engine(
|
||||||
|
settings.DATABASE_URL,
|
||||||
|
poolclass=NullPool,
|
||||||
|
echo=settings.DEBUG,
|
||||||
|
future=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create session factory
|
||||||
|
AsyncSessionLocal = async_sessionmaker(
|
||||||
|
engine,
|
||||||
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False,
|
||||||
|
autoflush=False,
|
||||||
|
autocommit=False
|
||||||
|
)
|
||||||
|
|
||||||
|
class DatabaseManager:
|
||||||
|
"""Database management operations"""
|
||||||
|
|
||||||
|
async def create_tables(self):
|
||||||
|
"""Create database tables"""
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
logger.info("Forecasting database tables created successfully")
|
||||||
|
|
||||||
|
async def get_session(self) -> AsyncSession:
|
||||||
|
"""Get database session"""
|
||||||
|
async with AsyncSessionLocal() as session:
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
await session.commit()
|
||||||
|
except Exception as e:
|
||||||
|
await session.rollback()
|
||||||
|
logger.error(f"Database session error: {e}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
# Global database manager instance
|
||||||
|
database_manager = DatabaseManager()
|
||||||
|
|
||||||
|
async def get_db() -> AsyncSession:
|
||||||
|
"""Database dependency"""
|
||||||
|
async for session in database_manager.get_session():
|
||||||
|
yield session
|
||||||
|
|
||||||
|
async def get_db_health() -> bool:
|
||||||
|
"""Check database health"""
|
||||||
|
try:
|
||||||
|
async with AsyncSessionLocal() as session:
|
||||||
|
await session.execute(text("SELECT 1"))
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Database health check failed: {e}")
|
||||||
|
return False
|
||||||
|
|||||||
@@ -1,61 +1,116 @@
|
|||||||
|
# ================================================================
|
||||||
|
# services/forecasting/app/main.py
|
||||||
|
# ================================================================
|
||||||
"""
|
"""
|
||||||
uLuforecasting Service
|
Forecasting Service Main Application
|
||||||
|
Demand prediction and forecasting service for bakery operations
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
from fastapi import FastAPI
|
from contextlib import asynccontextmanager
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.database import database_manager
|
from app.core.database import database_manager, get_db_health
|
||||||
|
from app.api import forecasts, predictions
|
||||||
|
from app.services.messaging import setup_messaging, cleanup_messaging
|
||||||
from shared.monitoring.logging import setup_logging
|
from shared.monitoring.logging import setup_logging
|
||||||
from shared.monitoring.metrics import MetricsCollector
|
from shared.monitoring.metrics import MetricsCollector
|
||||||
|
|
||||||
# Setup logging
|
# Setup structured logging
|
||||||
setup_logging("forecasting-service", "INFO")
|
setup_logging("forecasting-service", settings.LOG_LEVEL)
|
||||||
logger = structlog.get_logger()
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
# Create FastAPI app
|
|
||||||
app = FastAPI(
|
|
||||||
title="uLuforecasting Service",
|
|
||||||
description="uLuforecasting service for bakery forecasting",
|
|
||||||
version="1.0.0"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize metrics collector
|
# Initialize metrics collector
|
||||||
metrics_collector = MetricsCollector("forecasting-service")
|
metrics_collector = MetricsCollector("forecasting-service")
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""Application lifespan manager for startup and shutdown events"""
|
||||||
|
# Startup
|
||||||
|
logger.info("Starting Forecasting Service", version="1.0.0")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initialize database
|
||||||
|
logger.info("Initializing database connection")
|
||||||
|
await database_manager.create_tables()
|
||||||
|
logger.info("Database initialized successfully")
|
||||||
|
|
||||||
|
# Initialize messaging
|
||||||
|
logger.info("Setting up messaging")
|
||||||
|
await setup_messaging()
|
||||||
|
logger.info("Messaging initialized")
|
||||||
|
|
||||||
|
# Register custom metrics
|
||||||
|
metrics_collector.register_counter("forecasts_generated_total", "Total forecasts generated")
|
||||||
|
metrics_collector.register_counter("predictions_served_total", "Total predictions served")
|
||||||
|
metrics_collector.register_histogram("forecast_processing_time_seconds", "Time to process forecast request")
|
||||||
|
metrics_collector.register_gauge("active_models_count", "Number of active models")
|
||||||
|
|
||||||
|
# Start metrics server
|
||||||
|
metrics_collector.start_metrics_server(8080)
|
||||||
|
|
||||||
|
logger.info("Forecasting Service started successfully")
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to start Forecasting Service", error=str(e))
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
# Shutdown
|
||||||
|
logger.info("Shutting down Forecasting Service")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await cleanup_messaging()
|
||||||
|
logger.info("Messaging cleanup completed")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error during messaging cleanup", error=str(e))
|
||||||
|
|
||||||
|
# Create FastAPI app with lifespan
|
||||||
|
app = FastAPI(
|
||||||
|
title="Bakery Forecasting Service",
|
||||||
|
description="AI-powered demand prediction and forecasting service for bakery operations",
|
||||||
|
version="1.0.0",
|
||||||
|
docs_url="/docs",
|
||||||
|
redoc_url="/redoc",
|
||||||
|
lifespan=lifespan
|
||||||
|
)
|
||||||
|
|
||||||
# CORS middleware
|
# CORS middleware
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=["*"],
|
allow_origins=settings.CORS_ORIGINS_LIST,
|
||||||
allow_credentials=True,
|
allow_credentials=True,
|
||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@app.on_event("startup")
|
# Include API routers
|
||||||
async def startup_event():
|
app.include_router(forecasts.router, prefix="/api/v1/forecasts", tags=["forecasts"])
|
||||||
"""Application startup"""
|
app.include_router(predictions.router, prefix="/api/v1/predictions", tags=["predictions"])
|
||||||
logger.info("Starting uLuforecasting Service")
|
|
||||||
|
|
||||||
# Create database tables
|
|
||||||
await database_manager.create_tables()
|
|
||||||
|
|
||||||
# Start metrics server
|
|
||||||
metrics_collector.start_metrics_server(8080)
|
|
||||||
|
|
||||||
logger.info("uLuforecasting Service started successfully")
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health_check():
|
async def health_check():
|
||||||
"""Health check endpoint"""
|
"""Health check endpoint"""
|
||||||
|
db_health = await get_db_health()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "healthy",
|
"status": "healthy" if db_health else "unhealthy",
|
||||||
"service": "forecasting-service",
|
"service": "forecasting-service",
|
||||||
"version": "1.0.0"
|
"version": "1.0.0",
|
||||||
|
"database": "connected" if db_health else "disconnected",
|
||||||
|
"timestamp": structlog.get_logger().info("Health check requested")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@app.get("/metrics")
|
||||||
|
async def get_metrics():
|
||||||
|
"""Metrics endpoint for Prometheus"""
|
||||||
|
return metrics_collector.generate_latest()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||||
|
|
||||||
|
|||||||
101
services/forecasting/app/ml/model_loader.py
Normal file
101
services/forecasting/app/ml/model_loader.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
# ================================================================
|
||||||
|
# services/forecasting/app/ml/model_loader.py
|
||||||
|
# ================================================================
|
||||||
|
"""
|
||||||
|
Model loading and management utilities
|
||||||
|
"""
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
import pickle
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
|
class ModelLoader:
|
||||||
|
"""
|
||||||
|
Utility class for loading and managing ML models
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.model_cache = {}
|
||||||
|
self.metadata_cache = {}
|
||||||
|
|
||||||
|
async def load_model_with_metadata(self, model_id: str) -> Dict[str, Any]:
|
||||||
|
"""Load model along with its metadata"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get model metadata first
|
||||||
|
metadata = await self._get_model_metadata(model_id)
|
||||||
|
|
||||||
|
if not metadata:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Load the actual model
|
||||||
|
model = await self._load_model_binary(model_id)
|
||||||
|
|
||||||
|
if not model:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"model": model,
|
||||||
|
"metadata": metadata,
|
||||||
|
"loaded_at": datetime.now()
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error loading model with metadata",
|
||||||
|
model_id=model_id,
|
||||||
|
error=str(e))
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _get_model_metadata(self, model_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get model metadata from training service"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(
|
||||||
|
f"{settings.TRAINING_SERVICE_URL}/api/v1/models/{model_id}/metadata",
|
||||||
|
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
return response.json()
|
||||||
|
else:
|
||||||
|
logger.warning("Model metadata not found",
|
||||||
|
model_id=model_id,
|
||||||
|
status_code=response.status_code)
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error getting model metadata", error=str(e))
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _load_model_binary(self, model_id: str):
|
||||||
|
"""Load model binary from training service"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||||
|
response = await client.get(
|
||||||
|
f"{settings.TRAINING_SERVICE_URL}/api/v1/models/{model_id}/download",
|
||||||
|
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
model = pickle.loads(response.content)
|
||||||
|
return model
|
||||||
|
else:
|
||||||
|
logger.error("Failed to download model binary",
|
||||||
|
model_id=model_id,
|
||||||
|
status_code=response.status_code)
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error loading model binary", error=str(e))
|
||||||
|
return None
|
||||||
|
|
||||||
305
services/forecasting/app/ml/predictor.py
Normal file
305
services/forecasting/app/ml/predictor.py
Normal file
@@ -0,0 +1,305 @@
|
|||||||
|
# ================================================================
|
||||||
|
# services/forecasting/app/ml/predictor.py
|
||||||
|
# ================================================================
|
||||||
|
"""
|
||||||
|
Enhanced predictor module with advanced forecasting capabilities
|
||||||
|
"""
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
from typing import Dict, List, Any, Optional, Tuple
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from datetime import datetime, date, timedelta
|
||||||
|
import pickle
|
||||||
|
import json
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
from shared.monitoring.metrics import MetricsCollector
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
metrics = MetricsCollector("forecasting-service")
|
||||||
|
|
||||||
|
class BakeryPredictor:
|
||||||
|
"""
|
||||||
|
Advanced predictor for bakery demand forecasting
|
||||||
|
Handles Prophet models and business-specific logic
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.model_cache = {}
|
||||||
|
self.business_rules = BakeryBusinessRules()
|
||||||
|
|
||||||
|
async def predict_demand(self, model, features: Dict[str, Any],
|
||||||
|
business_type: str = "individual") -> Dict[str, float]:
|
||||||
|
"""Generate demand prediction with business rules applied"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Generate base prediction
|
||||||
|
base_prediction = await self._generate_base_prediction(model, features)
|
||||||
|
|
||||||
|
# Apply business rules
|
||||||
|
adjusted_prediction = self.business_rules.apply_rules(
|
||||||
|
base_prediction, features, business_type
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add uncertainty estimation
|
||||||
|
final_prediction = self._add_uncertainty_bands(adjusted_prediction, features)
|
||||||
|
|
||||||
|
return final_prediction
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error in demand prediction", error=str(e))
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _generate_base_prediction(self, model, features: Dict[str, Any]) -> Dict[str, float]:
|
||||||
|
"""Generate base prediction from Prophet model"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Convert features to Prophet DataFrame
|
||||||
|
df = self._prepare_prophet_dataframe(features)
|
||||||
|
|
||||||
|
# Generate forecast
|
||||||
|
forecast = model.predict(df)
|
||||||
|
|
||||||
|
if len(forecast) > 0:
|
||||||
|
row = forecast.iloc[0]
|
||||||
|
return {
|
||||||
|
"yhat": float(row['yhat']),
|
||||||
|
"yhat_lower": float(row['yhat_lower']),
|
||||||
|
"yhat_upper": float(row['yhat_upper']),
|
||||||
|
"trend": float(row.get('trend', 0)),
|
||||||
|
"seasonal": float(row.get('seasonal', 0)),
|
||||||
|
"weekly": float(row.get('weekly', 0)),
|
||||||
|
"yearly": float(row.get('yearly', 0)),
|
||||||
|
"holidays": float(row.get('holidays', 0))
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError("No prediction generated from model")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error generating base prediction", error=str(e))
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _prepare_prophet_dataframe(self, features: Dict[str, Any]) -> pd.DataFrame:
|
||||||
|
"""Convert features to Prophet-compatible DataFrame"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create base DataFrame
|
||||||
|
df = pd.DataFrame({
|
||||||
|
'ds': [pd.to_datetime(features['date'])]
|
||||||
|
})
|
||||||
|
|
||||||
|
# Add regressor features
|
||||||
|
feature_mapping = {
|
||||||
|
'temperature': 'temperature',
|
||||||
|
'precipitation': 'precipitation',
|
||||||
|
'humidity': 'humidity',
|
||||||
|
'wind_speed': 'wind_speed',
|
||||||
|
'traffic_volume': 'traffic_volume',
|
||||||
|
'pedestrian_count': 'pedestrian_count'
|
||||||
|
}
|
||||||
|
|
||||||
|
for feature_key, df_column in feature_mapping.items():
|
||||||
|
if feature_key in features and features[feature_key] is not None:
|
||||||
|
df[df_column] = float(features[feature_key])
|
||||||
|
else:
|
||||||
|
df[df_column] = 0.0
|
||||||
|
|
||||||
|
# Add categorical features
|
||||||
|
df['day_of_week'] = int(features.get('day_of_week', 0))
|
||||||
|
df['is_weekend'] = int(features.get('is_weekend', False))
|
||||||
|
df['is_holiday'] = int(features.get('is_holiday', False))
|
||||||
|
|
||||||
|
# Business type
|
||||||
|
business_type = features.get('business_type', 'individual')
|
||||||
|
df['is_central_workshop'] = int(business_type == 'central_workshop')
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error preparing Prophet dataframe", error=str(e))
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _add_uncertainty_bands(self, prediction: Dict[str, float],
|
||||||
|
features: Dict[str, Any]) -> Dict[str, float]:
|
||||||
|
"""Add uncertainty estimation based on external factors"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
base_demand = prediction["yhat"]
|
||||||
|
base_lower = prediction["yhat_lower"]
|
||||||
|
base_upper = prediction["yhat_upper"]
|
||||||
|
|
||||||
|
# Weather uncertainty
|
||||||
|
weather_uncertainty = self._calculate_weather_uncertainty(features)
|
||||||
|
|
||||||
|
# Holiday uncertainty
|
||||||
|
holiday_uncertainty = self._calculate_holiday_uncertainty(features)
|
||||||
|
|
||||||
|
# Weekend uncertainty
|
||||||
|
weekend_uncertainty = self._calculate_weekend_uncertainty(features)
|
||||||
|
|
||||||
|
# Total uncertainty factor
|
||||||
|
total_uncertainty = 1.0 + weather_uncertainty + holiday_uncertainty + weekend_uncertainty
|
||||||
|
|
||||||
|
# Adjust bounds
|
||||||
|
uncertainty_range = (base_upper - base_lower) * total_uncertainty
|
||||||
|
center_point = base_demand
|
||||||
|
|
||||||
|
adjusted_lower = center_point - (uncertainty_range / 2)
|
||||||
|
adjusted_upper = center_point + (uncertainty_range / 2)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"demand": max(0, base_demand), # Never predict negative demand
|
||||||
|
"lower_bound": max(0, adjusted_lower),
|
||||||
|
"upper_bound": adjusted_upper,
|
||||||
|
"uncertainty_factor": total_uncertainty,
|
||||||
|
"trend": prediction.get("trend", 0),
|
||||||
|
"seasonal": prediction.get("seasonal", 0),
|
||||||
|
"holiday_effect": prediction.get("holidays", 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error adding uncertainty bands", error=str(e))
|
||||||
|
# Return basic prediction if uncertainty calculation fails
|
||||||
|
return {
|
||||||
|
"demand": max(0, prediction["yhat"]),
|
||||||
|
"lower_bound": max(0, prediction["yhat_lower"]),
|
||||||
|
"upper_bound": prediction["yhat_upper"],
|
||||||
|
"uncertainty_factor": 1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
def _calculate_weather_uncertainty(self, features: Dict[str, Any]) -> float:
|
||||||
|
"""Calculate weather-based uncertainty"""
|
||||||
|
|
||||||
|
uncertainty = 0.0
|
||||||
|
|
||||||
|
# Temperature extremes add uncertainty
|
||||||
|
temp = features.get('temperature')
|
||||||
|
if temp is not None:
|
||||||
|
if temp < settings.TEMPERATURE_THRESHOLD_COLD or temp > settings.TEMPERATURE_THRESHOLD_HOT:
|
||||||
|
uncertainty += 0.1
|
||||||
|
|
||||||
|
# Rain adds uncertainty
|
||||||
|
precipitation = features.get('precipitation')
|
||||||
|
if precipitation is not None and precipitation > 0:
|
||||||
|
uncertainty += 0.05 * min(precipitation, 10) # Cap at 50mm
|
||||||
|
|
||||||
|
return uncertainty
|
||||||
|
|
||||||
|
def _calculate_holiday_uncertainty(self, features: Dict[str, Any]) -> float:
|
||||||
|
"""Calculate holiday-based uncertainty"""
|
||||||
|
|
||||||
|
if features.get('is_holiday', False):
|
||||||
|
return 0.2 # 20% additional uncertainty on holidays
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
def _calculate_weekend_uncertainty(self, features: Dict[str, Any]) -> float:
|
||||||
|
"""Calculate weekend-based uncertainty"""
|
||||||
|
|
||||||
|
if features.get('is_weekend', False):
|
||||||
|
return 0.1 # 10% additional uncertainty on weekends
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class BakeryBusinessRules:
|
||||||
|
"""
|
||||||
|
Business rules for Spanish bakeries
|
||||||
|
Applies domain-specific adjustments to predictions
|
||||||
|
"""
|
||||||
|
|
||||||
|
def apply_rules(self, prediction: Dict[str, float], features: Dict[str, Any],
|
||||||
|
business_type: str) -> Dict[str, float]:
|
||||||
|
"""Apply all business rules to prediction"""
|
||||||
|
|
||||||
|
adjusted_prediction = prediction.copy()
|
||||||
|
|
||||||
|
# Apply weather rules
|
||||||
|
adjusted_prediction = self._apply_weather_rules(adjusted_prediction, features)
|
||||||
|
|
||||||
|
# Apply time-based rules
|
||||||
|
adjusted_prediction = self._apply_time_rules(adjusted_prediction, features)
|
||||||
|
|
||||||
|
# Apply business type rules
|
||||||
|
adjusted_prediction = self._apply_business_type_rules(adjusted_prediction, business_type)
|
||||||
|
|
||||||
|
# Apply Spanish-specific rules
|
||||||
|
adjusted_prediction = self._apply_spanish_rules(adjusted_prediction, features)
|
||||||
|
|
||||||
|
return adjusted_prediction
|
||||||
|
|
||||||
|
def _apply_weather_rules(self, prediction: Dict[str, float],
|
||||||
|
features: Dict[str, Any]) -> Dict[str, float]:
|
||||||
|
"""Apply weather-based business rules"""
|
||||||
|
|
||||||
|
# Rain reduces foot traffic
|
||||||
|
precipitation = features.get('precipitation', 0)
|
||||||
|
if precipitation > 0:
|
||||||
|
rain_factor = settings.RAIN_IMPACT_FACTOR
|
||||||
|
prediction["yhat"] *= rain_factor
|
||||||
|
prediction["yhat_lower"] *= rain_factor
|
||||||
|
prediction["yhat_upper"] *= rain_factor
|
||||||
|
|
||||||
|
# Extreme temperatures affect different products differently
|
||||||
|
temperature = features.get('temperature')
|
||||||
|
if temperature is not None:
|
||||||
|
if temperature > settings.TEMPERATURE_THRESHOLD_HOT:
|
||||||
|
# Hot weather reduces bread sales, increases cold drinks
|
||||||
|
prediction["yhat"] *= 0.9
|
||||||
|
elif temperature < settings.TEMPERATURE_THRESHOLD_COLD:
|
||||||
|
# Cold weather increases hot beverage sales
|
||||||
|
prediction["yhat"] *= 1.1
|
||||||
|
|
||||||
|
return prediction
|
||||||
|
|
||||||
|
def _apply_time_rules(self, prediction: Dict[str, float],
|
||||||
|
features: Dict[str, Any]) -> Dict[str, float]:
|
||||||
|
"""Apply time-based business rules"""
|
||||||
|
|
||||||
|
# Weekend adjustment
|
||||||
|
if features.get('is_weekend', False):
|
||||||
|
weekend_factor = settings.WEEKEND_ADJUSTMENT_FACTOR
|
||||||
|
prediction["yhat"] *= weekend_factor
|
||||||
|
prediction["yhat_lower"] *= weekend_factor
|
||||||
|
prediction["yhat_upper"] *= weekend_factor
|
||||||
|
|
||||||
|
# Holiday adjustment
|
||||||
|
if features.get('is_holiday', False):
|
||||||
|
holiday_factor = settings.HOLIDAY_ADJUSTMENT_FACTOR
|
||||||
|
prediction["yhat"] *= holiday_factor
|
||||||
|
prediction["yhat_lower"] *= holiday_factor
|
||||||
|
prediction["yhat_upper"] *= holiday_factor
|
||||||
|
|
||||||
|
return prediction
|
||||||
|
|
||||||
|
def _apply_business_type_rules(self, prediction: Dict[str, float],
|
||||||
|
business_type: str) -> Dict[str, float]:
|
||||||
|
"""Apply business type specific rules"""
|
||||||
|
|
||||||
|
if business_type == "central_workshop":
|
||||||
|
# Central workshops have more stable demand
|
||||||
|
uncertainty_reduction = 0.8
|
||||||
|
center = prediction["yhat"]
|
||||||
|
lower = prediction["yhat_lower"]
|
||||||
|
upper = prediction["yhat_upper"]
|
||||||
|
|
||||||
|
# Reduce uncertainty band
|
||||||
|
new_range = (upper - lower) * uncertainty_reduction
|
||||||
|
prediction["yhat_lower"] = center - (new_range / 2)
|
||||||
|
prediction["yhat_upper"] = center + (new_range / 2)
|
||||||
|
|
||||||
|
return prediction
|
||||||
|
|
||||||
|
def _apply_spanish_rules(self, prediction: Dict[str, float],
|
||||||
|
features: Dict[str, Any]) -> Dict[str, float]:
|
||||||
|
"""Apply Spanish bakery specific rules"""
|
||||||
|
|
||||||
|
# Spanish siesta time considerations
|
||||||
|
current_date = pd.to_datetime(features['date'])
|
||||||
|
day_of_week = current_date.weekday()
|
||||||
|
|
||||||
|
# Reduced activity during typical siesta hours (14:00-17:00)
|
||||||
|
# This affects afternoon sales planning
|
||||||
|
if day_of_week < 5: # Weekdays
|
||||||
|
prediction["yhat"] *= 0.95 # Slight reduction for siesta effect
|
||||||
|
|
||||||
|
return prediction
|
||||||
112
services/forecasting/app/models/forecasts.py
Normal file
112
services/forecasting/app/models/forecasts.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
# ================================================================
|
||||||
|
# services/forecasting/app/models/forecasts.py
|
||||||
|
# ================================================================
|
||||||
|
"""
|
||||||
|
Forecast models for the forecasting service
|
||||||
|
"""
|
||||||
|
|
||||||
|
from sqlalchemy import Column, String, Integer, Float, DateTime, Boolean, Text, JSON
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from shared.database.base import Base
|
||||||
|
|
||||||
|
class Forecast(Base):
|
||||||
|
"""Forecast model for storing prediction results"""
|
||||||
|
__tablename__ = "forecasts"
|
||||||
|
|
||||||
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||||
|
product_name = Column(String(255), nullable=False, index=True)
|
||||||
|
location = Column(String(255), nullable=False, index=True)
|
||||||
|
|
||||||
|
# Forecast period
|
||||||
|
forecast_date = Column(DateTime(timezone=True), nullable=False, index=True)
|
||||||
|
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||||
|
|
||||||
|
# Prediction results
|
||||||
|
predicted_demand = Column(Float, nullable=False)
|
||||||
|
confidence_lower = Column(Float, nullable=False)
|
||||||
|
confidence_upper = Column(Float, nullable=False)
|
||||||
|
confidence_level = Column(Float, default=0.8)
|
||||||
|
|
||||||
|
# Model information
|
||||||
|
model_id = Column(UUID(as_uuid=True), nullable=False)
|
||||||
|
model_version = Column(String(50), nullable=False)
|
||||||
|
algorithm = Column(String(50), default="prophet")
|
||||||
|
|
||||||
|
# Business context
|
||||||
|
business_type = Column(String(50), default="individual") # individual or central_workshop
|
||||||
|
day_of_week = Column(Integer, nullable=False)
|
||||||
|
is_holiday = Column(Boolean, default=False)
|
||||||
|
is_weekend = Column(Boolean, default=False)
|
||||||
|
|
||||||
|
# External factors
|
||||||
|
weather_temperature = Column(Float)
|
||||||
|
weather_precipitation = Column(Float)
|
||||||
|
weather_description = Column(String(100))
|
||||||
|
traffic_volume = Column(Integer)
|
||||||
|
|
||||||
|
# Metadata
|
||||||
|
processing_time_ms = Column(Integer)
|
||||||
|
features_used = Column(JSON)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<Forecast(id={self.id}, product={self.product_name}, date={self.forecast_date})>"
|
||||||
|
|
||||||
|
class PredictionBatch(Base):
|
||||||
|
"""Batch prediction requests"""
|
||||||
|
__tablename__ = "prediction_batches"
|
||||||
|
|
||||||
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||||
|
|
||||||
|
# Batch information
|
||||||
|
batch_name = Column(String(255), nullable=False)
|
||||||
|
requested_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||||
|
completed_at = Column(DateTime(timezone=True))
|
||||||
|
|
||||||
|
# Status
|
||||||
|
status = Column(String(50), default="pending") # pending, processing, completed, failed
|
||||||
|
total_products = Column(Integer, default=0)
|
||||||
|
completed_products = Column(Integer, default=0)
|
||||||
|
failed_products = Column(Integer, default=0)
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
forecast_days = Column(Integer, default=7)
|
||||||
|
business_type = Column(String(50), default="individual")
|
||||||
|
|
||||||
|
# Results
|
||||||
|
error_message = Column(Text)
|
||||||
|
processing_time_ms = Column(Integer)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<PredictionBatch(id={self.id}, status={self.status})>"
|
||||||
|
|
||||||
|
class ForecastAlert(Base):
|
||||||
|
"""Alerts based on forecast results"""
|
||||||
|
__tablename__ = "forecast_alerts"
|
||||||
|
|
||||||
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||||
|
forecast_id = Column(UUID(as_uuid=True), nullable=False)
|
||||||
|
|
||||||
|
# Alert information
|
||||||
|
alert_type = Column(String(50), nullable=False) # high_demand, low_demand, stockout_risk
|
||||||
|
severity = Column(String(20), default="medium") # low, medium, high, critical
|
||||||
|
message = Column(Text, nullable=False)
|
||||||
|
|
||||||
|
# Status
|
||||||
|
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||||
|
acknowledged_at = Column(DateTime(timezone=True))
|
||||||
|
resolved_at = Column(DateTime(timezone=True))
|
||||||
|
is_active = Column(Boolean, default=True)
|
||||||
|
|
||||||
|
# Notification
|
||||||
|
notification_sent = Column(Boolean, default=False)
|
||||||
|
notification_method = Column(String(50)) # email, whatsapp, sms
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<ForecastAlert(id={self.id}, type={self.alert_type})>"
|
||||||
|
|
||||||
67
services/forecasting/app/models/predictions.py
Normal file
67
services/forecasting/app/models/predictions.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
# ================================================================
|
||||||
|
# services/forecasting/app/models/predictions.py
|
||||||
|
# ================================================================
|
||||||
|
"""
|
||||||
|
Additional prediction models for the forecasting service
|
||||||
|
"""
|
||||||
|
|
||||||
|
from sqlalchemy import Column, String, Integer, Float, DateTime, Boolean, Text, JSON
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from shared.database.base import Base
|
||||||
|
|
||||||
|
class ModelPerformanceMetric(Base):
|
||||||
|
"""Track model performance over time"""
|
||||||
|
__tablename__ = "model_performance_metrics"
|
||||||
|
|
||||||
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
model_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||||
|
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||||
|
product_name = Column(String(255), nullable=False)
|
||||||
|
|
||||||
|
# Performance metrics
|
||||||
|
mae = Column(Float) # Mean Absolute Error
|
||||||
|
mape = Column(Float) # Mean Absolute Percentage Error
|
||||||
|
rmse = Column(Float) # Root Mean Square Error
|
||||||
|
accuracy_score = Column(Float)
|
||||||
|
|
||||||
|
# Evaluation period
|
||||||
|
evaluation_date = Column(DateTime(timezone=True), nullable=False)
|
||||||
|
evaluation_period_start = Column(DateTime(timezone=True))
|
||||||
|
evaluation_period_end = Column(DateTime(timezone=True))
|
||||||
|
|
||||||
|
# Metadata
|
||||||
|
sample_size = Column(Integer)
|
||||||
|
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<ModelPerformanceMetric(model_id={self.model_id}, mae={self.mae})>"
|
||||||
|
|
||||||
|
class PredictionCache(Base):
|
||||||
|
"""Cache frequently requested predictions"""
|
||||||
|
__tablename__ = "prediction_cache"
|
||||||
|
|
||||||
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
cache_key = Column(String(255), unique=True, nullable=False, index=True)
|
||||||
|
|
||||||
|
# Cached data
|
||||||
|
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||||
|
product_name = Column(String(255), nullable=False)
|
||||||
|
location = Column(String(255), nullable=False)
|
||||||
|
forecast_date = Column(DateTime(timezone=True), nullable=False)
|
||||||
|
|
||||||
|
# Cached results
|
||||||
|
predicted_demand = Column(Float, nullable=False)
|
||||||
|
confidence_lower = Column(Float, nullable=False)
|
||||||
|
confidence_upper = Column(Float, nullable=False)
|
||||||
|
model_id = Column(UUID(as_uuid=True), nullable=False)
|
||||||
|
|
||||||
|
# Cache metadata
|
||||||
|
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||||
|
expires_at = Column(DateTime(timezone=True), nullable=False)
|
||||||
|
hit_count = Column(Integer, default=0)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<PredictionCache(key={self.cache_key}, product={self.product_name})>"
|
||||||
123
services/forecasting/app/schemas/forecasts.py
Normal file
123
services/forecasting/app/schemas/forecasts.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
# ================================================================
|
||||||
|
# services/forecasting/app/schemas/forecasts.py
|
||||||
|
# ================================================================
|
||||||
|
"""
|
||||||
|
Forecast schemas for request/response validation
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, validator
|
||||||
|
from datetime import datetime, date
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
class BusinessType(str, Enum):
|
||||||
|
INDIVIDUAL = "individual"
|
||||||
|
CENTRAL_WORKSHOP = "central_workshop"
|
||||||
|
|
||||||
|
class AlertType(str, Enum):
|
||||||
|
HIGH_DEMAND = "high_demand"
|
||||||
|
LOW_DEMAND = "low_demand"
|
||||||
|
STOCKOUT_RISK = "stockout_risk"
|
||||||
|
OVERPRODUCTION = "overproduction"
|
||||||
|
|
||||||
|
class ForecastRequest(BaseModel):
|
||||||
|
"""Request schema for generating forecasts"""
|
||||||
|
tenant_id: str = Field(..., description="Tenant ID")
|
||||||
|
product_name: str = Field(..., description="Product name")
|
||||||
|
location: str = Field(..., description="Location identifier")
|
||||||
|
forecast_date: date = Field(..., description="Date for which to generate forecast")
|
||||||
|
business_type: BusinessType = Field(BusinessType.INDIVIDUAL, description="Business model type")
|
||||||
|
|
||||||
|
# Optional context
|
||||||
|
include_weather: bool = Field(True, description="Include weather data in forecast")
|
||||||
|
include_traffic: bool = Field(True, description="Include traffic data in forecast")
|
||||||
|
confidence_level: float = Field(0.8, ge=0.5, le=0.95, description="Confidence level for intervals")
|
||||||
|
|
||||||
|
@validator('forecast_date')
|
||||||
|
def validate_forecast_date(cls, v):
|
||||||
|
if v < date.today():
|
||||||
|
raise ValueError("Forecast date cannot be in the past")
|
||||||
|
return v
|
||||||
|
|
||||||
|
class BatchForecastRequest(BaseModel):
|
||||||
|
"""Request schema for batch forecasting"""
|
||||||
|
tenant_id: str = Field(..., description="Tenant ID")
|
||||||
|
batch_name: str = Field(..., description="Batch name for tracking")
|
||||||
|
products: List[str] = Field(..., description="List of product names")
|
||||||
|
location: str = Field(..., description="Location identifier")
|
||||||
|
forecast_days: int = Field(7, ge=1, le=30, description="Number of days to forecast")
|
||||||
|
business_type: BusinessType = Field(BusinessType.INDIVIDUAL, description="Business model type")
|
||||||
|
|
||||||
|
# Options
|
||||||
|
include_weather: bool = Field(True, description="Include weather data")
|
||||||
|
include_traffic: bool = Field(True, description="Include traffic data")
|
||||||
|
confidence_level: float = Field(0.8, ge=0.5, le=0.95, description="Confidence level")
|
||||||
|
|
||||||
|
class ForecastResponse(BaseModel):
|
||||||
|
"""Response schema for forecast results"""
|
||||||
|
id: str
|
||||||
|
tenant_id: str
|
||||||
|
product_name: str
|
||||||
|
location: str
|
||||||
|
forecast_date: datetime
|
||||||
|
|
||||||
|
# Predictions
|
||||||
|
predicted_demand: float
|
||||||
|
confidence_lower: float
|
||||||
|
confidence_upper: float
|
||||||
|
confidence_level: float
|
||||||
|
|
||||||
|
# Model info
|
||||||
|
model_id: str
|
||||||
|
model_version: str
|
||||||
|
algorithm: str
|
||||||
|
|
||||||
|
# Context
|
||||||
|
business_type: str
|
||||||
|
is_holiday: bool
|
||||||
|
is_weekend: bool
|
||||||
|
day_of_week: int
|
||||||
|
|
||||||
|
# External factors
|
||||||
|
weather_temperature: Optional[float]
|
||||||
|
weather_precipitation: Optional[float]
|
||||||
|
weather_description: Optional[str]
|
||||||
|
traffic_volume: Optional[int]
|
||||||
|
|
||||||
|
# Metadata
|
||||||
|
created_at: datetime
|
||||||
|
processing_time_ms: Optional[int]
|
||||||
|
features_used: Optional[Dict[str, Any]]
|
||||||
|
|
||||||
|
class BatchForecastResponse(BaseModel):
|
||||||
|
"""Response schema for batch forecast requests"""
|
||||||
|
id: str
|
||||||
|
tenant_id: str
|
||||||
|
batch_name: str
|
||||||
|
status: str
|
||||||
|
total_products: int
|
||||||
|
completed_products: int
|
||||||
|
failed_products: int
|
||||||
|
|
||||||
|
# Timing
|
||||||
|
requested_at: datetime
|
||||||
|
completed_at: Optional[datetime]
|
||||||
|
processing_time_ms: Optional[int]
|
||||||
|
|
||||||
|
# Results
|
||||||
|
forecasts: Optional[List[ForecastResponse]]
|
||||||
|
error_message: Optional[str]
|
||||||
|
|
||||||
|
class AlertResponse(BaseModel):
|
||||||
|
"""Response schema for forecast alerts"""
|
||||||
|
id: str
|
||||||
|
tenant_id: str
|
||||||
|
forecast_id: str
|
||||||
|
alert_type: str
|
||||||
|
severity: str
|
||||||
|
message: str
|
||||||
|
is_active: bool
|
||||||
|
created_at: datetime
|
||||||
|
acknowledged_at: Optional[datetime]
|
||||||
|
notification_sent: bool
|
||||||
|
|
||||||
438
services/forecasting/app/services/forecasting_service.py
Normal file
438
services/forecasting/app/services/forecasting_service.py
Normal file
@@ -0,0 +1,438 @@
|
|||||||
|
# ================================================================
|
||||||
|
# services/forecasting/app/services/forecasting_service.py
|
||||||
|
# ================================================================
|
||||||
|
"""
|
||||||
|
Main forecasting service business logic
|
||||||
|
Orchestrates demand prediction operations
|
||||||
|
"""
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
from typing import Dict, List, Any, Optional
|
||||||
|
from datetime import datetime, date, timedelta
|
||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy import select, and_, desc
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from app.models.forecasts import Forecast, PredictionBatch, ForecastAlert
|
||||||
|
from app.schemas.forecasts import ForecastRequest, BatchForecastRequest, BusinessType
|
||||||
|
from app.services.prediction_service import PredictionService
|
||||||
|
from app.services.messaging import publish_forecast_completed, publish_alert_created
|
||||||
|
from app.core.config import settings
|
||||||
|
from shared.monitoring.metrics import MetricsCollector
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
metrics = MetricsCollector("forecasting-service")
|
||||||
|
|
||||||
|
class ForecastingService:
|
||||||
|
"""
|
||||||
|
Main service class for managing forecasting operations.
|
||||||
|
Handles demand prediction, batch processing, and alert generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.prediction_service = PredictionService()
|
||||||
|
|
||||||
|
async def generate_forecast(self, request: ForecastRequest, db: AsyncSession) -> Forecast:
|
||||||
|
"""Generate a single forecast for a product"""
|
||||||
|
start_time = datetime.now()
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info("Generating forecast",
|
||||||
|
tenant_id=request.tenant_id,
|
||||||
|
product=request.product_name,
|
||||||
|
date=request.forecast_date)
|
||||||
|
|
||||||
|
# Get the latest trained model for this tenant/product
|
||||||
|
model_info = await self._get_latest_model(
|
||||||
|
request.tenant_id,
|
||||||
|
request.product_name,
|
||||||
|
request.location
|
||||||
|
)
|
||||||
|
|
||||||
|
if not model_info:
|
||||||
|
raise ValueError(f"No trained model found for {request.product_name}")
|
||||||
|
|
||||||
|
# Prepare features for prediction
|
||||||
|
features = await self._prepare_forecast_features(request)
|
||||||
|
|
||||||
|
# Generate prediction using ML service
|
||||||
|
prediction_result = await self.prediction_service.predict(
|
||||||
|
model_id=model_info["model_id"],
|
||||||
|
features=features,
|
||||||
|
confidence_level=request.confidence_level
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create forecast record
|
||||||
|
forecast = Forecast(
|
||||||
|
tenant_id=uuid.UUID(request.tenant_id),
|
||||||
|
product_name=request.product_name,
|
||||||
|
location=request.location,
|
||||||
|
forecast_date=datetime.combine(request.forecast_date, datetime.min.time()),
|
||||||
|
|
||||||
|
# Prediction results
|
||||||
|
predicted_demand=prediction_result["demand"],
|
||||||
|
confidence_lower=prediction_result["lower_bound"],
|
||||||
|
confidence_upper=prediction_result["upper_bound"],
|
||||||
|
confidence_level=request.confidence_level,
|
||||||
|
|
||||||
|
# Model information
|
||||||
|
model_id=uuid.UUID(model_info["model_id"]),
|
||||||
|
model_version=model_info["version"],
|
||||||
|
algorithm=model_info.get("algorithm", "prophet"),
|
||||||
|
|
||||||
|
# Context
|
||||||
|
business_type=request.business_type.value,
|
||||||
|
day_of_week=request.forecast_date.weekday(),
|
||||||
|
is_holiday=features.get("is_holiday", False),
|
||||||
|
is_weekend=request.forecast_date.weekday() >= 5,
|
||||||
|
|
||||||
|
# External factors
|
||||||
|
weather_temperature=features.get("temperature"),
|
||||||
|
weather_precipitation=features.get("precipitation"),
|
||||||
|
weather_description=features.get("weather_description"),
|
||||||
|
traffic_volume=features.get("traffic_volume"),
|
||||||
|
|
||||||
|
# Metadata
|
||||||
|
processing_time_ms=int((datetime.now() - start_time).total_seconds() * 1000),
|
||||||
|
features_used=features
|
||||||
|
)
|
||||||
|
|
||||||
|
db.add(forecast)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(forecast)
|
||||||
|
|
||||||
|
# Check for alerts
|
||||||
|
await self._check_and_create_alerts(forecast, db)
|
||||||
|
|
||||||
|
# Update metrics
|
||||||
|
metrics.increment_counter("forecasts_generated_total",
|
||||||
|
{"product": request.product_name, "location": request.location})
|
||||||
|
|
||||||
|
# Publish event
|
||||||
|
await publish_forecast_completed({
|
||||||
|
"forecast_id": str(forecast.id),
|
||||||
|
"tenant_id": request.tenant_id,
|
||||||
|
"product_name": request.product_name,
|
||||||
|
"predicted_demand": forecast.predicted_demand
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info("Forecast generated successfully",
|
||||||
|
forecast_id=str(forecast.id),
|
||||||
|
predicted_demand=forecast.predicted_demand)
|
||||||
|
|
||||||
|
return forecast
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error generating forecast",
|
||||||
|
error=str(e),
|
||||||
|
tenant_id=request.tenant_id,
|
||||||
|
product=request.product_name)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def generate_batch_forecast(self, request: BatchForecastRequest, db: AsyncSession) -> PredictionBatch:
|
||||||
|
"""Generate forecasts for multiple products over multiple days"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info("Starting batch forecast generation",
|
||||||
|
tenant_id=request.tenant_id,
|
||||||
|
batch_name=request.batch_name,
|
||||||
|
products_count=len(request.products),
|
||||||
|
forecast_days=request.forecast_days)
|
||||||
|
|
||||||
|
# Create batch record
|
||||||
|
batch = PredictionBatch(
|
||||||
|
tenant_id=uuid.UUID(request.tenant_id),
|
||||||
|
batch_name=request.batch_name,
|
||||||
|
status="processing",
|
||||||
|
total_products=len(request.products) * request.forecast_days,
|
||||||
|
business_type=request.business_type.value,
|
||||||
|
forecast_days=request.forecast_days
|
||||||
|
)
|
||||||
|
|
||||||
|
db.add(batch)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(batch)
|
||||||
|
|
||||||
|
# Generate forecasts for each product and day
|
||||||
|
completed_count = 0
|
||||||
|
failed_count = 0
|
||||||
|
|
||||||
|
for product in request.products:
|
||||||
|
for day_offset in range(request.forecast_days):
|
||||||
|
forecast_date = date.today() + timedelta(days=day_offset + 1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
forecast_request = ForecastRequest(
|
||||||
|
tenant_id=request.tenant_id,
|
||||||
|
product_name=product,
|
||||||
|
location=request.location,
|
||||||
|
forecast_date=forecast_date,
|
||||||
|
business_type=request.business_type,
|
||||||
|
include_weather=request.include_weather,
|
||||||
|
include_traffic=request.include_traffic,
|
||||||
|
confidence_level=request.confidence_level
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.generate_forecast(forecast_request, db)
|
||||||
|
completed_count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to generate forecast for product",
|
||||||
|
product=product,
|
||||||
|
date=forecast_date,
|
||||||
|
error=str(e))
|
||||||
|
failed_count += 1
|
||||||
|
|
||||||
|
# Update batch status
|
||||||
|
batch.status = "completed" if failed_count == 0 else "partial"
|
||||||
|
batch.completed_products = completed_count
|
||||||
|
batch.failed_products = failed_count
|
||||||
|
batch.completed_at = datetime.now()
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
logger.info("Batch forecast generation completed",
|
||||||
|
batch_id=str(batch.id),
|
||||||
|
completed=completed_count,
|
||||||
|
failed=failed_count)
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error in batch forecast generation", error=str(e))
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_forecasts(self, tenant_id: str, location: str,
|
||||||
|
start_date: Optional[date] = None,
|
||||||
|
end_date: Optional[date] = None,
|
||||||
|
product_name: Optional[str] = None,
|
||||||
|
db: AsyncSession = None) -> List[Forecast]:
|
||||||
|
"""Retrieve forecasts with filtering"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
query = select(Forecast).where(
|
||||||
|
and_(
|
||||||
|
Forecast.tenant_id == uuid.UUID(tenant_id),
|
||||||
|
Forecast.location == location
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if start_date:
|
||||||
|
query = query.where(Forecast.forecast_date >= datetime.combine(start_date, datetime.min.time()))
|
||||||
|
|
||||||
|
if end_date:
|
||||||
|
query = query.where(Forecast.forecast_date <= datetime.combine(end_date, datetime.max.time()))
|
||||||
|
|
||||||
|
if product_name:
|
||||||
|
query = query.where(Forecast.product_name == product_name)
|
||||||
|
|
||||||
|
query = query.order_by(desc(Forecast.forecast_date))
|
||||||
|
|
||||||
|
result = await db.execute(query)
|
||||||
|
forecasts = result.scalars().all()
|
||||||
|
|
||||||
|
logger.info("Retrieved forecasts",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
count=len(forecasts))
|
||||||
|
|
||||||
|
return list(forecasts)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error retrieving forecasts", error=str(e))
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _get_latest_model(self, tenant_id: str, product_name: str, location: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get the latest trained model for a tenant/product combination"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Call training service to get model information
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(
|
||||||
|
f"{settings.TRAINING_SERVICE_URL}/api/v1/models/latest",
|
||||||
|
params={
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"product_name": product_name,
|
||||||
|
"location": location
|
||||||
|
},
|
||||||
|
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
return response.json()
|
||||||
|
elif response.status_code == 404:
|
||||||
|
logger.warning("No model found",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
product=product_name)
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error getting latest model", error=str(e))
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _prepare_forecast_features(self, request: ForecastRequest) -> Dict[str, Any]:
|
||||||
|
"""Prepare features for forecasting model"""
|
||||||
|
|
||||||
|
features = {
|
||||||
|
"date": request.forecast_date.isoformat(),
|
||||||
|
"day_of_week": request.forecast_date.weekday(),
|
||||||
|
"is_weekend": request.forecast_date.weekday() >= 5,
|
||||||
|
"business_type": request.business_type.value
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add Spanish holidays
|
||||||
|
features["is_holiday"] = await self._is_spanish_holiday(request.forecast_date)
|
||||||
|
|
||||||
|
# Add weather data if requested
|
||||||
|
if request.include_weather:
|
||||||
|
weather_data = await self._get_weather_forecast(request.forecast_date)
|
||||||
|
features.update(weather_data)
|
||||||
|
|
||||||
|
# Add traffic data if requested
|
||||||
|
if request.include_traffic:
|
||||||
|
traffic_data = await self._get_traffic_forecast(request.forecast_date, request.location)
|
||||||
|
features.update(traffic_data)
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
async def _is_spanish_holiday(self, forecast_date: date) -> bool:
|
||||||
|
"""Check if date is a Spanish holiday"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Call data service for holiday information
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(
|
||||||
|
f"{settings.DATA_SERVICE_URL}/api/v1/holidays/check",
|
||||||
|
params={"date": forecast_date.isoformat()},
|
||||||
|
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
return response.json().get("is_holiday", False)
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Error checking holiday status", error=str(e))
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _get_weather_forecast(self, forecast_date: date) -> Dict[str, Any]:
|
||||||
|
"""Get weather forecast for the date"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Call data service for weather forecast
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(
|
||||||
|
f"{settings.DATA_SERVICE_URL}/api/v1/weather/forecast",
|
||||||
|
params={"date": forecast_date.isoformat()},
|
||||||
|
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
weather = response.json()
|
||||||
|
return {
|
||||||
|
"temperature": weather.get("temperature"),
|
||||||
|
"precipitation": weather.get("precipitation"),
|
||||||
|
"humidity": weather.get("humidity"),
|
||||||
|
"weather_description": weather.get("description")
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Error getting weather forecast", error=str(e))
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def _get_traffic_forecast(self, forecast_date: date, location: str) -> Dict[str, Any]:
|
||||||
|
"""Get traffic forecast for the date and location"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Call data service for traffic forecast
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(
|
||||||
|
f"{settings.DATA_SERVICE_URL}/api/v1/traffic/forecast",
|
||||||
|
params={
|
||||||
|
"date": forecast_date.isoformat(),
|
||||||
|
"location": location
|
||||||
|
},
|
||||||
|
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
traffic = response.json()
|
||||||
|
return {
|
||||||
|
"traffic_volume": traffic.get("volume"),
|
||||||
|
"pedestrian_count": traffic.get("pedestrian_count")
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Error getting traffic forecast", error=str(e))
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def _check_and_create_alerts(self, forecast: Forecast, db: AsyncSession):
|
||||||
|
"""Check forecast and create alerts if needed"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
alerts_to_create = []
|
||||||
|
|
||||||
|
# High demand alert
|
||||||
|
if forecast.predicted_demand > settings.HIGH_DEMAND_THRESHOLD * 100: # Assuming base of 100 units
|
||||||
|
alerts_to_create.append({
|
||||||
|
"type": "high_demand",
|
||||||
|
"severity": "medium",
|
||||||
|
"message": f"High demand predicted for {forecast.product_name}: {forecast.predicted_demand:.0f} units"
|
||||||
|
})
|
||||||
|
|
||||||
|
# Low demand alert
|
||||||
|
if forecast.predicted_demand < settings.LOW_DEMAND_THRESHOLD * 100:
|
||||||
|
alerts_to_create.append({
|
||||||
|
"type": "low_demand",
|
||||||
|
"severity": "low",
|
||||||
|
"message": f"Low demand predicted for {forecast.product_name}: {forecast.predicted_demand:.0f} units"
|
||||||
|
})
|
||||||
|
|
||||||
|
# Stockout risk alert
|
||||||
|
if forecast.confidence_upper > settings.STOCKOUT_RISK_THRESHOLD * forecast.predicted_demand:
|
||||||
|
alerts_to_create.append({
|
||||||
|
"type": "stockout_risk",
|
||||||
|
"severity": "high",
|
||||||
|
"message": f"Stockout risk for {forecast.product_name}. Upper confidence: {forecast.confidence_upper:.0f}"
|
||||||
|
})
|
||||||
|
|
||||||
|
# Create alerts
|
||||||
|
for alert_data in alerts_to_create:
|
||||||
|
alert = ForecastAlert(
|
||||||
|
tenant_id=forecast.tenant_id,
|
||||||
|
forecast_id=forecast.id,
|
||||||
|
alert_type=alert_data["type"],
|
||||||
|
severity=alert_data["severity"],
|
||||||
|
message=alert_data["message"]
|
||||||
|
)
|
||||||
|
|
||||||
|
db.add(alert)
|
||||||
|
|
||||||
|
# Publish alert event
|
||||||
|
await publish_alert_created({
|
||||||
|
"alert_id": str(alert.id),
|
||||||
|
"tenant_id": str(forecast.tenant_id),
|
||||||
|
"product_name": forecast.product_name,
|
||||||
|
"alert_type": alert_data["type"],
|
||||||
|
"severity": alert_data["severity"],
|
||||||
|
"message": alert_data["message"]
|
||||||
|
})
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
if alerts_to_create:
|
||||||
|
logger.info("Created forecast alerts",
|
||||||
|
forecast_id=str(forecast.id),
|
||||||
|
alerts_count=len(alerts_to_create))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error creating alerts", error=str(e))
|
||||||
|
# Don't raise - alerts are not critical for forecast generation
|
||||||
98
services/forecasting/app/services/messaging.py
Normal file
98
services/forecasting/app/services/messaging.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
# ================================================================
|
||||||
|
# services/forecasting/app/services/messaging.py
|
||||||
|
# ================================================================
|
||||||
|
"""
|
||||||
|
Messaging service for event publishing and consuming
|
||||||
|
"""
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
import json
|
||||||
|
from typing import Dict, Any
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from shared.messaging.rabbitmq import RabbitMQPublisher, RabbitMQConsumer
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
|
# Global messaging instances
|
||||||
|
publisher = None
|
||||||
|
consumer = None
|
||||||
|
|
||||||
|
async def setup_messaging():
|
||||||
|
"""Initialize messaging services"""
|
||||||
|
global publisher, consumer
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initialize publisher
|
||||||
|
publisher = RabbitMQPublisher(settings.RABBITMQ_URL)
|
||||||
|
await publisher.connect()
|
||||||
|
|
||||||
|
# Initialize consumer
|
||||||
|
consumer = RabbitMQConsumer(settings.RABBITMQ_URL)
|
||||||
|
await consumer.connect()
|
||||||
|
|
||||||
|
# Set up event handlers
|
||||||
|
await consumer.subscribe("training.model.updated", handle_model_updated)
|
||||||
|
await consumer.subscribe("data.weather.updated", handle_weather_updated)
|
||||||
|
|
||||||
|
logger.info("Messaging setup completed")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to setup messaging", error=str(e))
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def cleanup_messaging():
|
||||||
|
"""Cleanup messaging connections"""
|
||||||
|
global publisher, consumer
|
||||||
|
|
||||||
|
try:
|
||||||
|
if consumer:
|
||||||
|
await consumer.close()
|
||||||
|
if publisher:
|
||||||
|
await publisher.close()
|
||||||
|
|
||||||
|
logger.info("Messaging cleanup completed")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error during messaging cleanup", error=str(e))
|
||||||
|
|
||||||
|
async def publish_forecast_completed(data: Dict[str, Any]):
|
||||||
|
"""Publish forecast completed event"""
|
||||||
|
if publisher:
|
||||||
|
await publisher.publish("forecasting.forecast.completed", data)
|
||||||
|
|
||||||
|
async def publish_alert_created(data: Dict[str, Any]):
|
||||||
|
"""Publish alert created event"""
|
||||||
|
if publisher:
|
||||||
|
await publisher.publish("forecasting.alert.created", data)
|
||||||
|
|
||||||
|
async def publish_batch_completed(data: Dict[str, Any]):
|
||||||
|
"""Publish batch forecast completed event"""
|
||||||
|
if publisher:
|
||||||
|
await publisher.publish("forecasting.batch.completed", data)
|
||||||
|
|
||||||
|
# Event handlers
|
||||||
|
async def handle_model_updated(data: Dict[str, Any]):
|
||||||
|
"""Handle model updated event from training service"""
|
||||||
|
try:
|
||||||
|
logger.info("Received model updated event",
|
||||||
|
model_id=data.get("model_id"),
|
||||||
|
tenant_id=data.get("tenant_id"))
|
||||||
|
|
||||||
|
# Clear model cache for this model
|
||||||
|
# This will be handled by PredictionService
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error handling model updated event", error=str(e))
|
||||||
|
|
||||||
|
async def handle_weather_updated(data: Dict[str, Any]):
|
||||||
|
"""Handle weather data updated event"""
|
||||||
|
try:
|
||||||
|
logger.info("Received weather updated event",
|
||||||
|
date=data.get("date"))
|
||||||
|
|
||||||
|
# Could trigger re-forecasting if needed
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error handling weather updated event", error=str(e))
|
||||||
166
services/forecasting/app/services/prediction_service.py
Normal file
166
services/forecasting/app/services/prediction_service.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
# ================================================================
|
||||||
|
# services/forecasting/app/services/prediction_service.py
|
||||||
|
# ================================================================
|
||||||
|
"""
|
||||||
|
Prediction service for loading models and generating predictions
|
||||||
|
Handles the actual ML prediction logic
|
||||||
|
"""
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
from typing import Dict, List, Any, Optional
|
||||||
|
import asyncio
|
||||||
|
import pickle
|
||||||
|
import json
|
||||||
|
from datetime import datetime, date
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import httpx
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
from shared.monitoring.metrics import MetricsCollector
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
metrics = MetricsCollector("forecasting-service")
|
||||||
|
|
||||||
|
class PredictionService:
|
||||||
|
"""
|
||||||
|
Service for loading ML models and generating predictions
|
||||||
|
Interfaces with trained Prophet models from the training service
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.model_cache = {}
|
||||||
|
self.cache_ttl = 3600 # 1 hour cache
|
||||||
|
|
||||||
|
async def predict(self, model_id: str, features: Dict[str, Any],
|
||||||
|
confidence_level: float = 0.8) -> Dict[str, float]:
|
||||||
|
"""Generate prediction using trained model"""
|
||||||
|
|
||||||
|
start_time = datetime.now()
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info("Generating prediction",
|
||||||
|
model_id=model_id,
|
||||||
|
features_count=len(features))
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
model = await self._load_model(model_id)
|
||||||
|
|
||||||
|
if not model:
|
||||||
|
raise ValueError(f"Model {model_id} not found or failed to load")
|
||||||
|
|
||||||
|
# Prepare features for Prophet
|
||||||
|
df = self._prepare_prophet_features(features)
|
||||||
|
|
||||||
|
# Generate prediction
|
||||||
|
forecast = model.predict(df)
|
||||||
|
|
||||||
|
# Extract prediction results
|
||||||
|
if len(forecast) > 0:
|
||||||
|
row = forecast.iloc[0]
|
||||||
|
result = {
|
||||||
|
"demand": float(row['yhat']),
|
||||||
|
"lower_bound": float(row[f'yhat_lower']),
|
||||||
|
"upper_bound": float(row[f'yhat_upper']),
|
||||||
|
"trend": float(row.get('trend', 0)),
|
||||||
|
"seasonal": float(row.get('seasonal', 0)),
|
||||||
|
"holiday": float(row.get('holidays', 0))
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError("No prediction generated from model")
|
||||||
|
|
||||||
|
# Update metrics
|
||||||
|
processing_time = (datetime.now() - start_time).total_seconds()
|
||||||
|
metrics.histogram_observe("forecast_processing_time_seconds", processing_time)
|
||||||
|
|
||||||
|
logger.info("Prediction generated successfully",
|
||||||
|
model_id=model_id,
|
||||||
|
predicted_demand=result["demand"],
|
||||||
|
processing_time_ms=int(processing_time * 1000))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error generating prediction",
|
||||||
|
model_id=model_id,
|
||||||
|
error=str(e))
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _load_model(self, model_id: str):
|
||||||
|
"""Load model from cache or training service"""
|
||||||
|
|
||||||
|
# Check cache first
|
||||||
|
if model_id in self.model_cache:
|
||||||
|
cached_model, cached_time = self.model_cache[model_id]
|
||||||
|
if (datetime.now() - cached_time).seconds < self.cache_ttl:
|
||||||
|
logger.debug("Using cached model", model_id=model_id)
|
||||||
|
return cached_model
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Download model from training service
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
response = await client.get(
|
||||||
|
f"{settings.TRAINING_SERVICE_URL}/api/v1/models/{model_id}/download",
|
||||||
|
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
# Load model from bytes
|
||||||
|
model_data = response.content
|
||||||
|
model = pickle.loads(model_data)
|
||||||
|
|
||||||
|
# Cache the model
|
||||||
|
self.model_cache[model_id] = (model, datetime.now())
|
||||||
|
|
||||||
|
logger.info("Model loaded successfully", model_id=model_id)
|
||||||
|
return model
|
||||||
|
else:
|
||||||
|
logger.error("Failed to download model",
|
||||||
|
model_id=model_id,
|
||||||
|
status_code=response.status_code)
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error loading model", model_id=model_id, error=str(e))
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _prepare_prophet_features(self, features: Dict[str, Any]) -> pd.DataFrame:
|
||||||
|
"""Convert features to Prophet-compatible DataFrame"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create base DataFrame with required 'ds' column
|
||||||
|
df = pd.DataFrame({
|
||||||
|
'ds': [pd.to_datetime(features['date'])]
|
||||||
|
})
|
||||||
|
|
||||||
|
# Add numeric features
|
||||||
|
numeric_features = [
|
||||||
|
'temperature', 'precipitation', 'humidity', 'wind_speed',
|
||||||
|
'traffic_volume', 'pedestrian_count'
|
||||||
|
]
|
||||||
|
|
||||||
|
for feature in numeric_features:
|
||||||
|
if feature in features and features[feature] is not None:
|
||||||
|
df[feature] = float(features[feature])
|
||||||
|
else:
|
||||||
|
df[feature] = 0.0
|
||||||
|
|
||||||
|
# Add categorical features
|
||||||
|
df['day_of_week'] = int(features.get('day_of_week', 0))
|
||||||
|
df['is_weekend'] = int(features.get('is_weekend', False))
|
||||||
|
df['is_holiday'] = int(features.get('is_holiday', False))
|
||||||
|
|
||||||
|
# Business type encoding
|
||||||
|
business_type = features.get('business_type', 'individual')
|
||||||
|
df['is_central_workshop'] = int(business_type == 'central_workshop')
|
||||||
|
|
||||||
|
logger.debug("Prepared Prophet features",
|
||||||
|
features_count=len(df.columns),
|
||||||
|
date=features['date'])
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error preparing Prophet features", error=str(e))
|
||||||
|
raise
|
||||||
@@ -0,0 +1,98 @@
|
|||||||
|
# ================================================================
|
||||||
|
# services/forecasting/migrations/versions/001_initial_tables.py
|
||||||
|
# ================================================================
|
||||||
|
"""Initial forecasting tables
|
||||||
|
|
||||||
|
Revision ID: 001
|
||||||
|
Revises:
|
||||||
|
Create Date: 2024-01-15 10:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers
|
||||||
|
revision = '001'
|
||||||
|
down_revision = None
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# Create forecasts table
|
||||||
|
op.create_table('forecasts',
|
||||||
|
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||||
|
sa.Column('tenant_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||||
|
sa.Column('product_name', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('location', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('forecast_date', sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column('predicted_demand', sa.Float(), nullable=False),
|
||||||
|
sa.Column('confidence_lower', sa.Float(), nullable=False),
|
||||||
|
sa.Column('confidence_upper', sa.Float(), nullable=False),
|
||||||
|
sa.Column('confidence_level', sa.Float(), nullable=True),
|
||||||
|
sa.Column('model_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||||
|
sa.Column('model_version', sa.String(length=50), nullable=False),
|
||||||
|
sa.Column('algorithm', sa.String(length=50), nullable=True),
|
||||||
|
sa.Column('business_type', sa.String(length=50), nullable=True),
|
||||||
|
sa.Column('day_of_week', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('is_holiday', sa.Boolean(), nullable=True),
|
||||||
|
sa.Column('is_weekend', sa.Boolean(), nullable=True),
|
||||||
|
sa.Column('weather_temperature', sa.Float(), nullable=True),
|
||||||
|
sa.Column('weather_precipitation', sa.Float(), nullable=True),
|
||||||
|
sa.Column('weather_description', sa.String(length=100), nullable=True),
|
||||||
|
sa.Column('traffic_volume', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('processing_time_ms', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('features_used', sa.JSON(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create indexes
|
||||||
|
op.create_index('ix_forecasts_tenant_id', 'forecasts', ['tenant_id'])
|
||||||
|
op.create_index('ix_forecasts_product_name', 'forecasts', ['product_name'])
|
||||||
|
op.create_index('ix_forecasts_location', 'forecasts', ['location'])
|
||||||
|
op.create_index('ix_forecasts_forecast_date', 'forecasts', ['forecast_date'])
|
||||||
|
|
||||||
|
# Create prediction_batches table
|
||||||
|
op.create_table('prediction_batches',
|
||||||
|
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||||
|
sa.Column('tenant_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||||
|
sa.Column('batch_name', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('requested_at', sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column('completed_at', sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column('status', sa.String(length=50), nullable=True),
|
||||||
|
sa.Column('total_products', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('completed_products', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('failed_products', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('forecast_days', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('business_type', sa.String(length=50), nullable=True),
|
||||||
|
sa.Column('error_message', sa.Text(), nullable=True),
|
||||||
|
sa.Column('processing_time_ms', sa.Integer(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
|
||||||
|
op.create_index('ix_prediction_batches_tenant_id', 'prediction_batches', ['tenant_id'])
|
||||||
|
|
||||||
|
# Create forecast_alerts table
|
||||||
|
op.create_table('forecast_alerts',
|
||||||
|
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||||
|
sa.Column('tenant_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||||
|
sa.Column('forecast_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||||
|
sa.Column('alert_type', sa.String(length=50), nullable=False),
|
||||||
|
sa.Column('severity', sa.String(length=20), nullable=True),
|
||||||
|
sa.Column('message', sa.Text(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column('acknowledged_at', sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column('resolved_at', sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column('is_active', sa.Boolean(), nullable=True),
|
||||||
|
sa.Column('notification_sent', sa.Boolean(), nullable=True),
|
||||||
|
sa.Column('notification_method', sa.String(length=50), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
|
||||||
|
op.create_index('ix_forecast_alerts_tenant_id', 'forecast_alerts', ['tenant_id'])
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
op.drop_table('forecast_alerts')
|
||||||
|
op.drop_table('prediction_batches')
|
||||||
|
op.drop_table('forecasts')
|
||||||
@@ -1,15 +1,36 @@
|
|||||||
|
# Core FastAPI dependencies
|
||||||
fastapi==0.104.1
|
fastapi==0.104.1
|
||||||
uvicorn[standard]==0.24.0
|
uvicorn[standard]==0.24.0
|
||||||
sqlalchemy==2.0.23
|
|
||||||
asyncpg==0.29.0
|
|
||||||
alembic==1.12.1
|
|
||||||
pydantic==2.5.0
|
pydantic==2.5.0
|
||||||
pydantic-settings==2.1.0
|
pydantic-settings==2.1.0
|
||||||
|
|
||||||
|
# Database
|
||||||
|
sqlalchemy[asyncio]==2.0.23
|
||||||
|
asyncpg==0.29.0
|
||||||
|
alembic==1.13.1
|
||||||
|
|
||||||
|
# Authentication & Security
|
||||||
|
python-jose[cryptography]==3.3.0
|
||||||
|
passlib[bcrypt]==1.7.4
|
||||||
|
python-multipart==0.0.6
|
||||||
|
|
||||||
|
# HTTP Client
|
||||||
httpx==0.25.2
|
httpx==0.25.2
|
||||||
redis==5.0.1
|
|
||||||
aio-pika==9.3.0
|
# Machine Learning
|
||||||
prometheus-client==0.17.1
|
prophet==1.1.4
|
||||||
python-json-logger==2.0.4
|
scikit-learn==1.3.2
|
||||||
pytz==2023.3
|
pandas==2.1.4
|
||||||
python-logstash==0.4.8
|
numpy==1.25.2
|
||||||
|
|
||||||
|
# Messaging
|
||||||
|
aio-pika==9.3.1
|
||||||
|
|
||||||
|
# Monitoring & Logging
|
||||||
structlog==23.2.0
|
structlog==23.2.0
|
||||||
|
prometheus-client==0.19.0
|
||||||
|
|
||||||
|
# Development dependencies
|
||||||
|
pytest==7.4.3
|
||||||
|
pytest-asyncio==0.21.1
|
||||||
|
pytest-cov==4.1.0
|
||||||
54
services/forecasting/tests/conftest.py
Normal file
54
services/forecasting/tests/conftest.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
# ================================================================
|
||||||
|
# services/forecasting/tests/conftest.py
|
||||||
|
# ================================================================
|
||||||
|
"""
|
||||||
|
Test configuration and fixtures for forecasting service
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||||
|
from sqlalchemy.pool import StaticPool
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
from shared.database.base import Base
|
||||||
|
|
||||||
|
# Test database URL
|
||||||
|
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def event_loop():
|
||||||
|
"""Create an instance of the default event loop for the test session."""
|
||||||
|
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||||
|
yield loop
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def test_db():
|
||||||
|
"""Create test database session"""
|
||||||
|
|
||||||
|
# Create test engine
|
||||||
|
engine = create_async_engine(
|
||||||
|
TEST_DATABASE_URL,
|
||||||
|
poolclass=StaticPool,
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
echo=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create tables
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
|
# Create session factory
|
||||||
|
TestSessionLocal = async_sessionmaker(
|
||||||
|
engine,
|
||||||
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Provide session
|
||||||
|
async with TestSessionLocal() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
await engine.dispose()
|
||||||
114
services/forecasting/tests/integration/test_forecasting_flow.py
Normal file
114
services/forecasting/tests/integration/test_forecasting_flow.py
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
# ================================================================
|
||||||
|
# Integration Tests: tests/integration/test_forecasting_flow.py
|
||||||
|
# ================================================================
|
||||||
|
"""
|
||||||
|
Integration tests for complete forecasting flow
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import httpx
|
||||||
|
import asyncio
|
||||||
|
from datetime import date, timedelta
|
||||||
|
import json
|
||||||
|
|
||||||
|
class TestForecastingFlow:
|
||||||
|
"""Test complete forecasting workflow"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_complete_forecast_flow(self):
|
||||||
|
"""Test complete flow from training to forecasting"""
|
||||||
|
|
||||||
|
base_url = "http://localhost:8000" # API Gateway
|
||||||
|
|
||||||
|
# Test data
|
||||||
|
tenant_id = "test-tenant-123"
|
||||||
|
product_name = "Pan Integral"
|
||||||
|
location = "madrid_centro"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
# 1. Check if model exists
|
||||||
|
model_response = await client.get(
|
||||||
|
f"{base_url}/api/v1/training/models/latest",
|
||||||
|
params={
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"product_name": product_name,
|
||||||
|
"location": location
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Generate forecast
|
||||||
|
forecast_request = {
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"product_name": product_name,
|
||||||
|
"location": location,
|
||||||
|
"forecast_date": (date.today() + timedelta(days=1)).isoformat(),
|
||||||
|
"business_type": "individual",
|
||||||
|
"include_weather": True,
|
||||||
|
"include_traffic": True,
|
||||||
|
"confidence_level": 0.8
|
||||||
|
}
|
||||||
|
|
||||||
|
forecast_response = await client.post(
|
||||||
|
f"{base_url}/api/v1/forecasting/single",
|
||||||
|
json=forecast_request
|
||||||
|
)
|
||||||
|
|
||||||
|
assert forecast_response.status_code == 200
|
||||||
|
forecast_data = forecast_response.json()
|
||||||
|
|
||||||
|
# Verify forecast structure
|
||||||
|
assert "id" in forecast_data
|
||||||
|
assert "predicted_demand" in forecast_data
|
||||||
|
assert "confidence_lower" in forecast_data
|
||||||
|
assert "confidence_upper" in forecast_data
|
||||||
|
assert forecast_data["product_name"] == product_name
|
||||||
|
|
||||||
|
# 3. Get forecast list
|
||||||
|
list_response = await client.get(
|
||||||
|
f"{base_url}/api/v1/forecasting/list",
|
||||||
|
params={"location": location}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert list_response.status_code == 200
|
||||||
|
forecasts = list_response.json()
|
||||||
|
assert len(forecasts) > 0
|
||||||
|
|
||||||
|
# 4. Check for alerts
|
||||||
|
alerts_response = await client.get(
|
||||||
|
f"{base_url}/api/v1/forecasting/alerts"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert alerts_response.status_code == 200
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_forecasting(self):
|
||||||
|
"""Test batch forecasting functionality"""
|
||||||
|
|
||||||
|
base_url = "http://localhost:8000"
|
||||||
|
|
||||||
|
batch_request = {
|
||||||
|
"tenant_id": "test-tenant-123",
|
||||||
|
"batch_name": "Weekly Forecast Batch",
|
||||||
|
"products": ["Pan Integral", "Croissant", "Café con Leche"],
|
||||||
|
"location": "madrid_centro",
|
||||||
|
"forecast_days": 7,
|
||||||
|
"business_type": "individual",
|
||||||
|
"include_weather": True,
|
||||||
|
"include_traffic": True,
|
||||||
|
"confidence_level": 0.8
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||||
|
response = await client.post(
|
||||||
|
f"{base_url}/api/v1/forecasting/batch",
|
||||||
|
json=batch_request
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
batch_data = response.json()
|
||||||
|
|
||||||
|
assert "id" in batch_data
|
||||||
|
assert batch_data["batch_name"] == "Weekly Forecast Batch"
|
||||||
|
assert batch_data["total_products"] == 21 # 3 products * 7 days
|
||||||
|
assert batch_data["status"] in ["completed", "partial"]
|
||||||
|
|
||||||
@@ -0,0 +1,106 @@
|
|||||||
|
# ================================================================
|
||||||
|
# Performance Tests: tests/performance/test_forecasting_performance.py
|
||||||
|
# ================================================================
|
||||||
|
"""
|
||||||
|
Performance tests for forecasting service
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import httpx
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
import statistics
|
||||||
|
|
||||||
|
class TestForecastingPerformance:
|
||||||
|
"""Performance tests for forecasting operations"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_single_forecast_performance(self):
|
||||||
|
"""Test single forecast generation performance"""
|
||||||
|
|
||||||
|
base_url = "http://localhost:8000"
|
||||||
|
|
||||||
|
forecast_request = {
|
||||||
|
"tenant_id": "perf-test-tenant",
|
||||||
|
"product_name": "Pan Integral",
|
||||||
|
"location": "madrid_centro",
|
||||||
|
"forecast_date": "2024-01-17",
|
||||||
|
"business_type": "individual",
|
||||||
|
"confidence_level": 0.8
|
||||||
|
}
|
||||||
|
|
||||||
|
times = []
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
for _ in range(10):
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
f"{base_url}/api/v1/forecasting/single",
|
||||||
|
json=forecast_request
|
||||||
|
)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
times.append(end_time - start_time)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# Performance assertions
|
||||||
|
avg_time = statistics.mean(times)
|
||||||
|
p95_time = statistics.quantiles(times, n=20)[18] # 95th percentile
|
||||||
|
|
||||||
|
assert avg_time < 2.0, f"Average response time {avg_time}s exceeds 2s"
|
||||||
|
assert p95_time < 5.0, f"95th percentile {p95_time}s exceeds 5s"
|
||||||
|
|
||||||
|
print(f"Average response time: {avg_time:.2f}s")
|
||||||
|
print(f"95th percentile: {p95_time:.2f}s")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_forecasts(self):
|
||||||
|
"""Test concurrent forecast generation"""
|
||||||
|
|
||||||
|
base_url = "http://localhost:8000"
|
||||||
|
|
||||||
|
async def make_forecast_request(product_id):
|
||||||
|
forecast_request = {
|
||||||
|
"tenant_id": "perf-test-tenant",
|
||||||
|
"product_name": f"Product_{product_id}",
|
||||||
|
"location": "madrid_centro",
|
||||||
|
"forecast_date": "2024-01-17",
|
||||||
|
"business_type": "individual"
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
start_time = time.time()
|
||||||
|
response = await client.post(
|
||||||
|
f"{base_url}/api/v1/forecasting/single",
|
||||||
|
json=forecast_request
|
||||||
|
)
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status_code": response.status_code,
|
||||||
|
"response_time": end_time - start_time,
|
||||||
|
"product_id": product_id
|
||||||
|
}
|
||||||
|
|
||||||
|
# Run 20 concurrent requests
|
||||||
|
tasks = [make_forecast_request(i) for i in range(20)]
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
# Analyze results
|
||||||
|
successful = [r for r in results if isinstance(r, dict) and r["status_code"] == 200]
|
||||||
|
failed = [r for r in results if not isinstance(r, dict) or r["status_code"] != 200]
|
||||||
|
|
||||||
|
success_rate = len(successful) / len(results)
|
||||||
|
|
||||||
|
assert success_rate >= 0.95, f"Success rate {success_rate} below 95%"
|
||||||
|
|
||||||
|
if successful:
|
||||||
|
avg_concurrent_time = statistics.mean([r["response_time"] for r in successful])
|
||||||
|
assert avg_concurrent_time < 10.0, f"Average concurrent time {avg_concurrent_time}s exceeds 10s"
|
||||||
|
|
||||||
|
print(f"Concurrent success rate: {success_rate:.2%}")
|
||||||
|
print(f"Average concurrent response time: {avg_concurrent_time:.2f}s")
|
||||||
|
|
||||||
135
services/forecasting/tests/test_forecasting.py
Normal file
135
services/forecasting/tests/test_forecasting.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
# ================================================================
|
||||||
|
# services/forecasting/tests/test_forecasting.py
|
||||||
|
# ================================================================
|
||||||
|
"""
|
||||||
|
Tests for forecasting service
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
from datetime import date, datetime, timedelta
|
||||||
|
from unittest.mock import Mock, AsyncMock, patch
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from app.services.forecasting_service import ForecastingService
|
||||||
|
from app.schemas.forecasts import ForecastRequest, BusinessType
|
||||||
|
from app.models.forecasts import Forecast
|
||||||
|
|
||||||
|
|
||||||
|
class TestForecastingService:
|
||||||
|
"""Test cases for ForecastingService"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def forecasting_service(self):
|
||||||
|
return ForecastingService()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_forecast_request(self):
|
||||||
|
return ForecastRequest(
|
||||||
|
tenant_id=str(uuid.uuid4()),
|
||||||
|
product_name="Pan Integral",
|
||||||
|
location="madrid_centro",
|
||||||
|
forecast_date=date.today() + timedelta(days=1),
|
||||||
|
business_type=BusinessType.INDIVIDUAL,
|
||||||
|
include_weather=True,
|
||||||
|
include_traffic=True,
|
||||||
|
confidence_level=0.8
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_forecast_success(self, forecasting_service, sample_forecast_request):
|
||||||
|
"""Test successful forecast generation"""
|
||||||
|
|
||||||
|
# Mock database session
|
||||||
|
mock_db = AsyncMock()
|
||||||
|
|
||||||
|
# Mock external dependencies
|
||||||
|
with patch.object(forecasting_service, '_get_latest_model') as mock_get_model, \
|
||||||
|
patch.object(forecasting_service, '_prepare_forecast_features') as mock_prepare_features, \
|
||||||
|
patch.object(forecasting_service.prediction_service, 'predict') as mock_predict, \
|
||||||
|
patch.object(forecasting_service, '_check_and_create_alerts') as mock_check_alerts:
|
||||||
|
|
||||||
|
# Setup mocks
|
||||||
|
mock_get_model.return_value = {
|
||||||
|
"model_id": str(uuid.uuid4()),
|
||||||
|
"version": "1.0.0",
|
||||||
|
"algorithm": "prophet"
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_prepare_features.return_value = {
|
||||||
|
"date": "2024-01-16",
|
||||||
|
"day_of_week": 1,
|
||||||
|
"is_weekend": False,
|
||||||
|
"is_holiday": False,
|
||||||
|
"temperature": 15.0,
|
||||||
|
"precipitation": 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_predict.return_value = {
|
||||||
|
"demand": 85.5,
|
||||||
|
"lower_bound": 70.2,
|
||||||
|
"upper_bound": 100.8
|
||||||
|
}
|
||||||
|
|
||||||
|
# Execute test
|
||||||
|
result = await forecasting_service.generate_forecast(sample_forecast_request, mock_db)
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert isinstance(result, Forecast)
|
||||||
|
assert result.product_name == "Pan Integral"
|
||||||
|
assert result.predicted_demand == 85.5
|
||||||
|
assert result.confidence_lower == 70.2
|
||||||
|
assert result.confidence_upper == 100.8
|
||||||
|
|
||||||
|
# Verify mocks were called
|
||||||
|
mock_get_model.assert_called_once()
|
||||||
|
mock_prepare_features.assert_called_once()
|
||||||
|
mock_predict.assert_called_once()
|
||||||
|
mock_check_alerts.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_forecast_no_model(self, forecasting_service, sample_forecast_request):
|
||||||
|
"""Test forecast generation when no model is found"""
|
||||||
|
|
||||||
|
mock_db = AsyncMock()
|
||||||
|
|
||||||
|
with patch.object(forecasting_service, '_get_latest_model') as mock_get_model:
|
||||||
|
mock_get_model.return_value = None
|
||||||
|
|
||||||
|
# Should raise ValueError
|
||||||
|
with pytest.raises(ValueError, match="No trained model found"):
|
||||||
|
await forecasting_service.generate_forecast(sample_forecast_request, mock_db)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prepare_forecast_features(self, forecasting_service, sample_forecast_request):
|
||||||
|
"""Test feature preparation for forecasting"""
|
||||||
|
|
||||||
|
with patch.object(forecasting_service, '_is_spanish_holiday') as mock_holiday, \
|
||||||
|
patch.object(forecasting_service, '_get_weather_forecast') as mock_weather, \
|
||||||
|
patch.object(forecasting_service, '_get_traffic_forecast') as mock_traffic:
|
||||||
|
|
||||||
|
# Setup mocks
|
||||||
|
mock_holiday.return_value = False
|
||||||
|
mock_weather.return_value = {
|
||||||
|
"temperature": 18.5,
|
||||||
|
"precipitation": 0.0,
|
||||||
|
"humidity": 65.0,
|
||||||
|
"weather_description": "Clear"
|
||||||
|
}
|
||||||
|
mock_traffic.return_value = {
|
||||||
|
"traffic_volume": 1200,
|
||||||
|
"pedestrian_count": 850
|
||||||
|
}
|
||||||
|
|
||||||
|
# Execute test
|
||||||
|
features = await forecasting_service._prepare_forecast_features(sample_forecast_request)
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert "date" in features
|
||||||
|
assert "day_of_week" in features
|
||||||
|
assert "is_weekend" in features
|
||||||
|
assert "is_holiday" in features
|
||||||
|
assert features["business_type"] == "individual"
|
||||||
|
assert features["temperature"] == 18.5
|
||||||
|
assert features["traffic_volume"] == 1200
|
||||||
|
|
||||||
Reference in New Issue
Block a user