Files
bakery-ia/services/training/app/services/training_service.py
Urtzi Alfaro 4684235111 Fix data fetch
2025-07-27 19:30:42 +02:00

671 lines
28 KiB
Python

# services/training/app/services/training_service.py
"""
Training service business logic
Orchestrates ML training operations and manages job lifecycle
"""
from typing import Dict, List, Any, Optional
import logging
from datetime import datetime, timedelta
import asyncio
import uuid
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, and_
import httpx
from app.models.training import ModelTrainingLog, TrainedModel
from app.ml.trainer import BakeryMLTrainer
from app.schemas.training import TrainingJobRequest, SingleProductTrainingRequest
from app.services.messaging import publish_job_completed, publish_job_failed
from app.core.config import settings
from shared.monitoring.metrics import MetricsCollector
from app.services.data_client import DataServiceClient
logger = logging.getLogger(__name__)
metrics = MetricsCollector("training-service")
class TrainingService:
"""
Main service class for managing ML training operations.
Replaces the old Celery-based training system with clean async implementation.
"""
def __init__(self):
self.ml_trainer = BakeryMLTrainer()
self.data_client = DataServiceClient()
async def execute_training_job_simple(self, job_id: str, tenant_id_str: str, request: TrainingJobRequest):
"""Simple wrapper that creates its own database session"""
try:
# Import database_manager locally to avoid circular imports
from app.core.database import database_manager
logger.info(f"Starting background training job {job_id} for tenant {tenant_id_str}")
# Create new session for background task
async with database_manager.async_session_local() as session:
await self.execute_training_job(session, job_id, tenant_id_str, request)
await session.commit()
except Exception as e:
logger.error(f"Background training job {job_id} failed: {str(e)}")
# Try to update job status to failed
try:
from app.core.database import database_manager
async with database_manager.async_session_local() as error_session:
await self._update_job_status(
error_session, job_id, "failed", 0,
f"Training failed: {str(e)}", error_message=str(e)
)
await error_session.commit()
except Exception as update_error:
logger.error(f"Failed to update job status: {str(update_error)}")
raise
async def create_training_job(self,
db: AsyncSession,
tenant_id: str,
job_id: str,
config: Dict[str, Any]) -> ModelTrainingLog:
"""Create a new training job record"""
try:
training_log = ModelTrainingLog(
job_id=job_id,
tenant_id=tenant_id,
status="pending",
progress=0,
current_step="Initializing training job",
start_time=datetime.now(),
config=config
)
db.add(training_log)
await db.commit()
await db.refresh(training_log)
logger.info(f"Created training job {job_id} for tenant {tenant_id}")
return training_log
except Exception as e:
logger.error(f"Failed to create training job: {str(e)}")
await db.rollback()
raise
async def create_single_product_job(self,
db: AsyncSession,
tenant_id: str,
product_name: str,
job_id: str,
config: Dict[str, Any]) -> ModelTrainingLog:
"""Create a training job for a single product"""
try:
config["single_product"] = product_name
training_log = ModelTrainingLog(
job_id=job_id,
tenant_id=tenant_id,
status="pending",
progress=0,
current_step=f"Initializing training for {product_name}",
start_time=datetime.now(),
config=config
)
db.add(training_log)
await db.commit()
await db.refresh(training_log)
logger.info(f"Created single product training job {job_id} for {product_name}")
return training_log
except Exception as e:
logger.error(f"Failed to create single product training job: {str(e)}")
await db.rollback()
raise
async def execute_training_job(self,
db: AsyncSession,
job_id: str,
tenant_id: str,
request: TrainingJobRequest):
"""Execute a complete training job"""
try:
logger.info(f"Starting execution of training job {job_id}")
# Update job status to running
await self._update_job_status(db, job_id, "running", 5, "Fetching training data")
# Fetch sales data from data service
sales_data = await self.data_client.fetch_sales_data(tenant_id)
# Fetch external data if requested
weather_data = []
traffic_data = []
if request.include_weather:
await self._update_job_status(db, job_id, "running", 15, "Fetching weather data")
weather_data = await self.data_client.fetch_weather_data(tenant_id)
if request.include_traffic:
await self._update_job_status(db, job_id, "running", 25, "Fetching traffic data")
traffic_data = await self.data_client.fetch_traffic_data(tenant_id)
# Execute ML training
await self._update_job_status(db, job_id, "running", 35, "Processing training data")
training_results = await self.ml_trainer.train_tenant_models(
tenant_id=tenant_id,
sales_data=sales_data,
weather_data=weather_data,
traffic_data=traffic_data,
job_id=job_id
)
await self._update_job_status(db, job_id, "running", 85, "Storing trained models")
# Store trained models in database
await self._store_trained_models(db, tenant_id, training_results)
await self._update_job_status(
db, job_id, "completed", 100, "Training completed successfully",
results=training_results
)
# Publish completion event
await publish_job_completed(job_id, tenant_id, training_results)
logger.info(f"Training results {training_results}")
logger.info(f"Training job {job_id} completed successfully")
metrics.increment_counter("training_jobs_completed")
except Exception as e:
logger.error(f"Training job {job_id} failed: {str(e)}")
await self._update_job_status(
db, job_id, "failed", 0, f"Training failed: {str(e)}",
error_message=str(e)
)
# Publish failure event
await publish_job_failed(job_id, tenant_id, str(e))
metrics.increment_counter("training_jobs_failed")
raise
async def execute_single_product_training(self,
db: AsyncSession,
job_id: str,
tenant_id: str,
product_name: str,
request: SingleProductTrainingRequest):
"""Execute training for a single product"""
try:
logger.info(f"Starting single product training {job_id} for {product_name}")
# Update job status
await self._update_job_status(db, job_id, "running", 10, f"Fetching data for {product_name}")
# Fetch data
sales_data = await self._fetch_product_sales_data(tenant_id, product_name, request)
weather_data = []
traffic_data = []
if request.include_weather:
await self._update_job_status(db, job_id, "running", 30, "Fetching weather data")
weather_data = await self.data_client.fetch_weather_data(tenant_id, request)
if request.include_traffic:
await self._update_job_status(db, job_id, "running", 50, "Fetching traffic data")
traffic_data = await self.data_client.fetch_traffic_data(tenant_id, request)
# Execute training
await self._update_job_status(db, job_id, "running", 70, f"Training model for {product_name}")
training_result = await self.ml_trainer.train_single_product(
tenant_id=tenant_id,
product_name=product_name,
sales_data=sales_data,
weather_data=weather_data,
traffic_data=traffic_data,
job_id=job_id
)
# Store model
await self._update_job_status(db, job_id, "running", 90, "Storing trained model")
await self._store_single_trained_model(db, tenant_id, product_name, training_result)
await self._update_job_status(
db, job_id, "completed", 100, f"Training completed for {product_name}",
results=training_result
)
logger.info(f"Single product training {job_id} completed successfully")
metrics.increment_counter("single_product_training_completed")
except Exception as e:
logger.error(f"Single product training {job_id} failed: {str(e)}")
await self._update_job_status(
db, job_id, "failed", 0, f"Training failed: {str(e)}",
error_message=str(e)
)
metrics.increment_counter("single_product_training_failed")
raise
async def get_job_status(self,
db: AsyncSession,
job_id: str,
tenant_id: str) -> Optional[ModelTrainingLog]:
"""Get training job status"""
try:
result = await db.execute(
select(ModelTrainingLog).where(
and_(
ModelTrainingLog.job_id == job_id,
ModelTrainingLog.tenant_id == tenant_id
)
)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Failed to get job status: {str(e)}")
return None
async def list_training_jobs(self,
db: AsyncSession,
tenant_id: str,
limit: int = 10,
status_filter: Optional[str] = None) -> List[ModelTrainingLog]:
"""List training jobs for a tenant"""
try:
query = select(ModelTrainingLog).where(
ModelTrainingLog.tenant_id == tenant_id
).order_by(ModelTrainingLog.start_time.desc()).limit(limit)
if status_filter:
query = query.where(ModelTrainingLog.status == status_filter)
result = await db.execute(query)
return result.scalars().all()
except Exception as e:
logger.error(f"Failed to list training jobs: {str(e)}")
return []
async def cancel_training_job(self,
db: AsyncSession,
job_id: str,
tenant_id: str) -> bool:
"""Cancel a training job"""
try:
result = await db.execute(
update(ModelTrainingLog)
.where(
and_(
ModelTrainingLog.job_id == job_id,
ModelTrainingLog.tenant_id == tenant_id,
ModelTrainingLog.status.in_(["pending", "running"])
)
)
.values(
status="cancelled",
end_time=datetime.now(),
current_step="Training cancelled by user"
)
)
await db.commit()
if result.rowcount > 0:
logger.info(f"Cancelled training job {job_id}")
return True
else:
logger.warning(f"Could not cancel training job {job_id} - not found or not cancellable")
return False
except Exception as e:
logger.error(f"Failed to cancel training job: {str(e)}")
await db.rollback()
return False
async def validate_training_data(self,
db: AsyncSession,
tenant_id: str,
config: Dict[str, Any]) -> Dict[str, Any]:
"""Validate training data before starting a job"""
try:
logger.info(f"Validating training data for tenant {tenant_id}")
issues = []
recommendations = []
# Fetch a sample of sales data to validate
sales_data = await self._fetch_sales_data(tenant_id, config, limit=1000)
if not sales_data:
issues.append("No sales data found for tenant")
return {
"is_valid": False,
"issues": issues,
"recommendations": ["Upload sales data before training"],
"estimated_time_minutes": 0
}
# Analyze data quality
products = set(item.get("product_name") for item in sales_data)
total_records = len(sales_data)
# Check for sufficient data per product
product_counts = {}
for item in sales_data:
product = item.get("product_name")
if product:
product_counts[product] = product_counts.get(product, 0) + 1
insufficient_products = [
product for product, count in product_counts.items()
if count < config.get("min_data_points", 30)
]
if insufficient_products:
issues.append(f"Insufficient data for products: {', '.join(insufficient_products)}")
recommendations.append("Collect more historical data for these products")
# Estimate training time
valid_products = len(products) - len(insufficient_products)
estimated_time = max(5, valid_products * 2) # 2 minutes per product minimum
is_valid = len(issues) == 0
return {
"is_valid": is_valid,
"issues": issues,
"recommendations": recommendations,
"estimated_time_minutes": estimated_time,
"products_analyzed": len(products),
"total_data_points": total_records
}
except Exception as e:
logger.error(f"Failed to validate training data: {str(e)}")
return {
"is_valid": False,
"issues": [f"Validation error: {str(e)}"],
"recommendations": ["Check data service connectivity"],
"estimated_time_minutes": 0
}
async def _update_job_status(self,
db: AsyncSession,
job_id: str,
status: str,
progress: int,
current_step: str,
results: Optional[Dict] = None,
error_message: Optional[str] = None):
"""Update training job status"""
try:
update_values = {
"status": status,
"progress": progress,
"current_step": current_step
}
if status == "completed":
update_values["end_time"] = datetime.now()
if results:
update_values["results"] = results
if error_message:
update_values["error_message"] = error_message
update_values["end_time"] = datetime.now()
await db.execute(
update(ModelTrainingLog)
.where(ModelTrainingLog.job_id == job_id)
.values(**update_values)
)
await db.commit()
except Exception as e:
logger.error(f"Failed to update job status: {str(e)}")
await db.rollback()
async def _fetch_sales_data(self,
tenant_id: str,
request: Any,
limit: Optional[int] = None) -> List[Dict]:
"""Fetch sales data from data service"""
try:
# Call data service to get sales data
async with httpx.AsyncClient() as client:
params = {}
headers = {
"X-Tenant-ID": tenant_id
}
if hasattr(request, 'start_date') and request.start_date:
params["start_date"] = request.start_date.isoformat()
if hasattr(request, 'end_date') and request.end_date:
params["end_date"] = request.end_date.isoformat()
if limit:
params["limit"] = limit
response = await client.get(
f"{settings.DATA_SERVICE_URL}/api/v1/tenants/{tenant_id}/sales",
params=params,
headers=headers,
timeout=30.0
)
if response.status_code == 200:
return response.json().get("sales", [])
else:
logger.error(f"Failed to fetch sales data: {response.status_code}")
return []
except Exception as e:
logger.error(f"Error fetching sales data: {str(e)}")
return []
async def _fetch_product_sales_data(self,
tenant_id: str,
product_name: str,
request: Any) -> List[Dict]:
"""Fetch sales data for a specific product"""
try:
async with httpx.AsyncClient() as client:
params = {
"tenant_id": tenant_id,
"product_name": product_name
}
if hasattr(request, 'start_date') and request.start_date:
params["start_date"] = request.start_date.isoformat()
if hasattr(request, 'end_date') and request.end_date:
params["end_date"] = request.end_date.isoformat()
response = await client.get(
f"{settings.DATA_SERVICE_URL}/api/sales/product/{product_name}",
params=params,
timeout=30.0
)
if response.status_code == 200:
return response.json().get("sales", [])
else:
logger.error(f"Failed to fetch product sales data: {response.status_code}")
return []
except Exception as e:
logger.error(f"Error fetching product sales data: {str(e)}")
return []
async def _store_trained_models(self,
db: AsyncSession,
tenant_id: str,
training_results: Dict[str, Any]):
"""Store trained models in database"""
try:
models_to_store = []
for product_name, result in training_results.get("training_results", {}).items():
if result.get("status") == "success":
model_info = result.get("model_info", {})
trained_model = TrainedModel(
tenant_id=tenant_id,
product_name=product_name,
model_id=model_info.get("model_id"),
model_type=model_info.get("type", "prophet"),
model_path=model_info.get("model_path"),
version=1, # Start with version 1
training_samples=model_info.get("training_samples", 0),
features=model_info.get("features", []),
hyperparameters=model_info.get("hyperparameters", {}),
training_metrics=model_info.get("training_metrics", {}),
data_period_start=datetime.fromisoformat(
model_info.get("data_period", {}).get("start_date", datetime.now().isoformat())
),
data_period_end=datetime.fromisoformat(
model_info.get("data_period", {}).get("end_date", datetime.now().isoformat())
),
created_at=datetime.now(),
is_active=True
)
models_to_store.append(trained_model)
# Deactivate old models for these products
if models_to_store:
product_names = [model.product_name for model in models_to_store]
await db.execute(
update(TrainedModel)
.where(
and_(
TrainedModel.tenant_id == tenant_id,
TrainedModel.product_name.in_(product_names),
TrainedModel.is_active == True
)
)
.values(is_active=False)
)
# Add new models
db.add_all(models_to_store)
await db.commit()
logger.info(f"Stored {len(models_to_store)} trained models for tenant {tenant_id}")
except Exception as e:
logger.error(f"Failed to store trained models: {str(e)}")
await db.rollback()
raise
async def _store_single_trained_model(self,
db: AsyncSession,
tenant_id: str,
product_name: str,
training_result: Dict[str, Any]):
"""Store a single trained model"""
try:
if training_result.get("status") == "success":
model_info = training_result.get("model_info", {})
# Deactivate old model for this product
await db.execute(
update(TrainedModel)
.where(
and_(
TrainedModel.tenant_id == tenant_id,
TrainedModel.product_name == product_name,
TrainedModel.is_active == True
)
)
.values(is_active=False)
)
# Create new model record
trained_model = TrainedModel(
tenant_id=tenant_id,
product_name=product_name,
model_id=model_info.get("model_id"),
model_type=model_info.get("type", "prophet"),
model_path=model_info.get("model_path"),
version=1,
training_samples=model_info.get("training_samples", 0),
features=model_info.get("features", []),
hyperparameters=model_info.get("hyperparameters", {}),
training_metrics=model_info.get("training_metrics", {}),
data_period_start=datetime.fromisoformat(
model_info.get("data_period", {}).get("start_date", datetime.now().isoformat())
),
data_period_end=datetime.fromisoformat(
model_info.get("data_period", {}).get("end_date", datetime.now().isoformat())
),
created_at=datetime.now(),
is_active=True
)
db.add(trained_model)
await db.commit()
logger.info(f"Stored trained model for {product_name}")
except Exception as e:
logger.error(f"Failed to store trained model: {str(e)}")
await db.rollback()
raise
async def get_training_logs(self,
db: AsyncSession,
job_id: str,
tenant_id: str) -> Optional[List[str]]:
"""Get detailed training logs for a job"""
try:
# For now, return basic log information from the database
# In a production system, you might store detailed logs separately
result = await db.execute(
select(ModelTrainingLog).where(
and_(
ModelTrainingLog.job_id == job_id,
ModelTrainingLog.tenant_id == tenant_id
)
)
)
training_log = result.scalar_one_or_none()
if training_log:
logs = [
f"Job started at: {training_log.start_time}",
f"Current status: {training_log.status}",
f"Progress: {training_log.progress}%",
f"Current step: {training_log.current_step}"
]
if training_log.end_time:
logs.append(f"Job completed at: {training_log.end_time}")
if training_log.error_message:
logs.append(f"Error: {training_log.error_message}")
if training_log.results:
results = training_log.results
logs.append(f"Models trained: {results.get('products_trained', 0)}")
logs.append(f"Models failed: {results.get('products_failed', 0)}")
return logs
return None
except Exception as e:
logger.error(f"Failed to get training logs: {str(e)}")
return None