Files
bakery-ia/services/training/app/repositories/model_repository.py
2025-08-08 09:08:41 +02:00

346 lines
13 KiB
Python

"""
Model Repository
Repository for trained model operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, text, desc
from datetime import datetime, timedelta
import structlog
from .base import TrainingBaseRepository
from app.models.training import TrainedModel
from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError
logger = structlog.get_logger()
class ModelRepository(TrainingBaseRepository):
"""Repository for trained model operations"""
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 600):
# Models are relatively stable, longer cache time (10 minutes)
super().__init__(TrainedModel, session, cache_ttl)
async def create_model(self, model_data: Dict[str, Any]) -> TrainedModel:
"""Create a new trained model with validation"""
try:
# Validate model data
validation_result = self._validate_training_data(
model_data,
["tenant_id", "product_name", "model_path", "job_id"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid model data: {validation_result['errors']}")
# Check for duplicate active models for same tenant+product
existing_model = await self.get_active_model_for_product(
model_data["tenant_id"],
model_data["product_name"]
)
# If there's an existing active model, we may want to deactivate it
if existing_model and model_data.get("is_production", False):
logger.info("Deactivating previous production model",
previous_model_id=existing_model.id,
tenant_id=model_data["tenant_id"],
product_name=model_data["product_name"])
await self.update(existing_model.id, {"is_production": False})
# Create new model
model = await self.create(model_data)
logger.info("Trained model created successfully",
model_id=model.id,
tenant_id=model.tenant_id,
product_name=model.product_name,
model_type=model.model_type)
return model
except (ValidationError, DuplicateRecordError):
raise
except Exception as e:
logger.error("Failed to create trained model",
tenant_id=model_data.get("tenant_id"),
product_name=model_data.get("product_name"),
error=str(e))
raise DatabaseError(f"Failed to create model: {str(e)}")
async def get_model_by_tenant_and_product(
self,
tenant_id: str,
product_name: str
) -> List[TrainedModel]:
"""Get all models for a tenant and product"""
try:
return await self.get_multi(
filters={
"tenant_id": tenant_id,
"product_name": product_name
},
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get models by tenant and product",
tenant_id=tenant_id,
product_name=product_name,
error=str(e))
raise DatabaseError(f"Failed to get models: {str(e)}")
async def get_active_model_for_product(
self,
tenant_id: str,
product_name: str
) -> Optional[TrainedModel]:
"""Get the active production model for a product"""
try:
models = await self.get_multi(
filters={
"tenant_id": tenant_id,
"product_name": product_name,
"is_active": True,
"is_production": True
},
order_by="created_at",
order_desc=True,
limit=1
)
return models[0] if models else None
except Exception as e:
logger.error("Failed to get active model for product",
tenant_id=tenant_id,
product_name=product_name,
error=str(e))
raise DatabaseError(f"Failed to get active model: {str(e)}")
async def get_models_by_tenant(
self,
tenant_id: str,
skip: int = 0,
limit: int = 100
) -> List[TrainedModel]:
"""Get all models for a tenant"""
return await self.get_by_tenant_id(tenant_id, skip=skip, limit=limit)
async def promote_to_production(self, model_id: str) -> Optional[TrainedModel]:
"""Promote a model to production"""
try:
# Get the model first
model = await self.get_by_id(model_id)
if not model:
raise ValueError(f"Model {model_id} not found")
# Deactivate other production models for the same tenant+product
await self._deactivate_other_production_models(
model.tenant_id,
model.product_name,
model_id
)
# Promote this model
updated_model = await self.update(model_id, {
"is_production": True,
"last_used_at": datetime.utcnow()
})
logger.info("Model promoted to production",
model_id=model_id,
tenant_id=model.tenant_id,
product_name=model.product_name)
return updated_model
except Exception as e:
logger.error("Failed to promote model to production",
model_id=model_id,
error=str(e))
raise DatabaseError(f"Failed to promote model: {str(e)}")
async def update_model_usage(self, model_id: str) -> Optional[TrainedModel]:
"""Update model last used timestamp"""
try:
return await self.update(model_id, {
"last_used_at": datetime.utcnow()
})
except Exception as e:
logger.error("Failed to update model usage",
model_id=model_id,
error=str(e))
# Don't raise here - usage update is not critical
return None
async def archive_old_models(self, tenant_id: str, days_old: int = 90) -> int:
"""Archive old non-production models"""
try:
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
query = text("""
UPDATE trained_models
SET is_active = false
WHERE tenant_id = :tenant_id
AND is_production = false
AND created_at < :cutoff_date
AND is_active = true
""")
result = await self.session.execute(query, {
"tenant_id": tenant_id,
"cutoff_date": cutoff_date
})
archived_count = result.rowcount
logger.info("Archived old models",
tenant_id=tenant_id,
archived_count=archived_count,
days_old=days_old)
return archived_count
except Exception as e:
logger.error("Failed to archive old models",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Model archival failed: {str(e)}")
async def get_model_statistics(self, tenant_id: str) -> Dict[str, Any]:
"""Get model statistics for a tenant"""
try:
# Get basic counts
total_models = await self.count(filters={"tenant_id": tenant_id})
active_models = await self.count(filters={
"tenant_id": tenant_id,
"is_active": True
})
production_models = await self.count(filters={
"tenant_id": tenant_id,
"is_production": True
})
# Get models by product using raw query
product_query = text("""
SELECT product_name, COUNT(*) as count
FROM trained_models
WHERE tenant_id = :tenant_id
AND is_active = true
GROUP BY product_name
ORDER BY count DESC
""")
result = await self.session.execute(product_query, {"tenant_id": tenant_id})
product_stats = {row.product_name: row.count for row in result.fetchall()}
# Recent activity (models created in last 30 days)
thirty_days_ago = datetime.utcnow() - timedelta(days=30)
recent_models_query = text("""
SELECT COUNT(*) as count
FROM trained_models
WHERE tenant_id = :tenant_id
AND created_at >= :thirty_days_ago
""")
recent_result = await self.session.execute(
recent_models_query,
{"tenant_id": tenant_id, "thirty_days_ago": thirty_days_ago}
)
recent_models = recent_result.scalar() or 0
return {
"total_models": total_models,
"active_models": active_models,
"inactive_models": total_models - active_models,
"production_models": production_models,
"models_by_product": product_stats,
"recent_models_30d": recent_models
}
except Exception as e:
logger.error("Failed to get model statistics",
tenant_id=tenant_id,
error=str(e))
return {
"total_models": 0,
"active_models": 0,
"inactive_models": 0,
"production_models": 0,
"models_by_product": {},
"recent_models_30d": 0
}
async def _deactivate_other_production_models(
self,
tenant_id: str,
product_name: str,
exclude_model_id: str
) -> int:
"""Deactivate other production models for the same tenant+product"""
try:
query = text("""
UPDATE trained_models
SET is_production = false
WHERE tenant_id = :tenant_id
AND product_name = :product_name
AND id != :exclude_model_id
AND is_production = true
""")
result = await self.session.execute(query, {
"tenant_id": tenant_id,
"product_name": product_name,
"exclude_model_id": exclude_model_id
})
return result.rowcount
except Exception as e:
logger.error("Failed to deactivate other production models",
tenant_id=tenant_id,
product_name=product_name,
error=str(e))
raise DatabaseError(f"Failed to deactivate models: {str(e)}")
async def get_model_performance_summary(self, model_id: str) -> Dict[str, Any]:
"""Get performance summary for a model"""
try:
model = await self.get_by_id(model_id)
if not model:
return {}
return {
"model_id": model.id,
"tenant_id": model.tenant_id,
"product_name": model.product_name,
"model_type": model.model_type,
"metrics": {
"mape": model.mape,
"mae": model.mae,
"rmse": model.rmse,
"r2_score": model.r2_score
},
"training_info": {
"training_samples": model.training_samples,
"training_start_date": model.training_start_date.isoformat() if model.training_start_date else None,
"training_end_date": model.training_end_date.isoformat() if model.training_end_date else None,
"data_quality_score": model.data_quality_score
},
"status": {
"is_active": model.is_active,
"is_production": model.is_production,
"created_at": model.created_at.isoformat() if model.created_at else None,
"last_used_at": model.last_used_at.isoformat() if model.last_used_at else None
},
"features": {
"hyperparameters": model.hyperparameters,
"features_used": model.features_used
}
}
except Exception as e:
logger.error("Failed to get model performance summary",
model_id=model_id,
error=str(e))
return {}