REFACTOR external service and improve websocket training
This commit is contained in:
@@ -3,32 +3,14 @@ Training Service Layer
|
||||
Business logic services for ML training and model management
|
||||
"""
|
||||
|
||||
from .training_service import TrainingService
|
||||
from .training_service import EnhancedTrainingService
|
||||
from .training_orchestrator import TrainingDataOrchestrator
|
||||
from .date_alignment_service import DateAlignmentService
|
||||
from .data_client import DataClient
|
||||
from .messaging import (
|
||||
publish_job_progress,
|
||||
publish_data_validation_started,
|
||||
publish_data_validation_completed,
|
||||
publish_job_step_completed,
|
||||
publish_job_completed,
|
||||
publish_job_failed,
|
||||
TrainingStatusPublisher
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"TrainingService",
|
||||
"EnhancedTrainingService",
|
||||
"TrainingDataOrchestrator",
|
||||
"TrainingDataOrchestrator",
|
||||
"DateAlignmentService",
|
||||
"DataClient",
|
||||
"publish_job_progress",
|
||||
"publish_data_validation_started",
|
||||
"publish_data_validation_completed",
|
||||
"publish_job_step_completed",
|
||||
"publish_job_completed",
|
||||
"publish_job_failed",
|
||||
"TrainingStatusPublisher"
|
||||
"DataClient"
|
||||
]
|
||||
@@ -1,16 +1,20 @@
|
||||
# services/training/app/services/data_client.py
|
||||
"""
|
||||
Training Service Data Client
|
||||
Migrated to use shared service clients - much simpler now!
|
||||
Migrated to use shared service clients with timeout configuration
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
import httpx
|
||||
|
||||
# Import the shared clients
|
||||
from shared.clients import get_sales_client, get_external_client, get_service_clients
|
||||
from app.core.config import settings
|
||||
from app.core import constants as const
|
||||
from app.utils.circuit_breaker import circuit_breaker_registry, CircuitBreakerError
|
||||
from app.utils.retry import with_retry, HTTP_RETRY_STRATEGY, EXTERNAL_SERVICE_RETRY_STRATEGY
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
@@ -21,21 +25,103 @@ class DataClient:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Get the new specialized clients
|
||||
# Get the new specialized clients with timeout configuration
|
||||
self.sales_client = get_sales_client(settings, "training")
|
||||
self.external_client = get_external_client(settings, "training")
|
||||
|
||||
|
||||
# Configure timeouts for HTTP clients
|
||||
self._configure_timeouts()
|
||||
|
||||
# Initialize circuit breakers for external services
|
||||
self._init_circuit_breakers()
|
||||
|
||||
# Check if the new method is available for stored traffic data
|
||||
if hasattr(self.external_client, 'get_stored_traffic_data_for_training'):
|
||||
self.supports_stored_traffic_data = True
|
||||
|
||||
def _configure_timeouts(self):
|
||||
"""Configure appropriate timeouts for HTTP clients"""
|
||||
timeout = httpx.Timeout(
|
||||
connect=const.HTTP_TIMEOUT_DEFAULT,
|
||||
read=const.HTTP_TIMEOUT_LONG_RUNNING,
|
||||
write=const.HTTP_TIMEOUT_DEFAULT,
|
||||
pool=const.HTTP_TIMEOUT_DEFAULT
|
||||
)
|
||||
|
||||
# Apply timeout to clients if they have httpx clients
|
||||
if hasattr(self.sales_client, 'client') and isinstance(self.sales_client.client, httpx.AsyncClient):
|
||||
self.sales_client.client.timeout = timeout
|
||||
|
||||
if hasattr(self.external_client, 'client') and isinstance(self.external_client.client, httpx.AsyncClient):
|
||||
self.external_client.client.timeout = timeout
|
||||
else:
|
||||
self.supports_stored_traffic_data = False
|
||||
logger.warning("Stored traffic data method not available in external client")
|
||||
|
||||
# Or alternatively, get all clients at once:
|
||||
# self.clients = get_service_clients(settings, "training")
|
||||
# Then use: self.clients.sales.get_sales_data(...) and self.clients.external.get_weather_forecast(...)
|
||||
|
||||
def _init_circuit_breakers(self):
|
||||
"""Initialize circuit breakers for external service calls"""
|
||||
# Sales service circuit breaker
|
||||
self.sales_cb = circuit_breaker_registry.get_or_create(
|
||||
name="sales_service",
|
||||
failure_threshold=5,
|
||||
recovery_timeout=60.0,
|
||||
expected_exception=Exception
|
||||
)
|
||||
|
||||
# Weather service circuit breaker
|
||||
self.weather_cb = circuit_breaker_registry.get_or_create(
|
||||
name="weather_service",
|
||||
failure_threshold=3, # Weather is optional, fail faster
|
||||
recovery_timeout=30.0,
|
||||
expected_exception=Exception
|
||||
)
|
||||
|
||||
# Traffic service circuit breaker
|
||||
self.traffic_cb = circuit_breaker_registry.get_or_create(
|
||||
name="traffic_service",
|
||||
failure_threshold=3, # Traffic is optional, fail faster
|
||||
recovery_timeout=30.0,
|
||||
expected_exception=Exception
|
||||
)
|
||||
|
||||
@with_retry(max_attempts=3, initial_delay=1.0, max_delay=10.0)
|
||||
async def _fetch_sales_data_internal(
|
||||
self,
|
||||
tenant_id: str,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
product_id: Optional[str] = None,
|
||||
fetch_all: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Internal method to fetch sales data with automatic retry"""
|
||||
if fetch_all:
|
||||
sales_data = await self.sales_client.get_all_sales_data(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
product_id=product_id,
|
||||
aggregation="daily",
|
||||
page_size=1000,
|
||||
max_pages=100
|
||||
)
|
||||
else:
|
||||
sales_data = await self.sales_client.get_sales_data(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
product_id=product_id,
|
||||
aggregation="daily"
|
||||
)
|
||||
sales_data = sales_data or []
|
||||
|
||||
if sales_data:
|
||||
logger.info(f"Fetched {len(sales_data)} sales records",
|
||||
tenant_id=tenant_id, product_id=product_id, fetch_all=fetch_all)
|
||||
return sales_data
|
||||
else:
|
||||
logger.error("No sales data returned", tenant_id=tenant_id)
|
||||
raise ValueError(f"No sales data available for tenant {tenant_id}")
|
||||
|
||||
async def fetch_sales_data(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@@ -45,50 +131,21 @@ class DataClient:
|
||||
fetch_all: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch sales data for training
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
start_date: Start date in ISO format
|
||||
end_date: End date in ISO format
|
||||
product_id: Optional product filter
|
||||
fetch_all: If True, fetches ALL records using pagination (original behavior)
|
||||
If False, fetches limited records (standard API response)
|
||||
Fetch sales data for training with circuit breaker protection
|
||||
"""
|
||||
try:
|
||||
if fetch_all:
|
||||
# Use paginated method to get ALL records (original behavior)
|
||||
sales_data = await self.sales_client.get_all_sales_data(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
product_id=product_id,
|
||||
aggregation="daily",
|
||||
page_size=1000, # Comply with API limit
|
||||
max_pages=100 # Safety limit (500k records max)
|
||||
)
|
||||
else:
|
||||
# Use standard method for limited results
|
||||
sales_data = await self.sales_client.get_sales_data(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
product_id=product_id,
|
||||
aggregation="daily"
|
||||
)
|
||||
sales_data = sales_data or []
|
||||
|
||||
if sales_data:
|
||||
logger.info(f"Fetched {len(sales_data)} sales records",
|
||||
tenant_id=tenant_id, product_id=product_id, fetch_all=fetch_all)
|
||||
return sales_data
|
||||
else:
|
||||
logger.warning("No sales data returned", tenant_id=tenant_id)
|
||||
return []
|
||||
|
||||
return await self.sales_cb.call(
|
||||
self._fetch_sales_data_internal,
|
||||
tenant_id, start_date, end_date, product_id, fetch_all
|
||||
)
|
||||
except CircuitBreakerError as e:
|
||||
logger.error(f"Sales service circuit breaker open: {e}")
|
||||
raise RuntimeError(f"Sales service unavailable: {str(e)}")
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching sales data: {e}", tenant_id=tenant_id)
|
||||
return []
|
||||
raise RuntimeError(f"Failed to fetch sales data: {str(e)}")
|
||||
|
||||
async def fetch_weather_data(
|
||||
self,
|
||||
@@ -112,15 +169,15 @@ class DataClient:
|
||||
)
|
||||
|
||||
if weather_data:
|
||||
logger.info(f"Fetched {len(weather_data)} weather records",
|
||||
logger.info(f"Fetched {len(weather_data)} weather records",
|
||||
tenant_id=tenant_id)
|
||||
return weather_data
|
||||
else:
|
||||
logger.warning("No weather data returned", tenant_id=tenant_id)
|
||||
logger.warning("No weather data returned, will use synthetic data", tenant_id=tenant_id)
|
||||
return []
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching weather data: {e}", tenant_id=tenant_id)
|
||||
logger.warning(f"Error fetching weather data, will use synthetic data: {e}", tenant_id=tenant_id)
|
||||
return []
|
||||
|
||||
async def fetch_traffic_data_unified(
|
||||
@@ -264,34 +321,93 @@ class DataClient:
|
||||
self,
|
||||
tenant_id: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
end_date: str,
|
||||
sales_data: List[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate data quality before training
|
||||
Validate data quality before training with comprehensive checks
|
||||
"""
|
||||
try:
|
||||
# Note: validation_data_quality may need to be implemented in one of the new services
|
||||
# validation_result = await self.sales_client.validate_data_quality(
|
||||
# tenant_id=tenant_id,
|
||||
# start_date=start_date,
|
||||
# end_date=end_date
|
||||
# )
|
||||
|
||||
# Temporary implementation - assume data is valid for now
|
||||
validation_result = {"is_valid": True, "message": "Validation temporarily disabled"}
|
||||
|
||||
if validation_result:
|
||||
logger.info("Data validation completed",
|
||||
tenant_id=tenant_id,
|
||||
is_valid=validation_result.get("is_valid", False))
|
||||
return validation_result
|
||||
errors = []
|
||||
warnings = []
|
||||
|
||||
# If sales data provided, validate it directly
|
||||
if sales_data is not None:
|
||||
if not sales_data or len(sales_data) == 0:
|
||||
errors.append("No sales data available for the specified period")
|
||||
return {"is_valid": False, "errors": errors, "warnings": warnings}
|
||||
|
||||
# Check minimum data points
|
||||
if len(sales_data) < 30:
|
||||
errors.append(f"Insufficient data points: {len(sales_data)} (minimum 30 required)")
|
||||
elif len(sales_data) < 90:
|
||||
warnings.append(f"Limited data points: {len(sales_data)} (recommended 90+)")
|
||||
|
||||
# Check for required fields
|
||||
required_fields = ['date', 'inventory_product_id']
|
||||
for record in sales_data[:5]: # Sample check
|
||||
missing = [f for f in required_fields if f not in record or record[f] is None]
|
||||
if missing:
|
||||
errors.append(f"Missing required fields: {missing}")
|
||||
break
|
||||
|
||||
# Check for data quality issues
|
||||
zero_count = sum(1 for r in sales_data if r.get('quantity', 0) == 0)
|
||||
zero_ratio = zero_count / len(sales_data)
|
||||
if zero_ratio > 0.9:
|
||||
errors.append(f"Too many zero values: {zero_ratio:.1%} of records")
|
||||
elif zero_ratio > 0.7:
|
||||
warnings.append(f"High zero value ratio: {zero_ratio:.1%}")
|
||||
|
||||
# Check product diversity
|
||||
unique_products = set(r.get('inventory_product_id') for r in sales_data if r.get('inventory_product_id'))
|
||||
if len(unique_products) == 0:
|
||||
errors.append("No valid product IDs found in sales data")
|
||||
elif len(unique_products) == 1:
|
||||
warnings.append("Only one product found - consider adding more products")
|
||||
|
||||
else:
|
||||
logger.warning("Data validation failed", tenant_id=tenant_id)
|
||||
return {"is_valid": False, "errors": ["Validation service unavailable"]}
|
||||
|
||||
# Fetch data for validation
|
||||
sales_data = await self.fetch_sales_data(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
fetch_all=False
|
||||
)
|
||||
|
||||
if not sales_data:
|
||||
errors.append("Unable to fetch sales data for validation")
|
||||
return {"is_valid": False, "errors": errors, "warnings": warnings}
|
||||
|
||||
# Recursive call with fetched data
|
||||
return await self.validate_data_quality(
|
||||
tenant_id, start_date, end_date, sales_data
|
||||
)
|
||||
|
||||
is_valid = len(errors) == 0
|
||||
result = {
|
||||
"is_valid": is_valid,
|
||||
"errors": errors,
|
||||
"warnings": warnings,
|
||||
"data_points": len(sales_data) if sales_data else 0,
|
||||
"unique_products": len(unique_products) if sales_data else 0
|
||||
}
|
||||
|
||||
if is_valid:
|
||||
logger.info("Data validation passed",
|
||||
tenant_id=tenant_id,
|
||||
data_points=result["data_points"],
|
||||
warnings_count=len(warnings))
|
||||
else:
|
||||
logger.error("Data validation failed",
|
||||
tenant_id=tenant_id,
|
||||
errors=errors)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating data: {e}", tenant_id=tenant_id)
|
||||
return {"is_valid": False, "errors": [str(e)]}
|
||||
raise ValueError(f"Data validation failed: {str(e)}")
|
||||
|
||||
# Global instance - same as before, but much simpler implementation
|
||||
data_client = DataClient()
|
||||
@@ -1,9 +1,9 @@
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from app.utils.timezone_utils import ensure_timezone_aware
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -84,31 +84,25 @@ class DateAlignmentService:
|
||||
requested_end: Optional[datetime]
|
||||
) -> DateRange:
|
||||
"""Determine the base date range for training."""
|
||||
|
||||
# ✅ FIX: Ensure all datetimes are timezone-aware for comparison
|
||||
def ensure_timezone_aware(dt: datetime) -> datetime:
|
||||
if dt.tzinfo is None:
|
||||
return dt.replace(tzinfo=timezone.utc)
|
||||
return dt
|
||||
|
||||
|
||||
# Use explicit dates if provided
|
||||
if requested_start and requested_end:
|
||||
requested_start = ensure_timezone_aware(requested_start)
|
||||
requested_end = ensure_timezone_aware(requested_end)
|
||||
|
||||
|
||||
if requested_end <= requested_start:
|
||||
raise ValueError("End date must be after start date")
|
||||
return DateRange(requested_start, requested_end, DataSourceType.BAKERY_SALES)
|
||||
|
||||
|
||||
# Otherwise, use the user's sales data range as the foundation
|
||||
start_date = ensure_timezone_aware(requested_start or user_sales_range.start)
|
||||
end_date = ensure_timezone_aware(requested_end or user_sales_range.end)
|
||||
|
||||
|
||||
# Ensure we don't exceed maximum training range
|
||||
if (end_date - start_date).days > self.MAX_TRAINING_RANGE_DAYS:
|
||||
start_date = end_date - timedelta(days=self.MAX_TRAINING_RANGE_DAYS)
|
||||
logger.warning(f"Limiting training range to {self.MAX_TRAINING_RANGE_DAYS} days")
|
||||
|
||||
|
||||
return DateRange(start_date, end_date, DataSourceType.BAKERY_SALES)
|
||||
|
||||
def _apply_data_source_constraints(self, base_range: DateRange) -> AlignedDateRange:
|
||||
|
||||
@@ -1,603 +0,0 @@
|
||||
# services/training/app/services/messaging.py
|
||||
"""
|
||||
Enhanced training service messaging - Complete status publishing implementation
|
||||
Uses shared RabbitMQ infrastructure with comprehensive progress tracking
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime
|
||||
from shared.messaging.rabbitmq import RabbitMQClient
|
||||
from shared.messaging.events import (
|
||||
TrainingStartedEvent,
|
||||
TrainingCompletedEvent,
|
||||
TrainingFailedEvent
|
||||
)
|
||||
from app.core.config import settings
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Single global instance
|
||||
training_publisher = RabbitMQClient(settings.RABBITMQ_URL, "training-service")
|
||||
|
||||
async def setup_messaging():
|
||||
"""Initialize messaging for training service"""
|
||||
success = await training_publisher.connect()
|
||||
if success:
|
||||
logger.info("Training service messaging initialized")
|
||||
else:
|
||||
logger.warning("Training service messaging failed to initialize")
|
||||
|
||||
async def cleanup_messaging():
|
||||
"""Cleanup messaging for training service"""
|
||||
await training_publisher.disconnect()
|
||||
logger.info("Training service messaging cleaned up")
|
||||
|
||||
def serialize_for_json(obj: Any) -> Any:
|
||||
"""
|
||||
Convert numpy types and other non-JSON serializable objects to JSON-compatible types
|
||||
"""
|
||||
if isinstance(obj, np.integer):
|
||||
return int(obj)
|
||||
elif isinstance(obj, np.floating):
|
||||
return float(obj)
|
||||
elif isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
elif isinstance(obj, np.bool_):
|
||||
return bool(obj)
|
||||
elif isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
elif isinstance(obj, dict):
|
||||
return {key: serialize_for_json(value) for key, value in obj.items()}
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
return [serialize_for_json(item) for item in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
def safe_json_serialize(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Recursively clean data dictionary for JSON serialization
|
||||
"""
|
||||
return serialize_for_json(data)
|
||||
|
||||
async def setup_websocket_message_routing():
|
||||
"""Set up message routing for WebSocket connections"""
|
||||
try:
|
||||
# This will be called from the WebSocket endpoint
|
||||
# to set up the consumer for a specific job
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to set up WebSocket message routing: {e}")
|
||||
|
||||
# =========================================
|
||||
# ENHANCED TRAINING JOB STATUS EVENTS
|
||||
# =========================================
|
||||
|
||||
async def publish_job_started(job_id: str, tenant_id: str, config: Dict[str, Any]) -> bool:
|
||||
"""Publish training job started event"""
|
||||
event = TrainingStartedEvent(
|
||||
service_name="training-service",
|
||||
data={
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"config": config,
|
||||
"started_at": datetime.now().isoformat(),
|
||||
"estimated_duration_minutes": config.get("estimated_duration_minutes", 15)
|
||||
}
|
||||
)
|
||||
success = await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.started",
|
||||
event_data=event.to_dict()
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"Published job started event", job_id=job_id, tenant_id=tenant_id)
|
||||
else:
|
||||
logger.error(f"Failed to publish job started event", job_id=job_id)
|
||||
|
||||
return success
|
||||
|
||||
async def publish_job_progress(
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
progress: int,
|
||||
step: str,
|
||||
current_product: Optional[str] = None,
|
||||
products_completed: int = 0,
|
||||
products_total: int = 0,
|
||||
estimated_time_remaining_minutes: Optional[int] = None,
|
||||
step_details: Optional[str] = None
|
||||
) -> bool:
|
||||
"""Publish detailed training job progress event with safe serialization"""
|
||||
event_data = {
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.progress",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"progress": min(max(int(progress), 0), 100), # Ensure int, not numpy.int64
|
||||
"current_step": step,
|
||||
"current_product": current_product,
|
||||
"products_completed": int(products_completed), # Convert numpy types
|
||||
"products_total": int(products_total),
|
||||
"estimated_time_remaining_minutes": int(estimated_time_remaining_minutes) if estimated_time_remaining_minutes else None,
|
||||
"step_details": step_details
|
||||
}
|
||||
}
|
||||
|
||||
# Clean the entire event data
|
||||
clean_event_data = safe_json_serialize(event_data)
|
||||
|
||||
success = await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.progress",
|
||||
event_data=clean_event_data
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"Published progress update",
|
||||
job_id=job_id,
|
||||
progress=progress,
|
||||
step=step,
|
||||
current_product=current_product)
|
||||
else:
|
||||
logger.error(f"Failed to publish progress update", job_id=job_id)
|
||||
|
||||
return success
|
||||
|
||||
async def publish_job_step_completed(
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
step_name: str,
|
||||
step_result: Dict[str, Any],
|
||||
progress: int
|
||||
) -> bool:
|
||||
"""Publish when a major training step is completed"""
|
||||
event_data = {
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.step.completed",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"step_name": step_name,
|
||||
"step_result": step_result,
|
||||
"progress": progress,
|
||||
"completed_at": datetime.now().isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
return await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.step.completed",
|
||||
event_data=event_data
|
||||
)
|
||||
|
||||
async def publish_job_completed(job_id: str, tenant_id: str, results: Dict[str, Any]) -> bool:
|
||||
"""Publish training job completed event with safe JSON serialization"""
|
||||
|
||||
# Clean the results data before creating the event
|
||||
clean_results = safe_json_serialize(results)
|
||||
|
||||
event = TrainingCompletedEvent(
|
||||
service_name="training-service",
|
||||
data={
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"results": clean_results, # Now safe for JSON
|
||||
"models_trained": clean_results.get("successful_trainings", 0),
|
||||
"success_rate": clean_results.get("success_rate", 0),
|
||||
"total_duration_seconds": clean_results.get("overall_training_time_seconds", 0),
|
||||
"completed_at": datetime.now().isoformat()
|
||||
}
|
||||
)
|
||||
|
||||
success = await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.completed",
|
||||
event_data=event.to_dict()
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"Published job completed event",
|
||||
job_id=job_id,
|
||||
models_trained=clean_results.get("successful_trainings", 0))
|
||||
else:
|
||||
logger.error(f"Failed to publish job completed event", job_id=job_id)
|
||||
|
||||
return success
|
||||
|
||||
async def publish_job_failed(job_id: str, tenant_id: str, error: str, error_details: Optional[Dict] = None) -> bool:
|
||||
"""Publish training job failed event"""
|
||||
event = TrainingFailedEvent(
|
||||
service_name="training-service",
|
||||
data={
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"error": error,
|
||||
"error_details": error_details or {},
|
||||
"failed_at": datetime.now().isoformat()
|
||||
}
|
||||
)
|
||||
|
||||
success = await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.failed",
|
||||
event_data=event.to_dict()
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"Published job failed event", job_id=job_id, error=error)
|
||||
else:
|
||||
logger.error(f"Failed to publish job failed event", job_id=job_id)
|
||||
|
||||
return success
|
||||
|
||||
async def publish_job_cancelled(job_id: str, tenant_id: str, reason: str = "User requested") -> bool:
|
||||
"""Publish training job cancelled event"""
|
||||
event_data = {
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.cancelled",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"reason": reason,
|
||||
"cancelled_at": datetime.now().isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
return await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.cancelled",
|
||||
event_data=event_data
|
||||
)
|
||||
|
||||
# =========================================
|
||||
# PRODUCT-LEVEL TRAINING EVENTS
|
||||
# =========================================
|
||||
|
||||
async def publish_product_training_started(job_id: str, tenant_id: str, inventory_product_id: str) -> bool:
|
||||
"""Publish single product training started event"""
|
||||
return await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.product.started",
|
||||
event_data={
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.product.started",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"started_at": datetime.now().isoformat()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
async def publish_product_training_completed(
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
model_id: str,
|
||||
metrics: Optional[Dict[str, float]] = None
|
||||
) -> bool:
|
||||
"""Publish single product training completed event"""
|
||||
return await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.product.completed",
|
||||
event_data={
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.product.completed",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"model_id": model_id,
|
||||
"metrics": metrics or {},
|
||||
"completed_at": datetime.now().isoformat()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
async def publish_product_training_failed(
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
error: str
|
||||
) -> bool:
|
||||
"""Publish single product training failed event"""
|
||||
return await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.product.failed",
|
||||
event_data={
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.product.failed",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"error": error,
|
||||
"failed_at": datetime.now().isoformat()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# =========================================
|
||||
# MODEL LIFECYCLE EVENTS
|
||||
# =========================================
|
||||
|
||||
async def publish_model_trained(model_id: str, tenant_id: str, inventory_product_id: str, metrics: Dict[str, float]) -> bool:
|
||||
"""Publish model trained event with safe metric serialization"""
|
||||
|
||||
# Clean metrics to ensure JSON serialization
|
||||
clean_metrics = safe_json_serialize(metrics) if metrics else {}
|
||||
|
||||
event_data = {
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.model.trained",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"model_id": model_id,
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"training_metrics": clean_metrics, # Now safe for JSON
|
||||
"trained_at": datetime.now().isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
return await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.model.trained",
|
||||
event_data=event_data
|
||||
)
|
||||
|
||||
|
||||
async def publish_model_validated(model_id: str, tenant_id: str, inventory_product_id: str, validation_results: Dict[str, Any]) -> bool:
|
||||
"""Publish model validation event"""
|
||||
return await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.model.validated",
|
||||
event_data={
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.model.validated",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"model_id": model_id,
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"validation_results": validation_results,
|
||||
"validated_at": datetime.now().isoformat()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
async def publish_model_saved(model_id: str, tenant_id: str, inventory_product_id: str, model_path: str) -> bool:
|
||||
"""Publish model saved event"""
|
||||
return await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.model.saved",
|
||||
event_data={
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.model.saved",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"model_id": model_id,
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"model_path": model_path,
|
||||
"saved_at": datetime.now().isoformat()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# =========================================
|
||||
# DATA PROCESSING EVENTS
|
||||
# =========================================
|
||||
|
||||
async def publish_data_validation_started(job_id: str, tenant_id: str, products: List[str]) -> bool:
|
||||
"""Publish data validation started event"""
|
||||
return await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.data.validation.started",
|
||||
event_data={
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.data.validation.started",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"products": products,
|
||||
"started_at": datetime.now().isoformat()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
async def publish_data_validation_completed(
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
validation_results: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""Publish data validation completed event"""
|
||||
return await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.data.validation.completed",
|
||||
event_data={
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.data.validation.completed",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"validation_results": validation_results,
|
||||
"completed_at": datetime.now().isoformat()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def publish_models_deleted_event(tenant_id: str, deletion_stats: Dict[str, Any]):
|
||||
"""Publish models deletion event to message queue"""
|
||||
try:
|
||||
await training_publisher.publish_event(
|
||||
exchange="training_events",
|
||||
routing_key="training.tenant.models.deleted",
|
||||
message={
|
||||
"event_type": "tenant_models_deleted",
|
||||
"tenant_id": tenant_id,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"deletion_stats": deletion_stats
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to publish models deletion event", error=str(e))
|
||||
|
||||
|
||||
# =========================================
|
||||
# UTILITY FUNCTIONS FOR BATCH PUBLISHING
|
||||
# =========================================
|
||||
|
||||
async def publish_batch_status_update(
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
updates: List[Dict[str, Any]]
|
||||
) -> bool:
|
||||
"""Publish multiple status updates as a batch"""
|
||||
batch_event = {
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.batch.update",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"updates": updates,
|
||||
"batch_size": len(updates)
|
||||
}
|
||||
}
|
||||
|
||||
return await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.batch.update",
|
||||
event_data=batch_event
|
||||
)
|
||||
|
||||
# =========================================
|
||||
# HELPER FUNCTIONS FOR TRAINING INTEGRATION
|
||||
# =========================================
|
||||
|
||||
class TrainingStatusPublisher:
|
||||
"""Helper class to manage training status publishing throughout the training process"""
|
||||
|
||||
def __init__(self, job_id: str, tenant_id: str):
|
||||
self.job_id = job_id
|
||||
self.tenant_id = tenant_id
|
||||
self.start_time = datetime.now()
|
||||
self.products_total = 0
|
||||
self.products_completed = 0
|
||||
|
||||
async def job_started(self, config: Dict[str, Any], products_total: int = 0):
|
||||
"""Publish job started with initial configuration"""
|
||||
self.products_total = products_total
|
||||
|
||||
# Clean config data
|
||||
clean_config = safe_json_serialize(config)
|
||||
|
||||
await publish_job_started(self.job_id, self.tenant_id, clean_config)
|
||||
|
||||
async def progress_update(
|
||||
self,
|
||||
progress: int,
|
||||
step: str,
|
||||
current_product: Optional[str] = None,
|
||||
step_details: Optional[str] = None
|
||||
):
|
||||
"""Publish progress update with improved time estimates"""
|
||||
elapsed_minutes = (datetime.now() - self.start_time).total_seconds() / 60
|
||||
|
||||
# Improved estimation based on training phases
|
||||
estimated_remaining = self._calculate_smart_time_remaining(progress, elapsed_minutes, step)
|
||||
|
||||
await publish_job_progress(
|
||||
job_id=self.job_id,
|
||||
tenant_id=self.tenant_id,
|
||||
progress=int(progress),
|
||||
step=step,
|
||||
current_product=current_product,
|
||||
products_completed=int(self.products_completed),
|
||||
products_total=int(self.products_total),
|
||||
estimated_time_remaining_minutes=int(estimated_remaining) if estimated_remaining else None,
|
||||
step_details=step_details
|
||||
)
|
||||
|
||||
def _calculate_smart_time_remaining(self, progress: int, elapsed_minutes: float, step: str) -> Optional[int]:
|
||||
"""Calculate estimated time remaining using phase-based estimation"""
|
||||
|
||||
# Define expected time distribution for each phase
|
||||
phase_durations = {
|
||||
"data_validation": 1.0, # 1 minute
|
||||
"feature_engineering": 2.0, # 2 minutes
|
||||
"model_training": 8.0, # 8 minutes (bulk of time)
|
||||
"model_validation": 1.0 # 1 minute
|
||||
}
|
||||
|
||||
total_expected_minutes = sum(phase_durations.values()) # 12 minutes
|
||||
|
||||
# Calculate progress through phases
|
||||
if progress <= 10: # data_validation phase
|
||||
remaining_in_phase = phase_durations["data_validation"] * (1 - (progress / 10))
|
||||
remaining_after_phase = sum(list(phase_durations.values())[1:])
|
||||
return int(remaining_in_phase + remaining_after_phase)
|
||||
|
||||
elif progress <= 20: # feature_engineering phase
|
||||
remaining_in_phase = phase_durations["feature_engineering"] * (1 - ((progress - 10) / 10))
|
||||
remaining_after_phase = sum(list(phase_durations.values())[2:])
|
||||
return int(remaining_in_phase + remaining_after_phase)
|
||||
|
||||
elif progress <= 90: # model_training phase (biggest chunk)
|
||||
remaining_in_phase = phase_durations["model_training"] * (1 - ((progress - 20) / 70))
|
||||
remaining_after_phase = phase_durations["model_validation"]
|
||||
return int(remaining_in_phase + remaining_after_phase)
|
||||
|
||||
elif progress <= 100: # model_validation phase
|
||||
remaining_in_phase = phase_durations["model_validation"] * (1 - ((progress - 90) / 10))
|
||||
return int(remaining_in_phase)
|
||||
|
||||
return 0
|
||||
|
||||
async def product_completed(self, inventory_product_id: str, model_id: str, metrics: Optional[Dict] = None):
|
||||
"""Mark a product as completed and update progress"""
|
||||
self.products_completed += 1
|
||||
|
||||
# Clean metrics before publishing
|
||||
clean_metrics = safe_json_serialize(metrics) if metrics else None
|
||||
|
||||
await publish_product_training_completed(
|
||||
self.job_id, self.tenant_id, inventory_product_id, model_id, clean_metrics
|
||||
)
|
||||
|
||||
# Update overall progress
|
||||
if self.products_total > 0:
|
||||
progress = int((self.products_completed / self.products_total) * 90) # Save 10% for final steps
|
||||
await self.progress_update(
|
||||
progress=progress,
|
||||
step=f"Completed training for {inventory_product_id}",
|
||||
current_product=None
|
||||
)
|
||||
|
||||
async def job_completed(self, results: Dict[str, Any]):
|
||||
"""Publish job completion with clean data"""
|
||||
clean_results = safe_json_serialize(results)
|
||||
await publish_job_completed(self.job_id, self.tenant_id, clean_results)
|
||||
|
||||
async def job_failed(self, error: str, error_details: Optional[Dict] = None):
|
||||
"""Publish job failure with clean error details"""
|
||||
clean_error_details = safe_json_serialize(error_details) if error_details else None
|
||||
await publish_job_failed(self.job_id, self.tenant_id, error, clean_error_details)
|
||||
|
||||
78
services/training/app/services/progress_tracker.py
Normal file
78
services/training/app/services/progress_tracker.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
Training Progress Tracker
|
||||
Manages progress calculation for parallel product training (20-80% range)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import structlog
|
||||
from typing import Optional
|
||||
|
||||
from app.services.training_events import publish_product_training_completed
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class ParallelProductProgressTracker:
|
||||
"""
|
||||
Tracks parallel product training progress and emits events.
|
||||
|
||||
For N products training in parallel:
|
||||
- Each product completion contributes 60/N% to overall progress
|
||||
- Progress range: 20% (after data analysis) to 80% (before completion)
|
||||
- Thread-safe for concurrent product trainings
|
||||
"""
|
||||
|
||||
def __init__(self, job_id: str, tenant_id: str, total_products: int):
|
||||
self.job_id = job_id
|
||||
self.tenant_id = tenant_id
|
||||
self.total_products = total_products
|
||||
self.products_completed = 0
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Calculate progress increment per product
|
||||
# 60% of total progress (from 20% to 80%) divided by number of products
|
||||
self.progress_per_product = 60 / total_products if total_products > 0 else 0
|
||||
|
||||
logger.info("ParallelProductProgressTracker initialized",
|
||||
job_id=job_id,
|
||||
total_products=total_products,
|
||||
progress_per_product=f"{self.progress_per_product:.2f}%")
|
||||
|
||||
async def mark_product_completed(self, product_name: str) -> int:
|
||||
"""
|
||||
Mark a product as completed and publish event.
|
||||
Returns the current overall progress percentage.
|
||||
"""
|
||||
async with self._lock:
|
||||
self.products_completed += 1
|
||||
current_progress = self.products_completed
|
||||
|
||||
# Publish product completion event
|
||||
await publish_product_training_completed(
|
||||
job_id=self.job_id,
|
||||
tenant_id=self.tenant_id,
|
||||
product_name=product_name,
|
||||
products_completed=current_progress,
|
||||
total_products=self.total_products
|
||||
)
|
||||
|
||||
# Calculate overall progress (20% base + progress from completed products)
|
||||
# This calculation is done on the frontend/consumer side based on the event data
|
||||
overall_progress = 20 + int((current_progress / self.total_products) * 60)
|
||||
|
||||
logger.info("Product training completed",
|
||||
job_id=self.job_id,
|
||||
product_name=product_name,
|
||||
products_completed=current_progress,
|
||||
total_products=self.total_products,
|
||||
overall_progress=overall_progress)
|
||||
|
||||
return overall_progress
|
||||
|
||||
def get_progress(self) -> dict:
|
||||
"""Get current progress summary"""
|
||||
return {
|
||||
"products_completed": self.products_completed,
|
||||
"total_products": self.total_products,
|
||||
"progress_percentage": 20 + int((self.products_completed / self.total_products) * 60)
|
||||
}
|
||||
238
services/training/app/services/training_events.py
Normal file
238
services/training/app/services/training_events.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""
|
||||
Training Progress Events Publisher
|
||||
Simple, clean event publisher for the 4 main training steps
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
from shared.messaging.rabbitmq import RabbitMQClient
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Single global publisher instance
|
||||
training_publisher = RabbitMQClient(settings.RABBITMQ_URL, "training-service")
|
||||
|
||||
|
||||
async def setup_messaging():
|
||||
"""Initialize messaging"""
|
||||
success = await training_publisher.connect()
|
||||
if success:
|
||||
logger.info("Training messaging initialized")
|
||||
else:
|
||||
logger.warning("Training messaging failed to initialize")
|
||||
return success
|
||||
|
||||
|
||||
async def cleanup_messaging():
|
||||
"""Cleanup messaging"""
|
||||
await training_publisher.disconnect()
|
||||
logger.info("Training messaging cleaned up")
|
||||
|
||||
|
||||
# ==========================================
|
||||
# 4 MAIN TRAINING PROGRESS EVENTS
|
||||
# ==========================================
|
||||
|
||||
async def publish_training_started(
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
total_products: int
|
||||
) -> bool:
|
||||
"""
|
||||
Event 1: Training Started (0% progress)
|
||||
"""
|
||||
event_data = {
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.started",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"progress": 0,
|
||||
"current_step": "Training Started",
|
||||
"step_details": f"Starting training for {total_products} products",
|
||||
"total_products": total_products
|
||||
}
|
||||
}
|
||||
|
||||
success = await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.started",
|
||||
event_data=event_data
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("Published training started event",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
total_products=total_products)
|
||||
else:
|
||||
logger.error("Failed to publish training started event", job_id=job_id)
|
||||
|
||||
return success
|
||||
|
||||
|
||||
async def publish_data_analysis(
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
analysis_details: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Event 2: Data Analysis (20% progress)
|
||||
"""
|
||||
event_data = {
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.progress",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"progress": 20,
|
||||
"current_step": "Data Analysis",
|
||||
"step_details": analysis_details or "Analyzing sales, weather, and traffic data"
|
||||
}
|
||||
}
|
||||
|
||||
success = await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.progress",
|
||||
event_data=event_data
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("Published data analysis event",
|
||||
job_id=job_id,
|
||||
progress=20)
|
||||
else:
|
||||
logger.error("Failed to publish data analysis event", job_id=job_id)
|
||||
|
||||
return success
|
||||
|
||||
|
||||
async def publish_product_training_completed(
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
product_name: str,
|
||||
products_completed: int,
|
||||
total_products: int
|
||||
) -> bool:
|
||||
"""
|
||||
Event 3: Product Training Completed (contributes to 20-80% progress)
|
||||
|
||||
This event is published each time a product training completes.
|
||||
The frontend/consumer will calculate the progress as:
|
||||
progress = 20 + (products_completed / total_products) * 60
|
||||
"""
|
||||
event_data = {
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.product.completed",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name,
|
||||
"products_completed": products_completed,
|
||||
"total_products": total_products,
|
||||
"current_step": "Model Training",
|
||||
"step_details": f"Completed training for {product_name} ({products_completed}/{total_products})"
|
||||
}
|
||||
}
|
||||
|
||||
success = await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.product.completed",
|
||||
event_data=event_data
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("Published product training completed event",
|
||||
job_id=job_id,
|
||||
product_name=product_name,
|
||||
products_completed=products_completed,
|
||||
total_products=total_products)
|
||||
else:
|
||||
logger.error("Failed to publish product training completed event",
|
||||
job_id=job_id)
|
||||
|
||||
return success
|
||||
|
||||
|
||||
async def publish_training_completed(
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
successful_trainings: int,
|
||||
failed_trainings: int,
|
||||
total_duration_seconds: float
|
||||
) -> bool:
|
||||
"""
|
||||
Event 4: Training Completed (100% progress)
|
||||
"""
|
||||
event_data = {
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.completed",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"progress": 100,
|
||||
"current_step": "Training Completed",
|
||||
"step_details": f"Training completed: {successful_trainings} successful, {failed_trainings} failed",
|
||||
"successful_trainings": successful_trainings,
|
||||
"failed_trainings": failed_trainings,
|
||||
"total_duration_seconds": total_duration_seconds
|
||||
}
|
||||
}
|
||||
|
||||
success = await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.completed",
|
||||
event_data=event_data
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("Published training completed event",
|
||||
job_id=job_id,
|
||||
successful_trainings=successful_trainings,
|
||||
failed_trainings=failed_trainings)
|
||||
else:
|
||||
logger.error("Failed to publish training completed event", job_id=job_id)
|
||||
|
||||
return success
|
||||
|
||||
|
||||
async def publish_training_failed(
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
error_message: str
|
||||
) -> bool:
|
||||
"""
|
||||
Event: Training Failed
|
||||
"""
|
||||
event_data = {
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.failed",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"current_step": "Training Failed",
|
||||
"error_message": error_message
|
||||
}
|
||||
}
|
||||
|
||||
success = await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.failed",
|
||||
event_data=event_data
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("Published training failed event",
|
||||
job_id=job_id,
|
||||
error=error_message)
|
||||
else:
|
||||
logger.error("Failed to publish training failed event", job_id=job_id)
|
||||
|
||||
return success
|
||||
@@ -16,13 +16,7 @@ import pandas as pd
|
||||
from app.services.data_client import DataClient
|
||||
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType, AlignedDateRange
|
||||
|
||||
from app.services.messaging import (
|
||||
publish_job_progress,
|
||||
publish_data_validation_started,
|
||||
publish_data_validation_completed,
|
||||
publish_job_step_completed,
|
||||
publish_job_failed
|
||||
)
|
||||
from app.services.training_events import publish_training_failed
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
@@ -76,7 +70,6 @@ class TrainingDataOrchestrator:
|
||||
# Step 1: Fetch and validate sales data (unified approach)
|
||||
sales_data = await self.data_client.fetch_sales_data(tenant_id, fetch_all=True)
|
||||
|
||||
# Pre-flight validation moved here to eliminate duplicate fetching
|
||||
if not sales_data or len(sales_data) == 0:
|
||||
error_msg = f"No sales data available for tenant {tenant_id}. Please import sales data before starting training."
|
||||
logger.error("Training aborted - no sales data", tenant_id=tenant_id, job_id=job_id)
|
||||
@@ -172,7 +165,8 @@ class TrainingDataOrchestrator:
|
||||
return training_dataset
|
||||
|
||||
except Exception as e:
|
||||
publish_job_failed(job_id, tenant_id, str(e))
|
||||
if job_id and tenant_id:
|
||||
await publish_training_failed(job_id, tenant_id, str(e))
|
||||
logger.error(f"Training data preparation failed: {str(e)}")
|
||||
raise ValueError(f"Failed to prepare training data: {str(e)}")
|
||||
|
||||
@@ -472,30 +466,18 @@ class TrainingDataOrchestrator:
|
||||
logger.warning(f"Enhanced traffic data collection failed: {e}")
|
||||
return []
|
||||
|
||||
# Keep original method for backwards compatibility
|
||||
async def _collect_traffic_data_with_timeout(
|
||||
self,
|
||||
lat: float,
|
||||
lon: float,
|
||||
aligned_range: AlignedDateRange,
|
||||
tenant_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Legacy traffic data collection method - redirects to enhanced version"""
|
||||
return await self._collect_traffic_data_with_timeout_enhanced(lat, lon, aligned_range, tenant_id)
|
||||
|
||||
def _log_enhanced_traffic_data_storage(self,
|
||||
lat: float,
|
||||
lon: float,
|
||||
aligned_range: AlignedDateRange,
|
||||
def _log_enhanced_traffic_data_storage(self,
|
||||
lat: float,
|
||||
lon: float,
|
||||
aligned_range: AlignedDateRange,
|
||||
record_count: int,
|
||||
traffic_data: List[Dict[str, Any]]):
|
||||
"""Enhanced logging for traffic data storage with detailed metadata"""
|
||||
# Analyze the stored data for additional insights
|
||||
cities_detected = set()
|
||||
has_pedestrian_data = 0
|
||||
data_sources = set()
|
||||
districts_covered = set()
|
||||
|
||||
|
||||
for record in traffic_data:
|
||||
if 'city' in record and record['city']:
|
||||
cities_detected.add(record['city'])
|
||||
@@ -505,7 +487,7 @@ class TrainingDataOrchestrator:
|
||||
data_sources.add(record['source'])
|
||||
if 'district' in record and record['district']:
|
||||
districts_covered.add(record['district'])
|
||||
|
||||
|
||||
logger.info(
|
||||
"Enhanced traffic data stored for re-training",
|
||||
location=f"{lat:.4f},{lon:.4f}",
|
||||
@@ -516,20 +498,9 @@ class TrainingDataOrchestrator:
|
||||
data_sources=list(data_sources),
|
||||
districts_covered=list(districts_covered),
|
||||
storage_timestamp=datetime.now().isoformat(),
|
||||
purpose="enhanced_model_training_and_retraining",
|
||||
architecture_version="2.0_abstracted"
|
||||
purpose="model_training_and_retraining"
|
||||
)
|
||||
|
||||
def _log_traffic_data_storage(self,
|
||||
lat: float,
|
||||
lon: float,
|
||||
aligned_range: AlignedDateRange,
|
||||
record_count: int):
|
||||
"""Legacy logging method - redirects to enhanced version"""
|
||||
# Create minimal traffic data structure for enhanced logging
|
||||
minimal_traffic_data = [{"city": "madrid", "source": "legacy"}] * min(record_count, 1)
|
||||
self._log_enhanced_traffic_data_storage(lat, lon, aligned_range, record_count, minimal_traffic_data)
|
||||
|
||||
def _validate_weather_data(self, weather_data: List[Dict[str, Any]]) -> bool:
|
||||
"""Validate weather data quality"""
|
||||
if not weather_data:
|
||||
|
||||
@@ -13,10 +13,9 @@ import json
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from app.ml.trainer import BakeryMLTrainer
|
||||
from app.ml.trainer import EnhancedBakeryMLTrainer
|
||||
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType
|
||||
from app.services.training_orchestrator import TrainingDataOrchestrator
|
||||
from app.services.messaging import TrainingStatusPublisher
|
||||
|
||||
# Import repositories
|
||||
from app.repositories import (
|
||||
@@ -119,7 +118,7 @@ class EnhancedTrainingService:
|
||||
self.artifact_repo = ArtifactRepository(session)
|
||||
|
||||
# Initialize training components
|
||||
self.trainer = BakeryMLTrainer(database_manager=self.database_manager)
|
||||
self.trainer = EnhancedBakeryMLTrainer(database_manager=self.database_manager)
|
||||
self.date_alignment_service = DateAlignmentService()
|
||||
self.orchestrator = TrainingDataOrchestrator(
|
||||
date_alignment_service=self.date_alignment_service
|
||||
@@ -164,10 +163,8 @@ class EnhancedTrainingService:
|
||||
# Get session and initialize repositories
|
||||
async with self.database_manager.get_session() as session:
|
||||
await self._init_repositories(session)
|
||||
|
||||
|
||||
try:
|
||||
# Pre-flight check moved to orchestrator to eliminate duplicate sales data fetching
|
||||
|
||||
# Check if training log already exists, create if not
|
||||
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
|
||||
|
||||
@@ -187,21 +184,12 @@ class EnhancedTrainingService:
|
||||
}
|
||||
training_log = await self.training_log_repo.create_training_log(log_data)
|
||||
|
||||
# Initialize status publisher
|
||||
status_publisher = TrainingStatusPublisher(job_id, tenant_id)
|
||||
|
||||
await status_publisher.progress_update(
|
||||
progress=10,
|
||||
step="data_validation",
|
||||
step_details="Data"
|
||||
)
|
||||
|
||||
# Step 1: Prepare training dataset (includes sales data validation)
|
||||
logger.info("Step 1: Preparing and aligning training data (with validation)")
|
||||
await self.training_log_repo.update_log_progress(
|
||||
job_id, 10, "data_validation", "running"
|
||||
)
|
||||
|
||||
|
||||
# Orchestrator now handles sales data validation to eliminate duplicate fetching
|
||||
training_dataset = await self.orchestrator.prepare_training_data(
|
||||
tenant_id=tenant_id,
|
||||
@@ -210,11 +198,11 @@ class EnhancedTrainingService:
|
||||
requested_end=requested_end,
|
||||
job_id=job_id
|
||||
)
|
||||
|
||||
|
||||
# Log the results from orchestrator's unified sales data fetch
|
||||
logger.info(f"Sales data validation completed: {len(training_dataset.sales_data)} records",
|
||||
logger.info(f"Sales data validation completed: {len(training_dataset.sales_data)} records",
|
||||
tenant_id=tenant_id, job_id=job_id)
|
||||
|
||||
|
||||
await self.training_log_repo.update_log_progress(
|
||||
job_id, 30, "data_preparation_complete", "running"
|
||||
)
|
||||
@@ -224,15 +212,15 @@ class EnhancedTrainingService:
|
||||
await self.training_log_repo.update_log_progress(
|
||||
job_id, 40, "ml_training", "running"
|
||||
)
|
||||
|
||||
|
||||
training_results = await self.trainer.train_tenant_models(
|
||||
tenant_id=tenant_id,
|
||||
training_dataset=training_dataset,
|
||||
job_id=job_id
|
||||
)
|
||||
|
||||
|
||||
await self.training_log_repo.update_log_progress(
|
||||
job_id, 80, "training_complete", "running"
|
||||
job_id, 85, "training_complete", "running"
|
||||
)
|
||||
|
||||
# Step 3: Store model records using repository
|
||||
@@ -240,19 +228,21 @@ class EnhancedTrainingService:
|
||||
logger.debug("Training results structure",
|
||||
keys=list(training_results.keys()) if isinstance(training_results, dict) else "not_dict",
|
||||
training_results_type=type(training_results).__name__)
|
||||
|
||||
stored_models = await self._store_trained_models(
|
||||
tenant_id, job_id, training_results
|
||||
)
|
||||
|
||||
|
||||
await self.training_log_repo.update_log_progress(
|
||||
job_id, 90, "storing_models", "running"
|
||||
job_id, 92, "storing_models", "running"
|
||||
)
|
||||
|
||||
|
||||
# Step 4: Create performance metrics
|
||||
|
||||
await self._create_performance_metrics(
|
||||
tenant_id, stored_models, training_results
|
||||
)
|
||||
|
||||
|
||||
# Step 5: Complete training log
|
||||
final_result = {
|
||||
"job_id": job_id,
|
||||
@@ -308,11 +298,11 @@ class EnhancedTrainingService:
|
||||
await self.training_log_repo.complete_training_log(
|
||||
job_id, results=json_safe_result
|
||||
)
|
||||
|
||||
|
||||
logger.info("Enhanced training job completed successfully",
|
||||
job_id=job_id,
|
||||
models_created=len(stored_models))
|
||||
|
||||
|
||||
return self._create_detailed_training_response(final_result)
|
||||
|
||||
except Exception as e:
|
||||
@@ -460,7 +450,7 @@ class EnhancedTrainingService:
|
||||
async def get_training_status(self, job_id: str) -> Dict[str, Any]:
|
||||
"""Get training job status using repository"""
|
||||
try:
|
||||
async with self.database_manager.get_session()() as session:
|
||||
async with self.database_manager.get_session() as session:
|
||||
await self._init_repositories(session)
|
||||
|
||||
log = await self.training_log_repo.get_log_by_job_id(job_id)
|
||||
@@ -761,8 +751,4 @@ class EnhancedTrainingService:
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to create detailed response", error=str(e))
|
||||
return final_result
|
||||
|
||||
|
||||
# Legacy compatibility alias
|
||||
TrainingService = EnhancedTrainingService
|
||||
return final_result
|
||||
Reference in New Issue
Block a user