Files
bakery-ia/services/training/tests/test_integration.py
2025-07-19 16:59:37 +02:00

848 lines
28 KiB
Python

# services/training/tests/test_integration.py
"""
Integration tests for training service
Tests complete workflows and service interactions
"""
import pytest
import asyncio
from unittest.mock import AsyncMock, Mock, patch
from httpx import AsyncClient
from datetime import datetime, timedelta
from app.main import app
from app.schemas.training import TrainingJobRequest
class TestTrainingWorkflowIntegration:
"""Test complete training workflows end-to-end"""
@pytest.mark.asyncio
async def test_complete_training_workflow(
self,
test_client: AsyncClient,
test_db_session,
mock_messaging,
mock_data_service,
mock_ml_trainer
):
"""Test complete training workflow from API to completion"""
# Step 1: Start training job
request_data = {
"include_weather": True,
"include_traffic": True,
"min_data_points": 30,
"seasonality_mode": "additive"
}
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=request_data)
assert response.status_code == 200
job_data = response.json()
job_id = job_data["job_id"]
# Step 2: Check initial status
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.get(f"/training/jobs/{job_id}/status")
assert response.status_code == 200
status_data = response.json()
assert status_data["status"] in ["pending", "started"]
# Step 3: Simulate background task completion
# In real scenario, this would be handled by background tasks
await asyncio.sleep(0.1) # Allow background task to start
# Step 4: Check completion status
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.get(f"/training/jobs/{job_id}/status")
# The job should exist in database even if not completed yet
assert response.status_code == 200
@pytest.mark.asyncio
async def test_single_product_training_workflow(
self,
test_client: AsyncClient,
mock_messaging,
mock_data_service,
mock_ml_trainer
):
"""Test single product training complete workflow"""
product_name = "Pan Integral"
request_data = {
"include_weather": True,
"include_traffic": False,
"seasonality_mode": "additive"
}
# Start single product training
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.post(
f"/training/products/{product_name}",
json=request_data
)
assert response.status_code == 200
job_data = response.json()
job_id = job_data["job_id"]
assert f"training started for {product_name}" in job_data["message"].lower()
# Check job status
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.get(f"/training/jobs/{job_id}/status")
assert response.status_code == 200
status_data = response.json()
assert status_data["job_id"] == job_id
@pytest.mark.asyncio
async def test_training_validation_workflow(
self,
test_client: AsyncClient,
mock_data_service
):
"""Test training data validation workflow"""
request_data = {
"include_weather": True,
"include_traffic": True,
"min_data_points": 30
}
# Validate training data
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.post("/training/validate", json=request_data)
assert response.status_code == 200
validation_data = response.json()
assert "is_valid" in validation_data
assert "issues" in validation_data
assert "recommendations" in validation_data
assert "estimated_training_time" in validation_data
# If validation passes, start actual training
if validation_data["is_valid"]:
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=request_data)
assert response.status_code == 200
@pytest.mark.asyncio
async def test_job_cancellation_workflow(
self,
test_client: AsyncClient,
training_job_in_db,
mock_messaging
):
"""Test job cancellation workflow"""
job_id = training_job_in_db.job_id
# Check initial status
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.get(f"/training/jobs/{job_id}/status")
assert response.status_code == 200
initial_status = response.json()
assert initial_status["status"] == "pending"
# Cancel the job
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.post(f"/training/jobs/{job_id}/cancel")
assert response.status_code == 200
cancel_response = response.json()
assert "cancelled" in cancel_response["message"].lower()
# Verify cancellation
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.get(f"/training/jobs/{job_id}/status")
assert response.status_code == 200
final_status = response.json()
assert final_status["status"] == "cancelled"
class TestServiceInteractionIntegration:
"""Test interactions between training service and external services"""
@pytest.mark.asyncio
async def test_data_service_integration(self, training_service, mock_data_service):
"""Test integration with data service"""
from app.schemas.training import TrainingJobRequest
request = TrainingJobRequest(
include_weather=True,
include_traffic=True,
min_data_points=30
)
# Test sales data fetching
sales_data = await training_service._fetch_sales_data("test-tenant", request)
assert isinstance(sales_data, list)
# Test weather data fetching
weather_data = await training_service._fetch_weather_data("test-tenant", request)
assert isinstance(weather_data, list)
# Test traffic data fetching
traffic_data = await training_service._fetch_traffic_data("test-tenant", request)
assert isinstance(traffic_data, list)
@pytest.mark.asyncio
async def test_messaging_integration(self, mock_messaging):
"""Test integration with messaging system"""
from app.services.messaging import (
publish_job_started,
publish_job_completed,
publish_model_trained
)
# Test various message types
result1 = await publish_job_started("job-123", "tenant-123", {})
result2 = await publish_job_completed("job-123", "tenant-123", {"status": "success"})
result3 = await publish_model_trained("model-123", "tenant-123", "Pan Integral", {"mae": 5.0})
assert result1 is True
assert result2 is True
assert result3 is True
@pytest.mark.asyncio
async def test_database_integration(self, test_db_session, training_service):
"""Test database operations integration"""
# Create a training job
job = await training_service.create_training_job(
db=test_db_session,
tenant_id="test-tenant",
job_id="integration-test-job",
config={"test": True}
)
assert job.job_id == "integration-test-job"
# Update job status
await training_service._update_job_status(
db=test_db_session,
job_id=job.job_id,
status="running",
progress=50,
current_step="Processing data"
)
# Retrieve updated job
updated_job = await training_service.get_job_status(
db=test_db_session,
job_id=job.job_id,
tenant_id="test-tenant"
)
assert updated_job.status == "running"
assert updated_job.progress == 50
class TestErrorHandlingIntegration:
"""Test error handling across service boundaries"""
@pytest.mark.asyncio
async def test_data_service_failure_handling(
self,
test_client: AsyncClient,
mock_messaging
):
"""Test handling when data service is unavailable"""
request_data = {
"include_weather": True,
"include_traffic": True,
"min_data_points": 30
}
# Mock data service failure
with patch('httpx.AsyncClient') as mock_client:
mock_client.return_value.__aenter__.return_value.get.side_effect = Exception("Service unavailable")
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=request_data)
# Should still create job but might fail during execution
assert response.status_code == 200
@pytest.mark.asyncio
async def test_messaging_failure_handling(
self,
test_client: AsyncClient,
mock_data_service
):
"""Test handling when messaging fails"""
request_data = {
"include_weather": True,
"include_traffic": True,
"min_data_points": 30
}
# Mock messaging failure
with patch('app.services.messaging.publish_job_started', side_effect=Exception("Messaging failed")):
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=request_data)
# Should still succeed even if messaging fails
assert response.status_code == 200
@pytest.mark.asyncio
async def test_ml_training_failure_handling(
self,
test_client: AsyncClient,
mock_messaging,
mock_data_service
):
"""Test handling when ML training fails"""
request_data = {
"include_weather": True,
"include_traffic": True,
"min_data_points": 30
}
# Mock ML training failure
with patch('app.ml.trainer.BakeryMLTrainer.train_tenant_models', side_effect=Exception("ML training failed")):
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=request_data)
# Job should be created successfully
assert response.status_code == 200
# Background task would handle the failure
class TestPerformanceIntegration:
"""Test performance characteristics of integrated workflows"""
@pytest.mark.asyncio
async def test_concurrent_training_jobs(
self,
test_client: AsyncClient,
mock_messaging,
mock_data_service,
mock_ml_trainer
):
"""Test handling multiple concurrent training jobs"""
request_data = {
"include_weather": True,
"include_traffic": True,
"min_data_points": 30
}
# Start multiple jobs concurrently
tasks = []
for i in range(5):
with patch('app.api.training.get_current_tenant_id', return_value=f"tenant-{i}"):
task = test_client.post("/training/jobs", json=request_data)
tasks.append(task)
responses = await asyncio.gather(*tasks)
# All jobs should be created successfully
for response in responses:
assert response.status_code == 200
data = response.json()
assert "job_id" in data
@pytest.mark.asyncio
async def test_large_dataset_handling(
self,
training_service,
test_db_session
):
"""Test handling of large datasets"""
# Simulate large dataset
large_config = {
"include_weather": True,
"include_traffic": True,
"min_data_points": 1000, # Large minimum
"products": [f"Product-{i}" for i in range(100)] # Many products
}
job = await training_service.create_training_job(
db=test_db_session,
tenant_id="test-tenant",
job_id="large-dataset-job",
config=large_config
)
assert job.config == large_config
assert job.job_id == "large-dataset-job"
@pytest.mark.asyncio
async def test_rapid_status_checks(
self,
test_client: AsyncClient,
training_job_in_db
):
"""Test rapid successive status checks"""
job_id = training_job_in_db.job_id
# Make many rapid status requests
tasks = []
for _ in range(20):
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
task = test_client.get(f"/training/jobs/{job_id}/status")
tasks.append(task)
responses = await asyncio.gather(*tasks)
# All requests should succeed
for response in responses:
assert response.status_code == 200
class TestSecurityIntegration:
"""Test security aspects of service integration"""
@pytest.mark.asyncio
async def test_tenant_isolation(
self,
test_client: AsyncClient,
training_job_in_db,
mock_messaging
):
"""Test that tenants cannot access each other's jobs"""
job_id = training_job_in_db.job_id
# Try to access job with different tenant ID
with patch('app.api.training.get_current_tenant_id', return_value="different-tenant"):
response = await test_client.get(f"/training/jobs/{job_id}/status")
# Should not find the job (belongs to different tenant)
assert response.status_code == 404
@pytest.mark.asyncio
async def test_input_validation_integration(
self,
test_client: AsyncClient
):
"""Test input validation across API boundaries"""
# Test invalid seasonality mode
invalid_request = {
"seasonality_mode": "invalid_mode",
"min_data_points": -5 # Invalid negative value
}
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=invalid_request)
assert response.status_code == 422 # Validation error
@pytest.mark.asyncio
async def test_sql_injection_protection(
self,
test_client: AsyncClient
):
"""Test protection against SQL injection attempts"""
# Try SQL injection in job ID
malicious_job_id = "job'; DROP TABLE model_training_logs; --"
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.get(f"/training/jobs/{malicious_job_id}/status")
# Should return 404, not cause database error
assert response.status_code == 404
class TestRecoveryIntegration:
"""Test recovery and resilience scenarios"""
@pytest.mark.asyncio
async def test_service_restart_recovery(
self,
test_db_session,
training_service,
training_job_in_db
):
"""Test service recovery after restart"""
# Simulate service restart by creating new service instance
new_training_service = training_service.__class__()
# Should be able to access existing jobs
existing_job = await new_training_service.get_job_status(
db=test_db_session,
job_id=training_job_in_db.job_id,
tenant_id=training_job_in_db.tenant_id
)
assert existing_job is not None
assert existing_job.job_id == training_job_in_db.job_id
@pytest.mark.asyncio
async def test_partial_failure_recovery(
self,
training_service,
test_db_session
):
"""Test recovery from partial failures"""
# Create job that might fail partway through
job = await training_service.create_training_job(
db=test_db_session,
tenant_id="test-tenant",
job_id="partial-failure-job",
config={"simulate_failure": True}
)
# Simulate partial progress
await training_service._update_job_status(
db=test_db_session,
job_id=job.job_id,
status="running",
progress=50,
current_step="Halfway through training"
)
# Simulate failure
await training_service._update_job_status(
db=test_db_session,
job_id=job.job_id,
status="failed",
progress=50,
current_step="Training failed",
error_message="Simulated failure"
)
# Verify failure was recorded
failed_job = await training_service.get_job_status(
db=test_db_session,
job_id=job.job_id,
tenant_id="test-tenant"
)
assert failed_job.status == "failed"
assert failed_job.error_message == "Simulated failure"
assert failed_job.progress == 50
class TestComplianceIntegration:
"""Test compliance and audit requirements"""
@pytest.mark.asyncio
async def test_audit_trail_creation(
self,
training_service,
test_db_session
):
"""Test that audit trail is properly created"""
# Create and update job
job = await training_service.create_training_job(
db=test_db_session,
tenant_id="test-tenant",
job_id="audit-test-job",
config={"audit_test": True}
)
# Multiple status updates
await training_service._update_job_status(
db=test_db_session,
job_id=job.job_id,
status="running",
progress=25,
current_step="Started processing"
)
await training_service._update_job_status(
db=test_db_session,
job_id=job.job_id,
status="running",
progress=75,
current_step="Almost complete"
)
await training_service._update_job_status(
db=test_db_session,
job_id=job.job_id,
status="completed",
progress=100,
current_step="Completed successfully"
)
# Verify audit trail
logs = await training_service.get_training_logs(
db=test_db_session,
job_id=job.job_id,
tenant_id="test-tenant"
)
assert logs is not None
assert len(logs) > 0
# Check final status
final_job = await training_service.get_job_status(
db=test_db_session,
job_id=job.job_id,
tenant_id="test-tenant"
)
assert final_job.status == "completed"
assert final_job.progress == 100
@pytest.mark.asyncio
async def test_data_retention_compliance(
self,
training_service,
test_db_session
):
"""Test data retention and cleanup compliance"""
from datetime import datetime, timedelta
# Create old job (simulate old data)
old_job = await training_service.create_training_job(
db=test_db_session,
tenant_id="test-tenant",
job_id="old-job",
config={"created_long_ago": True}
)
# Manually set old timestamp
from sqlalchemy import update
from app.models.training import ModelTrainingLog
old_timestamp = datetime.now() - timedelta(days=400)
await test_db_session.execute(
update(ModelTrainingLog)
.where(ModelTrainingLog.job_id == old_job.job_id)
.values(start_time=old_timestamp, created_at=old_timestamp)
)
await test_db_session.commit()
# Verify old job exists
retrieved_job = await training_service.get_job_status(
db=test_db_session,
job_id=old_job.job_id,
tenant_id="test-tenant"
)
assert retrieved_job is not None
# In a real implementation, there would be cleanup procedures
@pytest.mark.asyncio
async def test_gdpr_compliance_features(
self,
training_service,
test_db_session
):
"""Test GDPR compliance features"""
# Create job with tenant data
job = await training_service.create_training_job(
db=test_db_session,
tenant_id="gdpr-test-tenant",
job_id="gdpr-test-job",
config={"gdpr_test": True}
)
# Verify job is associated with tenant
assert job.tenant_id == "gdpr-test-tenant"
# Test data access (right to access)
tenant_jobs = await training_service.list_training_jobs(
db=test_db_session,
tenant_id="gdpr-test-tenant"
)
assert len(tenant_jobs) >= 1
assert any(job.job_id == "gdpr-test-job" for job in tenant_jobs)
@pytest.mark.slow
class TestLongRunningIntegration:
"""Test long-running integration scenarios (marked as slow)"""
@pytest.mark.asyncio
async def test_extended_training_simulation(
self,
training_service,
test_db_session,
mock_messaging
):
"""Test extended training process simulation"""
job = await training_service.create_training_job(
db=test_db_session,
tenant_id="test-tenant",
job_id="long-running-job",
config={"extended_test": True}
)
# Simulate progress over time
progress_steps = [
(10, "Initializing"),
(25, "Loading data"),
(50, "Training models"),
(75, "Validating results"),
(90, "Storing models"),
(100, "Completed")
]
for progress, step in progress_steps:
await training_service._update_job_status(
db=test_db_session,
job_id=job.job_id,
status="running" if progress < 100 else "completed",
progress=progress,
current_step=step
)
# Small delay to simulate real progression
await asyncio.sleep(0.01)
# Verify final state
final_job = await training_service.get_job_status(
db=test_db_session,
job_id=job.job_id,
tenant_id="test-tenant"
)
assert final_job.status == "completed"
assert final_job.progress == 100
assert final_job.current_step == "Completed"
@pytest.mark.asyncio
async def test_memory_usage_stability(
self,
training_service,
test_db_session
):
"""Test memory usage stability over many operations"""
# Create many jobs to test memory stability
for i in range(50):
job = await training_service.create_training_job(
db=test_db_session,
tenant_id=f"tenant-{i % 5}", # 5 different tenants
job_id=f"memory-test-job-{i}",
config={"iteration": i}
)
# Update status
await training_service._update_job_status(
db=test_db_session,
job_id=job.job_id,
status="completed",
progress=100,
current_step="Completed"
)
# List jobs for each tenant
for tenant_i in range(5):
tenant_id = f"tenant-{tenant_i}"
jobs = await training_service.list_training_jobs(
db=test_db_session,
tenant_id=tenant_id,
limit=20
)
# Should have 10 jobs per tenant (50 total / 5 tenants)
assert len(jobs) == 10
class TestBackwardCompatibility:
"""Test backward compatibility with existing systems"""
@pytest.mark.asyncio
async def test_legacy_config_handling(
self,
training_service,
test_db_session
):
"""Test handling of legacy configuration formats"""
# Test with old-style configuration
legacy_config = {
"weather_enabled": True, # Old key
"traffic_enabled": True, # Old key
"minimum_samples": 30, # Old key
"prophet_config": { # Old nested structure
"seasonality": "additive"
}
}
job = await training_service.create_training_job(
db=test_db_session,
tenant_id="test-tenant",
job_id="legacy-config-job",
config=legacy_config
)
assert job.config == legacy_config
assert job.job_id == "legacy-config-job"
@pytest.mark.asyncio
async def test_api_version_compatibility(
self,
test_client: AsyncClient
):
"""Test API version compatibility"""
# Test with minimal request (old API style)
minimal_request = {
"include_weather": True
}
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=minimal_request)
# Should work with defaults for missing fields
assert response.status_code == 200
data = response.json()
assert "job_id" in data
# Utility functions for integration tests
async def wait_for_condition(condition_func, timeout=5.0, interval=0.1):
"""Wait for a condition to become true"""
import time
start_time = time.time()
while time.time() - start_time < timeout:
if await condition_func():
return True
await asyncio.sleep(interval)
return False
def assert_job_progression(job_updates):
"""Assert that job updates show proper progression"""
assert len(job_updates) > 0
# Check progress is non-decreasing
for i in range(1, len(job_updates)):
assert job_updates[i]["progress"] >= job_updates[i-1]["progress"]
# Check final status
final_update = job_updates[-1]
assert final_update["status"] in ["completed", "failed", "cancelled"]
def assert_valid_job_structure(job_data):
"""Assert job data has valid structure"""
required_fields = ["job_id", "status", "tenant_id"]
for field in required_fields:
assert field in job_data
assert isinstance(job_data["progress"], int)
assert 0 <= job_data["progress"] <= 100
assert job_data["status"] in ["pending", "running", "completed", "failed", "cancelled"]