Files
bakery-ia/services/forecasting/app/api/predictions.py
2025-09-04 23:19:53 +02:00

413 lines
16 KiB
Python

"""
Enhanced Predictions API Endpoints with Repository Pattern
Real-time prediction capabilities using repository pattern with dependency injection
"""
import structlog
from fastapi import APIRouter, Depends, HTTPException, status, Query, Path, Request
from typing import List, Dict, Any, Optional
from datetime import date, datetime, timedelta
import uuid
from app.services.prediction_service import PredictionService
from app.services.forecasting_service import EnhancedForecastingService
from app.schemas.forecasts import ForecastRequest
from shared.auth.decorators import (
get_current_user_dep,
require_admin_role
)
from shared.database.base import create_database_manager
from shared.monitoring.decorators import track_execution_time
from shared.monitoring.metrics import get_metrics_collector
from app.core.config import settings
logger = structlog.get_logger()
router = APIRouter(tags=["enhanced-predictions"])
def get_enhanced_prediction_service():
"""Dependency injection for enhanced PredictionService"""
database_manager = create_database_manager(settings.DATABASE_URL, "forecasting-service")
return PredictionService(database_manager)
def get_enhanced_forecasting_service():
"""Dependency injection for EnhancedForecastingService"""
database_manager = create_database_manager(settings.DATABASE_URL, "forecasting-service")
return EnhancedForecastingService(database_manager)
@router.post("/tenants/{tenant_id}/predictions/realtime")
@track_execution_time("enhanced_realtime_prediction_duration_seconds", "forecasting-service")
async def generate_enhanced_realtime_prediction(
prediction_request: Dict[str, Any],
tenant_id: str = Path(..., description="Tenant ID"),
request_obj: Request = None,
prediction_service: PredictionService = Depends(get_enhanced_prediction_service)
):
"""Generate real-time prediction using enhanced repository pattern"""
metrics = get_metrics_collector(request_obj)
try:
logger.info("Generating enhanced real-time prediction",
tenant_id=tenant_id,
inventory_product_id=prediction_request.get("inventory_product_id"))
# Record metrics
if metrics:
metrics.increment_counter("enhanced_realtime_predictions_total")
# Validate required fields
required_fields = ["inventory_product_id", "model_id", "features"]
missing_fields = [field for field in required_fields if field not in prediction_request]
if missing_fields:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Missing required fields: {missing_fields}"
)
# Generate prediction using enhanced service
prediction_result = await prediction_service.predict(
model_id=prediction_request["model_id"],
model_path=prediction_request.get("model_path", ""),
features=prediction_request["features"],
confidence_level=prediction_request.get("confidence_level", 0.8)
)
if metrics:
metrics.increment_counter("enhanced_realtime_predictions_success_total")
logger.info("Enhanced real-time prediction generated successfully",
tenant_id=tenant_id,
prediction_value=prediction_result.get("prediction"))
return {
"tenant_id": tenant_id,
"inventory_product_id": prediction_request["inventory_product_id"],
"model_id": prediction_request["model_id"],
"prediction": prediction_result,
"generated_at": datetime.now().isoformat(),
"enhanced_features": True,
"repository_integration": True
}
except ValueError as e:
if metrics:
metrics.increment_counter("enhanced_prediction_validation_errors_total")
logger.error("Enhanced prediction validation error",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_realtime_predictions_errors_total")
logger.error("Enhanced real-time prediction failed",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Enhanced real-time prediction failed"
)
@router.post("/tenants/{tenant_id}/predictions/batch")
@track_execution_time("enhanced_batch_prediction_duration_seconds", "forecasting-service")
async def generate_enhanced_batch_predictions(
batch_request: Dict[str, Any],
tenant_id: str = Path(..., description="Tenant ID"),
request_obj: Request = None,
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
):
"""Generate batch predictions using enhanced repository pattern"""
metrics = get_metrics_collector(request_obj)
try:
logger.info("Generating enhanced batch predictions",
tenant_id=tenant_id,
predictions_count=len(batch_request.get("predictions", [])))
# Record metrics
if metrics:
metrics.increment_counter("enhanced_batch_predictions_total")
metrics.histogram("enhanced_batch_predictions_count", len(batch_request.get("predictions", [])))
# Validate batch request
if "predictions" not in batch_request or not batch_request["predictions"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Batch request must contain 'predictions' array"
)
# Generate batch predictions using enhanced service
batch_result = await enhanced_forecasting_service.generate_batch_predictions(
tenant_id=tenant_id,
batch_request=batch_request
)
if metrics:
metrics.increment_counter("enhanced_batch_predictions_success_total")
logger.info("Enhanced batch predictions generated successfully",
tenant_id=tenant_id,
batch_id=batch_result.get("batch_id"),
predictions_generated=len(batch_result.get("predictions", [])))
return {
**batch_result,
"enhanced_features": True,
"repository_integration": True
}
except ValueError as e:
if metrics:
metrics.increment_counter("enhanced_batch_prediction_validation_errors_total")
logger.error("Enhanced batch prediction validation error",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_batch_predictions_errors_total")
logger.error("Enhanced batch predictions failed",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Enhanced batch predictions failed"
)
@router.get("/tenants/{tenant_id}/predictions/cache")
@track_execution_time("enhanced_get_prediction_cache_duration_seconds", "forecasting-service")
async def get_enhanced_prediction_cache(
tenant_id: str = Path(..., description="Tenant ID"),
inventory_product_id: Optional[str] = Query(None, description="Filter by inventory product ID"),
skip: int = Query(0, description="Number of records to skip"),
limit: int = Query(100, description="Number of records to return"),
request_obj: Request = None,
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
):
"""Get cached predictions using enhanced repository pattern"""
metrics = get_metrics_collector(request_obj)
try:
# Record metrics
if metrics:
metrics.increment_counter("enhanced_get_prediction_cache_total")
# Get cached predictions using enhanced service
cached_predictions = await enhanced_forecasting_service.get_cached_predictions(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
skip=skip,
limit=limit
)
if metrics:
metrics.increment_counter("enhanced_get_prediction_cache_success_total")
return {
"tenant_id": tenant_id,
"cached_predictions": cached_predictions,
"total_returned": len(cached_predictions),
"filters": {
"inventory_product_id": inventory_product_id
},
"pagination": {
"skip": skip,
"limit": limit
},
"enhanced_features": True,
"repository_integration": True
}
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_get_prediction_cache_errors_total")
logger.error("Failed to get enhanced prediction cache",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get prediction cache"
)
@router.delete("/tenants/{tenant_id}/predictions/cache")
@track_execution_time("enhanced_clear_prediction_cache_duration_seconds", "forecasting-service")
async def clear_enhanced_prediction_cache(
tenant_id: str = Path(..., description="Tenant ID"),
inventory_product_id: Optional[str] = Query(None, description="Clear cache for specific inventory product ID"),
request_obj: Request = None,
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
):
"""Clear prediction cache using enhanced repository pattern"""
metrics = get_metrics_collector(request_obj)
try:
# Record metrics
if metrics:
metrics.increment_counter("enhanced_clear_prediction_cache_total")
# Clear cache using enhanced service
cleared_count = await enhanced_forecasting_service.clear_prediction_cache(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id
)
if metrics:
metrics.increment_counter("enhanced_clear_prediction_cache_success_total")
metrics.histogram("enhanced_cache_cleared_count", cleared_count)
logger.info("Enhanced prediction cache cleared",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
cleared_count=cleared_count)
return {
"message": "Prediction cache cleared successfully",
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id,
"cleared_count": cleared_count,
"enhanced_features": True,
"repository_integration": True
}
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_clear_prediction_cache_errors_total")
logger.error("Failed to clear enhanced prediction cache",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to clear prediction cache"
)
@router.get("/tenants/{tenant_id}/predictions/performance")
@track_execution_time("enhanced_get_prediction_performance_duration_seconds", "forecasting-service")
async def get_enhanced_prediction_performance(
tenant_id: str = Path(..., description="Tenant ID"),
model_id: Optional[str] = Query(None, description="Filter by model ID"),
start_date: Optional[date] = Query(None, description="Start date filter"),
end_date: Optional[date] = Query(None, description="End date filter"),
request_obj: Request = None,
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
):
"""Get prediction performance metrics using enhanced repository pattern"""
metrics = get_metrics_collector(request_obj)
try:
# Record metrics
if metrics:
metrics.increment_counter("enhanced_get_prediction_performance_total")
# Get performance metrics using enhanced service
performance = await enhanced_forecasting_service.get_prediction_performance(
tenant_id=tenant_id,
model_id=model_id,
start_date=start_date,
end_date=end_date
)
if metrics:
metrics.increment_counter("enhanced_get_prediction_performance_success_total")
return {
"tenant_id": tenant_id,
"performance_metrics": performance,
"filters": {
"model_id": model_id,
"start_date": start_date.isoformat() if start_date else None,
"end_date": end_date.isoformat() if end_date else None
},
"enhanced_features": True,
"repository_integration": True
}
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_get_prediction_performance_errors_total")
logger.error("Failed to get enhanced prediction performance",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get prediction performance"
)
@router.post("/tenants/{tenant_id}/predictions/validate")
@track_execution_time("enhanced_validate_prediction_duration_seconds", "forecasting-service")
async def validate_enhanced_prediction_request(
validation_request: Dict[str, Any],
tenant_id: str = Path(..., description="Tenant ID"),
request_obj: Request = None,
prediction_service: PredictionService = Depends(get_enhanced_prediction_service)
):
"""Validate prediction request without generating prediction"""
metrics = get_metrics_collector(request_obj)
try:
# Record metrics
if metrics:
metrics.increment_counter("enhanced_validate_prediction_total")
# Validate prediction request
validation_result = await prediction_service.validate_prediction_request(
validation_request
)
if metrics:
if validation_result.get("is_valid"):
metrics.increment_counter("enhanced_validate_prediction_success_total")
else:
metrics.increment_counter("enhanced_validate_prediction_failed_total")
return {
"tenant_id": tenant_id,
"validation_result": validation_result,
"enhanced_features": True,
"repository_integration": True
}
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_validate_prediction_errors_total")
logger.error("Failed to validate enhanced prediction request",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to validate prediction request"
)
@router.get("/health")
async def enhanced_predictions_health_check():
"""Enhanced health check endpoint for predictions"""
return {
"status": "healthy",
"service": "enhanced-predictions-service",
"version": "2.0.0",
"features": [
"repository-pattern",
"dependency-injection",
"realtime-predictions",
"batch-predictions",
"prediction-caching",
"performance-metrics",
"request-validation"
],
"timestamp": datetime.now().isoformat()
}