413 lines
16 KiB
Python
413 lines
16 KiB
Python
"""
|
|
Enhanced Predictions API Endpoints with Repository Pattern
|
|
Real-time prediction capabilities using repository pattern with dependency injection
|
|
"""
|
|
|
|
import structlog
|
|
from fastapi import APIRouter, Depends, HTTPException, status, Query, Path, Request
|
|
from typing import List, Dict, Any, Optional
|
|
from datetime import date, datetime, timedelta
|
|
import uuid
|
|
|
|
from app.services.prediction_service import PredictionService
|
|
from app.services.forecasting_service import EnhancedForecastingService
|
|
from app.schemas.forecasts import ForecastRequest
|
|
from shared.auth.decorators import (
|
|
get_current_user_dep,
|
|
require_admin_role
|
|
)
|
|
from shared.database.base import create_database_manager
|
|
from shared.monitoring.decorators import track_execution_time
|
|
from shared.monitoring.metrics import get_metrics_collector
|
|
from app.core.config import settings
|
|
|
|
logger = structlog.get_logger()
|
|
router = APIRouter(tags=["enhanced-predictions"])
|
|
|
|
def get_enhanced_prediction_service():
|
|
"""Dependency injection for enhanced PredictionService"""
|
|
database_manager = create_database_manager(settings.DATABASE_URL, "forecasting-service")
|
|
return PredictionService(database_manager)
|
|
|
|
def get_enhanced_forecasting_service():
|
|
"""Dependency injection for EnhancedForecastingService"""
|
|
database_manager = create_database_manager(settings.DATABASE_URL, "forecasting-service")
|
|
return EnhancedForecastingService(database_manager)
|
|
|
|
@router.post("/tenants/{tenant_id}/predictions/realtime")
|
|
@track_execution_time("enhanced_realtime_prediction_duration_seconds", "forecasting-service")
|
|
async def generate_enhanced_realtime_prediction(
|
|
prediction_request: Dict[str, Any],
|
|
tenant_id: str = Path(..., description="Tenant ID"),
|
|
request_obj: Request = None,
|
|
prediction_service: PredictionService = Depends(get_enhanced_prediction_service)
|
|
):
|
|
"""Generate real-time prediction using enhanced repository pattern"""
|
|
metrics = get_metrics_collector(request_obj)
|
|
|
|
try:
|
|
|
|
logger.info("Generating enhanced real-time prediction",
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=prediction_request.get("inventory_product_id"))
|
|
|
|
# Record metrics
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_realtime_predictions_total")
|
|
|
|
# Validate required fields
|
|
required_fields = ["inventory_product_id", "model_id", "features"]
|
|
missing_fields = [field for field in required_fields if field not in prediction_request]
|
|
if missing_fields:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"Missing required fields: {missing_fields}"
|
|
)
|
|
|
|
# Generate prediction using enhanced service
|
|
prediction_result = await prediction_service.predict(
|
|
model_id=prediction_request["model_id"],
|
|
model_path=prediction_request.get("model_path", ""),
|
|
features=prediction_request["features"],
|
|
confidence_level=prediction_request.get("confidence_level", 0.8)
|
|
)
|
|
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_realtime_predictions_success_total")
|
|
|
|
logger.info("Enhanced real-time prediction generated successfully",
|
|
tenant_id=tenant_id,
|
|
prediction_value=prediction_result.get("prediction"))
|
|
|
|
return {
|
|
"tenant_id": tenant_id,
|
|
"inventory_product_id": prediction_request["inventory_product_id"],
|
|
"model_id": prediction_request["model_id"],
|
|
"prediction": prediction_result,
|
|
"generated_at": datetime.now().isoformat(),
|
|
"enhanced_features": True,
|
|
"repository_integration": True
|
|
}
|
|
|
|
except ValueError as e:
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_prediction_validation_errors_total")
|
|
logger.error("Enhanced prediction validation error",
|
|
error=str(e),
|
|
tenant_id=tenant_id)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=str(e)
|
|
)
|
|
except Exception as e:
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_realtime_predictions_errors_total")
|
|
logger.error("Enhanced real-time prediction failed",
|
|
error=str(e),
|
|
tenant_id=tenant_id)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Enhanced real-time prediction failed"
|
|
)
|
|
|
|
|
|
@router.post("/tenants/{tenant_id}/predictions/batch")
|
|
@track_execution_time("enhanced_batch_prediction_duration_seconds", "forecasting-service")
|
|
async def generate_enhanced_batch_predictions(
|
|
batch_request: Dict[str, Any],
|
|
tenant_id: str = Path(..., description="Tenant ID"),
|
|
request_obj: Request = None,
|
|
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
|
):
|
|
"""Generate batch predictions using enhanced repository pattern"""
|
|
metrics = get_metrics_collector(request_obj)
|
|
|
|
try:
|
|
|
|
logger.info("Generating enhanced batch predictions",
|
|
tenant_id=tenant_id,
|
|
predictions_count=len(batch_request.get("predictions", [])))
|
|
|
|
# Record metrics
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_batch_predictions_total")
|
|
metrics.histogram("enhanced_batch_predictions_count", len(batch_request.get("predictions", [])))
|
|
|
|
# Validate batch request
|
|
if "predictions" not in batch_request or not batch_request["predictions"]:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Batch request must contain 'predictions' array"
|
|
)
|
|
|
|
# Generate batch predictions using enhanced service
|
|
batch_result = await enhanced_forecasting_service.generate_batch_predictions(
|
|
tenant_id=tenant_id,
|
|
batch_request=batch_request
|
|
)
|
|
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_batch_predictions_success_total")
|
|
|
|
logger.info("Enhanced batch predictions generated successfully",
|
|
tenant_id=tenant_id,
|
|
batch_id=batch_result.get("batch_id"),
|
|
predictions_generated=len(batch_result.get("predictions", [])))
|
|
|
|
return {
|
|
**batch_result,
|
|
"enhanced_features": True,
|
|
"repository_integration": True
|
|
}
|
|
|
|
except ValueError as e:
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_batch_prediction_validation_errors_total")
|
|
logger.error("Enhanced batch prediction validation error",
|
|
error=str(e),
|
|
tenant_id=tenant_id)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=str(e)
|
|
)
|
|
except Exception as e:
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_batch_predictions_errors_total")
|
|
logger.error("Enhanced batch predictions failed",
|
|
error=str(e),
|
|
tenant_id=tenant_id)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Enhanced batch predictions failed"
|
|
)
|
|
|
|
|
|
@router.get("/tenants/{tenant_id}/predictions/cache")
|
|
@track_execution_time("enhanced_get_prediction_cache_duration_seconds", "forecasting-service")
|
|
async def get_enhanced_prediction_cache(
|
|
tenant_id: str = Path(..., description="Tenant ID"),
|
|
inventory_product_id: Optional[str] = Query(None, description="Filter by inventory product ID"),
|
|
skip: int = Query(0, description="Number of records to skip"),
|
|
limit: int = Query(100, description="Number of records to return"),
|
|
request_obj: Request = None,
|
|
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
|
):
|
|
"""Get cached predictions using enhanced repository pattern"""
|
|
metrics = get_metrics_collector(request_obj)
|
|
|
|
try:
|
|
|
|
# Record metrics
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_get_prediction_cache_total")
|
|
|
|
# Get cached predictions using enhanced service
|
|
cached_predictions = await enhanced_forecasting_service.get_cached_predictions(
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=inventory_product_id,
|
|
skip=skip,
|
|
limit=limit
|
|
)
|
|
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_get_prediction_cache_success_total")
|
|
|
|
return {
|
|
"tenant_id": tenant_id,
|
|
"cached_predictions": cached_predictions,
|
|
"total_returned": len(cached_predictions),
|
|
"filters": {
|
|
"inventory_product_id": inventory_product_id
|
|
},
|
|
"pagination": {
|
|
"skip": skip,
|
|
"limit": limit
|
|
},
|
|
"enhanced_features": True,
|
|
"repository_integration": True
|
|
}
|
|
|
|
except Exception as e:
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_get_prediction_cache_errors_total")
|
|
logger.error("Failed to get enhanced prediction cache",
|
|
tenant_id=tenant_id,
|
|
error=str(e))
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to get prediction cache"
|
|
)
|
|
|
|
|
|
@router.delete("/tenants/{tenant_id}/predictions/cache")
|
|
@track_execution_time("enhanced_clear_prediction_cache_duration_seconds", "forecasting-service")
|
|
async def clear_enhanced_prediction_cache(
|
|
tenant_id: str = Path(..., description="Tenant ID"),
|
|
inventory_product_id: Optional[str] = Query(None, description="Clear cache for specific inventory product ID"),
|
|
request_obj: Request = None,
|
|
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
|
):
|
|
"""Clear prediction cache using enhanced repository pattern"""
|
|
metrics = get_metrics_collector(request_obj)
|
|
|
|
try:
|
|
|
|
# Record metrics
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_clear_prediction_cache_total")
|
|
|
|
# Clear cache using enhanced service
|
|
cleared_count = await enhanced_forecasting_service.clear_prediction_cache(
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=inventory_product_id
|
|
)
|
|
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_clear_prediction_cache_success_total")
|
|
metrics.histogram("enhanced_cache_cleared_count", cleared_count)
|
|
|
|
logger.info("Enhanced prediction cache cleared",
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=inventory_product_id,
|
|
cleared_count=cleared_count)
|
|
|
|
return {
|
|
"message": "Prediction cache cleared successfully",
|
|
"tenant_id": tenant_id,
|
|
"inventory_product_id": inventory_product_id,
|
|
"cleared_count": cleared_count,
|
|
"enhanced_features": True,
|
|
"repository_integration": True
|
|
}
|
|
|
|
except Exception as e:
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_clear_prediction_cache_errors_total")
|
|
logger.error("Failed to clear enhanced prediction cache",
|
|
tenant_id=tenant_id,
|
|
error=str(e))
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to clear prediction cache"
|
|
)
|
|
|
|
|
|
@router.get("/tenants/{tenant_id}/predictions/performance")
|
|
@track_execution_time("enhanced_get_prediction_performance_duration_seconds", "forecasting-service")
|
|
async def get_enhanced_prediction_performance(
|
|
tenant_id: str = Path(..., description="Tenant ID"),
|
|
model_id: Optional[str] = Query(None, description="Filter by model ID"),
|
|
start_date: Optional[date] = Query(None, description="Start date filter"),
|
|
end_date: Optional[date] = Query(None, description="End date filter"),
|
|
request_obj: Request = None,
|
|
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
|
):
|
|
"""Get prediction performance metrics using enhanced repository pattern"""
|
|
metrics = get_metrics_collector(request_obj)
|
|
|
|
try:
|
|
|
|
# Record metrics
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_get_prediction_performance_total")
|
|
|
|
# Get performance metrics using enhanced service
|
|
performance = await enhanced_forecasting_service.get_prediction_performance(
|
|
tenant_id=tenant_id,
|
|
model_id=model_id,
|
|
start_date=start_date,
|
|
end_date=end_date
|
|
)
|
|
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_get_prediction_performance_success_total")
|
|
|
|
return {
|
|
"tenant_id": tenant_id,
|
|
"performance_metrics": performance,
|
|
"filters": {
|
|
"model_id": model_id,
|
|
"start_date": start_date.isoformat() if start_date else None,
|
|
"end_date": end_date.isoformat() if end_date else None
|
|
},
|
|
"enhanced_features": True,
|
|
"repository_integration": True
|
|
}
|
|
|
|
except Exception as e:
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_get_prediction_performance_errors_total")
|
|
logger.error("Failed to get enhanced prediction performance",
|
|
tenant_id=tenant_id,
|
|
error=str(e))
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to get prediction performance"
|
|
)
|
|
|
|
|
|
@router.post("/tenants/{tenant_id}/predictions/validate")
|
|
@track_execution_time("enhanced_validate_prediction_duration_seconds", "forecasting-service")
|
|
async def validate_enhanced_prediction_request(
|
|
validation_request: Dict[str, Any],
|
|
tenant_id: str = Path(..., description="Tenant ID"),
|
|
request_obj: Request = None,
|
|
prediction_service: PredictionService = Depends(get_enhanced_prediction_service)
|
|
):
|
|
"""Validate prediction request without generating prediction"""
|
|
metrics = get_metrics_collector(request_obj)
|
|
|
|
try:
|
|
|
|
# Record metrics
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_validate_prediction_total")
|
|
|
|
# Validate prediction request
|
|
validation_result = await prediction_service.validate_prediction_request(
|
|
validation_request
|
|
)
|
|
|
|
if metrics:
|
|
if validation_result.get("is_valid"):
|
|
metrics.increment_counter("enhanced_validate_prediction_success_total")
|
|
else:
|
|
metrics.increment_counter("enhanced_validate_prediction_failed_total")
|
|
|
|
return {
|
|
"tenant_id": tenant_id,
|
|
"validation_result": validation_result,
|
|
"enhanced_features": True,
|
|
"repository_integration": True
|
|
}
|
|
|
|
except Exception as e:
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_validate_prediction_errors_total")
|
|
logger.error("Failed to validate enhanced prediction request",
|
|
tenant_id=tenant_id,
|
|
error=str(e))
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to validate prediction request"
|
|
)
|
|
|
|
|
|
@router.get("/health")
|
|
async def enhanced_predictions_health_check():
|
|
"""Enhanced health check endpoint for predictions"""
|
|
return {
|
|
"status": "healthy",
|
|
"service": "enhanced-predictions-service",
|
|
"version": "2.0.0",
|
|
"features": [
|
|
"repository-pattern",
|
|
"dependency-injection",
|
|
"realtime-predictions",
|
|
"batch-predictions",
|
|
"prediction-caching",
|
|
"performance-metrics",
|
|
"request-validation"
|
|
],
|
|
"timestamp": datetime.now().isoformat()
|
|
} |