Files
bakery-ia/services/forecasting/app/api/predictions.py

468 lines
18 KiB
Python
Raw Normal View History

2025-07-21 19:48:56 +02:00
"""
2025-08-08 09:08:41 +02:00
Enhanced Predictions API Endpoints with Repository Pattern
Real-time prediction capabilities using repository pattern with dependency injection
2025-07-21 19:48:56 +02:00
"""
import structlog
2025-08-08 09:08:41 +02:00
from fastapi import APIRouter, Depends, HTTPException, status, Query, Path, Request
from typing import List, Dict, Any, Optional
2025-07-21 19:48:56 +02:00
from datetime import date, datetime, timedelta
2025-08-02 17:09:53 +02:00
import uuid
2025-07-21 19:48:56 +02:00
2025-08-08 09:08:41 +02:00
from app.services.prediction_service import PredictionService
from app.services.forecasting_service import EnhancedForecastingService
from app.schemas.forecasts import ForecastRequest
2025-07-21 20:43:17 +02:00
from shared.auth.decorators import (
get_current_user_dep,
2025-08-02 17:09:53 +02:00
get_current_tenant_id_dep,
require_admin_role
2025-07-21 20:43:17 +02:00
)
2025-08-08 09:08:41 +02:00
from shared.database.base import create_database_manager
from shared.monitoring.decorators import track_execution_time
from shared.monitoring.metrics import get_metrics_collector
from app.core.config import settings
2025-07-21 19:48:56 +02:00
logger = structlog.get_logger()
2025-08-08 09:08:41 +02:00
router = APIRouter(tags=["enhanced-predictions"])
def get_enhanced_prediction_service():
"""Dependency injection for enhanced PredictionService"""
database_manager = create_database_manager(settings.DATABASE_URL, "forecasting-service")
return PredictionService(database_manager)
2025-07-21 19:48:56 +02:00
2025-08-08 09:08:41 +02:00
def get_enhanced_forecasting_service():
"""Dependency injection for EnhancedForecastingService"""
database_manager = create_database_manager(settings.DATABASE_URL, "forecasting-service")
return EnhancedForecastingService(database_manager)
2025-07-21 19:48:56 +02:00
2025-08-08 09:08:41 +02:00
@router.post("/tenants/{tenant_id}/predictions/realtime")
@track_execution_time("enhanced_realtime_prediction_duration_seconds", "forecasting-service")
async def generate_enhanced_realtime_prediction(
prediction_request: Dict[str, Any],
tenant_id: str = Path(..., description="Tenant ID"),
request_obj: Request = None,
current_tenant: str = Depends(get_current_tenant_id_dep),
prediction_service: PredictionService = Depends(get_enhanced_prediction_service)
2025-07-21 19:48:56 +02:00
):
2025-08-08 09:08:41 +02:00
"""Generate real-time prediction using enhanced repository pattern"""
metrics = get_metrics_collector(request_obj)
2025-07-21 19:48:56 +02:00
try:
2025-08-08 09:08:41 +02:00
# Enhanced tenant validation
if tenant_id != current_tenant:
if metrics:
metrics.increment_counter("enhanced_realtime_prediction_access_denied_total")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
2025-07-21 19:48:56 +02:00
2025-08-08 09:08:41 +02:00
logger.info("Generating enhanced real-time prediction",
tenant_id=tenant_id,
product_name=prediction_request.get("product_name"))
2025-07-21 19:48:56 +02:00
2025-08-08 09:08:41 +02:00
# Record metrics
if metrics:
metrics.increment_counter("enhanced_realtime_predictions_total")
2025-07-21 19:48:56 +02:00
2025-08-08 09:08:41 +02:00
# Validate required fields
required_fields = ["product_name", "model_id", "features"]
missing_fields = [field for field in required_fields if field not in prediction_request]
if missing_fields:
2025-07-21 19:48:56 +02:00
raise HTTPException(
2025-08-08 09:08:41 +02:00
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Missing required fields: {missing_fields}"
2025-07-21 19:48:56 +02:00
)
2025-08-08 09:08:41 +02:00
# Generate prediction using enhanced service
prediction_result = await prediction_service.predict(
model_id=prediction_request["model_id"],
model_path=prediction_request.get("model_path", ""),
features=prediction_request["features"],
confidence_level=prediction_request.get("confidence_level", 0.8)
2025-07-21 19:48:56 +02:00
)
2025-08-08 09:08:41 +02:00
if metrics:
metrics.increment_counter("enhanced_realtime_predictions_success_total")
logger.info("Enhanced real-time prediction generated successfully",
tenant_id=tenant_id,
prediction_value=prediction_result.get("prediction"))
2025-07-21 19:48:56 +02:00
return {
2025-08-08 09:08:41 +02:00
"tenant_id": tenant_id,
"product_name": prediction_request["product_name"],
"model_id": prediction_request["model_id"],
"prediction": prediction_result,
"generated_at": datetime.now().isoformat(),
"enhanced_features": True,
"repository_integration": True
2025-07-21 19:48:56 +02:00
}
2025-08-08 09:08:41 +02:00
except ValueError as e:
if metrics:
metrics.increment_counter("enhanced_prediction_validation_errors_total")
logger.error("Enhanced prediction validation error",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
2025-07-21 19:48:56 +02:00
except Exception as e:
2025-08-08 09:08:41 +02:00
if metrics:
metrics.increment_counter("enhanced_realtime_predictions_errors_total")
logger.error("Enhanced real-time prediction failed",
error=str(e),
tenant_id=tenant_id)
2025-07-21 19:48:56 +02:00
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
2025-08-08 09:08:41 +02:00
detail="Enhanced real-time prediction failed"
2025-07-21 19:48:56 +02:00
)
2025-08-08 09:08:41 +02:00
@router.post("/tenants/{tenant_id}/predictions/batch")
@track_execution_time("enhanced_batch_prediction_duration_seconds", "forecasting-service")
async def generate_enhanced_batch_predictions(
batch_request: Dict[str, Any],
tenant_id: str = Path(..., description="Tenant ID"),
request_obj: Request = None,
current_tenant: str = Depends(get_current_tenant_id_dep),
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
2025-07-21 19:48:56 +02:00
):
2025-08-08 09:08:41 +02:00
"""Generate batch predictions using enhanced repository pattern"""
metrics = get_metrics_collector(request_obj)
2025-07-21 19:48:56 +02:00
try:
2025-08-08 09:08:41 +02:00
# Enhanced tenant validation
if tenant_id != current_tenant:
if metrics:
metrics.increment_counter("enhanced_batch_prediction_access_denied_total")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
logger.info("Generating enhanced batch predictions",
tenant_id=tenant_id,
predictions_count=len(batch_request.get("predictions", [])))
# Record metrics
if metrics:
metrics.increment_counter("enhanced_batch_predictions_total")
metrics.histogram("enhanced_batch_predictions_count", len(batch_request.get("predictions", [])))
# Validate batch request
if "predictions" not in batch_request or not batch_request["predictions"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Batch request must contain 'predictions' array"
2025-07-21 19:48:56 +02:00
)
2025-08-08 09:08:41 +02:00
# Generate batch predictions using enhanced service
batch_result = await enhanced_forecasting_service.generate_batch_predictions(
tenant_id=tenant_id,
batch_request=batch_request
)
if metrics:
metrics.increment_counter("enhanced_batch_predictions_success_total")
logger.info("Enhanced batch predictions generated successfully",
tenant_id=tenant_id,
batch_id=batch_result.get("batch_id"),
predictions_generated=len(batch_result.get("predictions", [])))
2025-07-21 19:48:56 +02:00
return {
2025-08-08 09:08:41 +02:00
**batch_result,
"enhanced_features": True,
"repository_integration": True
2025-07-21 19:48:56 +02:00
}
2025-08-08 09:08:41 +02:00
except ValueError as e:
if metrics:
metrics.increment_counter("enhanced_batch_prediction_validation_errors_total")
logger.error("Enhanced batch prediction validation error",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
2025-07-21 19:48:56 +02:00
except Exception as e:
2025-08-08 09:08:41 +02:00
if metrics:
metrics.increment_counter("enhanced_batch_predictions_errors_total")
logger.error("Enhanced batch predictions failed",
error=str(e),
tenant_id=tenant_id)
2025-07-21 19:48:56 +02:00
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
2025-08-08 09:08:41 +02:00
detail="Enhanced batch predictions failed"
2025-07-21 19:48:56 +02:00
)
2025-08-08 09:08:41 +02:00
@router.get("/tenants/{tenant_id}/predictions/cache")
@track_execution_time("enhanced_get_prediction_cache_duration_seconds", "forecasting-service")
async def get_enhanced_prediction_cache(
tenant_id: str = Path(..., description="Tenant ID"),
product_name: Optional[str] = Query(None, description="Filter by product name"),
skip: int = Query(0, description="Number of records to skip"),
limit: int = Query(100, description="Number of records to return"),
request_obj: Request = None,
current_tenant: str = Depends(get_current_tenant_id_dep),
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
2025-08-02 17:09:53 +02:00
):
2025-08-08 09:08:41 +02:00
"""Get cached predictions using enhanced repository pattern"""
metrics = get_metrics_collector(request_obj)
2025-08-02 17:09:53 +02:00
try:
2025-08-08 09:08:41 +02:00
# Enhanced tenant validation
if tenant_id != current_tenant:
if metrics:
metrics.increment_counter("enhanced_get_cache_access_denied_total")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
# Record metrics
if metrics:
metrics.increment_counter("enhanced_get_prediction_cache_total")
2025-08-02 17:09:53 +02:00
2025-08-08 09:08:41 +02:00
# Get cached predictions using enhanced service
cached_predictions = await enhanced_forecasting_service.get_cached_predictions(
tenant_id=tenant_id,
product_name=product_name,
skip=skip,
limit=limit
2025-08-02 17:09:53 +02:00
)
2025-08-08 09:08:41 +02:00
if metrics:
metrics.increment_counter("enhanced_get_prediction_cache_success_total")
2025-08-02 17:09:53 +02:00
return {
"tenant_id": tenant_id,
2025-08-08 09:08:41 +02:00
"cached_predictions": cached_predictions,
"total_returned": len(cached_predictions),
"filters": {
"product_name": product_name
},
"pagination": {
"skip": skip,
"limit": limit
},
"enhanced_features": True,
"repository_integration": True
2025-08-02 17:09:53 +02:00
}
except Exception as e:
2025-08-08 09:08:41 +02:00
if metrics:
metrics.increment_counter("enhanced_get_prediction_cache_errors_total")
logger.error("Failed to get enhanced prediction cache",
tenant_id=tenant_id,
2025-08-02 17:09:53 +02:00
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
2025-08-08 09:08:41 +02:00
detail="Failed to get prediction cache"
2025-08-02 17:09:53 +02:00
)
2025-08-08 09:08:41 +02:00
2025-08-02 17:09:53 +02:00
@router.delete("/tenants/{tenant_id}/predictions/cache")
2025-08-08 09:08:41 +02:00
@track_execution_time("enhanced_clear_prediction_cache_duration_seconds", "forecasting-service")
async def clear_enhanced_prediction_cache(
tenant_id: str = Path(..., description="Tenant ID"),
product_name: Optional[str] = Query(None, description="Clear cache for specific product"),
request_obj: Request = None,
current_tenant: str = Depends(get_current_tenant_id_dep),
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
2025-08-02 17:09:53 +02:00
):
2025-08-08 09:08:41 +02:00
"""Clear prediction cache using enhanced repository pattern"""
metrics = get_metrics_collector(request_obj)
2025-08-02 17:09:53 +02:00
try:
2025-08-08 09:08:41 +02:00
# Enhanced tenant validation
if tenant_id != current_tenant:
if metrics:
metrics.increment_counter("enhanced_clear_cache_access_denied_total")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
# Record metrics
if metrics:
metrics.increment_counter("enhanced_clear_prediction_cache_total")
# Clear cache using enhanced service
cleared_count = await enhanced_forecasting_service.clear_prediction_cache(
tenant_id=tenant_id,
product_name=product_name
)
if metrics:
metrics.increment_counter("enhanced_clear_prediction_cache_success_total")
metrics.histogram("enhanced_cache_cleared_count", cleared_count)
logger.info("Enhanced prediction cache cleared",
tenant_id=tenant_id,
product_name=product_name,
cleared_count=cleared_count)
return {
"message": "Prediction cache cleared successfully",
"tenant_id": tenant_id,
"product_name": product_name,
"cleared_count": cleared_count,
"enhanced_features": True,
"repository_integration": True
}
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_clear_prediction_cache_errors_total")
logger.error("Failed to clear enhanced prediction cache",
tenant_id=tenant_id,
error=str(e))
2025-08-02 17:09:53 +02:00
raise HTTPException(
2025-08-08 09:08:41 +02:00
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to clear prediction cache"
2025-08-02 17:09:53 +02:00
)
2025-08-08 09:08:41 +02:00
@router.get("/tenants/{tenant_id}/predictions/performance")
@track_execution_time("enhanced_get_prediction_performance_duration_seconds", "forecasting-service")
async def get_enhanced_prediction_performance(
tenant_id: str = Path(..., description="Tenant ID"),
model_id: Optional[str] = Query(None, description="Filter by model ID"),
start_date: Optional[date] = Query(None, description="Start date filter"),
end_date: Optional[date] = Query(None, description="End date filter"),
request_obj: Request = None,
current_tenant: str = Depends(get_current_tenant_id_dep),
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
):
"""Get prediction performance metrics using enhanced repository pattern"""
metrics = get_metrics_collector(request_obj)
2025-08-02 17:09:53 +02:00
try:
2025-08-08 09:08:41 +02:00
# Enhanced tenant validation
if tenant_id != current_tenant:
if metrics:
metrics.increment_counter("enhanced_get_performance_access_denied_total")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
# Record metrics
if metrics:
metrics.increment_counter("enhanced_get_prediction_performance_total")
2025-08-02 17:09:53 +02:00
2025-08-08 09:08:41 +02:00
# Get performance metrics using enhanced service
performance = await enhanced_forecasting_service.get_prediction_performance(
tenant_id=tenant_id,
model_id=model_id,
start_date=start_date,
end_date=end_date
2025-08-02 17:09:53 +02:00
)
2025-08-08 09:08:41 +02:00
if metrics:
metrics.increment_counter("enhanced_get_prediction_performance_success_total")
return {
"tenant_id": tenant_id,
"performance_metrics": performance,
"filters": {
"model_id": model_id,
"start_date": start_date.isoformat() if start_date else None,
"end_date": end_date.isoformat() if end_date else None
},
"enhanced_features": True,
"repository_integration": True
}
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_get_prediction_performance_errors_total")
logger.error("Failed to get enhanced prediction performance",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get prediction performance"
2025-08-02 17:09:53 +02:00
)
2025-08-08 09:08:41 +02:00
@router.post("/tenants/{tenant_id}/predictions/validate")
@track_execution_time("enhanced_validate_prediction_duration_seconds", "forecasting-service")
async def validate_enhanced_prediction_request(
validation_request: Dict[str, Any],
tenant_id: str = Path(..., description="Tenant ID"),
request_obj: Request = None,
current_tenant: str = Depends(get_current_tenant_id_dep),
prediction_service: PredictionService = Depends(get_enhanced_prediction_service)
):
"""Validate prediction request without generating prediction"""
metrics = get_metrics_collector(request_obj)
try:
# Enhanced tenant validation
if tenant_id != current_tenant:
if metrics:
metrics.increment_counter("enhanced_validate_prediction_access_denied_total")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
2025-08-02 17:09:53 +02:00
2025-08-08 09:08:41 +02:00
# Record metrics
if metrics:
metrics.increment_counter("enhanced_validate_prediction_total")
2025-08-02 17:09:53 +02:00
2025-08-08 09:08:41 +02:00
# Validate prediction request
validation_result = await prediction_service.validate_prediction_request(
validation_request
)
if metrics:
if validation_result.get("is_valid"):
metrics.increment_counter("enhanced_validate_prediction_success_total")
else:
metrics.increment_counter("enhanced_validate_prediction_failed_total")
2025-08-02 17:09:53 +02:00
return {
"tenant_id": tenant_id,
2025-08-08 09:08:41 +02:00
"validation_result": validation_result,
"enhanced_features": True,
"repository_integration": True
2025-08-02 17:09:53 +02:00
}
except Exception as e:
2025-08-08 09:08:41 +02:00
if metrics:
metrics.increment_counter("enhanced_validate_prediction_errors_total")
logger.error("Failed to validate enhanced prediction request",
tenant_id=tenant_id,
2025-08-02 17:09:53 +02:00
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
2025-08-08 09:08:41 +02:00
detail="Failed to validate prediction request"
)
@router.get("/health")
async def enhanced_predictions_health_check():
"""Enhanced health check endpoint for predictions"""
return {
"status": "healthy",
"service": "enhanced-predictions-service",
"version": "2.0.0",
"features": [
"repository-pattern",
"dependency-injection",
"realtime-predictions",
"batch-predictions",
"prediction-caching",
"performance-metrics",
"request-validation"
],
"timestamp": datetime.now().isoformat()
}