Add all the code for training service
This commit is contained in:
@@ -1,101 +1,154 @@
|
||||
# services/training/app/models/training.py
|
||||
"""
|
||||
Training models - Fixed version
|
||||
Database models for training service
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, Integer, DateTime, Text, JSON, Boolean, Float
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Text, Boolean, JSON, Float
|
||||
from sqlalchemy.dialects.postgresql import UUID, ARRAY
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
from shared.database.base import Base
|
||||
Base = declarative_base()
|
||||
|
||||
class TrainingJob(Base):
|
||||
"""Training job model"""
|
||||
__tablename__ = "training_jobs"
|
||||
class ModelTrainingLog(Base):
|
||||
"""
|
||||
Table to track training job execution and status.
|
||||
Replaces the old Celery task tracking.
|
||||
"""
|
||||
__tablename__ = "model_training_logs"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
status = Column(String(20), nullable=False, default="queued") # queued, running, completed, failed
|
||||
progress = Column(Integer, default=0)
|
||||
current_step = Column(String(200))
|
||||
requested_by = Column(UUID(as_uuid=True), nullable=False)
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
job_id = Column(String(255), unique=True, index=True, nullable=False)
|
||||
tenant_id = Column(String(255), index=True, nullable=False)
|
||||
status = Column(String(50), nullable=False, default="pending") # pending, running, completed, failed, cancelled
|
||||
progress = Column(Integer, default=0) # 0-100 percentage
|
||||
current_step = Column(String(500), default="")
|
||||
|
||||
# Timing
|
||||
started_at = Column(DateTime, default=datetime.utcnow)
|
||||
completed_at = Column(DateTime)
|
||||
duration_seconds = Column(Integer)
|
||||
# Timestamps
|
||||
start_time = Column(DateTime, default=datetime.now)
|
||||
end_time = Column(DateTime, nullable=True)
|
||||
|
||||
# Results
|
||||
models_trained = Column(JSON)
|
||||
metrics = Column(JSON)
|
||||
error_message = Column(Text)
|
||||
# Configuration and results
|
||||
config = Column(JSON, nullable=True) # Training job configuration
|
||||
results = Column(JSON, nullable=True) # Training results
|
||||
error_message = Column(Text, nullable=True)
|
||||
|
||||
# Metadata
|
||||
training_data_from = Column(DateTime)
|
||||
training_data_to = Column(DateTime)
|
||||
total_data_points = Column(Integer)
|
||||
products_count = Column(Integer)
|
||||
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<TrainingJob(id={self.id}, tenant_id={self.tenant_id}, status={self.status})>"
|
||||
created_at = Column(DateTime, default=datetime.now)
|
||||
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
|
||||
|
||||
class TrainedModel(Base):
|
||||
"""Trained model information"""
|
||||
"""
|
||||
Table to store information about trained models.
|
||||
"""
|
||||
__tablename__ = "trained_models"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
training_job_id = Column(UUID(as_uuid=True), nullable=False)
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
model_id = Column(String(255), unique=True, index=True, nullable=False)
|
||||
tenant_id = Column(String(255), index=True, nullable=False)
|
||||
product_name = Column(String(255), index=True, nullable=False)
|
||||
|
||||
# Model details
|
||||
product_name = Column(String(100), nullable=False)
|
||||
model_type = Column(String(50), nullable=False, default="prophet")
|
||||
model_version = Column(String(20), nullable=False)
|
||||
model_path = Column(String(500)) # Path to saved model file
|
||||
# Model information
|
||||
model_type = Column(String(50), nullable=False, default="prophet") # prophet, arima, etc.
|
||||
model_path = Column(String(1000), nullable=False) # Path to stored model file
|
||||
version = Column(Integer, nullable=False, default=1)
|
||||
|
||||
# Training information
|
||||
training_samples = Column(Integer, nullable=False, default=0)
|
||||
features = Column(ARRAY(String), nullable=True) # List of features used
|
||||
hyperparameters = Column(JSON, nullable=True) # Model hyperparameters
|
||||
training_metrics = Column(JSON, nullable=True) # Training performance metrics
|
||||
|
||||
# Data period information
|
||||
data_period_start = Column(DateTime, nullable=True)
|
||||
data_period_end = Column(DateTime, nullable=True)
|
||||
|
||||
# Status and metadata
|
||||
is_active = Column(Boolean, default=True, index=True)
|
||||
created_at = Column(DateTime, default=datetime.now)
|
||||
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
|
||||
|
||||
class ModelPerformanceMetric(Base):
|
||||
"""
|
||||
Table to track model performance over time.
|
||||
"""
|
||||
__tablename__ = "model_performance_metrics"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
model_id = Column(String(255), index=True, nullable=False)
|
||||
tenant_id = Column(String(255), index=True, nullable=False)
|
||||
product_name = Column(String(255), index=True, nullable=False)
|
||||
|
||||
# Performance metrics
|
||||
mape = Column(Float) # Mean Absolute Percentage Error
|
||||
rmse = Column(Float) # Root Mean Square Error
|
||||
mae = Column(Float) # Mean Absolute Error
|
||||
r2_score = Column(Float) # R-squared score
|
||||
mae = Column(Float, nullable=True) # Mean Absolute Error
|
||||
mse = Column(Float, nullable=True) # Mean Squared Error
|
||||
rmse = Column(Float, nullable=True) # Root Mean Squared Error
|
||||
mape = Column(Float, nullable=True) # Mean Absolute Percentage Error
|
||||
r2_score = Column(Float, nullable=True) # R-squared score
|
||||
|
||||
# Training details
|
||||
training_samples = Column(Integer)
|
||||
validation_samples = Column(Integer)
|
||||
features_used = Column(JSON)
|
||||
hyperparameters = Column(JSON)
|
||||
# Additional metrics
|
||||
accuracy_percentage = Column(Float, nullable=True)
|
||||
prediction_confidence = Column(Float, nullable=True)
|
||||
|
||||
# Evaluation information
|
||||
evaluation_period_start = Column(DateTime, nullable=True)
|
||||
evaluation_period_end = Column(DateTime, nullable=True)
|
||||
evaluation_samples = Column(Integer, nullable=True)
|
||||
|
||||
# Metadata
|
||||
measured_at = Column(DateTime, default=datetime.now)
|
||||
created_at = Column(DateTime, default=datetime.now)
|
||||
|
||||
class TrainingJobQueue(Base):
|
||||
"""
|
||||
Table to manage training job queue and scheduling.
|
||||
"""
|
||||
__tablename__ = "training_job_queue"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
job_id = Column(String(255), unique=True, index=True, nullable=False)
|
||||
tenant_id = Column(String(255), index=True, nullable=False)
|
||||
|
||||
# Job configuration
|
||||
job_type = Column(String(50), nullable=False) # full_training, single_product, evaluation
|
||||
priority = Column(Integer, default=1) # Higher number = higher priority
|
||||
config = Column(JSON, nullable=True)
|
||||
|
||||
# Scheduling information
|
||||
scheduled_at = Column(DateTime, nullable=True)
|
||||
started_at = Column(DateTime, nullable=True)
|
||||
estimated_duration_minutes = Column(Integer, nullable=True)
|
||||
|
||||
# Status
|
||||
is_active = Column(Boolean, default=True)
|
||||
last_used_at = Column(DateTime)
|
||||
status = Column(String(50), nullable=False, default="queued") # queued, running, completed, failed
|
||||
retry_count = Column(Integer, default=0)
|
||||
max_retries = Column(Integer, default=3)
|
||||
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<TrainedModel(id={self.id}, product={self.product_name}, tenant={self.tenant_id})>"
|
||||
# Metadata
|
||||
created_at = Column(DateTime, default=datetime.now)
|
||||
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
|
||||
|
||||
class TrainingLog(Base):
|
||||
"""Training log entries - FIXED: renamed metadata to log_metadata"""
|
||||
__tablename__ = "training_logs"
|
||||
class ModelArtifact(Base):
|
||||
"""
|
||||
Table to track model files and artifacts.
|
||||
"""
|
||||
__tablename__ = "model_artifacts"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
training_job_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
model_id = Column(String(255), index=True, nullable=False)
|
||||
tenant_id = Column(String(255), index=True, nullable=False)
|
||||
|
||||
level = Column(String(10), nullable=False) # DEBUG, INFO, WARNING, ERROR
|
||||
message = Column(Text, nullable=False)
|
||||
step = Column(String(100))
|
||||
progress = Column(Integer)
|
||||
# Artifact information
|
||||
artifact_type = Column(String(50), nullable=False) # model_file, metadata, training_data, etc.
|
||||
file_path = Column(String(1000), nullable=False)
|
||||
file_size_bytes = Column(Integer, nullable=True)
|
||||
checksum = Column(String(255), nullable=True) # For file integrity
|
||||
|
||||
# Additional data
|
||||
execution_time = Column(Float) # Time taken for this step
|
||||
memory_usage = Column(Float) # Memory usage in MB
|
||||
log_metadata = Column(JSON) # FIXED: renamed from 'metadata' to 'log_metadata'
|
||||
# Storage information
|
||||
storage_location = Column(String(100), nullable=False, default="local") # local, s3, gcs, etc.
|
||||
compression = Column(String(50), nullable=True) # gzip, lz4, etc.
|
||||
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<TrainingLog(id={self.id}, level={self.level})>"
|
||||
# Metadata
|
||||
created_at = Column(DateTime, default=datetime.now)
|
||||
expires_at = Column(DateTime, nullable=True) # For automatic cleanup
|
||||
Reference in New Issue
Block a user