REFACTOR - Database logic
This commit is contained in:
253
services/forecasting/app/repositories/base.py
Normal file
253
services/forecasting/app/repositories/base.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""
|
||||
Base Repository for Forecasting Service
|
||||
Service-specific repository base class with forecasting utilities
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any, Type
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import text
|
||||
from datetime import datetime, date, timedelta
|
||||
import structlog
|
||||
|
||||
from shared.database.repository import BaseRepository
|
||||
from shared.database.exceptions import DatabaseError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class ForecastingBaseRepository(BaseRepository):
|
||||
"""Base repository for forecasting service with common forecasting operations"""
|
||||
|
||||
def __init__(self, model: Type, session: AsyncSession, cache_ttl: Optional[int] = 600):
|
||||
# Forecasting data benefits from medium cache time (10 minutes)
|
||||
super().__init__(model, session, cache_ttl)
|
||||
|
||||
async def get_by_tenant_id(self, tenant_id: str, skip: int = 0, limit: int = 100) -> List:
|
||||
"""Get records by tenant ID"""
|
||||
if hasattr(self.model, 'tenant_id'):
|
||||
return await self.get_multi(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
filters={"tenant_id": tenant_id},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
return await self.get_multi(skip=skip, limit=limit)
|
||||
|
||||
async def get_by_product_name(
|
||||
self,
|
||||
tenant_id: str,
|
||||
product_name: str,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List:
|
||||
"""Get records by tenant and product"""
|
||||
if hasattr(self.model, 'product_name'):
|
||||
return await self.get_multi(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
filters={
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name
|
||||
},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
return await self.get_by_tenant_id(tenant_id, skip, limit)
|
||||
|
||||
async def get_by_date_range(
|
||||
self,
|
||||
tenant_id: str,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List:
|
||||
"""Get records within date range for a tenant"""
|
||||
if not hasattr(self.model, 'forecast_date') and not hasattr(self.model, 'created_at'):
|
||||
logger.warning(f"Model {self.model.__name__} has no date field for filtering")
|
||||
return []
|
||||
|
||||
try:
|
||||
table_name = self.model.__tablename__
|
||||
date_field = "forecast_date" if hasattr(self.model, 'forecast_date') else "created_at"
|
||||
|
||||
query_text = f"""
|
||||
SELECT * FROM {table_name}
|
||||
WHERE tenant_id = :tenant_id
|
||||
AND {date_field} >= :start_date
|
||||
AND {date_field} <= :end_date
|
||||
ORDER BY {date_field} DESC
|
||||
LIMIT :limit OFFSET :skip
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {
|
||||
"tenant_id": tenant_id,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"limit": limit,
|
||||
"skip": skip
|
||||
})
|
||||
|
||||
# Convert rows to model objects
|
||||
records = []
|
||||
for row in result.fetchall():
|
||||
record_dict = dict(row._mapping)
|
||||
record = self.model(**record_dict)
|
||||
records.append(record)
|
||||
|
||||
return records
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get records by date range",
|
||||
model=self.model.__name__,
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Date range query failed: {str(e)}")
|
||||
|
||||
async def get_recent_records(
|
||||
self,
|
||||
tenant_id: str,
|
||||
hours: int = 24,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List:
|
||||
"""Get recent records for a tenant"""
|
||||
cutoff_time = datetime.utcnow() - timedelta(hours=hours)
|
||||
return await self.get_by_date_range(
|
||||
tenant_id, cutoff_time, datetime.utcnow(), skip, limit
|
||||
)
|
||||
|
||||
async def cleanup_old_records(self, days_old: int = 90) -> int:
|
||||
"""Clean up old forecasting records"""
|
||||
try:
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
|
||||
table_name = self.model.__tablename__
|
||||
|
||||
# Use created_at or forecast_date for cleanup
|
||||
date_field = "forecast_date" if hasattr(self.model, 'forecast_date') else "created_at"
|
||||
|
||||
query_text = f"""
|
||||
DELETE FROM {table_name}
|
||||
WHERE {date_field} < :cutoff_date
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
|
||||
deleted_count = result.rowcount
|
||||
|
||||
logger.info(f"Cleaned up old {self.model.__name__} records",
|
||||
deleted_count=deleted_count,
|
||||
days_old=days_old)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup old records",
|
||||
model=self.model.__name__,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Cleanup failed: {str(e)}")
|
||||
|
||||
async def get_statistics_by_tenant(self, tenant_id: str) -> Dict[str, Any]:
|
||||
"""Get statistics for a tenant"""
|
||||
try:
|
||||
table_name = self.model.__tablename__
|
||||
|
||||
# Get basic counts
|
||||
total_records = await self.count(filters={"tenant_id": tenant_id})
|
||||
|
||||
# Get recent activity (records in last 7 days)
|
||||
seven_days_ago = datetime.utcnow() - timedelta(days=7)
|
||||
recent_records = len(await self.get_by_date_range(
|
||||
tenant_id, seven_days_ago, datetime.utcnow(), limit=1000
|
||||
))
|
||||
|
||||
# Get records by product if applicable
|
||||
product_stats = {}
|
||||
if hasattr(self.model, 'product_name'):
|
||||
product_query = text(f"""
|
||||
SELECT product_name, COUNT(*) as count
|
||||
FROM {table_name}
|
||||
WHERE tenant_id = :tenant_id
|
||||
GROUP BY product_name
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(product_query, {"tenant_id": tenant_id})
|
||||
product_stats = {row.product_name: row.count for row in result.fetchall()}
|
||||
|
||||
return {
|
||||
"total_records": total_records,
|
||||
"recent_records_7d": recent_records,
|
||||
"records_by_product": product_stats
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get tenant statistics",
|
||||
model=self.model.__name__,
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"total_records": 0,
|
||||
"recent_records_7d": 0,
|
||||
"records_by_product": {}
|
||||
}
|
||||
|
||||
def _validate_forecast_data(self, data: Dict[str, Any], required_fields: List[str]) -> Dict[str, Any]:
|
||||
"""Validate forecasting-related data"""
|
||||
errors = []
|
||||
|
||||
for field in required_fields:
|
||||
if field not in data or not data[field]:
|
||||
errors.append(f"Missing required field: {field}")
|
||||
|
||||
# Validate tenant_id format if present
|
||||
if "tenant_id" in data and data["tenant_id"]:
|
||||
tenant_id = data["tenant_id"]
|
||||
if not isinstance(tenant_id, str) or len(tenant_id) < 1:
|
||||
errors.append("Invalid tenant_id format")
|
||||
|
||||
# Validate product_name if present
|
||||
if "product_name" in data and data["product_name"]:
|
||||
product_name = data["product_name"]
|
||||
if not isinstance(product_name, str) or len(product_name) < 1:
|
||||
errors.append("Invalid product_name format")
|
||||
|
||||
# Validate dates if present - accept datetime objects, date objects, and date strings
|
||||
date_fields = ["forecast_date", "created_at", "evaluation_date", "expires_at"]
|
||||
for field in date_fields:
|
||||
if field in data and data[field]:
|
||||
field_value = data[field]
|
||||
field_type = type(field_value).__name__
|
||||
|
||||
if isinstance(field_value, (datetime, date)):
|
||||
logger.debug(f"Date field {field} is valid {field_type}", field_value=str(field_value))
|
||||
continue # Already a datetime or date, valid
|
||||
elif isinstance(field_value, str):
|
||||
# Try to parse the string date
|
||||
try:
|
||||
from dateutil.parser import parse
|
||||
parse(field_value) # Just validate, don't convert yet
|
||||
logger.debug(f"Date field {field} is valid string", field_value=field_value)
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.error(f"Date parsing failed for {field}", field_value=field_value, error=str(e))
|
||||
errors.append(f"Invalid {field} format - must be datetime or valid date string")
|
||||
else:
|
||||
logger.error(f"Date field {field} has invalid type {field_type}", field_value=str(field_value))
|
||||
errors.append(f"Invalid {field} format - must be datetime or valid date string")
|
||||
|
||||
# Validate numeric fields
|
||||
numeric_fields = [
|
||||
"predicted_demand", "confidence_lower", "confidence_upper",
|
||||
"mae", "mape", "rmse", "accuracy_score"
|
||||
]
|
||||
for field in numeric_fields:
|
||||
if field in data and data[field] is not None:
|
||||
try:
|
||||
float(data[field])
|
||||
except (ValueError, TypeError):
|
||||
errors.append(f"Invalid {field} format - must be numeric")
|
||||
|
||||
return {
|
||||
"is_valid": len(errors) == 0,
|
||||
"errors": errors
|
||||
}
|
||||
Reference in New Issue
Block a user