REFACTOR - Database logic
This commit is contained in:
@@ -6,7 +6,7 @@ Database models for training service
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Text, Boolean, JSON, Float
|
||||
from sqlalchemy.dialects.postgresql import UUID, ARRAY
|
||||
from shared.database.base import Base
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
|
||||
|
||||
@@ -25,8 +25,8 @@ class ModelTrainingLog(Base):
|
||||
current_step = Column(String(500), default="")
|
||||
|
||||
# Timestamps
|
||||
start_time = Column(DateTime, default=datetime.now)
|
||||
end_time = Column(DateTime, nullable=True)
|
||||
start_time = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
end_time = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Configuration and results
|
||||
config = Column(JSON, nullable=True) # Training job configuration
|
||||
@@ -34,8 +34,8 @@ class ModelTrainingLog(Base):
|
||||
error_message = Column(Text, nullable=True)
|
||||
|
||||
# Metadata
|
||||
created_at = Column(DateTime, default=datetime.now)
|
||||
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
||||
class ModelPerformanceMetric(Base):
|
||||
"""
|
||||
@@ -65,8 +65,8 @@ class ModelPerformanceMetric(Base):
|
||||
evaluation_samples = Column(Integer, nullable=True)
|
||||
|
||||
# Metadata
|
||||
measured_at = Column(DateTime, default=datetime.now)
|
||||
created_at = Column(DateTime, default=datetime.now)
|
||||
measured_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
|
||||
class TrainingJobQueue(Base):
|
||||
"""
|
||||
@@ -94,8 +94,8 @@ class TrainingJobQueue(Base):
|
||||
max_retries = Column(Integer, default=3)
|
||||
|
||||
# Metadata
|
||||
created_at = Column(DateTime, default=datetime.now)
|
||||
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
cancelled_by = Column(String, nullable=True)
|
||||
|
||||
class ModelArtifact(Base):
|
||||
@@ -119,15 +119,15 @@ class ModelArtifact(Base):
|
||||
compression = Column(String(50), nullable=True) # gzip, lz4, etc.
|
||||
|
||||
# Metadata
|
||||
created_at = Column(DateTime, default=datetime.now)
|
||||
expires_at = Column(DateTime, nullable=True) # For automatic cleanup
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
expires_at = Column(DateTime(timezone=True), nullable=True) # For automatic cleanup
|
||||
|
||||
class TrainedModel(Base):
|
||||
__tablename__ = "trained_models"
|
||||
|
||||
# Primary identification
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
tenant_id = Column(String, nullable=False, index=True)
|
||||
# Primary identification - Updated to use UUID properly
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
product_name = Column(String, nullable=False, index=True)
|
||||
|
||||
# Model information
|
||||
@@ -154,13 +154,14 @@ class TrainedModel(Base):
|
||||
is_active = Column(Boolean, default=True)
|
||||
is_production = Column(Boolean, default=False)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
last_used_at = Column(DateTime)
|
||||
# Timestamps - Updated to be timezone-aware with proper defaults
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
last_used_at = Column(DateTime(timezone=True))
|
||||
|
||||
# Training data info
|
||||
training_start_date = Column(DateTime)
|
||||
training_end_date = Column(DateTime)
|
||||
training_start_date = Column(DateTime(timezone=True))
|
||||
training_end_date = Column(DateTime(timezone=True))
|
||||
data_quality_score = Column(Float)
|
||||
|
||||
# Additional metadata
|
||||
@@ -169,9 +170,9 @@ class TrainedModel(Base):
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"model_id": self.id,
|
||||
"tenant_id": self.tenant_id,
|
||||
"id": str(self.id),
|
||||
"model_id": str(self.id),
|
||||
"tenant_id": str(self.tenant_id),
|
||||
"product_name": self.product_name,
|
||||
"model_type": self.model_type,
|
||||
"model_version": self.model_version,
|
||||
@@ -186,6 +187,7 @@ class TrainedModel(Base):
|
||||
"is_active": self.is_active,
|
||||
"is_production": self.is_production,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
"last_used_at": self.last_used_at.isoformat() if self.last_used_at else None,
|
||||
"training_start_date": self.training_start_date.isoformat() if self.training_start_date else None,
|
||||
"training_end_date": self.training_end_date.isoformat() if self.training_end_date else None,
|
||||
|
||||
@@ -1,80 +1,11 @@
|
||||
# services/training/app/models/training_models.py
|
||||
"""
|
||||
Database models for trained ML models
|
||||
Legacy file - TrainedModel has been moved to training.py
|
||||
This file is deprecated and should be removed after migration.
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, DateTime, Float, Integer, Text, Boolean, JSON
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
# Import the actual model from the correct location
|
||||
from .training import TrainedModel
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class TrainedModel(Base):
|
||||
__tablename__ = "trained_models"
|
||||
|
||||
# Primary identification
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
tenant_id = Column(String, nullable=False, index=True)
|
||||
product_name = Column(String, nullable=False, index=True)
|
||||
|
||||
# Model information
|
||||
model_type = Column(String, default="prophet_optimized")
|
||||
model_version = Column(String, default="1.0")
|
||||
job_id = Column(String, nullable=False)
|
||||
|
||||
# File storage
|
||||
model_path = Column(String, nullable=False) # Path to the .pkl file
|
||||
metadata_path = Column(String) # Path to metadata JSON
|
||||
|
||||
# Training metrics
|
||||
mape = Column(Float)
|
||||
mae = Column(Float)
|
||||
rmse = Column(Float)
|
||||
r2_score = Column(Float)
|
||||
training_samples = Column(Integer)
|
||||
|
||||
# Hyperparameters and features
|
||||
hyperparameters = Column(JSON) # Store optimized parameters
|
||||
features_used = Column(JSON) # List of regressor columns
|
||||
|
||||
# Model status
|
||||
is_active = Column(Boolean, default=True)
|
||||
is_production = Column(Boolean, default=False)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
last_used_at = Column(DateTime)
|
||||
|
||||
# Training data info
|
||||
training_start_date = Column(DateTime)
|
||||
training_end_date = Column(DateTime)
|
||||
data_quality_score = Column(Float)
|
||||
|
||||
# Additional metadata
|
||||
notes = Column(Text)
|
||||
created_by = Column(String) # User who triggered training
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"tenant_id": self.tenant_id,
|
||||
"product_name": self.product_name,
|
||||
"model_type": self.model_type,
|
||||
"model_version": self.model_version,
|
||||
"model_path": self.model_path,
|
||||
"mape": self.mape,
|
||||
"mae": self.mae,
|
||||
"rmse": self.rmse,
|
||||
"r2_score": self.r2_score,
|
||||
"training_samples": self.training_samples,
|
||||
"hyperparameters": self.hyperparameters,
|
||||
"features_used": self.features_used,
|
||||
"is_active": self.is_active,
|
||||
"is_production": self.is_production,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"last_used_at": self.last_used_at.isoformat() if self.last_used_at else None,
|
||||
"training_start_date": self.training_start_date.isoformat() if self.training_start_date else None,
|
||||
"training_end_date": self.training_end_date.isoformat() if self.training_end_date else None,
|
||||
"data_quality_score": self.data_quality_score
|
||||
}
|
||||
# For backward compatibility, re-export the model
|
||||
__all__ = ["TrainedModel"]
|
||||
Reference in New Issue
Block a user