Add improvements
This commit is contained in:
@@ -9,6 +9,7 @@ from fastapi.responses import JSONResponse
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import date, datetime, timezone
|
||||
import uuid
|
||||
from uuid import UUID
|
||||
|
||||
from app.services.forecasting_service import EnhancedForecastingService
|
||||
from app.services.prediction_service import PredictionService
|
||||
@@ -42,6 +43,30 @@ async def get_rate_limiter():
|
||||
return create_rate_limiter(redis_client)
|
||||
|
||||
|
||||
def validate_uuid(value: str, field_name: str = "ID") -> str:
|
||||
"""
|
||||
Validate that a string is a valid UUID.
|
||||
|
||||
Args:
|
||||
value: The string to validate
|
||||
field_name: Name of the field for error messages
|
||||
|
||||
Returns:
|
||||
The validated UUID string
|
||||
|
||||
Raises:
|
||||
HTTPException: If the value is not a valid UUID
|
||||
"""
|
||||
try:
|
||||
UUID(value)
|
||||
return value
|
||||
except (ValueError, AttributeError):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"{field_name} must be a valid UUID, got: {value}"
|
||||
)
|
||||
|
||||
|
||||
def get_enhanced_forecasting_service():
|
||||
"""Dependency injection for EnhancedForecastingService"""
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "forecasting-service")
|
||||
@@ -68,6 +93,10 @@ async def generate_single_forecast(
|
||||
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
||||
):
|
||||
"""Generate a single product forecast with caching support"""
|
||||
# Validate UUID fields
|
||||
validate_uuid(tenant_id, "tenant_id")
|
||||
# inventory_product_id already validated by ForecastRequest schema
|
||||
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
|
||||
@@ -28,15 +28,6 @@ router = APIRouter(prefix="/internal/demo", tags=["internal"])
|
||||
DEMO_TENANT_PROFESSIONAL = "a1b2c3d4-e5f6-47a8-b9c0-d1e2f3a4b5c6"
|
||||
|
||||
|
||||
def verify_internal_api_key(x_internal_api_key: Optional[str] = Header(None)):
|
||||
"""Verify internal API key for service-to-service communication"""
|
||||
from app.core.config import settings
|
||||
if x_internal_api_key != settings.INTERNAL_API_KEY:
|
||||
logger.warning("Unauthorized internal API access attempted")
|
||||
raise HTTPException(status_code=403, detail="Invalid internal API key")
|
||||
return True
|
||||
|
||||
|
||||
def parse_date_field(date_value, session_time: datetime, field_name: str = "date") -> Optional[datetime]:
|
||||
"""
|
||||
Parse date field, handling both ISO strings and BASE_TS markers.
|
||||
@@ -98,8 +89,7 @@ async def clone_demo_data(
|
||||
demo_account_type: str,
|
||||
session_id: Optional[str] = None,
|
||||
session_created_at: Optional[str] = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: bool = Depends(verify_internal_api_key)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Clone forecasting service data for a virtual demo tenant
|
||||
@@ -406,7 +396,7 @@ async def clone_demo_data(
|
||||
|
||||
|
||||
@router.get("/clone/health")
|
||||
async def clone_health_check(_: bool = Depends(verify_internal_api_key)):
|
||||
async def clone_health_check():
|
||||
"""
|
||||
Health check for internal cloning endpoint
|
||||
Used by orchestrator to verify service availability
|
||||
@@ -421,8 +411,7 @@ async def clone_health_check(_: bool = Depends(verify_internal_api_key)):
|
||||
@router.delete("/tenant/{virtual_tenant_id}")
|
||||
async def delete_demo_tenant_data(
|
||||
virtual_tenant_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: bool = Depends(verify_internal_api_key)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Delete all demo data for a virtual tenant.
|
||||
|
||||
Reference in New Issue
Block a user