Files
bakery-ia/services/training/app/services/training_service.py
Urtzi Alfaro cb80a93c4b Few fixes
2025-07-17 14:34:24 +02:00

384 lines
14 KiB
Python

"""
Training service business logic
"""
import asyncio
import logging
from datetime import datetime, timedelta
from typing import Dict, Any, List, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, and_
import httpx
import uuid
import json
from app.core.config import settings
from app.models.training import TrainingJob, TrainedModel, TrainingLog
from app.schemas.training import TrainingRequest, TrainingJobResponse, TrainedModelResponse
from app.ml.trainer import MLTrainer
from app.services.messaging import message_publisher
from shared.messaging.events import TrainingStartedEvent, TrainingCompletedEvent, TrainingFailedEvent
logger = logging.getLogger(__name__)
class TrainingService:
"""Training service business logic"""
def __init__(self):
self.ml_trainer = MLTrainer()
async def start_training(self, request: TrainingRequest, user_data: dict, db: AsyncSession) -> TrainingJobResponse:
"""Start a new training job"""
tenant_id = user_data.get("tenant_id")
if not tenant_id:
raise ValueError("User must be associated with a tenant")
# Check if there's already a running job for this tenant
existing_job = await self._get_running_job(tenant_id, db)
if existing_job:
raise ValueError("Training job already running for this tenant")
# Create training job
training_job = TrainingJob(
tenant_id=tenant_id,
status="queued",
progress=0,
current_step="Queued for training",
requested_by=user_data.get("user_id"),
training_data_from=datetime.utcnow() - timedelta(days=request.training_days),
training_data_to=datetime.utcnow()
)
db.add(training_job)
await db.commit()
await db.refresh(training_job)
# Start training in background
asyncio.create_task(self._execute_training(training_job.id, request, db))
# Publish training started event
await message_publisher.publish_event(
"training_events",
"training.started",
TrainingStartedEvent(
event_id=str(uuid.uuid4()),
service_name="training-service",
timestamp=datetime.utcnow(),
data={
"job_id": str(training_job.id),
"tenant_id": tenant_id,
"requested_by": user_data.get("user_id"),
"training_days": request.training_days
}
).__dict__
)
logger.info(f"Training job started: {training_job.id} for tenant: {tenant_id}")
return TrainingJobResponse(
id=str(training_job.id),
tenant_id=tenant_id,
status=training_job.status,
progress=training_job.progress,
current_step=training_job.current_step,
started_at=training_job.started_at,
completed_at=training_job.completed_at,
duration_seconds=training_job.duration_seconds,
models_trained=training_job.models_trained,
metrics=training_job.metrics,
error_message=training_job.error_message
)
async def get_training_status(self, job_id: str, user_data: dict, db: AsyncSession) -> TrainingJobResponse:
"""Get training job status"""
tenant_id = user_data.get("tenant_id")
result = await db.execute(
select(TrainingJob).where(
and_(
TrainingJob.id == job_id,
TrainingJob.tenant_id == tenant_id
)
)
)
job = result.scalar_one_or_none()
if not job:
raise ValueError("Training job not found")
return TrainingJobResponse(
id=str(job.id),
tenant_id=str(job.tenant_id),
status=job.status,
progress=job.progress,
current_step=job.current_step,
started_at=job.started_at,
completed_at=job.completed_at,
duration_seconds=job.duration_seconds,
models_trained=job.models_trained,
metrics=job.metrics,
error_message=job.error_message
)
async def get_trained_models(self, user_data: dict, db: AsyncSession) -> List[TrainedModelResponse]:
"""Get trained models for tenant"""
tenant_id = user_data.get("tenant_id")
result = await db.execute(
select(TrainedModel).where(
and_(
TrainedModel.tenant_id == tenant_id,
TrainedModel.is_active == True
)
).order_by(TrainedModel.created_at.desc())
)
models = result.scalars().all()
return [
TrainedModelResponse(
id=str(model.id),
product_name=model.product_name,
model_type=model.model_type,
model_version=model.model_version,
mape=model.mape,
rmse=model.rmse,
mae=model.mae,
r2_score=model.r2_score,
training_samples=model.training_samples,
features_used=model.features_used,
is_active=model.is_active,
created_at=model.created_at,
last_used_at=model.last_used_at
)
for model in models
]
async def get_training_jobs(self, user_data: dict, limit: int, offset: int, db: AsyncSession) -> List[TrainingJobResponse]:
"""Get training jobs for tenant"""
tenant_id = user_data.get("tenant_id")
result = await db.execute(
select(TrainingJob).where(
TrainingJob.tenant_id == tenant_id
).order_by(TrainingJob.created_at.desc())
.limit(limit)
.offset(offset)
)
jobs = result.scalars().all()
return [
TrainingJobResponse(
id=str(job.id),
tenant_id=str(job.tenant_id),
status=job.status,
progress=job.progress,
current_step=job.current_step,
started_at=job.started_at,
completed_at=job.completed_at,
duration_seconds=job.duration_seconds,
models_trained=job.models_trained,
metrics=job.metrics,
error_message=job.error_message
)
for job in jobs
]
async def _get_running_job(self, tenant_id: str, db: AsyncSession) -> Optional[TrainingJob]:
"""Get running training job for tenant"""
result = await db.execute(
select(TrainingJob).where(
and_(
TrainingJob.tenant_id == tenant_id,
TrainingJob.status.in_(["queued", "running"])
)
)
)
return result.scalar_one_or_none()
async def _execute_training(self, job_id: str, request: TrainingRequest, db: AsyncSession):
"""Execute training job"""
start_time = datetime.utcnow()
try:
# Update job status
await self._update_job_status(job_id, "running", 0, "Starting training...", db)
# Get training data
await self._update_job_status(job_id, "running", 10, "Fetching training data...", db)
training_data = await self._get_training_data(job_id, request, db)
# Train models
await self._update_job_status(job_id, "running", 30, "Training models...", db)
models_result = await self.ml_trainer.train_models(training_data, job_id, db)
# Validate models
await self._update_job_status(job_id, "running", 80, "Validating models...", db)
validation_result = await self.ml_trainer.validate_models(models_result, db)
# Save models
await self._update_job_status(job_id, "running", 90, "Saving models...", db)
await self._save_trained_models(job_id, models_result, validation_result, db)
# Complete job
duration = int((datetime.utcnow() - start_time).total_seconds())
await self._complete_job(job_id, models_result, validation_result, duration, db)
# Publish completion event
await message_publisher.publish_event(
"training_events",
"training.completed",
TrainingCompletedEvent(
event_id=str(uuid.uuid4()),
service_name="training-service",
timestamp=datetime.utcnow(),
data={
"job_id": str(job_id),
"models_trained": len(models_result),
"duration_seconds": duration
}
).__dict__
)
logger.info(f"Training job completed: {job_id}")
except Exception as e:
logger.error(f"Training job failed: {job_id} - {e}")
# Update job as failed
await self._update_job_status(job_id, "failed", 0, f"Training failed: {str(e)}", db)
# Publish failure event
await message_publisher.publish_event(
"training_events",
"training.failed",
TrainingFailedEvent(
event_id=str(uuid.uuid4()),
service_name="training-service",
timestamp=datetime.utcnow(),
data={
"job_id": str(job_id),
"error": str(e)
}
).__dict__
)
async def _update_job_status(self, job_id: str, status: str, progress: int, current_step: str, db: AsyncSession):
"""Update training job status"""
await db.execute(
update(TrainingJob)
.where(TrainingJob.id == job_id)
.values(
status=status,
progress=progress,
current_step=current_step,
updated_at=datetime.utcnow()
)
)
await db.commit()
async def _get_training_data(self, job_id: str, request: TrainingRequest, db: AsyncSession) -> Dict[str, Any]:
"""Get training data from data service"""
# Get job details
result = await db.execute(
select(TrainingJob).where(TrainingJob.id == job_id)
)
job = result.scalar_one()
try:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{settings.DATA_SERVICE_URL}/training-data/{job.tenant_id}",
params={
"from_date": job.training_data_from.isoformat(),
"to_date": job.training_data_to.isoformat(),
"products": request.products
}
)
if response.status_code == 200:
return response.json()
else:
raise Exception(f"Failed to get training data: {response.status_code}")
except Exception as e:
logger.error(f"Error getting training data: {e}")
raise
async def _save_trained_models(self, job_id: str, models_result: Dict[str, Any], validation_result: Dict[str, Any], db: AsyncSession):
"""Save trained models to database"""
# Get job details
result = await db.execute(
select(TrainingJob).where(TrainingJob.id == job_id)
)
job = result.scalar_one()
# Deactivate old models
await db.execute(
update(TrainedModel)
.where(TrainedModel.tenant_id == job.tenant_id)
.values(is_active=False)
)
# Save new models
for product_name, model_data in models_result.items():
validation_data = validation_result.get(product_name, {})
trained_model = TrainedModel(
tenant_id=job.tenant_id,
training_job_id=job_id,
product_name=product_name,
model_type=model_data.get("type", "prophet"),
model_version="1.0",
model_path=model_data.get("path"),
mape=validation_data.get("mape"),
rmse=validation_data.get("rmse"),
mae=validation_data.get("mae"),
r2_score=validation_data.get("r2_score"),
training_samples=model_data.get("training_samples"),
features_used=model_data.get("features", []),
hyperparameters=model_data.get("hyperparameters", {}),
is_active=True
)
db.add(trained_model)
await db.commit()
async def _complete_job(self, job_id: str, models_result: Dict[str, Any], validation_result: Dict[str, Any], duration: int, db: AsyncSession):
"""Complete training job"""
# Calculate metrics
metrics = {
"models_trained": len(models_result),
"average_mape": sum(v.get("mape", 0) for v in validation_result.values()) / len(validation_result) if validation_result else 0,
"training_duration": duration,
"validation_results": validation_result
}
await db.execute(
update(TrainingJob)
.where(TrainingJob.id == job_id)
.values(
status="completed",
progress=100,
current_step="Training completed successfully",
completed_at=datetime.utcnow(),
duration_seconds=duration,
models_trained=models_result,
metrics=metrics,
products_count=len(models_result)
)
)
await db.commit()