REFACTOR - Database logic

This commit is contained in:
Urtzi Alfaro
2025-08-08 09:08:41 +02:00
parent 0154365bfc
commit 488bb3ef93
113 changed files with 22842 additions and 6503 deletions

View File

@@ -0,0 +1,20 @@
"""
Forecasting Service Repositories
Repository implementations for forecasting service
"""
from .base import ForecastingBaseRepository
from .forecast_repository import ForecastRepository
from .prediction_batch_repository import PredictionBatchRepository
from .forecast_alert_repository import ForecastAlertRepository
from .performance_metric_repository import PerformanceMetricRepository
from .prediction_cache_repository import PredictionCacheRepository
__all__ = [
"ForecastingBaseRepository",
"ForecastRepository",
"PredictionBatchRepository",
"ForecastAlertRepository",
"PerformanceMetricRepository",
"PredictionCacheRepository"
]

View File

@@ -0,0 +1,253 @@
"""
Base Repository for Forecasting Service
Service-specific repository base class with forecasting utilities
"""
from typing import Optional, List, Dict, Any, Type
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
from datetime import datetime, date, timedelta
import structlog
from shared.database.repository import BaseRepository
from shared.database.exceptions import DatabaseError
logger = structlog.get_logger()
class ForecastingBaseRepository(BaseRepository):
"""Base repository for forecasting service with common forecasting operations"""
def __init__(self, model: Type, session: AsyncSession, cache_ttl: Optional[int] = 600):
# Forecasting data benefits from medium cache time (10 minutes)
super().__init__(model, session, cache_ttl)
async def get_by_tenant_id(self, tenant_id: str, skip: int = 0, limit: int = 100) -> List:
"""Get records by tenant ID"""
if hasattr(self.model, 'tenant_id'):
return await self.get_multi(
skip=skip,
limit=limit,
filters={"tenant_id": tenant_id},
order_by="created_at",
order_desc=True
)
return await self.get_multi(skip=skip, limit=limit)
async def get_by_product_name(
self,
tenant_id: str,
product_name: str,
skip: int = 0,
limit: int = 100
) -> List:
"""Get records by tenant and product"""
if hasattr(self.model, 'product_name'):
return await self.get_multi(
skip=skip,
limit=limit,
filters={
"tenant_id": tenant_id,
"product_name": product_name
},
order_by="created_at",
order_desc=True
)
return await self.get_by_tenant_id(tenant_id, skip, limit)
async def get_by_date_range(
self,
tenant_id: str,
start_date: datetime,
end_date: datetime,
skip: int = 0,
limit: int = 100
) -> List:
"""Get records within date range for a tenant"""
if not hasattr(self.model, 'forecast_date') and not hasattr(self.model, 'created_at'):
logger.warning(f"Model {self.model.__name__} has no date field for filtering")
return []
try:
table_name = self.model.__tablename__
date_field = "forecast_date" if hasattr(self.model, 'forecast_date') else "created_at"
query_text = f"""
SELECT * FROM {table_name}
WHERE tenant_id = :tenant_id
AND {date_field} >= :start_date
AND {date_field} <= :end_date
ORDER BY {date_field} DESC
LIMIT :limit OFFSET :skip
"""
result = await self.session.execute(text(query_text), {
"tenant_id": tenant_id,
"start_date": start_date,
"end_date": end_date,
"limit": limit,
"skip": skip
})
# Convert rows to model objects
records = []
for row in result.fetchall():
record_dict = dict(row._mapping)
record = self.model(**record_dict)
records.append(record)
return records
except Exception as e:
logger.error("Failed to get records by date range",
model=self.model.__name__,
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Date range query failed: {str(e)}")
async def get_recent_records(
self,
tenant_id: str,
hours: int = 24,
skip: int = 0,
limit: int = 100
) -> List:
"""Get recent records for a tenant"""
cutoff_time = datetime.utcnow() - timedelta(hours=hours)
return await self.get_by_date_range(
tenant_id, cutoff_time, datetime.utcnow(), skip, limit
)
async def cleanup_old_records(self, days_old: int = 90) -> int:
"""Clean up old forecasting records"""
try:
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
table_name = self.model.__tablename__
# Use created_at or forecast_date for cleanup
date_field = "forecast_date" if hasattr(self.model, 'forecast_date') else "created_at"
query_text = f"""
DELETE FROM {table_name}
WHERE {date_field} < :cutoff_date
"""
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
deleted_count = result.rowcount
logger.info(f"Cleaned up old {self.model.__name__} records",
deleted_count=deleted_count,
days_old=days_old)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup old records",
model=self.model.__name__,
error=str(e))
raise DatabaseError(f"Cleanup failed: {str(e)}")
async def get_statistics_by_tenant(self, tenant_id: str) -> Dict[str, Any]:
"""Get statistics for a tenant"""
try:
table_name = self.model.__tablename__
# Get basic counts
total_records = await self.count(filters={"tenant_id": tenant_id})
# Get recent activity (records in last 7 days)
seven_days_ago = datetime.utcnow() - timedelta(days=7)
recent_records = len(await self.get_by_date_range(
tenant_id, seven_days_ago, datetime.utcnow(), limit=1000
))
# Get records by product if applicable
product_stats = {}
if hasattr(self.model, 'product_name'):
product_query = text(f"""
SELECT product_name, COUNT(*) as count
FROM {table_name}
WHERE tenant_id = :tenant_id
GROUP BY product_name
ORDER BY count DESC
""")
result = await self.session.execute(product_query, {"tenant_id": tenant_id})
product_stats = {row.product_name: row.count for row in result.fetchall()}
return {
"total_records": total_records,
"recent_records_7d": recent_records,
"records_by_product": product_stats
}
except Exception as e:
logger.error("Failed to get tenant statistics",
model=self.model.__name__,
tenant_id=tenant_id,
error=str(e))
return {
"total_records": 0,
"recent_records_7d": 0,
"records_by_product": {}
}
def _validate_forecast_data(self, data: Dict[str, Any], required_fields: List[str]) -> Dict[str, Any]:
"""Validate forecasting-related data"""
errors = []
for field in required_fields:
if field not in data or not data[field]:
errors.append(f"Missing required field: {field}")
# Validate tenant_id format if present
if "tenant_id" in data and data["tenant_id"]:
tenant_id = data["tenant_id"]
if not isinstance(tenant_id, str) or len(tenant_id) < 1:
errors.append("Invalid tenant_id format")
# Validate product_name if present
if "product_name" in data and data["product_name"]:
product_name = data["product_name"]
if not isinstance(product_name, str) or len(product_name) < 1:
errors.append("Invalid product_name format")
# Validate dates if present - accept datetime objects, date objects, and date strings
date_fields = ["forecast_date", "created_at", "evaluation_date", "expires_at"]
for field in date_fields:
if field in data and data[field]:
field_value = data[field]
field_type = type(field_value).__name__
if isinstance(field_value, (datetime, date)):
logger.debug(f"Date field {field} is valid {field_type}", field_value=str(field_value))
continue # Already a datetime or date, valid
elif isinstance(field_value, str):
# Try to parse the string date
try:
from dateutil.parser import parse
parse(field_value) # Just validate, don't convert yet
logger.debug(f"Date field {field} is valid string", field_value=field_value)
except (ValueError, TypeError) as e:
logger.error(f"Date parsing failed for {field}", field_value=field_value, error=str(e))
errors.append(f"Invalid {field} format - must be datetime or valid date string")
else:
logger.error(f"Date field {field} has invalid type {field_type}", field_value=str(field_value))
errors.append(f"Invalid {field} format - must be datetime or valid date string")
# Validate numeric fields
numeric_fields = [
"predicted_demand", "confidence_lower", "confidence_upper",
"mae", "mape", "rmse", "accuracy_score"
]
for field in numeric_fields:
if field in data and data[field] is not None:
try:
float(data[field])
except (ValueError, TypeError):
errors.append(f"Invalid {field} format - must be numeric")
return {
"is_valid": len(errors) == 0,
"errors": errors
}

View File

@@ -0,0 +1,375 @@
"""
Forecast Alert Repository
Repository for forecast alert operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
from datetime import datetime, timedelta
import structlog
from .base import ForecastingBaseRepository
from app.models.forecasts import ForecastAlert
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class ForecastAlertRepository(ForecastingBaseRepository):
"""Repository for forecast alert operations"""
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 300):
# Alerts change frequently, shorter cache time (5 minutes)
super().__init__(ForecastAlert, session, cache_ttl)
async def create_alert(self, alert_data: Dict[str, Any]) -> ForecastAlert:
"""Create a new forecast alert"""
try:
# Validate alert data
validation_result = self._validate_forecast_data(
alert_data,
["tenant_id", "forecast_id", "alert_type", "message"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid alert data: {validation_result['errors']}")
# Set default values
if "severity" not in alert_data:
alert_data["severity"] = "medium"
if "is_active" not in alert_data:
alert_data["is_active"] = True
if "notification_sent" not in alert_data:
alert_data["notification_sent"] = False
alert = await self.create(alert_data)
logger.info("Forecast alert created",
alert_id=alert.id,
tenant_id=alert.tenant_id,
alert_type=alert.alert_type,
severity=alert.severity)
return alert
except ValidationError:
raise
except Exception as e:
logger.error("Failed to create forecast alert",
tenant_id=alert_data.get("tenant_id"),
error=str(e))
raise DatabaseError(f"Failed to create alert: {str(e)}")
async def get_active_alerts(
self,
tenant_id: str,
alert_type: str = None,
severity: str = None
) -> List[ForecastAlert]:
"""Get active alerts for a tenant"""
try:
filters = {
"tenant_id": tenant_id,
"is_active": True
}
if alert_type:
filters["alert_type"] = alert_type
if severity:
filters["severity"] = severity
return await self.get_multi(
filters=filters,
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get active alerts",
tenant_id=tenant_id,
error=str(e))
return []
async def acknowledge_alert(
self,
alert_id: str,
acknowledged_by: str = None
) -> Optional[ForecastAlert]:
"""Acknowledge an alert"""
try:
update_data = {
"acknowledged_at": datetime.utcnow()
}
if acknowledged_by:
# Store in message or create a new field if needed
current_alert = await self.get_by_id(alert_id)
if current_alert:
update_data["message"] = f"{current_alert.message} (Acknowledged by: {acknowledged_by})"
updated_alert = await self.update(alert_id, update_data)
logger.info("Alert acknowledged",
alert_id=alert_id,
acknowledged_by=acknowledged_by)
return updated_alert
except Exception as e:
logger.error("Failed to acknowledge alert",
alert_id=alert_id,
error=str(e))
raise DatabaseError(f"Failed to acknowledge alert: {str(e)}")
async def resolve_alert(
self,
alert_id: str,
resolved_by: str = None
) -> Optional[ForecastAlert]:
"""Resolve an alert"""
try:
update_data = {
"resolved_at": datetime.utcnow(),
"is_active": False
}
if resolved_by:
current_alert = await self.get_by_id(alert_id)
if current_alert:
update_data["message"] = f"{current_alert.message} (Resolved by: {resolved_by})"
updated_alert = await self.update(alert_id, update_data)
logger.info("Alert resolved",
alert_id=alert_id,
resolved_by=resolved_by)
return updated_alert
except Exception as e:
logger.error("Failed to resolve alert",
alert_id=alert_id,
error=str(e))
raise DatabaseError(f"Failed to resolve alert: {str(e)}")
async def mark_notification_sent(
self,
alert_id: str,
notification_method: str
) -> Optional[ForecastAlert]:
"""Mark alert notification as sent"""
try:
update_data = {
"notification_sent": True,
"notification_method": notification_method
}
updated_alert = await self.update(alert_id, update_data)
logger.debug("Alert notification marked as sent",
alert_id=alert_id,
method=notification_method)
return updated_alert
except Exception as e:
logger.error("Failed to mark notification as sent",
alert_id=alert_id,
error=str(e))
return None
async def get_unnotified_alerts(self, tenant_id: str = None) -> List[ForecastAlert]:
"""Get alerts that haven't been notified yet"""
try:
filters = {
"is_active": True,
"notification_sent": False
}
if tenant_id:
filters["tenant_id"] = tenant_id
return await self.get_multi(
filters=filters,
order_by="created_at",
order_desc=False # Oldest first for notification
)
except Exception as e:
logger.error("Failed to get unnotified alerts",
tenant_id=tenant_id,
error=str(e))
return []
async def get_alert_statistics(self, tenant_id: str) -> Dict[str, Any]:
"""Get alert statistics for a tenant"""
try:
# Get counts by type
type_query = text("""
SELECT alert_type, COUNT(*) as count
FROM forecast_alerts
WHERE tenant_id = :tenant_id
GROUP BY alert_type
ORDER BY count DESC
""")
result = await self.session.execute(type_query, {"tenant_id": tenant_id})
alerts_by_type = {row.alert_type: row.count for row in result.fetchall()}
# Get counts by severity
severity_query = text("""
SELECT severity, COUNT(*) as count
FROM forecast_alerts
WHERE tenant_id = :tenant_id
GROUP BY severity
ORDER BY count DESC
""")
severity_result = await self.session.execute(severity_query, {"tenant_id": tenant_id})
alerts_by_severity = {row.severity: row.count for row in severity_result.fetchall()}
# Get status counts
total_alerts = await self.count(filters={"tenant_id": tenant_id})
active_alerts = await self.count(filters={
"tenant_id": tenant_id,
"is_active": True
})
acknowledged_alerts = await self.count(filters={
"tenant_id": tenant_id,
"acknowledged_at": "IS NOT NULL" # This won't work with our current filters
})
# Get recent activity (alerts in last 7 days)
seven_days_ago = datetime.utcnow() - timedelta(days=7)
recent_alerts = len(await self.get_by_date_range(
tenant_id, seven_days_ago, datetime.utcnow(), limit=1000
))
# Calculate response metrics
response_query = text("""
SELECT
AVG(EXTRACT(EPOCH FROM (acknowledged_at - created_at))/60) as avg_acknowledgment_time_minutes,
AVG(EXTRACT(EPOCH FROM (resolved_at - created_at))/60) as avg_resolution_time_minutes,
COUNT(CASE WHEN acknowledged_at IS NOT NULL THEN 1 END) as acknowledged_count,
COUNT(CASE WHEN resolved_at IS NOT NULL THEN 1 END) as resolved_count
FROM forecast_alerts
WHERE tenant_id = :tenant_id
""")
response_result = await self.session.execute(response_query, {"tenant_id": tenant_id})
response_row = response_result.fetchone()
return {
"total_alerts": total_alerts,
"active_alerts": active_alerts,
"resolved_alerts": total_alerts - active_alerts,
"alerts_by_type": alerts_by_type,
"alerts_by_severity": alerts_by_severity,
"recent_alerts_7d": recent_alerts,
"response_metrics": {
"avg_acknowledgment_time_minutes": float(response_row.avg_acknowledgment_time_minutes or 0),
"avg_resolution_time_minutes": float(response_row.avg_resolution_time_minutes or 0),
"acknowledgment_rate": round((response_row.acknowledged_count / max(total_alerts, 1)) * 100, 2),
"resolution_rate": round((response_row.resolved_count / max(total_alerts, 1)) * 100, 2)
} if response_row else {
"avg_acknowledgment_time_minutes": 0.0,
"avg_resolution_time_minutes": 0.0,
"acknowledgment_rate": 0.0,
"resolution_rate": 0.0
}
}
except Exception as e:
logger.error("Failed to get alert statistics",
tenant_id=tenant_id,
error=str(e))
return {
"total_alerts": 0,
"active_alerts": 0,
"resolved_alerts": 0,
"alerts_by_type": {},
"alerts_by_severity": {},
"recent_alerts_7d": 0,
"response_metrics": {
"avg_acknowledgment_time_minutes": 0.0,
"avg_resolution_time_minutes": 0.0,
"acknowledgment_rate": 0.0,
"resolution_rate": 0.0
}
}
async def cleanup_old_alerts(self, days_old: int = 90) -> int:
"""Clean up old resolved alerts"""
try:
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
query_text = """
DELETE FROM forecast_alerts
WHERE is_active = false
AND resolved_at IS NOT NULL
AND resolved_at < :cutoff_date
"""
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
deleted_count = result.rowcount
logger.info("Cleaned up old forecast alerts",
deleted_count=deleted_count,
days_old=days_old)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup old alerts",
error=str(e))
raise DatabaseError(f"Alert cleanup failed: {str(e)}")
async def bulk_resolve_alerts(
self,
tenant_id: str,
alert_type: str = None,
older_than_hours: int = 24
) -> int:
"""Bulk resolve old alerts"""
try:
cutoff_time = datetime.utcnow() - timedelta(hours=older_than_hours)
conditions = [
"tenant_id = :tenant_id",
"is_active = true",
"created_at < :cutoff_time"
]
params = {
"tenant_id": tenant_id,
"cutoff_time": cutoff_time
}
if alert_type:
conditions.append("alert_type = :alert_type")
params["alert_type"] = alert_type
query_text = f"""
UPDATE forecast_alerts
SET is_active = false, resolved_at = :resolved_at
WHERE {' AND '.join(conditions)}
"""
params["resolved_at"] = datetime.utcnow()
result = await self.session.execute(text(query_text), params)
resolved_count = result.rowcount
logger.info("Bulk resolved old alerts",
tenant_id=tenant_id,
alert_type=alert_type,
resolved_count=resolved_count,
older_than_hours=older_than_hours)
return resolved_count
except Exception as e:
logger.error("Failed to bulk resolve alerts",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Bulk resolve failed: {str(e)}")

View File

@@ -0,0 +1,429 @@
"""
Forecast Repository
Repository for forecast operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, text, desc, func
from datetime import datetime, timedelta, date
import structlog
from .base import ForecastingBaseRepository
from app.models.forecasts import Forecast
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class ForecastRepository(ForecastingBaseRepository):
"""Repository for forecast operations"""
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 600):
# Forecasts are relatively stable, medium cache time (10 minutes)
super().__init__(Forecast, session, cache_ttl)
async def create_forecast(self, forecast_data: Dict[str, Any]) -> Forecast:
"""Create a new forecast with validation"""
try:
# Validate forecast data
validation_result = self._validate_forecast_data(
forecast_data,
["tenant_id", "product_name", "location", "forecast_date",
"predicted_demand", "confidence_lower", "confidence_upper", "model_id"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid forecast data: {validation_result['errors']}")
# Set default values
if "confidence_level" not in forecast_data:
forecast_data["confidence_level"] = 0.8
if "algorithm" not in forecast_data:
forecast_data["algorithm"] = "prophet"
if "business_type" not in forecast_data:
forecast_data["business_type"] = "individual"
# Create forecast
forecast = await self.create(forecast_data)
logger.info("Forecast created successfully",
forecast_id=forecast.id,
tenant_id=forecast.tenant_id,
product_name=forecast.product_name,
forecast_date=forecast.forecast_date.isoformat())
return forecast
except ValidationError:
raise
except Exception as e:
logger.error("Failed to create forecast",
tenant_id=forecast_data.get("tenant_id"),
product_name=forecast_data.get("product_name"),
error=str(e))
raise DatabaseError(f"Failed to create forecast: {str(e)}")
async def get_forecasts_by_date_range(
self,
tenant_id: str,
start_date: date,
end_date: date,
product_name: str = None,
location: str = None
) -> List[Forecast]:
"""Get forecasts within a date range"""
try:
filters = {"tenant_id": tenant_id}
if product_name:
filters["product_name"] = product_name
if location:
filters["location"] = location
# Convert dates to datetime for comparison
start_datetime = datetime.combine(start_date, datetime.min.time())
end_datetime = datetime.combine(end_date, datetime.max.time())
return await self.get_by_date_range(
tenant_id, start_datetime, end_datetime
)
except Exception as e:
logger.error("Failed to get forecasts by date range",
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
error=str(e))
raise DatabaseError(f"Failed to get forecasts: {str(e)}")
async def get_latest_forecast_for_product(
self,
tenant_id: str,
product_name: str,
location: str = None
) -> Optional[Forecast]:
"""Get the most recent forecast for a product"""
try:
filters = {
"tenant_id": tenant_id,
"product_name": product_name
}
if location:
filters["location"] = location
forecasts = await self.get_multi(
filters=filters,
limit=1,
order_by="forecast_date",
order_desc=True
)
return forecasts[0] if forecasts else None
except Exception as e:
logger.error("Failed to get latest forecast for product",
tenant_id=tenant_id,
product_name=product_name,
error=str(e))
raise DatabaseError(f"Failed to get latest forecast: {str(e)}")
async def get_forecasts_for_date(
self,
tenant_id: str,
forecast_date: date,
product_name: str = None
) -> List[Forecast]:
"""Get all forecasts for a specific date"""
try:
# Convert date to datetime range
start_datetime = datetime.combine(forecast_date, datetime.min.time())
end_datetime = datetime.combine(forecast_date, datetime.max.time())
return await self.get_by_date_range(
tenant_id, start_datetime, end_datetime
)
except Exception as e:
logger.error("Failed to get forecasts for date",
tenant_id=tenant_id,
forecast_date=forecast_date,
error=str(e))
raise DatabaseError(f"Failed to get forecasts for date: {str(e)}")
async def get_forecast_accuracy_metrics(
self,
tenant_id: str,
product_name: str = None,
days_back: int = 30
) -> Dict[str, Any]:
"""Get forecast accuracy metrics"""
try:
cutoff_date = datetime.utcnow() - timedelta(days=days_back)
# Build base query conditions
conditions = ["tenant_id = :tenant_id", "forecast_date >= :cutoff_date"]
params = {
"tenant_id": tenant_id,
"cutoff_date": cutoff_date
}
if product_name:
conditions.append("product_name = :product_name")
params["product_name"] = product_name
query_text = f"""
SELECT
COUNT(*) as total_forecasts,
AVG(predicted_demand) as avg_predicted_demand,
MIN(predicted_demand) as min_predicted_demand,
MAX(predicted_demand) as max_predicted_demand,
AVG(confidence_upper - confidence_lower) as avg_confidence_interval,
AVG(processing_time_ms) as avg_processing_time_ms,
COUNT(DISTINCT product_name) as unique_products,
COUNT(DISTINCT model_id) as unique_models
FROM forecasts
WHERE {' AND '.join(conditions)}
"""
result = await self.session.execute(text(query_text), params)
row = result.fetchone()
if row and row.total_forecasts > 0:
return {
"total_forecasts": int(row.total_forecasts),
"avg_predicted_demand": float(row.avg_predicted_demand or 0),
"min_predicted_demand": float(row.min_predicted_demand or 0),
"max_predicted_demand": float(row.max_predicted_demand or 0),
"avg_confidence_interval": float(row.avg_confidence_interval or 0),
"avg_processing_time_ms": float(row.avg_processing_time_ms or 0),
"unique_products": int(row.unique_products or 0),
"unique_models": int(row.unique_models or 0),
"period_days": days_back
}
return {
"total_forecasts": 0,
"avg_predicted_demand": 0.0,
"min_predicted_demand": 0.0,
"max_predicted_demand": 0.0,
"avg_confidence_interval": 0.0,
"avg_processing_time_ms": 0.0,
"unique_products": 0,
"unique_models": 0,
"period_days": days_back
}
except Exception as e:
logger.error("Failed to get forecast accuracy metrics",
tenant_id=tenant_id,
error=str(e))
return {
"total_forecasts": 0,
"avg_predicted_demand": 0.0,
"min_predicted_demand": 0.0,
"max_predicted_demand": 0.0,
"avg_confidence_interval": 0.0,
"avg_processing_time_ms": 0.0,
"unique_products": 0,
"unique_models": 0,
"period_days": days_back
}
async def get_demand_trends(
self,
tenant_id: str,
product_name: str,
days_back: int = 30
) -> Dict[str, Any]:
"""Get demand trends for a product"""
try:
cutoff_date = datetime.utcnow() - timedelta(days=days_back)
query_text = """
SELECT
DATE(forecast_date) as date,
AVG(predicted_demand) as avg_demand,
MIN(predicted_demand) as min_demand,
MAX(predicted_demand) as max_demand,
COUNT(*) as forecast_count
FROM forecasts
WHERE tenant_id = :tenant_id
AND product_name = :product_name
AND forecast_date >= :cutoff_date
GROUP BY DATE(forecast_date)
ORDER BY date DESC
"""
result = await self.session.execute(text(query_text), {
"tenant_id": tenant_id,
"product_name": product_name,
"cutoff_date": cutoff_date
})
trends = []
for row in result.fetchall():
trends.append({
"date": row.date.isoformat() if row.date else None,
"avg_demand": float(row.avg_demand),
"min_demand": float(row.min_demand),
"max_demand": float(row.max_demand),
"forecast_count": int(row.forecast_count)
})
# Calculate overall trend direction
if len(trends) >= 2:
recent_avg = sum(t["avg_demand"] for t in trends[:7]) / min(7, len(trends))
older_avg = sum(t["avg_demand"] for t in trends[-7:]) / min(7, len(trends[-7:]))
trend_direction = "increasing" if recent_avg > older_avg else "decreasing"
else:
trend_direction = "stable"
return {
"product_name": product_name,
"period_days": days_back,
"trends": trends,
"trend_direction": trend_direction,
"total_data_points": len(trends)
}
except Exception as e:
logger.error("Failed to get demand trends",
tenant_id=tenant_id,
product_name=product_name,
error=str(e))
return {
"product_name": product_name,
"period_days": days_back,
"trends": [],
"trend_direction": "unknown",
"total_data_points": 0
}
async def get_model_usage_statistics(self, tenant_id: str) -> Dict[str, Any]:
"""Get statistics about model usage"""
try:
# Get model usage counts
model_query = text("""
SELECT
model_id,
algorithm,
COUNT(*) as usage_count,
AVG(predicted_demand) as avg_prediction,
MAX(forecast_date) as last_used,
COUNT(DISTINCT product_name) as products_covered
FROM forecasts
WHERE tenant_id = :tenant_id
GROUP BY model_id, algorithm
ORDER BY usage_count DESC
""")
result = await self.session.execute(model_query, {"tenant_id": tenant_id})
model_stats = []
for row in result.fetchall():
model_stats.append({
"model_id": row.model_id,
"algorithm": row.algorithm,
"usage_count": int(row.usage_count),
"avg_prediction": float(row.avg_prediction),
"last_used": row.last_used.isoformat() if row.last_used else None,
"products_covered": int(row.products_covered)
})
# Get algorithm distribution
algorithm_query = text("""
SELECT algorithm, COUNT(*) as count
FROM forecasts
WHERE tenant_id = :tenant_id
GROUP BY algorithm
""")
algorithm_result = await self.session.execute(algorithm_query, {"tenant_id": tenant_id})
algorithm_distribution = {row.algorithm: row.count for row in algorithm_result.fetchall()}
return {
"model_statistics": model_stats,
"algorithm_distribution": algorithm_distribution,
"total_unique_models": len(model_stats)
}
except Exception as e:
logger.error("Failed to get model usage statistics",
tenant_id=tenant_id,
error=str(e))
return {
"model_statistics": [],
"algorithm_distribution": {},
"total_unique_models": 0
}
async def cleanup_old_forecasts(self, days_old: int = 90) -> int:
"""Clean up old forecasts"""
return await self.cleanup_old_records(days_old=days_old)
async def get_forecast_summary(self, tenant_id: str) -> Dict[str, Any]:
"""Get comprehensive forecast summary for a tenant"""
try:
# Get basic statistics
basic_stats = await self.get_statistics_by_tenant(tenant_id)
# Get accuracy metrics
accuracy_metrics = await self.get_forecast_accuracy_metrics(tenant_id)
# Get model usage
model_usage = await self.get_model_usage_statistics(tenant_id)
# Get recent activity
recent_forecasts = await self.get_recent_records(tenant_id, hours=24)
return {
"tenant_id": tenant_id,
"basic_statistics": basic_stats,
"accuracy_metrics": accuracy_metrics,
"model_usage": model_usage,
"recent_activity": {
"forecasts_last_24h": len(recent_forecasts),
"latest_forecast": recent_forecasts[0].forecast_date.isoformat() if recent_forecasts else None
}
}
except Exception as e:
logger.error("Failed to get forecast summary",
tenant_id=tenant_id,
error=str(e))
return {"error": f"Failed to get forecast summary: {str(e)}"}
async def bulk_create_forecasts(self, forecasts_data: List[Dict[str, Any]]) -> List[Forecast]:
"""Bulk create multiple forecasts"""
try:
created_forecasts = []
for forecast_data in forecasts_data:
# Validate each forecast
validation_result = self._validate_forecast_data(
forecast_data,
["tenant_id", "product_name", "location", "forecast_date",
"predicted_demand", "confidence_lower", "confidence_upper", "model_id"]
)
if not validation_result["is_valid"]:
logger.warning("Skipping invalid forecast data",
errors=validation_result["errors"],
data=forecast_data)
continue
forecast = await self.create(forecast_data)
created_forecasts.append(forecast)
logger.info("Bulk created forecasts",
requested_count=len(forecasts_data),
created_count=len(created_forecasts))
return created_forecasts
except Exception as e:
logger.error("Failed to bulk create forecasts",
requested_count=len(forecasts_data),
error=str(e))
raise DatabaseError(f"Bulk forecast creation failed: {str(e)}")

View File

@@ -0,0 +1,170 @@
"""
Performance Metric Repository
Repository for model performance metrics in forecasting service
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
from datetime import datetime, timedelta
import structlog
from .base import ForecastingBaseRepository
from app.models.predictions import ModelPerformanceMetric
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class PerformanceMetricRepository(ForecastingBaseRepository):
"""Repository for model performance metrics operations"""
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 900):
# Performance metrics are stable, longer cache time (15 minutes)
super().__init__(ModelPerformanceMetric, session, cache_ttl)
async def create_metric(self, metric_data: Dict[str, Any]) -> ModelPerformanceMetric:
"""Create a new performance metric"""
try:
# Validate metric data
validation_result = self._validate_forecast_data(
metric_data,
["model_id", "tenant_id", "product_name", "evaluation_date"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid metric data: {validation_result['errors']}")
metric = await self.create(metric_data)
logger.info("Performance metric created",
metric_id=metric.id,
model_id=metric.model_id,
tenant_id=metric.tenant_id,
product_name=metric.product_name)
return metric
except ValidationError:
raise
except Exception as e:
logger.error("Failed to create performance metric",
model_id=metric_data.get("model_id"),
error=str(e))
raise DatabaseError(f"Failed to create metric: {str(e)}")
async def get_metrics_by_model(
self,
model_id: str,
skip: int = 0,
limit: int = 100
) -> List[ModelPerformanceMetric]:
"""Get all metrics for a model"""
try:
return await self.get_multi(
filters={"model_id": model_id},
skip=skip,
limit=limit,
order_by="evaluation_date",
order_desc=True
)
except Exception as e:
logger.error("Failed to get metrics by model",
model_id=model_id,
error=str(e))
raise DatabaseError(f"Failed to get metrics: {str(e)}")
async def get_latest_metric_for_model(self, model_id: str) -> Optional[ModelPerformanceMetric]:
"""Get the latest performance metric for a model"""
try:
metrics = await self.get_multi(
filters={"model_id": model_id},
limit=1,
order_by="evaluation_date",
order_desc=True
)
return metrics[0] if metrics else None
except Exception as e:
logger.error("Failed to get latest metric for model",
model_id=model_id,
error=str(e))
raise DatabaseError(f"Failed to get latest metric: {str(e)}")
async def get_performance_trends(
self,
tenant_id: str,
product_name: str = None,
days: int = 30
) -> Dict[str, Any]:
"""Get performance trends over time"""
try:
start_date = datetime.utcnow() - timedelta(days=days)
conditions = [
"tenant_id = :tenant_id",
"evaluation_date >= :start_date"
]
params = {
"tenant_id": tenant_id,
"start_date": start_date
}
if product_name:
conditions.append("product_name = :product_name")
params["product_name"] = product_name
query_text = f"""
SELECT
DATE(evaluation_date) as date,
product_name,
AVG(mae) as avg_mae,
AVG(mape) as avg_mape,
AVG(rmse) as avg_rmse,
AVG(accuracy_score) as avg_accuracy,
COUNT(*) as measurement_count
FROM model_performance_metrics
WHERE {' AND '.join(conditions)}
GROUP BY DATE(evaluation_date), product_name
ORDER BY date DESC, product_name
"""
result = await self.session.execute(text(query_text), params)
trends = []
for row in result.fetchall():
trends.append({
"date": row.date.isoformat() if row.date else None,
"product_name": row.product_name,
"metrics": {
"avg_mae": float(row.avg_mae) if row.avg_mae else None,
"avg_mape": float(row.avg_mape) if row.avg_mape else None,
"avg_rmse": float(row.avg_rmse) if row.avg_rmse else None,
"avg_accuracy": float(row.avg_accuracy) if row.avg_accuracy else None
},
"measurement_count": int(row.measurement_count)
})
return {
"tenant_id": tenant_id,
"product_name": product_name,
"period_days": days,
"trends": trends,
"total_measurements": len(trends)
}
except Exception as e:
logger.error("Failed to get performance trends",
tenant_id=tenant_id,
product_name=product_name,
error=str(e))
return {
"tenant_id": tenant_id,
"product_name": product_name,
"period_days": days,
"trends": [],
"total_measurements": 0
}
async def cleanup_old_metrics(self, days_old: int = 180) -> int:
"""Clean up old performance metrics"""
return await self.cleanup_old_records(days_old=days_old)

View File

@@ -0,0 +1,388 @@
"""
Prediction Batch Repository
Repository for prediction batch operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
from datetime import datetime, timedelta
import structlog
from .base import ForecastingBaseRepository
from app.models.forecasts import PredictionBatch
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class PredictionBatchRepository(ForecastingBaseRepository):
"""Repository for prediction batch operations"""
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 300):
# Batch operations change frequently, shorter cache time (5 minutes)
super().__init__(PredictionBatch, session, cache_ttl)
async def create_batch(self, batch_data: Dict[str, Any]) -> PredictionBatch:
"""Create a new prediction batch"""
try:
# Validate batch data
validation_result = self._validate_forecast_data(
batch_data,
["tenant_id", "batch_name"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid batch data: {validation_result['errors']}")
# Set default values
if "status" not in batch_data:
batch_data["status"] = "pending"
if "forecast_days" not in batch_data:
batch_data["forecast_days"] = 7
if "business_type" not in batch_data:
batch_data["business_type"] = "individual"
batch = await self.create(batch_data)
logger.info("Prediction batch created",
batch_id=batch.id,
tenant_id=batch.tenant_id,
batch_name=batch.batch_name)
return batch
except ValidationError:
raise
except Exception as e:
logger.error("Failed to create prediction batch",
tenant_id=batch_data.get("tenant_id"),
error=str(e))
raise DatabaseError(f"Failed to create batch: {str(e)}")
async def update_batch_progress(
self,
batch_id: str,
completed_products: int = None,
failed_products: int = None,
total_products: int = None,
status: str = None
) -> Optional[PredictionBatch]:
"""Update batch progress"""
try:
update_data = {}
if completed_products is not None:
update_data["completed_products"] = completed_products
if failed_products is not None:
update_data["failed_products"] = failed_products
if total_products is not None:
update_data["total_products"] = total_products
if status:
update_data["status"] = status
if status in ["completed", "failed"]:
update_data["completed_at"] = datetime.utcnow()
if not update_data:
return await self.get_by_id(batch_id)
updated_batch = await self.update(batch_id, update_data)
logger.debug("Batch progress updated",
batch_id=batch_id,
status=status,
completed=completed_products)
return updated_batch
except Exception as e:
logger.error("Failed to update batch progress",
batch_id=batch_id,
error=str(e))
raise DatabaseError(f"Failed to update batch: {str(e)}")
async def complete_batch(
self,
batch_id: str,
processing_time_ms: int = None
) -> Optional[PredictionBatch]:
"""Mark batch as completed"""
try:
update_data = {
"status": "completed",
"completed_at": datetime.utcnow()
}
if processing_time_ms:
update_data["processing_time_ms"] = processing_time_ms
updated_batch = await self.update(batch_id, update_data)
logger.info("Batch completed",
batch_id=batch_id,
processing_time_ms=processing_time_ms)
return updated_batch
except Exception as e:
logger.error("Failed to complete batch",
batch_id=batch_id,
error=str(e))
raise DatabaseError(f"Failed to complete batch: {str(e)}")
async def fail_batch(
self,
batch_id: str,
error_message: str,
processing_time_ms: int = None
) -> Optional[PredictionBatch]:
"""Mark batch as failed"""
try:
update_data = {
"status": "failed",
"completed_at": datetime.utcnow(),
"error_message": error_message
}
if processing_time_ms:
update_data["processing_time_ms"] = processing_time_ms
updated_batch = await self.update(batch_id, update_data)
logger.error("Batch failed",
batch_id=batch_id,
error_message=error_message)
return updated_batch
except Exception as e:
logger.error("Failed to mark batch as failed",
batch_id=batch_id,
error=str(e))
raise DatabaseError(f"Failed to fail batch: {str(e)}")
async def cancel_batch(
self,
batch_id: str,
cancelled_by: str = None
) -> Optional[PredictionBatch]:
"""Cancel a batch"""
try:
batch = await self.get_by_id(batch_id)
if not batch:
return None
if batch.status in ["completed", "failed"]:
logger.warning("Cannot cancel finished batch",
batch_id=batch_id,
status=batch.status)
return batch
update_data = {
"status": "cancelled",
"completed_at": datetime.utcnow(),
"cancelled_by": cancelled_by,
"error_message": f"Cancelled by {cancelled_by}" if cancelled_by else "Cancelled"
}
updated_batch = await self.update(batch_id, update_data)
logger.info("Batch cancelled",
batch_id=batch_id,
cancelled_by=cancelled_by)
return updated_batch
except Exception as e:
logger.error("Failed to cancel batch",
batch_id=batch_id,
error=str(e))
raise DatabaseError(f"Failed to cancel batch: {str(e)}")
async def get_active_batches(self, tenant_id: str = None) -> List[PredictionBatch]:
"""Get currently active (pending/processing) batches"""
try:
filters = {"status": "processing"}
if tenant_id:
# Need to handle multiple status values with raw query
query_text = """
SELECT * FROM prediction_batches
WHERE status IN ('pending', 'processing')
AND tenant_id = :tenant_id
ORDER BY requested_at DESC
"""
params = {"tenant_id": tenant_id}
else:
query_text = """
SELECT * FROM prediction_batches
WHERE status IN ('pending', 'processing')
ORDER BY requested_at DESC
"""
params = {}
result = await self.session.execute(text(query_text), params)
batches = []
for row in result.fetchall():
record_dict = dict(row._mapping)
batch = self.model(**record_dict)
batches.append(batch)
return batches
except Exception as e:
logger.error("Failed to get active batches",
tenant_id=tenant_id,
error=str(e))
return []
async def get_batch_statistics(self, tenant_id: str = None) -> Dict[str, Any]:
"""Get batch processing statistics"""
try:
base_filter = "WHERE 1=1"
params = {}
if tenant_id:
base_filter = "WHERE tenant_id = :tenant_id"
params["tenant_id"] = tenant_id
# Get counts by status
status_query = text(f"""
SELECT
status,
COUNT(*) as count,
AVG(CASE WHEN processing_time_ms IS NOT NULL THEN processing_time_ms END) as avg_processing_time_ms
FROM prediction_batches
{base_filter}
GROUP BY status
""")
result = await self.session.execute(status_query, params)
status_stats = {}
total_batches = 0
avg_processing_times = {}
for row in result.fetchall():
status_stats[row.status] = row.count
total_batches += row.count
if row.avg_processing_time_ms:
avg_processing_times[row.status] = float(row.avg_processing_time_ms)
# Get recent activity (batches in last 7 days)
seven_days_ago = datetime.utcnow() - timedelta(days=7)
recent_query = text(f"""
SELECT COUNT(*) as count
FROM prediction_batches
{base_filter}
AND requested_at >= :seven_days_ago
""")
recent_result = await self.session.execute(recent_query, {
**params,
"seven_days_ago": seven_days_ago
})
recent_batches = recent_result.scalar() or 0
# Calculate success rate
completed = status_stats.get("completed", 0)
failed = status_stats.get("failed", 0)
cancelled = status_stats.get("cancelled", 0)
finished_batches = completed + failed + cancelled
success_rate = (completed / finished_batches * 100) if finished_batches > 0 else 0
return {
"total_batches": total_batches,
"batches_by_status": status_stats,
"success_rate": round(success_rate, 2),
"recent_batches_7d": recent_batches,
"avg_processing_times_ms": avg_processing_times
}
except Exception as e:
logger.error("Failed to get batch statistics",
tenant_id=tenant_id,
error=str(e))
return {
"total_batches": 0,
"batches_by_status": {},
"success_rate": 0.0,
"recent_batches_7d": 0,
"avg_processing_times_ms": {}
}
async def cleanup_old_batches(self, days_old: int = 30) -> int:
"""Clean up old completed/failed batches"""
try:
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
query_text = """
DELETE FROM prediction_batches
WHERE status IN ('completed', 'failed', 'cancelled')
AND completed_at < :cutoff_date
"""
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
deleted_count = result.rowcount
logger.info("Cleaned up old prediction batches",
deleted_count=deleted_count,
days_old=days_old)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup old batches",
error=str(e))
raise DatabaseError(f"Batch cleanup failed: {str(e)}")
async def get_batch_details(self, batch_id: str) -> Dict[str, Any]:
"""Get detailed batch information"""
try:
batch = await self.get_by_id(batch_id)
if not batch:
return {"error": "Batch not found"}
# Calculate completion percentage
completion_percentage = 0
if batch.total_products > 0:
completion_percentage = (batch.completed_products / batch.total_products) * 100
# Calculate elapsed time
elapsed_time_ms = 0
if batch.completed_at:
elapsed_time_ms = int((batch.completed_at - batch.requested_at).total_seconds() * 1000)
elif batch.status in ["pending", "processing"]:
elapsed_time_ms = int((datetime.utcnow() - batch.requested_at).total_seconds() * 1000)
return {
"batch_id": str(batch.id),
"tenant_id": str(batch.tenant_id),
"batch_name": batch.batch_name,
"status": batch.status,
"progress": {
"total_products": batch.total_products,
"completed_products": batch.completed_products,
"failed_products": batch.failed_products,
"completion_percentage": round(completion_percentage, 2)
},
"timing": {
"requested_at": batch.requested_at.isoformat(),
"completed_at": batch.completed_at.isoformat() if batch.completed_at else None,
"elapsed_time_ms": elapsed_time_ms,
"processing_time_ms": batch.processing_time_ms
},
"configuration": {
"forecast_days": batch.forecast_days,
"business_type": batch.business_type
},
"error_message": batch.error_message,
"cancelled_by": batch.cancelled_by
}
except Exception as e:
logger.error("Failed to get batch details",
batch_id=batch_id,
error=str(e))
return {"error": f"Failed to get batch details: {str(e)}"}

View File

@@ -0,0 +1,302 @@
"""
Prediction Cache Repository
Repository for prediction cache operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
from datetime import datetime, timedelta
import structlog
import hashlib
from .base import ForecastingBaseRepository
from app.models.predictions import PredictionCache
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class PredictionCacheRepository(ForecastingBaseRepository):
"""Repository for prediction cache operations"""
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 60):
# Cache entries change very frequently, short cache time (1 minute)
super().__init__(PredictionCache, session, cache_ttl)
def _generate_cache_key(
self,
tenant_id: str,
product_name: str,
location: str,
forecast_date: datetime
) -> str:
"""Generate cache key for prediction"""
key_data = f"{tenant_id}:{product_name}:{location}:{forecast_date.isoformat()}"
return hashlib.md5(key_data.encode()).hexdigest()
async def cache_prediction(
self,
tenant_id: str,
product_name: str,
location: str,
forecast_date: datetime,
predicted_demand: float,
confidence_lower: float,
confidence_upper: float,
model_id: str,
expires_in_hours: int = 24
) -> PredictionCache:
"""Cache a prediction result"""
try:
cache_key = self._generate_cache_key(tenant_id, product_name, location, forecast_date)
expires_at = datetime.utcnow() + timedelta(hours=expires_in_hours)
cache_data = {
"cache_key": cache_key,
"tenant_id": tenant_id,
"product_name": product_name,
"location": location,
"forecast_date": forecast_date,
"predicted_demand": predicted_demand,
"confidence_lower": confidence_lower,
"confidence_upper": confidence_upper,
"model_id": model_id,
"expires_at": expires_at,
"hit_count": 0
}
# Try to update existing cache entry first
existing_cache = await self.get_by_field("cache_key", cache_key)
if existing_cache:
cache_entry = await self.update(existing_cache.id, cache_data)
logger.debug("Updated cache entry", cache_key=cache_key)
else:
cache_entry = await self.create(cache_data)
logger.debug("Created cache entry", cache_key=cache_key)
return cache_entry
except Exception as e:
logger.error("Failed to cache prediction",
tenant_id=tenant_id,
product_name=product_name,
error=str(e))
raise DatabaseError(f"Failed to cache prediction: {str(e)}")
async def get_cached_prediction(
self,
tenant_id: str,
product_name: str,
location: str,
forecast_date: datetime
) -> Optional[PredictionCache]:
"""Get cached prediction if valid"""
try:
cache_key = self._generate_cache_key(tenant_id, product_name, location, forecast_date)
cache_entry = await self.get_by_field("cache_key", cache_key)
if not cache_entry:
logger.debug("Cache miss", cache_key=cache_key)
return None
# Check if cache entry has expired
if cache_entry.expires_at < datetime.utcnow():
logger.debug("Cache expired", cache_key=cache_key)
await self.delete(cache_entry.id)
return None
# Increment hit count
await self.update(cache_entry.id, {"hit_count": cache_entry.hit_count + 1})
logger.debug("Cache hit",
cache_key=cache_key,
hit_count=cache_entry.hit_count + 1)
return cache_entry
except Exception as e:
logger.error("Failed to get cached prediction",
tenant_id=tenant_id,
product_name=product_name,
error=str(e))
return None
async def invalidate_cache(
self,
tenant_id: str,
product_name: str = None,
location: str = None
) -> int:
"""Invalidate cache entries"""
try:
conditions = ["tenant_id = :tenant_id"]
params = {"tenant_id": tenant_id}
if product_name:
conditions.append("product_name = :product_name")
params["product_name"] = product_name
if location:
conditions.append("location = :location")
params["location"] = location
query_text = f"""
DELETE FROM prediction_cache
WHERE {' AND '.join(conditions)}
"""
result = await self.session.execute(text(query_text), params)
invalidated_count = result.rowcount
logger.info("Cache invalidated",
tenant_id=tenant_id,
product_name=product_name,
location=location,
invalidated_count=invalidated_count)
return invalidated_count
except Exception as e:
logger.error("Failed to invalidate cache",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Cache invalidation failed: {str(e)}")
async def cleanup_expired_cache(self) -> int:
"""Clean up expired cache entries"""
try:
query_text = """
DELETE FROM prediction_cache
WHERE expires_at < :now
"""
result = await self.session.execute(text(query_text), {"now": datetime.utcnow()})
deleted_count = result.rowcount
logger.info("Cleaned up expired cache entries",
deleted_count=deleted_count)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup expired cache",
error=str(e))
raise DatabaseError(f"Cache cleanup failed: {str(e)}")
async def get_cache_statistics(self, tenant_id: str = None) -> Dict[str, Any]:
"""Get cache performance statistics"""
try:
base_filter = "WHERE 1=1"
params = {}
if tenant_id:
base_filter = "WHERE tenant_id = :tenant_id"
params["tenant_id"] = tenant_id
# Get cache statistics
stats_query = text(f"""
SELECT
COUNT(*) as total_entries,
COUNT(CASE WHEN expires_at > :now THEN 1 END) as active_entries,
COUNT(CASE WHEN expires_at <= :now THEN 1 END) as expired_entries,
SUM(hit_count) as total_hits,
AVG(hit_count) as avg_hits_per_entry,
MAX(hit_count) as max_hits,
COUNT(DISTINCT product_name) as unique_products
FROM prediction_cache
{base_filter}
""")
params["now"] = datetime.utcnow()
result = await self.session.execute(stats_query, params)
row = result.fetchone()
if row:
return {
"total_entries": int(row.total_entries or 0),
"active_entries": int(row.active_entries or 0),
"expired_entries": int(row.expired_entries or 0),
"total_hits": int(row.total_hits or 0),
"avg_hits_per_entry": float(row.avg_hits_per_entry or 0),
"max_hits": int(row.max_hits or 0),
"unique_products": int(row.unique_products or 0),
"cache_hit_ratio": round((row.total_hits / max(row.total_entries, 1)), 2)
}
return {
"total_entries": 0,
"active_entries": 0,
"expired_entries": 0,
"total_hits": 0,
"avg_hits_per_entry": 0.0,
"max_hits": 0,
"unique_products": 0,
"cache_hit_ratio": 0.0
}
except Exception as e:
logger.error("Failed to get cache statistics",
tenant_id=tenant_id,
error=str(e))
return {
"total_entries": 0,
"active_entries": 0,
"expired_entries": 0,
"total_hits": 0,
"avg_hits_per_entry": 0.0,
"max_hits": 0,
"unique_products": 0,
"cache_hit_ratio": 0.0
}
async def get_most_accessed_predictions(
self,
tenant_id: str = None,
limit: int = 10
) -> List[Dict[str, Any]]:
"""Get most frequently accessed cached predictions"""
try:
base_filter = "WHERE hit_count > 0"
params = {"limit": limit}
if tenant_id:
base_filter = "WHERE tenant_id = :tenant_id AND hit_count > 0"
params["tenant_id"] = tenant_id
query_text = f"""
SELECT
product_name,
location,
hit_count,
predicted_demand,
created_at,
expires_at
FROM prediction_cache
{base_filter}
ORDER BY hit_count DESC
LIMIT :limit
"""
result = await self.session.execute(text(query_text), params)
popular_predictions = []
for row in result.fetchall():
popular_predictions.append({
"product_name": row.product_name,
"location": row.location,
"hit_count": int(row.hit_count),
"predicted_demand": float(row.predicted_demand),
"created_at": row.created_at.isoformat() if row.created_at else None,
"expires_at": row.expires_at.isoformat() if row.expires_at else None
})
return popular_predictions
except Exception as e:
logger.error("Failed to get most accessed predictions",
tenant_id=tenant_id,
error=str(e))
return []