384 lines
14 KiB
Python
384 lines
14 KiB
Python
"""
|
|
Training service business logic
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, Any, List, Optional
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, update, and_
|
|
import httpx
|
|
import uuid
|
|
import json
|
|
|
|
from app.core.config import settings
|
|
from app.models.training import TrainingJob, TrainedModel, TrainingLog
|
|
from app.schemas.training import TrainingRequest, TrainingJobResponse, TrainedModelResponse
|
|
from app.ml.trainer import MLTrainer
|
|
from app.services.messaging import message_publisher
|
|
from shared.messaging.events import TrainingStartedEvent, TrainingCompletedEvent, TrainingFailedEvent
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class TrainingService:
|
|
"""Training service business logic"""
|
|
|
|
def __init__(self):
|
|
self.ml_trainer = MLTrainer()
|
|
|
|
async def start_training(self, request: TrainingRequest, user_data: dict, db: AsyncSession) -> TrainingJobResponse:
|
|
"""Start a new training job"""
|
|
|
|
tenant_id = user_data.get("tenant_id")
|
|
if not tenant_id:
|
|
raise ValueError("User must be associated with a tenant")
|
|
|
|
# Check if there's already a running job for this tenant
|
|
existing_job = await self._get_running_job(tenant_id, db)
|
|
if existing_job:
|
|
raise ValueError("Training job already running for this tenant")
|
|
|
|
# Create training job
|
|
training_job = TrainingJob(
|
|
tenant_id=tenant_id,
|
|
status="queued",
|
|
progress=0,
|
|
current_step="Queued for training",
|
|
requested_by=user_data.get("user_id"),
|
|
training_data_from=datetime.utcnow() - timedelta(days=request.training_days),
|
|
training_data_to=datetime.utcnow()
|
|
)
|
|
|
|
db.add(training_job)
|
|
await db.commit()
|
|
await db.refresh(training_job)
|
|
|
|
# Start training in background
|
|
asyncio.create_task(self._execute_training(training_job.id, request, db))
|
|
|
|
# Publish training started event
|
|
await message_publisher.publish_event(
|
|
"training_events",
|
|
"training.started",
|
|
TrainingStartedEvent(
|
|
event_id=str(uuid.uuid4()),
|
|
service_name="training-service",
|
|
timestamp=datetime.utcnow(),
|
|
data={
|
|
"job_id": str(training_job.id),
|
|
"tenant_id": tenant_id,
|
|
"requested_by": user_data.get("user_id"),
|
|
"training_days": request.training_days
|
|
}
|
|
).__dict__
|
|
)
|
|
|
|
logger.info(f"Training job started: {training_job.id} for tenant: {tenant_id}")
|
|
|
|
return TrainingJobResponse(
|
|
id=str(training_job.id),
|
|
tenant_id=tenant_id,
|
|
status=training_job.status,
|
|
progress=training_job.progress,
|
|
current_step=training_job.current_step,
|
|
started_at=training_job.started_at,
|
|
completed_at=training_job.completed_at,
|
|
duration_seconds=training_job.duration_seconds,
|
|
models_trained=training_job.models_trained,
|
|
metrics=training_job.metrics,
|
|
error_message=training_job.error_message
|
|
)
|
|
|
|
async def get_training_status(self, job_id: str, user_data: dict, db: AsyncSession) -> TrainingJobResponse:
|
|
"""Get training job status"""
|
|
|
|
tenant_id = user_data.get("tenant_id")
|
|
|
|
result = await db.execute(
|
|
select(TrainingJob).where(
|
|
and_(
|
|
TrainingJob.id == job_id,
|
|
TrainingJob.tenant_id == tenant_id
|
|
)
|
|
)
|
|
)
|
|
|
|
job = result.scalar_one_or_none()
|
|
if not job:
|
|
raise ValueError("Training job not found")
|
|
|
|
return TrainingJobResponse(
|
|
id=str(job.id),
|
|
tenant_id=str(job.tenant_id),
|
|
status=job.status,
|
|
progress=job.progress,
|
|
current_step=job.current_step,
|
|
started_at=job.started_at,
|
|
completed_at=job.completed_at,
|
|
duration_seconds=job.duration_seconds,
|
|
models_trained=job.models_trained,
|
|
metrics=job.metrics,
|
|
error_message=job.error_message
|
|
)
|
|
|
|
async def get_trained_models(self, user_data: dict, db: AsyncSession) -> List[TrainedModelResponse]:
|
|
"""Get trained models for tenant"""
|
|
|
|
tenant_id = user_data.get("tenant_id")
|
|
|
|
result = await db.execute(
|
|
select(TrainedModel).where(
|
|
and_(
|
|
TrainedModel.tenant_id == tenant_id,
|
|
TrainedModel.is_active == True
|
|
)
|
|
).order_by(TrainedModel.created_at.desc())
|
|
)
|
|
|
|
models = result.scalars().all()
|
|
|
|
return [
|
|
TrainedModelResponse(
|
|
id=str(model.id),
|
|
product_name=model.product_name,
|
|
model_type=model.model_type,
|
|
model_version=model.model_version,
|
|
mape=model.mape,
|
|
rmse=model.rmse,
|
|
mae=model.mae,
|
|
r2_score=model.r2_score,
|
|
training_samples=model.training_samples,
|
|
features_used=model.features_used,
|
|
is_active=model.is_active,
|
|
created_at=model.created_at,
|
|
last_used_at=model.last_used_at
|
|
)
|
|
for model in models
|
|
]
|
|
|
|
async def get_training_jobs(self, user_data: dict, limit: int, offset: int, db: AsyncSession) -> List[TrainingJobResponse]:
|
|
"""Get training jobs for tenant"""
|
|
|
|
tenant_id = user_data.get("tenant_id")
|
|
|
|
result = await db.execute(
|
|
select(TrainingJob).where(
|
|
TrainingJob.tenant_id == tenant_id
|
|
).order_by(TrainingJob.created_at.desc())
|
|
.limit(limit)
|
|
.offset(offset)
|
|
)
|
|
|
|
jobs = result.scalars().all()
|
|
|
|
return [
|
|
TrainingJobResponse(
|
|
id=str(job.id),
|
|
tenant_id=str(job.tenant_id),
|
|
status=job.status,
|
|
progress=job.progress,
|
|
current_step=job.current_step,
|
|
started_at=job.started_at,
|
|
completed_at=job.completed_at,
|
|
duration_seconds=job.duration_seconds,
|
|
models_trained=job.models_trained,
|
|
metrics=job.metrics,
|
|
error_message=job.error_message
|
|
)
|
|
for job in jobs
|
|
]
|
|
|
|
async def _get_running_job(self, tenant_id: str, db: AsyncSession) -> Optional[TrainingJob]:
|
|
"""Get running training job for tenant"""
|
|
|
|
result = await db.execute(
|
|
select(TrainingJob).where(
|
|
and_(
|
|
TrainingJob.tenant_id == tenant_id,
|
|
TrainingJob.status.in_(["queued", "running"])
|
|
)
|
|
)
|
|
)
|
|
|
|
return result.scalar_one_or_none()
|
|
|
|
async def _execute_training(self, job_id: str, request: TrainingRequest, db: AsyncSession):
|
|
"""Execute training job"""
|
|
|
|
start_time = datetime.utcnow()
|
|
|
|
try:
|
|
# Update job status
|
|
await self._update_job_status(job_id, "running", 0, "Starting training...", db)
|
|
|
|
# Get training data
|
|
await self._update_job_status(job_id, "running", 10, "Fetching training data...", db)
|
|
training_data = await self._get_training_data(job_id, request, db)
|
|
|
|
# Train models
|
|
await self._update_job_status(job_id, "running", 30, "Training models...", db)
|
|
models_result = await self.ml_trainer.train_models(training_data, job_id, db)
|
|
|
|
# Validate models
|
|
await self._update_job_status(job_id, "running", 80, "Validating models...", db)
|
|
validation_result = await self.ml_trainer.validate_models(models_result, db)
|
|
|
|
# Save models
|
|
await self._update_job_status(job_id, "running", 90, "Saving models...", db)
|
|
await self._save_trained_models(job_id, models_result, validation_result, db)
|
|
|
|
# Complete job
|
|
duration = int((datetime.utcnow() - start_time).total_seconds())
|
|
await self._complete_job(job_id, models_result, validation_result, duration, db)
|
|
|
|
# Publish completion event
|
|
await message_publisher.publish_event(
|
|
"training_events",
|
|
"training.completed",
|
|
TrainingCompletedEvent(
|
|
event_id=str(uuid.uuid4()),
|
|
service_name="training-service",
|
|
timestamp=datetime.utcnow(),
|
|
data={
|
|
"job_id": str(job_id),
|
|
"models_trained": len(models_result),
|
|
"duration_seconds": duration
|
|
}
|
|
).__dict__
|
|
)
|
|
|
|
logger.info(f"Training job completed: {job_id}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Training job failed: {job_id} - {e}")
|
|
|
|
# Update job as failed
|
|
await self._update_job_status(job_id, "failed", 0, f"Training failed: {str(e)}", db)
|
|
|
|
# Publish failure event
|
|
await message_publisher.publish_event(
|
|
"training_events",
|
|
"training.failed",
|
|
TrainingFailedEvent(
|
|
event_id=str(uuid.uuid4()),
|
|
service_name="training-service",
|
|
timestamp=datetime.utcnow(),
|
|
data={
|
|
"job_id": str(job_id),
|
|
"error": str(e)
|
|
}
|
|
).__dict__
|
|
)
|
|
|
|
async def _update_job_status(self, job_id: str, status: str, progress: int, current_step: str, db: AsyncSession):
|
|
"""Update training job status"""
|
|
|
|
await db.execute(
|
|
update(TrainingJob)
|
|
.where(TrainingJob.id == job_id)
|
|
.values(
|
|
status=status,
|
|
progress=progress,
|
|
current_step=current_step,
|
|
updated_at=datetime.utcnow()
|
|
)
|
|
)
|
|
await db.commit()
|
|
|
|
async def _get_training_data(self, job_id: str, request: TrainingRequest, db: AsyncSession) -> Dict[str, Any]:
|
|
"""Get training data from data service"""
|
|
|
|
# Get job details
|
|
result = await db.execute(
|
|
select(TrainingJob).where(TrainingJob.id == job_id)
|
|
)
|
|
job = result.scalar_one()
|
|
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(
|
|
f"{settings.DATA_SERVICE_URL}/training-data/{job.tenant_id}",
|
|
params={
|
|
"from_date": job.training_data_from.isoformat(),
|
|
"to_date": job.training_data_to.isoformat(),
|
|
"products": request.products
|
|
}
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
return response.json()
|
|
else:
|
|
raise Exception(f"Failed to get training data: {response.status_code}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting training data: {e}")
|
|
raise
|
|
|
|
async def _save_trained_models(self, job_id: str, models_result: Dict[str, Any], validation_result: Dict[str, Any], db: AsyncSession):
|
|
"""Save trained models to database"""
|
|
|
|
# Get job details
|
|
result = await db.execute(
|
|
select(TrainingJob).where(TrainingJob.id == job_id)
|
|
)
|
|
job = result.scalar_one()
|
|
|
|
# Deactivate old models
|
|
await db.execute(
|
|
update(TrainedModel)
|
|
.where(TrainedModel.tenant_id == job.tenant_id)
|
|
.values(is_active=False)
|
|
)
|
|
|
|
# Save new models
|
|
for product_name, model_data in models_result.items():
|
|
validation_data = validation_result.get(product_name, {})
|
|
|
|
trained_model = TrainedModel(
|
|
tenant_id=job.tenant_id,
|
|
training_job_id=job_id,
|
|
product_name=product_name,
|
|
model_type=model_data.get("type", "prophet"),
|
|
model_version="1.0",
|
|
model_path=model_data.get("path"),
|
|
mape=validation_data.get("mape"),
|
|
rmse=validation_data.get("rmse"),
|
|
mae=validation_data.get("mae"),
|
|
r2_score=validation_data.get("r2_score"),
|
|
training_samples=model_data.get("training_samples"),
|
|
features_used=model_data.get("features", []),
|
|
hyperparameters=model_data.get("hyperparameters", {}),
|
|
is_active=True
|
|
)
|
|
|
|
db.add(trained_model)
|
|
|
|
await db.commit()
|
|
|
|
async def _complete_job(self, job_id: str, models_result: Dict[str, Any], validation_result: Dict[str, Any], duration: int, db: AsyncSession):
|
|
"""Complete training job"""
|
|
|
|
# Calculate metrics
|
|
metrics = {
|
|
"models_trained": len(models_result),
|
|
"average_mape": sum(v.get("mape", 0) for v in validation_result.values()) / len(validation_result) if validation_result else 0,
|
|
"training_duration": duration,
|
|
"validation_results": validation_result
|
|
}
|
|
|
|
await db.execute(
|
|
update(TrainingJob)
|
|
.where(TrainingJob.id == job_id)
|
|
.values(
|
|
status="completed",
|
|
progress=100,
|
|
current_step="Training completed successfully",
|
|
completed_at=datetime.utcnow(),
|
|
duration_seconds=duration,
|
|
models_trained=models_result,
|
|
metrics=metrics,
|
|
products_count=len(models_result)
|
|
)
|
|
)
|
|
await db.commit() |