REFACTOR - Database logic

This commit is contained in:
Urtzi Alfaro
2025-08-08 09:08:41 +02:00
parent 0154365bfc
commit 488bb3ef93
113 changed files with 22842 additions and 6503 deletions

View File

@@ -1,271 +1,468 @@
# ================================================================
# services/forecasting/app/api/predictions.py
# ================================================================
"""
Prediction API endpoints - Real-time prediction capabilities
Enhanced Predictions API Endpoints with Repository Pattern
Real-time prediction capabilities using repository pattern with dependency injection
"""
import structlog
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, Dict, Any
from fastapi import APIRouter, Depends, HTTPException, status, Query, Path, Request
from typing import List, Dict, Any, Optional
from datetime import date, datetime, timedelta
from sqlalchemy import select, delete, func
import uuid
from app.core.database import get_db
from app.services.prediction_service import PredictionService
from app.services.forecasting_service import EnhancedForecastingService
from app.schemas.forecasts import ForecastRequest
from shared.auth.decorators import (
get_current_user_dep,
get_current_tenant_id_dep,
get_current_user_dep,
require_admin_role
)
from app.services.prediction_service import PredictionService
from app.schemas.forecasts import ForecastRequest
from shared.database.base import create_database_manager
from shared.monitoring.decorators import track_execution_time
from shared.monitoring.metrics import get_metrics_collector
from app.core.config import settings
logger = structlog.get_logger()
router = APIRouter()
router = APIRouter(tags=["enhanced-predictions"])
# Initialize service
prediction_service = PredictionService()
def get_enhanced_prediction_service():
"""Dependency injection for enhanced PredictionService"""
database_manager = create_database_manager(settings.DATABASE_URL, "forecasting-service")
return PredictionService(database_manager)
@router.post("/realtime")
async def get_realtime_prediction(
product_name: str,
location: str,
forecast_date: date,
features: Dict[str, Any],
tenant_id: str = Depends(get_current_tenant_id_dep)
def get_enhanced_forecasting_service():
"""Dependency injection for EnhancedForecastingService"""
database_manager = create_database_manager(settings.DATABASE_URL, "forecasting-service")
return EnhancedForecastingService(database_manager)
@router.post("/tenants/{tenant_id}/predictions/realtime")
@track_execution_time("enhanced_realtime_prediction_duration_seconds", "forecasting-service")
async def generate_enhanced_realtime_prediction(
prediction_request: Dict[str, Any],
tenant_id: str = Path(..., description="Tenant ID"),
request_obj: Request = None,
current_tenant: str = Depends(get_current_tenant_id_dep),
prediction_service: PredictionService = Depends(get_enhanced_prediction_service)
):
"""Get real-time prediction without storing in database"""
"""Generate real-time prediction using enhanced repository pattern"""
metrics = get_metrics_collector(request_obj)
try:
# Get latest model
from app.services.forecasting_service import ForecastingService
forecasting_service = ForecastingService()
model_info = await forecasting_service._get_latest_model(
tenant_id, product_name, location
)
if not model_info:
# Enhanced tenant validation
if tenant_id != current_tenant:
if metrics:
metrics.increment_counter("enhanced_realtime_prediction_access_denied_total")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No trained model found for {product_name}"
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
# Generate prediction
prediction = await prediction_service.predict(
model_id=model_info["model_id"],
features=features,
confidence_level=0.8
)
logger.info("Generating enhanced real-time prediction",
tenant_id=tenant_id,
product_name=prediction_request.get("product_name"))
return {
"product_name": product_name,
"location": location,
"forecast_date": forecast_date,
"predicted_demand": prediction["demand"],
"confidence_lower": prediction["lower_bound"],
"confidence_upper": prediction["upper_bound"],
"model_id": model_info["model_id"],
"model_version": model_info["version"],
"generated_at": datetime.now(),
"features_used": features
}
# Record metrics
if metrics:
metrics.increment_counter("enhanced_realtime_predictions_total")
except HTTPException:
raise
except Exception as e:
logger.error("Error getting realtime prediction", error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal server error"
)
@router.get("/quick/{product_name}")
async def get_quick_prediction(
product_name: str,
location: str = Query(...),
days_ahead: int = Query(1, ge=1, le=7),
tenant_id: str = Depends(get_current_tenant_id_dep)
):
"""Get quick prediction for next few days"""
try:
# Generate predictions for the next N days
predictions = []
for day in range(1, days_ahead + 1):
forecast_date = date.today() + timedelta(days=day)
# Prepare basic features
features = {
"date": forecast_date.isoformat(),
"day_of_week": forecast_date.weekday(),
"is_weekend": forecast_date.weekday() >= 5,
"business_type": "individual"
}
# Get model and predict
from app.services.forecasting_service import ForecastingService
forecasting_service = ForecastingService()
model_info = await forecasting_service._get_latest_model(
tenant_id, product_name, location
# Validate required fields
required_fields = ["product_name", "model_id", "features"]
missing_fields = [field for field in required_fields if field not in prediction_request]
if missing_fields:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Missing required fields: {missing_fields}"
)
if model_info:
prediction = await prediction_service.predict(
model_id=model_info["model_id"],
features=features
)
predictions.append({
"date": forecast_date,
"predicted_demand": prediction["demand"],
"confidence_lower": prediction["lower_bound"],
"confidence_upper": prediction["upper_bound"]
})
# Generate prediction using enhanced service
prediction_result = await prediction_service.predict(
model_id=prediction_request["model_id"],
model_path=prediction_request.get("model_path", ""),
features=prediction_request["features"],
confidence_level=prediction_request.get("confidence_level", 0.8)
)
if metrics:
metrics.increment_counter("enhanced_realtime_predictions_success_total")
logger.info("Enhanced real-time prediction generated successfully",
tenant_id=tenant_id,
prediction_value=prediction_result.get("prediction"))
return {
"product_name": product_name,
"location": location,
"predictions": predictions,
"generated_at": datetime.now()
"tenant_id": tenant_id,
"product_name": prediction_request["product_name"],
"model_id": prediction_request["model_id"],
"prediction": prediction_result,
"generated_at": datetime.now().isoformat(),
"enhanced_features": True,
"repository_integration": True
}
except Exception as e:
logger.error("Error getting quick prediction", error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal server error"
)
@router.post("/tenants/{tenant_id}/predictions/cancel-batches")
async def cancel_tenant_prediction_batches(
tenant_id: str,
current_user = Depends(get_current_user_dep),
_admin_check = Depends(require_admin_role),
db: AsyncSession = Depends(get_db)
):
"""Cancel all active prediction batches for a tenant (admin only)"""
try:
tenant_uuid = uuid.UUID(tenant_id)
except ValueError:
except ValueError as e:
if metrics:
metrics.increment_counter("enhanced_prediction_validation_errors_total")
logger.error("Enhanced prediction validation error",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid tenant ID format"
detail=str(e)
)
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_realtime_predictions_errors_total")
logger.error("Enhanced real-time prediction failed",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Enhanced real-time prediction failed"
)
@router.post("/tenants/{tenant_id}/predictions/batch")
@track_execution_time("enhanced_batch_prediction_duration_seconds", "forecasting-service")
async def generate_enhanced_batch_predictions(
batch_request: Dict[str, Any],
tenant_id: str = Path(..., description="Tenant ID"),
request_obj: Request = None,
current_tenant: str = Depends(get_current_tenant_id_dep),
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
):
"""Generate batch predictions using enhanced repository pattern"""
metrics = get_metrics_collector(request_obj)
try:
from app.models.forecasts import PredictionBatch
# Enhanced tenant validation
if tenant_id != current_tenant:
if metrics:
metrics.increment_counter("enhanced_batch_prediction_access_denied_total")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
# Find active prediction batches
active_batches_query = select(PredictionBatch).where(
PredictionBatch.tenant_id == tenant_uuid,
PredictionBatch.status.in_(["queued", "running", "pending"])
logger.info("Generating enhanced batch predictions",
tenant_id=tenant_id,
predictions_count=len(batch_request.get("predictions", [])))
# Record metrics
if metrics:
metrics.increment_counter("enhanced_batch_predictions_total")
metrics.histogram("enhanced_batch_predictions_count", len(batch_request.get("predictions", [])))
# Validate batch request
if "predictions" not in batch_request or not batch_request["predictions"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Batch request must contain 'predictions' array"
)
# Generate batch predictions using enhanced service
batch_result = await enhanced_forecasting_service.generate_batch_predictions(
tenant_id=tenant_id,
batch_request=batch_request
)
active_batches_result = await db.execute(active_batches_query)
active_batches = active_batches_result.scalars().all()
batches_cancelled = 0
cancelled_batch_ids = []
errors = []
if metrics:
metrics.increment_counter("enhanced_batch_predictions_success_total")
for batch in active_batches:
try:
batch.status = "cancelled"
batch.updated_at = datetime.utcnow()
batch.cancelled_by = current_user.get("user_id")
batches_cancelled += 1
cancelled_batch_ids.append(str(batch.id))
logger.info("Cancelled prediction batch",
batch_id=str(batch.id),
tenant_id=tenant_id)
except Exception as e:
error_msg = f"Failed to cancel batch {batch.id}: {str(e)}"
errors.append(error_msg)
logger.error(error_msg)
if batches_cancelled > 0:
await db.commit()
logger.info("Enhanced batch predictions generated successfully",
tenant_id=tenant_id,
batch_id=batch_result.get("batch_id"),
predictions_generated=len(batch_result.get("predictions", [])))
return {
**batch_result,
"enhanced_features": True,
"repository_integration": True
}
except ValueError as e:
if metrics:
metrics.increment_counter("enhanced_batch_prediction_validation_errors_total")
logger.error("Enhanced batch prediction validation error",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_batch_predictions_errors_total")
logger.error("Enhanced batch predictions failed",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Enhanced batch predictions failed"
)
@router.get("/tenants/{tenant_id}/predictions/cache")
@track_execution_time("enhanced_get_prediction_cache_duration_seconds", "forecasting-service")
async def get_enhanced_prediction_cache(
tenant_id: str = Path(..., description="Tenant ID"),
product_name: Optional[str] = Query(None, description="Filter by product name"),
skip: int = Query(0, description="Number of records to skip"),
limit: int = Query(100, description="Number of records to return"),
request_obj: Request = None,
current_tenant: str = Depends(get_current_tenant_id_dep),
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
):
"""Get cached predictions using enhanced repository pattern"""
metrics = get_metrics_collector(request_obj)
try:
# Enhanced tenant validation
if tenant_id != current_tenant:
if metrics:
metrics.increment_counter("enhanced_get_cache_access_denied_total")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
# Record metrics
if metrics:
metrics.increment_counter("enhanced_get_prediction_cache_total")
# Get cached predictions using enhanced service
cached_predictions = await enhanced_forecasting_service.get_cached_predictions(
tenant_id=tenant_id,
product_name=product_name,
skip=skip,
limit=limit
)
if metrics:
metrics.increment_counter("enhanced_get_prediction_cache_success_total")
return {
"success": True,
"tenant_id": tenant_id,
"batches_cancelled": batches_cancelled,
"cancelled_batch_ids": cancelled_batch_ids,
"errors": errors,
"cancelled_at": datetime.utcnow().isoformat()
"cached_predictions": cached_predictions,
"total_returned": len(cached_predictions),
"filters": {
"product_name": product_name
},
"pagination": {
"skip": skip,
"limit": limit
},
"enhanced_features": True,
"repository_integration": True
}
except Exception as e:
await db.rollback()
logger.error("Failed to cancel tenant prediction batches",
tenant_id=tenant_id,
if metrics:
metrics.increment_counter("enhanced_get_prediction_cache_errors_total")
logger.error("Failed to get enhanced prediction cache",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to cancel prediction batches"
detail="Failed to get prediction cache"
)
@router.delete("/tenants/{tenant_id}/predictions/cache")
async def clear_tenant_prediction_cache(
tenant_id: str,
current_user = Depends(get_current_user_dep),
_admin_check = Depends(require_admin_role),
db: AsyncSession = Depends(get_db)
@track_execution_time("enhanced_clear_prediction_cache_duration_seconds", "forecasting-service")
async def clear_enhanced_prediction_cache(
tenant_id: str = Path(..., description="Tenant ID"),
product_name: Optional[str] = Query(None, description="Clear cache for specific product"),
request_obj: Request = None,
current_tenant: str = Depends(get_current_tenant_id_dep),
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
):
"""Clear all prediction cache for a tenant (admin only)"""
try:
tenant_uuid = uuid.UUID(tenant_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid tenant ID format"
)
"""Clear prediction cache using enhanced repository pattern"""
metrics = get_metrics_collector(request_obj)
try:
from app.models.forecasts import PredictionCache
# Enhanced tenant validation
if tenant_id != current_tenant:
if metrics:
metrics.increment_counter("enhanced_clear_cache_access_denied_total")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
# Count cache entries before deletion
cache_count_query = select(func.count(PredictionCache.id)).where(
PredictionCache.tenant_id == tenant_uuid
# Record metrics
if metrics:
metrics.increment_counter("enhanced_clear_prediction_cache_total")
# Clear cache using enhanced service
cleared_count = await enhanced_forecasting_service.clear_prediction_cache(
tenant_id=tenant_id,
product_name=product_name
)
cache_count_result = await db.execute(cache_count_query)
cache_count = cache_count_result.scalar()
# Delete cache entries
cache_delete_query = delete(PredictionCache).where(
PredictionCache.tenant_id == tenant_uuid
)
cache_delete_result = await db.execute(cache_delete_query)
if metrics:
metrics.increment_counter("enhanced_clear_prediction_cache_success_total")
metrics.histogram("enhanced_cache_cleared_count", cleared_count)
await db.commit()
logger.info("Cleared tenant prediction cache",
logger.info("Enhanced prediction cache cleared",
tenant_id=tenant_id,
cache_cleared=cache_delete_result.rowcount)
product_name=product_name,
cleared_count=cleared_count)
return {
"success": True,
"message": "Prediction cache cleared successfully",
"tenant_id": tenant_id,
"cache_cleared": cache_delete_result.rowcount,
"expected_count": cache_count,
"cleared_at": datetime.utcnow().isoformat()
"product_name": product_name,
"cleared_count": cleared_count,
"enhanced_features": True,
"repository_integration": True
}
except Exception as e:
await db.rollback()
logger.error("Failed to clear tenant prediction cache",
tenant_id=tenant_id,
if metrics:
metrics.increment_counter("enhanced_clear_prediction_cache_errors_total")
logger.error("Failed to clear enhanced prediction cache",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to clear prediction cache"
)
)
@router.get("/tenants/{tenant_id}/predictions/performance")
@track_execution_time("enhanced_get_prediction_performance_duration_seconds", "forecasting-service")
async def get_enhanced_prediction_performance(
tenant_id: str = Path(..., description="Tenant ID"),
model_id: Optional[str] = Query(None, description="Filter by model ID"),
start_date: Optional[date] = Query(None, description="Start date filter"),
end_date: Optional[date] = Query(None, description="End date filter"),
request_obj: Request = None,
current_tenant: str = Depends(get_current_tenant_id_dep),
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
):
"""Get prediction performance metrics using enhanced repository pattern"""
metrics = get_metrics_collector(request_obj)
try:
# Enhanced tenant validation
if tenant_id != current_tenant:
if metrics:
metrics.increment_counter("enhanced_get_performance_access_denied_total")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
# Record metrics
if metrics:
metrics.increment_counter("enhanced_get_prediction_performance_total")
# Get performance metrics using enhanced service
performance = await enhanced_forecasting_service.get_prediction_performance(
tenant_id=tenant_id,
model_id=model_id,
start_date=start_date,
end_date=end_date
)
if metrics:
metrics.increment_counter("enhanced_get_prediction_performance_success_total")
return {
"tenant_id": tenant_id,
"performance_metrics": performance,
"filters": {
"model_id": model_id,
"start_date": start_date.isoformat() if start_date else None,
"end_date": end_date.isoformat() if end_date else None
},
"enhanced_features": True,
"repository_integration": True
}
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_get_prediction_performance_errors_total")
logger.error("Failed to get enhanced prediction performance",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get prediction performance"
)
@router.post("/tenants/{tenant_id}/predictions/validate")
@track_execution_time("enhanced_validate_prediction_duration_seconds", "forecasting-service")
async def validate_enhanced_prediction_request(
validation_request: Dict[str, Any],
tenant_id: str = Path(..., description="Tenant ID"),
request_obj: Request = None,
current_tenant: str = Depends(get_current_tenant_id_dep),
prediction_service: PredictionService = Depends(get_enhanced_prediction_service)
):
"""Validate prediction request without generating prediction"""
metrics = get_metrics_collector(request_obj)
try:
# Enhanced tenant validation
if tenant_id != current_tenant:
if metrics:
metrics.increment_counter("enhanced_validate_prediction_access_denied_total")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
# Record metrics
if metrics:
metrics.increment_counter("enhanced_validate_prediction_total")
# Validate prediction request
validation_result = await prediction_service.validate_prediction_request(
validation_request
)
if metrics:
if validation_result.get("is_valid"):
metrics.increment_counter("enhanced_validate_prediction_success_total")
else:
metrics.increment_counter("enhanced_validate_prediction_failed_total")
return {
"tenant_id": tenant_id,
"validation_result": validation_result,
"enhanced_features": True,
"repository_integration": True
}
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_validate_prediction_errors_total")
logger.error("Failed to validate enhanced prediction request",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to validate prediction request"
)
@router.get("/health")
async def enhanced_predictions_health_check():
"""Enhanced health check endpoint for predictions"""
return {
"status": "healthy",
"service": "enhanced-predictions-service",
"version": "2.0.0",
"features": [
"repository-pattern",
"dependency-injection",
"realtime-predictions",
"batch-predictions",
"prediction-caching",
"performance-metrics",
"request-validation"
],
"timestamp": datetime.now().isoformat()
}