""" Sales Repository Repository for sales data operations with business-specific queries """ from typing import Optional, List, Dict, Any, Type from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, and_, or_, func, desc, asc, text from datetime import datetime, timezone import structlog from .base import DataBaseRepository from app.models.sales import SalesData from app.schemas.sales import SalesDataCreate, SalesDataResponse from shared.database.exceptions import DatabaseError, ValidationError logger = structlog.get_logger() class SalesRepository(DataBaseRepository[SalesData, SalesDataCreate, Dict]): """Repository for sales data operations""" def __init__(self, model_class: Type, session: AsyncSession, cache_ttl: Optional[int] = 300): super().__init__(model_class, session, cache_ttl) async def get_by_tenant_and_date_range( self, tenant_id: str, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, product_names: Optional[List[str]] = None, location_ids: Optional[List[str]] = None, skip: int = 0, limit: int = 100 ) -> List[SalesData]: """Get sales data filtered by tenant, date range, and optional filters""" try: query = select(self.model).where(self.model.tenant_id == tenant_id) # Add date range filter if start_date: start_date = self._ensure_utc_datetime(start_date) query = query.where(self.model.date >= start_date) if end_date: end_date = self._ensure_utc_datetime(end_date) query = query.where(self.model.date <= end_date) # Add product filter if product_names: query = query.where(self.model.product_name.in_(product_names)) # Add location filter if location_ids: query = query.where(self.model.location_id.in_(location_ids)) # Order by date descending (most recent first) query = query.order_by(desc(self.model.date)) # Apply pagination query = query.offset(skip).limit(limit) result = await self.session.execute(query) return result.scalars().all() except Exception as e: logger.error("Failed to get sales by tenant and date range", tenant_id=tenant_id, start_date=start_date, end_date=end_date, error=str(e)) raise DatabaseError(f"Failed to get sales data: {str(e)}") async def get_sales_aggregation( self, tenant_id: str, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, group_by: str = "daily", product_name: Optional[str] = None ) -> List[Dict[str, Any]]: """Get aggregated sales data for analytics""" try: # Determine date truncation based on group_by if group_by == "daily": date_trunc = "day" elif group_by == "weekly": date_trunc = "week" elif group_by == "monthly": date_trunc = "month" else: raise ValidationError(f"Invalid group_by value: {group_by}") # Build base query if self.session.bind.dialect.name == 'postgresql': query = text(""" SELECT DATE_TRUNC(:date_trunc, date) as period, product_name, COUNT(*) as record_count, SUM(quantity_sold) as total_quantity, SUM(revenue) as total_revenue, AVG(quantity_sold) as average_quantity, AVG(revenue) as average_revenue FROM sales_data WHERE tenant_id = :tenant_id """) else: # SQLite fallback query = text(""" SELECT DATE(date) as period, product_name, COUNT(*) as record_count, SUM(quantity_sold) as total_quantity, SUM(revenue) as total_revenue, AVG(quantity_sold) as average_quantity, AVG(revenue) as average_revenue FROM sales_data WHERE tenant_id = :tenant_id """) params = { "tenant_id": tenant_id, "date_trunc": date_trunc } # Add date filters if start_date: query = text(str(query) + " AND date >= :start_date") params["start_date"] = self._ensure_utc_datetime(start_date) if end_date: query = text(str(query) + " AND date <= :end_date") params["end_date"] = self._ensure_utc_datetime(end_date) # Add product filter if product_name: query = text(str(query) + " AND product_name = :product_name") params["product_name"] = product_name # Add GROUP BY and ORDER BY query = text(str(query) + " GROUP BY period, product_name ORDER BY period DESC") result = await self.session.execute(query, params) rows = result.fetchall() # Convert to list of dictionaries aggregations = [] for row in rows: aggregations.append({ "period": group_by, "date": row.period, "product_name": row.product_name, "record_count": row.record_count, "total_quantity": row.total_quantity, "total_revenue": float(row.total_revenue), "average_quantity": float(row.average_quantity), "average_revenue": float(row.average_revenue) }) return aggregations except Exception as e: logger.error("Failed to get sales aggregation", tenant_id=tenant_id, group_by=group_by, error=str(e)) raise DatabaseError(f"Sales aggregation failed: {str(e)}") async def get_top_products( self, tenant_id: str, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, limit: int = 10, by_metric: str = "revenue" ) -> List[Dict[str, Any]]: """Get top products by quantity or revenue""" try: if by_metric not in ["revenue", "quantity"]: raise ValidationError(f"Invalid metric: {by_metric}") # Choose the aggregation column metric_column = "revenue" if by_metric == "revenue" else "quantity_sold" query = text(f""" SELECT product_name, COUNT(*) as sale_count, SUM(quantity_sold) as total_quantity, SUM(revenue) as total_revenue, AVG(revenue) as avg_revenue_per_sale FROM sales_data WHERE tenant_id = :tenant_id {('AND date >= :start_date' if start_date else '')} {('AND date <= :end_date' if end_date else '')} GROUP BY product_name ORDER BY SUM({metric_column}) DESC LIMIT :limit """) params = {"tenant_id": tenant_id, "limit": limit} if start_date: params["start_date"] = self._ensure_utc_datetime(start_date) if end_date: params["end_date"] = self._ensure_utc_datetime(end_date) result = await self.session.execute(query, params) rows = result.fetchall() products = [] for row in rows: products.append({ "product_name": row.product_name, "sale_count": row.sale_count, "total_quantity": row.total_quantity, "total_revenue": float(row.total_revenue), "avg_revenue_per_sale": float(row.avg_revenue_per_sale), "metric_used": by_metric }) return products except Exception as e: logger.error("Failed to get top products", tenant_id=tenant_id, by_metric=by_metric, error=str(e)) raise DatabaseError(f"Top products query failed: {str(e)}") async def get_sales_by_location( self, tenant_id: str, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None ) -> List[Dict[str, Any]]: """Get sales statistics by location""" try: query = text(""" SELECT COALESCE(location_id, 'unknown') as location_id, COUNT(*) as sale_count, SUM(quantity_sold) as total_quantity, SUM(revenue) as total_revenue, AVG(revenue) as avg_revenue_per_sale FROM sales_data WHERE tenant_id = :tenant_id {date_filters} GROUP BY location_id ORDER BY SUM(revenue) DESC """.format( date_filters=( "AND date >= :start_date" if start_date else "" ) + ( " AND date <= :end_date" if end_date else "" ) )) params = {"tenant_id": tenant_id} if start_date: params["start_date"] = self._ensure_utc_datetime(start_date) if end_date: params["end_date"] = self._ensure_utc_datetime(end_date) result = await self.session.execute(query, params) rows = result.fetchall() locations = [] for row in rows: locations.append({ "location_id": row.location_id, "sale_count": row.sale_count, "total_quantity": row.total_quantity, "total_revenue": float(row.total_revenue), "avg_revenue_per_sale": float(row.avg_revenue_per_sale) }) return locations except Exception as e: logger.error("Failed to get sales by location", tenant_id=tenant_id, error=str(e)) raise DatabaseError(f"Sales by location query failed: {str(e)}") async def create_bulk_sales( self, sales_records: List[Dict[str, Any]], tenant_id: str ) -> List[SalesData]: """Create multiple sales records in bulk""" try: # Ensure all records have tenant_id for record in sales_records: record["tenant_id"] = tenant_id # Ensure dates are timezone-aware if "date" in record and record["date"]: record["date"] = self._ensure_utc_datetime(record["date"]) return await self.bulk_create(sales_records) except Exception as e: logger.error("Failed to create bulk sales", tenant_id=tenant_id, record_count=len(sales_records), error=str(e)) raise DatabaseError(f"Bulk sales creation failed: {str(e)}") async def search_sales( self, tenant_id: str, search_term: str, skip: int = 0, limit: int = 100 ) -> List[SalesData]: """Search sales by product name or notes""" try: # Use the parent search method with sales-specific fields search_fields = ["product_name", "notes", "location_id"] # Filter by tenant first query = select(self.model).where( and_( self.model.tenant_id == tenant_id, or_( self.model.product_name.ilike(f"%{search_term}%"), self.model.notes.ilike(f"%{search_term}%") if hasattr(self.model, 'notes') else False, self.model.location_id.ilike(f"%{search_term}%") if hasattr(self.model, 'location_id') else False ) ) ).order_by(desc(self.model.date)).offset(skip).limit(limit) result = await self.session.execute(query) return result.scalars().all() except Exception as e: logger.error("Failed to search sales", tenant_id=tenant_id, search_term=search_term, error=str(e)) raise DatabaseError(f"Sales search failed: {str(e)}") async def get_sales_summary( self, tenant_id: str, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None ) -> Dict[str, Any]: """Get comprehensive sales summary for a tenant""" try: base_filters = {"tenant_id": tenant_id} # Build date filter for count date_query = select(func.count(self.model.id)).where(self.model.tenant_id == tenant_id) if start_date: date_query = date_query.where(self.model.date >= self._ensure_utc_datetime(start_date)) if end_date: date_query = date_query.where(self.model.date <= self._ensure_utc_datetime(end_date)) # Get basic counts total_result = await self.session.execute(date_query) total_sales = total_result.scalar() or 0 # Get revenue and quantity totals summary_query = text(""" SELECT COUNT(*) as total_records, SUM(quantity_sold) as total_quantity, SUM(revenue) as total_revenue, AVG(revenue) as avg_revenue, MIN(date) as earliest_sale, MAX(date) as latest_sale, COUNT(DISTINCT product_name) as unique_products, COUNT(DISTINCT location_id) as unique_locations FROM sales_data WHERE tenant_id = :tenant_id {date_filters} """.format( date_filters=( "AND date >= :start_date" if start_date else "" ) + ( " AND date <= :end_date" if end_date else "" ) )) params = {"tenant_id": tenant_id} if start_date: params["start_date"] = self._ensure_utc_datetime(start_date) if end_date: params["end_date"] = self._ensure_utc_datetime(end_date) result = await self.session.execute(summary_query, params) row = result.fetchone() if row: return { "tenant_id": tenant_id, "period_start": start_date, "period_end": end_date, "total_sales": row.total_records or 0, "total_quantity": row.total_quantity or 0, "total_revenue": float(row.total_revenue or 0), "average_revenue": float(row.avg_revenue or 0), "earliest_sale": row.earliest_sale, "latest_sale": row.latest_sale, "unique_products": row.unique_products or 0, "unique_locations": row.unique_locations or 0 } else: return { "tenant_id": tenant_id, "period_start": start_date, "period_end": end_date, "total_sales": 0, "total_quantity": 0, "total_revenue": 0.0, "average_revenue": 0.0, "earliest_sale": None, "latest_sale": None, "unique_products": 0, "unique_locations": 0 } except Exception as e: logger.error("Failed to get sales summary", tenant_id=tenant_id, error=str(e)) raise DatabaseError(f"Sales summary failed: {str(e)}") async def validate_sales_data(self, sales_data: Dict[str, Any]) -> Dict[str, Any]: """Validate sales data before insertion""" errors = [] warnings = [] try: # Check required fields required_fields = ["date", "product_name", "quantity_sold", "revenue"] for field in required_fields: if field not in sales_data or sales_data[field] is None: errors.append(f"Missing required field: {field}") # Validate data types and ranges if "quantity_sold" in sales_data: if not isinstance(sales_data["quantity_sold"], (int, float)) or sales_data["quantity_sold"] <= 0: errors.append("quantity_sold must be a positive number") if "revenue" in sales_data: if not isinstance(sales_data["revenue"], (int, float)) or sales_data["revenue"] <= 0: errors.append("revenue must be a positive number") # Validate string lengths if "product_name" in sales_data and len(str(sales_data["product_name"])) > 255: errors.append("product_name exceeds maximum length of 255 characters") # Check for suspicious data if "quantity_sold" in sales_data and "revenue" in sales_data: unit_price = sales_data["revenue"] / sales_data["quantity_sold"] if unit_price > 10000: # Arbitrary high price threshold warnings.append(f"Unusually high unit price: {unit_price:.2f}") elif unit_price < 0.01: # Very low price warnings.append(f"Unusually low unit price: {unit_price:.2f}") return { "is_valid": len(errors) == 0, "errors": errors, "warnings": warnings } except Exception as e: logger.error("Failed to validate sales data", error=str(e)) return { "is_valid": False, "errors": [f"Validation error: {str(e)}"], "warnings": [] } async def get_product_statistics(self, tenant_id: str) -> List[Dict[str, Any]]: """Get product statistics for tenant""" try: query = text(""" SELECT product_name, COUNT(*) as total_sales, SUM(quantity_sold) as total_quantity, SUM(revenue) as total_revenue, AVG(revenue) as avg_revenue, MIN(date) as first_sale, MAX(date) as last_sale FROM sales_data WHERE tenant_id = :tenant_id GROUP BY product_name ORDER BY SUM(revenue) DESC """) result = await self.session.execute(query, {"tenant_id": tenant_id}) rows = result.fetchall() products = [] for row in rows: products.append({ "product_name": row.product_name, "total_sales": int(row.total_sales or 0), "total_quantity": int(row.total_quantity or 0), "total_revenue": float(row.total_revenue or 0), "avg_revenue": float(row.avg_revenue or 0), "first_sale": row.first_sale.isoformat() if row.first_sale else None, "last_sale": row.last_sale.isoformat() if row.last_sale else None }) logger.debug(f"Found {len(products)} products for tenant {tenant_id}") return products except Exception as e: logger.error(f"Error getting product statistics: {str(e)}", tenant_id=tenant_id) return []