REFACTOR external service and improve websocket training

This commit is contained in:
Urtzi Alfaro
2025-10-09 14:11:02 +02:00
parent 7c72f83c51
commit 3c689b4f98
111 changed files with 13289 additions and 2374 deletions

View File

@@ -3,32 +3,14 @@ Training Service Layer
Business logic services for ML training and model management
"""
from .training_service import TrainingService
from .training_service import EnhancedTrainingService
from .training_orchestrator import TrainingDataOrchestrator
from .date_alignment_service import DateAlignmentService
from .data_client import DataClient
from .messaging import (
publish_job_progress,
publish_data_validation_started,
publish_data_validation_completed,
publish_job_step_completed,
publish_job_completed,
publish_job_failed,
TrainingStatusPublisher
)
__all__ = [
"TrainingService",
"EnhancedTrainingService",
"TrainingDataOrchestrator",
"TrainingDataOrchestrator",
"DateAlignmentService",
"DataClient",
"publish_job_progress",
"publish_data_validation_started",
"publish_data_validation_completed",
"publish_job_step_completed",
"publish_job_completed",
"publish_job_failed",
"TrainingStatusPublisher"
"DataClient"
]

View File

@@ -1,16 +1,20 @@
# services/training/app/services/data_client.py
"""
Training Service Data Client
Migrated to use shared service clients - much simpler now!
Migrated to use shared service clients with timeout configuration
"""
import structlog
from typing import Dict, Any, List, Optional
from datetime import datetime
import httpx
# Import the shared clients
from shared.clients import get_sales_client, get_external_client, get_service_clients
from app.core.config import settings
from app.core import constants as const
from app.utils.circuit_breaker import circuit_breaker_registry, CircuitBreakerError
from app.utils.retry import with_retry, HTTP_RETRY_STRATEGY, EXTERNAL_SERVICE_RETRY_STRATEGY
logger = structlog.get_logger()
@@ -21,21 +25,103 @@ class DataClient:
"""
def __init__(self):
# Get the new specialized clients
# Get the new specialized clients with timeout configuration
self.sales_client = get_sales_client(settings, "training")
self.external_client = get_external_client(settings, "training")
# Configure timeouts for HTTP clients
self._configure_timeouts()
# Initialize circuit breakers for external services
self._init_circuit_breakers()
# Check if the new method is available for stored traffic data
if hasattr(self.external_client, 'get_stored_traffic_data_for_training'):
self.supports_stored_traffic_data = True
def _configure_timeouts(self):
"""Configure appropriate timeouts for HTTP clients"""
timeout = httpx.Timeout(
connect=const.HTTP_TIMEOUT_DEFAULT,
read=const.HTTP_TIMEOUT_LONG_RUNNING,
write=const.HTTP_TIMEOUT_DEFAULT,
pool=const.HTTP_TIMEOUT_DEFAULT
)
# Apply timeout to clients if they have httpx clients
if hasattr(self.sales_client, 'client') and isinstance(self.sales_client.client, httpx.AsyncClient):
self.sales_client.client.timeout = timeout
if hasattr(self.external_client, 'client') and isinstance(self.external_client.client, httpx.AsyncClient):
self.external_client.client.timeout = timeout
else:
self.supports_stored_traffic_data = False
logger.warning("Stored traffic data method not available in external client")
# Or alternatively, get all clients at once:
# self.clients = get_service_clients(settings, "training")
# Then use: self.clients.sales.get_sales_data(...) and self.clients.external.get_weather_forecast(...)
def _init_circuit_breakers(self):
"""Initialize circuit breakers for external service calls"""
# Sales service circuit breaker
self.sales_cb = circuit_breaker_registry.get_or_create(
name="sales_service",
failure_threshold=5,
recovery_timeout=60.0,
expected_exception=Exception
)
# Weather service circuit breaker
self.weather_cb = circuit_breaker_registry.get_or_create(
name="weather_service",
failure_threshold=3, # Weather is optional, fail faster
recovery_timeout=30.0,
expected_exception=Exception
)
# Traffic service circuit breaker
self.traffic_cb = circuit_breaker_registry.get_or_create(
name="traffic_service",
failure_threshold=3, # Traffic is optional, fail faster
recovery_timeout=30.0,
expected_exception=Exception
)
@with_retry(max_attempts=3, initial_delay=1.0, max_delay=10.0)
async def _fetch_sales_data_internal(
self,
tenant_id: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
product_id: Optional[str] = None,
fetch_all: bool = True
) -> List[Dict[str, Any]]:
"""Internal method to fetch sales data with automatic retry"""
if fetch_all:
sales_data = await self.sales_client.get_all_sales_data(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
product_id=product_id,
aggregation="daily",
page_size=1000,
max_pages=100
)
else:
sales_data = await self.sales_client.get_sales_data(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
product_id=product_id,
aggregation="daily"
)
sales_data = sales_data or []
if sales_data:
logger.info(f"Fetched {len(sales_data)} sales records",
tenant_id=tenant_id, product_id=product_id, fetch_all=fetch_all)
return sales_data
else:
logger.error("No sales data returned", tenant_id=tenant_id)
raise ValueError(f"No sales data available for tenant {tenant_id}")
async def fetch_sales_data(
self,
tenant_id: str,
@@ -45,50 +131,21 @@ class DataClient:
fetch_all: bool = True
) -> List[Dict[str, Any]]:
"""
Fetch sales data for training
Args:
tenant_id: Tenant identifier
start_date: Start date in ISO format
end_date: End date in ISO format
product_id: Optional product filter
fetch_all: If True, fetches ALL records using pagination (original behavior)
If False, fetches limited records (standard API response)
Fetch sales data for training with circuit breaker protection
"""
try:
if fetch_all:
# Use paginated method to get ALL records (original behavior)
sales_data = await self.sales_client.get_all_sales_data(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
product_id=product_id,
aggregation="daily",
page_size=1000, # Comply with API limit
max_pages=100 # Safety limit (500k records max)
)
else:
# Use standard method for limited results
sales_data = await self.sales_client.get_sales_data(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
product_id=product_id,
aggregation="daily"
)
sales_data = sales_data or []
if sales_data:
logger.info(f"Fetched {len(sales_data)} sales records",
tenant_id=tenant_id, product_id=product_id, fetch_all=fetch_all)
return sales_data
else:
logger.warning("No sales data returned", tenant_id=tenant_id)
return []
return await self.sales_cb.call(
self._fetch_sales_data_internal,
tenant_id, start_date, end_date, product_id, fetch_all
)
except CircuitBreakerError as e:
logger.error(f"Sales service circuit breaker open: {e}")
raise RuntimeError(f"Sales service unavailable: {str(e)}")
except ValueError:
raise
except Exception as e:
logger.error(f"Error fetching sales data: {e}", tenant_id=tenant_id)
return []
raise RuntimeError(f"Failed to fetch sales data: {str(e)}")
async def fetch_weather_data(
self,
@@ -112,15 +169,15 @@ class DataClient:
)
if weather_data:
logger.info(f"Fetched {len(weather_data)} weather records",
logger.info(f"Fetched {len(weather_data)} weather records",
tenant_id=tenant_id)
return weather_data
else:
logger.warning("No weather data returned", tenant_id=tenant_id)
logger.warning("No weather data returned, will use synthetic data", tenant_id=tenant_id)
return []
except Exception as e:
logger.error(f"Error fetching weather data: {e}", tenant_id=tenant_id)
logger.warning(f"Error fetching weather data, will use synthetic data: {e}", tenant_id=tenant_id)
return []
async def fetch_traffic_data_unified(
@@ -264,34 +321,93 @@ class DataClient:
self,
tenant_id: str,
start_date: str,
end_date: str
end_date: str,
sales_data: List[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Validate data quality before training
Validate data quality before training with comprehensive checks
"""
try:
# Note: validation_data_quality may need to be implemented in one of the new services
# validation_result = await self.sales_client.validate_data_quality(
# tenant_id=tenant_id,
# start_date=start_date,
# end_date=end_date
# )
# Temporary implementation - assume data is valid for now
validation_result = {"is_valid": True, "message": "Validation temporarily disabled"}
if validation_result:
logger.info("Data validation completed",
tenant_id=tenant_id,
is_valid=validation_result.get("is_valid", False))
return validation_result
errors = []
warnings = []
# If sales data provided, validate it directly
if sales_data is not None:
if not sales_data or len(sales_data) == 0:
errors.append("No sales data available for the specified period")
return {"is_valid": False, "errors": errors, "warnings": warnings}
# Check minimum data points
if len(sales_data) < 30:
errors.append(f"Insufficient data points: {len(sales_data)} (minimum 30 required)")
elif len(sales_data) < 90:
warnings.append(f"Limited data points: {len(sales_data)} (recommended 90+)")
# Check for required fields
required_fields = ['date', 'inventory_product_id']
for record in sales_data[:5]: # Sample check
missing = [f for f in required_fields if f not in record or record[f] is None]
if missing:
errors.append(f"Missing required fields: {missing}")
break
# Check for data quality issues
zero_count = sum(1 for r in sales_data if r.get('quantity', 0) == 0)
zero_ratio = zero_count / len(sales_data)
if zero_ratio > 0.9:
errors.append(f"Too many zero values: {zero_ratio:.1%} of records")
elif zero_ratio > 0.7:
warnings.append(f"High zero value ratio: {zero_ratio:.1%}")
# Check product diversity
unique_products = set(r.get('inventory_product_id') for r in sales_data if r.get('inventory_product_id'))
if len(unique_products) == 0:
errors.append("No valid product IDs found in sales data")
elif len(unique_products) == 1:
warnings.append("Only one product found - consider adding more products")
else:
logger.warning("Data validation failed", tenant_id=tenant_id)
return {"is_valid": False, "errors": ["Validation service unavailable"]}
# Fetch data for validation
sales_data = await self.fetch_sales_data(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
fetch_all=False
)
if not sales_data:
errors.append("Unable to fetch sales data for validation")
return {"is_valid": False, "errors": errors, "warnings": warnings}
# Recursive call with fetched data
return await self.validate_data_quality(
tenant_id, start_date, end_date, sales_data
)
is_valid = len(errors) == 0
result = {
"is_valid": is_valid,
"errors": errors,
"warnings": warnings,
"data_points": len(sales_data) if sales_data else 0,
"unique_products": len(unique_products) if sales_data else 0
}
if is_valid:
logger.info("Data validation passed",
tenant_id=tenant_id,
data_points=result["data_points"],
warnings_count=len(warnings))
else:
logger.error("Data validation failed",
tenant_id=tenant_id,
errors=errors)
return result
except Exception as e:
logger.error(f"Error validating data: {e}", tenant_id=tenant_id)
return {"is_valid": False, "errors": [str(e)]}
raise ValueError(f"Data validation failed: {str(e)}")
# Global instance - same as before, but much simpler implementation
data_client = DataClient()

View File

@@ -1,9 +1,9 @@
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from enum import Enum
import logging
from datetime import datetime, timedelta, timezone
from app.utils.timezone_utils import ensure_timezone_aware
logger = logging.getLogger(__name__)
@@ -84,31 +84,25 @@ class DateAlignmentService:
requested_end: Optional[datetime]
) -> DateRange:
"""Determine the base date range for training."""
# ✅ FIX: Ensure all datetimes are timezone-aware for comparison
def ensure_timezone_aware(dt: datetime) -> datetime:
if dt.tzinfo is None:
return dt.replace(tzinfo=timezone.utc)
return dt
# Use explicit dates if provided
if requested_start and requested_end:
requested_start = ensure_timezone_aware(requested_start)
requested_end = ensure_timezone_aware(requested_end)
if requested_end <= requested_start:
raise ValueError("End date must be after start date")
return DateRange(requested_start, requested_end, DataSourceType.BAKERY_SALES)
# Otherwise, use the user's sales data range as the foundation
start_date = ensure_timezone_aware(requested_start or user_sales_range.start)
end_date = ensure_timezone_aware(requested_end or user_sales_range.end)
# Ensure we don't exceed maximum training range
if (end_date - start_date).days > self.MAX_TRAINING_RANGE_DAYS:
start_date = end_date - timedelta(days=self.MAX_TRAINING_RANGE_DAYS)
logger.warning(f"Limiting training range to {self.MAX_TRAINING_RANGE_DAYS} days")
return DateRange(start_date, end_date, DataSourceType.BAKERY_SALES)
def _apply_data_source_constraints(self, base_range: DateRange) -> AlignedDateRange:

View File

@@ -1,603 +0,0 @@
# services/training/app/services/messaging.py
"""
Enhanced training service messaging - Complete status publishing implementation
Uses shared RabbitMQ infrastructure with comprehensive progress tracking
"""
import structlog
from typing import Dict, Any, Optional, List
from datetime import datetime
from shared.messaging.rabbitmq import RabbitMQClient
from shared.messaging.events import (
TrainingStartedEvent,
TrainingCompletedEvent,
TrainingFailedEvent
)
from app.core.config import settings
import json
import numpy as np
logger = structlog.get_logger()
# Single global instance
training_publisher = RabbitMQClient(settings.RABBITMQ_URL, "training-service")
async def setup_messaging():
"""Initialize messaging for training service"""
success = await training_publisher.connect()
if success:
logger.info("Training service messaging initialized")
else:
logger.warning("Training service messaging failed to initialize")
async def cleanup_messaging():
"""Cleanup messaging for training service"""
await training_publisher.disconnect()
logger.info("Training service messaging cleaned up")
def serialize_for_json(obj: Any) -> Any:
"""
Convert numpy types and other non-JSON serializable objects to JSON-compatible types
"""
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, np.bool_):
return bool(obj)
elif isinstance(obj, datetime):
return obj.isoformat()
elif isinstance(obj, dict):
return {key: serialize_for_json(value) for key, value in obj.items()}
elif isinstance(obj, (list, tuple)):
return [serialize_for_json(item) for item in obj]
else:
return obj
def safe_json_serialize(data: Dict[str, Any]) -> Dict[str, Any]:
"""
Recursively clean data dictionary for JSON serialization
"""
return serialize_for_json(data)
async def setup_websocket_message_routing():
"""Set up message routing for WebSocket connections"""
try:
# This will be called from the WebSocket endpoint
# to set up the consumer for a specific job
pass
except Exception as e:
logger.error(f"Failed to set up WebSocket message routing: {e}")
# =========================================
# ENHANCED TRAINING JOB STATUS EVENTS
# =========================================
async def publish_job_started(job_id: str, tenant_id: str, config: Dict[str, Any]) -> bool:
"""Publish training job started event"""
event = TrainingStartedEvent(
service_name="training-service",
data={
"job_id": job_id,
"tenant_id": tenant_id,
"config": config,
"started_at": datetime.now().isoformat(),
"estimated_duration_minutes": config.get("estimated_duration_minutes", 15)
}
)
success = await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.started",
event_data=event.to_dict()
)
if success:
logger.info(f"Published job started event", job_id=job_id, tenant_id=tenant_id)
else:
logger.error(f"Failed to publish job started event", job_id=job_id)
return success
async def publish_job_progress(
job_id: str,
tenant_id: str,
progress: int,
step: str,
current_product: Optional[str] = None,
products_completed: int = 0,
products_total: int = 0,
estimated_time_remaining_minutes: Optional[int] = None,
step_details: Optional[str] = None
) -> bool:
"""Publish detailed training job progress event with safe serialization"""
event_data = {
"service_name": "training-service",
"event_type": "training.progress",
"timestamp": datetime.now().isoformat(),
"data": {
"job_id": job_id,
"tenant_id": tenant_id,
"progress": min(max(int(progress), 0), 100), # Ensure int, not numpy.int64
"current_step": step,
"current_product": current_product,
"products_completed": int(products_completed), # Convert numpy types
"products_total": int(products_total),
"estimated_time_remaining_minutes": int(estimated_time_remaining_minutes) if estimated_time_remaining_minutes else None,
"step_details": step_details
}
}
# Clean the entire event data
clean_event_data = safe_json_serialize(event_data)
success = await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.progress",
event_data=clean_event_data
)
if success:
logger.info(f"Published progress update",
job_id=job_id,
progress=progress,
step=step,
current_product=current_product)
else:
logger.error(f"Failed to publish progress update", job_id=job_id)
return success
async def publish_job_step_completed(
job_id: str,
tenant_id: str,
step_name: str,
step_result: Dict[str, Any],
progress: int
) -> bool:
"""Publish when a major training step is completed"""
event_data = {
"service_name": "training-service",
"event_type": "training.step.completed",
"timestamp": datetime.now().isoformat(),
"data": {
"job_id": job_id,
"tenant_id": tenant_id,
"step_name": step_name,
"step_result": step_result,
"progress": progress,
"completed_at": datetime.now().isoformat()
}
}
return await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.step.completed",
event_data=event_data
)
async def publish_job_completed(job_id: str, tenant_id: str, results: Dict[str, Any]) -> bool:
"""Publish training job completed event with safe JSON serialization"""
# Clean the results data before creating the event
clean_results = safe_json_serialize(results)
event = TrainingCompletedEvent(
service_name="training-service",
data={
"job_id": job_id,
"tenant_id": tenant_id,
"results": clean_results, # Now safe for JSON
"models_trained": clean_results.get("successful_trainings", 0),
"success_rate": clean_results.get("success_rate", 0),
"total_duration_seconds": clean_results.get("overall_training_time_seconds", 0),
"completed_at": datetime.now().isoformat()
}
)
success = await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.completed",
event_data=event.to_dict()
)
if success:
logger.info(f"Published job completed event",
job_id=job_id,
models_trained=clean_results.get("successful_trainings", 0))
else:
logger.error(f"Failed to publish job completed event", job_id=job_id)
return success
async def publish_job_failed(job_id: str, tenant_id: str, error: str, error_details: Optional[Dict] = None) -> bool:
"""Publish training job failed event"""
event = TrainingFailedEvent(
service_name="training-service",
data={
"job_id": job_id,
"tenant_id": tenant_id,
"error": error,
"error_details": error_details or {},
"failed_at": datetime.now().isoformat()
}
)
success = await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.failed",
event_data=event.to_dict()
)
if success:
logger.info(f"Published job failed event", job_id=job_id, error=error)
else:
logger.error(f"Failed to publish job failed event", job_id=job_id)
return success
async def publish_job_cancelled(job_id: str, tenant_id: str, reason: str = "User requested") -> bool:
"""Publish training job cancelled event"""
event_data = {
"service_name": "training-service",
"event_type": "training.cancelled",
"timestamp": datetime.now().isoformat(),
"data": {
"job_id": job_id,
"tenant_id": tenant_id,
"reason": reason,
"cancelled_at": datetime.now().isoformat()
}
}
return await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.cancelled",
event_data=event_data
)
# =========================================
# PRODUCT-LEVEL TRAINING EVENTS
# =========================================
async def publish_product_training_started(job_id: str, tenant_id: str, inventory_product_id: str) -> bool:
"""Publish single product training started event"""
return await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.product.started",
event_data={
"service_name": "training-service",
"event_type": "training.product.started",
"timestamp": datetime.now().isoformat(),
"data": {
"job_id": job_id,
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id,
"started_at": datetime.now().isoformat()
}
}
)
async def publish_product_training_completed(
job_id: str,
tenant_id: str,
inventory_product_id: str,
model_id: str,
metrics: Optional[Dict[str, float]] = None
) -> bool:
"""Publish single product training completed event"""
return await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.product.completed",
event_data={
"service_name": "training-service",
"event_type": "training.product.completed",
"timestamp": datetime.now().isoformat(),
"data": {
"job_id": job_id,
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id,
"model_id": model_id,
"metrics": metrics or {},
"completed_at": datetime.now().isoformat()
}
}
)
async def publish_product_training_failed(
job_id: str,
tenant_id: str,
inventory_product_id: str,
error: str
) -> bool:
"""Publish single product training failed event"""
return await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.product.failed",
event_data={
"service_name": "training-service",
"event_type": "training.product.failed",
"timestamp": datetime.now().isoformat(),
"data": {
"job_id": job_id,
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id,
"error": error,
"failed_at": datetime.now().isoformat()
}
}
)
# =========================================
# MODEL LIFECYCLE EVENTS
# =========================================
async def publish_model_trained(model_id: str, tenant_id: str, inventory_product_id: str, metrics: Dict[str, float]) -> bool:
"""Publish model trained event with safe metric serialization"""
# Clean metrics to ensure JSON serialization
clean_metrics = safe_json_serialize(metrics) if metrics else {}
event_data = {
"service_name": "training-service",
"event_type": "training.model.trained",
"timestamp": datetime.now().isoformat(),
"data": {
"model_id": model_id,
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id,
"training_metrics": clean_metrics, # Now safe for JSON
"trained_at": datetime.now().isoformat()
}
}
return await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.model.trained",
event_data=event_data
)
async def publish_model_validated(model_id: str, tenant_id: str, inventory_product_id: str, validation_results: Dict[str, Any]) -> bool:
"""Publish model validation event"""
return await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.model.validated",
event_data={
"service_name": "training-service",
"event_type": "training.model.validated",
"timestamp": datetime.now().isoformat(),
"data": {
"model_id": model_id,
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id,
"validation_results": validation_results,
"validated_at": datetime.now().isoformat()
}
}
)
async def publish_model_saved(model_id: str, tenant_id: str, inventory_product_id: str, model_path: str) -> bool:
"""Publish model saved event"""
return await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.model.saved",
event_data={
"service_name": "training-service",
"event_type": "training.model.saved",
"timestamp": datetime.now().isoformat(),
"data": {
"model_id": model_id,
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id,
"model_path": model_path,
"saved_at": datetime.now().isoformat()
}
}
)
# =========================================
# DATA PROCESSING EVENTS
# =========================================
async def publish_data_validation_started(job_id: str, tenant_id: str, products: List[str]) -> bool:
"""Publish data validation started event"""
return await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.data.validation.started",
event_data={
"service_name": "training-service",
"event_type": "training.data.validation.started",
"timestamp": datetime.now().isoformat(),
"data": {
"job_id": job_id,
"tenant_id": tenant_id,
"products": products,
"started_at": datetime.now().isoformat()
}
}
)
async def publish_data_validation_completed(
job_id: str,
tenant_id: str,
validation_results: Dict[str, Any]
) -> bool:
"""Publish data validation completed event"""
return await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.data.validation.completed",
event_data={
"service_name": "training-service",
"event_type": "training.data.validation.completed",
"timestamp": datetime.now().isoformat(),
"data": {
"job_id": job_id,
"tenant_id": tenant_id,
"validation_results": validation_results,
"completed_at": datetime.now().isoformat()
}
}
)
async def publish_models_deleted_event(tenant_id: str, deletion_stats: Dict[str, Any]):
"""Publish models deletion event to message queue"""
try:
await training_publisher.publish_event(
exchange="training_events",
routing_key="training.tenant.models.deleted",
message={
"event_type": "tenant_models_deleted",
"tenant_id": tenant_id,
"timestamp": datetime.utcnow().isoformat(),
"deletion_stats": deletion_stats
}
)
except Exception as e:
logger.error("Failed to publish models deletion event", error=str(e))
# =========================================
# UTILITY FUNCTIONS FOR BATCH PUBLISHING
# =========================================
async def publish_batch_status_update(
job_id: str,
tenant_id: str,
updates: List[Dict[str, Any]]
) -> bool:
"""Publish multiple status updates as a batch"""
batch_event = {
"service_name": "training-service",
"event_type": "training.batch.update",
"timestamp": datetime.now().isoformat(),
"data": {
"job_id": job_id,
"tenant_id": tenant_id,
"updates": updates,
"batch_size": len(updates)
}
}
return await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.batch.update",
event_data=batch_event
)
# =========================================
# HELPER FUNCTIONS FOR TRAINING INTEGRATION
# =========================================
class TrainingStatusPublisher:
"""Helper class to manage training status publishing throughout the training process"""
def __init__(self, job_id: str, tenant_id: str):
self.job_id = job_id
self.tenant_id = tenant_id
self.start_time = datetime.now()
self.products_total = 0
self.products_completed = 0
async def job_started(self, config: Dict[str, Any], products_total: int = 0):
"""Publish job started with initial configuration"""
self.products_total = products_total
# Clean config data
clean_config = safe_json_serialize(config)
await publish_job_started(self.job_id, self.tenant_id, clean_config)
async def progress_update(
self,
progress: int,
step: str,
current_product: Optional[str] = None,
step_details: Optional[str] = None
):
"""Publish progress update with improved time estimates"""
elapsed_minutes = (datetime.now() - self.start_time).total_seconds() / 60
# Improved estimation based on training phases
estimated_remaining = self._calculate_smart_time_remaining(progress, elapsed_minutes, step)
await publish_job_progress(
job_id=self.job_id,
tenant_id=self.tenant_id,
progress=int(progress),
step=step,
current_product=current_product,
products_completed=int(self.products_completed),
products_total=int(self.products_total),
estimated_time_remaining_minutes=int(estimated_remaining) if estimated_remaining else None,
step_details=step_details
)
def _calculate_smart_time_remaining(self, progress: int, elapsed_minutes: float, step: str) -> Optional[int]:
"""Calculate estimated time remaining using phase-based estimation"""
# Define expected time distribution for each phase
phase_durations = {
"data_validation": 1.0, # 1 minute
"feature_engineering": 2.0, # 2 minutes
"model_training": 8.0, # 8 minutes (bulk of time)
"model_validation": 1.0 # 1 minute
}
total_expected_minutes = sum(phase_durations.values()) # 12 minutes
# Calculate progress through phases
if progress <= 10: # data_validation phase
remaining_in_phase = phase_durations["data_validation"] * (1 - (progress / 10))
remaining_after_phase = sum(list(phase_durations.values())[1:])
return int(remaining_in_phase + remaining_after_phase)
elif progress <= 20: # feature_engineering phase
remaining_in_phase = phase_durations["feature_engineering"] * (1 - ((progress - 10) / 10))
remaining_after_phase = sum(list(phase_durations.values())[2:])
return int(remaining_in_phase + remaining_after_phase)
elif progress <= 90: # model_training phase (biggest chunk)
remaining_in_phase = phase_durations["model_training"] * (1 - ((progress - 20) / 70))
remaining_after_phase = phase_durations["model_validation"]
return int(remaining_in_phase + remaining_after_phase)
elif progress <= 100: # model_validation phase
remaining_in_phase = phase_durations["model_validation"] * (1 - ((progress - 90) / 10))
return int(remaining_in_phase)
return 0
async def product_completed(self, inventory_product_id: str, model_id: str, metrics: Optional[Dict] = None):
"""Mark a product as completed and update progress"""
self.products_completed += 1
# Clean metrics before publishing
clean_metrics = safe_json_serialize(metrics) if metrics else None
await publish_product_training_completed(
self.job_id, self.tenant_id, inventory_product_id, model_id, clean_metrics
)
# Update overall progress
if self.products_total > 0:
progress = int((self.products_completed / self.products_total) * 90) # Save 10% for final steps
await self.progress_update(
progress=progress,
step=f"Completed training for {inventory_product_id}",
current_product=None
)
async def job_completed(self, results: Dict[str, Any]):
"""Publish job completion with clean data"""
clean_results = safe_json_serialize(results)
await publish_job_completed(self.job_id, self.tenant_id, clean_results)
async def job_failed(self, error: str, error_details: Optional[Dict] = None):
"""Publish job failure with clean error details"""
clean_error_details = safe_json_serialize(error_details) if error_details else None
await publish_job_failed(self.job_id, self.tenant_id, error, clean_error_details)

View File

@@ -0,0 +1,78 @@
"""
Training Progress Tracker
Manages progress calculation for parallel product training (20-80% range)
"""
import asyncio
import structlog
from typing import Optional
from app.services.training_events import publish_product_training_completed
logger = structlog.get_logger()
class ParallelProductProgressTracker:
"""
Tracks parallel product training progress and emits events.
For N products training in parallel:
- Each product completion contributes 60/N% to overall progress
- Progress range: 20% (after data analysis) to 80% (before completion)
- Thread-safe for concurrent product trainings
"""
def __init__(self, job_id: str, tenant_id: str, total_products: int):
self.job_id = job_id
self.tenant_id = tenant_id
self.total_products = total_products
self.products_completed = 0
self._lock = asyncio.Lock()
# Calculate progress increment per product
# 60% of total progress (from 20% to 80%) divided by number of products
self.progress_per_product = 60 / total_products if total_products > 0 else 0
logger.info("ParallelProductProgressTracker initialized",
job_id=job_id,
total_products=total_products,
progress_per_product=f"{self.progress_per_product:.2f}%")
async def mark_product_completed(self, product_name: str) -> int:
"""
Mark a product as completed and publish event.
Returns the current overall progress percentage.
"""
async with self._lock:
self.products_completed += 1
current_progress = self.products_completed
# Publish product completion event
await publish_product_training_completed(
job_id=self.job_id,
tenant_id=self.tenant_id,
product_name=product_name,
products_completed=current_progress,
total_products=self.total_products
)
# Calculate overall progress (20% base + progress from completed products)
# This calculation is done on the frontend/consumer side based on the event data
overall_progress = 20 + int((current_progress / self.total_products) * 60)
logger.info("Product training completed",
job_id=self.job_id,
product_name=product_name,
products_completed=current_progress,
total_products=self.total_products,
overall_progress=overall_progress)
return overall_progress
def get_progress(self) -> dict:
"""Get current progress summary"""
return {
"products_completed": self.products_completed,
"total_products": self.total_products,
"progress_percentage": 20 + int((self.products_completed / self.total_products) * 60)
}

View File

@@ -0,0 +1,238 @@
"""
Training Progress Events Publisher
Simple, clean event publisher for the 4 main training steps
"""
import structlog
from datetime import datetime
from typing import Dict, Any, Optional
from shared.messaging.rabbitmq import RabbitMQClient
from app.core.config import settings
logger = structlog.get_logger()
# Single global publisher instance
training_publisher = RabbitMQClient(settings.RABBITMQ_URL, "training-service")
async def setup_messaging():
"""Initialize messaging"""
success = await training_publisher.connect()
if success:
logger.info("Training messaging initialized")
else:
logger.warning("Training messaging failed to initialize")
return success
async def cleanup_messaging():
"""Cleanup messaging"""
await training_publisher.disconnect()
logger.info("Training messaging cleaned up")
# ==========================================
# 4 MAIN TRAINING PROGRESS EVENTS
# ==========================================
async def publish_training_started(
job_id: str,
tenant_id: str,
total_products: int
) -> bool:
"""
Event 1: Training Started (0% progress)
"""
event_data = {
"service_name": "training-service",
"event_type": "training.started",
"timestamp": datetime.now().isoformat(),
"data": {
"job_id": job_id,
"tenant_id": tenant_id,
"progress": 0,
"current_step": "Training Started",
"step_details": f"Starting training for {total_products} products",
"total_products": total_products
}
}
success = await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.started",
event_data=event_data
)
if success:
logger.info("Published training started event",
job_id=job_id,
tenant_id=tenant_id,
total_products=total_products)
else:
logger.error("Failed to publish training started event", job_id=job_id)
return success
async def publish_data_analysis(
job_id: str,
tenant_id: str,
analysis_details: Optional[str] = None
) -> bool:
"""
Event 2: Data Analysis (20% progress)
"""
event_data = {
"service_name": "training-service",
"event_type": "training.progress",
"timestamp": datetime.now().isoformat(),
"data": {
"job_id": job_id,
"tenant_id": tenant_id,
"progress": 20,
"current_step": "Data Analysis",
"step_details": analysis_details or "Analyzing sales, weather, and traffic data"
}
}
success = await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.progress",
event_data=event_data
)
if success:
logger.info("Published data analysis event",
job_id=job_id,
progress=20)
else:
logger.error("Failed to publish data analysis event", job_id=job_id)
return success
async def publish_product_training_completed(
job_id: str,
tenant_id: str,
product_name: str,
products_completed: int,
total_products: int
) -> bool:
"""
Event 3: Product Training Completed (contributes to 20-80% progress)
This event is published each time a product training completes.
The frontend/consumer will calculate the progress as:
progress = 20 + (products_completed / total_products) * 60
"""
event_data = {
"service_name": "training-service",
"event_type": "training.product.completed",
"timestamp": datetime.now().isoformat(),
"data": {
"job_id": job_id,
"tenant_id": tenant_id,
"product_name": product_name,
"products_completed": products_completed,
"total_products": total_products,
"current_step": "Model Training",
"step_details": f"Completed training for {product_name} ({products_completed}/{total_products})"
}
}
success = await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.product.completed",
event_data=event_data
)
if success:
logger.info("Published product training completed event",
job_id=job_id,
product_name=product_name,
products_completed=products_completed,
total_products=total_products)
else:
logger.error("Failed to publish product training completed event",
job_id=job_id)
return success
async def publish_training_completed(
job_id: str,
tenant_id: str,
successful_trainings: int,
failed_trainings: int,
total_duration_seconds: float
) -> bool:
"""
Event 4: Training Completed (100% progress)
"""
event_data = {
"service_name": "training-service",
"event_type": "training.completed",
"timestamp": datetime.now().isoformat(),
"data": {
"job_id": job_id,
"tenant_id": tenant_id,
"progress": 100,
"current_step": "Training Completed",
"step_details": f"Training completed: {successful_trainings} successful, {failed_trainings} failed",
"successful_trainings": successful_trainings,
"failed_trainings": failed_trainings,
"total_duration_seconds": total_duration_seconds
}
}
success = await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.completed",
event_data=event_data
)
if success:
logger.info("Published training completed event",
job_id=job_id,
successful_trainings=successful_trainings,
failed_trainings=failed_trainings)
else:
logger.error("Failed to publish training completed event", job_id=job_id)
return success
async def publish_training_failed(
job_id: str,
tenant_id: str,
error_message: str
) -> bool:
"""
Event: Training Failed
"""
event_data = {
"service_name": "training-service",
"event_type": "training.failed",
"timestamp": datetime.now().isoformat(),
"data": {
"job_id": job_id,
"tenant_id": tenant_id,
"current_step": "Training Failed",
"error_message": error_message
}
}
success = await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.failed",
event_data=event_data
)
if success:
logger.info("Published training failed event",
job_id=job_id,
error=error_message)
else:
logger.error("Failed to publish training failed event", job_id=job_id)
return success

View File

@@ -16,13 +16,7 @@ import pandas as pd
from app.services.data_client import DataClient
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType, AlignedDateRange
from app.services.messaging import (
publish_job_progress,
publish_data_validation_started,
publish_data_validation_completed,
publish_job_step_completed,
publish_job_failed
)
from app.services.training_events import publish_training_failed
logger = structlog.get_logger()
@@ -76,7 +70,6 @@ class TrainingDataOrchestrator:
# Step 1: Fetch and validate sales data (unified approach)
sales_data = await self.data_client.fetch_sales_data(tenant_id, fetch_all=True)
# Pre-flight validation moved here to eliminate duplicate fetching
if not sales_data or len(sales_data) == 0:
error_msg = f"No sales data available for tenant {tenant_id}. Please import sales data before starting training."
logger.error("Training aborted - no sales data", tenant_id=tenant_id, job_id=job_id)
@@ -172,7 +165,8 @@ class TrainingDataOrchestrator:
return training_dataset
except Exception as e:
publish_job_failed(job_id, tenant_id, str(e))
if job_id and tenant_id:
await publish_training_failed(job_id, tenant_id, str(e))
logger.error(f"Training data preparation failed: {str(e)}")
raise ValueError(f"Failed to prepare training data: {str(e)}")
@@ -472,30 +466,18 @@ class TrainingDataOrchestrator:
logger.warning(f"Enhanced traffic data collection failed: {e}")
return []
# Keep original method for backwards compatibility
async def _collect_traffic_data_with_timeout(
self,
lat: float,
lon: float,
aligned_range: AlignedDateRange,
tenant_id: str
) -> List[Dict[str, Any]]:
"""Legacy traffic data collection method - redirects to enhanced version"""
return await self._collect_traffic_data_with_timeout_enhanced(lat, lon, aligned_range, tenant_id)
def _log_enhanced_traffic_data_storage(self,
lat: float,
lon: float,
aligned_range: AlignedDateRange,
def _log_enhanced_traffic_data_storage(self,
lat: float,
lon: float,
aligned_range: AlignedDateRange,
record_count: int,
traffic_data: List[Dict[str, Any]]):
"""Enhanced logging for traffic data storage with detailed metadata"""
# Analyze the stored data for additional insights
cities_detected = set()
has_pedestrian_data = 0
data_sources = set()
districts_covered = set()
for record in traffic_data:
if 'city' in record and record['city']:
cities_detected.add(record['city'])
@@ -505,7 +487,7 @@ class TrainingDataOrchestrator:
data_sources.add(record['source'])
if 'district' in record and record['district']:
districts_covered.add(record['district'])
logger.info(
"Enhanced traffic data stored for re-training",
location=f"{lat:.4f},{lon:.4f}",
@@ -516,20 +498,9 @@ class TrainingDataOrchestrator:
data_sources=list(data_sources),
districts_covered=list(districts_covered),
storage_timestamp=datetime.now().isoformat(),
purpose="enhanced_model_training_and_retraining",
architecture_version="2.0_abstracted"
purpose="model_training_and_retraining"
)
def _log_traffic_data_storage(self,
lat: float,
lon: float,
aligned_range: AlignedDateRange,
record_count: int):
"""Legacy logging method - redirects to enhanced version"""
# Create minimal traffic data structure for enhanced logging
minimal_traffic_data = [{"city": "madrid", "source": "legacy"}] * min(record_count, 1)
self._log_enhanced_traffic_data_storage(lat, lon, aligned_range, record_count, minimal_traffic_data)
def _validate_weather_data(self, weather_data: List[Dict[str, Any]]) -> bool:
"""Validate weather data quality"""
if not weather_data:

View File

@@ -13,10 +13,9 @@ import json
import numpy as np
import pandas as pd
from app.ml.trainer import BakeryMLTrainer
from app.ml.trainer import EnhancedBakeryMLTrainer
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType
from app.services.training_orchestrator import TrainingDataOrchestrator
from app.services.messaging import TrainingStatusPublisher
# Import repositories
from app.repositories import (
@@ -119,7 +118,7 @@ class EnhancedTrainingService:
self.artifact_repo = ArtifactRepository(session)
# Initialize training components
self.trainer = BakeryMLTrainer(database_manager=self.database_manager)
self.trainer = EnhancedBakeryMLTrainer(database_manager=self.database_manager)
self.date_alignment_service = DateAlignmentService()
self.orchestrator = TrainingDataOrchestrator(
date_alignment_service=self.date_alignment_service
@@ -164,10 +163,8 @@ class EnhancedTrainingService:
# Get session and initialize repositories
async with self.database_manager.get_session() as session:
await self._init_repositories(session)
try:
# Pre-flight check moved to orchestrator to eliminate duplicate sales data fetching
# Check if training log already exists, create if not
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
@@ -187,21 +184,12 @@ class EnhancedTrainingService:
}
training_log = await self.training_log_repo.create_training_log(log_data)
# Initialize status publisher
status_publisher = TrainingStatusPublisher(job_id, tenant_id)
await status_publisher.progress_update(
progress=10,
step="data_validation",
step_details="Data"
)
# Step 1: Prepare training dataset (includes sales data validation)
logger.info("Step 1: Preparing and aligning training data (with validation)")
await self.training_log_repo.update_log_progress(
job_id, 10, "data_validation", "running"
)
# Orchestrator now handles sales data validation to eliminate duplicate fetching
training_dataset = await self.orchestrator.prepare_training_data(
tenant_id=tenant_id,
@@ -210,11 +198,11 @@ class EnhancedTrainingService:
requested_end=requested_end,
job_id=job_id
)
# Log the results from orchestrator's unified sales data fetch
logger.info(f"Sales data validation completed: {len(training_dataset.sales_data)} records",
logger.info(f"Sales data validation completed: {len(training_dataset.sales_data)} records",
tenant_id=tenant_id, job_id=job_id)
await self.training_log_repo.update_log_progress(
job_id, 30, "data_preparation_complete", "running"
)
@@ -224,15 +212,15 @@ class EnhancedTrainingService:
await self.training_log_repo.update_log_progress(
job_id, 40, "ml_training", "running"
)
training_results = await self.trainer.train_tenant_models(
tenant_id=tenant_id,
training_dataset=training_dataset,
job_id=job_id
)
await self.training_log_repo.update_log_progress(
job_id, 80, "training_complete", "running"
job_id, 85, "training_complete", "running"
)
# Step 3: Store model records using repository
@@ -240,19 +228,21 @@ class EnhancedTrainingService:
logger.debug("Training results structure",
keys=list(training_results.keys()) if isinstance(training_results, dict) else "not_dict",
training_results_type=type(training_results).__name__)
stored_models = await self._store_trained_models(
tenant_id, job_id, training_results
)
await self.training_log_repo.update_log_progress(
job_id, 90, "storing_models", "running"
job_id, 92, "storing_models", "running"
)
# Step 4: Create performance metrics
await self._create_performance_metrics(
tenant_id, stored_models, training_results
)
# Step 5: Complete training log
final_result = {
"job_id": job_id,
@@ -308,11 +298,11 @@ class EnhancedTrainingService:
await self.training_log_repo.complete_training_log(
job_id, results=json_safe_result
)
logger.info("Enhanced training job completed successfully",
job_id=job_id,
models_created=len(stored_models))
return self._create_detailed_training_response(final_result)
except Exception as e:
@@ -460,7 +450,7 @@ class EnhancedTrainingService:
async def get_training_status(self, job_id: str) -> Dict[str, Any]:
"""Get training job status using repository"""
try:
async with self.database_manager.get_session()() as session:
async with self.database_manager.get_session() as session:
await self._init_repositories(session)
log = await self.training_log_repo.get_log_by_job_id(job_id)
@@ -761,8 +751,4 @@ class EnhancedTrainingService:
except Exception as e:
logger.error("Failed to create detailed response", error=str(e))
return final_result
# Legacy compatibility alias
TrainingService = EnhancedTrainingService
return final_result