""" Base Repository for Forecasting Service Service-specific repository base class with forecasting utilities """ from typing import Optional, List, Dict, Any, Type from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import text from datetime import datetime, date, timedelta, timezone import structlog from shared.database.repository import BaseRepository from shared.database.exceptions import DatabaseError logger = structlog.get_logger() class ForecastingBaseRepository(BaseRepository): """Base repository for forecasting service with common forecasting operations""" def __init__(self, model: Type, session: AsyncSession, cache_ttl: Optional[int] = 600): # Forecasting data benefits from medium cache time (10 minutes) super().__init__(model, session, cache_ttl) async def get_by_tenant_id(self, tenant_id: str, skip: int = 0, limit: int = 100) -> List: """Get records by tenant ID""" if hasattr(self.model, 'tenant_id'): return await self.get_multi( skip=skip, limit=limit, filters={"tenant_id": tenant_id}, order_by="created_at", order_desc=True ) return await self.get_multi(skip=skip, limit=limit) async def get_by_inventory_product_id( self, tenant_id: str, inventory_product_id: str, skip: int = 0, limit: int = 100 ) -> List: """Get records by tenant and inventory product""" if hasattr(self.model, 'inventory_product_id'): return await self.get_multi( skip=skip, limit=limit, filters={ "tenant_id": tenant_id, "inventory_product_id": inventory_product_id }, order_by="created_at", order_desc=True ) return await self.get_by_tenant_id(tenant_id, skip, limit) async def get_by_date_range( self, tenant_id: str, start_date: datetime, end_date: datetime, skip: int = 0, limit: int = 100 ) -> List: """Get records within date range for a tenant""" if not hasattr(self.model, 'forecast_date') and not hasattr(self.model, 'created_at'): logger.warning(f"Model {self.model.__name__} has no date field for filtering") return [] try: table_name = self.model.__tablename__ date_field = "forecast_date" if hasattr(self.model, 'forecast_date') else "created_at" query_text = f""" SELECT * FROM {table_name} WHERE tenant_id = :tenant_id AND {date_field} >= :start_date AND {date_field} <= :end_date ORDER BY {date_field} DESC LIMIT :limit OFFSET :skip """ result = await self.session.execute(text(query_text), { "tenant_id": tenant_id, "start_date": start_date, "end_date": end_date, "limit": limit, "skip": skip }) # Convert rows to model objects records = [] for row in result.fetchall(): record_dict = dict(row._mapping) record = self.model(**record_dict) records.append(record) return records except Exception as e: logger.error("Failed to get records by date range", model=self.model.__name__, tenant_id=tenant_id, error=str(e)) raise DatabaseError(f"Date range query failed: {str(e)}") async def get_recent_records( self, tenant_id: str, hours: int = 24, skip: int = 0, limit: int = 100 ) -> List: """Get recent records for a tenant""" cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours) return await self.get_by_date_range( tenant_id, cutoff_time, datetime.now(timezone.utc), skip, limit ) async def cleanup_old_records(self, days_old: int = 90) -> int: """Clean up old forecasting records""" try: cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_old) table_name = self.model.__tablename__ # Use created_at or forecast_date for cleanup date_field = "forecast_date" if hasattr(self.model, 'forecast_date') else "created_at" query_text = f""" DELETE FROM {table_name} WHERE {date_field} < :cutoff_date """ result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date}) deleted_count = result.rowcount logger.info(f"Cleaned up old {self.model.__name__} records", deleted_count=deleted_count, days_old=days_old) return deleted_count except Exception as e: logger.error("Failed to cleanup old records", model=self.model.__name__, error=str(e)) raise DatabaseError(f"Cleanup failed: {str(e)}") async def get_statistics_by_tenant(self, tenant_id: str) -> Dict[str, Any]: """Get statistics for a tenant""" try: table_name = self.model.__tablename__ # Get basic counts total_records = await self.count(filters={"tenant_id": tenant_id}) # Get recent activity (records in last 7 days) seven_days_ago = datetime.now(timezone.utc) - timedelta(days=7) recent_records = len(await self.get_by_date_range( tenant_id, seven_days_ago, datetime.now(timezone.utc), limit=1000 )) # Get records by product if applicable product_stats = {} if hasattr(self.model, 'inventory_product_id'): product_query = text(f""" SELECT inventory_product_id, COUNT(*) as count FROM {table_name} WHERE tenant_id = :tenant_id GROUP BY inventory_product_id ORDER BY count DESC """) result = await self.session.execute(product_query, {"tenant_id": tenant_id}) product_stats = {row.inventory_product_id: row.count for row in result.fetchall()} return { "total_records": total_records, "recent_records_7d": recent_records, "records_by_product": product_stats } except Exception as e: logger.error("Failed to get tenant statistics", model=self.model.__name__, tenant_id=tenant_id, error=str(e)) return { "total_records": 0, "recent_records_7d": 0, "records_by_product": {} } def _validate_forecast_data(self, data: Dict[str, Any], required_fields: List[str]) -> Dict[str, Any]: """Validate forecasting-related data""" errors = [] for field in required_fields: if field not in data or data[field] is None: errors.append(f"Missing required field: {field}") # Validate tenant_id format if present if "tenant_id" in data and data["tenant_id"]: tenant_id = data["tenant_id"] if not isinstance(tenant_id, str) or len(tenant_id) < 1: errors.append("Invalid tenant_id format") # Validate inventory_product_id if present if "inventory_product_id" in data and data["inventory_product_id"]: inventory_product_id = data["inventory_product_id"] if not isinstance(inventory_product_id, str) or len(inventory_product_id) < 1: errors.append("Invalid inventory_product_id format") # Validate dates if present - accept datetime objects, date objects, and date strings date_fields = ["forecast_date", "created_at", "evaluation_date", "expires_at"] for field in date_fields: if field in data and data[field]: field_value = data[field] field_type = type(field_value).__name__ if isinstance(field_value, (datetime, date)): logger.debug(f"Date field {field} is valid {field_type}", field_value=str(field_value)) continue # Already a datetime or date, valid elif isinstance(field_value, str): # Try to parse the string date try: from dateutil.parser import parse parse(field_value) # Just validate, don't convert yet logger.debug(f"Date field {field} is valid string", field_value=field_value) except (ValueError, TypeError) as e: logger.error(f"Date parsing failed for {field}", field_value=field_value, error=str(e)) errors.append(f"Invalid {field} format - must be datetime or valid date string") else: logger.error(f"Date field {field} has invalid type {field_type}", field_value=str(field_value)) errors.append(f"Invalid {field} format - must be datetime or valid date string") # Validate numeric fields numeric_fields = [ "predicted_demand", "confidence_lower", "confidence_upper", "mae", "mape", "rmse", "accuracy_score" ] for field in numeric_fields: if field in data and data[field] is not None: try: float(data[field]) except (ValueError, TypeError): errors.append(f"Invalid {field} format - must be numeric") return { "is_valid": len(errors) == 0, "errors": errors }