Few fixes

This commit is contained in:
Urtzi Alfaro
2025-07-17 14:34:24 +02:00
parent 5bb3e93da4
commit cb80a93c4b
36 changed files with 1512 additions and 141 deletions

View File

@@ -3,7 +3,7 @@ Training service configuration
"""
import os
from pydantic import BaseSettings
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
"""Application settings"""

View File

View File

@@ -0,0 +1,101 @@
"""
Training models
"""
from sqlalchemy import Column, String, Integer, DateTime, Text, JSON, Boolean, Float
from sqlalchemy.dialects.postgresql import UUID
from datetime import datetime
import uuid
from shared.database.base import Base
class TrainingJob(Base):
"""Training job model"""
__tablename__ = "training_jobs"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
status = Column(String(20), nullable=False, default="queued") # queued, running, completed, failed
progress = Column(Integer, default=0)
current_step = Column(String(200))
requested_by = Column(UUID(as_uuid=True), nullable=False)
# Timing
started_at = Column(DateTime, default=datetime.utcnow)
completed_at = Column(DateTime)
duration_seconds = Column(Integer)
# Results
models_trained = Column(JSON)
metrics = Column(JSON)
error_message = Column(Text)
# Metadata
training_data_from = Column(DateTime)
training_data_to = Column(DateTime)
total_data_points = Column(Integer)
products_count = Column(Integer)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
def __repr__(self):
return f"<TrainingJob(id={self.id}, tenant_id={self.tenant_id}, status={self.status})>"
class TrainedModel(Base):
"""Trained model information"""
__tablename__ = "trained_models"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
training_job_id = Column(UUID(as_uuid=True), nullable=False)
# Model details
product_name = Column(String(100), nullable=False)
model_type = Column(String(50), nullable=False, default="prophet")
model_version = Column(String(20), nullable=False)
model_path = Column(String(500)) # Path to saved model file
# Performance metrics
mape = Column(Float) # Mean Absolute Percentage Error
rmse = Column(Float) # Root Mean Square Error
mae = Column(Float) # Mean Absolute Error
r2_score = Column(Float) # R-squared score
# Training details
training_samples = Column(Integer)
validation_samples = Column(Integer)
features_used = Column(JSON)
hyperparameters = Column(JSON)
# Status
is_active = Column(Boolean, default=True)
last_used_at = Column(DateTime)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
def __repr__(self):
return f"<TrainedModel(id={self.id}, product={self.product_name}, tenant={self.tenant_id})>"
class TrainingLog(Base):
"""Training log entries"""
__tablename__ = "training_logs"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
training_job_id = Column(UUID(as_uuid=True), nullable=False, index=True)
level = Column(String(10), nullable=False) # DEBUG, INFO, WARNING, ERROR
message = Column(Text, nullable=False)
step = Column(String(100))
progress = Column(Integer)
# Additional data
execution_time = Column(Float) # Time taken for this step
memory_usage = Column(Float) # Memory usage in MB
metadata = Column(JSON) # Additional metadata
created_at = Column(DateTime, default=datetime.utcnow)
def __repr__(self):
return f"<TrainingLog(id={self.id}, level={self.level})>"

View File

@@ -4,47 +4,3 @@ Messaging service for training service
from shared.messaging.rabbitmq import RabbitMQClient
from app.core.config import settings
# Global message publisher
message_publisher = RabbitMQClient(settings.RABBITMQ_URL)
# services/training/Dockerfile
FROM python:3.11-slim
WORKDIR /app
# Install system dependencies
RUN apt-get update && apt-get install -y \
gcc \
g++ \
curl \
&& rm -rf /var/lib/apt/lists/*
# Copy requirements
COPY requirements.txt .
# Install Python dependencies
RUN pip install --no-cache-dir -r requirements.txt
# Copy shared libraries
COPY --from=shared /shared /app/shared
# Copy application code
COPY . .
# Create model storage directory
RUN mkdir -p /app/models
# Add shared libraries to Python path
ENV PYTHONPATH="/app:/app/shared:$PYTHONPATH"
# Expose port
EXPOSE 8000
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1
# Run application
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]

View File

@@ -0,0 +1,384 @@
"""
Training service business logic
"""
import asyncio
import logging
from datetime import datetime, timedelta
from typing import Dict, Any, List, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, and_
import httpx
import uuid
import json
from app.core.config import settings
from app.models.training import TrainingJob, TrainedModel, TrainingLog
from app.schemas.training import TrainingRequest, TrainingJobResponse, TrainedModelResponse
from app.ml.trainer import MLTrainer
from app.services.messaging import message_publisher
from shared.messaging.events import TrainingStartedEvent, TrainingCompletedEvent, TrainingFailedEvent
logger = logging.getLogger(__name__)
class TrainingService:
"""Training service business logic"""
def __init__(self):
self.ml_trainer = MLTrainer()
async def start_training(self, request: TrainingRequest, user_data: dict, db: AsyncSession) -> TrainingJobResponse:
"""Start a new training job"""
tenant_id = user_data.get("tenant_id")
if not tenant_id:
raise ValueError("User must be associated with a tenant")
# Check if there's already a running job for this tenant
existing_job = await self._get_running_job(tenant_id, db)
if existing_job:
raise ValueError("Training job already running for this tenant")
# Create training job
training_job = TrainingJob(
tenant_id=tenant_id,
status="queued",
progress=0,
current_step="Queued for training",
requested_by=user_data.get("user_id"),
training_data_from=datetime.utcnow() - timedelta(days=request.training_days),
training_data_to=datetime.utcnow()
)
db.add(training_job)
await db.commit()
await db.refresh(training_job)
# Start training in background
asyncio.create_task(self._execute_training(training_job.id, request, db))
# Publish training started event
await message_publisher.publish_event(
"training_events",
"training.started",
TrainingStartedEvent(
event_id=str(uuid.uuid4()),
service_name="training-service",
timestamp=datetime.utcnow(),
data={
"job_id": str(training_job.id),
"tenant_id": tenant_id,
"requested_by": user_data.get("user_id"),
"training_days": request.training_days
}
).__dict__
)
logger.info(f"Training job started: {training_job.id} for tenant: {tenant_id}")
return TrainingJobResponse(
id=str(training_job.id),
tenant_id=tenant_id,
status=training_job.status,
progress=training_job.progress,
current_step=training_job.current_step,
started_at=training_job.started_at,
completed_at=training_job.completed_at,
duration_seconds=training_job.duration_seconds,
models_trained=training_job.models_trained,
metrics=training_job.metrics,
error_message=training_job.error_message
)
async def get_training_status(self, job_id: str, user_data: dict, db: AsyncSession) -> TrainingJobResponse:
"""Get training job status"""
tenant_id = user_data.get("tenant_id")
result = await db.execute(
select(TrainingJob).where(
and_(
TrainingJob.id == job_id,
TrainingJob.tenant_id == tenant_id
)
)
)
job = result.scalar_one_or_none()
if not job:
raise ValueError("Training job not found")
return TrainingJobResponse(
id=str(job.id),
tenant_id=str(job.tenant_id),
status=job.status,
progress=job.progress,
current_step=job.current_step,
started_at=job.started_at,
completed_at=job.completed_at,
duration_seconds=job.duration_seconds,
models_trained=job.models_trained,
metrics=job.metrics,
error_message=job.error_message
)
async def get_trained_models(self, user_data: dict, db: AsyncSession) -> List[TrainedModelResponse]:
"""Get trained models for tenant"""
tenant_id = user_data.get("tenant_id")
result = await db.execute(
select(TrainedModel).where(
and_(
TrainedModel.tenant_id == tenant_id,
TrainedModel.is_active == True
)
).order_by(TrainedModel.created_at.desc())
)
models = result.scalars().all()
return [
TrainedModelResponse(
id=str(model.id),
product_name=model.product_name,
model_type=model.model_type,
model_version=model.model_version,
mape=model.mape,
rmse=model.rmse,
mae=model.mae,
r2_score=model.r2_score,
training_samples=model.training_samples,
features_used=model.features_used,
is_active=model.is_active,
created_at=model.created_at,
last_used_at=model.last_used_at
)
for model in models
]
async def get_training_jobs(self, user_data: dict, limit: int, offset: int, db: AsyncSession) -> List[TrainingJobResponse]:
"""Get training jobs for tenant"""
tenant_id = user_data.get("tenant_id")
result = await db.execute(
select(TrainingJob).where(
TrainingJob.tenant_id == tenant_id
).order_by(TrainingJob.created_at.desc())
.limit(limit)
.offset(offset)
)
jobs = result.scalars().all()
return [
TrainingJobResponse(
id=str(job.id),
tenant_id=str(job.tenant_id),
status=job.status,
progress=job.progress,
current_step=job.current_step,
started_at=job.started_at,
completed_at=job.completed_at,
duration_seconds=job.duration_seconds,
models_trained=job.models_trained,
metrics=job.metrics,
error_message=job.error_message
)
for job in jobs
]
async def _get_running_job(self, tenant_id: str, db: AsyncSession) -> Optional[TrainingJob]:
"""Get running training job for tenant"""
result = await db.execute(
select(TrainingJob).where(
and_(
TrainingJob.tenant_id == tenant_id,
TrainingJob.status.in_(["queued", "running"])
)
)
)
return result.scalar_one_or_none()
async def _execute_training(self, job_id: str, request: TrainingRequest, db: AsyncSession):
"""Execute training job"""
start_time = datetime.utcnow()
try:
# Update job status
await self._update_job_status(job_id, "running", 0, "Starting training...", db)
# Get training data
await self._update_job_status(job_id, "running", 10, "Fetching training data...", db)
training_data = await self._get_training_data(job_id, request, db)
# Train models
await self._update_job_status(job_id, "running", 30, "Training models...", db)
models_result = await self.ml_trainer.train_models(training_data, job_id, db)
# Validate models
await self._update_job_status(job_id, "running", 80, "Validating models...", db)
validation_result = await self.ml_trainer.validate_models(models_result, db)
# Save models
await self._update_job_status(job_id, "running", 90, "Saving models...", db)
await self._save_trained_models(job_id, models_result, validation_result, db)
# Complete job
duration = int((datetime.utcnow() - start_time).total_seconds())
await self._complete_job(job_id, models_result, validation_result, duration, db)
# Publish completion event
await message_publisher.publish_event(
"training_events",
"training.completed",
TrainingCompletedEvent(
event_id=str(uuid.uuid4()),
service_name="training-service",
timestamp=datetime.utcnow(),
data={
"job_id": str(job_id),
"models_trained": len(models_result),
"duration_seconds": duration
}
).__dict__
)
logger.info(f"Training job completed: {job_id}")
except Exception as e:
logger.error(f"Training job failed: {job_id} - {e}")
# Update job as failed
await self._update_job_status(job_id, "failed", 0, f"Training failed: {str(e)}", db)
# Publish failure event
await message_publisher.publish_event(
"training_events",
"training.failed",
TrainingFailedEvent(
event_id=str(uuid.uuid4()),
service_name="training-service",
timestamp=datetime.utcnow(),
data={
"job_id": str(job_id),
"error": str(e)
}
).__dict__
)
async def _update_job_status(self, job_id: str, status: str, progress: int, current_step: str, db: AsyncSession):
"""Update training job status"""
await db.execute(
update(TrainingJob)
.where(TrainingJob.id == job_id)
.values(
status=status,
progress=progress,
current_step=current_step,
updated_at=datetime.utcnow()
)
)
await db.commit()
async def _get_training_data(self, job_id: str, request: TrainingRequest, db: AsyncSession) -> Dict[str, Any]:
"""Get training data from data service"""
# Get job details
result = await db.execute(
select(TrainingJob).where(TrainingJob.id == job_id)
)
job = result.scalar_one()
try:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{settings.DATA_SERVICE_URL}/training-data/{job.tenant_id}",
params={
"from_date": job.training_data_from.isoformat(),
"to_date": job.training_data_to.isoformat(),
"products": request.products
}
)
if response.status_code == 200:
return response.json()
else:
raise Exception(f"Failed to get training data: {response.status_code}")
except Exception as e:
logger.error(f"Error getting training data: {e}")
raise
async def _save_trained_models(self, job_id: str, models_result: Dict[str, Any], validation_result: Dict[str, Any], db: AsyncSession):
"""Save trained models to database"""
# Get job details
result = await db.execute(
select(TrainingJob).where(TrainingJob.id == job_id)
)
job = result.scalar_one()
# Deactivate old models
await db.execute(
update(TrainedModel)
.where(TrainedModel.tenant_id == job.tenant_id)
.values(is_active=False)
)
# Save new models
for product_name, model_data in models_result.items():
validation_data = validation_result.get(product_name, {})
trained_model = TrainedModel(
tenant_id=job.tenant_id,
training_job_id=job_id,
product_name=product_name,
model_type=model_data.get("type", "prophet"),
model_version="1.0",
model_path=model_data.get("path"),
mape=validation_data.get("mape"),
rmse=validation_data.get("rmse"),
mae=validation_data.get("mae"),
r2_score=validation_data.get("r2_score"),
training_samples=model_data.get("training_samples"),
features_used=model_data.get("features", []),
hyperparameters=model_data.get("hyperparameters", {}),
is_active=True
)
db.add(trained_model)
await db.commit()
async def _complete_job(self, job_id: str, models_result: Dict[str, Any], validation_result: Dict[str, Any], duration: int, db: AsyncSession):
"""Complete training job"""
# Calculate metrics
metrics = {
"models_trained": len(models_result),
"average_mape": sum(v.get("mape", 0) for v in validation_result.values()) / len(validation_result) if validation_result else 0,
"training_duration": duration,
"validation_results": validation_result
}
await db.execute(
update(TrainingJob)
.where(TrainingJob.id == job_id)
.values(
status="completed",
progress=100,
current_step="Training completed successfully",
completed_at=datetime.utcnow(),
duration_seconds=duration,
models_trained=models_result,
metrics=metrics,
products_count=len(models_result)
)
)
await db.commit()