Files
2025-11-14 20:27:39 +01:00

254 lines
11 KiB
Python

# services/training/app/models/training.py
"""
Database models for training service
"""
from sqlalchemy import Column, Integer, String, DateTime, Text, Boolean, JSON, Float
from sqlalchemy.dialects.postgresql import UUID, ARRAY
from shared.database.base import Base
from datetime import datetime, timezone
import uuid
class ModelTrainingLog(Base):
"""
Table to track training job execution and status.
Replaces the old Celery task tracking.
"""
__tablename__ = "model_training_logs"
id = Column(Integer, primary_key=True, index=True)
job_id = Column(String(255), unique=True, index=True, nullable=False)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
status = Column(String(50), nullable=False, default="pending") # pending, running, completed, failed, cancelled
progress = Column(Integer, default=0) # 0-100 percentage
current_step = Column(String(500), default="")
# Timestamps
start_time = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
end_time = Column(DateTime(timezone=True), nullable=True)
# Configuration and results
config = Column(JSON, nullable=True) # Training job configuration
results = Column(JSON, nullable=True) # Training results
error_message = Column(Text, nullable=True)
# Metadata
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
class ModelPerformanceMetric(Base):
"""
Table to track model performance over time.
"""
__tablename__ = "model_performance_metrics"
id = Column(Integer, primary_key=True, index=True)
model_id = Column(String(255), index=True, nullable=False)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
inventory_product_id = Column(UUID(as_uuid=True), index=True, nullable=False)
# Performance metrics
mae = Column(Float, nullable=True) # Mean Absolute Error
mse = Column(Float, nullable=True) # Mean Squared Error
rmse = Column(Float, nullable=True) # Root Mean Squared Error
mape = Column(Float, nullable=True) # Mean Absolute Percentage Error
r2_score = Column(Float, nullable=True) # R-squared score
# Additional metrics
accuracy_percentage = Column(Float, nullable=True)
prediction_confidence = Column(Float, nullable=True)
# Evaluation information
evaluation_period_start = Column(DateTime, nullable=True)
evaluation_period_end = Column(DateTime, nullable=True)
evaluation_samples = Column(Integer, nullable=True)
# Metadata
measured_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
class TrainingJobQueue(Base):
"""
Table to manage training job queue and scheduling.
"""
__tablename__ = "training_job_queue"
id = Column(Integer, primary_key=True, index=True)
job_id = Column(String(255), unique=True, index=True, nullable=False)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
# Job configuration
job_type = Column(String(50), nullable=False) # full_training, single_product, evaluation
priority = Column(Integer, default=1) # Higher number = higher priority
config = Column(JSON, nullable=True)
# Scheduling information
scheduled_at = Column(DateTime, nullable=True)
started_at = Column(DateTime, nullable=True)
estimated_duration_minutes = Column(Integer, nullable=True)
# Status
status = Column(String(50), nullable=False, default="queued") # queued, running, completed, failed
retry_count = Column(Integer, default=0)
max_retries = Column(Integer, default=3)
# Metadata
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
cancelled_by = Column(String, nullable=True)
class ModelArtifact(Base):
"""
Table to track model files and artifacts.
"""
__tablename__ = "model_artifacts"
id = Column(Integer, primary_key=True, index=True)
model_id = Column(String(255), index=True, nullable=False)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
# Artifact information
artifact_type = Column(String(50), nullable=False) # model_file, metadata, training_data, etc.
file_path = Column(String(1000), nullable=False)
file_size_bytes = Column(Integer, nullable=True)
checksum = Column(String(255), nullable=True) # For file integrity
# Storage information
storage_location = Column(String(100), nullable=False, default="local") # local, s3, gcs, etc.
compression = Column(String(50), nullable=True) # gzip, lz4, etc.
# Metadata
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
expires_at = Column(DateTime(timezone=True), nullable=True) # For automatic cleanup
class TrainedModel(Base):
__tablename__ = "trained_models"
# Primary identification - Updated to use UUID properly
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
inventory_product_id = Column(UUID(as_uuid=True), nullable=False, index=True)
# Model information
model_type = Column(String, default="prophet_optimized")
model_version = Column(String, default="1.0")
job_id = Column(String, nullable=False)
# File storage
model_path = Column(String, nullable=False) # Path to the .pkl file
metadata_path = Column(String) # Path to metadata JSON
# Training metrics
mape = Column(Float)
mae = Column(Float)
rmse = Column(Float)
r2_score = Column(Float)
training_samples = Column(Integer)
# Hyperparameters and features
hyperparameters = Column(JSON) # Store optimized parameters
features_used = Column(JSON) # List of regressor columns
normalization_params = Column(JSON) # Store feature normalization parameters for consistent predictions
product_category = Column(String, nullable=True) # Product category for category-specific forecasting
# Model status
is_active = Column(Boolean, default=True)
is_production = Column(Boolean, default=False)
# Timestamps - Updated to be timezone-aware with proper defaults
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
last_used_at = Column(DateTime(timezone=True))
# Training data info
training_start_date = Column(DateTime(timezone=True))
training_end_date = Column(DateTime(timezone=True))
data_quality_score = Column(Float)
# Additional metadata
notes = Column(Text)
created_by = Column(String) # User who triggered training
def to_dict(self):
return {
"id": str(self.id),
"model_id": str(self.id),
"tenant_id": str(self.tenant_id),
"inventory_product_id": str(self.inventory_product_id),
"model_type": self.model_type,
"model_version": self.model_version,
"model_path": self.model_path,
"mape": self.mape,
"mae": self.mae,
"rmse": self.rmse,
"r2_score": self.r2_score,
"training_samples": self.training_samples,
"hyperparameters": self.hyperparameters,
"features_used": self.features_used,
"features": self.features_used, # Alias for frontend compatibility (ModelDetailsModal expects 'features')
"product_category": self.product_category,
"is_active": self.is_active,
"is_production": self.is_production,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
"last_used_at": self.last_used_at.isoformat() if self.last_used_at else None,
"training_start_date": self.training_start_date.isoformat() if self.training_start_date else None,
"training_end_date": self.training_end_date.isoformat() if self.training_end_date else None,
"data_quality_score": self.data_quality_score
}
class TrainingPerformanceMetrics(Base):
"""
Table to track historical training performance for time estimation.
Stores aggregated metrics from completed training jobs.
"""
__tablename__ = "training_performance_metrics"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
job_id = Column(String(255), nullable=False, index=True)
# Training job statistics
total_products = Column(Integer, nullable=False)
successful_products = Column(Integer, nullable=False)
failed_products = Column(Integer, nullable=False)
# Time metrics
total_duration_seconds = Column(Float, nullable=False)
avg_time_per_product = Column(Float, nullable=False) # Key metric for estimation
data_analysis_time_seconds = Column(Float, nullable=True)
training_time_seconds = Column(Float, nullable=True)
finalization_time_seconds = Column(Float, nullable=True)
# Job metadata
completed_at = Column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc))
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
def __repr__(self):
return (
f"<TrainingPerformanceMetrics("
f"tenant_id={self.tenant_id}, "
f"job_id={self.job_id}, "
f"total_products={self.total_products}, "
f"avg_time_per_product={self.avg_time_per_product:.2f}s"
f")>"
)
def to_dict(self):
return {
"id": str(self.id),
"tenant_id": str(self.tenant_id),
"job_id": self.job_id,
"total_products": self.total_products,
"successful_products": self.successful_products,
"failed_products": self.failed_products,
"total_duration_seconds": self.total_duration_seconds,
"avg_time_per_product": self.avg_time_per_product,
"data_analysis_time_seconds": self.data_analysis_time_seconds,
"training_time_seconds": self.training_time_seconds,
"finalization_time_seconds": self.finalization_time_seconds,
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
"created_at": self.created_at.isoformat() if self.created_at else None
}