REFACTOR - Database logic

This commit is contained in:
Urtzi Alfaro
2025-08-08 09:08:41 +02:00
parent 0154365bfc
commit 488bb3ef93
113 changed files with 22842 additions and 6503 deletions

View File

@@ -0,0 +1,12 @@
"""
Data Service Repositories
Repository implementations for data service
"""
from .base import DataBaseRepository
from .sales_repository import SalesRepository
__all__ = [
"DataBaseRepository",
"SalesRepository"
]

View File

@@ -0,0 +1,167 @@
"""
Base Repository for Data Service
Service-specific repository base class with data service utilities
"""
from typing import Optional, List, Dict, Any, Type, TypeVar, Generic
from sqlalchemy.ext.asyncio import AsyncSession
from datetime import datetime, timezone
import structlog
from shared.database.repository import BaseRepository
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
# Type variables for the data service repository
Model = TypeVar('Model')
CreateSchema = TypeVar('CreateSchema')
UpdateSchema = TypeVar('UpdateSchema')
class DataBaseRepository(BaseRepository[Model, CreateSchema, UpdateSchema], Generic[Model, CreateSchema, UpdateSchema]):
"""Base repository for data service with common data operations"""
def __init__(self, model: Type, session: AsyncSession, cache_ttl: Optional[int] = 300):
super().__init__(model, session, cache_ttl)
async def get_by_tenant_id(
self,
tenant_id: str,
skip: int = 0,
limit: int = 100
) -> List:
"""Get records filtered by tenant_id"""
return await self.get_multi(
skip=skip,
limit=limit,
filters={"tenant_id": tenant_id}
)
async def get_by_date_range(
self,
tenant_id: str,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
skip: int = 0,
limit: int = 100
) -> List:
"""Get records filtered by tenant and date range"""
try:
filters = {"tenant_id": tenant_id}
# Build date range filter
if start_date or end_date:
if not hasattr(self.model, 'date'):
raise ValidationError("Model does not have 'date' field for date filtering")
# This would need a more complex implementation for date ranges
# For now, we'll use the basic filter
if start_date and end_date:
# Would need custom query building for date ranges
pass
return await self.get_multi(
skip=skip,
limit=limit,
filters=filters,
order_by="date",
order_desc=True
)
except Exception as e:
logger.error(f"Failed to get records by date range",
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
error=str(e))
raise DatabaseError(f"Date range query failed: {str(e)}")
async def count_by_tenant(self, tenant_id: str) -> int:
"""Count records for a specific tenant"""
return await self.count(filters={"tenant_id": tenant_id})
async def validate_tenant_access(self, tenant_id: str, record_id: Any) -> bool:
"""Validate that a record belongs to the specified tenant"""
try:
record = await self.get_by_id(record_id)
if not record:
return False
# Check if record has tenant_id field and matches
if hasattr(record, 'tenant_id'):
return str(record.tenant_id) == str(tenant_id)
return True # If no tenant_id field, allow access
except Exception as e:
logger.error("Failed to validate tenant access",
tenant_id=tenant_id,
record_id=record_id,
error=str(e))
return False
async def get_tenant_stats(self, tenant_id: str) -> Dict[str, Any]:
"""Get statistics for a specific tenant"""
try:
total_records = await self.count_by_tenant(tenant_id)
# Get recent activity (if model has created_at)
recent_records = 0
if hasattr(self.model, 'created_at'):
# This would need custom query for date filtering
# For now, return basic stats
pass
return {
"tenant_id": tenant_id,
"total_records": total_records,
"recent_records": recent_records,
"model_type": self.model.__name__
}
except Exception as e:
logger.error("Failed to get tenant statistics",
tenant_id=tenant_id, error=str(e))
return {
"tenant_id": tenant_id,
"total_records": 0,
"recent_records": 0,
"model_type": self.model.__name__,
"error": str(e)
}
async def cleanup_old_records(
self,
tenant_id: str,
days_old: int = 365,
batch_size: int = 1000
) -> int:
"""Clean up old records for a tenant (if model has date/created_at field)"""
try:
if not hasattr(self.model, 'created_at') and not hasattr(self.model, 'date'):
logger.warning(f"Model {self.model.__name__} has no date field for cleanup")
return 0
# This would need custom implementation with raw SQL
# For now, return 0 to indicate no cleanup performed
logger.info(f"Cleanup requested for {self.model.__name__} but not implemented")
return 0
except Exception as e:
logger.error("Failed to cleanup old records",
tenant_id=tenant_id,
days_old=days_old,
error=str(e))
raise DatabaseError(f"Cleanup failed: {str(e)}")
def _ensure_utc_datetime(self, dt: Optional[datetime]) -> Optional[datetime]:
"""Ensure datetime is UTC timezone aware"""
if dt is None:
return None
if dt.tzinfo is None:
# Assume naive datetime is UTC
return dt.replace(tzinfo=timezone.utc)
return dt.astimezone(timezone.utc)

View File

@@ -0,0 +1,517 @@
"""
Sales Repository
Repository for sales data operations with business-specific queries
"""
from typing import Optional, List, Dict, Any, Type
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, or_, func, desc, asc, text
from datetime import datetime, timezone
import structlog
from .base import DataBaseRepository
from app.models.sales import SalesData
from app.schemas.sales import SalesDataCreate, SalesDataResponse
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class SalesRepository(DataBaseRepository[SalesData, SalesDataCreate, Dict]):
"""Repository for sales data operations"""
def __init__(self, model_class: Type, session: AsyncSession, cache_ttl: Optional[int] = 300):
super().__init__(model_class, session, cache_ttl)
async def get_by_tenant_and_date_range(
self,
tenant_id: str,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
product_names: Optional[List[str]] = None,
location_ids: Optional[List[str]] = None,
skip: int = 0,
limit: int = 100
) -> List[SalesData]:
"""Get sales data filtered by tenant, date range, and optional filters"""
try:
query = select(self.model).where(self.model.tenant_id == tenant_id)
# Add date range filter
if start_date:
start_date = self._ensure_utc_datetime(start_date)
query = query.where(self.model.date >= start_date)
if end_date:
end_date = self._ensure_utc_datetime(end_date)
query = query.where(self.model.date <= end_date)
# Add product filter
if product_names:
query = query.where(self.model.product_name.in_(product_names))
# Add location filter
if location_ids:
query = query.where(self.model.location_id.in_(location_ids))
# Order by date descending (most recent first)
query = query.order_by(desc(self.model.date))
# Apply pagination
query = query.offset(skip).limit(limit)
result = await self.session.execute(query)
return result.scalars().all()
except Exception as e:
logger.error("Failed to get sales by tenant and date range",
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
error=str(e))
raise DatabaseError(f"Failed to get sales data: {str(e)}")
async def get_sales_aggregation(
self,
tenant_id: str,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
group_by: str = "daily",
product_name: Optional[str] = None
) -> List[Dict[str, Any]]:
"""Get aggregated sales data for analytics"""
try:
# Determine date truncation based on group_by
if group_by == "daily":
date_trunc = "day"
elif group_by == "weekly":
date_trunc = "week"
elif group_by == "monthly":
date_trunc = "month"
else:
raise ValidationError(f"Invalid group_by value: {group_by}")
# Build base query
if self.session.bind.dialect.name == 'postgresql':
query = text("""
SELECT
DATE_TRUNC(:date_trunc, date) as period,
product_name,
COUNT(*) as record_count,
SUM(quantity_sold) as total_quantity,
SUM(revenue) as total_revenue,
AVG(quantity_sold) as average_quantity,
AVG(revenue) as average_revenue
FROM sales_data
WHERE tenant_id = :tenant_id
""")
else:
# SQLite fallback
query = text("""
SELECT
DATE(date) as period,
product_name,
COUNT(*) as record_count,
SUM(quantity_sold) as total_quantity,
SUM(revenue) as total_revenue,
AVG(quantity_sold) as average_quantity,
AVG(revenue) as average_revenue
FROM sales_data
WHERE tenant_id = :tenant_id
""")
params = {
"tenant_id": tenant_id,
"date_trunc": date_trunc
}
# Add date filters
if start_date:
query = text(str(query) + " AND date >= :start_date")
params["start_date"] = self._ensure_utc_datetime(start_date)
if end_date:
query = text(str(query) + " AND date <= :end_date")
params["end_date"] = self._ensure_utc_datetime(end_date)
# Add product filter
if product_name:
query = text(str(query) + " AND product_name = :product_name")
params["product_name"] = product_name
# Add GROUP BY and ORDER BY
query = text(str(query) + " GROUP BY period, product_name ORDER BY period DESC")
result = await self.session.execute(query, params)
rows = result.fetchall()
# Convert to list of dictionaries
aggregations = []
for row in rows:
aggregations.append({
"period": group_by,
"date": row.period,
"product_name": row.product_name,
"record_count": row.record_count,
"total_quantity": row.total_quantity,
"total_revenue": float(row.total_revenue),
"average_quantity": float(row.average_quantity),
"average_revenue": float(row.average_revenue)
})
return aggregations
except Exception as e:
logger.error("Failed to get sales aggregation",
tenant_id=tenant_id,
group_by=group_by,
error=str(e))
raise DatabaseError(f"Sales aggregation failed: {str(e)}")
async def get_top_products(
self,
tenant_id: str,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
limit: int = 10,
by_metric: str = "revenue"
) -> List[Dict[str, Any]]:
"""Get top products by quantity or revenue"""
try:
if by_metric not in ["revenue", "quantity"]:
raise ValidationError(f"Invalid metric: {by_metric}")
# Choose the aggregation column
metric_column = "revenue" if by_metric == "revenue" else "quantity_sold"
query = text(f"""
SELECT
product_name,
COUNT(*) as sale_count,
SUM(quantity_sold) as total_quantity,
SUM(revenue) as total_revenue,
AVG(revenue) as avg_revenue_per_sale
FROM sales_data
WHERE tenant_id = :tenant_id
{('AND date >= :start_date' if start_date else '')}
{('AND date <= :end_date' if end_date else '')}
GROUP BY product_name
ORDER BY SUM({metric_column}) DESC
LIMIT :limit
""")
params = {"tenant_id": tenant_id, "limit": limit}
if start_date:
params["start_date"] = self._ensure_utc_datetime(start_date)
if end_date:
params["end_date"] = self._ensure_utc_datetime(end_date)
result = await self.session.execute(query, params)
rows = result.fetchall()
products = []
for row in rows:
products.append({
"product_name": row.product_name,
"sale_count": row.sale_count,
"total_quantity": row.total_quantity,
"total_revenue": float(row.total_revenue),
"avg_revenue_per_sale": float(row.avg_revenue_per_sale),
"metric_used": by_metric
})
return products
except Exception as e:
logger.error("Failed to get top products",
tenant_id=tenant_id,
by_metric=by_metric,
error=str(e))
raise DatabaseError(f"Top products query failed: {str(e)}")
async def get_sales_by_location(
self,
tenant_id: str,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None
) -> List[Dict[str, Any]]:
"""Get sales statistics by location"""
try:
query = text("""
SELECT
COALESCE(location_id, 'unknown') as location_id,
COUNT(*) as sale_count,
SUM(quantity_sold) as total_quantity,
SUM(revenue) as total_revenue,
AVG(revenue) as avg_revenue_per_sale
FROM sales_data
WHERE tenant_id = :tenant_id
{date_filters}
GROUP BY location_id
ORDER BY SUM(revenue) DESC
""".format(
date_filters=(
"AND date >= :start_date" if start_date else ""
) + (
" AND date <= :end_date" if end_date else ""
)
))
params = {"tenant_id": tenant_id}
if start_date:
params["start_date"] = self._ensure_utc_datetime(start_date)
if end_date:
params["end_date"] = self._ensure_utc_datetime(end_date)
result = await self.session.execute(query, params)
rows = result.fetchall()
locations = []
for row in rows:
locations.append({
"location_id": row.location_id,
"sale_count": row.sale_count,
"total_quantity": row.total_quantity,
"total_revenue": float(row.total_revenue),
"avg_revenue_per_sale": float(row.avg_revenue_per_sale)
})
return locations
except Exception as e:
logger.error("Failed to get sales by location",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Sales by location query failed: {str(e)}")
async def create_bulk_sales(
self,
sales_records: List[Dict[str, Any]],
tenant_id: str
) -> List[SalesData]:
"""Create multiple sales records in bulk"""
try:
# Ensure all records have tenant_id
for record in sales_records:
record["tenant_id"] = tenant_id
# Ensure dates are timezone-aware
if "date" in record and record["date"]:
record["date"] = self._ensure_utc_datetime(record["date"])
return await self.bulk_create(sales_records)
except Exception as e:
logger.error("Failed to create bulk sales",
tenant_id=tenant_id,
record_count=len(sales_records),
error=str(e))
raise DatabaseError(f"Bulk sales creation failed: {str(e)}")
async def search_sales(
self,
tenant_id: str,
search_term: str,
skip: int = 0,
limit: int = 100
) -> List[SalesData]:
"""Search sales by product name or notes"""
try:
# Use the parent search method with sales-specific fields
search_fields = ["product_name", "notes", "location_id"]
# Filter by tenant first
query = select(self.model).where(
and_(
self.model.tenant_id == tenant_id,
or_(
self.model.product_name.ilike(f"%{search_term}%"),
self.model.notes.ilike(f"%{search_term}%") if hasattr(self.model, 'notes') else False,
self.model.location_id.ilike(f"%{search_term}%") if hasattr(self.model, 'location_id') else False
)
)
).order_by(desc(self.model.date)).offset(skip).limit(limit)
result = await self.session.execute(query)
return result.scalars().all()
except Exception as e:
logger.error("Failed to search sales",
tenant_id=tenant_id,
search_term=search_term,
error=str(e))
raise DatabaseError(f"Sales search failed: {str(e)}")
async def get_sales_summary(
self,
tenant_id: str,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None
) -> Dict[str, Any]:
"""Get comprehensive sales summary for a tenant"""
try:
base_filters = {"tenant_id": tenant_id}
# Build date filter for count
date_query = select(func.count(self.model.id)).where(self.model.tenant_id == tenant_id)
if start_date:
date_query = date_query.where(self.model.date >= self._ensure_utc_datetime(start_date))
if end_date:
date_query = date_query.where(self.model.date <= self._ensure_utc_datetime(end_date))
# Get basic counts
total_result = await self.session.execute(date_query)
total_sales = total_result.scalar() or 0
# Get revenue and quantity totals
summary_query = text("""
SELECT
COUNT(*) as total_records,
SUM(quantity_sold) as total_quantity,
SUM(revenue) as total_revenue,
AVG(revenue) as avg_revenue,
MIN(date) as earliest_sale,
MAX(date) as latest_sale,
COUNT(DISTINCT product_name) as unique_products,
COUNT(DISTINCT location_id) as unique_locations
FROM sales_data
WHERE tenant_id = :tenant_id
{date_filters}
""".format(
date_filters=(
"AND date >= :start_date" if start_date else ""
) + (
" AND date <= :end_date" if end_date else ""
)
))
params = {"tenant_id": tenant_id}
if start_date:
params["start_date"] = self._ensure_utc_datetime(start_date)
if end_date:
params["end_date"] = self._ensure_utc_datetime(end_date)
result = await self.session.execute(summary_query, params)
row = result.fetchone()
if row:
return {
"tenant_id": tenant_id,
"period_start": start_date,
"period_end": end_date,
"total_sales": row.total_records or 0,
"total_quantity": row.total_quantity or 0,
"total_revenue": float(row.total_revenue or 0),
"average_revenue": float(row.avg_revenue or 0),
"earliest_sale": row.earliest_sale,
"latest_sale": row.latest_sale,
"unique_products": row.unique_products or 0,
"unique_locations": row.unique_locations or 0
}
else:
return {
"tenant_id": tenant_id,
"period_start": start_date,
"period_end": end_date,
"total_sales": 0,
"total_quantity": 0,
"total_revenue": 0.0,
"average_revenue": 0.0,
"earliest_sale": None,
"latest_sale": None,
"unique_products": 0,
"unique_locations": 0
}
except Exception as e:
logger.error("Failed to get sales summary",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Sales summary failed: {str(e)}")
async def validate_sales_data(self, sales_data: Dict[str, Any]) -> Dict[str, Any]:
"""Validate sales data before insertion"""
errors = []
warnings = []
try:
# Check required fields
required_fields = ["date", "product_name", "quantity_sold", "revenue"]
for field in required_fields:
if field not in sales_data or sales_data[field] is None:
errors.append(f"Missing required field: {field}")
# Validate data types and ranges
if "quantity_sold" in sales_data:
if not isinstance(sales_data["quantity_sold"], (int, float)) or sales_data["quantity_sold"] <= 0:
errors.append("quantity_sold must be a positive number")
if "revenue" in sales_data:
if not isinstance(sales_data["revenue"], (int, float)) or sales_data["revenue"] <= 0:
errors.append("revenue must be a positive number")
# Validate string lengths
if "product_name" in sales_data and len(str(sales_data["product_name"])) > 255:
errors.append("product_name exceeds maximum length of 255 characters")
# Check for suspicious data
if "quantity_sold" in sales_data and "revenue" in sales_data:
unit_price = sales_data["revenue"] / sales_data["quantity_sold"]
if unit_price > 10000: # Arbitrary high price threshold
warnings.append(f"Unusually high unit price: {unit_price:.2f}")
elif unit_price < 0.01: # Very low price
warnings.append(f"Unusually low unit price: {unit_price:.2f}")
return {
"is_valid": len(errors) == 0,
"errors": errors,
"warnings": warnings
}
except Exception as e:
logger.error("Failed to validate sales data", error=str(e))
return {
"is_valid": False,
"errors": [f"Validation error: {str(e)}"],
"warnings": []
}
async def get_product_statistics(self, tenant_id: str) -> List[Dict[str, Any]]:
"""Get product statistics for tenant"""
try:
query = text("""
SELECT
product_name,
COUNT(*) as total_sales,
SUM(quantity_sold) as total_quantity,
SUM(revenue) as total_revenue,
AVG(revenue) as avg_revenue,
MIN(date) as first_sale,
MAX(date) as last_sale
FROM sales_data
WHERE tenant_id = :tenant_id
GROUP BY product_name
ORDER BY SUM(revenue) DESC
""")
result = await self.session.execute(query, {"tenant_id": tenant_id})
rows = result.fetchall()
products = []
for row in rows:
products.append({
"product_name": row.product_name,
"total_sales": int(row.total_sales or 0),
"total_quantity": int(row.total_quantity or 0),
"total_revenue": float(row.total_revenue or 0),
"avg_revenue": float(row.avg_revenue or 0),
"first_sale": row.first_sale.isoformat() if row.first_sale else None,
"last_sale": row.last_sale.isoformat() if row.last_sale else None
})
logger.debug(f"Found {len(products)} products for tenant {tenant_id}")
return products
except Exception as e:
logger.error(f"Error getting product statistics: {str(e)}", tenant_id=tenant_id)
return []