Add all the code for training service

This commit is contained in:
Urtzi Alfaro
2025-07-19 16:59:37 +02:00
parent 42097202d2
commit f3071c00bd
21 changed files with 7504 additions and 764 deletions

View File

@@ -1,101 +1,154 @@
# services/training/app/models/training.py
"""
Training models - Fixed version
Database models for training service
"""
from sqlalchemy import Column, String, Integer, DateTime, Text, JSON, Boolean, Float
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy import Column, Integer, String, DateTime, Text, Boolean, JSON, Float
from sqlalchemy.dialects.postgresql import UUID, ARRAY
from sqlalchemy.ext.declarative import declarative_base
from datetime import datetime
import uuid
from shared.database.base import Base
Base = declarative_base()
class TrainingJob(Base):
"""Training job model"""
__tablename__ = "training_jobs"
class ModelTrainingLog(Base):
"""
Table to track training job execution and status.
Replaces the old Celery task tracking.
"""
__tablename__ = "model_training_logs"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
status = Column(String(20), nullable=False, default="queued") # queued, running, completed, failed
progress = Column(Integer, default=0)
current_step = Column(String(200))
requested_by = Column(UUID(as_uuid=True), nullable=False)
id = Column(Integer, primary_key=True, index=True)
job_id = Column(String(255), unique=True, index=True, nullable=False)
tenant_id = Column(String(255), index=True, nullable=False)
status = Column(String(50), nullable=False, default="pending") # pending, running, completed, failed, cancelled
progress = Column(Integer, default=0) # 0-100 percentage
current_step = Column(String(500), default="")
# Timing
started_at = Column(DateTime, default=datetime.utcnow)
completed_at = Column(DateTime)
duration_seconds = Column(Integer)
# Timestamps
start_time = Column(DateTime, default=datetime.now)
end_time = Column(DateTime, nullable=True)
# Results
models_trained = Column(JSON)
metrics = Column(JSON)
error_message = Column(Text)
# Configuration and results
config = Column(JSON, nullable=True) # Training job configuration
results = Column(JSON, nullable=True) # Training results
error_message = Column(Text, nullable=True)
# Metadata
training_data_from = Column(DateTime)
training_data_to = Column(DateTime)
total_data_points = Column(Integer)
products_count = Column(Integer)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
def __repr__(self):
return f"<TrainingJob(id={self.id}, tenant_id={self.tenant_id}, status={self.status})>"
created_at = Column(DateTime, default=datetime.now)
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
class TrainedModel(Base):
"""Trained model information"""
"""
Table to store information about trained models.
"""
__tablename__ = "trained_models"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
training_job_id = Column(UUID(as_uuid=True), nullable=False)
id = Column(Integer, primary_key=True, index=True)
model_id = Column(String(255), unique=True, index=True, nullable=False)
tenant_id = Column(String(255), index=True, nullable=False)
product_name = Column(String(255), index=True, nullable=False)
# Model details
product_name = Column(String(100), nullable=False)
model_type = Column(String(50), nullable=False, default="prophet")
model_version = Column(String(20), nullable=False)
model_path = Column(String(500)) # Path to saved model file
# Model information
model_type = Column(String(50), nullable=False, default="prophet") # prophet, arima, etc.
model_path = Column(String(1000), nullable=False) # Path to stored model file
version = Column(Integer, nullable=False, default=1)
# Training information
training_samples = Column(Integer, nullable=False, default=0)
features = Column(ARRAY(String), nullable=True) # List of features used
hyperparameters = Column(JSON, nullable=True) # Model hyperparameters
training_metrics = Column(JSON, nullable=True) # Training performance metrics
# Data period information
data_period_start = Column(DateTime, nullable=True)
data_period_end = Column(DateTime, nullable=True)
# Status and metadata
is_active = Column(Boolean, default=True, index=True)
created_at = Column(DateTime, default=datetime.now)
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
class ModelPerformanceMetric(Base):
"""
Table to track model performance over time.
"""
__tablename__ = "model_performance_metrics"
id = Column(Integer, primary_key=True, index=True)
model_id = Column(String(255), index=True, nullable=False)
tenant_id = Column(String(255), index=True, nullable=False)
product_name = Column(String(255), index=True, nullable=False)
# Performance metrics
mape = Column(Float) # Mean Absolute Percentage Error
rmse = Column(Float) # Root Mean Square Error
mae = Column(Float) # Mean Absolute Error
r2_score = Column(Float) # R-squared score
mae = Column(Float, nullable=True) # Mean Absolute Error
mse = Column(Float, nullable=True) # Mean Squared Error
rmse = Column(Float, nullable=True) # Root Mean Squared Error
mape = Column(Float, nullable=True) # Mean Absolute Percentage Error
r2_score = Column(Float, nullable=True) # R-squared score
# Training details
training_samples = Column(Integer)
validation_samples = Column(Integer)
features_used = Column(JSON)
hyperparameters = Column(JSON)
# Additional metrics
accuracy_percentage = Column(Float, nullable=True)
prediction_confidence = Column(Float, nullable=True)
# Evaluation information
evaluation_period_start = Column(DateTime, nullable=True)
evaluation_period_end = Column(DateTime, nullable=True)
evaluation_samples = Column(Integer, nullable=True)
# Metadata
measured_at = Column(DateTime, default=datetime.now)
created_at = Column(DateTime, default=datetime.now)
class TrainingJobQueue(Base):
"""
Table to manage training job queue and scheduling.
"""
__tablename__ = "training_job_queue"
id = Column(Integer, primary_key=True, index=True)
job_id = Column(String(255), unique=True, index=True, nullable=False)
tenant_id = Column(String(255), index=True, nullable=False)
# Job configuration
job_type = Column(String(50), nullable=False) # full_training, single_product, evaluation
priority = Column(Integer, default=1) # Higher number = higher priority
config = Column(JSON, nullable=True)
# Scheduling information
scheduled_at = Column(DateTime, nullable=True)
started_at = Column(DateTime, nullable=True)
estimated_duration_minutes = Column(Integer, nullable=True)
# Status
is_active = Column(Boolean, default=True)
last_used_at = Column(DateTime)
status = Column(String(50), nullable=False, default="queued") # queued, running, completed, failed
retry_count = Column(Integer, default=0)
max_retries = Column(Integer, default=3)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
def __repr__(self):
return f"<TrainedModel(id={self.id}, product={self.product_name}, tenant={self.tenant_id})>"
# Metadata
created_at = Column(DateTime, default=datetime.now)
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
class TrainingLog(Base):
"""Training log entries - FIXED: renamed metadata to log_metadata"""
__tablename__ = "training_logs"
class ModelArtifact(Base):
"""
Table to track model files and artifacts.
"""
__tablename__ = "model_artifacts"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
training_job_id = Column(UUID(as_uuid=True), nullable=False, index=True)
id = Column(Integer, primary_key=True, index=True)
model_id = Column(String(255), index=True, nullable=False)
tenant_id = Column(String(255), index=True, nullable=False)
level = Column(String(10), nullable=False) # DEBUG, INFO, WARNING, ERROR
message = Column(Text, nullable=False)
step = Column(String(100))
progress = Column(Integer)
# Artifact information
artifact_type = Column(String(50), nullable=False) # model_file, metadata, training_data, etc.
file_path = Column(String(1000), nullable=False)
file_size_bytes = Column(Integer, nullable=True)
checksum = Column(String(255), nullable=True) # For file integrity
# Additional data
execution_time = Column(Float) # Time taken for this step
memory_usage = Column(Float) # Memory usage in MB
log_metadata = Column(JSON) # FIXED: renamed from 'metadata' to 'log_metadata'
# Storage information
storage_location = Column(String(100), nullable=False, default="local") # local, s3, gcs, etc.
compression = Column(String(50), nullable=True) # gzip, lz4, etc.
created_at = Column(DateTime, default=datetime.utcnow)
def __repr__(self):
return f"<TrainingLog(id={self.id}, level={self.level})>"
# Metadata
created_at = Column(DateTime, default=datetime.now)
expires_at = Column(DateTime, nullable=True) # For automatic cleanup