REFACTOR - Database logic
This commit is contained in:
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
Forecasting API Layer
|
||||
HTTP endpoints for demand forecasting and prediction operations
|
||||
"""
|
||||
|
||||
from .forecasts import router as forecasts_router
|
||||
|
||||
from .predictions import router as predictions_router
|
||||
|
||||
|
||||
__all__ = [
|
||||
"forecasts_router",
|
||||
|
||||
"predictions_router",
|
||||
|
||||
]
|
||||
@@ -1,494 +1,503 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/api/forecasts.py
|
||||
# ================================================================
|
||||
"""
|
||||
Forecast API endpoints
|
||||
Enhanced Forecast API Endpoints with Repository Pattern
|
||||
Updated to use repository pattern with dependency injection and improved error handling
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query, Path
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query, Path, Request
|
||||
from typing import List, Optional
|
||||
from datetime import date, datetime
|
||||
from sqlalchemy import select, delete, func
|
||||
import uuid
|
||||
|
||||
from app.core.database import get_db
|
||||
from shared.auth.decorators import (
|
||||
get_current_user_dep,
|
||||
require_admin_role
|
||||
)
|
||||
from app.services.forecasting_service import ForecastingService
|
||||
from app.services.forecasting_service import EnhancedForecastingService
|
||||
from app.schemas.forecasts import (
|
||||
ForecastRequest, ForecastResponse, BatchForecastRequest,
|
||||
BatchForecastResponse, AlertResponse
|
||||
)
|
||||
from app.models.forecasts import Forecast, PredictionBatch, ForecastAlert
|
||||
from app.services.messaging import publish_forecasts_deleted_event
|
||||
from shared.auth.decorators import (
|
||||
get_current_user_dep,
|
||||
get_current_tenant_id_dep,
|
||||
require_admin_role
|
||||
)
|
||||
from shared.database.base import create_database_manager
|
||||
from shared.monitoring.decorators import track_execution_time
|
||||
from shared.monitoring.metrics import get_metrics_collector
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter()
|
||||
router = APIRouter(tags=["enhanced-forecasts"])
|
||||
|
||||
# Initialize service
|
||||
forecasting_service = ForecastingService()
|
||||
def get_enhanced_forecasting_service():
|
||||
"""Dependency injection for EnhancedForecastingService"""
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "forecasting-service")
|
||||
return EnhancedForecastingService(database_manager)
|
||||
|
||||
@router.post("/tenants/{tenant_id}/forecasts/single", response_model=ForecastResponse)
|
||||
async def create_single_forecast(
|
||||
@track_execution_time("enhanced_single_forecast_duration_seconds", "forecasting-service")
|
||||
async def create_enhanced_single_forecast(
|
||||
request: ForecastRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
tenant_id: str = Path(..., description="Tenant ID")
|
||||
):
|
||||
"""Generate a single product forecast"""
|
||||
|
||||
try:
|
||||
|
||||
# Generate forecast
|
||||
forecast = await forecasting_service.generate_forecast(tenant_id, request, db)
|
||||
|
||||
# Convert to response model
|
||||
return ForecastResponse(
|
||||
id=str(forecast.id),
|
||||
tenant_id=tenant_id,
|
||||
product_name=forecast.product_name,
|
||||
location=forecast.location,
|
||||
forecast_date=forecast.forecast_date,
|
||||
predicted_demand=forecast.predicted_demand,
|
||||
confidence_lower=forecast.confidence_lower,
|
||||
confidence_upper=forecast.confidence_upper,
|
||||
confidence_level=forecast.confidence_level,
|
||||
model_id=str(forecast.model_id),
|
||||
model_version=forecast.model_version,
|
||||
algorithm=forecast.algorithm,
|
||||
business_type=forecast.business_type,
|
||||
is_holiday=forecast.is_holiday,
|
||||
is_weekend=forecast.is_weekend,
|
||||
day_of_week=forecast.day_of_week,
|
||||
weather_temperature=forecast.weather_temperature,
|
||||
weather_precipitation=forecast.weather_precipitation,
|
||||
weather_description=forecast.weather_description,
|
||||
traffic_volume=forecast.traffic_volume,
|
||||
created_at=forecast.created_at,
|
||||
processing_time_ms=forecast.processing_time_ms,
|
||||
features_used=forecast.features_used
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error creating single forecast", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal server error"
|
||||
)
|
||||
|
||||
@router.post("/tenants/{tenant_id}/forecasts/batch", response_model=BatchForecastResponse)
|
||||
async def create_batch_forecast(
|
||||
request: BatchForecastRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
current_user: dict = Depends(get_current_user_dep)
|
||||
request_obj: Request = None,
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
||||
):
|
||||
"""Generate batch forecasts for multiple products"""
|
||||
"""Generate a single product forecast using enhanced repository pattern"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
# Verify tenant access
|
||||
if str(request.tenant_id) != tenant_id:
|
||||
# Enhanced tenant validation
|
||||
if tenant_id != current_tenant:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_forecast_access_denied_total")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to this tenant"
|
||||
detail="Access denied to tenant resources"
|
||||
)
|
||||
|
||||
# Generate batch forecast
|
||||
batch = await forecasting_service.generate_batch_forecast(request, db)
|
||||
logger.info("Generating enhanced single forecast",
|
||||
tenant_id=tenant_id,
|
||||
product_name=request.product_name,
|
||||
forecast_date=request.forecast_date.isoformat())
|
||||
|
||||
# Get associated forecasts
|
||||
forecasts = await forecasting_service.get_forecasts(
|
||||
tenant_id=request.tenant_id,
|
||||
location=request.location,
|
||||
db=db
|
||||
# Record metrics
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_single_forecasts_total")
|
||||
|
||||
# Generate forecast using enhanced service
|
||||
forecast = await enhanced_forecasting_service.generate_forecast(
|
||||
tenant_id=tenant_id,
|
||||
request=request
|
||||
)
|
||||
|
||||
# Convert forecasts to response models
|
||||
forecast_responses = []
|
||||
for forecast in forecasts[:batch.total_products]: # Limit to batch size
|
||||
forecast_responses.append(ForecastResponse(
|
||||
id=str(forecast.id),
|
||||
tenant_id=str(forecast.tenant_id),
|
||||
product_name=forecast.product_name,
|
||||
location=forecast.location,
|
||||
forecast_date=forecast.forecast_date,
|
||||
predicted_demand=forecast.predicted_demand,
|
||||
confidence_lower=forecast.confidence_lower,
|
||||
confidence_upper=forecast.confidence_upper,
|
||||
confidence_level=forecast.confidence_level,
|
||||
model_id=str(forecast.model_id),
|
||||
model_version=forecast.model_version,
|
||||
algorithm=forecast.algorithm,
|
||||
business_type=forecast.business_type,
|
||||
is_holiday=forecast.is_holiday,
|
||||
is_weekend=forecast.is_weekend,
|
||||
day_of_week=forecast.day_of_week,
|
||||
weather_temperature=forecast.weather_temperature,
|
||||
weather_precipitation=forecast.weather_precipitation,
|
||||
weather_description=forecast.weather_description,
|
||||
traffic_volume=forecast.traffic_volume,
|
||||
created_at=forecast.created_at,
|
||||
processing_time_ms=forecast.processing_time_ms,
|
||||
features_used=forecast.features_used
|
||||
))
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_single_forecasts_success_total")
|
||||
|
||||
return BatchForecastResponse(
|
||||
id=str(batch.id),
|
||||
tenant_id=str(batch.tenant_id),
|
||||
batch_name=batch.batch_name,
|
||||
status=batch.status,
|
||||
total_products=batch.total_products,
|
||||
completed_products=batch.completed_products,
|
||||
failed_products=batch.failed_products,
|
||||
requested_at=batch.requested_at,
|
||||
completed_at=batch.completed_at,
|
||||
processing_time_ms=batch.processing_time_ms,
|
||||
forecasts=forecast_responses
|
||||
)
|
||||
logger.info("Enhanced single forecast generated successfully",
|
||||
tenant_id=tenant_id,
|
||||
forecast_id=forecast.id)
|
||||
|
||||
return forecast
|
||||
|
||||
except ValueError as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_forecast_validation_errors_total")
|
||||
logger.error("Enhanced forecast validation error",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error creating batch forecast", error=str(e))
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_single_forecasts_errors_total")
|
||||
logger.error("Enhanced single forecast generation failed",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal server error"
|
||||
detail="Enhanced forecast generation failed"
|
||||
)
|
||||
|
||||
@router.get("/tenants/{tenant_id}/forecasts/list", response_model=List[ForecastResponse])
|
||||
async def list_forecasts(
|
||||
location: str,
|
||||
start_date: Optional[date] = Query(None),
|
||||
end_date: Optional[date] = Query(None),
|
||||
product_name: Optional[str] = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
tenant_id: str = Path(..., description="Tenant ID")
|
||||
|
||||
@router.post("/tenants/{tenant_id}/forecasts/batch", response_model=BatchForecastResponse)
|
||||
@track_execution_time("enhanced_batch_forecast_duration_seconds", "forecasting-service")
|
||||
async def create_enhanced_batch_forecast(
|
||||
request: BatchForecastRequest,
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
request_obj: Request = None,
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
||||
):
|
||||
"""List forecasts with filtering"""
|
||||
"""Generate batch forecasts using enhanced repository pattern"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
# Enhanced tenant validation
|
||||
if tenant_id != current_tenant:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_batch_forecast_access_denied_total")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to tenant resources"
|
||||
)
|
||||
|
||||
# Get forecasts
|
||||
forecasts = await forecasting_service.get_forecasts(
|
||||
logger.info("Generating enhanced batch forecasts",
|
||||
tenant_id=tenant_id,
|
||||
products_count=len(request.products),
|
||||
forecast_dates_count=len(request.forecast_dates))
|
||||
|
||||
# Record metrics
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_batch_forecasts_total")
|
||||
metrics.histogram("enhanced_batch_forecast_products_count", len(request.products))
|
||||
|
||||
# Generate batch forecasts using enhanced service
|
||||
batch_result = await enhanced_forecasting_service.generate_batch_forecasts(
|
||||
tenant_id=tenant_id,
|
||||
location=location,
|
||||
request=request
|
||||
)
|
||||
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_batch_forecasts_success_total")
|
||||
|
||||
logger.info("Enhanced batch forecasts generated successfully",
|
||||
tenant_id=tenant_id,
|
||||
batch_id=batch_result.get("batch_id"),
|
||||
forecasts_generated=len(batch_result.get("forecasts", [])))
|
||||
|
||||
return BatchForecastResponse(**batch_result)
|
||||
|
||||
except ValueError as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_batch_forecast_validation_errors_total")
|
||||
logger.error("Enhanced batch forecast validation error",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_batch_forecasts_errors_total")
|
||||
logger.error("Enhanced batch forecast generation failed",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Enhanced batch forecast generation failed"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}/forecasts")
|
||||
@track_execution_time("enhanced_get_forecasts_duration_seconds", "forecasting-service")
|
||||
async def get_enhanced_tenant_forecasts(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
product_name: Optional[str] = Query(None, description="Filter by product name"),
|
||||
start_date: Optional[date] = Query(None, description="Start date filter"),
|
||||
end_date: Optional[date] = Query(None, description="End date filter"),
|
||||
skip: int = Query(0, description="Number of records to skip"),
|
||||
limit: int = Query(100, description="Number of records to return"),
|
||||
request_obj: Request = None,
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
||||
):
|
||||
"""Get tenant forecasts with enhanced filtering using repository pattern"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
# Enhanced tenant validation
|
||||
if tenant_id != current_tenant:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_get_forecasts_access_denied_total")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to tenant resources"
|
||||
)
|
||||
|
||||
# Record metrics
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_get_forecasts_total")
|
||||
|
||||
# Get forecasts using enhanced service
|
||||
forecasts = await enhanced_forecasting_service.get_tenant_forecasts(
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
product_name=product_name,
|
||||
db=db
|
||||
skip=skip,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
# Convert to response models
|
||||
return [
|
||||
ForecastResponse(
|
||||
id=str(forecast.id),
|
||||
tenant_id=str(forecast.tenant_id),
|
||||
product_name=forecast.product_name,
|
||||
location=forecast.location,
|
||||
forecast_date=forecast.forecast_date,
|
||||
predicted_demand=forecast.predicted_demand,
|
||||
confidence_lower=forecast.confidence_lower,
|
||||
confidence_upper=forecast.confidence_upper,
|
||||
confidence_level=forecast.confidence_level,
|
||||
model_id=str(forecast.model_id),
|
||||
model_version=forecast.model_version,
|
||||
algorithm=forecast.algorithm,
|
||||
business_type=forecast.business_type,
|
||||
is_holiday=forecast.is_holiday,
|
||||
is_weekend=forecast.is_weekend,
|
||||
day_of_week=forecast.day_of_week,
|
||||
weather_temperature=forecast.weather_temperature,
|
||||
weather_precipitation=forecast.weather_precipitation,
|
||||
weather_description=forecast.weather_description,
|
||||
traffic_volume=forecast.traffic_volume,
|
||||
created_at=forecast.created_at,
|
||||
processing_time_ms=forecast.processing_time_ms,
|
||||
features_used=forecast.features_used
|
||||
)
|
||||
for forecast in forecasts
|
||||
]
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_get_forecasts_success_total")
|
||||
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"forecasts": forecasts,
|
||||
"total_returned": len(forecasts),
|
||||
"filters": {
|
||||
"product_name": product_name,
|
||||
"start_date": start_date.isoformat() if start_date else None,
|
||||
"end_date": end_date.isoformat() if end_date else None
|
||||
},
|
||||
"pagination": {
|
||||
"skip": skip,
|
||||
"limit": limit
|
||||
},
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error listing forecasts", error=str(e))
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_get_forecasts_errors_total")
|
||||
logger.error("Failed to get enhanced tenant forecasts",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal server error"
|
||||
detail="Failed to get tenant forecasts"
|
||||
)
|
||||
|
||||
@router.get("/tenants/{tenant_id}/forecasts/alerts", response_model=List[AlertResponse])
|
||||
async def get_forecast_alerts(
|
||||
active_only: bool = Query(True),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
current_user: dict = Depends(get_current_user_dep)
|
||||
):
|
||||
"""Get forecast alerts for tenant"""
|
||||
|
||||
try:
|
||||
from sqlalchemy import select, and_
|
||||
|
||||
# Build query
|
||||
query = select(ForecastAlert).where(
|
||||
ForecastAlert.tenant_id == tenant_id
|
||||
)
|
||||
|
||||
if active_only:
|
||||
query = query.where(ForecastAlert.is_active == True)
|
||||
|
||||
query = query.order_by(ForecastAlert.created_at.desc())
|
||||
|
||||
# Execute query
|
||||
result = await db.execute(query)
|
||||
alerts = result.scalars().all()
|
||||
|
||||
# Convert to response models
|
||||
return [
|
||||
AlertResponse(
|
||||
id=str(alert.id),
|
||||
tenant_id=str(alert.tenant_id),
|
||||
forecast_id=str(alert.forecast_id),
|
||||
alert_type=alert.alert_type,
|
||||
severity=alert.severity,
|
||||
message=alert.message,
|
||||
is_active=alert.is_active,
|
||||
created_at=alert.created_at,
|
||||
acknowledged_at=alert.acknowledged_at,
|
||||
notification_sent=alert.notification_sent
|
||||
)
|
||||
for alert in alerts
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting forecast alerts", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal server error"
|
||||
)
|
||||
|
||||
@router.put("/tenants/{tenant_id}/forecasts/alerts/{alert_id}/acknowledge")
|
||||
async def acknowledge_alert(
|
||||
alert_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
@router.get("/tenants/{tenant_id}/forecasts/{forecast_id}")
|
||||
@track_execution_time("enhanced_get_forecast_duration_seconds", "forecasting-service")
|
||||
async def get_enhanced_forecast_by_id(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
current_user: dict = Depends(get_current_user_dep)
|
||||
forecast_id: str = Path(..., description="Forecast ID"),
|
||||
request_obj: Request = None,
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
||||
):
|
||||
"""Acknowledge a forecast alert"""
|
||||
"""Get specific forecast by ID using enhanced repository pattern"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
from sqlalchemy import select, update
|
||||
from datetime import datetime
|
||||
|
||||
# Get alert
|
||||
result = await db.execute(
|
||||
select(ForecastAlert).where(
|
||||
and_(
|
||||
ForecastAlert.id == alert_id,
|
||||
ForecastAlert.tenant_id == tenant_id
|
||||
)
|
||||
# Enhanced tenant validation
|
||||
if tenant_id != current_tenant:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_get_forecast_access_denied_total")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to tenant resources"
|
||||
)
|
||||
)
|
||||
alert = result.scalar_one_or_none()
|
||||
|
||||
if not alert:
|
||||
# Record metrics
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_get_forecast_by_id_total")
|
||||
|
||||
# Get forecast using enhanced service
|
||||
forecast = await enhanced_forecasting_service.get_forecast_by_id(forecast_id)
|
||||
|
||||
if not forecast:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Alert not found"
|
||||
detail="Forecast not found"
|
||||
)
|
||||
|
||||
# Update alert
|
||||
alert.acknowledged_at = datetime.now()
|
||||
alert.is_active = False
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_get_forecast_by_id_success_total")
|
||||
|
||||
await db.commit()
|
||||
|
||||
return {"message": "Alert acknowledged successfully"}
|
||||
return {
|
||||
**forecast,
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Error acknowledging alert", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal server error"
|
||||
)
|
||||
|
||||
@router.delete("/tenants/{tenant_id}/forecasts")
|
||||
async def delete_tenant_forecasts(
|
||||
tenant_id: str,
|
||||
current_user = Depends(get_current_user_dep),
|
||||
_admin_check = Depends(require_admin_role),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Delete all forecasts and predictions for a tenant (admin only)"""
|
||||
try:
|
||||
tenant_uuid = uuid.UUID(tenant_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid tenant ID format"
|
||||
)
|
||||
|
||||
try:
|
||||
from app.models.forecasts import Forecast, Prediction, PredictionBatch
|
||||
|
||||
deletion_stats = {
|
||||
"tenant_id": tenant_id,
|
||||
"deleted_at": datetime.utcnow().isoformat(),
|
||||
"forecasts_deleted": 0,
|
||||
"predictions_deleted": 0,
|
||||
"batches_deleted": 0,
|
||||
"errors": []
|
||||
}
|
||||
|
||||
# Count before deletion
|
||||
forecasts_count_query = select(func.count(Forecast.id)).where(
|
||||
Forecast.tenant_id == tenant_uuid
|
||||
)
|
||||
forecasts_count_result = await db.execute(forecasts_count_query)
|
||||
forecasts_count = forecasts_count_result.scalar()
|
||||
|
||||
predictions_count_query = select(func.count(Prediction.id)).where(
|
||||
Prediction.tenant_id == tenant_uuid
|
||||
)
|
||||
predictions_count_result = await db.execute(predictions_count_query)
|
||||
predictions_count = predictions_count_result.scalar()
|
||||
|
||||
batches_count_query = select(func.count(PredictionBatch.id)).where(
|
||||
PredictionBatch.tenant_id == tenant_uuid
|
||||
)
|
||||
batches_count_result = await db.execute(batches_count_query)
|
||||
batches_count = batches_count_result.scalar()
|
||||
|
||||
# Delete predictions first (they may reference forecasts)
|
||||
try:
|
||||
predictions_delete_query = delete(Prediction).where(
|
||||
Prediction.tenant_id == tenant_uuid
|
||||
)
|
||||
predictions_delete_result = await db.execute(predictions_delete_query)
|
||||
deletion_stats["predictions_deleted"] = predictions_delete_result.rowcount
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error deleting predictions: {str(e)}"
|
||||
deletion_stats["errors"].append(error_msg)
|
||||
logger.error(error_msg)
|
||||
|
||||
# Delete prediction batches
|
||||
try:
|
||||
batches_delete_query = delete(PredictionBatch).where(
|
||||
PredictionBatch.tenant_id == tenant_uuid
|
||||
)
|
||||
batches_delete_result = await db.execute(batches_delete_query)
|
||||
deletion_stats["batches_deleted"] = batches_delete_result.rowcount
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error deleting prediction batches: {str(e)}"
|
||||
deletion_stats["errors"].append(error_msg)
|
||||
logger.error(error_msg)
|
||||
|
||||
# Delete forecasts
|
||||
try:
|
||||
forecasts_delete_query = delete(Forecast).where(
|
||||
Forecast.tenant_id == tenant_uuid
|
||||
)
|
||||
forecasts_delete_result = await db.execute(forecasts_delete_query)
|
||||
deletion_stats["forecasts_deleted"] = forecasts_delete_result.rowcount
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error deleting forecasts: {str(e)}"
|
||||
deletion_stats["errors"].append(error_msg)
|
||||
logger.error(error_msg)
|
||||
|
||||
await db.commit()
|
||||
|
||||
logger.info("Deleted tenant forecasting data",
|
||||
tenant_id=tenant_id,
|
||||
forecasts=deletion_stats["forecasts_deleted"],
|
||||
predictions=deletion_stats["predictions_deleted"],
|
||||
batches=deletion_stats["batches_deleted"])
|
||||
|
||||
deletion_stats["success"] = len(deletion_stats["errors"]) == 0
|
||||
deletion_stats["expected_counts"] = {
|
||||
"forecasts": forecasts_count,
|
||||
"predictions": predictions_count,
|
||||
"batches": batches_count
|
||||
}
|
||||
|
||||
return deletion_stats
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error("Failed to delete tenant forecasts",
|
||||
tenant_id=tenant_id,
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_get_forecast_by_id_errors_total")
|
||||
logger.error("Failed to get enhanced forecast by ID",
|
||||
forecast_id=forecast_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to delete tenant forecasts"
|
||||
detail="Failed to get forecast"
|
||||
)
|
||||
|
||||
@router.get("/tenants/{tenant_id}/forecasts/count")
|
||||
async def get_tenant_forecasts_count(
|
||||
tenant_id: str,
|
||||
current_user = Depends(get_current_user_dep),
|
||||
_admin_check = Depends(require_admin_role),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
|
||||
@router.delete("/tenants/{tenant_id}/forecasts/{forecast_id}")
|
||||
@track_execution_time("enhanced_delete_forecast_duration_seconds", "forecasting-service")
|
||||
async def delete_enhanced_forecast(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
forecast_id: str = Path(..., description="Forecast ID"),
|
||||
request_obj: Request = None,
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
||||
):
|
||||
"""Get count of forecasts and predictions for a tenant (admin only)"""
|
||||
try:
|
||||
tenant_uuid = uuid.UUID(tenant_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid tenant ID format"
|
||||
)
|
||||
"""Delete forecast using enhanced repository pattern"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
from app.models.forecasts import Forecast, Prediction, PredictionBatch
|
||||
# Enhanced tenant validation
|
||||
if tenant_id != current_tenant:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_delete_forecast_access_denied_total")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to tenant resources"
|
||||
)
|
||||
|
||||
# Count forecasts
|
||||
forecasts_count_query = select(func.count(Forecast.id)).where(
|
||||
Forecast.tenant_id == tenant_uuid
|
||||
)
|
||||
forecasts_count_result = await db.execute(forecasts_count_query)
|
||||
forecasts_count = forecasts_count_result.scalar()
|
||||
# Record metrics
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_delete_forecast_total")
|
||||
|
||||
# Count predictions
|
||||
predictions_count_query = select(func.count(Prediction.id)).where(
|
||||
Prediction.tenant_id == tenant_uuid
|
||||
)
|
||||
predictions_count_result = await db.execute(predictions_count_query)
|
||||
predictions_count = predictions_count_result.scalar()
|
||||
# Delete forecast using enhanced service
|
||||
deleted = await enhanced_forecasting_service.delete_forecast(forecast_id)
|
||||
|
||||
# Count batches
|
||||
batches_count_query = select(func.count(PredictionBatch.id)).where(
|
||||
PredictionBatch.tenant_id == tenant_uuid
|
||||
if not deleted:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Forecast not found"
|
||||
)
|
||||
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_delete_forecast_success_total")
|
||||
|
||||
logger.info("Enhanced forecast deleted successfully",
|
||||
forecast_id=forecast_id,
|
||||
tenant_id=tenant_id)
|
||||
|
||||
return {
|
||||
"message": "Forecast deleted successfully",
|
||||
"forecast_id": forecast_id,
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_delete_forecast_errors_total")
|
||||
logger.error("Failed to delete enhanced forecast",
|
||||
forecast_id=forecast_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to delete forecast"
|
||||
)
|
||||
batches_count_result = await db.execute(batches_count_query)
|
||||
batches_count = batches_count_result.scalar()
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}/forecasts/alerts")
|
||||
@track_execution_time("enhanced_get_alerts_duration_seconds", "forecasting-service")
|
||||
async def get_enhanced_forecast_alerts(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
active_only: bool = Query(True, description="Return only active alerts"),
|
||||
skip: int = Query(0, description="Number of records to skip"),
|
||||
limit: int = Query(50, description="Number of records to return"),
|
||||
request_obj: Request = None,
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
||||
):
|
||||
"""Get forecast alerts using enhanced repository pattern"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
# Enhanced tenant validation
|
||||
if tenant_id != current_tenant:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_get_alerts_access_denied_total")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to tenant resources"
|
||||
)
|
||||
|
||||
# Record metrics
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_get_alerts_total")
|
||||
|
||||
# Get alerts using enhanced service
|
||||
alerts = await enhanced_forecasting_service.get_tenant_alerts(
|
||||
tenant_id=tenant_id,
|
||||
active_only=active_only,
|
||||
skip=skip,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_get_alerts_success_total")
|
||||
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"forecasts_count": forecasts_count,
|
||||
"predictions_count": predictions_count,
|
||||
"batches_count": batches_count,
|
||||
"total_forecasting_assets": forecasts_count + predictions_count + batches_count
|
||||
"alerts": alerts,
|
||||
"total_returned": len(alerts),
|
||||
"active_only": active_only,
|
||||
"pagination": {
|
||||
"skip": skip,
|
||||
"limit": limit
|
||||
},
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get tenant forecasts count",
|
||||
tenant_id=tenant_id,
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_get_alerts_errors_total")
|
||||
logger.error("Failed to get enhanced forecast alerts",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get forecasts count"
|
||||
)
|
||||
detail="Failed to get forecast alerts"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}/forecasts/statistics")
|
||||
@track_execution_time("enhanced_forecast_statistics_duration_seconds", "forecasting-service")
|
||||
async def get_enhanced_forecast_statistics(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
request_obj: Request = None,
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
||||
):
|
||||
"""Get comprehensive forecast statistics using enhanced repository pattern"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
# Enhanced tenant validation
|
||||
if tenant_id != current_tenant:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_forecast_statistics_access_denied_total")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to tenant resources"
|
||||
)
|
||||
|
||||
# Record metrics
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_forecast_statistics_total")
|
||||
|
||||
# Get statistics using enhanced service
|
||||
statistics = await enhanced_forecasting_service.get_tenant_forecast_statistics(tenant_id)
|
||||
|
||||
if statistics.get("error"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=statistics["error"]
|
||||
)
|
||||
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_forecast_statistics_success_total")
|
||||
|
||||
return {
|
||||
**statistics,
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_forecast_statistics_errors_total")
|
||||
logger.error("Failed to get enhanced forecast statistics",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get forecast statistics"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def enhanced_health_check():
|
||||
"""Enhanced health check endpoint for the forecasting service"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "enhanced-forecasting-service",
|
||||
"version": "2.0.0",
|
||||
"features": [
|
||||
"repository-pattern",
|
||||
"dependency-injection",
|
||||
"enhanced-error-handling",
|
||||
"metrics-tracking",
|
||||
"transactional-operations",
|
||||
"batch-processing"
|
||||
],
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
@@ -1,271 +1,468 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/api/predictions.py
|
||||
# ================================================================
|
||||
"""
|
||||
Prediction API endpoints - Real-time prediction capabilities
|
||||
Enhanced Predictions API Endpoints with Repository Pattern
|
||||
Real-time prediction capabilities using repository pattern with dependency injection
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import List, Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query, Path, Request
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import date, datetime, timedelta
|
||||
from sqlalchemy import select, delete, func
|
||||
import uuid
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.services.prediction_service import PredictionService
|
||||
from app.services.forecasting_service import EnhancedForecastingService
|
||||
from app.schemas.forecasts import ForecastRequest
|
||||
from shared.auth.decorators import (
|
||||
get_current_user_dep,
|
||||
get_current_tenant_id_dep,
|
||||
get_current_user_dep,
|
||||
require_admin_role
|
||||
)
|
||||
from app.services.prediction_service import PredictionService
|
||||
from app.schemas.forecasts import ForecastRequest
|
||||
from shared.database.base import create_database_manager
|
||||
from shared.monitoring.decorators import track_execution_time
|
||||
from shared.monitoring.metrics import get_metrics_collector
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter()
|
||||
router = APIRouter(tags=["enhanced-predictions"])
|
||||
|
||||
# Initialize service
|
||||
prediction_service = PredictionService()
|
||||
def get_enhanced_prediction_service():
|
||||
"""Dependency injection for enhanced PredictionService"""
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "forecasting-service")
|
||||
return PredictionService(database_manager)
|
||||
|
||||
@router.post("/realtime")
|
||||
async def get_realtime_prediction(
|
||||
product_name: str,
|
||||
location: str,
|
||||
forecast_date: date,
|
||||
features: Dict[str, Any],
|
||||
tenant_id: str = Depends(get_current_tenant_id_dep)
|
||||
def get_enhanced_forecasting_service():
|
||||
"""Dependency injection for EnhancedForecastingService"""
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "forecasting-service")
|
||||
return EnhancedForecastingService(database_manager)
|
||||
|
||||
@router.post("/tenants/{tenant_id}/predictions/realtime")
|
||||
@track_execution_time("enhanced_realtime_prediction_duration_seconds", "forecasting-service")
|
||||
async def generate_enhanced_realtime_prediction(
|
||||
prediction_request: Dict[str, Any],
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
request_obj: Request = None,
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
prediction_service: PredictionService = Depends(get_enhanced_prediction_service)
|
||||
):
|
||||
"""Get real-time prediction without storing in database"""
|
||||
"""Generate real-time prediction using enhanced repository pattern"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
|
||||
# Get latest model
|
||||
from app.services.forecasting_service import ForecastingService
|
||||
forecasting_service = ForecastingService()
|
||||
|
||||
model_info = await forecasting_service._get_latest_model(
|
||||
tenant_id, product_name, location
|
||||
)
|
||||
|
||||
if not model_info:
|
||||
# Enhanced tenant validation
|
||||
if tenant_id != current_tenant:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_realtime_prediction_access_denied_total")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No trained model found for {product_name}"
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to tenant resources"
|
||||
)
|
||||
|
||||
# Generate prediction
|
||||
prediction = await prediction_service.predict(
|
||||
model_id=model_info["model_id"],
|
||||
features=features,
|
||||
confidence_level=0.8
|
||||
)
|
||||
logger.info("Generating enhanced real-time prediction",
|
||||
tenant_id=tenant_id,
|
||||
product_name=prediction_request.get("product_name"))
|
||||
|
||||
return {
|
||||
"product_name": product_name,
|
||||
"location": location,
|
||||
"forecast_date": forecast_date,
|
||||
"predicted_demand": prediction["demand"],
|
||||
"confidence_lower": prediction["lower_bound"],
|
||||
"confidence_upper": prediction["upper_bound"],
|
||||
"model_id": model_info["model_id"],
|
||||
"model_version": model_info["version"],
|
||||
"generated_at": datetime.now(),
|
||||
"features_used": features
|
||||
}
|
||||
# Record metrics
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_realtime_predictions_total")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Error getting realtime prediction", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal server error"
|
||||
)
|
||||
|
||||
@router.get("/quick/{product_name}")
|
||||
async def get_quick_prediction(
|
||||
product_name: str,
|
||||
location: str = Query(...),
|
||||
days_ahead: int = Query(1, ge=1, le=7),
|
||||
tenant_id: str = Depends(get_current_tenant_id_dep)
|
||||
):
|
||||
"""Get quick prediction for next few days"""
|
||||
|
||||
try:
|
||||
|
||||
# Generate predictions for the next N days
|
||||
predictions = []
|
||||
|
||||
for day in range(1, days_ahead + 1):
|
||||
forecast_date = date.today() + timedelta(days=day)
|
||||
|
||||
# Prepare basic features
|
||||
features = {
|
||||
"date": forecast_date.isoformat(),
|
||||
"day_of_week": forecast_date.weekday(),
|
||||
"is_weekend": forecast_date.weekday() >= 5,
|
||||
"business_type": "individual"
|
||||
}
|
||||
|
||||
# Get model and predict
|
||||
from app.services.forecasting_service import ForecastingService
|
||||
forecasting_service = ForecastingService()
|
||||
|
||||
model_info = await forecasting_service._get_latest_model(
|
||||
tenant_id, product_name, location
|
||||
# Validate required fields
|
||||
required_fields = ["product_name", "model_id", "features"]
|
||||
missing_fields = [field for field in required_fields if field not in prediction_request]
|
||||
if missing_fields:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Missing required fields: {missing_fields}"
|
||||
)
|
||||
|
||||
if model_info:
|
||||
prediction = await prediction_service.predict(
|
||||
model_id=model_info["model_id"],
|
||||
features=features
|
||||
)
|
||||
|
||||
predictions.append({
|
||||
"date": forecast_date,
|
||||
"predicted_demand": prediction["demand"],
|
||||
"confidence_lower": prediction["lower_bound"],
|
||||
"confidence_upper": prediction["upper_bound"]
|
||||
})
|
||||
|
||||
# Generate prediction using enhanced service
|
||||
prediction_result = await prediction_service.predict(
|
||||
model_id=prediction_request["model_id"],
|
||||
model_path=prediction_request.get("model_path", ""),
|
||||
features=prediction_request["features"],
|
||||
confidence_level=prediction_request.get("confidence_level", 0.8)
|
||||
)
|
||||
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_realtime_predictions_success_total")
|
||||
|
||||
logger.info("Enhanced real-time prediction generated successfully",
|
||||
tenant_id=tenant_id,
|
||||
prediction_value=prediction_result.get("prediction"))
|
||||
|
||||
return {
|
||||
"product_name": product_name,
|
||||
"location": location,
|
||||
"predictions": predictions,
|
||||
"generated_at": datetime.now()
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": prediction_request["product_name"],
|
||||
"model_id": prediction_request["model_id"],
|
||||
"prediction": prediction_result,
|
||||
"generated_at": datetime.now().isoformat(),
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting quick prediction", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal server error"
|
||||
)
|
||||
|
||||
@router.post("/tenants/{tenant_id}/predictions/cancel-batches")
|
||||
async def cancel_tenant_prediction_batches(
|
||||
tenant_id: str,
|
||||
current_user = Depends(get_current_user_dep),
|
||||
_admin_check = Depends(require_admin_role),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Cancel all active prediction batches for a tenant (admin only)"""
|
||||
try:
|
||||
tenant_uuid = uuid.UUID(tenant_id)
|
||||
except ValueError:
|
||||
except ValueError as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_prediction_validation_errors_total")
|
||||
logger.error("Enhanced prediction validation error",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid tenant ID format"
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_realtime_predictions_errors_total")
|
||||
logger.error("Enhanced real-time prediction failed",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Enhanced real-time prediction failed"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/tenants/{tenant_id}/predictions/batch")
|
||||
@track_execution_time("enhanced_batch_prediction_duration_seconds", "forecasting-service")
|
||||
async def generate_enhanced_batch_predictions(
|
||||
batch_request: Dict[str, Any],
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
request_obj: Request = None,
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
||||
):
|
||||
"""Generate batch predictions using enhanced repository pattern"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
from app.models.forecasts import PredictionBatch
|
||||
# Enhanced tenant validation
|
||||
if tenant_id != current_tenant:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_batch_prediction_access_denied_total")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to tenant resources"
|
||||
)
|
||||
|
||||
# Find active prediction batches
|
||||
active_batches_query = select(PredictionBatch).where(
|
||||
PredictionBatch.tenant_id == tenant_uuid,
|
||||
PredictionBatch.status.in_(["queued", "running", "pending"])
|
||||
logger.info("Generating enhanced batch predictions",
|
||||
tenant_id=tenant_id,
|
||||
predictions_count=len(batch_request.get("predictions", [])))
|
||||
|
||||
# Record metrics
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_batch_predictions_total")
|
||||
metrics.histogram("enhanced_batch_predictions_count", len(batch_request.get("predictions", [])))
|
||||
|
||||
# Validate batch request
|
||||
if "predictions" not in batch_request or not batch_request["predictions"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Batch request must contain 'predictions' array"
|
||||
)
|
||||
|
||||
# Generate batch predictions using enhanced service
|
||||
batch_result = await enhanced_forecasting_service.generate_batch_predictions(
|
||||
tenant_id=tenant_id,
|
||||
batch_request=batch_request
|
||||
)
|
||||
active_batches_result = await db.execute(active_batches_query)
|
||||
active_batches = active_batches_result.scalars().all()
|
||||
|
||||
batches_cancelled = 0
|
||||
cancelled_batch_ids = []
|
||||
errors = []
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_batch_predictions_success_total")
|
||||
|
||||
for batch in active_batches:
|
||||
try:
|
||||
batch.status = "cancelled"
|
||||
batch.updated_at = datetime.utcnow()
|
||||
batch.cancelled_by = current_user.get("user_id")
|
||||
batches_cancelled += 1
|
||||
cancelled_batch_ids.append(str(batch.id))
|
||||
|
||||
logger.info("Cancelled prediction batch",
|
||||
batch_id=str(batch.id),
|
||||
tenant_id=tenant_id)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to cancel batch {batch.id}: {str(e)}"
|
||||
errors.append(error_msg)
|
||||
logger.error(error_msg)
|
||||
|
||||
if batches_cancelled > 0:
|
||||
await db.commit()
|
||||
logger.info("Enhanced batch predictions generated successfully",
|
||||
tenant_id=tenant_id,
|
||||
batch_id=batch_result.get("batch_id"),
|
||||
predictions_generated=len(batch_result.get("predictions", [])))
|
||||
|
||||
return {
|
||||
**batch_result,
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_batch_prediction_validation_errors_total")
|
||||
logger.error("Enhanced batch prediction validation error",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_batch_predictions_errors_total")
|
||||
logger.error("Enhanced batch predictions failed",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Enhanced batch predictions failed"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}/predictions/cache")
|
||||
@track_execution_time("enhanced_get_prediction_cache_duration_seconds", "forecasting-service")
|
||||
async def get_enhanced_prediction_cache(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
product_name: Optional[str] = Query(None, description="Filter by product name"),
|
||||
skip: int = Query(0, description="Number of records to skip"),
|
||||
limit: int = Query(100, description="Number of records to return"),
|
||||
request_obj: Request = None,
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
||||
):
|
||||
"""Get cached predictions using enhanced repository pattern"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
# Enhanced tenant validation
|
||||
if tenant_id != current_tenant:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_get_cache_access_denied_total")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to tenant resources"
|
||||
)
|
||||
|
||||
# Record metrics
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_get_prediction_cache_total")
|
||||
|
||||
# Get cached predictions using enhanced service
|
||||
cached_predictions = await enhanced_forecasting_service.get_cached_predictions(
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
skip=skip,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_get_prediction_cache_success_total")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"tenant_id": tenant_id,
|
||||
"batches_cancelled": batches_cancelled,
|
||||
"cancelled_batch_ids": cancelled_batch_ids,
|
||||
"errors": errors,
|
||||
"cancelled_at": datetime.utcnow().isoformat()
|
||||
"cached_predictions": cached_predictions,
|
||||
"total_returned": len(cached_predictions),
|
||||
"filters": {
|
||||
"product_name": product_name
|
||||
},
|
||||
"pagination": {
|
||||
"skip": skip,
|
||||
"limit": limit
|
||||
},
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error("Failed to cancel tenant prediction batches",
|
||||
tenant_id=tenant_id,
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_get_prediction_cache_errors_total")
|
||||
logger.error("Failed to get enhanced prediction cache",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to cancel prediction batches"
|
||||
detail="Failed to get prediction cache"
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/tenants/{tenant_id}/predictions/cache")
|
||||
async def clear_tenant_prediction_cache(
|
||||
tenant_id: str,
|
||||
current_user = Depends(get_current_user_dep),
|
||||
_admin_check = Depends(require_admin_role),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
@track_execution_time("enhanced_clear_prediction_cache_duration_seconds", "forecasting-service")
|
||||
async def clear_enhanced_prediction_cache(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
product_name: Optional[str] = Query(None, description="Clear cache for specific product"),
|
||||
request_obj: Request = None,
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
||||
):
|
||||
"""Clear all prediction cache for a tenant (admin only)"""
|
||||
try:
|
||||
tenant_uuid = uuid.UUID(tenant_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid tenant ID format"
|
||||
)
|
||||
"""Clear prediction cache using enhanced repository pattern"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
from app.models.forecasts import PredictionCache
|
||||
# Enhanced tenant validation
|
||||
if tenant_id != current_tenant:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_clear_cache_access_denied_total")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to tenant resources"
|
||||
)
|
||||
|
||||
# Count cache entries before deletion
|
||||
cache_count_query = select(func.count(PredictionCache.id)).where(
|
||||
PredictionCache.tenant_id == tenant_uuid
|
||||
# Record metrics
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_clear_prediction_cache_total")
|
||||
|
||||
# Clear cache using enhanced service
|
||||
cleared_count = await enhanced_forecasting_service.clear_prediction_cache(
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name
|
||||
)
|
||||
cache_count_result = await db.execute(cache_count_query)
|
||||
cache_count = cache_count_result.scalar()
|
||||
|
||||
# Delete cache entries
|
||||
cache_delete_query = delete(PredictionCache).where(
|
||||
PredictionCache.tenant_id == tenant_uuid
|
||||
)
|
||||
cache_delete_result = await db.execute(cache_delete_query)
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_clear_prediction_cache_success_total")
|
||||
metrics.histogram("enhanced_cache_cleared_count", cleared_count)
|
||||
|
||||
await db.commit()
|
||||
|
||||
logger.info("Cleared tenant prediction cache",
|
||||
logger.info("Enhanced prediction cache cleared",
|
||||
tenant_id=tenant_id,
|
||||
cache_cleared=cache_delete_result.rowcount)
|
||||
product_name=product_name,
|
||||
cleared_count=cleared_count)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Prediction cache cleared successfully",
|
||||
"tenant_id": tenant_id,
|
||||
"cache_cleared": cache_delete_result.rowcount,
|
||||
"expected_count": cache_count,
|
||||
"cleared_at": datetime.utcnow().isoformat()
|
||||
"product_name": product_name,
|
||||
"cleared_count": cleared_count,
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error("Failed to clear tenant prediction cache",
|
||||
tenant_id=tenant_id,
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_clear_prediction_cache_errors_total")
|
||||
logger.error("Failed to clear enhanced prediction cache",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to clear prediction cache"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}/predictions/performance")
|
||||
@track_execution_time("enhanced_get_prediction_performance_duration_seconds", "forecasting-service")
|
||||
async def get_enhanced_prediction_performance(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
model_id: Optional[str] = Query(None, description="Filter by model ID"),
|
||||
start_date: Optional[date] = Query(None, description="Start date filter"),
|
||||
end_date: Optional[date] = Query(None, description="End date filter"),
|
||||
request_obj: Request = None,
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
||||
):
|
||||
"""Get prediction performance metrics using enhanced repository pattern"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
# Enhanced tenant validation
|
||||
if tenant_id != current_tenant:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_get_performance_access_denied_total")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to tenant resources"
|
||||
)
|
||||
|
||||
# Record metrics
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_get_prediction_performance_total")
|
||||
|
||||
# Get performance metrics using enhanced service
|
||||
performance = await enhanced_forecasting_service.get_prediction_performance(
|
||||
tenant_id=tenant_id,
|
||||
model_id=model_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_get_prediction_performance_success_total")
|
||||
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"performance_metrics": performance,
|
||||
"filters": {
|
||||
"model_id": model_id,
|
||||
"start_date": start_date.isoformat() if start_date else None,
|
||||
"end_date": end_date.isoformat() if end_date else None
|
||||
},
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_get_prediction_performance_errors_total")
|
||||
logger.error("Failed to get enhanced prediction performance",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get prediction performance"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/tenants/{tenant_id}/predictions/validate")
|
||||
@track_execution_time("enhanced_validate_prediction_duration_seconds", "forecasting-service")
|
||||
async def validate_enhanced_prediction_request(
|
||||
validation_request: Dict[str, Any],
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
request_obj: Request = None,
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
prediction_service: PredictionService = Depends(get_enhanced_prediction_service)
|
||||
):
|
||||
"""Validate prediction request without generating prediction"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
# Enhanced tenant validation
|
||||
if tenant_id != current_tenant:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_validate_prediction_access_denied_total")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to tenant resources"
|
||||
)
|
||||
|
||||
# Record metrics
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_validate_prediction_total")
|
||||
|
||||
# Validate prediction request
|
||||
validation_result = await prediction_service.validate_prediction_request(
|
||||
validation_request
|
||||
)
|
||||
|
||||
if metrics:
|
||||
if validation_result.get("is_valid"):
|
||||
metrics.increment_counter("enhanced_validate_prediction_success_total")
|
||||
else:
|
||||
metrics.increment_counter("enhanced_validate_prediction_failed_total")
|
||||
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"validation_result": validation_result,
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_validate_prediction_errors_total")
|
||||
logger.error("Failed to validate enhanced prediction request",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to validate prediction request"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def enhanced_predictions_health_check():
|
||||
"""Enhanced health check endpoint for predictions"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "enhanced-predictions-service",
|
||||
"version": "2.0.0",
|
||||
"features": [
|
||||
"repository-pattern",
|
||||
"dependency-injection",
|
||||
"realtime-predictions",
|
||||
"batch-predictions",
|
||||
"prediction-caching",
|
||||
"performance-metrics",
|
||||
"request-validation"
|
||||
],
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
Reference in New Issue
Block a user