Files
bakery-ia/services/forecasting/app/repositories/base.py
2025-08-14 16:47:34 +02:00

253 lines
10 KiB
Python

"""
Base Repository for Forecasting Service
Service-specific repository base class with forecasting utilities
"""
from typing import Optional, List, Dict, Any, Type
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
from datetime import datetime, date, timedelta
import structlog
from shared.database.repository import BaseRepository
from shared.database.exceptions import DatabaseError
logger = structlog.get_logger()
class ForecastingBaseRepository(BaseRepository):
"""Base repository for forecasting service with common forecasting operations"""
def __init__(self, model: Type, session: AsyncSession, cache_ttl: Optional[int] = 600):
# Forecasting data benefits from medium cache time (10 minutes)
super().__init__(model, session, cache_ttl)
async def get_by_tenant_id(self, tenant_id: str, skip: int = 0, limit: int = 100) -> List:
"""Get records by tenant ID"""
if hasattr(self.model, 'tenant_id'):
return await self.get_multi(
skip=skip,
limit=limit,
filters={"tenant_id": tenant_id},
order_by="created_at",
order_desc=True
)
return await self.get_multi(skip=skip, limit=limit)
async def get_by_inventory_product_id(
self,
tenant_id: str,
inventory_product_id: str,
skip: int = 0,
limit: int = 100
) -> List:
"""Get records by tenant and inventory product"""
if hasattr(self.model, 'inventory_product_id'):
return await self.get_multi(
skip=skip,
limit=limit,
filters={
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id
},
order_by="created_at",
order_desc=True
)
return await self.get_by_tenant_id(tenant_id, skip, limit)
async def get_by_date_range(
self,
tenant_id: str,
start_date: datetime,
end_date: datetime,
skip: int = 0,
limit: int = 100
) -> List:
"""Get records within date range for a tenant"""
if not hasattr(self.model, 'forecast_date') and not hasattr(self.model, 'created_at'):
logger.warning(f"Model {self.model.__name__} has no date field for filtering")
return []
try:
table_name = self.model.__tablename__
date_field = "forecast_date" if hasattr(self.model, 'forecast_date') else "created_at"
query_text = f"""
SELECT * FROM {table_name}
WHERE tenant_id = :tenant_id
AND {date_field} >= :start_date
AND {date_field} <= :end_date
ORDER BY {date_field} DESC
LIMIT :limit OFFSET :skip
"""
result = await self.session.execute(text(query_text), {
"tenant_id": tenant_id,
"start_date": start_date,
"end_date": end_date,
"limit": limit,
"skip": skip
})
# Convert rows to model objects
records = []
for row in result.fetchall():
record_dict = dict(row._mapping)
record = self.model(**record_dict)
records.append(record)
return records
except Exception as e:
logger.error("Failed to get records by date range",
model=self.model.__name__,
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Date range query failed: {str(e)}")
async def get_recent_records(
self,
tenant_id: str,
hours: int = 24,
skip: int = 0,
limit: int = 100
) -> List:
"""Get recent records for a tenant"""
cutoff_time = datetime.utcnow() - timedelta(hours=hours)
return await self.get_by_date_range(
tenant_id, cutoff_time, datetime.utcnow(), skip, limit
)
async def cleanup_old_records(self, days_old: int = 90) -> int:
"""Clean up old forecasting records"""
try:
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
table_name = self.model.__tablename__
# Use created_at or forecast_date for cleanup
date_field = "forecast_date" if hasattr(self.model, 'forecast_date') else "created_at"
query_text = f"""
DELETE FROM {table_name}
WHERE {date_field} < :cutoff_date
"""
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
deleted_count = result.rowcount
logger.info(f"Cleaned up old {self.model.__name__} records",
deleted_count=deleted_count,
days_old=days_old)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup old records",
model=self.model.__name__,
error=str(e))
raise DatabaseError(f"Cleanup failed: {str(e)}")
async def get_statistics_by_tenant(self, tenant_id: str) -> Dict[str, Any]:
"""Get statistics for a tenant"""
try:
table_name = self.model.__tablename__
# Get basic counts
total_records = await self.count(filters={"tenant_id": tenant_id})
# Get recent activity (records in last 7 days)
seven_days_ago = datetime.utcnow() - timedelta(days=7)
recent_records = len(await self.get_by_date_range(
tenant_id, seven_days_ago, datetime.utcnow(), limit=1000
))
# Get records by product if applicable
product_stats = {}
if hasattr(self.model, 'inventory_product_id'):
product_query = text(f"""
SELECT inventory_product_id, COUNT(*) as count
FROM {table_name}
WHERE tenant_id = :tenant_id
GROUP BY inventory_product_id
ORDER BY count DESC
""")
result = await self.session.execute(product_query, {"tenant_id": tenant_id})
product_stats = {row.inventory_product_id: row.count for row in result.fetchall()}
return {
"total_records": total_records,
"recent_records_7d": recent_records,
"records_by_product": product_stats
}
except Exception as e:
logger.error("Failed to get tenant statistics",
model=self.model.__name__,
tenant_id=tenant_id,
error=str(e))
return {
"total_records": 0,
"recent_records_7d": 0,
"records_by_product": {}
}
def _validate_forecast_data(self, data: Dict[str, Any], required_fields: List[str]) -> Dict[str, Any]:
"""Validate forecasting-related data"""
errors = []
for field in required_fields:
if field not in data or data[field] is None:
errors.append(f"Missing required field: {field}")
# Validate tenant_id format if present
if "tenant_id" in data and data["tenant_id"]:
tenant_id = data["tenant_id"]
if not isinstance(tenant_id, str) or len(tenant_id) < 1:
errors.append("Invalid tenant_id format")
# Validate inventory_product_id if present
if "inventory_product_id" in data and data["inventory_product_id"]:
inventory_product_id = data["inventory_product_id"]
if not isinstance(inventory_product_id, str) or len(inventory_product_id) < 1:
errors.append("Invalid inventory_product_id format")
# Validate dates if present - accept datetime objects, date objects, and date strings
date_fields = ["forecast_date", "created_at", "evaluation_date", "expires_at"]
for field in date_fields:
if field in data and data[field]:
field_value = data[field]
field_type = type(field_value).__name__
if isinstance(field_value, (datetime, date)):
logger.debug(f"Date field {field} is valid {field_type}", field_value=str(field_value))
continue # Already a datetime or date, valid
elif isinstance(field_value, str):
# Try to parse the string date
try:
from dateutil.parser import parse
parse(field_value) # Just validate, don't convert yet
logger.debug(f"Date field {field} is valid string", field_value=field_value)
except (ValueError, TypeError) as e:
logger.error(f"Date parsing failed for {field}", field_value=field_value, error=str(e))
errors.append(f"Invalid {field} format - must be datetime or valid date string")
else:
logger.error(f"Date field {field} has invalid type {field_type}", field_value=str(field_value))
errors.append(f"Invalid {field} format - must be datetime or valid date string")
# Validate numeric fields
numeric_fields = [
"predicted_demand", "confidence_lower", "confidence_upper",
"mae", "mape", "rmse", "accuracy_score"
]
for field in numeric_fields:
if field in data and data[field] is not None:
try:
float(data[field])
except (ValueError, TypeError):
errors.append(f"Invalid {field} format - must be numeric")
return {
"is_valid": len(errors) == 0,
"errors": errors
}