Files
bakery-ia/services/training/app/models/training.py
Urtzi Alfaro cb80a93c4b Few fixes
2025-07-17 14:34:24 +02:00

101 lines
3.5 KiB
Python

"""
Training models
"""
from sqlalchemy import Column, String, Integer, DateTime, Text, JSON, Boolean, Float
from sqlalchemy.dialects.postgresql import UUID
from datetime import datetime
import uuid
from shared.database.base import Base
class TrainingJob(Base):
"""Training job model"""
__tablename__ = "training_jobs"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
status = Column(String(20), nullable=False, default="queued") # queued, running, completed, failed
progress = Column(Integer, default=0)
current_step = Column(String(200))
requested_by = Column(UUID(as_uuid=True), nullable=False)
# Timing
started_at = Column(DateTime, default=datetime.utcnow)
completed_at = Column(DateTime)
duration_seconds = Column(Integer)
# Results
models_trained = Column(JSON)
metrics = Column(JSON)
error_message = Column(Text)
# Metadata
training_data_from = Column(DateTime)
training_data_to = Column(DateTime)
total_data_points = Column(Integer)
products_count = Column(Integer)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
def __repr__(self):
return f"<TrainingJob(id={self.id}, tenant_id={self.tenant_id}, status={self.status})>"
class TrainedModel(Base):
"""Trained model information"""
__tablename__ = "trained_models"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
training_job_id = Column(UUID(as_uuid=True), nullable=False)
# Model details
product_name = Column(String(100), nullable=False)
model_type = Column(String(50), nullable=False, default="prophet")
model_version = Column(String(20), nullable=False)
model_path = Column(String(500)) # Path to saved model file
# Performance metrics
mape = Column(Float) # Mean Absolute Percentage Error
rmse = Column(Float) # Root Mean Square Error
mae = Column(Float) # Mean Absolute Error
r2_score = Column(Float) # R-squared score
# Training details
training_samples = Column(Integer)
validation_samples = Column(Integer)
features_used = Column(JSON)
hyperparameters = Column(JSON)
# Status
is_active = Column(Boolean, default=True)
last_used_at = Column(DateTime)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
def __repr__(self):
return f"<TrainedModel(id={self.id}, product={self.product_name}, tenant={self.tenant_id})>"
class TrainingLog(Base):
"""Training log entries"""
__tablename__ = "training_logs"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
training_job_id = Column(UUID(as_uuid=True), nullable=False, index=True)
level = Column(String(10), nullable=False) # DEBUG, INFO, WARNING, ERROR
message = Column(Text, nullable=False)
step = Column(String(100))
progress = Column(Integer)
# Additional data
execution_time = Column(Float) # Time taken for this step
memory_usage = Column(Float) # Memory usage in MB
metadata = Column(JSON) # Additional metadata
created_at = Column(DateTime, default=datetime.utcnow)
def __repr__(self):
return f"<TrainingLog(id={self.id}, level={self.level})>"