Add role-based filtering and imporve code
This commit is contained in:
@@ -12,6 +12,7 @@ from app.services.prediction_service import PredictionService
|
||||
from shared.database.base import create_database_manager
|
||||
from app.core.config import settings
|
||||
from shared.routing import RouteBuilder
|
||||
from shared.auth.access_control import analytics_tier_required
|
||||
|
||||
route_builder = RouteBuilder('forecasting')
|
||||
logger = structlog.get_logger()
|
||||
@@ -27,13 +28,14 @@ def get_enhanced_prediction_service():
|
||||
@router.get(
|
||||
route_builder.build_analytics_route("predictions-performance")
|
||||
)
|
||||
@analytics_tier_required
|
||||
async def get_predictions_performance(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
start_date: Optional[date] = Query(None),
|
||||
end_date: Optional[date] = Query(None),
|
||||
prediction_service: PredictionService = Depends(get_enhanced_prediction_service)
|
||||
):
|
||||
"""Get predictions performance analytics"""
|
||||
"""Get predictions performance analytics (Professional+ tier required)"""
|
||||
try:
|
||||
logger.info("Getting predictions performance", tenant_id=tenant_id)
|
||||
|
||||
|
||||
@@ -23,11 +23,22 @@ from shared.monitoring.metrics import get_metrics_collector
|
||||
from app.core.config import settings
|
||||
from shared.routing import RouteBuilder
|
||||
from shared.auth.access_control import require_user_role
|
||||
from shared.security import create_audit_logger, create_rate_limiter, AuditSeverity, AuditAction
|
||||
from shared.subscription.plans import get_forecast_quota, get_forecast_horizon_limit
|
||||
from shared.redis_utils import get_redis_client
|
||||
|
||||
route_builder = RouteBuilder('forecasting')
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter(tags=["forecasting-operations"])
|
||||
|
||||
# Initialize audit logger
|
||||
audit_logger = create_audit_logger("forecasting-service")
|
||||
|
||||
async def get_rate_limiter():
|
||||
"""Dependency for rate limiter"""
|
||||
redis_client = await get_redis_client()
|
||||
return create_rate_limiter(redis_client)
|
||||
|
||||
|
||||
def get_enhanced_forecasting_service():
|
||||
"""Dependency injection for EnhancedForecastingService"""
|
||||
@@ -194,16 +205,17 @@ async def generate_multi_day_forecast(
|
||||
route_builder.build_operations_route("batch"),
|
||||
response_model=BatchForecastResponse
|
||||
)
|
||||
@require_user_role(['viewer', 'member', 'admin', 'owner'])
|
||||
@require_user_role(['admin', 'owner'])
|
||||
@track_execution_time("enhanced_batch_forecast_duration_seconds", "forecasting-service")
|
||||
async def generate_batch_forecast(
|
||||
request: BatchForecastRequest,
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
request_obj: Request = None,
|
||||
current_user: dict = Depends(get_current_user_dep),
|
||||
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
||||
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service),
|
||||
rate_limiter = Depends(get_rate_limiter)
|
||||
):
|
||||
"""Generate forecasts for multiple products in batch"""
|
||||
"""Generate forecasts for multiple products in batch (Admin+ only, quota enforced)"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
@@ -217,6 +229,24 @@ async def generate_batch_forecast(
|
||||
if not request.inventory_product_ids:
|
||||
raise ValueError("inventory_product_ids cannot be empty")
|
||||
|
||||
# Get subscription tier and enforce quotas
|
||||
tier = current_user.get('subscription_tier', 'starter')
|
||||
|
||||
# Check daily quota for forecast generation
|
||||
quota_limit = get_forecast_quota(tier)
|
||||
quota_result = await rate_limiter.check_and_increment_quota(
|
||||
tenant_id,
|
||||
"forecast_generation",
|
||||
quota_limit,
|
||||
period=86400 # 24 hours
|
||||
)
|
||||
|
||||
# Validate forecast horizon if specified
|
||||
if request.horizon_days:
|
||||
await rate_limiter.validate_forecast_horizon(
|
||||
tenant_id, request.horizon_days, tier
|
||||
)
|
||||
|
||||
batch_result = await enhanced_forecasting_service.generate_batch_forecast(
|
||||
tenant_id=tenant_id,
|
||||
request=request
|
||||
|
||||
@@ -26,7 +26,7 @@ from shared.monitoring.decorators import track_execution_time
|
||||
from shared.monitoring.metrics import get_metrics_collector
|
||||
from app.core.config import settings
|
||||
from shared.routing import RouteBuilder
|
||||
from shared.auth.access_control import require_user_role
|
||||
from shared.auth.access_control import require_user_role, enterprise_tier_required
|
||||
|
||||
route_builder = RouteBuilder('forecasting')
|
||||
logger = structlog.get_logger()
|
||||
@@ -43,12 +43,14 @@ def get_enhanced_forecasting_service():
|
||||
route_builder.build_analytics_route("scenario-simulation"),
|
||||
response_model=ScenarioSimulationResponse
|
||||
)
|
||||
@require_user_role(['viewer', 'member', 'admin', 'owner'])
|
||||
@require_user_role(['admin', 'owner'])
|
||||
@enterprise_tier_required
|
||||
@track_execution_time("scenario_simulation_duration_seconds", "forecasting-service")
|
||||
async def simulate_scenario(
|
||||
request: ScenarioSimulationRequest,
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
request_obj: Request = None,
|
||||
current_user: dict = Depends(get_current_user_dep),
|
||||
forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
||||
):
|
||||
"""
|
||||
@@ -62,7 +64,7 @@ async def simulate_scenario(
|
||||
- Promotions
|
||||
- Supply disruptions
|
||||
|
||||
**PROFESSIONAL/ENTERPRISE ONLY**
|
||||
**ENTERPRISE TIER ONLY - Admin+ role required**
|
||||
"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
start_time = datetime.now(timezone.utc)
|
||||
|
||||
@@ -4,6 +4,13 @@ Forecasting Service Models Package
|
||||
Import all models to ensure they are registered with SQLAlchemy Base.
|
||||
"""
|
||||
|
||||
# Import AuditLog model for this service
|
||||
from shared.security import create_audit_log_model
|
||||
from shared.database.base import Base
|
||||
|
||||
# Create audit log model for this service
|
||||
AuditLog = create_audit_log_model(Base)
|
||||
|
||||
# Import all models to register them with the Base metadata
|
||||
from .forecasts import Forecast, PredictionBatch
|
||||
from .predictions import ModelPerformanceMetric, PredictionCache
|
||||
@@ -14,4 +21,5 @@ __all__ = [
|
||||
"PredictionBatch",
|
||||
"ModelPerformanceMetric",
|
||||
"PredictionCache",
|
||||
"AuditLog",
|
||||
]
|
||||
@@ -14,11 +14,11 @@ Cache Strategy:
|
||||
"""
|
||||
|
||||
import json
|
||||
import redis
|
||||
from datetime import datetime, date, timedelta
|
||||
from typing import Optional, Dict, Any, List
|
||||
from uuid import UUID
|
||||
import structlog
|
||||
from shared.redis_utils import get_redis_client
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
@@ -26,47 +26,20 @@ logger = structlog.get_logger()
|
||||
class ForecastCacheService:
|
||||
"""Service-level caching for forecast predictions"""
|
||||
|
||||
def __init__(self, redis_url: str):
|
||||
"""
|
||||
Initialize Redis connection for forecast caching
|
||||
def __init__(self):
|
||||
"""Initialize forecast cache service"""
|
||||
pass
|
||||
|
||||
Args:
|
||||
redis_url: Redis connection URL
|
||||
"""
|
||||
self.redis_url = redis_url
|
||||
self._redis_client = None
|
||||
self._connect()
|
||||
async def _get_redis(self):
|
||||
"""Get shared Redis client"""
|
||||
return await get_redis_client()
|
||||
|
||||
def _connect(self):
|
||||
"""Establish Redis connection with retry logic"""
|
||||
try:
|
||||
self._redis_client = redis.from_url(
|
||||
self.redis_url,
|
||||
decode_responses=True,
|
||||
socket_keepalive=True,
|
||||
socket_keepalive_options={1: 1, 3: 3, 5: 5},
|
||||
retry_on_timeout=True,
|
||||
max_connections=100, # Higher limit for forecast service
|
||||
health_check_interval=30
|
||||
)
|
||||
# Test connection
|
||||
self._redis_client.ping()
|
||||
logger.info("Forecast cache Redis connection established")
|
||||
except Exception as e:
|
||||
logger.error("Failed to connect to forecast cache Redis", error=str(e))
|
||||
self._redis_client = None
|
||||
|
||||
@property
|
||||
def redis(self):
|
||||
"""Get Redis client with connection check"""
|
||||
if self._redis_client is None:
|
||||
self._connect()
|
||||
return self._redis_client
|
||||
|
||||
def is_available(self) -> bool:
|
||||
async def is_available(self) -> bool:
|
||||
"""Check if Redis cache is available"""
|
||||
try:
|
||||
return self.redis is not None and self.redis.ping()
|
||||
client = await self._get_redis()
|
||||
await client.ping()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@@ -138,12 +111,13 @@ class ForecastCacheService:
|
||||
Returns:
|
||||
Cached forecast data or None if not found
|
||||
"""
|
||||
if not self.is_available():
|
||||
if not await self.is_available():
|
||||
return None
|
||||
|
||||
try:
|
||||
key = self._get_forecast_key(tenant_id, product_id, forecast_date)
|
||||
cached_data = self.redis.get(key)
|
||||
client = await self._get_redis()
|
||||
cached_data = await client.get(key)
|
||||
|
||||
if cached_data:
|
||||
forecast_data = json.loads(cached_data)
|
||||
@@ -188,7 +162,7 @@ class ForecastCacheService:
|
||||
Returns:
|
||||
True if cached successfully, False otherwise
|
||||
"""
|
||||
if not self.is_available():
|
||||
if not await self.is_available():
|
||||
logger.warning("Redis not available, skipping forecast cache")
|
||||
return False
|
||||
|
||||
@@ -205,7 +179,8 @@ class ForecastCacheService:
|
||||
}
|
||||
|
||||
# Serialize and cache
|
||||
self.redis.setex(
|
||||
client = await self._get_redis()
|
||||
await client.setex(
|
||||
key,
|
||||
ttl,
|
||||
json.dumps(cache_entry, default=str)
|
||||
@@ -241,12 +216,13 @@ class ForecastCacheService:
|
||||
Returns:
|
||||
Cached batch forecast data or None
|
||||
"""
|
||||
if not self.is_available():
|
||||
if not await self.is_available():
|
||||
return None
|
||||
|
||||
try:
|
||||
key = self._get_batch_forecast_key(tenant_id, product_ids, forecast_date)
|
||||
cached_data = self.redis.get(key)
|
||||
client = await self._get_redis()
|
||||
cached_data = await client.get(key)
|
||||
|
||||
if cached_data:
|
||||
forecast_data = json.loads(cached_data)
|
||||
@@ -273,7 +249,7 @@ class ForecastCacheService:
|
||||
forecast_data: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""Cache batch forecast result"""
|
||||
if not self.is_available():
|
||||
if not await self.is_available():
|
||||
return False
|
||||
|
||||
try:
|
||||
@@ -287,7 +263,8 @@ class ForecastCacheService:
|
||||
'ttl_seconds': ttl
|
||||
}
|
||||
|
||||
self.redis.setex(key, ttl, json.dumps(cache_entry, default=str))
|
||||
client = await self._get_redis()
|
||||
await client.setex(key, ttl, json.dumps(cache_entry, default=str))
|
||||
|
||||
logger.info("Batch forecast cached successfully",
|
||||
tenant_id=str(tenant_id),
|
||||
@@ -320,16 +297,17 @@ class ForecastCacheService:
|
||||
Returns:
|
||||
Number of cache entries invalidated
|
||||
"""
|
||||
if not self.is_available():
|
||||
if not await self.is_available():
|
||||
return 0
|
||||
|
||||
try:
|
||||
# Find all keys matching this product
|
||||
pattern = f"forecast:{tenant_id}:{product_id}:*"
|
||||
keys = self.redis.keys(pattern)
|
||||
client = await self._get_redis()
|
||||
keys = await client.keys(pattern)
|
||||
|
||||
if keys:
|
||||
deleted = self.redis.delete(*keys)
|
||||
deleted = await client.delete(*keys)
|
||||
logger.info("Invalidated product forecast cache",
|
||||
tenant_id=str(tenant_id),
|
||||
product_id=str(product_id),
|
||||
@@ -359,7 +337,7 @@ class ForecastCacheService:
|
||||
Returns:
|
||||
Number of cache entries invalidated
|
||||
"""
|
||||
if not self.is_available():
|
||||
if not await self.is_available():
|
||||
return 0
|
||||
|
||||
try:
|
||||
@@ -368,10 +346,11 @@ class ForecastCacheService:
|
||||
else:
|
||||
pattern = f"forecast:{tenant_id}:*"
|
||||
|
||||
keys = self.redis.keys(pattern)
|
||||
client = await self._get_redis()
|
||||
keys = await client.keys(pattern)
|
||||
|
||||
if keys:
|
||||
deleted = self.redis.delete(*keys)
|
||||
deleted = await client.delete(*keys)
|
||||
logger.info("Invalidated tenant forecast cache",
|
||||
tenant_id=str(tenant_id),
|
||||
forecast_date=str(forecast_date) if forecast_date else "all",
|
||||
@@ -391,15 +370,16 @@ class ForecastCacheService:
|
||||
Returns:
|
||||
Number of cache entries invalidated
|
||||
"""
|
||||
if not self.is_available():
|
||||
if not await self.is_available():
|
||||
return 0
|
||||
|
||||
try:
|
||||
pattern = "forecast:*"
|
||||
keys = self.redis.keys(pattern)
|
||||
client = await self._get_redis()
|
||||
keys = await client.keys(pattern)
|
||||
|
||||
if keys:
|
||||
deleted = self.redis.delete(*keys)
|
||||
deleted = await client.delete(*keys)
|
||||
logger.warning("Invalidated ALL forecast cache", keys_deleted=deleted)
|
||||
return deleted
|
||||
|
||||
@@ -413,22 +393,23 @@ class ForecastCacheService:
|
||||
# CACHE STATISTICS & MONITORING
|
||||
# ================================================================
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
async def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get cache statistics for monitoring
|
||||
|
||||
Returns:
|
||||
Dictionary with cache metrics
|
||||
"""
|
||||
if not self.is_available():
|
||||
if not await self.is_available():
|
||||
return {"available": False}
|
||||
|
||||
try:
|
||||
info = self.redis.info()
|
||||
client = await self._get_redis()
|
||||
info = await client.info()
|
||||
|
||||
# Get forecast-specific stats
|
||||
forecast_keys = self.redis.keys("forecast:*")
|
||||
batch_keys = self.redis.keys("forecast:batch:*")
|
||||
forecast_keys = await client.keys("forecast:*")
|
||||
batch_keys = await client.keys("forecast:batch:*")
|
||||
|
||||
return {
|
||||
"available": True,
|
||||
@@ -471,12 +452,13 @@ class ForecastCacheService:
|
||||
Returns:
|
||||
Cache metadata or None
|
||||
"""
|
||||
if not self.is_available():
|
||||
if not await self.is_available():
|
||||
return None
|
||||
|
||||
try:
|
||||
key = self._get_forecast_key(tenant_id, product_id, forecast_date)
|
||||
ttl = self.redis.ttl(key)
|
||||
client = await self._get_redis()
|
||||
ttl = await client.ttl(key)
|
||||
|
||||
if ttl > 0:
|
||||
return {
|
||||
@@ -498,21 +480,16 @@ class ForecastCacheService:
|
||||
_cache_service = None
|
||||
|
||||
|
||||
def get_forecast_cache_service(redis_url: Optional[str] = None) -> ForecastCacheService:
|
||||
def get_forecast_cache_service() -> ForecastCacheService:
|
||||
"""
|
||||
Get the global forecast cache service instance
|
||||
|
||||
Args:
|
||||
redis_url: Redis connection URL (required for first call)
|
||||
|
||||
Returns:
|
||||
ForecastCacheService instance
|
||||
"""
|
||||
global _cache_service
|
||||
|
||||
if _cache_service is None:
|
||||
if redis_url is None:
|
||||
raise ValueError("redis_url required for first initialization")
|
||||
_cache_service = ForecastCacheService(redis_url)
|
||||
_cache_service = ForecastCacheService()
|
||||
|
||||
return _cache_service
|
||||
|
||||
Reference in New Issue
Block a user