Initial commit - production deployment

This commit is contained in:
2026-01-21 17:17:16 +01:00
commit c23d00dd92
2289 changed files with 638440 additions and 0 deletions

View File

@@ -0,0 +1,18 @@
"""
Forecasting Service Repositories
Repository implementations for forecasting service
"""
from .base import ForecastingBaseRepository
from .forecast_repository import ForecastRepository
from .prediction_batch_repository import PredictionBatchRepository
from .performance_metric_repository import PerformanceMetricRepository
from .prediction_cache_repository import PredictionCacheRepository
__all__ = [
"ForecastingBaseRepository",
"ForecastRepository",
"PredictionBatchRepository",
"PerformanceMetricRepository",
"PredictionCacheRepository"
]

View File

@@ -0,0 +1,253 @@
"""
Base Repository for Forecasting Service
Service-specific repository base class with forecasting utilities
"""
from typing import Optional, List, Dict, Any, Type
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
from datetime import datetime, date, timedelta, timezone
import structlog
from shared.database.repository import BaseRepository
from shared.database.exceptions import DatabaseError
logger = structlog.get_logger()
class ForecastingBaseRepository(BaseRepository):
"""Base repository for forecasting service with common forecasting operations"""
def __init__(self, model: Type, session: AsyncSession, cache_ttl: Optional[int] = 600):
# Forecasting data benefits from medium cache time (10 minutes)
super().__init__(model, session, cache_ttl)
async def get_by_tenant_id(self, tenant_id: str, skip: int = 0, limit: int = 100) -> List:
"""Get records by tenant ID"""
if hasattr(self.model, 'tenant_id'):
return await self.get_multi(
skip=skip,
limit=limit,
filters={"tenant_id": tenant_id},
order_by="created_at",
order_desc=True
)
return await self.get_multi(skip=skip, limit=limit)
async def get_by_inventory_product_id(
self,
tenant_id: str,
inventory_product_id: str,
skip: int = 0,
limit: int = 100
) -> List:
"""Get records by tenant and inventory product"""
if hasattr(self.model, 'inventory_product_id'):
return await self.get_multi(
skip=skip,
limit=limit,
filters={
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id
},
order_by="created_at",
order_desc=True
)
return await self.get_by_tenant_id(tenant_id, skip, limit)
async def get_by_date_range(
self,
tenant_id: str,
start_date: datetime,
end_date: datetime,
skip: int = 0,
limit: int = 100
) -> List:
"""Get records within date range for a tenant"""
if not hasattr(self.model, 'forecast_date') and not hasattr(self.model, 'created_at'):
logger.warning(f"Model {self.model.__name__} has no date field for filtering")
return []
try:
table_name = self.model.__tablename__
date_field = "forecast_date" if hasattr(self.model, 'forecast_date') else "created_at"
query_text = f"""
SELECT * FROM {table_name}
WHERE tenant_id = :tenant_id
AND {date_field} >= :start_date
AND {date_field} <= :end_date
ORDER BY {date_field} DESC
LIMIT :limit OFFSET :skip
"""
result = await self.session.execute(text(query_text), {
"tenant_id": tenant_id,
"start_date": start_date,
"end_date": end_date,
"limit": limit,
"skip": skip
})
# Convert rows to model objects
records = []
for row in result.fetchall():
record_dict = dict(row._mapping)
record = self.model(**record_dict)
records.append(record)
return records
except Exception as e:
logger.error("Failed to get records by date range",
model=self.model.__name__,
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Date range query failed: {str(e)}")
async def get_recent_records(
self,
tenant_id: str,
hours: int = 24,
skip: int = 0,
limit: int = 100
) -> List:
"""Get recent records for a tenant"""
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours)
return await self.get_by_date_range(
tenant_id, cutoff_time, datetime.now(timezone.utc), skip, limit
)
async def cleanup_old_records(self, days_old: int = 90) -> int:
"""Clean up old forecasting records"""
try:
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_old)
table_name = self.model.__tablename__
# Use created_at or forecast_date for cleanup
date_field = "forecast_date" if hasattr(self.model, 'forecast_date') else "created_at"
query_text = f"""
DELETE FROM {table_name}
WHERE {date_field} < :cutoff_date
"""
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
deleted_count = result.rowcount
logger.info(f"Cleaned up old {self.model.__name__} records",
deleted_count=deleted_count,
days_old=days_old)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup old records",
model=self.model.__name__,
error=str(e))
raise DatabaseError(f"Cleanup failed: {str(e)}")
async def get_statistics_by_tenant(self, tenant_id: str) -> Dict[str, Any]:
"""Get statistics for a tenant"""
try:
table_name = self.model.__tablename__
# Get basic counts
total_records = await self.count(filters={"tenant_id": tenant_id})
# Get recent activity (records in last 7 days)
seven_days_ago = datetime.now(timezone.utc) - timedelta(days=7)
recent_records = len(await self.get_by_date_range(
tenant_id, seven_days_ago, datetime.now(timezone.utc), limit=1000
))
# Get records by product if applicable
product_stats = {}
if hasattr(self.model, 'inventory_product_id'):
product_query = text(f"""
SELECT inventory_product_id, COUNT(*) as count
FROM {table_name}
WHERE tenant_id = :tenant_id
GROUP BY inventory_product_id
ORDER BY count DESC
""")
result = await self.session.execute(product_query, {"tenant_id": tenant_id})
product_stats = {row.inventory_product_id: row.count for row in result.fetchall()}
return {
"total_records": total_records,
"recent_records_7d": recent_records,
"records_by_product": product_stats
}
except Exception as e:
logger.error("Failed to get tenant statistics",
model=self.model.__name__,
tenant_id=tenant_id,
error=str(e))
return {
"total_records": 0,
"recent_records_7d": 0,
"records_by_product": {}
}
def _validate_forecast_data(self, data: Dict[str, Any], required_fields: List[str]) -> Dict[str, Any]:
"""Validate forecasting-related data"""
errors = []
for field in required_fields:
if field not in data or data[field] is None:
errors.append(f"Missing required field: {field}")
# Validate tenant_id format if present
if "tenant_id" in data and data["tenant_id"]:
tenant_id = data["tenant_id"]
if not isinstance(tenant_id, str) or len(tenant_id) < 1:
errors.append("Invalid tenant_id format")
# Validate inventory_product_id if present
if "inventory_product_id" in data and data["inventory_product_id"]:
inventory_product_id = data["inventory_product_id"]
if not isinstance(inventory_product_id, str) or len(inventory_product_id) < 1:
errors.append("Invalid inventory_product_id format")
# Validate dates if present - accept datetime objects, date objects, and date strings
date_fields = ["forecast_date", "created_at", "evaluation_date", "expires_at"]
for field in date_fields:
if field in data and data[field]:
field_value = data[field]
field_type = type(field_value).__name__
if isinstance(field_value, (datetime, date)):
logger.debug(f"Date field {field} is valid {field_type}", field_value=str(field_value))
continue # Already a datetime or date, valid
elif isinstance(field_value, str):
# Try to parse the string date
try:
from dateutil.parser import parse
parse(field_value) # Just validate, don't convert yet
logger.debug(f"Date field {field} is valid string", field_value=field_value)
except (ValueError, TypeError) as e:
logger.error(f"Date parsing failed for {field}", field_value=field_value, error=str(e))
errors.append(f"Invalid {field} format - must be datetime or valid date string")
else:
logger.error(f"Date field {field} has invalid type {field_type}", field_value=str(field_value))
errors.append(f"Invalid {field} format - must be datetime or valid date string")
# Validate numeric fields
numeric_fields = [
"predicted_demand", "confidence_lower", "confidence_upper",
"mae", "mape", "rmse", "accuracy_score"
]
for field in numeric_fields:
if field in data and data[field] is not None:
try:
float(data[field])
except (ValueError, TypeError):
errors.append(f"Invalid {field} format - must be numeric")
return {
"is_valid": len(errors) == 0,
"errors": errors
}

View File

@@ -0,0 +1,565 @@
"""
Forecast Repository
Repository for forecast operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, text, desc, func
from sqlalchemy.exc import IntegrityError
from datetime import datetime, timedelta, date, timezone
import structlog
from .base import ForecastingBaseRepository
from app.models.forecasts import Forecast
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class ForecastRepository(ForecastingBaseRepository):
"""Repository for forecast operations"""
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 600):
# Forecasts are relatively stable, medium cache time (10 minutes)
super().__init__(Forecast, session, cache_ttl)
async def create_forecast(self, forecast_data: Dict[str, Any]) -> Forecast:
"""
Create a new forecast with validation.
Handles duplicate forecast race condition gracefully:
If a forecast already exists for the same (tenant, product, date, location),
it will be updated instead of creating a duplicate.
"""
try:
# Validate forecast data
validation_result = self._validate_forecast_data(
forecast_data,
["tenant_id", "inventory_product_id", "location", "forecast_date",
"predicted_demand", "confidence_lower", "confidence_upper", "model_id"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid forecast data: {validation_result['errors']}")
# Set default values
if "confidence_level" not in forecast_data:
forecast_data["confidence_level"] = 0.8
if "algorithm" not in forecast_data:
forecast_data["algorithm"] = "prophet"
if "business_type" not in forecast_data:
forecast_data["business_type"] = "individual"
# Try to create forecast
try:
forecast = await self.create(forecast_data)
logger.info("Forecast created successfully",
forecast_id=forecast.id,
tenant_id=forecast.tenant_id,
inventory_product_id=forecast.inventory_product_id,
forecast_date=forecast.forecast_date.isoformat())
return forecast
except IntegrityError as ie:
# Handle unique constraint violation (duplicate forecast)
error_msg = str(ie).lower()
if "unique constraint" in error_msg or "duplicate" in error_msg or "uq_forecast_tenant_product_date_location" in error_msg:
logger.warning("Forecast already exists (race condition), updating instead",
tenant_id=forecast_data.get("tenant_id"),
inventory_product_id=forecast_data.get("inventory_product_id"),
forecast_date=str(forecast_data.get("forecast_date")))
# Rollback the failed insert
await self.session.rollback()
# Fetch the existing forecast
existing_forecast = await self.get_existing_forecast(
tenant_id=forecast_data["tenant_id"],
inventory_product_id=forecast_data["inventory_product_id"],
forecast_date=forecast_data["forecast_date"],
location=forecast_data["location"]
)
if existing_forecast:
# Update existing forecast with new prediction data
update_data = {
"predicted_demand": forecast_data["predicted_demand"],
"confidence_lower": forecast_data["confidence_lower"],
"confidence_upper": forecast_data["confidence_upper"],
"confidence_level": forecast_data.get("confidence_level", 0.8),
"model_id": forecast_data["model_id"],
"model_version": forecast_data.get("model_version"),
"algorithm": forecast_data.get("algorithm", "prophet"),
"processing_time_ms": forecast_data.get("processing_time_ms"),
"features_used": forecast_data.get("features_used"),
"weather_temperature": forecast_data.get("weather_temperature"),
"weather_precipitation": forecast_data.get("weather_precipitation"),
"weather_description": forecast_data.get("weather_description"),
}
updated_forecast = await self.update(str(existing_forecast.id), update_data)
logger.info("Existing forecast updated after duplicate detection",
forecast_id=updated_forecast.id,
tenant_id=updated_forecast.tenant_id,
inventory_product_id=updated_forecast.inventory_product_id)
return updated_forecast
else:
# This shouldn't happen, but log it
logger.error("Duplicate forecast detected but not found in database")
raise DatabaseError("Duplicate forecast detected but not found")
else:
# Different integrity error, re-raise
raise
except ValidationError:
raise
except IntegrityError as ie:
# Re-raise integrity errors that weren't handled above
logger.error("Database integrity error creating forecast",
tenant_id=forecast_data.get("tenant_id"),
error=str(ie))
raise DatabaseError(f"Database integrity error: {str(ie)}")
except Exception as e:
logger.error("Failed to create forecast",
tenant_id=forecast_data.get("tenant_id"),
inventory_product_id=forecast_data.get("inventory_product_id"),
error=str(e))
raise DatabaseError(f"Failed to create forecast: {str(e)}")
async def get_existing_forecast(
self,
tenant_id: str,
inventory_product_id: str,
forecast_date: datetime,
location: str
) -> Optional[Forecast]:
"""Get an existing forecast by unique key (tenant, product, date, location)"""
try:
query = select(Forecast).where(
and_(
Forecast.tenant_id == tenant_id,
Forecast.inventory_product_id == inventory_product_id,
Forecast.forecast_date == forecast_date,
Forecast.location == location
)
)
result = await self.session.execute(query)
return result.scalar_one_or_none()
except Exception as e:
logger.error("Failed to get existing forecast", error=str(e))
return None
async def get_forecasts_by_date_range(
self,
tenant_id: str,
start_date: date,
end_date: date,
inventory_product_id: str = None,
location: str = None
) -> List[Forecast]:
"""Get forecasts within a date range"""
try:
filters = {"tenant_id": tenant_id}
if inventory_product_id:
filters["inventory_product_id"] = inventory_product_id
if location:
filters["location"] = location
# Convert dates to datetime for comparison
start_datetime = datetime.combine(start_date, datetime.min.time())
end_datetime = datetime.combine(end_date, datetime.max.time())
return await self.get_by_date_range(
tenant_id, start_datetime, end_datetime
)
except Exception as e:
logger.error("Failed to get forecasts by date range",
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
error=str(e))
raise DatabaseError(f"Failed to get forecasts: {str(e)}")
async def get_latest_forecast_for_product(
self,
tenant_id: str,
inventory_product_id: str,
location: str = None
) -> Optional[Forecast]:
"""Get the most recent forecast for a product"""
try:
filters = {
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id
}
if location:
filters["location"] = location
forecasts = await self.get_multi(
filters=filters,
limit=1,
order_by="forecast_date",
order_desc=True
)
return forecasts[0] if forecasts else None
except Exception as e:
logger.error("Failed to get latest forecast for product",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
error=str(e))
raise DatabaseError(f"Failed to get latest forecast: {str(e)}")
async def get_forecasts_for_date(
self,
tenant_id: str,
forecast_date: date,
inventory_product_id: str = None
) -> List[Forecast]:
"""Get all forecasts for a specific date"""
try:
# Convert date to datetime range
start_datetime = datetime.combine(forecast_date, datetime.min.time())
end_datetime = datetime.combine(forecast_date, datetime.max.time())
return await self.get_by_date_range(
tenant_id, start_datetime, end_datetime
)
except Exception as e:
logger.error("Failed to get forecasts for date",
tenant_id=tenant_id,
forecast_date=forecast_date,
error=str(e))
raise DatabaseError(f"Failed to get forecasts for date: {str(e)}")
async def get_forecast_accuracy_metrics(
self,
tenant_id: str,
inventory_product_id: str = None,
days_back: int = 30
) -> Dict[str, Any]:
"""Get forecast accuracy metrics"""
try:
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_back)
# Build base query conditions
conditions = ["tenant_id = :tenant_id", "forecast_date >= :cutoff_date"]
params = {
"tenant_id": tenant_id,
"cutoff_date": cutoff_date
}
if inventory_product_id:
conditions.append("inventory_product_id = :inventory_product_id")
params["inventory_product_id"] = inventory_product_id
query_text = f"""
SELECT
COUNT(*) as total_forecasts,
AVG(predicted_demand) as avg_predicted_demand,
MIN(predicted_demand) as min_predicted_demand,
MAX(predicted_demand) as max_predicted_demand,
AVG(confidence_upper - confidence_lower) as avg_confidence_interval,
AVG(processing_time_ms) as avg_processing_time_ms,
COUNT(DISTINCT inventory_product_id) as unique_products,
COUNT(DISTINCT model_id) as unique_models
FROM forecasts
WHERE {' AND '.join(conditions)}
"""
result = await self.session.execute(text(query_text), params)
row = result.fetchone()
if row and row.total_forecasts > 0:
return {
"total_forecasts": int(row.total_forecasts),
"avg_predicted_demand": float(row.avg_predicted_demand or 0),
"min_predicted_demand": float(row.min_predicted_demand or 0),
"max_predicted_demand": float(row.max_predicted_demand or 0),
"avg_confidence_interval": float(row.avg_confidence_interval or 0),
"avg_processing_time_ms": float(row.avg_processing_time_ms or 0),
"unique_products": int(row.unique_products or 0),
"unique_models": int(row.unique_models or 0),
"period_days": days_back
}
return {
"total_forecasts": 0,
"avg_predicted_demand": 0.0,
"min_predicted_demand": 0.0,
"max_predicted_demand": 0.0,
"avg_confidence_interval": 0.0,
"avg_processing_time_ms": 0.0,
"unique_products": 0,
"unique_models": 0,
"period_days": days_back
}
except Exception as e:
logger.error("Failed to get forecast accuracy metrics",
tenant_id=tenant_id,
error=str(e))
return {
"total_forecasts": 0,
"avg_predicted_demand": 0.0,
"min_predicted_demand": 0.0,
"max_predicted_demand": 0.0,
"avg_confidence_interval": 0.0,
"avg_processing_time_ms": 0.0,
"unique_products": 0,
"unique_models": 0,
"period_days": days_back
}
async def get_demand_trends(
self,
tenant_id: str,
inventory_product_id: str,
days_back: int = 30
) -> Dict[str, Any]:
"""Get demand trends for a product"""
try:
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_back)
query_text = """
SELECT
DATE(forecast_date) as date,
AVG(predicted_demand) as avg_demand,
MIN(predicted_demand) as min_demand,
MAX(predicted_demand) as max_demand,
COUNT(*) as forecast_count
FROM forecasts
WHERE tenant_id = :tenant_id
AND inventory_product_id = :inventory_product_id
AND forecast_date >= :cutoff_date
GROUP BY DATE(forecast_date)
ORDER BY date DESC
"""
result = await self.session.execute(text(query_text), {
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id,
"cutoff_date": cutoff_date
})
trends = []
for row in result.fetchall():
trends.append({
"date": row.date.isoformat() if row.date else None,
"avg_demand": float(row.avg_demand),
"min_demand": float(row.min_demand),
"max_demand": float(row.max_demand),
"forecast_count": int(row.forecast_count)
})
# Calculate overall trend direction
if len(trends) >= 2:
recent_avg = sum(t["avg_demand"] for t in trends[:7]) / min(7, len(trends))
older_avg = sum(t["avg_demand"] for t in trends[-7:]) / min(7, len(trends[-7:]))
trend_direction = "increasing" if recent_avg > older_avg else "decreasing"
else:
trend_direction = "stable"
return {
"inventory_product_id": inventory_product_id,
"period_days": days_back,
"trends": trends,
"trend_direction": trend_direction,
"total_data_points": len(trends)
}
except Exception as e:
logger.error("Failed to get demand trends",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
error=str(e))
return {
"inventory_product_id": inventory_product_id,
"period_days": days_back,
"trends": [],
"trend_direction": "unknown",
"total_data_points": 0
}
async def get_model_usage_statistics(self, tenant_id: str) -> Dict[str, Any]:
"""Get statistics about model usage"""
try:
# Get model usage counts
model_query = text("""
SELECT
model_id,
algorithm,
COUNT(*) as usage_count,
AVG(predicted_demand) as avg_prediction,
MAX(forecast_date) as last_used,
COUNT(DISTINCT inventory_product_id) as products_covered
FROM forecasts
WHERE tenant_id = :tenant_id
GROUP BY model_id, algorithm
ORDER BY usage_count DESC
""")
result = await self.session.execute(model_query, {"tenant_id": tenant_id})
model_stats = []
for row in result.fetchall():
model_stats.append({
"model_id": row.model_id,
"algorithm": row.algorithm,
"usage_count": int(row.usage_count),
"avg_prediction": float(row.avg_prediction),
"last_used": row.last_used.isoformat() if row.last_used else None,
"products_covered": int(row.products_covered)
})
# Get algorithm distribution
algorithm_query = text("""
SELECT algorithm, COUNT(*) as count
FROM forecasts
WHERE tenant_id = :tenant_id
GROUP BY algorithm
""")
algorithm_result = await self.session.execute(algorithm_query, {"tenant_id": tenant_id})
algorithm_distribution = {row.algorithm: row.count for row in algorithm_result.fetchall()}
return {
"model_statistics": model_stats,
"algorithm_distribution": algorithm_distribution,
"total_unique_models": len(model_stats)
}
except Exception as e:
logger.error("Failed to get model usage statistics",
tenant_id=tenant_id,
error=str(e))
return {
"model_statistics": [],
"algorithm_distribution": {},
"total_unique_models": 0
}
async def cleanup_old_forecasts(self, days_old: int = 90) -> int:
"""Clean up old forecasts"""
return await self.cleanup_old_records(days_old=days_old)
async def get_forecast_summary(self, tenant_id: str) -> Dict[str, Any]:
"""Get comprehensive forecast summary for a tenant"""
try:
# Get basic statistics
basic_stats = await self.get_statistics_by_tenant(tenant_id)
# Get accuracy metrics
accuracy_metrics = await self.get_forecast_accuracy_metrics(tenant_id)
# Get model usage
model_usage = await self.get_model_usage_statistics(tenant_id)
# Get recent activity
recent_forecasts = await self.get_recent_records(tenant_id, hours=24)
return {
"tenant_id": tenant_id,
"basic_statistics": basic_stats,
"accuracy_metrics": accuracy_metrics,
"model_usage": model_usage,
"recent_activity": {
"forecasts_last_24h": len(recent_forecasts),
"latest_forecast": recent_forecasts[0].forecast_date.isoformat() if recent_forecasts else None
}
}
except Exception as e:
logger.error("Failed to get forecast summary",
tenant_id=tenant_id,
error=str(e))
return {"error": f"Failed to get forecast summary: {str(e)}"}
async def get_forecasts_by_date(
self,
tenant_id: str,
forecast_date: date,
inventory_product_id: str = None
) -> List[Forecast]:
"""
Get all forecasts for a specific date.
Used for forecast validation against actual sales.
Args:
tenant_id: Tenant UUID
forecast_date: Date to get forecasts for
inventory_product_id: Optional product filter
Returns:
List of forecasts for the date
"""
try:
query = select(Forecast).where(
and_(
Forecast.tenant_id == tenant_id,
func.date(Forecast.forecast_date) == forecast_date
)
)
if inventory_product_id:
query = query.where(Forecast.inventory_product_id == inventory_product_id)
result = await self.session.execute(query)
forecasts = result.scalars().all()
logger.info("Retrieved forecasts by date",
tenant_id=tenant_id,
forecast_date=forecast_date.isoformat(),
count=len(forecasts))
return list(forecasts)
except Exception as e:
logger.error("Failed to get forecasts by date",
tenant_id=tenant_id,
forecast_date=forecast_date.isoformat(),
error=str(e))
raise DatabaseError(f"Failed to get forecasts: {str(e)}")
async def bulk_create_forecasts(self, forecasts_data: List[Dict[str, Any]]) -> List[Forecast]:
"""Bulk create multiple forecasts"""
try:
created_forecasts = []
for forecast_data in forecasts_data:
# Validate each forecast
validation_result = self._validate_forecast_data(
forecast_data,
["tenant_id", "inventory_product_id", "location", "forecast_date",
"predicted_demand", "confidence_lower", "confidence_upper", "model_id"]
)
if not validation_result["is_valid"]:
logger.warning("Skipping invalid forecast data",
errors=validation_result["errors"],
data=forecast_data)
continue
forecast = await self.create(forecast_data)
created_forecasts.append(forecast)
logger.info("Bulk created forecasts",
requested_count=len(forecasts_data),
created_count=len(created_forecasts))
return created_forecasts
except Exception as e:
logger.error("Failed to bulk create forecasts",
requested_count=len(forecasts_data),
error=str(e))
raise DatabaseError(f"Bulk forecast creation failed: {str(e)}")

View File

@@ -0,0 +1,214 @@
# services/forecasting/app/repositories/forecasting_alert_repository.py
"""
Forecasting Alert Repository
Data access layer for forecasting-specific alert detection and analysis
"""
from typing import List, Dict, Any
from uuid import UUID
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
import structlog
logger = structlog.get_logger()
class ForecastingAlertRepository:
"""Repository for forecasting alert data access"""
def __init__(self, session: AsyncSession):
self.session = session
async def get_weekend_demand_surges(self, tenant_id: UUID) -> List[Dict[str, Any]]:
"""
Get predicted weekend demand surges
Returns forecasts showing significant growth over previous weeks
"""
try:
query = text("""
WITH weekend_forecast AS (
SELECT
f.tenant_id,
f.inventory_product_id,
f.product_name,
f.predicted_demand,
f.forecast_date,
LAG(f.predicted_demand, 7) OVER (
PARTITION BY f.tenant_id, f.inventory_product_id
ORDER BY f.forecast_date
) as prev_week_demand,
AVG(f.predicted_demand) OVER (
PARTITION BY f.tenant_id, f.inventory_product_id
ORDER BY f.forecast_date
ROWS BETWEEN 6 PRECEDING AND CURRENT ROW
) as avg_weekly_demand
FROM forecasts f
WHERE f.forecast_date >= CURRENT_DATE + INTERVAL '1 day'
AND f.forecast_date <= CURRENT_DATE + INTERVAL '3 days'
AND EXTRACT(DOW FROM f.forecast_date) IN (6, 0)
AND f.tenant_id = :tenant_id
),
surge_analysis AS (
SELECT *,
CASE
WHEN prev_week_demand > 0 THEN
(predicted_demand - prev_week_demand) / prev_week_demand * 100
ELSE 0
END as growth_percentage,
CASE
WHEN avg_weekly_demand > 0 THEN
(predicted_demand - avg_weekly_demand) / avg_weekly_demand * 100
ELSE 0
END as avg_growth_percentage
FROM weekend_forecast
)
SELECT * FROM surge_analysis
WHERE growth_percentage > 50 OR avg_growth_percentage > 50
ORDER BY growth_percentage DESC
""")
result = await self.session.execute(query, {"tenant_id": tenant_id})
return [dict(row._mapping) for row in result.fetchall()]
except Exception as e:
logger.error("Failed to get weekend demand surges", error=str(e), tenant_id=str(tenant_id))
raise
async def get_weather_impact_forecasts(self, tenant_id: UUID) -> List[Dict[str, Any]]:
"""
Get weather impact on demand forecasts
Returns forecasts with rain or significant demand changes
"""
try:
query = text("""
WITH weather_impact AS (
SELECT
f.tenant_id,
f.inventory_product_id,
f.product_name,
f.predicted_demand,
f.forecast_date,
f.weather_precipitation,
f.weather_temperature,
f.traffic_volume,
AVG(f.predicted_demand) OVER (
PARTITION BY f.tenant_id, f.inventory_product_id
ORDER BY f.forecast_date
ROWS BETWEEN 6 PRECEDING AND CURRENT ROW
) as avg_demand
FROM forecasts f
WHERE f.forecast_date >= CURRENT_DATE + INTERVAL '1 day'
AND f.forecast_date <= CURRENT_DATE + INTERVAL '2 days'
AND f.tenant_id = :tenant_id
),
rain_impact AS (
SELECT *,
CASE
WHEN weather_precipitation > 2.0 THEN true
ELSE false
END as rain_forecast,
CASE
WHEN traffic_volume < 80 THEN true
ELSE false
END as low_traffic_expected,
(predicted_demand - avg_demand) / avg_demand * 100 as demand_change
FROM weather_impact
)
SELECT * FROM rain_impact
WHERE rain_forecast = true OR demand_change < -15
ORDER BY demand_change ASC
""")
result = await self.session.execute(query, {"tenant_id": tenant_id})
return [dict(row._mapping) for row in result.fetchall()]
except Exception as e:
logger.error("Failed to get weather impact forecasts", error=str(e), tenant_id=str(tenant_id))
raise
async def get_holiday_demand_spikes(self, tenant_id: UUID) -> List[Dict[str, Any]]:
"""
Get historical holiday demand spike analysis
Returns products with significant holiday demand increases
"""
try:
query = text("""
WITH holiday_demand AS (
SELECT
f.tenant_id,
f.inventory_product_id,
f.product_name,
AVG(f.predicted_demand) as avg_holiday_demand,
AVG(CASE WHEN f.is_holiday = false THEN f.predicted_demand END) as avg_normal_demand,
COUNT(*) as forecast_count
FROM forecasts f
WHERE f.created_at > CURRENT_DATE - INTERVAL '365 days'
AND f.tenant_id = :tenant_id
GROUP BY f.tenant_id, f.inventory_product_id, f.product_name
HAVING COUNT(*) >= 10
),
demand_spike_analysis AS (
SELECT *,
CASE
WHEN avg_normal_demand > 0 THEN
(avg_holiday_demand - avg_normal_demand) / avg_normal_demand * 100
ELSE 0
END as spike_percentage
FROM holiday_demand
)
SELECT * FROM demand_spike_analysis
WHERE spike_percentage > 25
ORDER BY spike_percentage DESC
""")
result = await self.session.execute(query, {"tenant_id": tenant_id})
return [dict(row._mapping) for row in result.fetchall()]
except Exception as e:
logger.error("Failed to get holiday demand spikes", error=str(e), tenant_id=str(tenant_id))
raise
async def get_demand_pattern_analysis(self, tenant_id: UUID) -> List[Dict[str, Any]]:
"""
Get weekly demand pattern analysis for optimization
Returns products with significant demand variations
"""
try:
query = text("""
WITH weekly_patterns AS (
SELECT
f.tenant_id,
f.inventory_product_id,
f.product_name,
EXTRACT(DOW FROM f.forecast_date) as day_of_week,
AVG(f.predicted_demand) as avg_demand,
STDDEV(f.predicted_demand) as demand_variance,
COUNT(*) as data_points
FROM forecasts f
WHERE f.created_at > CURRENT_DATE - INTERVAL '60 days'
AND f.tenant_id = :tenant_id
GROUP BY f.tenant_id, f.inventory_product_id, f.product_name, EXTRACT(DOW FROM f.forecast_date)
HAVING COUNT(*) >= 5
),
pattern_analysis AS (
SELECT
tenant_id, inventory_product_id, product_name,
MAX(avg_demand) as peak_demand,
MIN(avg_demand) as min_demand,
AVG(avg_demand) as overall_avg,
MAX(avg_demand) - MIN(avg_demand) as demand_range
FROM weekly_patterns
GROUP BY tenant_id, inventory_product_id, product_name
)
SELECT * FROM pattern_analysis
WHERE demand_range > overall_avg * 0.3
AND peak_demand > overall_avg * 1.5
ORDER BY demand_range DESC
""")
result = await self.session.execute(query, {"tenant_id": tenant_id})
return [dict(row._mapping) for row in result.fetchall()]
except Exception as e:
logger.error("Failed to get demand pattern analysis", error=str(e), tenant_id=str(tenant_id))
raise

View File

@@ -0,0 +1,271 @@
"""
Performance Metric Repository
Repository for model performance metrics in forecasting service
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
from datetime import datetime, timedelta, timezone
import structlog
from .base import ForecastingBaseRepository
from app.models.predictions import ModelPerformanceMetric
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class PerformanceMetricRepository(ForecastingBaseRepository):
"""Repository for model performance metrics operations"""
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 900):
# Performance metrics are stable, longer cache time (15 minutes)
super().__init__(ModelPerformanceMetric, session, cache_ttl)
async def create_metric(self, metric_data: Dict[str, Any]) -> ModelPerformanceMetric:
"""Create a new performance metric"""
try:
# Validate metric data
validation_result = self._validate_forecast_data(
metric_data,
["model_id", "tenant_id", "inventory_product_id", "evaluation_date"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid metric data: {validation_result['errors']}")
metric = await self.create(metric_data)
logger.info("Performance metric created",
metric_id=metric.id,
model_id=metric.model_id,
tenant_id=metric.tenant_id,
inventory_product_id=metric.inventory_product_id)
return metric
except ValidationError:
raise
except Exception as e:
logger.error("Failed to create performance metric",
model_id=metric_data.get("model_id"),
error=str(e))
raise DatabaseError(f"Failed to create metric: {str(e)}")
async def get_metrics_by_model(
self,
model_id: str,
skip: int = 0,
limit: int = 100
) -> List[ModelPerformanceMetric]:
"""Get all metrics for a model"""
try:
return await self.get_multi(
filters={"model_id": model_id},
skip=skip,
limit=limit,
order_by="evaluation_date",
order_desc=True
)
except Exception as e:
logger.error("Failed to get metrics by model",
model_id=model_id,
error=str(e))
raise DatabaseError(f"Failed to get metrics: {str(e)}")
async def get_latest_metric_for_model(self, model_id: str) -> Optional[ModelPerformanceMetric]:
"""Get the latest performance metric for a model"""
try:
metrics = await self.get_multi(
filters={"model_id": model_id},
limit=1,
order_by="evaluation_date",
order_desc=True
)
return metrics[0] if metrics else None
except Exception as e:
logger.error("Failed to get latest metric for model",
model_id=model_id,
error=str(e))
raise DatabaseError(f"Failed to get latest metric: {str(e)}")
async def get_performance_trends(
self,
tenant_id: str,
inventory_product_id: str = None,
days: int = 30
) -> Dict[str, Any]:
"""Get performance trends over time"""
try:
start_date = datetime.now(timezone.utc) - timedelta(days=days)
conditions = [
"tenant_id = :tenant_id",
"evaluation_date >= :start_date"
]
params = {
"tenant_id": tenant_id,
"start_date": start_date
}
if inventory_product_id:
conditions.append("inventory_product_id = :inventory_product_id")
params["inventory_product_id"] = inventory_product_id
query_text = f"""
SELECT
DATE(evaluation_date) as date,
inventory_product_id,
AVG(mae) as avg_mae,
AVG(mape) as avg_mape,
AVG(rmse) as avg_rmse,
AVG(accuracy_score) as avg_accuracy,
COUNT(*) as measurement_count
FROM model_performance_metrics
WHERE {' AND '.join(conditions)}
GROUP BY DATE(evaluation_date), inventory_product_id
ORDER BY date DESC, inventory_product_id
"""
result = await self.session.execute(text(query_text), params)
trends = []
for row in result.fetchall():
trends.append({
"date": row.date.isoformat() if row.date else None,
"inventory_product_id": row.inventory_product_id,
"metrics": {
"avg_mae": float(row.avg_mae) if row.avg_mae else None,
"avg_mape": float(row.avg_mape) if row.avg_mape else None,
"avg_rmse": float(row.avg_rmse) if row.avg_rmse else None,
"avg_accuracy": float(row.avg_accuracy) if row.avg_accuracy else None
},
"measurement_count": int(row.measurement_count)
})
return {
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id,
"period_days": days,
"trends": trends,
"total_measurements": len(trends)
}
except Exception as e:
logger.error("Failed to get performance trends",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
error=str(e))
return {
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id,
"period_days": days,
"trends": [],
"total_measurements": 0
}
async def cleanup_old_metrics(self, days_old: int = 180) -> int:
"""Clean up old performance metrics"""
return await self.cleanup_old_records(days_old=days_old)
async def bulk_create_metrics(self, metrics: List[ModelPerformanceMetric]) -> int:
"""
Bulk insert performance metrics for validation
Args:
metrics: List of ModelPerformanceMetric objects to insert
Returns:
Number of metrics created
"""
try:
if not metrics:
return 0
self.session.add_all(metrics)
await self.session.flush()
logger.info(
"Bulk created performance metrics",
count=len(metrics)
)
return len(metrics)
except Exception as e:
logger.error(
"Failed to bulk create performance metrics",
count=len(metrics),
error=str(e)
)
raise DatabaseError(f"Failed to bulk create metrics: {str(e)}")
async def get_metrics_by_date_range(
self,
tenant_id: str,
start_date: datetime,
end_date: datetime,
inventory_product_id: Optional[str] = None
) -> List[ModelPerformanceMetric]:
"""
Get performance metrics for a date range
Args:
tenant_id: Tenant identifier
start_date: Start of date range
end_date: End of date range
inventory_product_id: Optional product filter
Returns:
List of performance metrics
"""
try:
filters = {
"tenant_id": tenant_id
}
if inventory_product_id:
filters["inventory_product_id"] = inventory_product_id
# Build custom query for date range
query_text = """
SELECT *
FROM model_performance_metrics
WHERE tenant_id = :tenant_id
AND evaluation_date >= :start_date
AND evaluation_date <= :end_date
"""
params = {
"tenant_id": tenant_id,
"start_date": start_date,
"end_date": end_date
}
if inventory_product_id:
query_text += " AND inventory_product_id = :inventory_product_id"
params["inventory_product_id"] = inventory_product_id
query_text += " ORDER BY evaluation_date DESC"
result = await self.session.execute(text(query_text), params)
rows = result.fetchall()
# Convert rows to ModelPerformanceMetric objects
metrics = []
for row in rows:
metric = ModelPerformanceMetric()
for column in row._mapping.keys():
setattr(metric, column, row._mapping[column])
metrics.append(metric)
return metrics
except Exception as e:
logger.error(
"Failed to get metrics by date range",
tenant_id=tenant_id,
error=str(e)
)
raise DatabaseError(f"Failed to get metrics: {str(e)}")

View File

@@ -0,0 +1,388 @@
"""
Prediction Batch Repository
Repository for prediction batch operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
from datetime import datetime, timedelta, timezone
import structlog
from .base import ForecastingBaseRepository
from app.models.forecasts import PredictionBatch
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class PredictionBatchRepository(ForecastingBaseRepository):
"""Repository for prediction batch operations"""
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 300):
# Batch operations change frequently, shorter cache time (5 minutes)
super().__init__(PredictionBatch, session, cache_ttl)
async def create_batch(self, batch_data: Dict[str, Any]) -> PredictionBatch:
"""Create a new prediction batch"""
try:
# Validate batch data
validation_result = self._validate_forecast_data(
batch_data,
["tenant_id", "batch_name"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid batch data: {validation_result['errors']}")
# Set default values
if "status" not in batch_data:
batch_data["status"] = "pending"
if "forecast_days" not in batch_data:
batch_data["forecast_days"] = 7
if "business_type" not in batch_data:
batch_data["business_type"] = "individual"
batch = await self.create(batch_data)
logger.info("Prediction batch created",
batch_id=batch.id,
tenant_id=batch.tenant_id,
batch_name=batch.batch_name)
return batch
except ValidationError:
raise
except Exception as e:
logger.error("Failed to create prediction batch",
tenant_id=batch_data.get("tenant_id"),
error=str(e))
raise DatabaseError(f"Failed to create batch: {str(e)}")
async def update_batch_progress(
self,
batch_id: str,
completed_products: int = None,
failed_products: int = None,
total_products: int = None,
status: str = None
) -> Optional[PredictionBatch]:
"""Update batch progress"""
try:
update_data = {}
if completed_products is not None:
update_data["completed_products"] = completed_products
if failed_products is not None:
update_data["failed_products"] = failed_products
if total_products is not None:
update_data["total_products"] = total_products
if status:
update_data["status"] = status
if status in ["completed", "failed"]:
update_data["completed_at"] = datetime.now(timezone.utc)
if not update_data:
return await self.get_by_id(batch_id)
updated_batch = await self.update(batch_id, update_data)
logger.debug("Batch progress updated",
batch_id=batch_id,
status=status,
completed=completed_products)
return updated_batch
except Exception as e:
logger.error("Failed to update batch progress",
batch_id=batch_id,
error=str(e))
raise DatabaseError(f"Failed to update batch: {str(e)}")
async def complete_batch(
self,
batch_id: str,
processing_time_ms: int = None
) -> Optional[PredictionBatch]:
"""Mark batch as completed"""
try:
update_data = {
"status": "completed",
"completed_at": datetime.now(timezone.utc)
}
if processing_time_ms:
update_data["processing_time_ms"] = processing_time_ms
updated_batch = await self.update(batch_id, update_data)
logger.info("Batch completed",
batch_id=batch_id,
processing_time_ms=processing_time_ms)
return updated_batch
except Exception as e:
logger.error("Failed to complete batch",
batch_id=batch_id,
error=str(e))
raise DatabaseError(f"Failed to complete batch: {str(e)}")
async def fail_batch(
self,
batch_id: str,
error_message: str,
processing_time_ms: int = None
) -> Optional[PredictionBatch]:
"""Mark batch as failed"""
try:
update_data = {
"status": "failed",
"completed_at": datetime.now(timezone.utc),
"error_message": error_message
}
if processing_time_ms:
update_data["processing_time_ms"] = processing_time_ms
updated_batch = await self.update(batch_id, update_data)
logger.error("Batch failed",
batch_id=batch_id,
error_message=error_message)
return updated_batch
except Exception as e:
logger.error("Failed to mark batch as failed",
batch_id=batch_id,
error=str(e))
raise DatabaseError(f"Failed to fail batch: {str(e)}")
async def cancel_batch(
self,
batch_id: str,
cancelled_by: str = None
) -> Optional[PredictionBatch]:
"""Cancel a batch"""
try:
batch = await self.get_by_id(batch_id)
if not batch:
return None
if batch.status in ["completed", "failed"]:
logger.warning("Cannot cancel finished batch",
batch_id=batch_id,
status=batch.status)
return batch
update_data = {
"status": "cancelled",
"completed_at": datetime.now(timezone.utc),
"cancelled_by": cancelled_by,
"error_message": f"Cancelled by {cancelled_by}" if cancelled_by else "Cancelled"
}
updated_batch = await self.update(batch_id, update_data)
logger.info("Batch cancelled",
batch_id=batch_id,
cancelled_by=cancelled_by)
return updated_batch
except Exception as e:
logger.error("Failed to cancel batch",
batch_id=batch_id,
error=str(e))
raise DatabaseError(f"Failed to cancel batch: {str(e)}")
async def get_active_batches(self, tenant_id: str = None) -> List[PredictionBatch]:
"""Get currently active (pending/processing) batches"""
try:
filters = {"status": "processing"}
if tenant_id:
# Need to handle multiple status values with raw query
query_text = """
SELECT * FROM prediction_batches
WHERE status IN ('pending', 'processing')
AND tenant_id = :tenant_id
ORDER BY requested_at DESC
"""
params = {"tenant_id": tenant_id}
else:
query_text = """
SELECT * FROM prediction_batches
WHERE status IN ('pending', 'processing')
ORDER BY requested_at DESC
"""
params = {}
result = await self.session.execute(text(query_text), params)
batches = []
for row in result.fetchall():
record_dict = dict(row._mapping)
batch = self.model(**record_dict)
batches.append(batch)
return batches
except Exception as e:
logger.error("Failed to get active batches",
tenant_id=tenant_id,
error=str(e))
return []
async def get_batch_statistics(self, tenant_id: str = None) -> Dict[str, Any]:
"""Get batch processing statistics"""
try:
base_filter = "WHERE 1=1"
params = {}
if tenant_id:
base_filter = "WHERE tenant_id = :tenant_id"
params["tenant_id"] = tenant_id
# Get counts by status
status_query = text(f"""
SELECT
status,
COUNT(*) as count,
AVG(CASE WHEN processing_time_ms IS NOT NULL THEN processing_time_ms END) as avg_processing_time_ms
FROM prediction_batches
{base_filter}
GROUP BY status
""")
result = await self.session.execute(status_query, params)
status_stats = {}
total_batches = 0
avg_processing_times = {}
for row in result.fetchall():
status_stats[row.status] = row.count
total_batches += row.count
if row.avg_processing_time_ms:
avg_processing_times[row.status] = float(row.avg_processing_time_ms)
# Get recent activity (batches in last 7 days)
seven_days_ago = datetime.now(timezone.utc) - timedelta(days=7)
recent_query = text(f"""
SELECT COUNT(*) as count
FROM prediction_batches
{base_filter}
AND requested_at >= :seven_days_ago
""")
recent_result = await self.session.execute(recent_query, {
**params,
"seven_days_ago": seven_days_ago
})
recent_batches = recent_result.scalar() or 0
# Calculate success rate
completed = status_stats.get("completed", 0)
failed = status_stats.get("failed", 0)
cancelled = status_stats.get("cancelled", 0)
finished_batches = completed + failed + cancelled
success_rate = (completed / finished_batches * 100) if finished_batches > 0 else 0
return {
"total_batches": total_batches,
"batches_by_status": status_stats,
"success_rate": round(success_rate, 2),
"recent_batches_7d": recent_batches,
"avg_processing_times_ms": avg_processing_times
}
except Exception as e:
logger.error("Failed to get batch statistics",
tenant_id=tenant_id,
error=str(e))
return {
"total_batches": 0,
"batches_by_status": {},
"success_rate": 0.0,
"recent_batches_7d": 0,
"avg_processing_times_ms": {}
}
async def cleanup_old_batches(self, days_old: int = 30) -> int:
"""Clean up old completed/failed batches"""
try:
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_old)
query_text = """
DELETE FROM prediction_batches
WHERE status IN ('completed', 'failed', 'cancelled')
AND completed_at < :cutoff_date
"""
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
deleted_count = result.rowcount
logger.info("Cleaned up old prediction batches",
deleted_count=deleted_count,
days_old=days_old)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup old batches",
error=str(e))
raise DatabaseError(f"Batch cleanup failed: {str(e)}")
async def get_batch_details(self, batch_id: str) -> Dict[str, Any]:
"""Get detailed batch information"""
try:
batch = await self.get_by_id(batch_id)
if not batch:
return {"error": "Batch not found"}
# Calculate completion percentage
completion_percentage = 0
if batch.total_products > 0:
completion_percentage = (batch.completed_products / batch.total_products) * 100
# Calculate elapsed time
elapsed_time_ms = 0
if batch.completed_at:
elapsed_time_ms = int((batch.completed_at - batch.requested_at).total_seconds() * 1000)
elif batch.status in ["pending", "processing"]:
elapsed_time_ms = int((datetime.now(timezone.utc) - batch.requested_at).total_seconds() * 1000)
return {
"batch_id": str(batch.id),
"tenant_id": str(batch.tenant_id),
"batch_name": batch.batch_name,
"status": batch.status,
"progress": {
"total_products": batch.total_products,
"completed_products": batch.completed_products,
"failed_products": batch.failed_products,
"completion_percentage": round(completion_percentage, 2)
},
"timing": {
"requested_at": batch.requested_at.isoformat(),
"completed_at": batch.completed_at.isoformat() if batch.completed_at else None,
"elapsed_time_ms": elapsed_time_ms,
"processing_time_ms": batch.processing_time_ms
},
"configuration": {
"forecast_days": batch.forecast_days,
"business_type": batch.business_type
},
"error_message": batch.error_message,
"cancelled_by": batch.cancelled_by
}
except Exception as e:
logger.error("Failed to get batch details",
batch_id=batch_id,
error=str(e))
return {"error": f"Failed to get batch details: {str(e)}"}

View File

@@ -0,0 +1,302 @@
"""
Prediction Cache Repository
Repository for prediction cache operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
from datetime import datetime, timedelta, timezone
import structlog
import hashlib
from .base import ForecastingBaseRepository
from app.models.predictions import PredictionCache
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class PredictionCacheRepository(ForecastingBaseRepository):
"""Repository for prediction cache operations"""
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 60):
# Cache entries change very frequently, short cache time (1 minute)
super().__init__(PredictionCache, session, cache_ttl)
def _generate_cache_key(
self,
tenant_id: str,
inventory_product_id: str,
location: str,
forecast_date: datetime
) -> str:
"""Generate cache key for prediction"""
key_data = f"{tenant_id}:{inventory_product_id}:{location}:{forecast_date.isoformat()}"
return hashlib.md5(key_data.encode()).hexdigest()
async def cache_prediction(
self,
tenant_id: str,
inventory_product_id: str,
location: str,
forecast_date: datetime,
predicted_demand: float,
confidence_lower: float,
confidence_upper: float,
model_id: str,
expires_in_hours: int = 24
) -> PredictionCache:
"""Cache a prediction result"""
try:
cache_key = self._generate_cache_key(tenant_id, inventory_product_id, location, forecast_date)
expires_at = datetime.now(timezone.utc) + timedelta(hours=expires_in_hours)
cache_data = {
"cache_key": cache_key,
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id,
"location": location,
"forecast_date": forecast_date,
"predicted_demand": predicted_demand,
"confidence_lower": confidence_lower,
"confidence_upper": confidence_upper,
"model_id": model_id,
"expires_at": expires_at,
"hit_count": 0
}
# Try to update existing cache entry first
existing_cache = await self.get_by_field("cache_key", cache_key)
if existing_cache:
cache_entry = await self.update(existing_cache.id, cache_data)
logger.debug("Updated cache entry", cache_key=cache_key)
else:
cache_entry = await self.create(cache_data)
logger.debug("Created cache entry", cache_key=cache_key)
return cache_entry
except Exception as e:
logger.error("Failed to cache prediction",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
error=str(e))
raise DatabaseError(f"Failed to cache prediction: {str(e)}")
async def get_cached_prediction(
self,
tenant_id: str,
inventory_product_id: str,
location: str,
forecast_date: datetime
) -> Optional[PredictionCache]:
"""Get cached prediction if valid"""
try:
cache_key = self._generate_cache_key(tenant_id, inventory_product_id, location, forecast_date)
cache_entry = await self.get_by_field("cache_key", cache_key)
if not cache_entry:
logger.debug("Cache miss", cache_key=cache_key)
return None
# Check if cache entry has expired
if cache_entry.expires_at < datetime.now(timezone.utc):
logger.debug("Cache expired", cache_key=cache_key)
await self.delete(cache_entry.id)
return None
# Increment hit count
await self.update(cache_entry.id, {"hit_count": cache_entry.hit_count + 1})
logger.debug("Cache hit",
cache_key=cache_key,
hit_count=cache_entry.hit_count + 1)
return cache_entry
except Exception as e:
logger.error("Failed to get cached prediction",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
error=str(e))
return None
async def invalidate_cache(
self,
tenant_id: str,
inventory_product_id: str = None,
location: str = None
) -> int:
"""Invalidate cache entries"""
try:
conditions = ["tenant_id = :tenant_id"]
params = {"tenant_id": tenant_id}
if inventory_product_id:
conditions.append("inventory_product_id = :inventory_product_id")
params["inventory_product_id"] = inventory_product_id
if location:
conditions.append("location = :location")
params["location"] = location
query_text = f"""
DELETE FROM prediction_cache
WHERE {' AND '.join(conditions)}
"""
result = await self.session.execute(text(query_text), params)
invalidated_count = result.rowcount
logger.info("Cache invalidated",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
location=location,
invalidated_count=invalidated_count)
return invalidated_count
except Exception as e:
logger.error("Failed to invalidate cache",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Cache invalidation failed: {str(e)}")
async def cleanup_expired_cache(self) -> int:
"""Clean up expired cache entries"""
try:
query_text = """
DELETE FROM prediction_cache
WHERE expires_at < :now
"""
result = await self.session.execute(text(query_text), {"now": datetime.now(timezone.utc)})
deleted_count = result.rowcount
logger.info("Cleaned up expired cache entries",
deleted_count=deleted_count)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup expired cache",
error=str(e))
raise DatabaseError(f"Cache cleanup failed: {str(e)}")
async def get_cache_statistics(self, tenant_id: str = None) -> Dict[str, Any]:
"""Get cache performance statistics"""
try:
base_filter = "WHERE 1=1"
params = {}
if tenant_id:
base_filter = "WHERE tenant_id = :tenant_id"
params["tenant_id"] = tenant_id
# Get cache statistics
stats_query = text(f"""
SELECT
COUNT(*) as total_entries,
COUNT(CASE WHEN expires_at > :now THEN 1 END) as active_entries,
COUNT(CASE WHEN expires_at <= :now THEN 1 END) as expired_entries,
SUM(hit_count) as total_hits,
AVG(hit_count) as avg_hits_per_entry,
MAX(hit_count) as max_hits,
COUNT(DISTINCT inventory_product_id) as unique_products
FROM prediction_cache
{base_filter}
""")
params["now"] = datetime.now(timezone.utc)
result = await self.session.execute(stats_query, params)
row = result.fetchone()
if row:
return {
"total_entries": int(row.total_entries or 0),
"active_entries": int(row.active_entries or 0),
"expired_entries": int(row.expired_entries or 0),
"total_hits": int(row.total_hits or 0),
"avg_hits_per_entry": float(row.avg_hits_per_entry or 0),
"max_hits": int(row.max_hits or 0),
"unique_products": int(row.unique_products or 0),
"cache_hit_ratio": round((row.total_hits / max(row.total_entries, 1)), 2)
}
return {
"total_entries": 0,
"active_entries": 0,
"expired_entries": 0,
"total_hits": 0,
"avg_hits_per_entry": 0.0,
"max_hits": 0,
"unique_products": 0,
"cache_hit_ratio": 0.0
}
except Exception as e:
logger.error("Failed to get cache statistics",
tenant_id=tenant_id,
error=str(e))
return {
"total_entries": 0,
"active_entries": 0,
"expired_entries": 0,
"total_hits": 0,
"avg_hits_per_entry": 0.0,
"max_hits": 0,
"unique_products": 0,
"cache_hit_ratio": 0.0
}
async def get_most_accessed_predictions(
self,
tenant_id: str = None,
limit: int = 10
) -> List[Dict[str, Any]]:
"""Get most frequently accessed cached predictions"""
try:
base_filter = "WHERE hit_count > 0"
params = {"limit": limit}
if tenant_id:
base_filter = "WHERE tenant_id = :tenant_id AND hit_count > 0"
params["tenant_id"] = tenant_id
query_text = f"""
SELECT
inventory_product_id,
location,
hit_count,
predicted_demand,
created_at,
expires_at
FROM prediction_cache
{base_filter}
ORDER BY hit_count DESC
LIMIT :limit
"""
result = await self.session.execute(text(query_text), params)
popular_predictions = []
for row in result.fetchall():
popular_predictions.append({
"inventory_product_id": row.inventory_product_id,
"location": row.location,
"hit_count": int(row.hit_count),
"predicted_demand": float(row.predicted_demand),
"created_at": row.created_at.isoformat() if row.created_at else None,
"expires_at": row.expires_at.isoformat() if row.expires_at else None
})
return popular_predictions
except Exception as e:
logger.error("Failed to get most accessed predictions",
tenant_id=tenant_id,
error=str(e))
return []