Add all the code for training service
This commit is contained in:
467
services/training/tests/test_messaging.py
Normal file
467
services/training/tests/test_messaging.py
Normal file
@@ -0,0 +1,467 @@
|
||||
# services/training/tests/test_messaging.py
|
||||
"""
|
||||
Tests for training service messaging functionality
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
import json
|
||||
|
||||
from app.services import messaging
|
||||
|
||||
|
||||
class TestTrainingMessaging:
|
||||
"""Test training service messaging functions"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_publisher(self):
|
||||
"""Mock the RabbitMQ publisher"""
|
||||
with patch('app.services.messaging.training_publisher') as mock_pub:
|
||||
mock_pub.publish_event = AsyncMock(return_value=True)
|
||||
mock_pub.connect = AsyncMock(return_value=True)
|
||||
mock_pub.disconnect = AsyncMock(return_value=None)
|
||||
yield mock_pub
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_messaging_success(self, mock_publisher):
|
||||
"""Test successful messaging setup"""
|
||||
await messaging.setup_messaging()
|
||||
|
||||
mock_publisher.connect.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_messaging_failure(self, mock_publisher):
|
||||
"""Test messaging setup failure"""
|
||||
mock_publisher.connect.return_value = False
|
||||
|
||||
await messaging.setup_messaging()
|
||||
|
||||
mock_publisher.connect.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_messaging(self, mock_publisher):
|
||||
"""Test messaging cleanup"""
|
||||
await messaging.cleanup_messaging()
|
||||
|
||||
mock_publisher.disconnect.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_job_started(self, mock_publisher):
|
||||
"""Test publishing job started event"""
|
||||
job_id = "test-job-123"
|
||||
tenant_id = "test-tenant"
|
||||
config = {"include_weather": True}
|
||||
|
||||
result = await messaging.publish_job_started(job_id, tenant_id, config)
|
||||
|
||||
assert result is True
|
||||
mock_publisher.publish_event.assert_called_once()
|
||||
|
||||
# Check call arguments
|
||||
call_args = mock_publisher.publish_event.call_args
|
||||
assert call_args[1]["exchange_name"] == "training.events"
|
||||
assert call_args[1]["routing_key"] == "training.started"
|
||||
|
||||
event_data = call_args[1]["event_data"]
|
||||
assert event_data["service_name"] == "training-service"
|
||||
assert event_data["data"]["job_id"] == job_id
|
||||
assert event_data["data"]["tenant_id"] == tenant_id
|
||||
assert event_data["data"]["config"] == config
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_job_progress(self, mock_publisher):
|
||||
"""Test publishing job progress event"""
|
||||
job_id = "test-job-123"
|
||||
tenant_id = "test-tenant"
|
||||
progress = 50
|
||||
step = "Training models"
|
||||
|
||||
result = await messaging.publish_job_progress(job_id, tenant_id, progress, step)
|
||||
|
||||
assert result is True
|
||||
mock_publisher.publish_event.assert_called_once()
|
||||
|
||||
call_args = mock_publisher.publish_event.call_args
|
||||
assert call_args[1]["routing_key"] == "training.progress"
|
||||
|
||||
event_data = call_args[1]["event_data"]
|
||||
assert event_data["data"]["progress"] == progress
|
||||
assert event_data["data"]["current_step"] == step
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_job_completed(self, mock_publisher):
|
||||
"""Test publishing job completed event"""
|
||||
job_id = "test-job-123"
|
||||
tenant_id = "test-tenant"
|
||||
results = {
|
||||
"products_trained": 3,
|
||||
"summary": {"success_rate": 100.0}
|
||||
}
|
||||
|
||||
result = await messaging.publish_job_completed(job_id, tenant_id, results)
|
||||
|
||||
assert result is True
|
||||
mock_publisher.publish_event.assert_called_once()
|
||||
|
||||
call_args = mock_publisher.publish_event.call_args
|
||||
assert call_args[1]["routing_key"] == "training.completed"
|
||||
|
||||
event_data = call_args[1]["event_data"]
|
||||
assert event_data["data"]["results"] == results
|
||||
assert event_data["data"]["models_trained"] == 3
|
||||
assert event_data["data"]["success_rate"] == 100.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_job_failed(self, mock_publisher):
|
||||
"""Test publishing job failed event"""
|
||||
job_id = "test-job-123"
|
||||
tenant_id = "test-tenant"
|
||||
error = "Data service unavailable"
|
||||
|
||||
result = await messaging.publish_job_failed(job_id, tenant_id, error)
|
||||
|
||||
assert result is True
|
||||
mock_publisher.publish_event.assert_called_once()
|
||||
|
||||
call_args = mock_publisher.publish_event.call_args
|
||||
assert call_args[1]["routing_key"] == "training.failed"
|
||||
|
||||
event_data = call_args[1]["event_data"]
|
||||
assert event_data["data"]["error"] == error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_job_cancelled(self, mock_publisher):
|
||||
"""Test publishing job cancelled event"""
|
||||
job_id = "test-job-123"
|
||||
tenant_id = "test-tenant"
|
||||
|
||||
result = await messaging.publish_job_cancelled(job_id, tenant_id)
|
||||
|
||||
assert result is True
|
||||
mock_publisher.publish_event.assert_called_once()
|
||||
|
||||
call_args = mock_publisher.publish_event.call_args
|
||||
assert call_args[1]["routing_key"] == "training.cancelled"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_product_training_started(self, mock_publisher):
|
||||
"""Test publishing product training started event"""
|
||||
job_id = "test-product-job-123"
|
||||
tenant_id = "test-tenant"
|
||||
product_name = "Pan Integral"
|
||||
|
||||
result = await messaging.publish_product_training_started(job_id, tenant_id, product_name)
|
||||
|
||||
assert result is True
|
||||
mock_publisher.publish_event.assert_called_once()
|
||||
|
||||
call_args = mock_publisher.publish_event.call_args
|
||||
assert call_args[1]["routing_key"] == "training.product.started"
|
||||
|
||||
event_data = call_args[1]["event_data"]
|
||||
assert event_data["data"]["product_name"] == product_name
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_product_training_completed(self, mock_publisher):
|
||||
"""Test publishing product training completed event"""
|
||||
job_id = "test-product-job-123"
|
||||
tenant_id = "test-tenant"
|
||||
product_name = "Pan Integral"
|
||||
model_id = "test-model-123"
|
||||
|
||||
result = await messaging.publish_product_training_completed(
|
||||
job_id, tenant_id, product_name, model_id
|
||||
)
|
||||
|
||||
assert result is True
|
||||
mock_publisher.publish_event.assert_called_once()
|
||||
|
||||
call_args = mock_publisher.publish_event.call_args
|
||||
assert call_args[1]["routing_key"] == "training.product.completed"
|
||||
|
||||
event_data = call_args[1]["event_data"]
|
||||
assert event_data["data"]["model_id"] == model_id
|
||||
assert event_data["data"]["product_name"] == product_name
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_model_trained(self, mock_publisher):
|
||||
"""Test publishing model trained event"""
|
||||
model_id = "test-model-123"
|
||||
tenant_id = "test-tenant"
|
||||
product_name = "Pan Integral"
|
||||
metrics = {"mae": 5.2, "rmse": 7.8, "mape": 12.5}
|
||||
|
||||
result = await messaging.publish_model_trained(model_id, tenant_id, product_name, metrics)
|
||||
|
||||
assert result is True
|
||||
mock_publisher.publish_event.assert_called_once()
|
||||
|
||||
call_args = mock_publisher.publish_event.call_args
|
||||
assert call_args[1]["routing_key"] == "training.model.trained"
|
||||
|
||||
event_data = call_args[1]["event_data"]
|
||||
assert event_data["data"]["training_metrics"] == metrics
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_model_updated(self, mock_publisher):
|
||||
"""Test publishing model updated event"""
|
||||
model_id = "test-model-123"
|
||||
tenant_id = "test-tenant"
|
||||
product_name = "Pan Integral"
|
||||
version = 2
|
||||
|
||||
result = await messaging.publish_model_updated(model_id, tenant_id, product_name, version)
|
||||
|
||||
assert result is True
|
||||
mock_publisher.publish_event.assert_called_once()
|
||||
|
||||
call_args = mock_publisher.publish_event.call_args
|
||||
assert call_args[1]["routing_key"] == "training.model.updated"
|
||||
|
||||
event_data = call_args[1]["event_data"]
|
||||
assert event_data["data"]["version"] == version
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_model_validated(self, mock_publisher):
|
||||
"""Test publishing model validated event"""
|
||||
model_id = "test-model-123"
|
||||
tenant_id = "test-tenant"
|
||||
product_name = "Pan Integral"
|
||||
validation_results = {"is_valid": True, "accuracy": 0.95}
|
||||
|
||||
result = await messaging.publish_model_validated(
|
||||
model_id, tenant_id, product_name, validation_results
|
||||
)
|
||||
|
||||
assert result is True
|
||||
mock_publisher.publish_event.assert_called_once()
|
||||
|
||||
call_args = mock_publisher.publish_event.call_args
|
||||
assert call_args[1]["routing_key"] == "training.model.validated"
|
||||
|
||||
event_data = call_args[1]["event_data"]
|
||||
assert event_data["data"]["validation_results"] == validation_results
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_model_saved(self, mock_publisher):
|
||||
"""Test publishing model saved event"""
|
||||
model_id = "test-model-123"
|
||||
tenant_id = "test-tenant"
|
||||
product_name = "Pan Integral"
|
||||
model_path = "/models/test-model-123.pkl"
|
||||
|
||||
result = await messaging.publish_model_saved(model_id, tenant_id, product_name, model_path)
|
||||
|
||||
assert result is True
|
||||
mock_publisher.publish_event.assert_called_once()
|
||||
|
||||
call_args = mock_publisher.publish_event.call_args
|
||||
assert call_args[1]["routing_key"] == "training.model.saved"
|
||||
|
||||
event_data = call_args[1]["event_data"]
|
||||
assert event_data["data"]["model_path"] == model_path
|
||||
|
||||
|
||||
class TestMessagingErrorHandling:
|
||||
"""Test error handling in messaging"""
|
||||
|
||||
@pytest.fixture
|
||||
def failing_publisher(self):
|
||||
"""Mock publisher that fails"""
|
||||
with patch('app.services.messaging.training_publisher') as mock_pub:
|
||||
mock_pub.publish_event = AsyncMock(return_value=False)
|
||||
mock_pub.connect = AsyncMock(return_value=False)
|
||||
yield mock_pub
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_event_failure(self, failing_publisher):
|
||||
"""Test handling of publish event failure"""
|
||||
result = await messaging.publish_job_started("job-123", "tenant-123", {})
|
||||
|
||||
assert result is False
|
||||
failing_publisher.publish_event.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_messaging_connection_failure(self, failing_publisher):
|
||||
"""Test setup with connection failure"""
|
||||
await messaging.setup_messaging()
|
||||
|
||||
failing_publisher.connect.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_with_exception(self):
|
||||
"""Test publishing with exception"""
|
||||
with patch('app.services.messaging.training_publisher') as mock_pub:
|
||||
mock_pub.publish_event.side_effect = Exception("Connection lost")
|
||||
|
||||
result = await messaging.publish_job_started("job-123", "tenant-123", {})
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestMessagingIntegration:
|
||||
"""Test messaging integration with shared components"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_structure_consistency(self):
|
||||
"""Test that events follow consistent structure"""
|
||||
with patch('app.services.messaging.training_publisher') as mock_pub:
|
||||
mock_pub.publish_event = AsyncMock(return_value=True)
|
||||
|
||||
# Test different event types
|
||||
await messaging.publish_job_started("job-123", "tenant-123", {})
|
||||
await messaging.publish_job_completed("job-123", "tenant-123", {})
|
||||
await messaging.publish_model_trained("model-123", "tenant-123", "Pan", {})
|
||||
|
||||
# Verify all calls have consistent structure
|
||||
assert mock_pub.publish_event.call_count == 3
|
||||
|
||||
for call in mock_pub.publish_event.call_args_list:
|
||||
event_data = call[1]["event_data"]
|
||||
|
||||
# All events should have these fields
|
||||
assert "service_name" in event_data
|
||||
assert "event_type" in event_data
|
||||
assert "data" in event_data
|
||||
assert event_data["service_name"] == "training-service"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shared_event_classes_usage(self):
|
||||
"""Test that shared event classes are used properly"""
|
||||
with patch('shared.messaging.events.TrainingStartedEvent') as mock_event_class:
|
||||
mock_event = Mock()
|
||||
mock_event.to_dict.return_value = {
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.started",
|
||||
"data": {"job_id": "test-job"}
|
||||
}
|
||||
mock_event_class.return_value = mock_event
|
||||
|
||||
with patch('app.services.messaging.training_publisher') as mock_pub:
|
||||
mock_pub.publish_event = AsyncMock(return_value=True)
|
||||
|
||||
await messaging.publish_job_started("test-job", "test-tenant", {})
|
||||
|
||||
# Verify shared event class was used
|
||||
mock_event_class.assert_called_once()
|
||||
mock_event.to_dict.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_routing_key_consistency(self):
|
||||
"""Test that routing keys follow consistent patterns"""
|
||||
with patch('app.services.messaging.training_publisher') as mock_pub:
|
||||
mock_pub.publish_event = AsyncMock(return_value=True)
|
||||
|
||||
# Test various event types
|
||||
events_and_keys = [
|
||||
(messaging.publish_job_started, "training.started"),
|
||||
(messaging.publish_job_progress, "training.progress"),
|
||||
(messaging.publish_job_completed, "training.completed"),
|
||||
(messaging.publish_job_failed, "training.failed"),
|
||||
(messaging.publish_job_cancelled, "training.cancelled"),
|
||||
(messaging.publish_product_training_started, "training.product.started"),
|
||||
(messaging.publish_product_training_completed, "training.product.completed"),
|
||||
(messaging.publish_model_trained, "training.model.trained"),
|
||||
(messaging.publish_model_updated, "training.model.updated"),
|
||||
(messaging.publish_model_validated, "training.model.validated"),
|
||||
(messaging.publish_model_saved, "training.model.saved")
|
||||
]
|
||||
|
||||
for event_func, expected_key in events_and_keys:
|
||||
mock_pub.reset_mock()
|
||||
|
||||
# Call event function with appropriate parameters
|
||||
if "progress" in expected_key:
|
||||
await event_func("job-123", "tenant-123", 50, "step")
|
||||
elif "model" in expected_key and "trained" in expected_key:
|
||||
await event_func("model-123", "tenant-123", "product", {})
|
||||
elif "model" in expected_key and "updated" in expected_key:
|
||||
await event_func("model-123", "tenant-123", "product", 1)
|
||||
elif "model" in expected_key and "validated" in expected_key:
|
||||
await event_func("model-123", "tenant-123", "product", {})
|
||||
elif "model" in expected_key and "saved" in expected_key:
|
||||
await event_func("model-123", "tenant-123", "product", "/path")
|
||||
elif "product" in expected_key and "completed" in expected_key:
|
||||
await event_func("job-123", "tenant-123", "product", "model-123")
|
||||
elif "product" in expected_key:
|
||||
await event_func("job-123", "tenant-123", "product")
|
||||
elif "failed" in expected_key:
|
||||
await event_func("job-123", "tenant-123", "error")
|
||||
elif "cancelled" in expected_key:
|
||||
await event_func("job-123", "tenant-123")
|
||||
else:
|
||||
await event_func("job-123", "tenant-123", {})
|
||||
|
||||
# Verify routing key
|
||||
call_args = mock_pub.publish_event.call_args
|
||||
assert call_args[1]["routing_key"] == expected_key
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exchange_consistency(self):
|
||||
"""Test that all events use the same exchange"""
|
||||
with patch('app.services.messaging.training_publisher') as mock_pub:
|
||||
mock_pub.publish_event = AsyncMock(return_value=True)
|
||||
|
||||
# Test multiple events
|
||||
await messaging.publish_job_started("job-123", "tenant-123", {})
|
||||
await messaging.publish_model_trained("model-123", "tenant-123", "product", {})
|
||||
await messaging.publish_product_training_started("job-123", "tenant-123", "product")
|
||||
|
||||
# Verify all use same exchange
|
||||
for call in mock_pub.publish_event.call_args_list:
|
||||
assert call[1]["exchange_name"] == "training.events"
|
||||
|
||||
|
||||
class TestMessagingPerformance:
|
||||
"""Test messaging performance and reliability"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_publishing(self):
|
||||
"""Test concurrent event publishing"""
|
||||
import asyncio
|
||||
|
||||
with patch('app.services.messaging.training_publisher') as mock_pub:
|
||||
mock_pub.publish_event = AsyncMock(return_value=True)
|
||||
|
||||
# Create multiple concurrent publishing tasks
|
||||
tasks = []
|
||||
for i in range(10):
|
||||
task = messaging.publish_job_progress(f"job-{i}", "tenant-123", i * 10, f"step-{i}")
|
||||
tasks.append(task)
|
||||
|
||||
# Execute all tasks concurrently
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Verify all succeeded
|
||||
assert all(results)
|
||||
assert mock_pub.publish_event.call_count == 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_event_data(self):
|
||||
"""Test publishing events with large data payloads"""
|
||||
with patch('app.services.messaging.training_publisher') as mock_pub:
|
||||
mock_pub.publish_event = AsyncMock(return_value=True)
|
||||
|
||||
# Create large config data
|
||||
large_config = {
|
||||
"products": [f"Product-{i}" for i in range(1000)],
|
||||
"features": [f"feature-{i}" for i in range(100)],
|
||||
"hyperparameters": {f"param-{i}": i for i in range(50)}
|
||||
}
|
||||
|
||||
result = await messaging.publish_job_started("job-123", "tenant-123", large_config)
|
||||
|
||||
assert result is True
|
||||
mock_pub.publish_event.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rapid_sequential_publishing(self):
|
||||
"""Test rapid sequential event publishing"""
|
||||
with patch('app.services.messaging.training_publisher') as mock_pub:
|
||||
mock_pub.publish_event = AsyncMock(return_value=True)
|
||||
|
||||
# Publish many events in sequence
|
||||
for i in range(100):
|
||||
await messaging.publish_job_progress("job-123", "tenant-123", i, f"step-{i}")
|
||||
|
||||
assert mock_pub.publish_event.call_count == 100
|
||||
Reference in New Issue
Block a user