376 lines
15 KiB
Python
376 lines
15 KiB
Python
"""
|
|
Model Repository
|
|
Repository for trained model operations
|
|
"""
|
|
|
|
from typing import Optional, List, Dict, Any
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, and_, text, desc
|
|
from datetime import datetime, timezone, timedelta
|
|
import structlog
|
|
|
|
from .base import TrainingBaseRepository
|
|
from app.models.training import TrainedModel
|
|
from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class ModelRepository(TrainingBaseRepository):
|
|
"""Repository for trained model operations"""
|
|
|
|
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 600):
|
|
# Models are relatively stable, longer cache time (10 minutes)
|
|
super().__init__(TrainedModel, session, cache_ttl)
|
|
|
|
async def create_model(self, model_data: Dict[str, Any]) -> TrainedModel:
|
|
"""Create a new trained model with validation"""
|
|
try:
|
|
# Validate model data
|
|
validation_result = self._validate_training_data(
|
|
model_data,
|
|
["tenant_id", "inventory_product_id", "model_path", "job_id"]
|
|
)
|
|
|
|
if not validation_result["is_valid"]:
|
|
raise ValidationError(f"Invalid model data: {validation_result['errors']}")
|
|
|
|
# Check for duplicate active models for same tenant+product
|
|
existing_model = await self.get_active_model_for_product(
|
|
model_data["tenant_id"],
|
|
model_data["inventory_product_id"]
|
|
)
|
|
|
|
# If there's an existing active model, we may want to deactivate it
|
|
if existing_model and model_data.get("is_production", False):
|
|
logger.info("Deactivating previous production model",
|
|
previous_model_id=existing_model.id,
|
|
tenant_id=model_data["tenant_id"],
|
|
inventory_product_id=model_data["inventory_product_id"])
|
|
await self.update(existing_model.id, {"is_production": False})
|
|
|
|
# Create new model
|
|
model = await self.create(model_data)
|
|
|
|
logger.info("Trained model created successfully",
|
|
model_id=model.id,
|
|
tenant_id=model.tenant_id,
|
|
inventory_product_id=str(model.inventory_product_id),
|
|
model_type=model.model_type)
|
|
|
|
return model
|
|
|
|
except (ValidationError, DuplicateRecordError):
|
|
raise
|
|
except Exception as e:
|
|
logger.error("Failed to create trained model",
|
|
tenant_id=model_data.get("tenant_id"),
|
|
inventory_product_id=model_data.get("inventory_product_id"),
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to create model: {str(e)}")
|
|
|
|
async def get_model_by_tenant_and_product(
|
|
self,
|
|
tenant_id: str,
|
|
inventory_product_id: str
|
|
) -> List[TrainedModel]:
|
|
"""Get all models for a tenant and product"""
|
|
try:
|
|
return await self.get_multi(
|
|
filters={
|
|
"tenant_id": tenant_id,
|
|
"inventory_product_id": inventory_product_id
|
|
},
|
|
order_by="created_at",
|
|
order_desc=True
|
|
)
|
|
except Exception as e:
|
|
logger.error("Failed to get models by tenant and product",
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=inventory_product_id,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to get models: {str(e)}")
|
|
|
|
async def get_active_model_for_product(
|
|
self,
|
|
tenant_id: str,
|
|
inventory_product_id: str
|
|
) -> Optional[TrainedModel]:
|
|
"""Get the active production model for a product"""
|
|
try:
|
|
models = await self.get_multi(
|
|
filters={
|
|
"tenant_id": tenant_id,
|
|
"inventory_product_id": inventory_product_id,
|
|
"is_active": True,
|
|
"is_production": True
|
|
},
|
|
order_by="created_at",
|
|
order_desc=True,
|
|
limit=1
|
|
)
|
|
return models[0] if models else None
|
|
except Exception as e:
|
|
logger.error("Failed to get active model for product",
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=inventory_product_id,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to get active model: {str(e)}")
|
|
|
|
async def get_models_by_tenant(
|
|
self,
|
|
tenant_id: str,
|
|
skip: int = 0,
|
|
limit: int = 100
|
|
) -> List[TrainedModel]:
|
|
"""Get all models for a tenant"""
|
|
return await self.get_by_tenant_id(tenant_id, skip=skip, limit=limit)
|
|
|
|
async def promote_to_production(self, model_id: str) -> Optional[TrainedModel]:
|
|
"""Promote a model to production"""
|
|
try:
|
|
# Get the model first
|
|
model = await self.get_by_id(model_id)
|
|
if not model:
|
|
raise ValueError(f"Model {model_id} not found")
|
|
|
|
# Deactivate other production models for the same tenant+product
|
|
await self._deactivate_other_production_models(
|
|
model.tenant_id,
|
|
str(model.inventory_product_id),
|
|
model_id
|
|
)
|
|
|
|
# Promote this model
|
|
updated_model = await self.update(model_id, {
|
|
"is_production": True,
|
|
"last_used_at": datetime.now(timezone.utc)
|
|
})
|
|
|
|
logger.info("Model promoted to production",
|
|
model_id=model_id,
|
|
tenant_id=model.tenant_id,
|
|
inventory_product_id=str(model.inventory_product_id))
|
|
|
|
return updated_model
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to promote model to production",
|
|
model_id=model_id,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to promote model: {str(e)}")
|
|
|
|
async def update_model_usage(self, model_id: str) -> Optional[TrainedModel]:
|
|
"""Update model last used timestamp"""
|
|
try:
|
|
return await self.update(model_id, {
|
|
"last_used_at": datetime.now(timezone.utc)
|
|
})
|
|
except Exception as e:
|
|
logger.error("Failed to update model usage",
|
|
model_id=model_id,
|
|
error=str(e))
|
|
# Don't raise here - usage update is not critical
|
|
return None
|
|
|
|
async def archive_old_models(self, tenant_id: str, days_old: int = 90) -> int:
|
|
"""Archive old non-production models"""
|
|
try:
|
|
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_old)
|
|
|
|
query = text("""
|
|
UPDATE trained_models
|
|
SET is_active = false
|
|
WHERE tenant_id = :tenant_id
|
|
AND is_production = false
|
|
AND created_at < :cutoff_date
|
|
AND is_active = true
|
|
""")
|
|
|
|
result = await self.session.execute(query, {
|
|
"tenant_id": tenant_id,
|
|
"cutoff_date": cutoff_date
|
|
})
|
|
|
|
archived_count = result.rowcount
|
|
|
|
logger.info("Archived old models",
|
|
tenant_id=tenant_id,
|
|
archived_count=archived_count,
|
|
days_old=days_old)
|
|
|
|
return archived_count
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to archive old models",
|
|
tenant_id=tenant_id,
|
|
error=str(e))
|
|
raise DatabaseError(f"Model archival failed: {str(e)}")
|
|
|
|
async def get_model_statistics(self, tenant_id: str) -> Dict[str, Any]:
|
|
"""Get model statistics for a tenant"""
|
|
try:
|
|
# Get basic counts
|
|
total_models = await self.count(filters={"tenant_id": tenant_id})
|
|
active_models = await self.count(filters={
|
|
"tenant_id": tenant_id,
|
|
"is_active": True
|
|
})
|
|
production_models = await self.count(filters={
|
|
"tenant_id": tenant_id,
|
|
"is_production": True
|
|
})
|
|
|
|
# Get models by product using raw query
|
|
product_query = text("""
|
|
SELECT inventory_product_id, COUNT(*) as count
|
|
FROM trained_models
|
|
WHERE tenant_id = :tenant_id
|
|
AND is_active = true
|
|
GROUP BY inventory_product_id
|
|
ORDER BY count DESC
|
|
""")
|
|
|
|
result = await self.session.execute(product_query, {"tenant_id": tenant_id})
|
|
product_stats = {row.inventory_product_id: row.count for row in result.fetchall()}
|
|
|
|
# Recent activity (models created in last 30 days)
|
|
thirty_days_ago = datetime.now(timezone.utc) - timedelta(days=30)
|
|
recent_models_query = text("""
|
|
SELECT COUNT(*) as count
|
|
FROM trained_models
|
|
WHERE tenant_id = :tenant_id
|
|
AND created_at >= :thirty_days_ago
|
|
""")
|
|
|
|
recent_result = await self.session.execute(
|
|
recent_models_query,
|
|
{"tenant_id": tenant_id, "thirty_days_ago": thirty_days_ago}
|
|
)
|
|
recent_models = recent_result.scalar() or 0
|
|
|
|
# Calculate average accuracy from model metrics
|
|
accuracy_query = text("""
|
|
SELECT AVG(mape) as average_mape, COUNT(*) as total_models_with_metrics
|
|
FROM trained_models
|
|
WHERE tenant_id = :tenant_id
|
|
AND mape IS NOT NULL
|
|
AND is_active = true
|
|
""")
|
|
|
|
accuracy_result = await self.session.execute(accuracy_query, {"tenant_id": tenant_id})
|
|
accuracy_row = accuracy_result.fetchone()
|
|
|
|
average_mape = accuracy_row.average_mape if accuracy_row and accuracy_row.average_mape else 0
|
|
total_models_with_metrics = accuracy_row.total_models_with_metrics if accuracy_row else 0
|
|
|
|
# Convert MAPE to accuracy percentage (lower MAPE = higher accuracy)
|
|
# Use 100 - MAPE as a simple conversion, but cap it at reasonable bounds
|
|
# Return None if no models have metrics (no data), rather than 0
|
|
if total_models_with_metrics == 0:
|
|
average_accuracy = None
|
|
else:
|
|
average_accuracy = max(0, min(100, 100 - float(average_mape))) if average_mape > 0 else 0
|
|
|
|
return {
|
|
"total_models": total_models,
|
|
"active_models": active_models,
|
|
"inactive_models": total_models - active_models,
|
|
"production_models": production_models,
|
|
"models_by_product": product_stats,
|
|
"recent_models_30d": recent_models,
|
|
"average_accuracy": average_accuracy,
|
|
"total_models_with_metrics": total_models_with_metrics,
|
|
"average_mape": float(average_mape) if average_mape > 0 else 0
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get model statistics",
|
|
tenant_id=tenant_id,
|
|
error=str(e))
|
|
return {
|
|
"total_models": 0,
|
|
"active_models": 0,
|
|
"inactive_models": 0,
|
|
"production_models": 0,
|
|
"models_by_product": {},
|
|
"recent_models_30d": 0,
|
|
"average_accuracy": 0,
|
|
"total_models_with_metrics": 0,
|
|
"average_mape": 0
|
|
}
|
|
|
|
async def _deactivate_other_production_models(
|
|
self,
|
|
tenant_id: str,
|
|
inventory_product_id: str,
|
|
exclude_model_id: str
|
|
) -> int:
|
|
"""Deactivate other production models for the same tenant+product"""
|
|
try:
|
|
query = text("""
|
|
UPDATE trained_models
|
|
SET is_production = false
|
|
WHERE tenant_id = :tenant_id
|
|
AND inventory_product_id = :inventory_product_id
|
|
AND id != :exclude_model_id
|
|
AND is_production = true
|
|
""")
|
|
|
|
result = await self.session.execute(query, {
|
|
"tenant_id": tenant_id,
|
|
"inventory_product_id": inventory_product_id,
|
|
"exclude_model_id": exclude_model_id
|
|
})
|
|
|
|
return result.rowcount
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to deactivate other production models",
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=inventory_product_id,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to deactivate models: {str(e)}")
|
|
|
|
async def get_model_performance_summary(self, model_id: str) -> Dict[str, Any]:
|
|
"""Get performance summary for a model"""
|
|
try:
|
|
model = await self.get_by_id(model_id)
|
|
if not model:
|
|
return {}
|
|
|
|
return {
|
|
"model_id": model.id,
|
|
"tenant_id": model.tenant_id,
|
|
"inventory_product_id": str(model.inventory_product_id),
|
|
"model_type": model.model_type,
|
|
"metrics": {
|
|
"mape": model.mape,
|
|
"mae": model.mae,
|
|
"rmse": model.rmse,
|
|
"r2_score": model.r2_score
|
|
},
|
|
"training_info": {
|
|
"training_samples": model.training_samples,
|
|
"training_start_date": model.training_start_date.isoformat() if model.training_start_date else None,
|
|
"training_end_date": model.training_end_date.isoformat() if model.training_end_date else None,
|
|
"data_quality_score": model.data_quality_score
|
|
},
|
|
"status": {
|
|
"is_active": model.is_active,
|
|
"is_production": model.is_production,
|
|
"created_at": model.created_at.isoformat() if model.created_at else None,
|
|
"last_used_at": model.last_used_at.isoformat() if model.last_used_at else None
|
|
},
|
|
"features": {
|
|
"hyperparameters": model.hyperparameters,
|
|
"features_used": model.features_used
|
|
}
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get model performance summary",
|
|
model_id=model_id,
|
|
error=str(e))
|
|
return {}
|