Improve training code
This commit is contained in:
@@ -1,721 +1,303 @@
|
||||
# services/training/app/services/training_service.py
|
||||
"""
|
||||
Training service business logic
|
||||
Orchestrates ML training operations and manages job lifecycle
|
||||
Main Training Service - Coordinates the complete training process
|
||||
This is the entry point from the API layer
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Any, Optional
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
import asyncio
|
||||
import uuid
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, update, and_
|
||||
import httpx
|
||||
|
||||
from app.models.training import ModelTrainingLog, TrainedModel
|
||||
from app.ml.trainer import BakeryMLTrainer
|
||||
from app.schemas.training import TrainingJobRequest, SingleProductTrainingRequest
|
||||
from app.services.messaging import publish_job_completed, publish_job_failed
|
||||
from app.core.config import settings
|
||||
from shared.monitoring.metrics import MetricsCollector
|
||||
from app.services.data_client import DataServiceClient
|
||||
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType
|
||||
from app.services.training_orchestrator import TrainingDataOrchestrator
|
||||
|
||||
from app.core.database import get_db_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
metrics = MetricsCollector("training-service")
|
||||
|
||||
class TrainingService:
|
||||
"""
|
||||
Main service class for managing ML training operations.
|
||||
Replaces the old Celery-based training system with clean async implementation.
|
||||
Main training service that coordinates the complete training pipeline.
|
||||
Entry point from API layer - handles business logic and orchestration.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.ml_trainer = BakeryMLTrainer()
|
||||
self.data_client = DataServiceClient()
|
||||
|
||||
async def _determine_sales_date_range(self, sales_data: List[Dict]) -> tuple[datetime, datetime]:
|
||||
"""Determine start and end dates from sales data with validation"""
|
||||
if not sales_data:
|
||||
raise ValueError("No sales data available to determine date range")
|
||||
|
||||
dates = []
|
||||
for record in sales_data:
|
||||
if 'date' in record:
|
||||
try:
|
||||
if isinstance(record['date'], str):
|
||||
# Handle various date string formats
|
||||
date_str = record['date'].replace('Z', '+00:00')
|
||||
if 'T' in date_str:
|
||||
parsed_date = datetime.fromisoformat(date_str)
|
||||
else:
|
||||
# Handle date-only strings
|
||||
parsed_date = datetime.strptime(date_str, '%Y-%m-%d')
|
||||
dates.append(parsed_date)
|
||||
elif isinstance(record['date'], datetime):
|
||||
dates.append(record['date'])
|
||||
except (ValueError, AttributeError) as e:
|
||||
logger.warning(f"Invalid date format in record: {record['date']} - {e}")
|
||||
continue
|
||||
|
||||
if not dates:
|
||||
raise ValueError("No valid dates found in sales data")
|
||||
|
||||
start_date = min(dates)
|
||||
end_date = max(dates)
|
||||
|
||||
# Validate and adjust date range for external APIs
|
||||
start_date, end_date = self._adjust_date_range_for_apis(start_date, end_date)
|
||||
|
||||
logger.info(f"Determined and adjusted sales date range: {start_date} to {end_date}")
|
||||
return start_date, end_date
|
||||
|
||||
def _adjust_date_range_for_apis(self, start_date: datetime, end_date: datetime) -> tuple[datetime, datetime]:
|
||||
"""Adjust date range to comply with external API limits"""
|
||||
|
||||
# Weather and traffic APIs have a 90-day limit
|
||||
MAX_DAYS = 90
|
||||
|
||||
# Calculate current range
|
||||
current_range = (end_date - start_date).days
|
||||
|
||||
if current_range > MAX_DAYS:
|
||||
logger.warning(f"Date range ({current_range} days) exceeds API limit ({MAX_DAYS} days). Adjusting...")
|
||||
|
||||
# Keep the most recent data
|
||||
start_date = end_date - timedelta(days=MAX_DAYS)
|
||||
logger.info(f"Adjusted start_date to {start_date} to fit within {MAX_DAYS} day limit")
|
||||
|
||||
# Ensure dates are not in the future
|
||||
now = datetime.now()
|
||||
if end_date > now:
|
||||
end_date = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
logger.info(f"Adjusted end_date to {end_date} (cannot be in future)")
|
||||
|
||||
if start_date > now:
|
||||
start_date = now.replace(hour=0, minute=0, second=0, microsecond=0) - timedelta(days=30)
|
||||
logger.info(f"Adjusted start_date to {start_date} (was in future)")
|
||||
|
||||
# Ensure start_date is before end_date
|
||||
if start_date >= end_date:
|
||||
start_date = end_date - timedelta(days=30) # Default to 30 days of data
|
||||
logger.warning(f"start_date was not before end_date. Adjusted start_date to {start_date}")
|
||||
def __init__(self, db_session: AsyncSession = None):
|
||||
self.db_session = db_session
|
||||
self.trainer = BakeryMLTrainer(db_session=db_session) # Pass DB session
|
||||
self.date_alignment_service = DateAlignmentService()
|
||||
self.orchestrator = TrainingDataOrchestrator(
|
||||
date_alignment_service=self.date_alignment_service
|
||||
)
|
||||
|
||||
return start_date, end_date
|
||||
async def start_training_job(
|
||||
self,
|
||||
tenant_id: str,
|
||||
bakery_location: tuple[float, float] = (40.4168, -3.7038), # Default Madrid
|
||||
requested_start: Optional[datetime] = None,
|
||||
requested_end: Optional[datetime] = None,
|
||||
job_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Start a complete training job for a tenant.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
sales_data: Historical sales data
|
||||
bakery_location: Bakery coordinates (lat, lon)
|
||||
weather_data: Optional weather data
|
||||
traffic_data: Optional traffic data
|
||||
requested_start: Optional explicit start date
|
||||
requested_end: Optional explicit end date
|
||||
job_id: Optional job identifier
|
||||
|
||||
Returns:
|
||||
Training job results
|
||||
"""
|
||||
if not job_id:
|
||||
job_id = f"training_{tenant_id}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
logger.info(f"Starting training job {job_id} for tenant {tenant_id}")
|
||||
|
||||
async def execute_training_job_simple(self, job_id: str, tenant_id_str: str, request: TrainingJobRequest):
|
||||
"""Simple wrapper that creates its own database session"""
|
||||
try:
|
||||
# Import database_manager locally to avoid circular imports
|
||||
from app.core.database import database_manager
|
||||
|
||||
logger.info(f"Starting background training job {job_id} for tenant {tenant_id_str}")
|
||||
|
||||
# Create new session for background task
|
||||
async with database_manager.async_session_local() as session:
|
||||
await self.execute_training_job(session, job_id, tenant_id_str, request)
|
||||
await session.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Background training job {job_id} failed: {str(e)}")
|
||||
|
||||
# Try to update job status to failed
|
||||
try:
|
||||
from app.core.database import database_manager
|
||||
async with database_manager.async_session_local() as error_session:
|
||||
await self._update_job_status(
|
||||
error_session, job_id, "failed", 0,
|
||||
f"Training failed: {str(e)}", error_message=str(e)
|
||||
)
|
||||
await error_session.commit()
|
||||
except Exception as update_error:
|
||||
logger.error(f"Failed to update job status: {str(update_error)}")
|
||||
|
||||
raise
|
||||
|
||||
async def create_training_job(self,
|
||||
db: AsyncSession,
|
||||
tenant_id: str,
|
||||
job_id: str,
|
||||
config: Dict[str, Any]) -> ModelTrainingLog:
|
||||
"""Create a new training job record"""
|
||||
try:
|
||||
training_log = ModelTrainingLog(
|
||||
job_id=job_id,
|
||||
# Step 1: Prepare training dataset with date alignment and orchestration
|
||||
logger.info("Step 1: Preparing and aligning training data")
|
||||
training_dataset = await self.orchestrator.prepare_training_data(
|
||||
tenant_id=tenant_id,
|
||||
status="pending",
|
||||
progress=0,
|
||||
current_step="Initializing training job",
|
||||
start_time=datetime.now(),
|
||||
config=config
|
||||
)
|
||||
|
||||
db.add(training_log)
|
||||
await db.commit()
|
||||
await db.refresh(training_log)
|
||||
|
||||
logger.info(f"Created training job {job_id} for tenant {tenant_id}")
|
||||
return training_log
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create training job: {str(e)}")
|
||||
await db.rollback()
|
||||
raise
|
||||
|
||||
async def create_single_product_job(self,
|
||||
db: AsyncSession,
|
||||
tenant_id: str,
|
||||
product_name: str,
|
||||
job_id: str,
|
||||
config: Dict[str, Any]) -> ModelTrainingLog:
|
||||
"""Create a training job for a single product"""
|
||||
try:
|
||||
config["single_product"] = product_name
|
||||
|
||||
training_log = ModelTrainingLog(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
status="pending",
|
||||
progress=0,
|
||||
current_step=f"Initializing training for {product_name}",
|
||||
start_time=datetime.now(),
|
||||
config=config
|
||||
)
|
||||
|
||||
db.add(training_log)
|
||||
await db.commit()
|
||||
await db.refresh(training_log)
|
||||
|
||||
logger.info(f"Created single product training job {job_id} for {product_name}")
|
||||
return training_log
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create single product training job: {str(e)}")
|
||||
await db.rollback()
|
||||
raise
|
||||
|
||||
async def execute_training_job(self,
|
||||
db: AsyncSession,
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
request: TrainingJobRequest):
|
||||
"""Execute a complete training job"""
|
||||
try:
|
||||
logger.info(f"Starting execution of training job {job_id}")
|
||||
|
||||
# Update job status to running
|
||||
await self._update_job_status(db, job_id, "running", 5, "Fetching training data")
|
||||
|
||||
# Fetch sales data from data service
|
||||
sales_data = await self.data_client.fetch_sales_data(tenant_id)
|
||||
|
||||
if not sales_data:
|
||||
raise ValueError("No sales data found for training")
|
||||
|
||||
# Determine date range from sales data
|
||||
start_date, end_date = await self._determine_sales_date_range(sales_data)
|
||||
|
||||
# Convert dates to ISO format strings for API calls
|
||||
start_date_str = start_date.isoformat()
|
||||
end_date_str = end_date.isoformat()
|
||||
|
||||
logger.info(f"Using date range for external APIs: {start_date_str} to {end_date_str}")
|
||||
|
||||
# Fetch external data if requested using the sales date range
|
||||
weather_data = []
|
||||
traffic_data = []
|
||||
|
||||
await self._update_job_status(db, job_id, "running", 15, "Fetching weather data")
|
||||
try:
|
||||
weather_data = await self.data_client.fetch_weather_data(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date_str,
|
||||
end_date=end_date_str,
|
||||
latitude=40.4168, # Madrid coordinates
|
||||
longitude=-3.7038
|
||||
)
|
||||
logger.info(f"Fetched {len(weather_data)} weather records")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch weather data: {e}. Continuing without weather data.")
|
||||
weather_data = []
|
||||
|
||||
await self._update_job_status(db, job_id, "running", 25, "Fetching traffic data")
|
||||
try:
|
||||
traffic_data = await self.data_client.fetch_traffic_data(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date_str,
|
||||
end_date=end_date_str,
|
||||
latitude=40.4168,
|
||||
longitude=-3.7038
|
||||
)
|
||||
logger.info(f"Fetched {len(traffic_data)} traffic records")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch traffic data: {e}. Continuing without traffic data.")
|
||||
traffic_data = []
|
||||
|
||||
# Execute ML training
|
||||
await self._update_job_status(db, job_id, "running", 35, "Processing training data")
|
||||
|
||||
training_results = await self.ml_trainer.train_tenant_models(
|
||||
tenant_id=tenant_id,
|
||||
sales_data=sales_data,
|
||||
weather_data=weather_data,
|
||||
traffic_data=traffic_data,
|
||||
bakery_location=bakery_location,
|
||||
requested_start=requested_start,
|
||||
requested_end=requested_end,
|
||||
job_id=job_id
|
||||
)
|
||||
|
||||
await self._update_job_status(db, job_id, "running", 85, "Storing trained models")
|
||||
|
||||
# Store trained models in database
|
||||
await self._store_trained_models(db, tenant_id, training_results)
|
||||
|
||||
await self._update_job_status(
|
||||
db, job_id, "completed", 100, "Training completed successfully",
|
||||
results=training_results
|
||||
# Step 2: Execute ML training pipeline
|
||||
logger.info("Step 2: Starting ML training pipeline")
|
||||
training_results = await self.trainer.train_tenant_models(
|
||||
tenant_id=tenant_id,
|
||||
training_dataset=training_dataset,
|
||||
job_id=job_id
|
||||
)
|
||||
|
||||
# Publish completion event
|
||||
await publish_job_completed(job_id, tenant_id, training_results)
|
||||
# Step 3: Compile final results
|
||||
final_result = {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"status": "completed",
|
||||
"training_results": training_results,
|
||||
"data_summary": {
|
||||
"sales_records": len(training_dataset.sales_data),
|
||||
"weather_records": len(training_dataset.weather_data),
|
||||
"traffic_records": len(training_dataset.traffic_data),
|
||||
"date_range": {
|
||||
"start": training_dataset.date_range.start.isoformat(),
|
||||
"end": training_dataset.date_range.end.isoformat()
|
||||
},
|
||||
"data_sources_used": [source.value for source in training_dataset.date_range.available_sources],
|
||||
"constraints_applied": training_dataset.date_range.constraints
|
||||
},
|
||||
"completed_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
logger.info(f"Training results {training_results}")
|
||||
logger.info(f"Training job {job_id} completed successfully")
|
||||
metrics.increment_counter("training_jobs_completed")
|
||||
return final_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training job {job_id} failed: {str(e)}")
|
||||
await self._update_job_status(
|
||||
db, job_id, "failed", 0, f"Training failed: {str(e)}",
|
||||
error_message=str(e)
|
||||
)
|
||||
|
||||
# Publish failure event
|
||||
await publish_job_failed(job_id, tenant_id, str(e))
|
||||
|
||||
metrics.increment_counter("training_jobs_failed")
|
||||
raise
|
||||
return {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"status": "failed",
|
||||
"error_message": str(e),
|
||||
"failed_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
async def execute_single_product_training(self,
|
||||
db: AsyncSession,
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
product_name: str,
|
||||
request: SingleProductTrainingRequest):
|
||||
"""Execute training for a single product"""
|
||||
async def start_single_product_training(
|
||||
self,
|
||||
tenant_id: str,
|
||||
product_name: str,
|
||||
sales_data: List[Dict[str, Any]],
|
||||
bakery_location: tuple[float, float] = (40.4168, -3.7038),
|
||||
weather_data: Optional[List[Dict[str, Any]]] = None,
|
||||
traffic_data: Optional[List[Dict[str, Any]]] = None,
|
||||
job_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Train a model for a single product.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
product_name: Product name
|
||||
sales_data: Historical sales data
|
||||
bakery_location: Bakery coordinates
|
||||
weather_data: Optional weather data
|
||||
traffic_data: Optional traffic data
|
||||
job_id: Optional job identifier
|
||||
|
||||
Returns:
|
||||
Single product training result
|
||||
"""
|
||||
if not job_id:
|
||||
job_id = f"single_{tenant_id}_{product_name}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
logger.info(f"Starting single product training {job_id} for {product_name}")
|
||||
|
||||
try:
|
||||
logger.info(f"Starting single product training {job_id} for {product_name}")
|
||||
# Filter sales data for the specific product
|
||||
product_sales = [
|
||||
record for record in sales_data
|
||||
if record.get('product_name') == product_name
|
||||
]
|
||||
|
||||
# Update job status
|
||||
await self._update_job_status(db, job_id, "running", 10, f"Fetching data for {product_name}")
|
||||
if not product_sales:
|
||||
raise ValueError(f"No sales data found for product: {product_name}")
|
||||
|
||||
# Fetch data
|
||||
sales_data = await self._fetch_product_sales_data(tenant_id, product_name, request)
|
||||
weather_data = []
|
||||
traffic_data = []
|
||||
|
||||
if request.include_weather:
|
||||
await self._update_job_status(db, job_id, "running", 30, "Fetching weather data")
|
||||
weather_data = await self.data_client.fetch_weather_data(tenant_id, request)
|
||||
|
||||
if request.include_traffic:
|
||||
await self._update_job_status(db, job_id, "running", 50, "Fetching traffic data")
|
||||
traffic_data = await self.data_client.fetch_traffic_data(tenant_id, request)
|
||||
|
||||
# Execute training
|
||||
await self._update_job_status(db, job_id, "running", 70, f"Training model for {product_name}")
|
||||
|
||||
training_result = await self.ml_trainer.train_single_product(
|
||||
# Use the same pipeline but for single product
|
||||
return await self.start_training_job(
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
sales_data=sales_data,
|
||||
sales_data=product_sales,
|
||||
bakery_location=bakery_location,
|
||||
weather_data=weather_data,
|
||||
traffic_data=traffic_data,
|
||||
job_id=job_id
|
||||
)
|
||||
|
||||
# Store model
|
||||
await self._update_job_status(db, job_id, "running", 90, "Storing trained model")
|
||||
await self._store_single_trained_model(db, tenant_id, product_name, training_result)
|
||||
|
||||
await self._update_job_status(
|
||||
db, job_id, "completed", 100, f"Training completed for {product_name}",
|
||||
results=training_result
|
||||
)
|
||||
|
||||
logger.info(f"Single product training {job_id} completed successfully")
|
||||
metrics.increment_counter("single_product_training_completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Single product training {job_id} failed: {str(e)}")
|
||||
await self._update_job_status(
|
||||
db, job_id, "failed", 0, f"Training failed: {str(e)}",
|
||||
error_message=str(e)
|
||||
)
|
||||
metrics.increment_counter("single_product_training_failed")
|
||||
raise
|
||||
return {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name,
|
||||
"status": "failed",
|
||||
"error_message": str(e),
|
||||
"failed_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
async def get_job_status(self,
|
||||
db: AsyncSession,
|
||||
job_id: str,
|
||||
tenant_id: str) -> Optional[ModelTrainingLog]:
|
||||
"""Get training job status"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(ModelTrainingLog).where(
|
||||
and_(
|
||||
ModelTrainingLog.job_id == job_id,
|
||||
ModelTrainingLog.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
async def validate_training_data(
|
||||
self,
|
||||
tenant_id: str,
|
||||
sales_data: List[Dict[str, Any]],
|
||||
products: Optional[List[str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate training data quality before starting training.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
sales_data: Sales data to validate
|
||||
products: Optional list of specific products to validate
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get job status: {str(e)}")
|
||||
return None
|
||||
|
||||
async def list_training_jobs(self,
|
||||
db: AsyncSession,
|
||||
tenant_id: str,
|
||||
limit: int = 10,
|
||||
status_filter: Optional[str] = None) -> List[ModelTrainingLog]:
|
||||
"""List training jobs for a tenant"""
|
||||
try:
|
||||
query = select(ModelTrainingLog).where(
|
||||
ModelTrainingLog.tenant_id == tenant_id
|
||||
).order_by(ModelTrainingLog.start_time.desc()).limit(limit)
|
||||
|
||||
if status_filter:
|
||||
query = query.where(ModelTrainingLog.status == status_filter)
|
||||
|
||||
result = await db.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list training jobs: {str(e)}")
|
||||
return []
|
||||
|
||||
async def cancel_training_job(self,
|
||||
db: AsyncSession,
|
||||
job_id: str,
|
||||
tenant_id: str) -> bool:
|
||||
"""Cancel a training job"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
update(ModelTrainingLog)
|
||||
.where(
|
||||
and_(
|
||||
ModelTrainingLog.job_id == job_id,
|
||||
ModelTrainingLog.tenant_id == tenant_id,
|
||||
ModelTrainingLog.status.in_(["pending", "running"])
|
||||
)
|
||||
)
|
||||
.values(
|
||||
status="cancelled",
|
||||
end_time=datetime.now(),
|
||||
current_step="Training cancelled by user"
|
||||
)
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
|
||||
if result.rowcount > 0:
|
||||
logger.info(f"Cancelled training job {job_id}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Could not cancel training job {job_id} - not found or not cancellable")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cancel training job: {str(e)}")
|
||||
await db.rollback()
|
||||
return False
|
||||
|
||||
async def validate_training_data(self,
|
||||
db: AsyncSession,
|
||||
tenant_id: str,
|
||||
config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate training data before starting a job"""
|
||||
Returns:
|
||||
Validation results
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Validating training data for tenant {tenant_id}")
|
||||
|
||||
issues = []
|
||||
recommendations = []
|
||||
|
||||
# Fetch a sample of sales data to validate
|
||||
sales_data = await self._fetch_sales_data(tenant_id, config, limit=1000)
|
||||
|
||||
# Extract sales date range for validation
|
||||
if not sales_data:
|
||||
issues.append("No sales data found for tenant")
|
||||
return {
|
||||
"is_valid": False,
|
||||
"issues": issues,
|
||||
"recommendations": ["Upload sales data before training"],
|
||||
"estimated_time_minutes": 0
|
||||
"valid": False,
|
||||
"errors": ["No sales data provided"],
|
||||
"warnings": []
|
||||
}
|
||||
|
||||
# Analyze data quality
|
||||
products = set(item.get("product_name") for item in sales_data)
|
||||
total_records = len(sales_data)
|
||||
|
||||
# Check for sufficient data per product
|
||||
product_counts = {}
|
||||
for item in sales_data:
|
||||
product = item.get("product_name")
|
||||
if product:
|
||||
product_counts[product] = product_counts.get(product, 0) + 1
|
||||
|
||||
insufficient_products = [
|
||||
product for product, count in product_counts.items()
|
||||
if count < config.get("min_data_points", 30)
|
||||
]
|
||||
|
||||
if insufficient_products:
|
||||
issues.append(f"Insufficient data for products: {', '.join(insufficient_products)}")
|
||||
recommendations.append("Collect more historical data for these products")
|
||||
|
||||
# Estimate training time
|
||||
valid_products = len(products) - len(insufficient_products)
|
||||
estimated_time = max(5, valid_products * 2) # 2 minutes per product minimum
|
||||
|
||||
is_valid = len(issues) == 0
|
||||
|
||||
return {
|
||||
"is_valid": is_valid,
|
||||
"issues": issues,
|
||||
"recommendations": recommendations,
|
||||
"estimated_time_minutes": estimated_time,
|
||||
"products_analyzed": len(products),
|
||||
"total_data_points": total_records
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to validate training data: {str(e)}")
|
||||
return {
|
||||
"is_valid": False,
|
||||
"issues": [f"Validation error: {str(e)}"],
|
||||
"recommendations": ["Check data service connectivity"],
|
||||
"estimated_time_minutes": 0
|
||||
}
|
||||
|
||||
async def _update_job_status(self,
|
||||
db: AsyncSession,
|
||||
job_id: str,
|
||||
status: str,
|
||||
progress: int,
|
||||
current_step: str,
|
||||
results: Optional[Dict] = None,
|
||||
error_message: Optional[str] = None):
|
||||
"""Update training job status"""
|
||||
try:
|
||||
update_values = {
|
||||
"status": status,
|
||||
"progress": progress,
|
||||
"current_step": current_step
|
||||
}
|
||||
|
||||
if status == "completed":
|
||||
update_values["end_time"] = datetime.now()
|
||||
|
||||
if results:
|
||||
update_values["results"] = results
|
||||
|
||||
if error_message:
|
||||
update_values["error_message"] = error_message
|
||||
update_values["end_time"] = datetime.now()
|
||||
|
||||
await db.execute(
|
||||
update(ModelTrainingLog)
|
||||
.where(ModelTrainingLog.job_id == job_id)
|
||||
.values(**update_values)
|
||||
# Create a mock training dataset to validate
|
||||
mock_dataset = await self.orchestrator.prepare_training_data(
|
||||
tenant_id=tenant_id,
|
||||
sales_data=sales_data,
|
||||
bakery_location=(40.4168, -3.7038), # Default Madrid
|
||||
job_id=f"validation_{uuid.uuid4().hex[:8]}"
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
# Validate the dataset
|
||||
validation_results = self.orchestrator.validate_training_data_quality(mock_dataset)
|
||||
|
||||
# Add product-specific information
|
||||
unique_products = list(set(record.get('product_name', 'unknown') for record in sales_data))
|
||||
product_data_points = {}
|
||||
|
||||
for record in sales_data:
|
||||
product = record.get('product_name', 'unknown')
|
||||
product_data_points[product] = product_data_points.get(product, 0) + 1
|
||||
|
||||
validation_results.update({
|
||||
"products_found": unique_products,
|
||||
"product_data_points": product_data_points,
|
||||
"total_records": len(sales_data),
|
||||
"date_range_info": {
|
||||
"start": mock_dataset.date_range.start.isoformat(),
|
||||
"end": mock_dataset.date_range.end.isoformat(),
|
||||
"duration_days": (mock_dataset.date_range.end - mock_dataset.date_range.start).days
|
||||
}
|
||||
})
|
||||
|
||||
return validation_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update job status: {str(e)}")
|
||||
await db.rollback()
|
||||
logger.error(f"Training data validation failed: {str(e)}")
|
||||
return {
|
||||
"valid": False,
|
||||
"errors": [f"Validation failed: {str(e)}"],
|
||||
"warnings": []
|
||||
}
|
||||
|
||||
async def _store_trained_models(self,
|
||||
db: AsyncSession,
|
||||
tenant_id: str,
|
||||
training_results: Dict[str, Any]):
|
||||
"""Store trained models in database"""
|
||||
async def get_training_recommendations(
|
||||
self,
|
||||
tenant_id: str,
|
||||
sales_data: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get training recommendations based on data analysis.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
sales_data: Historical sales data
|
||||
|
||||
Returns:
|
||||
Training recommendations
|
||||
"""
|
||||
try:
|
||||
models_to_store = []
|
||||
logger.info(f"Generating training recommendations for tenant {tenant_id}")
|
||||
|
||||
for product_name, result in training_results.get("training_results", {}).items():
|
||||
if result.get("status") == "success":
|
||||
model_info = result.get("model_info", {})
|
||||
|
||||
trained_model = TrainedModel(
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
model_id=model_info.get("model_id"),
|
||||
model_type=model_info.get("type", "prophet"),
|
||||
model_path=model_info.get("model_path"),
|
||||
version=1, # Start with version 1
|
||||
training_samples=model_info.get("training_samples", 0),
|
||||
features=model_info.get("features", []),
|
||||
hyperparameters=model_info.get("hyperparameters", {}),
|
||||
training_metrics=model_info.get("training_metrics", {}),
|
||||
data_period_start=datetime.fromisoformat(
|
||||
model_info.get("data_period", {}).get("start_date", datetime.now().isoformat())
|
||||
),
|
||||
data_period_end=datetime.fromisoformat(
|
||||
model_info.get("data_period", {}).get("end_date", datetime.now().isoformat())
|
||||
),
|
||||
created_at=datetime.now(),
|
||||
is_active=True
|
||||
)
|
||||
|
||||
models_to_store.append(trained_model)
|
||||
# Analyze the data
|
||||
validation_results = await self.validate_training_data(tenant_id, sales_data)
|
||||
|
||||
# Deactivate old models for these products
|
||||
if models_to_store:
|
||||
product_names = [model.product_name for model in models_to_store]
|
||||
|
||||
await db.execute(
|
||||
update(TrainedModel)
|
||||
.where(
|
||||
and_(
|
||||
TrainedModel.tenant_id == tenant_id,
|
||||
TrainedModel.product_name.in_(product_names),
|
||||
TrainedModel.is_active == True
|
||||
)
|
||||
)
|
||||
.values(is_active=False)
|
||||
)
|
||||
|
||||
# Add new models
|
||||
db.add_all(models_to_store)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"Stored {len(models_to_store)} trained models for tenant {tenant_id}")
|
||||
recommendations = {
|
||||
"should_retrain": True,
|
||||
"reasons": [],
|
||||
"recommended_products": [],
|
||||
"optimal_config": {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 30,
|
||||
"hyperparameter_optimization": True
|
||||
}
|
||||
}
|
||||
|
||||
# Analyze data quality and provide recommendations
|
||||
if validation_results.get("data_quality_score", 0) >= 80:
|
||||
recommendations["reasons"].append("High quality data detected")
|
||||
else:
|
||||
recommendations["reasons"].append("Data quality could be improved")
|
||||
|
||||
# Recommend products with sufficient data
|
||||
product_data_points = validation_results.get("product_data_points", {})
|
||||
for product, points in product_data_points.items():
|
||||
if points >= 30: # Minimum viable data points
|
||||
recommendations["recommended_products"].append(product)
|
||||
|
||||
if len(recommendations["recommended_products"]) == 0:
|
||||
recommendations["should_retrain"] = False
|
||||
recommendations["reasons"].append("Insufficient data for reliable training")
|
||||
|
||||
return recommendations
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store trained models: {str(e)}")
|
||||
await db.rollback()
|
||||
raise
|
||||
|
||||
async def _store_single_trained_model(self,
|
||||
db: AsyncSession,
|
||||
tenant_id: str,
|
||||
product_name: str,
|
||||
training_result: Dict[str, Any]):
|
||||
"""Store a single trained model"""
|
||||
try:
|
||||
if training_result.get("status") == "success":
|
||||
model_info = training_result.get("model_info", {})
|
||||
|
||||
# Deactivate old model for this product
|
||||
await db.execute(
|
||||
update(TrainedModel)
|
||||
.where(
|
||||
and_(
|
||||
TrainedModel.tenant_id == tenant_id,
|
||||
TrainedModel.product_name == product_name,
|
||||
TrainedModel.is_active == True
|
||||
)
|
||||
)
|
||||
.values(is_active=False)
|
||||
)
|
||||
|
||||
# Create new model record
|
||||
trained_model = TrainedModel(
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
model_id=model_info.get("model_id"),
|
||||
model_type=model_info.get("type", "prophet"),
|
||||
model_path=model_info.get("model_path"),
|
||||
version=1,
|
||||
training_samples=model_info.get("training_samples", 0),
|
||||
features=model_info.get("features", []),
|
||||
hyperparameters=model_info.get("hyperparameters", {}),
|
||||
training_metrics=model_info.get("training_metrics", {}),
|
||||
data_period_start=datetime.fromisoformat(
|
||||
model_info.get("data_period", {}).get("start_date", datetime.now().isoformat())
|
||||
),
|
||||
data_period_end=datetime.fromisoformat(
|
||||
model_info.get("data_period", {}).get("end_date", datetime.now().isoformat())
|
||||
),
|
||||
created_at=datetime.now(),
|
||||
is_active=True
|
||||
)
|
||||
|
||||
db.add(trained_model)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"Stored trained model for {product_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store trained model: {str(e)}")
|
||||
await db.rollback()
|
||||
raise
|
||||
|
||||
async def get_training_logs(self,
|
||||
db: AsyncSession,
|
||||
job_id: str,
|
||||
tenant_id: str) -> Optional[List[str]]:
|
||||
"""Get detailed training logs for a job"""
|
||||
try:
|
||||
# For now, return basic log information from the database
|
||||
# In a production system, you might store detailed logs separately
|
||||
result = await db.execute(
|
||||
select(ModelTrainingLog).where(
|
||||
and_(
|
||||
ModelTrainingLog.job_id == job_id,
|
||||
ModelTrainingLog.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
training_log = result.scalar_one_or_none()
|
||||
|
||||
if training_log:
|
||||
logs = [
|
||||
f"Job started at: {training_log.start_time}",
|
||||
f"Current status: {training_log.status}",
|
||||
f"Progress: {training_log.progress}%",
|
||||
f"Current step: {training_log.current_step}"
|
||||
]
|
||||
|
||||
if training_log.end_time:
|
||||
logs.append(f"Job completed at: {training_log.end_time}")
|
||||
|
||||
if training_log.error_message:
|
||||
logs.append(f"Error: {training_log.error_message}")
|
||||
|
||||
if training_log.results:
|
||||
results = training_log.results
|
||||
logs.append(f"Models trained: {results.get('products_trained', 0)}")
|
||||
logs.append(f"Models failed: {results.get('products_failed', 0)}")
|
||||
|
||||
return logs
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get training logs: {str(e)}")
|
||||
return None
|
||||
|
||||
async def _determine_sales_date_range(self, sales_data: List[Dict]) -> tuple[datetime, datetime]:
|
||||
"""Determine start and end dates from sales data"""
|
||||
if not sales_data:
|
||||
raise ValueError("No sales data available to determine date range")
|
||||
|
||||
dates = []
|
||||
for record in sales_data:
|
||||
if 'date' in record:
|
||||
if isinstance(record['date'], str):
|
||||
dates.append(datetime.fromisoformat(record['date'].replace('Z', '+00:00')))
|
||||
elif isinstance(record['date'], datetime):
|
||||
dates.append(record['date'])
|
||||
|
||||
if not dates:
|
||||
raise ValueError("No valid dates found in sales data")
|
||||
|
||||
start_date = min(dates)
|
||||
end_date = max(dates)
|
||||
|
||||
logger.info(f"Determined sales date range: {start_date} to {end_date}")
|
||||
return start_date, end_date
|
||||
logger.error(f"Failed to generate training recommendations: {str(e)}")
|
||||
return {
|
||||
"should_retrain": False,
|
||||
"reasons": [f"Error analyzing data: {str(e)}"],
|
||||
"recommended_products": [],
|
||||
"optimal_config": {}
|
||||
}
|
||||
Reference in New Issue
Block a user