Files
bakery-ia/services/training/app/models/training.py

153 lines
6.0 KiB
Python
Raw Normal View History

2025-07-19 16:59:37 +02:00
# services/training/app/models/training.py
2025-07-17 14:34:24 +02:00
"""
2025-07-19 16:59:37 +02:00
Database models for training service
2025-07-17 14:34:24 +02:00
"""
2025-07-19 16:59:37 +02:00
from sqlalchemy import Column, Integer, String, DateTime, Text, Boolean, JSON, Float
from sqlalchemy.dialects.postgresql import UUID, ARRAY
2025-07-25 20:01:37 +02:00
from shared.database.base import Base
2025-07-17 14:34:24 +02:00
from datetime import datetime
import uuid
2025-07-27 10:01:37 +02:00
2025-07-19 16:59:37 +02:00
class ModelTrainingLog(Base):
"""
Table to track training job execution and status.
Replaces the old Celery task tracking.
"""
__tablename__ = "model_training_logs"
id = Column(Integer, primary_key=True, index=True)
job_id = Column(String(255), unique=True, index=True, nullable=False)
2025-07-27 10:01:37 +02:00
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
2025-07-19 16:59:37 +02:00
status = Column(String(50), nullable=False, default="pending") # pending, running, completed, failed, cancelled
progress = Column(Integer, default=0) # 0-100 percentage
current_step = Column(String(500), default="")
# Timestamps
start_time = Column(DateTime, default=datetime.now)
end_time = Column(DateTime, nullable=True)
# Configuration and results
config = Column(JSON, nullable=True) # Training job configuration
results = Column(JSON, nullable=True) # Training results
error_message = Column(Text, nullable=True)
2025-07-17 14:34:24 +02:00
# Metadata
2025-07-19 16:59:37 +02:00
created_at = Column(DateTime, default=datetime.now)
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
2025-07-17 14:34:24 +02:00
class TrainedModel(Base):
2025-07-19 16:59:37 +02:00
"""
Table to store information about trained models.
"""
2025-07-17 14:34:24 +02:00
__tablename__ = "trained_models"
2025-07-19 16:59:37 +02:00
id = Column(Integer, primary_key=True, index=True)
model_id = Column(String(255), unique=True, index=True, nullable=False)
2025-07-27 10:01:37 +02:00
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
2025-07-19 16:59:37 +02:00
product_name = Column(String(255), index=True, nullable=False)
# Model information
model_type = Column(String(50), nullable=False, default="prophet") # prophet, arima, etc.
model_path = Column(String(1000), nullable=False) # Path to stored model file
version = Column(Integer, nullable=False, default=1)
# Training information
training_samples = Column(Integer, nullable=False, default=0)
features = Column(ARRAY(String), nullable=True) # List of features used
hyperparameters = Column(JSON, nullable=True) # Model hyperparameters
training_metrics = Column(JSON, nullable=True) # Training performance metrics
# Data period information
data_period_start = Column(DateTime, nullable=True)
data_period_end = Column(DateTime, nullable=True)
# Status and metadata
is_active = Column(Boolean, default=True, index=True)
created_at = Column(DateTime, default=datetime.now)
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
class ModelPerformanceMetric(Base):
"""
Table to track model performance over time.
"""
__tablename__ = "model_performance_metrics"
2025-07-17 14:34:24 +02:00
2025-07-19 16:59:37 +02:00
id = Column(Integer, primary_key=True, index=True)
model_id = Column(String(255), index=True, nullable=False)
2025-07-27 10:01:37 +02:00
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
2025-07-19 16:59:37 +02:00
product_name = Column(String(255), index=True, nullable=False)
2025-07-17 14:34:24 +02:00
# Performance metrics
2025-07-19 16:59:37 +02:00
mae = Column(Float, nullable=True) # Mean Absolute Error
mse = Column(Float, nullable=True) # Mean Squared Error
rmse = Column(Float, nullable=True) # Root Mean Squared Error
mape = Column(Float, nullable=True) # Mean Absolute Percentage Error
r2_score = Column(Float, nullable=True) # R-squared score
2025-07-17 14:34:24 +02:00
2025-07-19 16:59:37 +02:00
# Additional metrics
accuracy_percentage = Column(Float, nullable=True)
prediction_confidence = Column(Float, nullable=True)
2025-07-17 14:34:24 +02:00
2025-07-19 16:59:37 +02:00
# Evaluation information
evaluation_period_start = Column(DateTime, nullable=True)
evaluation_period_end = Column(DateTime, nullable=True)
evaluation_samples = Column(Integer, nullable=True)
2025-07-17 14:34:24 +02:00
2025-07-19 16:59:37 +02:00
# Metadata
measured_at = Column(DateTime, default=datetime.now)
created_at = Column(DateTime, default=datetime.now)
2025-07-17 14:34:24 +02:00
2025-07-19 16:59:37 +02:00
class TrainingJobQueue(Base):
"""
Table to manage training job queue and scheduling.
"""
__tablename__ = "training_job_queue"
id = Column(Integer, primary_key=True, index=True)
job_id = Column(String(255), unique=True, index=True, nullable=False)
2025-07-27 10:01:37 +02:00
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
2025-07-19 16:59:37 +02:00
# Job configuration
job_type = Column(String(50), nullable=False) # full_training, single_product, evaluation
priority = Column(Integer, default=1) # Higher number = higher priority
config = Column(JSON, nullable=True)
# Scheduling information
scheduled_at = Column(DateTime, nullable=True)
started_at = Column(DateTime, nullable=True)
estimated_duration_minutes = Column(Integer, nullable=True)
2025-07-17 14:34:24 +02:00
2025-07-19 16:59:37 +02:00
# Status
status = Column(String(50), nullable=False, default="queued") # queued, running, completed, failed
retry_count = Column(Integer, default=0)
max_retries = Column(Integer, default=3)
2025-07-17 14:34:24 +02:00
2025-07-19 16:59:37 +02:00
# Metadata
created_at = Column(DateTime, default=datetime.now)
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
class ModelArtifact(Base):
"""
Table to track model files and artifacts.
"""
__tablename__ = "model_artifacts"
id = Column(Integer, primary_key=True, index=True)
model_id = Column(String(255), index=True, nullable=False)
2025-07-27 10:01:37 +02:00
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
2025-07-19 16:59:37 +02:00
# Artifact information
artifact_type = Column(String(50), nullable=False) # model_file, metadata, training_data, etc.
file_path = Column(String(1000), nullable=False)
file_size_bytes = Column(Integer, nullable=True)
checksum = Column(String(255), nullable=True) # For file integrity
# Storage information
storage_location = Column(String(100), nullable=False, default="local") # local, s3, gcs, etc.
compression = Column(String(50), nullable=True) # gzip, lz4, etc.
2025-07-17 14:34:24 +02:00
2025-07-19 16:59:37 +02:00
# Metadata
created_at = Column(DateTime, default=datetime.now)
expires_at = Column(DateTime, nullable=True) # For automatic cleanup