671 lines
28 KiB
Python
671 lines
28 KiB
Python
# services/training/app/services/training_service.py
|
|
"""
|
|
Training service business logic
|
|
Orchestrates ML training operations and manages job lifecycle
|
|
"""
|
|
|
|
from typing import Dict, List, Any, Optional
|
|
import logging
|
|
from datetime import datetime, timedelta
|
|
import asyncio
|
|
import uuid
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, update, and_
|
|
import httpx
|
|
|
|
from app.models.training import ModelTrainingLog, TrainedModel
|
|
from app.ml.trainer import BakeryMLTrainer
|
|
from app.schemas.training import TrainingJobRequest, SingleProductTrainingRequest
|
|
from app.services.messaging import publish_job_completed, publish_job_failed
|
|
from app.core.config import settings
|
|
from shared.monitoring.metrics import MetricsCollector
|
|
from app.services.data_client import DataServiceClient
|
|
|
|
logger = logging.getLogger(__name__)
|
|
metrics = MetricsCollector("training-service")
|
|
|
|
class TrainingService:
|
|
"""
|
|
Main service class for managing ML training operations.
|
|
Replaces the old Celery-based training system with clean async implementation.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.ml_trainer = BakeryMLTrainer()
|
|
self.data_client = DataServiceClient()
|
|
|
|
async def execute_training_job_simple(self, job_id: str, tenant_id_str: str, request: TrainingJobRequest):
|
|
"""Simple wrapper that creates its own database session"""
|
|
try:
|
|
# Import database_manager locally to avoid circular imports
|
|
from app.core.database import database_manager
|
|
|
|
logger.info(f"Starting background training job {job_id} for tenant {tenant_id_str}")
|
|
|
|
# Create new session for background task
|
|
async with database_manager.async_session_local() as session:
|
|
await self.execute_training_job(session, job_id, tenant_id_str, request)
|
|
await session.commit()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Background training job {job_id} failed: {str(e)}")
|
|
|
|
# Try to update job status to failed
|
|
try:
|
|
from app.core.database import database_manager
|
|
async with database_manager.async_session_local() as error_session:
|
|
await self._update_job_status(
|
|
error_session, job_id, "failed", 0,
|
|
f"Training failed: {str(e)}", error_message=str(e)
|
|
)
|
|
await error_session.commit()
|
|
except Exception as update_error:
|
|
logger.error(f"Failed to update job status: {str(update_error)}")
|
|
|
|
raise
|
|
|
|
async def create_training_job(self,
|
|
db: AsyncSession,
|
|
tenant_id: str,
|
|
job_id: str,
|
|
config: Dict[str, Any]) -> ModelTrainingLog:
|
|
"""Create a new training job record"""
|
|
try:
|
|
training_log = ModelTrainingLog(
|
|
job_id=job_id,
|
|
tenant_id=tenant_id,
|
|
status="pending",
|
|
progress=0,
|
|
current_step="Initializing training job",
|
|
start_time=datetime.now(),
|
|
config=config
|
|
)
|
|
|
|
db.add(training_log)
|
|
await db.commit()
|
|
await db.refresh(training_log)
|
|
|
|
logger.info(f"Created training job {job_id} for tenant {tenant_id}")
|
|
return training_log
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to create training job: {str(e)}")
|
|
await db.rollback()
|
|
raise
|
|
|
|
async def create_single_product_job(self,
|
|
db: AsyncSession,
|
|
tenant_id: str,
|
|
product_name: str,
|
|
job_id: str,
|
|
config: Dict[str, Any]) -> ModelTrainingLog:
|
|
"""Create a training job for a single product"""
|
|
try:
|
|
config["single_product"] = product_name
|
|
|
|
training_log = ModelTrainingLog(
|
|
job_id=job_id,
|
|
tenant_id=tenant_id,
|
|
status="pending",
|
|
progress=0,
|
|
current_step=f"Initializing training for {product_name}",
|
|
start_time=datetime.now(),
|
|
config=config
|
|
)
|
|
|
|
db.add(training_log)
|
|
await db.commit()
|
|
await db.refresh(training_log)
|
|
|
|
logger.info(f"Created single product training job {job_id} for {product_name}")
|
|
return training_log
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to create single product training job: {str(e)}")
|
|
await db.rollback()
|
|
raise
|
|
|
|
async def execute_training_job(self,
|
|
db: AsyncSession,
|
|
job_id: str,
|
|
tenant_id: str,
|
|
request: TrainingJobRequest):
|
|
"""Execute a complete training job"""
|
|
try:
|
|
logger.info(f"Starting execution of training job {job_id}")
|
|
|
|
# Update job status to running
|
|
await self._update_job_status(db, job_id, "running", 5, "Fetching training data")
|
|
|
|
# Fetch sales data from data service
|
|
sales_data = await self.data_client.fetch_sales_data(tenant_id)
|
|
|
|
# Fetch external data if requested
|
|
weather_data = []
|
|
traffic_data = []
|
|
|
|
if request.include_weather:
|
|
await self._update_job_status(db, job_id, "running", 15, "Fetching weather data")
|
|
weather_data = await self.data_client.fetch_weather_data(tenant_id)
|
|
|
|
if request.include_traffic:
|
|
await self._update_job_status(db, job_id, "running", 25, "Fetching traffic data")
|
|
traffic_data = await self.data_client.fetch_traffic_data(tenant_id)
|
|
|
|
# Execute ML training
|
|
await self._update_job_status(db, job_id, "running", 35, "Processing training data")
|
|
|
|
training_results = await self.ml_trainer.train_tenant_models(
|
|
tenant_id=tenant_id,
|
|
sales_data=sales_data,
|
|
weather_data=weather_data,
|
|
traffic_data=traffic_data,
|
|
job_id=job_id
|
|
)
|
|
|
|
await self._update_job_status(db, job_id, "running", 85, "Storing trained models")
|
|
|
|
# Store trained models in database
|
|
await self._store_trained_models(db, tenant_id, training_results)
|
|
|
|
await self._update_job_status(
|
|
db, job_id, "completed", 100, "Training completed successfully",
|
|
results=training_results
|
|
)
|
|
|
|
# Publish completion event
|
|
await publish_job_completed(job_id, tenant_id, training_results)
|
|
|
|
logger.info(f"Training results {training_results}")
|
|
logger.info(f"Training job {job_id} completed successfully")
|
|
metrics.increment_counter("training_jobs_completed")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Training job {job_id} failed: {str(e)}")
|
|
await self._update_job_status(
|
|
db, job_id, "failed", 0, f"Training failed: {str(e)}",
|
|
error_message=str(e)
|
|
)
|
|
|
|
# Publish failure event
|
|
await publish_job_failed(job_id, tenant_id, str(e))
|
|
|
|
metrics.increment_counter("training_jobs_failed")
|
|
raise
|
|
|
|
async def execute_single_product_training(self,
|
|
db: AsyncSession,
|
|
job_id: str,
|
|
tenant_id: str,
|
|
product_name: str,
|
|
request: SingleProductTrainingRequest):
|
|
"""Execute training for a single product"""
|
|
try:
|
|
logger.info(f"Starting single product training {job_id} for {product_name}")
|
|
|
|
# Update job status
|
|
await self._update_job_status(db, job_id, "running", 10, f"Fetching data for {product_name}")
|
|
|
|
# Fetch data
|
|
sales_data = await self._fetch_product_sales_data(tenant_id, product_name, request)
|
|
weather_data = []
|
|
traffic_data = []
|
|
|
|
if request.include_weather:
|
|
await self._update_job_status(db, job_id, "running", 30, "Fetching weather data")
|
|
weather_data = await self.data_client.fetch_weather_data(tenant_id, request)
|
|
|
|
if request.include_traffic:
|
|
await self._update_job_status(db, job_id, "running", 50, "Fetching traffic data")
|
|
traffic_data = await self.data_client.fetch_traffic_data(tenant_id, request)
|
|
|
|
# Execute training
|
|
await self._update_job_status(db, job_id, "running", 70, f"Training model for {product_name}")
|
|
|
|
training_result = await self.ml_trainer.train_single_product(
|
|
tenant_id=tenant_id,
|
|
product_name=product_name,
|
|
sales_data=sales_data,
|
|
weather_data=weather_data,
|
|
traffic_data=traffic_data,
|
|
job_id=job_id
|
|
)
|
|
|
|
# Store model
|
|
await self._update_job_status(db, job_id, "running", 90, "Storing trained model")
|
|
await self._store_single_trained_model(db, tenant_id, product_name, training_result)
|
|
|
|
await self._update_job_status(
|
|
db, job_id, "completed", 100, f"Training completed for {product_name}",
|
|
results=training_result
|
|
)
|
|
|
|
logger.info(f"Single product training {job_id} completed successfully")
|
|
metrics.increment_counter("single_product_training_completed")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Single product training {job_id} failed: {str(e)}")
|
|
await self._update_job_status(
|
|
db, job_id, "failed", 0, f"Training failed: {str(e)}",
|
|
error_message=str(e)
|
|
)
|
|
metrics.increment_counter("single_product_training_failed")
|
|
raise
|
|
|
|
async def get_job_status(self,
|
|
db: AsyncSession,
|
|
job_id: str,
|
|
tenant_id: str) -> Optional[ModelTrainingLog]:
|
|
"""Get training job status"""
|
|
try:
|
|
result = await db.execute(
|
|
select(ModelTrainingLog).where(
|
|
and_(
|
|
ModelTrainingLog.job_id == job_id,
|
|
ModelTrainingLog.tenant_id == tenant_id
|
|
)
|
|
)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get job status: {str(e)}")
|
|
return None
|
|
|
|
async def list_training_jobs(self,
|
|
db: AsyncSession,
|
|
tenant_id: str,
|
|
limit: int = 10,
|
|
status_filter: Optional[str] = None) -> List[ModelTrainingLog]:
|
|
"""List training jobs for a tenant"""
|
|
try:
|
|
query = select(ModelTrainingLog).where(
|
|
ModelTrainingLog.tenant_id == tenant_id
|
|
).order_by(ModelTrainingLog.start_time.desc()).limit(limit)
|
|
|
|
if status_filter:
|
|
query = query.where(ModelTrainingLog.status == status_filter)
|
|
|
|
result = await db.execute(query)
|
|
return result.scalars().all()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to list training jobs: {str(e)}")
|
|
return []
|
|
|
|
async def cancel_training_job(self,
|
|
db: AsyncSession,
|
|
job_id: str,
|
|
tenant_id: str) -> bool:
|
|
"""Cancel a training job"""
|
|
try:
|
|
result = await db.execute(
|
|
update(ModelTrainingLog)
|
|
.where(
|
|
and_(
|
|
ModelTrainingLog.job_id == job_id,
|
|
ModelTrainingLog.tenant_id == tenant_id,
|
|
ModelTrainingLog.status.in_(["pending", "running"])
|
|
)
|
|
)
|
|
.values(
|
|
status="cancelled",
|
|
end_time=datetime.now(),
|
|
current_step="Training cancelled by user"
|
|
)
|
|
)
|
|
|
|
await db.commit()
|
|
|
|
if result.rowcount > 0:
|
|
logger.info(f"Cancelled training job {job_id}")
|
|
return True
|
|
else:
|
|
logger.warning(f"Could not cancel training job {job_id} - not found or not cancellable")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to cancel training job: {str(e)}")
|
|
await db.rollback()
|
|
return False
|
|
|
|
async def validate_training_data(self,
|
|
db: AsyncSession,
|
|
tenant_id: str,
|
|
config: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Validate training data before starting a job"""
|
|
try:
|
|
logger.info(f"Validating training data for tenant {tenant_id}")
|
|
|
|
issues = []
|
|
recommendations = []
|
|
|
|
# Fetch a sample of sales data to validate
|
|
sales_data = await self._fetch_sales_data(tenant_id, config, limit=1000)
|
|
|
|
if not sales_data:
|
|
issues.append("No sales data found for tenant")
|
|
return {
|
|
"is_valid": False,
|
|
"issues": issues,
|
|
"recommendations": ["Upload sales data before training"],
|
|
"estimated_time_minutes": 0
|
|
}
|
|
|
|
# Analyze data quality
|
|
products = set(item.get("product_name") for item in sales_data)
|
|
total_records = len(sales_data)
|
|
|
|
# Check for sufficient data per product
|
|
product_counts = {}
|
|
for item in sales_data:
|
|
product = item.get("product_name")
|
|
if product:
|
|
product_counts[product] = product_counts.get(product, 0) + 1
|
|
|
|
insufficient_products = [
|
|
product for product, count in product_counts.items()
|
|
if count < config.get("min_data_points", 30)
|
|
]
|
|
|
|
if insufficient_products:
|
|
issues.append(f"Insufficient data for products: {', '.join(insufficient_products)}")
|
|
recommendations.append("Collect more historical data for these products")
|
|
|
|
# Estimate training time
|
|
valid_products = len(products) - len(insufficient_products)
|
|
estimated_time = max(5, valid_products * 2) # 2 minutes per product minimum
|
|
|
|
is_valid = len(issues) == 0
|
|
|
|
return {
|
|
"is_valid": is_valid,
|
|
"issues": issues,
|
|
"recommendations": recommendations,
|
|
"estimated_time_minutes": estimated_time,
|
|
"products_analyzed": len(products),
|
|
"total_data_points": total_records
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to validate training data: {str(e)}")
|
|
return {
|
|
"is_valid": False,
|
|
"issues": [f"Validation error: {str(e)}"],
|
|
"recommendations": ["Check data service connectivity"],
|
|
"estimated_time_minutes": 0
|
|
}
|
|
|
|
async def _update_job_status(self,
|
|
db: AsyncSession,
|
|
job_id: str,
|
|
status: str,
|
|
progress: int,
|
|
current_step: str,
|
|
results: Optional[Dict] = None,
|
|
error_message: Optional[str] = None):
|
|
"""Update training job status"""
|
|
try:
|
|
update_values = {
|
|
"status": status,
|
|
"progress": progress,
|
|
"current_step": current_step
|
|
}
|
|
|
|
if status == "completed":
|
|
update_values["end_time"] = datetime.now()
|
|
|
|
if results:
|
|
update_values["results"] = results
|
|
|
|
if error_message:
|
|
update_values["error_message"] = error_message
|
|
update_values["end_time"] = datetime.now()
|
|
|
|
await db.execute(
|
|
update(ModelTrainingLog)
|
|
.where(ModelTrainingLog.job_id == job_id)
|
|
.values(**update_values)
|
|
)
|
|
|
|
await db.commit()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to update job status: {str(e)}")
|
|
await db.rollback()
|
|
|
|
async def _fetch_sales_data(self,
|
|
tenant_id: str,
|
|
request: Any,
|
|
limit: Optional[int] = None) -> List[Dict]:
|
|
"""Fetch sales data from data service"""
|
|
try:
|
|
# Call data service to get sales data
|
|
async with httpx.AsyncClient() as client:
|
|
params = {}
|
|
headers = {
|
|
"X-Tenant-ID": tenant_id
|
|
}
|
|
|
|
if hasattr(request, 'start_date') and request.start_date:
|
|
params["start_date"] = request.start_date.isoformat()
|
|
|
|
if hasattr(request, 'end_date') and request.end_date:
|
|
params["end_date"] = request.end_date.isoformat()
|
|
|
|
if limit:
|
|
params["limit"] = limit
|
|
|
|
response = await client.get(
|
|
f"{settings.DATA_SERVICE_URL}/api/v1/tenants/{tenant_id}/sales",
|
|
params=params,
|
|
headers=headers,
|
|
timeout=30.0
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
return response.json().get("sales", [])
|
|
else:
|
|
logger.error(f"Failed to fetch sales data: {response.status_code}")
|
|
return []
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error fetching sales data: {str(e)}")
|
|
return []
|
|
|
|
async def _fetch_product_sales_data(self,
|
|
tenant_id: str,
|
|
product_name: str,
|
|
request: Any) -> List[Dict]:
|
|
"""Fetch sales data for a specific product"""
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
params = {
|
|
"tenant_id": tenant_id,
|
|
"product_name": product_name
|
|
}
|
|
|
|
if hasattr(request, 'start_date') and request.start_date:
|
|
params["start_date"] = request.start_date.isoformat()
|
|
|
|
if hasattr(request, 'end_date') and request.end_date:
|
|
params["end_date"] = request.end_date.isoformat()
|
|
|
|
response = await client.get(
|
|
f"{settings.DATA_SERVICE_URL}/api/sales/product/{product_name}",
|
|
params=params,
|
|
timeout=30.0
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
return response.json().get("sales", [])
|
|
else:
|
|
logger.error(f"Failed to fetch product sales data: {response.status_code}")
|
|
return []
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error fetching product sales data: {str(e)}")
|
|
return []
|
|
|
|
async def _store_trained_models(self,
|
|
db: AsyncSession,
|
|
tenant_id: str,
|
|
training_results: Dict[str, Any]):
|
|
"""Store trained models in database"""
|
|
try:
|
|
models_to_store = []
|
|
|
|
for product_name, result in training_results.get("training_results", {}).items():
|
|
if result.get("status") == "success":
|
|
model_info = result.get("model_info", {})
|
|
|
|
trained_model = TrainedModel(
|
|
tenant_id=tenant_id,
|
|
product_name=product_name,
|
|
model_id=model_info.get("model_id"),
|
|
model_type=model_info.get("type", "prophet"),
|
|
model_path=model_info.get("model_path"),
|
|
version=1, # Start with version 1
|
|
training_samples=model_info.get("training_samples", 0),
|
|
features=model_info.get("features", []),
|
|
hyperparameters=model_info.get("hyperparameters", {}),
|
|
training_metrics=model_info.get("training_metrics", {}),
|
|
data_period_start=datetime.fromisoformat(
|
|
model_info.get("data_period", {}).get("start_date", datetime.now().isoformat())
|
|
),
|
|
data_period_end=datetime.fromisoformat(
|
|
model_info.get("data_period", {}).get("end_date", datetime.now().isoformat())
|
|
),
|
|
created_at=datetime.now(),
|
|
is_active=True
|
|
)
|
|
|
|
models_to_store.append(trained_model)
|
|
|
|
# Deactivate old models for these products
|
|
if models_to_store:
|
|
product_names = [model.product_name for model in models_to_store]
|
|
|
|
await db.execute(
|
|
update(TrainedModel)
|
|
.where(
|
|
and_(
|
|
TrainedModel.tenant_id == tenant_id,
|
|
TrainedModel.product_name.in_(product_names),
|
|
TrainedModel.is_active == True
|
|
)
|
|
)
|
|
.values(is_active=False)
|
|
)
|
|
|
|
# Add new models
|
|
db.add_all(models_to_store)
|
|
await db.commit()
|
|
|
|
logger.info(f"Stored {len(models_to_store)} trained models for tenant {tenant_id}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to store trained models: {str(e)}")
|
|
await db.rollback()
|
|
raise
|
|
|
|
async def _store_single_trained_model(self,
|
|
db: AsyncSession,
|
|
tenant_id: str,
|
|
product_name: str,
|
|
training_result: Dict[str, Any]):
|
|
"""Store a single trained model"""
|
|
try:
|
|
if training_result.get("status") == "success":
|
|
model_info = training_result.get("model_info", {})
|
|
|
|
# Deactivate old model for this product
|
|
await db.execute(
|
|
update(TrainedModel)
|
|
.where(
|
|
and_(
|
|
TrainedModel.tenant_id == tenant_id,
|
|
TrainedModel.product_name == product_name,
|
|
TrainedModel.is_active == True
|
|
)
|
|
)
|
|
.values(is_active=False)
|
|
)
|
|
|
|
# Create new model record
|
|
trained_model = TrainedModel(
|
|
tenant_id=tenant_id,
|
|
product_name=product_name,
|
|
model_id=model_info.get("model_id"),
|
|
model_type=model_info.get("type", "prophet"),
|
|
model_path=model_info.get("model_path"),
|
|
version=1,
|
|
training_samples=model_info.get("training_samples", 0),
|
|
features=model_info.get("features", []),
|
|
hyperparameters=model_info.get("hyperparameters", {}),
|
|
training_metrics=model_info.get("training_metrics", {}),
|
|
data_period_start=datetime.fromisoformat(
|
|
model_info.get("data_period", {}).get("start_date", datetime.now().isoformat())
|
|
),
|
|
data_period_end=datetime.fromisoformat(
|
|
model_info.get("data_period", {}).get("end_date", datetime.now().isoformat())
|
|
),
|
|
created_at=datetime.now(),
|
|
is_active=True
|
|
)
|
|
|
|
db.add(trained_model)
|
|
await db.commit()
|
|
|
|
logger.info(f"Stored trained model for {product_name}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to store trained model: {str(e)}")
|
|
await db.rollback()
|
|
raise
|
|
|
|
async def get_training_logs(self,
|
|
db: AsyncSession,
|
|
job_id: str,
|
|
tenant_id: str) -> Optional[List[str]]:
|
|
"""Get detailed training logs for a job"""
|
|
try:
|
|
# For now, return basic log information from the database
|
|
# In a production system, you might store detailed logs separately
|
|
result = await db.execute(
|
|
select(ModelTrainingLog).where(
|
|
and_(
|
|
ModelTrainingLog.job_id == job_id,
|
|
ModelTrainingLog.tenant_id == tenant_id
|
|
)
|
|
)
|
|
)
|
|
|
|
training_log = result.scalar_one_or_none()
|
|
|
|
if training_log:
|
|
logs = [
|
|
f"Job started at: {training_log.start_time}",
|
|
f"Current status: {training_log.status}",
|
|
f"Progress: {training_log.progress}%",
|
|
f"Current step: {training_log.current_step}"
|
|
]
|
|
|
|
if training_log.end_time:
|
|
logs.append(f"Job completed at: {training_log.end_time}")
|
|
|
|
if training_log.error_message:
|
|
logs.append(f"Error: {training_log.error_message}")
|
|
|
|
if training_log.results:
|
|
results = training_log.results
|
|
logs.append(f"Models trained: {results.get('products_trained', 0)}")
|
|
logs.append(f"Models failed: {results.get('products_failed', 0)}")
|
|
|
|
return logs
|
|
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get training logs: {str(e)}")
|
|
return None |