Few fixes
This commit is contained in:
@@ -3,7 +3,7 @@ Training service configuration
|
||||
"""
|
||||
|
||||
import os
|
||||
from pydantic import BaseSettings
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings"""
|
||||
|
||||
0
services/training/app/models/__init__.py
Normal file
0
services/training/app/models/__init__.py
Normal file
101
services/training/app/models/training.py
Normal file
101
services/training/app/models/training.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Training models
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, Integer, DateTime, Text, JSON, Boolean, Float
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
from shared.database.base import Base
|
||||
|
||||
class TrainingJob(Base):
|
||||
"""Training job model"""
|
||||
__tablename__ = "training_jobs"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
status = Column(String(20), nullable=False, default="queued") # queued, running, completed, failed
|
||||
progress = Column(Integer, default=0)
|
||||
current_step = Column(String(200))
|
||||
requested_by = Column(UUID(as_uuid=True), nullable=False)
|
||||
|
||||
# Timing
|
||||
started_at = Column(DateTime, default=datetime.utcnow)
|
||||
completed_at = Column(DateTime)
|
||||
duration_seconds = Column(Integer)
|
||||
|
||||
# Results
|
||||
models_trained = Column(JSON)
|
||||
metrics = Column(JSON)
|
||||
error_message = Column(Text)
|
||||
|
||||
# Metadata
|
||||
training_data_from = Column(DateTime)
|
||||
training_data_to = Column(DateTime)
|
||||
total_data_points = Column(Integer)
|
||||
products_count = Column(Integer)
|
||||
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<TrainingJob(id={self.id}, tenant_id={self.tenant_id}, status={self.status})>"
|
||||
|
||||
class TrainedModel(Base):
|
||||
"""Trained model information"""
|
||||
__tablename__ = "trained_models"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
training_job_id = Column(UUID(as_uuid=True), nullable=False)
|
||||
|
||||
# Model details
|
||||
product_name = Column(String(100), nullable=False)
|
||||
model_type = Column(String(50), nullable=False, default="prophet")
|
||||
model_version = Column(String(20), nullable=False)
|
||||
model_path = Column(String(500)) # Path to saved model file
|
||||
|
||||
# Performance metrics
|
||||
mape = Column(Float) # Mean Absolute Percentage Error
|
||||
rmse = Column(Float) # Root Mean Square Error
|
||||
mae = Column(Float) # Mean Absolute Error
|
||||
r2_score = Column(Float) # R-squared score
|
||||
|
||||
# Training details
|
||||
training_samples = Column(Integer)
|
||||
validation_samples = Column(Integer)
|
||||
features_used = Column(JSON)
|
||||
hyperparameters = Column(JSON)
|
||||
|
||||
# Status
|
||||
is_active = Column(Boolean, default=True)
|
||||
last_used_at = Column(DateTime)
|
||||
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<TrainedModel(id={self.id}, product={self.product_name}, tenant={self.tenant_id})>"
|
||||
|
||||
class TrainingLog(Base):
|
||||
"""Training log entries"""
|
||||
__tablename__ = "training_logs"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
training_job_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
|
||||
level = Column(String(10), nullable=False) # DEBUG, INFO, WARNING, ERROR
|
||||
message = Column(Text, nullable=False)
|
||||
step = Column(String(100))
|
||||
progress = Column(Integer)
|
||||
|
||||
# Additional data
|
||||
execution_time = Column(Float) # Time taken for this step
|
||||
memory_usage = Column(Float) # Memory usage in MB
|
||||
metadata = Column(JSON) # Additional metadata
|
||||
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<TrainingLog(id={self.id}, level={self.level})>"
|
||||
@@ -4,47 +4,3 @@ Messaging service for training service
|
||||
|
||||
from shared.messaging.rabbitmq import RabbitMQClient
|
||||
from app.core.config import settings
|
||||
|
||||
# Global message publisher
|
||||
message_publisher = RabbitMQClient(settings.RABBITMQ_URL)
|
||||
|
||||
|
||||
# services/training/Dockerfile
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
g++ \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements
|
||||
COPY requirements.txt .
|
||||
|
||||
# Install Python dependencies
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy shared libraries
|
||||
COPY --from=shared /shared /app/shared
|
||||
|
||||
# Copy application code
|
||||
COPY . .
|
||||
|
||||
# Create model storage directory
|
||||
RUN mkdir -p /app/models
|
||||
|
||||
# Add shared libraries to Python path
|
||||
ENV PYTHONPATH="/app:/app/shared:$PYTHONPATH"
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD curl -f http://localhost:8000/health || exit 1
|
||||
|
||||
# Run application
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
384
services/training/app/services/training_service.py
Normal file
384
services/training/app/services/training_service.py
Normal file
@@ -0,0 +1,384 @@
|
||||
"""
|
||||
Training service business logic
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, update, and_
|
||||
import httpx
|
||||
import uuid
|
||||
import json
|
||||
|
||||
from app.core.config import settings
|
||||
from app.models.training import TrainingJob, TrainedModel, TrainingLog
|
||||
from app.schemas.training import TrainingRequest, TrainingJobResponse, TrainedModelResponse
|
||||
from app.ml.trainer import MLTrainer
|
||||
from app.services.messaging import message_publisher
|
||||
from shared.messaging.events import TrainingStartedEvent, TrainingCompletedEvent, TrainingFailedEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TrainingService:
|
||||
"""Training service business logic"""
|
||||
|
||||
def __init__(self):
|
||||
self.ml_trainer = MLTrainer()
|
||||
|
||||
async def start_training(self, request: TrainingRequest, user_data: dict, db: AsyncSession) -> TrainingJobResponse:
|
||||
"""Start a new training job"""
|
||||
|
||||
tenant_id = user_data.get("tenant_id")
|
||||
if not tenant_id:
|
||||
raise ValueError("User must be associated with a tenant")
|
||||
|
||||
# Check if there's already a running job for this tenant
|
||||
existing_job = await self._get_running_job(tenant_id, db)
|
||||
if existing_job:
|
||||
raise ValueError("Training job already running for this tenant")
|
||||
|
||||
# Create training job
|
||||
training_job = TrainingJob(
|
||||
tenant_id=tenant_id,
|
||||
status="queued",
|
||||
progress=0,
|
||||
current_step="Queued for training",
|
||||
requested_by=user_data.get("user_id"),
|
||||
training_data_from=datetime.utcnow() - timedelta(days=request.training_days),
|
||||
training_data_to=datetime.utcnow()
|
||||
)
|
||||
|
||||
db.add(training_job)
|
||||
await db.commit()
|
||||
await db.refresh(training_job)
|
||||
|
||||
# Start training in background
|
||||
asyncio.create_task(self._execute_training(training_job.id, request, db))
|
||||
|
||||
# Publish training started event
|
||||
await message_publisher.publish_event(
|
||||
"training_events",
|
||||
"training.started",
|
||||
TrainingStartedEvent(
|
||||
event_id=str(uuid.uuid4()),
|
||||
service_name="training-service",
|
||||
timestamp=datetime.utcnow(),
|
||||
data={
|
||||
"job_id": str(training_job.id),
|
||||
"tenant_id": tenant_id,
|
||||
"requested_by": user_data.get("user_id"),
|
||||
"training_days": request.training_days
|
||||
}
|
||||
).__dict__
|
||||
)
|
||||
|
||||
logger.info(f"Training job started: {training_job.id} for tenant: {tenant_id}")
|
||||
|
||||
return TrainingJobResponse(
|
||||
id=str(training_job.id),
|
||||
tenant_id=tenant_id,
|
||||
status=training_job.status,
|
||||
progress=training_job.progress,
|
||||
current_step=training_job.current_step,
|
||||
started_at=training_job.started_at,
|
||||
completed_at=training_job.completed_at,
|
||||
duration_seconds=training_job.duration_seconds,
|
||||
models_trained=training_job.models_trained,
|
||||
metrics=training_job.metrics,
|
||||
error_message=training_job.error_message
|
||||
)
|
||||
|
||||
async def get_training_status(self, job_id: str, user_data: dict, db: AsyncSession) -> TrainingJobResponse:
|
||||
"""Get training job status"""
|
||||
|
||||
tenant_id = user_data.get("tenant_id")
|
||||
|
||||
result = await db.execute(
|
||||
select(TrainingJob).where(
|
||||
and_(
|
||||
TrainingJob.id == job_id,
|
||||
TrainingJob.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
job = result.scalar_one_or_none()
|
||||
if not job:
|
||||
raise ValueError("Training job not found")
|
||||
|
||||
return TrainingJobResponse(
|
||||
id=str(job.id),
|
||||
tenant_id=str(job.tenant_id),
|
||||
status=job.status,
|
||||
progress=job.progress,
|
||||
current_step=job.current_step,
|
||||
started_at=job.started_at,
|
||||
completed_at=job.completed_at,
|
||||
duration_seconds=job.duration_seconds,
|
||||
models_trained=job.models_trained,
|
||||
metrics=job.metrics,
|
||||
error_message=job.error_message
|
||||
)
|
||||
|
||||
async def get_trained_models(self, user_data: dict, db: AsyncSession) -> List[TrainedModelResponse]:
|
||||
"""Get trained models for tenant"""
|
||||
|
||||
tenant_id = user_data.get("tenant_id")
|
||||
|
||||
result = await db.execute(
|
||||
select(TrainedModel).where(
|
||||
and_(
|
||||
TrainedModel.tenant_id == tenant_id,
|
||||
TrainedModel.is_active == True
|
||||
)
|
||||
).order_by(TrainedModel.created_at.desc())
|
||||
)
|
||||
|
||||
models = result.scalars().all()
|
||||
|
||||
return [
|
||||
TrainedModelResponse(
|
||||
id=str(model.id),
|
||||
product_name=model.product_name,
|
||||
model_type=model.model_type,
|
||||
model_version=model.model_version,
|
||||
mape=model.mape,
|
||||
rmse=model.rmse,
|
||||
mae=model.mae,
|
||||
r2_score=model.r2_score,
|
||||
training_samples=model.training_samples,
|
||||
features_used=model.features_used,
|
||||
is_active=model.is_active,
|
||||
created_at=model.created_at,
|
||||
last_used_at=model.last_used_at
|
||||
)
|
||||
for model in models
|
||||
]
|
||||
|
||||
async def get_training_jobs(self, user_data: dict, limit: int, offset: int, db: AsyncSession) -> List[TrainingJobResponse]:
|
||||
"""Get training jobs for tenant"""
|
||||
|
||||
tenant_id = user_data.get("tenant_id")
|
||||
|
||||
result = await db.execute(
|
||||
select(TrainingJob).where(
|
||||
TrainingJob.tenant_id == tenant_id
|
||||
).order_by(TrainingJob.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
|
||||
jobs = result.scalars().all()
|
||||
|
||||
return [
|
||||
TrainingJobResponse(
|
||||
id=str(job.id),
|
||||
tenant_id=str(job.tenant_id),
|
||||
status=job.status,
|
||||
progress=job.progress,
|
||||
current_step=job.current_step,
|
||||
started_at=job.started_at,
|
||||
completed_at=job.completed_at,
|
||||
duration_seconds=job.duration_seconds,
|
||||
models_trained=job.models_trained,
|
||||
metrics=job.metrics,
|
||||
error_message=job.error_message
|
||||
)
|
||||
for job in jobs
|
||||
]
|
||||
|
||||
async def _get_running_job(self, tenant_id: str, db: AsyncSession) -> Optional[TrainingJob]:
|
||||
"""Get running training job for tenant"""
|
||||
|
||||
result = await db.execute(
|
||||
select(TrainingJob).where(
|
||||
and_(
|
||||
TrainingJob.tenant_id == tenant_id,
|
||||
TrainingJob.status.in_(["queued", "running"])
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def _execute_training(self, job_id: str, request: TrainingRequest, db: AsyncSession):
|
||||
"""Execute training job"""
|
||||
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
# Update job status
|
||||
await self._update_job_status(job_id, "running", 0, "Starting training...", db)
|
||||
|
||||
# Get training data
|
||||
await self._update_job_status(job_id, "running", 10, "Fetching training data...", db)
|
||||
training_data = await self._get_training_data(job_id, request, db)
|
||||
|
||||
# Train models
|
||||
await self._update_job_status(job_id, "running", 30, "Training models...", db)
|
||||
models_result = await self.ml_trainer.train_models(training_data, job_id, db)
|
||||
|
||||
# Validate models
|
||||
await self._update_job_status(job_id, "running", 80, "Validating models...", db)
|
||||
validation_result = await self.ml_trainer.validate_models(models_result, db)
|
||||
|
||||
# Save models
|
||||
await self._update_job_status(job_id, "running", 90, "Saving models...", db)
|
||||
await self._save_trained_models(job_id, models_result, validation_result, db)
|
||||
|
||||
# Complete job
|
||||
duration = int((datetime.utcnow() - start_time).total_seconds())
|
||||
await self._complete_job(job_id, models_result, validation_result, duration, db)
|
||||
|
||||
# Publish completion event
|
||||
await message_publisher.publish_event(
|
||||
"training_events",
|
||||
"training.completed",
|
||||
TrainingCompletedEvent(
|
||||
event_id=str(uuid.uuid4()),
|
||||
service_name="training-service",
|
||||
timestamp=datetime.utcnow(),
|
||||
data={
|
||||
"job_id": str(job_id),
|
||||
"models_trained": len(models_result),
|
||||
"duration_seconds": duration
|
||||
}
|
||||
).__dict__
|
||||
)
|
||||
|
||||
logger.info(f"Training job completed: {job_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training job failed: {job_id} - {e}")
|
||||
|
||||
# Update job as failed
|
||||
await self._update_job_status(job_id, "failed", 0, f"Training failed: {str(e)}", db)
|
||||
|
||||
# Publish failure event
|
||||
await message_publisher.publish_event(
|
||||
"training_events",
|
||||
"training.failed",
|
||||
TrainingFailedEvent(
|
||||
event_id=str(uuid.uuid4()),
|
||||
service_name="training-service",
|
||||
timestamp=datetime.utcnow(),
|
||||
data={
|
||||
"job_id": str(job_id),
|
||||
"error": str(e)
|
||||
}
|
||||
).__dict__
|
||||
)
|
||||
|
||||
async def _update_job_status(self, job_id: str, status: str, progress: int, current_step: str, db: AsyncSession):
|
||||
"""Update training job status"""
|
||||
|
||||
await db.execute(
|
||||
update(TrainingJob)
|
||||
.where(TrainingJob.id == job_id)
|
||||
.values(
|
||||
status=status,
|
||||
progress=progress,
|
||||
current_step=current_step,
|
||||
updated_at=datetime.utcnow()
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
async def _get_training_data(self, job_id: str, request: TrainingRequest, db: AsyncSession) -> Dict[str, Any]:
|
||||
"""Get training data from data service"""
|
||||
|
||||
# Get job details
|
||||
result = await db.execute(
|
||||
select(TrainingJob).where(TrainingJob.id == job_id)
|
||||
)
|
||||
job = result.scalar_one()
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.DATA_SERVICE_URL}/training-data/{job.tenant_id}",
|
||||
params={
|
||||
"from_date": job.training_data_from.isoformat(),
|
||||
"to_date": job.training_data_to.isoformat(),
|
||||
"products": request.products
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
raise Exception(f"Failed to get training data: {response.status_code}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting training data: {e}")
|
||||
raise
|
||||
|
||||
async def _save_trained_models(self, job_id: str, models_result: Dict[str, Any], validation_result: Dict[str, Any], db: AsyncSession):
|
||||
"""Save trained models to database"""
|
||||
|
||||
# Get job details
|
||||
result = await db.execute(
|
||||
select(TrainingJob).where(TrainingJob.id == job_id)
|
||||
)
|
||||
job = result.scalar_one()
|
||||
|
||||
# Deactivate old models
|
||||
await db.execute(
|
||||
update(TrainedModel)
|
||||
.where(TrainedModel.tenant_id == job.tenant_id)
|
||||
.values(is_active=False)
|
||||
)
|
||||
|
||||
# Save new models
|
||||
for product_name, model_data in models_result.items():
|
||||
validation_data = validation_result.get(product_name, {})
|
||||
|
||||
trained_model = TrainedModel(
|
||||
tenant_id=job.tenant_id,
|
||||
training_job_id=job_id,
|
||||
product_name=product_name,
|
||||
model_type=model_data.get("type", "prophet"),
|
||||
model_version="1.0",
|
||||
model_path=model_data.get("path"),
|
||||
mape=validation_data.get("mape"),
|
||||
rmse=validation_data.get("rmse"),
|
||||
mae=validation_data.get("mae"),
|
||||
r2_score=validation_data.get("r2_score"),
|
||||
training_samples=model_data.get("training_samples"),
|
||||
features_used=model_data.get("features", []),
|
||||
hyperparameters=model_data.get("hyperparameters", {}),
|
||||
is_active=True
|
||||
)
|
||||
|
||||
db.add(trained_model)
|
||||
|
||||
await db.commit()
|
||||
|
||||
async def _complete_job(self, job_id: str, models_result: Dict[str, Any], validation_result: Dict[str, Any], duration: int, db: AsyncSession):
|
||||
"""Complete training job"""
|
||||
|
||||
# Calculate metrics
|
||||
metrics = {
|
||||
"models_trained": len(models_result),
|
||||
"average_mape": sum(v.get("mape", 0) for v in validation_result.values()) / len(validation_result) if validation_result else 0,
|
||||
"training_duration": duration,
|
||||
"validation_results": validation_result
|
||||
}
|
||||
|
||||
await db.execute(
|
||||
update(TrainingJob)
|
||||
.where(TrainingJob.id == job_id)
|
||||
.values(
|
||||
status="completed",
|
||||
progress=100,
|
||||
current_step="Training completed successfully",
|
||||
completed_at=datetime.utcnow(),
|
||||
duration_seconds=duration,
|
||||
models_trained=models_result,
|
||||
metrics=metrics,
|
||||
products_count=len(models_result)
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
Reference in New Issue
Block a user