""" Model Repository Repository for trained model operations """ from typing import Optional, List, Dict, Any from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, and_, text, desc from datetime import datetime, timezone, timedelta import structlog from .base import TrainingBaseRepository from app.models.training import TrainedModel from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError logger = structlog.get_logger() class ModelRepository(TrainingBaseRepository): """Repository for trained model operations""" def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 600): # Models are relatively stable, longer cache time (10 minutes) super().__init__(TrainedModel, session, cache_ttl) async def create_model(self, model_data: Dict[str, Any]) -> TrainedModel: """Create a new trained model with validation""" try: # Validate model data validation_result = self._validate_training_data( model_data, ["tenant_id", "inventory_product_id", "model_path", "job_id"] ) if not validation_result["is_valid"]: raise ValidationError(f"Invalid model data: {validation_result['errors']}") # Check for duplicate active models for same tenant+product existing_model = await self.get_active_model_for_product( model_data["tenant_id"], model_data["inventory_product_id"] ) # If there's an existing active model, we may want to deactivate it if existing_model and model_data.get("is_production", False): logger.info("Deactivating previous production model", previous_model_id=existing_model.id, tenant_id=model_data["tenant_id"], inventory_product_id=model_data["inventory_product_id"]) await self.update(existing_model.id, {"is_production": False}) # Create new model model = await self.create(model_data) logger.info("Trained model created successfully", model_id=model.id, tenant_id=model.tenant_id, inventory_product_id=str(model.inventory_product_id), model_type=model.model_type) return model except (ValidationError, DuplicateRecordError): raise except Exception as e: logger.error("Failed to create trained model", tenant_id=model_data.get("tenant_id"), inventory_product_id=model_data.get("inventory_product_id"), error=str(e)) raise DatabaseError(f"Failed to create model: {str(e)}") async def get_model_by_tenant_and_product( self, tenant_id: str, inventory_product_id: str ) -> List[TrainedModel]: """Get all models for a tenant and product""" try: return await self.get_multi( filters={ "tenant_id": tenant_id, "inventory_product_id": inventory_product_id }, order_by="created_at", order_desc=True ) except Exception as e: logger.error("Failed to get models by tenant and product", tenant_id=tenant_id, inventory_product_id=inventory_product_id, error=str(e)) raise DatabaseError(f"Failed to get models: {str(e)}") async def get_active_model_for_product( self, tenant_id: str, inventory_product_id: str ) -> Optional[TrainedModel]: """Get the active production model for a product""" try: models = await self.get_multi( filters={ "tenant_id": tenant_id, "inventory_product_id": inventory_product_id, "is_active": True, "is_production": True }, order_by="created_at", order_desc=True, limit=1 ) return models[0] if models else None except Exception as e: logger.error("Failed to get active model for product", tenant_id=tenant_id, inventory_product_id=inventory_product_id, error=str(e)) raise DatabaseError(f"Failed to get active model: {str(e)}") async def get_models_by_tenant( self, tenant_id: str, skip: int = 0, limit: int = 100 ) -> List[TrainedModel]: """Get all models for a tenant""" return await self.get_by_tenant_id(tenant_id, skip=skip, limit=limit) async def promote_to_production(self, model_id: str) -> Optional[TrainedModel]: """Promote a model to production""" try: # Get the model first model = await self.get_by_id(model_id) if not model: raise ValueError(f"Model {model_id} not found") # Deactivate other production models for the same tenant+product await self._deactivate_other_production_models( model.tenant_id, str(model.inventory_product_id), model_id ) # Promote this model updated_model = await self.update(model_id, { "is_production": True, "last_used_at": datetime.now(timezone.utc) }) logger.info("Model promoted to production", model_id=model_id, tenant_id=model.tenant_id, inventory_product_id=str(model.inventory_product_id)) return updated_model except Exception as e: logger.error("Failed to promote model to production", model_id=model_id, error=str(e)) raise DatabaseError(f"Failed to promote model: {str(e)}") async def update_model_usage(self, model_id: str) -> Optional[TrainedModel]: """Update model last used timestamp""" try: return await self.update(model_id, { "last_used_at": datetime.now(timezone.utc) }) except Exception as e: logger.error("Failed to update model usage", model_id=model_id, error=str(e)) # Don't raise here - usage update is not critical return None async def archive_old_models(self, tenant_id: str, days_old: int = 90) -> int: """Archive old non-production models""" try: cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_old) query = text(""" UPDATE trained_models SET is_active = false WHERE tenant_id = :tenant_id AND is_production = false AND created_at < :cutoff_date AND is_active = true """) result = await self.session.execute(query, { "tenant_id": tenant_id, "cutoff_date": cutoff_date }) archived_count = result.rowcount logger.info("Archived old models", tenant_id=tenant_id, archived_count=archived_count, days_old=days_old) return archived_count except Exception as e: logger.error("Failed to archive old models", tenant_id=tenant_id, error=str(e)) raise DatabaseError(f"Model archival failed: {str(e)}") async def get_model_statistics(self, tenant_id: str) -> Dict[str, Any]: """Get model statistics for a tenant""" try: # Get basic counts total_models = await self.count(filters={"tenant_id": tenant_id}) active_models = await self.count(filters={ "tenant_id": tenant_id, "is_active": True }) production_models = await self.count(filters={ "tenant_id": tenant_id, "is_production": True }) # Get models by product using raw query product_query = text(""" SELECT inventory_product_id, COUNT(*) as count FROM trained_models WHERE tenant_id = :tenant_id AND is_active = true GROUP BY inventory_product_id ORDER BY count DESC """) result = await self.session.execute(product_query, {"tenant_id": tenant_id}) product_stats = {row.inventory_product_id: row.count for row in result.fetchall()} # Recent activity (models created in last 30 days) thirty_days_ago = datetime.now(timezone.utc) - timedelta(days=30) recent_models_query = text(""" SELECT COUNT(*) as count FROM trained_models WHERE tenant_id = :tenant_id AND created_at >= :thirty_days_ago """) recent_result = await self.session.execute( recent_models_query, {"tenant_id": tenant_id, "thirty_days_ago": thirty_days_ago} ) recent_models = recent_result.scalar() or 0 # Calculate average accuracy from model metrics accuracy_query = text(""" SELECT AVG(mape) as average_mape, COUNT(*) as total_models_with_metrics FROM trained_models WHERE tenant_id = :tenant_id AND mape IS NOT NULL AND is_active = true """) accuracy_result = await self.session.execute(accuracy_query, {"tenant_id": tenant_id}) accuracy_row = accuracy_result.fetchone() average_mape = accuracy_row.average_mape if accuracy_row and accuracy_row.average_mape else 0 total_models_with_metrics = accuracy_row.total_models_with_metrics if accuracy_row else 0 # Convert MAPE to accuracy percentage (lower MAPE = higher accuracy) # Use 100 - MAPE as a simple conversion, but cap it at reasonable bounds # Return None if no models have metrics (no data), rather than 0 if total_models_with_metrics == 0: average_accuracy = None else: average_accuracy = max(0, min(100, 100 - float(average_mape))) if average_mape > 0 else 0 return { "total_models": total_models, "active_models": active_models, "inactive_models": total_models - active_models, "production_models": production_models, "models_by_product": product_stats, "recent_models_30d": recent_models, "average_accuracy": average_accuracy, "total_models_with_metrics": total_models_with_metrics, "average_mape": float(average_mape) if average_mape > 0 else 0 } except Exception as e: logger.error("Failed to get model statistics", tenant_id=tenant_id, error=str(e)) return { "total_models": 0, "active_models": 0, "inactive_models": 0, "production_models": 0, "models_by_product": {}, "recent_models_30d": 0, "average_accuracy": 0, "total_models_with_metrics": 0, "average_mape": 0 } async def _deactivate_other_production_models( self, tenant_id: str, inventory_product_id: str, exclude_model_id: str ) -> int: """Deactivate other production models for the same tenant+product""" try: query = text(""" UPDATE trained_models SET is_production = false WHERE tenant_id = :tenant_id AND inventory_product_id = :inventory_product_id AND id != :exclude_model_id AND is_production = true """) result = await self.session.execute(query, { "tenant_id": tenant_id, "inventory_product_id": inventory_product_id, "exclude_model_id": exclude_model_id }) return result.rowcount except Exception as e: logger.error("Failed to deactivate other production models", tenant_id=tenant_id, inventory_product_id=inventory_product_id, error=str(e)) raise DatabaseError(f"Failed to deactivate models: {str(e)}") async def get_model_performance_summary(self, model_id: str) -> Dict[str, Any]: """Get performance summary for a model""" try: model = await self.get_by_id(model_id) if not model: return {} return { "model_id": model.id, "tenant_id": model.tenant_id, "inventory_product_id": str(model.inventory_product_id), "model_type": model.model_type, "metrics": { "mape": model.mape, "mae": model.mae, "rmse": model.rmse, "r2_score": model.r2_score }, "training_info": { "training_samples": model.training_samples, "training_start_date": model.training_start_date.isoformat() if model.training_start_date else None, "training_end_date": model.training_end_date.isoformat() if model.training_end_date else None, "data_quality_score": model.data_quality_score }, "status": { "is_active": model.is_active, "is_production": model.is_production, "created_at": model.created_at.isoformat() if model.created_at else None, "last_used_at": model.last_used_at.isoformat() if model.last_used_at else None }, "features": { "hyperparameters": model.hyperparameters, "features_used": model.features_used } } except Exception as e: logger.error("Failed to get model performance summary", model_id=model_id, error=str(e)) return {}