101 lines
3.6 KiB
Python
101 lines
3.6 KiB
Python
"""
|
|
Training models - Fixed version
|
|
"""
|
|
|
|
from sqlalchemy import Column, String, Integer, DateTime, Text, JSON, Boolean, Float
|
|
from sqlalchemy.dialects.postgresql import UUID
|
|
from datetime import datetime
|
|
import uuid
|
|
|
|
from shared.database.base import Base
|
|
|
|
class TrainingJob(Base):
|
|
"""Training job model"""
|
|
__tablename__ = "training_jobs"
|
|
|
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
|
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
|
status = Column(String(20), nullable=False, default="queued") # queued, running, completed, failed
|
|
progress = Column(Integer, default=0)
|
|
current_step = Column(String(200))
|
|
requested_by = Column(UUID(as_uuid=True), nullable=False)
|
|
|
|
# Timing
|
|
started_at = Column(DateTime, default=datetime.utcnow)
|
|
completed_at = Column(DateTime)
|
|
duration_seconds = Column(Integer)
|
|
|
|
# Results
|
|
models_trained = Column(JSON)
|
|
metrics = Column(JSON)
|
|
error_message = Column(Text)
|
|
|
|
# Metadata
|
|
training_data_from = Column(DateTime)
|
|
training_data_to = Column(DateTime)
|
|
total_data_points = Column(Integer)
|
|
products_count = Column(Integer)
|
|
|
|
created_at = Column(DateTime, default=datetime.utcnow)
|
|
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
|
|
|
def __repr__(self):
|
|
return f"<TrainingJob(id={self.id}, tenant_id={self.tenant_id}, status={self.status})>"
|
|
|
|
class TrainedModel(Base):
|
|
"""Trained model information"""
|
|
__tablename__ = "trained_models"
|
|
|
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
|
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
|
training_job_id = Column(UUID(as_uuid=True), nullable=False)
|
|
|
|
# Model details
|
|
product_name = Column(String(100), nullable=False)
|
|
model_type = Column(String(50), nullable=False, default="prophet")
|
|
model_version = Column(String(20), nullable=False)
|
|
model_path = Column(String(500)) # Path to saved model file
|
|
|
|
# Performance metrics
|
|
mape = Column(Float) # Mean Absolute Percentage Error
|
|
rmse = Column(Float) # Root Mean Square Error
|
|
mae = Column(Float) # Mean Absolute Error
|
|
r2_score = Column(Float) # R-squared score
|
|
|
|
# Training details
|
|
training_samples = Column(Integer)
|
|
validation_samples = Column(Integer)
|
|
features_used = Column(JSON)
|
|
hyperparameters = Column(JSON)
|
|
|
|
# Status
|
|
is_active = Column(Boolean, default=True)
|
|
last_used_at = Column(DateTime)
|
|
|
|
created_at = Column(DateTime, default=datetime.utcnow)
|
|
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
|
|
|
def __repr__(self):
|
|
return f"<TrainedModel(id={self.id}, product={self.product_name}, tenant={self.tenant_id})>"
|
|
|
|
class TrainingLog(Base):
|
|
"""Training log entries - FIXED: renamed metadata to log_metadata"""
|
|
__tablename__ = "training_logs"
|
|
|
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
|
training_job_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
|
|
|
level = Column(String(10), nullable=False) # DEBUG, INFO, WARNING, ERROR
|
|
message = Column(Text, nullable=False)
|
|
step = Column(String(100))
|
|
progress = Column(Integer)
|
|
|
|
# Additional data
|
|
execution_time = Column(Float) # Time taken for this step
|
|
memory_usage = Column(Float) # Memory usage in MB
|
|
log_metadata = Column(JSON) # FIXED: renamed from 'metadata' to 'log_metadata'
|
|
|
|
created_at = Column(DateTime, default=datetime.utcnow)
|
|
|
|
def __repr__(self):
|
|
return f"<TrainingLog(id={self.id}, level={self.level})>" |