# services/sales/app/repositories/sales_repository.py """ Sales Repository using Repository Pattern """ from typing import List, Optional, Dict, Any from uuid import UUID from datetime import datetime from sqlalchemy import select, func, and_, or_, desc, asc from sqlalchemy.ext.asyncio import AsyncSession import structlog from app.models.sales import SalesData from app.schemas.sales import SalesDataCreate, SalesDataUpdate, SalesDataQuery from shared.database.repository import BaseRepository logger = structlog.get_logger() class SalesRepository(BaseRepository[SalesData, SalesDataCreate, SalesDataUpdate]): """Repository for sales data operations""" def __init__(self, session: AsyncSession): super().__init__(SalesData, session) async def create_sales_record(self, sales_data: SalesDataCreate, tenant_id: UUID) -> SalesData: """Create a new sales record""" try: # Prepare data create_data = sales_data.model_dump() create_data['tenant_id'] = tenant_id # Calculate weekend flag if not provided if sales_data.date and create_data.get('is_weekend') is None: create_data['is_weekend'] = sales_data.date.weekday() >= 5 # Create record record = await self.create(create_data) logger.info( "Created sales record", record_id=record.id, inventory_product_id=record.inventory_product_id, quantity=record.quantity_sold, tenant_id=tenant_id ) return record except Exception as e: logger.error("Failed to create sales record", error=str(e), tenant_id=tenant_id) raise async def get_by_tenant( self, tenant_id: UUID, query_params: Optional[SalesDataQuery] = None ) -> List[SalesData]: """Get sales records by tenant with optional filtering""" try: # Build base query stmt = select(SalesData).where(SalesData.tenant_id == tenant_id) # Apply filters if query_params provided if query_params: if query_params.start_date: stmt = stmt.where(SalesData.date >= query_params.start_date) if query_params.end_date: stmt = stmt.where(SalesData.date <= query_params.end_date) # Note: product_name queries now require joining with inventory service # if query_params.product_name: # # Would need to join with inventory service to filter by product name # pass # Note: product_category field was removed - filtering by category now requires inventory service # if query_params.product_category: # # Would need to join with inventory service to filter by product category # pass if hasattr(query_params, 'inventory_product_id') and query_params.inventory_product_id: stmt = stmt.where(SalesData.inventory_product_id == query_params.inventory_product_id) if query_params.location_id: stmt = stmt.where(SalesData.location_id == query_params.location_id) if query_params.sales_channel: stmt = stmt.where(SalesData.sales_channel == query_params.sales_channel) if query_params.source: stmt = stmt.where(SalesData.source == query_params.source) if query_params.is_validated is not None: stmt = stmt.where(SalesData.is_validated == query_params.is_validated) # Apply ordering if query_params.order_by and hasattr(SalesData, query_params.order_by): order_col = getattr(SalesData, query_params.order_by) if query_params.order_direction == 'asc': stmt = stmt.order_by(asc(order_col)) else: stmt = stmt.order_by(desc(order_col)) else: stmt = stmt.order_by(desc(SalesData.date)) # Apply pagination stmt = stmt.offset(query_params.offset).limit(query_params.limit) else: # Default ordering with safety limit for direct repository calls # Note: API calls always provide query_params, so this only applies to direct usage stmt = stmt.order_by(desc(SalesData.date)).limit(10000) result = await self.session.execute(stmt) records = result.scalars().all() logger.info( "Retrieved sales records", count=len(records), tenant_id=tenant_id ) return list(records) except Exception as e: logger.error("Failed to get sales records", error=str(e), tenant_id=tenant_id) raise async def get_by_inventory_product( self, tenant_id: UUID, inventory_product_id: UUID, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None ) -> List[SalesData]: """Get sales records for a specific inventory product""" try: stmt = select(SalesData).where( and_( SalesData.tenant_id == tenant_id, SalesData.inventory_product_id == inventory_product_id ) ) if start_date: stmt = stmt.where(SalesData.date >= start_date) if end_date: stmt = stmt.where(SalesData.date <= end_date) stmt = stmt.order_by(desc(SalesData.date)) result = await self.session.execute(stmt) records = result.scalars().all() return list(records) except Exception as e: logger.error("Failed to get product sales", error=str(e), tenant_id=tenant_id, inventory_product_id=inventory_product_id) raise async def get_analytics( self, tenant_id: UUID, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None ) -> Dict[str, Any]: """Get sales analytics for a tenant""" try: # Build base query base_query = select(SalesData).where(SalesData.tenant_id == tenant_id) if start_date: base_query = base_query.where(SalesData.date >= start_date) if end_date: base_query = base_query.where(SalesData.date <= end_date) # Total revenue and quantity summary_query = select( func.sum(SalesData.revenue).label('total_revenue'), func.sum(SalesData.quantity_sold).label('total_quantity'), func.count().label('total_transactions'), func.avg(SalesData.revenue).label('avg_transaction_value') ).where(SalesData.tenant_id == tenant_id) if start_date: summary_query = summary_query.where(SalesData.date >= start_date) if end_date: summary_query = summary_query.where(SalesData.date <= end_date) result = await self.session.execute(summary_query) summary = result.first() # Top products top_products_query = select( SalesData.inventory_product_id, # Note: was product_name func.sum(SalesData.revenue).label('revenue'), func.sum(SalesData.quantity_sold).label('quantity') ).where(SalesData.tenant_id == tenant_id) if start_date: top_products_query = top_products_query.where(SalesData.date >= start_date) if end_date: top_products_query = top_products_query.where(SalesData.date <= end_date) top_products_query = top_products_query.group_by( SalesData.inventory_product_id # Note: was product_name ).order_by( desc(func.sum(SalesData.revenue)) ).limit(10) top_products_result = await self.session.execute(top_products_query) top_products = [ { 'inventory_product_id': str(row.inventory_product_id), # Note: was product_name 'revenue': float(row.revenue) if row.revenue else 0, 'quantity': row.quantity or 0 } for row in top_products_result ] # Sales by channel channel_query = select( SalesData.sales_channel, func.sum(SalesData.revenue).label('revenue'), func.count().label('transactions') ).where(SalesData.tenant_id == tenant_id) if start_date: channel_query = channel_query.where(SalesData.date >= start_date) if end_date: channel_query = channel_query.where(SalesData.date <= end_date) channel_query = channel_query.group_by(SalesData.sales_channel) channel_result = await self.session.execute(channel_query) sales_by_channel = { row.sales_channel: { 'revenue': float(row.revenue) if row.revenue else 0, 'transactions': row.transactions or 0 } for row in channel_result } return { 'total_revenue': float(summary.total_revenue) if summary.total_revenue else 0, 'total_quantity': summary.total_quantity or 0, 'total_transactions': summary.total_transactions or 0, 'average_transaction_value': float(summary.avg_transaction_value) if summary.avg_transaction_value else 0, 'top_products': top_products, 'sales_by_channel': sales_by_channel } except Exception as e: logger.error("Failed to get sales analytics", error=str(e), tenant_id=tenant_id) raise async def get_product_categories(self, tenant_id: UUID) -> List[str]: """Get distinct product categories for a tenant""" try: # Note: product_category field was removed - categories now managed via inventory service # This method should be updated to query categories from inventory service # For now, return empty list to avoid breaking existing code logger.warning("get_product_categories called but product_category field was removed", tenant_id=tenant_id) categories = [] return sorted(categories) except Exception as e: logger.error("Failed to get product categories", error=str(e), tenant_id=tenant_id) raise async def validate_record(self, record_id: UUID, validation_notes: Optional[str] = None) -> SalesData: """Mark a sales record as validated""" try: record = await self.get_by_id(record_id) if not record: raise ValueError(f"Sales record {record_id} not found") update_data = { 'is_validated': True, 'validation_notes': validation_notes } updated_record = await self.update(record_id, update_data) logger.info("Validated sales record", record_id=record_id) return updated_record except Exception as e: logger.error("Failed to validate sales record", error=str(e), record_id=record_id) raise async def create_sales_records_bulk( self, sales_data_list: List[SalesDataCreate], tenant_id: UUID ) -> int: """Bulk insert sales records for performance optimization""" try: from uuid import uuid4 records = [] for sales_data in sales_data_list: is_weekend = sales_data.date.weekday() >= 5 if sales_data.date else False record = SalesData( id=uuid4(), tenant_id=tenant_id, date=sales_data.date, inventory_product_id=sales_data.inventory_product_id, quantity_sold=sales_data.quantity_sold, unit_price=sales_data.unit_price, revenue=sales_data.revenue, location_id=sales_data.location_id, sales_channel=sales_data.sales_channel, source=sales_data.source, is_weekend=is_weekend, is_validated=getattr(sales_data, 'is_validated', False) ) records.append(record) self.session.add_all(records) await self.session.flush() logger.info( "Bulk created sales records", count=len(records), tenant_id=tenant_id ) return len(records) except Exception as e: logger.error("Failed to bulk create sales records", error=str(e), tenant_id=tenant_id) raise