REFACTOR - Database logic
This commit is contained in:
@@ -1,271 +1,468 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/api/predictions.py
|
||||
# ================================================================
|
||||
"""
|
||||
Prediction API endpoints - Real-time prediction capabilities
|
||||
Enhanced Predictions API Endpoints with Repository Pattern
|
||||
Real-time prediction capabilities using repository pattern with dependency injection
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import List, Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query, Path, Request
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import date, datetime, timedelta
|
||||
from sqlalchemy import select, delete, func
|
||||
import uuid
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.services.prediction_service import PredictionService
|
||||
from app.services.forecasting_service import EnhancedForecastingService
|
||||
from app.schemas.forecasts import ForecastRequest
|
||||
from shared.auth.decorators import (
|
||||
get_current_user_dep,
|
||||
get_current_tenant_id_dep,
|
||||
get_current_user_dep,
|
||||
require_admin_role
|
||||
)
|
||||
from app.services.prediction_service import PredictionService
|
||||
from app.schemas.forecasts import ForecastRequest
|
||||
from shared.database.base import create_database_manager
|
||||
from shared.monitoring.decorators import track_execution_time
|
||||
from shared.monitoring.metrics import get_metrics_collector
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter()
|
||||
router = APIRouter(tags=["enhanced-predictions"])
|
||||
|
||||
# Initialize service
|
||||
prediction_service = PredictionService()
|
||||
def get_enhanced_prediction_service():
|
||||
"""Dependency injection for enhanced PredictionService"""
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "forecasting-service")
|
||||
return PredictionService(database_manager)
|
||||
|
||||
@router.post("/realtime")
|
||||
async def get_realtime_prediction(
|
||||
product_name: str,
|
||||
location: str,
|
||||
forecast_date: date,
|
||||
features: Dict[str, Any],
|
||||
tenant_id: str = Depends(get_current_tenant_id_dep)
|
||||
def get_enhanced_forecasting_service():
|
||||
"""Dependency injection for EnhancedForecastingService"""
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "forecasting-service")
|
||||
return EnhancedForecastingService(database_manager)
|
||||
|
||||
@router.post("/tenants/{tenant_id}/predictions/realtime")
|
||||
@track_execution_time("enhanced_realtime_prediction_duration_seconds", "forecasting-service")
|
||||
async def generate_enhanced_realtime_prediction(
|
||||
prediction_request: Dict[str, Any],
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
request_obj: Request = None,
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
prediction_service: PredictionService = Depends(get_enhanced_prediction_service)
|
||||
):
|
||||
"""Get real-time prediction without storing in database"""
|
||||
"""Generate real-time prediction using enhanced repository pattern"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
|
||||
# Get latest model
|
||||
from app.services.forecasting_service import ForecastingService
|
||||
forecasting_service = ForecastingService()
|
||||
|
||||
model_info = await forecasting_service._get_latest_model(
|
||||
tenant_id, product_name, location
|
||||
)
|
||||
|
||||
if not model_info:
|
||||
# Enhanced tenant validation
|
||||
if tenant_id != current_tenant:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_realtime_prediction_access_denied_total")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No trained model found for {product_name}"
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to tenant resources"
|
||||
)
|
||||
|
||||
# Generate prediction
|
||||
prediction = await prediction_service.predict(
|
||||
model_id=model_info["model_id"],
|
||||
features=features,
|
||||
confidence_level=0.8
|
||||
)
|
||||
logger.info("Generating enhanced real-time prediction",
|
||||
tenant_id=tenant_id,
|
||||
product_name=prediction_request.get("product_name"))
|
||||
|
||||
return {
|
||||
"product_name": product_name,
|
||||
"location": location,
|
||||
"forecast_date": forecast_date,
|
||||
"predicted_demand": prediction["demand"],
|
||||
"confidence_lower": prediction["lower_bound"],
|
||||
"confidence_upper": prediction["upper_bound"],
|
||||
"model_id": model_info["model_id"],
|
||||
"model_version": model_info["version"],
|
||||
"generated_at": datetime.now(),
|
||||
"features_used": features
|
||||
}
|
||||
# Record metrics
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_realtime_predictions_total")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Error getting realtime prediction", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal server error"
|
||||
)
|
||||
|
||||
@router.get("/quick/{product_name}")
|
||||
async def get_quick_prediction(
|
||||
product_name: str,
|
||||
location: str = Query(...),
|
||||
days_ahead: int = Query(1, ge=1, le=7),
|
||||
tenant_id: str = Depends(get_current_tenant_id_dep)
|
||||
):
|
||||
"""Get quick prediction for next few days"""
|
||||
|
||||
try:
|
||||
|
||||
# Generate predictions for the next N days
|
||||
predictions = []
|
||||
|
||||
for day in range(1, days_ahead + 1):
|
||||
forecast_date = date.today() + timedelta(days=day)
|
||||
|
||||
# Prepare basic features
|
||||
features = {
|
||||
"date": forecast_date.isoformat(),
|
||||
"day_of_week": forecast_date.weekday(),
|
||||
"is_weekend": forecast_date.weekday() >= 5,
|
||||
"business_type": "individual"
|
||||
}
|
||||
|
||||
# Get model and predict
|
||||
from app.services.forecasting_service import ForecastingService
|
||||
forecasting_service = ForecastingService()
|
||||
|
||||
model_info = await forecasting_service._get_latest_model(
|
||||
tenant_id, product_name, location
|
||||
# Validate required fields
|
||||
required_fields = ["product_name", "model_id", "features"]
|
||||
missing_fields = [field for field in required_fields if field not in prediction_request]
|
||||
if missing_fields:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Missing required fields: {missing_fields}"
|
||||
)
|
||||
|
||||
if model_info:
|
||||
prediction = await prediction_service.predict(
|
||||
model_id=model_info["model_id"],
|
||||
features=features
|
||||
)
|
||||
|
||||
predictions.append({
|
||||
"date": forecast_date,
|
||||
"predicted_demand": prediction["demand"],
|
||||
"confidence_lower": prediction["lower_bound"],
|
||||
"confidence_upper": prediction["upper_bound"]
|
||||
})
|
||||
|
||||
# Generate prediction using enhanced service
|
||||
prediction_result = await prediction_service.predict(
|
||||
model_id=prediction_request["model_id"],
|
||||
model_path=prediction_request.get("model_path", ""),
|
||||
features=prediction_request["features"],
|
||||
confidence_level=prediction_request.get("confidence_level", 0.8)
|
||||
)
|
||||
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_realtime_predictions_success_total")
|
||||
|
||||
logger.info("Enhanced real-time prediction generated successfully",
|
||||
tenant_id=tenant_id,
|
||||
prediction_value=prediction_result.get("prediction"))
|
||||
|
||||
return {
|
||||
"product_name": product_name,
|
||||
"location": location,
|
||||
"predictions": predictions,
|
||||
"generated_at": datetime.now()
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": prediction_request["product_name"],
|
||||
"model_id": prediction_request["model_id"],
|
||||
"prediction": prediction_result,
|
||||
"generated_at": datetime.now().isoformat(),
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting quick prediction", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal server error"
|
||||
)
|
||||
|
||||
@router.post("/tenants/{tenant_id}/predictions/cancel-batches")
|
||||
async def cancel_tenant_prediction_batches(
|
||||
tenant_id: str,
|
||||
current_user = Depends(get_current_user_dep),
|
||||
_admin_check = Depends(require_admin_role),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Cancel all active prediction batches for a tenant (admin only)"""
|
||||
try:
|
||||
tenant_uuid = uuid.UUID(tenant_id)
|
||||
except ValueError:
|
||||
except ValueError as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_prediction_validation_errors_total")
|
||||
logger.error("Enhanced prediction validation error",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid tenant ID format"
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_realtime_predictions_errors_total")
|
||||
logger.error("Enhanced real-time prediction failed",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Enhanced real-time prediction failed"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/tenants/{tenant_id}/predictions/batch")
|
||||
@track_execution_time("enhanced_batch_prediction_duration_seconds", "forecasting-service")
|
||||
async def generate_enhanced_batch_predictions(
|
||||
batch_request: Dict[str, Any],
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
request_obj: Request = None,
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
||||
):
|
||||
"""Generate batch predictions using enhanced repository pattern"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
from app.models.forecasts import PredictionBatch
|
||||
# Enhanced tenant validation
|
||||
if tenant_id != current_tenant:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_batch_prediction_access_denied_total")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to tenant resources"
|
||||
)
|
||||
|
||||
# Find active prediction batches
|
||||
active_batches_query = select(PredictionBatch).where(
|
||||
PredictionBatch.tenant_id == tenant_uuid,
|
||||
PredictionBatch.status.in_(["queued", "running", "pending"])
|
||||
logger.info("Generating enhanced batch predictions",
|
||||
tenant_id=tenant_id,
|
||||
predictions_count=len(batch_request.get("predictions", [])))
|
||||
|
||||
# Record metrics
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_batch_predictions_total")
|
||||
metrics.histogram("enhanced_batch_predictions_count", len(batch_request.get("predictions", [])))
|
||||
|
||||
# Validate batch request
|
||||
if "predictions" not in batch_request or not batch_request["predictions"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Batch request must contain 'predictions' array"
|
||||
)
|
||||
|
||||
# Generate batch predictions using enhanced service
|
||||
batch_result = await enhanced_forecasting_service.generate_batch_predictions(
|
||||
tenant_id=tenant_id,
|
||||
batch_request=batch_request
|
||||
)
|
||||
active_batches_result = await db.execute(active_batches_query)
|
||||
active_batches = active_batches_result.scalars().all()
|
||||
|
||||
batches_cancelled = 0
|
||||
cancelled_batch_ids = []
|
||||
errors = []
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_batch_predictions_success_total")
|
||||
|
||||
for batch in active_batches:
|
||||
try:
|
||||
batch.status = "cancelled"
|
||||
batch.updated_at = datetime.utcnow()
|
||||
batch.cancelled_by = current_user.get("user_id")
|
||||
batches_cancelled += 1
|
||||
cancelled_batch_ids.append(str(batch.id))
|
||||
|
||||
logger.info("Cancelled prediction batch",
|
||||
batch_id=str(batch.id),
|
||||
tenant_id=tenant_id)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to cancel batch {batch.id}: {str(e)}"
|
||||
errors.append(error_msg)
|
||||
logger.error(error_msg)
|
||||
|
||||
if batches_cancelled > 0:
|
||||
await db.commit()
|
||||
logger.info("Enhanced batch predictions generated successfully",
|
||||
tenant_id=tenant_id,
|
||||
batch_id=batch_result.get("batch_id"),
|
||||
predictions_generated=len(batch_result.get("predictions", [])))
|
||||
|
||||
return {
|
||||
**batch_result,
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_batch_prediction_validation_errors_total")
|
||||
logger.error("Enhanced batch prediction validation error",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_batch_predictions_errors_total")
|
||||
logger.error("Enhanced batch predictions failed",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Enhanced batch predictions failed"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}/predictions/cache")
|
||||
@track_execution_time("enhanced_get_prediction_cache_duration_seconds", "forecasting-service")
|
||||
async def get_enhanced_prediction_cache(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
product_name: Optional[str] = Query(None, description="Filter by product name"),
|
||||
skip: int = Query(0, description="Number of records to skip"),
|
||||
limit: int = Query(100, description="Number of records to return"),
|
||||
request_obj: Request = None,
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
||||
):
|
||||
"""Get cached predictions using enhanced repository pattern"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
# Enhanced tenant validation
|
||||
if tenant_id != current_tenant:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_get_cache_access_denied_total")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to tenant resources"
|
||||
)
|
||||
|
||||
# Record metrics
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_get_prediction_cache_total")
|
||||
|
||||
# Get cached predictions using enhanced service
|
||||
cached_predictions = await enhanced_forecasting_service.get_cached_predictions(
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
skip=skip,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_get_prediction_cache_success_total")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"tenant_id": tenant_id,
|
||||
"batches_cancelled": batches_cancelled,
|
||||
"cancelled_batch_ids": cancelled_batch_ids,
|
||||
"errors": errors,
|
||||
"cancelled_at": datetime.utcnow().isoformat()
|
||||
"cached_predictions": cached_predictions,
|
||||
"total_returned": len(cached_predictions),
|
||||
"filters": {
|
||||
"product_name": product_name
|
||||
},
|
||||
"pagination": {
|
||||
"skip": skip,
|
||||
"limit": limit
|
||||
},
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error("Failed to cancel tenant prediction batches",
|
||||
tenant_id=tenant_id,
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_get_prediction_cache_errors_total")
|
||||
logger.error("Failed to get enhanced prediction cache",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to cancel prediction batches"
|
||||
detail="Failed to get prediction cache"
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/tenants/{tenant_id}/predictions/cache")
|
||||
async def clear_tenant_prediction_cache(
|
||||
tenant_id: str,
|
||||
current_user = Depends(get_current_user_dep),
|
||||
_admin_check = Depends(require_admin_role),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
@track_execution_time("enhanced_clear_prediction_cache_duration_seconds", "forecasting-service")
|
||||
async def clear_enhanced_prediction_cache(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
product_name: Optional[str] = Query(None, description="Clear cache for specific product"),
|
||||
request_obj: Request = None,
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
||||
):
|
||||
"""Clear all prediction cache for a tenant (admin only)"""
|
||||
try:
|
||||
tenant_uuid = uuid.UUID(tenant_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid tenant ID format"
|
||||
)
|
||||
"""Clear prediction cache using enhanced repository pattern"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
from app.models.forecasts import PredictionCache
|
||||
# Enhanced tenant validation
|
||||
if tenant_id != current_tenant:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_clear_cache_access_denied_total")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to tenant resources"
|
||||
)
|
||||
|
||||
# Count cache entries before deletion
|
||||
cache_count_query = select(func.count(PredictionCache.id)).where(
|
||||
PredictionCache.tenant_id == tenant_uuid
|
||||
# Record metrics
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_clear_prediction_cache_total")
|
||||
|
||||
# Clear cache using enhanced service
|
||||
cleared_count = await enhanced_forecasting_service.clear_prediction_cache(
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name
|
||||
)
|
||||
cache_count_result = await db.execute(cache_count_query)
|
||||
cache_count = cache_count_result.scalar()
|
||||
|
||||
# Delete cache entries
|
||||
cache_delete_query = delete(PredictionCache).where(
|
||||
PredictionCache.tenant_id == tenant_uuid
|
||||
)
|
||||
cache_delete_result = await db.execute(cache_delete_query)
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_clear_prediction_cache_success_total")
|
||||
metrics.histogram("enhanced_cache_cleared_count", cleared_count)
|
||||
|
||||
await db.commit()
|
||||
|
||||
logger.info("Cleared tenant prediction cache",
|
||||
logger.info("Enhanced prediction cache cleared",
|
||||
tenant_id=tenant_id,
|
||||
cache_cleared=cache_delete_result.rowcount)
|
||||
product_name=product_name,
|
||||
cleared_count=cleared_count)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Prediction cache cleared successfully",
|
||||
"tenant_id": tenant_id,
|
||||
"cache_cleared": cache_delete_result.rowcount,
|
||||
"expected_count": cache_count,
|
||||
"cleared_at": datetime.utcnow().isoformat()
|
||||
"product_name": product_name,
|
||||
"cleared_count": cleared_count,
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error("Failed to clear tenant prediction cache",
|
||||
tenant_id=tenant_id,
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_clear_prediction_cache_errors_total")
|
||||
logger.error("Failed to clear enhanced prediction cache",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to clear prediction cache"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}/predictions/performance")
|
||||
@track_execution_time("enhanced_get_prediction_performance_duration_seconds", "forecasting-service")
|
||||
async def get_enhanced_prediction_performance(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
model_id: Optional[str] = Query(None, description="Filter by model ID"),
|
||||
start_date: Optional[date] = Query(None, description="Start date filter"),
|
||||
end_date: Optional[date] = Query(None, description="End date filter"),
|
||||
request_obj: Request = None,
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
||||
):
|
||||
"""Get prediction performance metrics using enhanced repository pattern"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
# Enhanced tenant validation
|
||||
if tenant_id != current_tenant:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_get_performance_access_denied_total")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to tenant resources"
|
||||
)
|
||||
|
||||
# Record metrics
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_get_prediction_performance_total")
|
||||
|
||||
# Get performance metrics using enhanced service
|
||||
performance = await enhanced_forecasting_service.get_prediction_performance(
|
||||
tenant_id=tenant_id,
|
||||
model_id=model_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_get_prediction_performance_success_total")
|
||||
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"performance_metrics": performance,
|
||||
"filters": {
|
||||
"model_id": model_id,
|
||||
"start_date": start_date.isoformat() if start_date else None,
|
||||
"end_date": end_date.isoformat() if end_date else None
|
||||
},
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_get_prediction_performance_errors_total")
|
||||
logger.error("Failed to get enhanced prediction performance",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get prediction performance"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/tenants/{tenant_id}/predictions/validate")
|
||||
@track_execution_time("enhanced_validate_prediction_duration_seconds", "forecasting-service")
|
||||
async def validate_enhanced_prediction_request(
|
||||
validation_request: Dict[str, Any],
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
request_obj: Request = None,
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
prediction_service: PredictionService = Depends(get_enhanced_prediction_service)
|
||||
):
|
||||
"""Validate prediction request without generating prediction"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
# Enhanced tenant validation
|
||||
if tenant_id != current_tenant:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_validate_prediction_access_denied_total")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to tenant resources"
|
||||
)
|
||||
|
||||
# Record metrics
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_validate_prediction_total")
|
||||
|
||||
# Validate prediction request
|
||||
validation_result = await prediction_service.validate_prediction_request(
|
||||
validation_request
|
||||
)
|
||||
|
||||
if metrics:
|
||||
if validation_result.get("is_valid"):
|
||||
metrics.increment_counter("enhanced_validate_prediction_success_total")
|
||||
else:
|
||||
metrics.increment_counter("enhanced_validate_prediction_failed_total")
|
||||
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"validation_result": validation_result,
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_validate_prediction_errors_total")
|
||||
logger.error("Failed to validate enhanced prediction request",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to validate prediction request"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def enhanced_predictions_health_check():
|
||||
"""Enhanced health check endpoint for predictions"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "enhanced-predictions-service",
|
||||
"version": "2.0.0",
|
||||
"features": [
|
||||
"repository-pattern",
|
||||
"dependency-injection",
|
||||
"realtime-predictions",
|
||||
"batch-predictions",
|
||||
"prediction-caching",
|
||||
"performance-metrics",
|
||||
"request-validation"
|
||||
],
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
Reference in New Issue
Block a user