Add role-based filtering and imporve code

This commit is contained in:
Urtzi Alfaro
2025-10-15 16:12:49 +02:00
parent 96ad5c6692
commit 8f9e9a7edc
158 changed files with 11033 additions and 1544 deletions

View File

@@ -16,9 +16,13 @@ RUN apt-get update && apt-get install -y \
&& rm -rf /var/lib/apt/lists/*
# Copy requirements
COPY shared/requirements-tracing.txt /tmp/
COPY services/forecasting/requirements.txt .
# Install Python dependencies
RUN pip install --no-cache-dir -r /tmp/requirements-tracing.txt
RUN pip install --no-cache-dir -r requirements.txt
# Copy shared libraries from the shared stage

View File

@@ -12,6 +12,7 @@ from app.services.prediction_service import PredictionService
from shared.database.base import create_database_manager
from app.core.config import settings
from shared.routing import RouteBuilder
from shared.auth.access_control import analytics_tier_required
route_builder = RouteBuilder('forecasting')
logger = structlog.get_logger()
@@ -27,13 +28,14 @@ def get_enhanced_prediction_service():
@router.get(
route_builder.build_analytics_route("predictions-performance")
)
@analytics_tier_required
async def get_predictions_performance(
tenant_id: str = Path(..., description="Tenant ID"),
start_date: Optional[date] = Query(None),
end_date: Optional[date] = Query(None),
prediction_service: PredictionService = Depends(get_enhanced_prediction_service)
):
"""Get predictions performance analytics"""
"""Get predictions performance analytics (Professional+ tier required)"""
try:
logger.info("Getting predictions performance", tenant_id=tenant_id)

View File

@@ -23,11 +23,22 @@ from shared.monitoring.metrics import get_metrics_collector
from app.core.config import settings
from shared.routing import RouteBuilder
from shared.auth.access_control import require_user_role
from shared.security import create_audit_logger, create_rate_limiter, AuditSeverity, AuditAction
from shared.subscription.plans import get_forecast_quota, get_forecast_horizon_limit
from shared.redis_utils import get_redis_client
route_builder = RouteBuilder('forecasting')
logger = structlog.get_logger()
router = APIRouter(tags=["forecasting-operations"])
# Initialize audit logger
audit_logger = create_audit_logger("forecasting-service")
async def get_rate_limiter():
"""Dependency for rate limiter"""
redis_client = await get_redis_client()
return create_rate_limiter(redis_client)
def get_enhanced_forecasting_service():
"""Dependency injection for EnhancedForecastingService"""
@@ -194,16 +205,17 @@ async def generate_multi_day_forecast(
route_builder.build_operations_route("batch"),
response_model=BatchForecastResponse
)
@require_user_role(['viewer', 'member', 'admin', 'owner'])
@require_user_role(['admin', 'owner'])
@track_execution_time("enhanced_batch_forecast_duration_seconds", "forecasting-service")
async def generate_batch_forecast(
request: BatchForecastRequest,
tenant_id: str = Path(..., description="Tenant ID"),
request_obj: Request = None,
current_user: dict = Depends(get_current_user_dep),
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service),
rate_limiter = Depends(get_rate_limiter)
):
"""Generate forecasts for multiple products in batch"""
"""Generate forecasts for multiple products in batch (Admin+ only, quota enforced)"""
metrics = get_metrics_collector(request_obj)
try:
@@ -217,6 +229,24 @@ async def generate_batch_forecast(
if not request.inventory_product_ids:
raise ValueError("inventory_product_ids cannot be empty")
# Get subscription tier and enforce quotas
tier = current_user.get('subscription_tier', 'starter')
# Check daily quota for forecast generation
quota_limit = get_forecast_quota(tier)
quota_result = await rate_limiter.check_and_increment_quota(
tenant_id,
"forecast_generation",
quota_limit,
period=86400 # 24 hours
)
# Validate forecast horizon if specified
if request.horizon_days:
await rate_limiter.validate_forecast_horizon(
tenant_id, request.horizon_days, tier
)
batch_result = await enhanced_forecasting_service.generate_batch_forecast(
tenant_id=tenant_id,
request=request

View File

@@ -26,7 +26,7 @@ from shared.monitoring.decorators import track_execution_time
from shared.monitoring.metrics import get_metrics_collector
from app.core.config import settings
from shared.routing import RouteBuilder
from shared.auth.access_control import require_user_role
from shared.auth.access_control import require_user_role, enterprise_tier_required
route_builder = RouteBuilder('forecasting')
logger = structlog.get_logger()
@@ -43,12 +43,14 @@ def get_enhanced_forecasting_service():
route_builder.build_analytics_route("scenario-simulation"),
response_model=ScenarioSimulationResponse
)
@require_user_role(['viewer', 'member', 'admin', 'owner'])
@require_user_role(['admin', 'owner'])
@enterprise_tier_required
@track_execution_time("scenario_simulation_duration_seconds", "forecasting-service")
async def simulate_scenario(
request: ScenarioSimulationRequest,
tenant_id: str = Path(..., description="Tenant ID"),
request_obj: Request = None,
current_user: dict = Depends(get_current_user_dep),
forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
):
"""
@@ -62,7 +64,7 @@ async def simulate_scenario(
- Promotions
- Supply disruptions
**PROFESSIONAL/ENTERPRISE ONLY**
**ENTERPRISE TIER ONLY - Admin+ role required**
"""
metrics = get_metrics_collector(request_obj)
start_time = datetime.now(timezone.utc)

View File

@@ -4,6 +4,13 @@ Forecasting Service Models Package
Import all models to ensure they are registered with SQLAlchemy Base.
"""
# Import AuditLog model for this service
from shared.security import create_audit_log_model
from shared.database.base import Base
# Create audit log model for this service
AuditLog = create_audit_log_model(Base)
# Import all models to register them with the Base metadata
from .forecasts import Forecast, PredictionBatch
from .predictions import ModelPerformanceMetric, PredictionCache
@@ -14,4 +21,5 @@ __all__ = [
"PredictionBatch",
"ModelPerformanceMetric",
"PredictionCache",
"AuditLog",
]

View File

@@ -14,11 +14,11 @@ Cache Strategy:
"""
import json
import redis
from datetime import datetime, date, timedelta
from typing import Optional, Dict, Any, List
from uuid import UUID
import structlog
from shared.redis_utils import get_redis_client
logger = structlog.get_logger()
@@ -26,47 +26,20 @@ logger = structlog.get_logger()
class ForecastCacheService:
"""Service-level caching for forecast predictions"""
def __init__(self, redis_url: str):
"""
Initialize Redis connection for forecast caching
def __init__(self):
"""Initialize forecast cache service"""
pass
Args:
redis_url: Redis connection URL
"""
self.redis_url = redis_url
self._redis_client = None
self._connect()
async def _get_redis(self):
"""Get shared Redis client"""
return await get_redis_client()
def _connect(self):
"""Establish Redis connection with retry logic"""
try:
self._redis_client = redis.from_url(
self.redis_url,
decode_responses=True,
socket_keepalive=True,
socket_keepalive_options={1: 1, 3: 3, 5: 5},
retry_on_timeout=True,
max_connections=100, # Higher limit for forecast service
health_check_interval=30
)
# Test connection
self._redis_client.ping()
logger.info("Forecast cache Redis connection established")
except Exception as e:
logger.error("Failed to connect to forecast cache Redis", error=str(e))
self._redis_client = None
@property
def redis(self):
"""Get Redis client with connection check"""
if self._redis_client is None:
self._connect()
return self._redis_client
def is_available(self) -> bool:
async def is_available(self) -> bool:
"""Check if Redis cache is available"""
try:
return self.redis is not None and self.redis.ping()
client = await self._get_redis()
await client.ping()
return True
except Exception:
return False
@@ -138,12 +111,13 @@ class ForecastCacheService:
Returns:
Cached forecast data or None if not found
"""
if not self.is_available():
if not await self.is_available():
return None
try:
key = self._get_forecast_key(tenant_id, product_id, forecast_date)
cached_data = self.redis.get(key)
client = await self._get_redis()
cached_data = await client.get(key)
if cached_data:
forecast_data = json.loads(cached_data)
@@ -188,7 +162,7 @@ class ForecastCacheService:
Returns:
True if cached successfully, False otherwise
"""
if not self.is_available():
if not await self.is_available():
logger.warning("Redis not available, skipping forecast cache")
return False
@@ -205,7 +179,8 @@ class ForecastCacheService:
}
# Serialize and cache
self.redis.setex(
client = await self._get_redis()
await client.setex(
key,
ttl,
json.dumps(cache_entry, default=str)
@@ -241,12 +216,13 @@ class ForecastCacheService:
Returns:
Cached batch forecast data or None
"""
if not self.is_available():
if not await self.is_available():
return None
try:
key = self._get_batch_forecast_key(tenant_id, product_ids, forecast_date)
cached_data = self.redis.get(key)
client = await self._get_redis()
cached_data = await client.get(key)
if cached_data:
forecast_data = json.loads(cached_data)
@@ -273,7 +249,7 @@ class ForecastCacheService:
forecast_data: Dict[str, Any]
) -> bool:
"""Cache batch forecast result"""
if not self.is_available():
if not await self.is_available():
return False
try:
@@ -287,7 +263,8 @@ class ForecastCacheService:
'ttl_seconds': ttl
}
self.redis.setex(key, ttl, json.dumps(cache_entry, default=str))
client = await self._get_redis()
await client.setex(key, ttl, json.dumps(cache_entry, default=str))
logger.info("Batch forecast cached successfully",
tenant_id=str(tenant_id),
@@ -320,16 +297,17 @@ class ForecastCacheService:
Returns:
Number of cache entries invalidated
"""
if not self.is_available():
if not await self.is_available():
return 0
try:
# Find all keys matching this product
pattern = f"forecast:{tenant_id}:{product_id}:*"
keys = self.redis.keys(pattern)
client = await self._get_redis()
keys = await client.keys(pattern)
if keys:
deleted = self.redis.delete(*keys)
deleted = await client.delete(*keys)
logger.info("Invalidated product forecast cache",
tenant_id=str(tenant_id),
product_id=str(product_id),
@@ -359,7 +337,7 @@ class ForecastCacheService:
Returns:
Number of cache entries invalidated
"""
if not self.is_available():
if not await self.is_available():
return 0
try:
@@ -368,10 +346,11 @@ class ForecastCacheService:
else:
pattern = f"forecast:{tenant_id}:*"
keys = self.redis.keys(pattern)
client = await self._get_redis()
keys = await client.keys(pattern)
if keys:
deleted = self.redis.delete(*keys)
deleted = await client.delete(*keys)
logger.info("Invalidated tenant forecast cache",
tenant_id=str(tenant_id),
forecast_date=str(forecast_date) if forecast_date else "all",
@@ -391,15 +370,16 @@ class ForecastCacheService:
Returns:
Number of cache entries invalidated
"""
if not self.is_available():
if not await self.is_available():
return 0
try:
pattern = "forecast:*"
keys = self.redis.keys(pattern)
client = await self._get_redis()
keys = await client.keys(pattern)
if keys:
deleted = self.redis.delete(*keys)
deleted = await client.delete(*keys)
logger.warning("Invalidated ALL forecast cache", keys_deleted=deleted)
return deleted
@@ -413,22 +393,23 @@ class ForecastCacheService:
# CACHE STATISTICS & MONITORING
# ================================================================
def get_cache_stats(self) -> Dict[str, Any]:
async def get_cache_stats(self) -> Dict[str, Any]:
"""
Get cache statistics for monitoring
Returns:
Dictionary with cache metrics
"""
if not self.is_available():
if not await self.is_available():
return {"available": False}
try:
info = self.redis.info()
client = await self._get_redis()
info = await client.info()
# Get forecast-specific stats
forecast_keys = self.redis.keys("forecast:*")
batch_keys = self.redis.keys("forecast:batch:*")
forecast_keys = await client.keys("forecast:*")
batch_keys = await client.keys("forecast:batch:*")
return {
"available": True,
@@ -471,12 +452,13 @@ class ForecastCacheService:
Returns:
Cache metadata or None
"""
if not self.is_available():
if not await self.is_available():
return None
try:
key = self._get_forecast_key(tenant_id, product_id, forecast_date)
ttl = self.redis.ttl(key)
client = await self._get_redis()
ttl = await client.ttl(key)
if ttl > 0:
return {
@@ -498,21 +480,16 @@ class ForecastCacheService:
_cache_service = None
def get_forecast_cache_service(redis_url: Optional[str] = None) -> ForecastCacheService:
def get_forecast_cache_service() -> ForecastCacheService:
"""
Get the global forecast cache service instance
Args:
redis_url: Redis connection URL (required for first call)
Returns:
ForecastCacheService instance
"""
global _cache_service
if _cache_service is None:
if redis_url is None:
raise ValueError("redis_url required for first initialization")
_cache_service = ForecastCacheService(redis_url)
_cache_service = ForecastCacheService()
return _cache_service

View File

@@ -1,18 +1,18 @@
"""initial_schema_20251009_2039
"""initial_schema_20251015_1230
Revision ID: cae963fbc2af
Revision ID: 301bc59f6dfb
Revises:
Create Date: 2025-10-09 20:39:42.106460+02:00
Create Date: 2025-10-15 12:30:42.311369+02:00
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = 'cae963fbc2af'
revision: str = '301bc59f6dfb'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
@@ -20,6 +20,38 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('audit_logs',
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('tenant_id', sa.UUID(), nullable=False),
sa.Column('user_id', sa.UUID(), nullable=False),
sa.Column('action', sa.String(length=100), nullable=False),
sa.Column('resource_type', sa.String(length=100), nullable=False),
sa.Column('resource_id', sa.String(length=255), nullable=True),
sa.Column('severity', sa.String(length=20), nullable=False),
sa.Column('service_name', sa.String(length=100), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('changes', postgresql.JSON(astext_type=sa.Text()), nullable=True),
sa.Column('audit_metadata', postgresql.JSON(astext_type=sa.Text()), nullable=True),
sa.Column('ip_address', sa.String(length=45), nullable=True),
sa.Column('user_agent', sa.Text(), nullable=True),
sa.Column('endpoint', sa.String(length=255), nullable=True),
sa.Column('method', sa.String(length=10), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_index('idx_audit_resource_type_action', 'audit_logs', ['resource_type', 'action'], unique=False)
op.create_index('idx_audit_service_created', 'audit_logs', ['service_name', 'created_at'], unique=False)
op.create_index('idx_audit_severity_created', 'audit_logs', ['severity', 'created_at'], unique=False)
op.create_index('idx_audit_tenant_created', 'audit_logs', ['tenant_id', 'created_at'], unique=False)
op.create_index('idx_audit_user_created', 'audit_logs', ['user_id', 'created_at'], unique=False)
op.create_index(op.f('ix_audit_logs_action'), 'audit_logs', ['action'], unique=False)
op.create_index(op.f('ix_audit_logs_created_at'), 'audit_logs', ['created_at'], unique=False)
op.create_index(op.f('ix_audit_logs_resource_id'), 'audit_logs', ['resource_id'], unique=False)
op.create_index(op.f('ix_audit_logs_resource_type'), 'audit_logs', ['resource_type'], unique=False)
op.create_index(op.f('ix_audit_logs_service_name'), 'audit_logs', ['service_name'], unique=False)
op.create_index(op.f('ix_audit_logs_severity'), 'audit_logs', ['severity'], unique=False)
op.create_index(op.f('ix_audit_logs_tenant_id'), 'audit_logs', ['tenant_id'], unique=False)
op.create_index(op.f('ix_audit_logs_user_id'), 'audit_logs', ['user_id'], unique=False)
op.create_table('forecasts',
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('tenant_id', sa.UUID(), nullable=False),
@@ -125,4 +157,18 @@ def downgrade() -> None:
op.drop_index(op.f('ix_forecasts_inventory_product_id'), table_name='forecasts')
op.drop_index(op.f('ix_forecasts_forecast_date'), table_name='forecasts')
op.drop_table('forecasts')
op.drop_index(op.f('ix_audit_logs_user_id'), table_name='audit_logs')
op.drop_index(op.f('ix_audit_logs_tenant_id'), table_name='audit_logs')
op.drop_index(op.f('ix_audit_logs_severity'), table_name='audit_logs')
op.drop_index(op.f('ix_audit_logs_service_name'), table_name='audit_logs')
op.drop_index(op.f('ix_audit_logs_resource_type'), table_name='audit_logs')
op.drop_index(op.f('ix_audit_logs_resource_id'), table_name='audit_logs')
op.drop_index(op.f('ix_audit_logs_created_at'), table_name='audit_logs')
op.drop_index(op.f('ix_audit_logs_action'), table_name='audit_logs')
op.drop_index('idx_audit_user_created', table_name='audit_logs')
op.drop_index('idx_audit_tenant_created', table_name='audit_logs')
op.drop_index('idx_audit_severity_created', table_name='audit_logs')
op.drop_index('idx_audit_service_created', table_name='audit_logs')
op.drop_index('idx_audit_resource_type_action', table_name='audit_logs')
op.drop_table('audit_logs')
# ### end Alembic commands ###