848 lines
28 KiB
Python
848 lines
28 KiB
Python
# services/training/tests/test_integration.py
|
|
"""
|
|
Integration tests for training service
|
|
Tests complete workflows and service interactions
|
|
"""
|
|
|
|
import pytest
|
|
import asyncio
|
|
from unittest.mock import AsyncMock, Mock, patch
|
|
from httpx import AsyncClient
|
|
from datetime import datetime, timedelta
|
|
|
|
from app.main import app
|
|
from app.schemas.training import TrainingJobRequest
|
|
|
|
|
|
class TestTrainingWorkflowIntegration:
|
|
"""Test complete training workflows end-to-end"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_training_workflow(
|
|
self,
|
|
test_client: AsyncClient,
|
|
test_db_session,
|
|
mock_messaging,
|
|
mock_data_service,
|
|
mock_ml_trainer
|
|
):
|
|
"""Test complete training workflow from API to completion"""
|
|
|
|
# Step 1: Start training job
|
|
request_data = {
|
|
"include_weather": True,
|
|
"include_traffic": True,
|
|
"min_data_points": 30,
|
|
"seasonality_mode": "additive"
|
|
}
|
|
|
|
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
|
response = await test_client.post("/training/jobs", json=request_data)
|
|
|
|
assert response.status_code == 200
|
|
job_data = response.json()
|
|
job_id = job_data["job_id"]
|
|
|
|
# Step 2: Check initial status
|
|
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
|
response = await test_client.get(f"/training/jobs/{job_id}/status")
|
|
|
|
assert response.status_code == 200
|
|
status_data = response.json()
|
|
assert status_data["status"] in ["pending", "started"]
|
|
|
|
# Step 3: Simulate background task completion
|
|
# In real scenario, this would be handled by background tasks
|
|
await asyncio.sleep(0.1) # Allow background task to start
|
|
|
|
# Step 4: Check completion status
|
|
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
|
response = await test_client.get(f"/training/jobs/{job_id}/status")
|
|
|
|
# The job should exist in database even if not completed yet
|
|
assert response.status_code == 200
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_single_product_training_workflow(
|
|
self,
|
|
test_client: AsyncClient,
|
|
mock_messaging,
|
|
mock_data_service,
|
|
mock_ml_trainer
|
|
):
|
|
"""Test single product training complete workflow"""
|
|
|
|
product_name = "Pan Integral"
|
|
request_data = {
|
|
"include_weather": True,
|
|
"include_traffic": False,
|
|
"seasonality_mode": "additive"
|
|
}
|
|
|
|
# Start single product training
|
|
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
|
response = await test_client.post(
|
|
f"/training/products/{product_name}",
|
|
json=request_data
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
job_data = response.json()
|
|
job_id = job_data["job_id"]
|
|
assert f"training started for {product_name}" in job_data["message"].lower()
|
|
|
|
# Check job status
|
|
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
|
response = await test_client.get(f"/training/jobs/{job_id}/status")
|
|
|
|
assert response.status_code == 200
|
|
status_data = response.json()
|
|
assert status_data["job_id"] == job_id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_training_validation_workflow(
|
|
self,
|
|
test_client: AsyncClient,
|
|
mock_data_service
|
|
):
|
|
"""Test training data validation workflow"""
|
|
|
|
request_data = {
|
|
"include_weather": True,
|
|
"include_traffic": True,
|
|
"min_data_points": 30
|
|
}
|
|
|
|
# Validate training data
|
|
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
|
response = await test_client.post("/training/validate", json=request_data)
|
|
|
|
assert response.status_code == 200
|
|
validation_data = response.json()
|
|
|
|
assert "is_valid" in validation_data
|
|
assert "issues" in validation_data
|
|
assert "recommendations" in validation_data
|
|
assert "estimated_training_time" in validation_data
|
|
|
|
# If validation passes, start actual training
|
|
if validation_data["is_valid"]:
|
|
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
|
response = await test_client.post("/training/jobs", json=request_data)
|
|
|
|
assert response.status_code == 200
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_job_cancellation_workflow(
|
|
self,
|
|
test_client: AsyncClient,
|
|
training_job_in_db,
|
|
mock_messaging
|
|
):
|
|
"""Test job cancellation workflow"""
|
|
|
|
job_id = training_job_in_db.job_id
|
|
|
|
# Check initial status
|
|
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
|
response = await test_client.get(f"/training/jobs/{job_id}/status")
|
|
|
|
assert response.status_code == 200
|
|
initial_status = response.json()
|
|
assert initial_status["status"] == "pending"
|
|
|
|
# Cancel the job
|
|
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
|
response = await test_client.post(f"/training/jobs/{job_id}/cancel")
|
|
|
|
assert response.status_code == 200
|
|
cancel_response = response.json()
|
|
assert "cancelled" in cancel_response["message"].lower()
|
|
|
|
# Verify cancellation
|
|
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
|
response = await test_client.get(f"/training/jobs/{job_id}/status")
|
|
|
|
assert response.status_code == 200
|
|
final_status = response.json()
|
|
assert final_status["status"] == "cancelled"
|
|
|
|
|
|
class TestServiceInteractionIntegration:
|
|
"""Test interactions between training service and external services"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_data_service_integration(self, training_service, mock_data_service):
|
|
"""Test integration with data service"""
|
|
from app.schemas.training import TrainingJobRequest
|
|
|
|
request = TrainingJobRequest(
|
|
include_weather=True,
|
|
include_traffic=True,
|
|
min_data_points=30
|
|
)
|
|
|
|
# Test sales data fetching
|
|
sales_data = await training_service._fetch_sales_data("test-tenant", request)
|
|
assert isinstance(sales_data, list)
|
|
|
|
# Test weather data fetching
|
|
weather_data = await training_service._fetch_weather_data("test-tenant", request)
|
|
assert isinstance(weather_data, list)
|
|
|
|
# Test traffic data fetching
|
|
traffic_data = await training_service._fetch_traffic_data("test-tenant", request)
|
|
assert isinstance(traffic_data, list)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_messaging_integration(self, mock_messaging):
|
|
"""Test integration with messaging system"""
|
|
from app.services.messaging import (
|
|
publish_job_started,
|
|
publish_job_completed,
|
|
publish_model_trained
|
|
)
|
|
|
|
# Test various message types
|
|
result1 = await publish_job_started("job-123", "tenant-123", {})
|
|
result2 = await publish_job_completed("job-123", "tenant-123", {"status": "success"})
|
|
result3 = await publish_model_trained("model-123", "tenant-123", "Pan Integral", {"mae": 5.0})
|
|
|
|
assert result1 is True
|
|
assert result2 is True
|
|
assert result3 is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_database_integration(self, test_db_session, training_service):
|
|
"""Test database operations integration"""
|
|
|
|
# Create a training job
|
|
job = await training_service.create_training_job(
|
|
db=test_db_session,
|
|
tenant_id="test-tenant",
|
|
job_id="integration-test-job",
|
|
config={"test": True}
|
|
)
|
|
|
|
assert job.job_id == "integration-test-job"
|
|
|
|
# Update job status
|
|
await training_service._update_job_status(
|
|
db=test_db_session,
|
|
job_id=job.job_id,
|
|
status="running",
|
|
progress=50,
|
|
current_step="Processing data"
|
|
)
|
|
|
|
# Retrieve updated job
|
|
updated_job = await training_service.get_job_status(
|
|
db=test_db_session,
|
|
job_id=job.job_id,
|
|
tenant_id="test-tenant"
|
|
)
|
|
|
|
assert updated_job.status == "running"
|
|
assert updated_job.progress == 50
|
|
|
|
|
|
class TestErrorHandlingIntegration:
|
|
"""Test error handling across service boundaries"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_data_service_failure_handling(
|
|
self,
|
|
test_client: AsyncClient,
|
|
mock_messaging
|
|
):
|
|
"""Test handling when data service is unavailable"""
|
|
|
|
request_data = {
|
|
"include_weather": True,
|
|
"include_traffic": True,
|
|
"min_data_points": 30
|
|
}
|
|
|
|
# Mock data service failure
|
|
with patch('httpx.AsyncClient') as mock_client:
|
|
mock_client.return_value.__aenter__.return_value.get.side_effect = Exception("Service unavailable")
|
|
|
|
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
|
response = await test_client.post("/training/jobs", json=request_data)
|
|
|
|
# Should still create job but might fail during execution
|
|
assert response.status_code == 200
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_messaging_failure_handling(
|
|
self,
|
|
test_client: AsyncClient,
|
|
mock_data_service
|
|
):
|
|
"""Test handling when messaging fails"""
|
|
|
|
request_data = {
|
|
"include_weather": True,
|
|
"include_traffic": True,
|
|
"min_data_points": 30
|
|
}
|
|
|
|
# Mock messaging failure
|
|
with patch('app.services.messaging.publish_job_started', side_effect=Exception("Messaging failed")):
|
|
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
|
response = await test_client.post("/training/jobs", json=request_data)
|
|
|
|
# Should still succeed even if messaging fails
|
|
assert response.status_code == 200
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ml_training_failure_handling(
|
|
self,
|
|
test_client: AsyncClient,
|
|
mock_messaging,
|
|
mock_data_service
|
|
):
|
|
"""Test handling when ML training fails"""
|
|
|
|
request_data = {
|
|
"include_weather": True,
|
|
"include_traffic": True,
|
|
"min_data_points": 30
|
|
}
|
|
|
|
# Mock ML training failure
|
|
with patch('app.ml.trainer.BakeryMLTrainer.train_tenant_models', side_effect=Exception("ML training failed")):
|
|
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
|
response = await test_client.post("/training/jobs", json=request_data)
|
|
|
|
# Job should be created successfully
|
|
assert response.status_code == 200
|
|
|
|
# Background task would handle the failure
|
|
|
|
|
|
class TestPerformanceIntegration:
|
|
"""Test performance characteristics of integrated workflows"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_concurrent_training_jobs(
|
|
self,
|
|
test_client: AsyncClient,
|
|
mock_messaging,
|
|
mock_data_service,
|
|
mock_ml_trainer
|
|
):
|
|
"""Test handling multiple concurrent training jobs"""
|
|
|
|
request_data = {
|
|
"include_weather": True,
|
|
"include_traffic": True,
|
|
"min_data_points": 30
|
|
}
|
|
|
|
# Start multiple jobs concurrently
|
|
tasks = []
|
|
for i in range(5):
|
|
with patch('app.api.training.get_current_tenant_id', return_value=f"tenant-{i}"):
|
|
task = test_client.post("/training/jobs", json=request_data)
|
|
tasks.append(task)
|
|
|
|
responses = await asyncio.gather(*tasks)
|
|
|
|
# All jobs should be created successfully
|
|
for response in responses:
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "job_id" in data
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_large_dataset_handling(
|
|
self,
|
|
training_service,
|
|
test_db_session
|
|
):
|
|
"""Test handling of large datasets"""
|
|
|
|
# Simulate large dataset
|
|
large_config = {
|
|
"include_weather": True,
|
|
"include_traffic": True,
|
|
"min_data_points": 1000, # Large minimum
|
|
"products": [f"Product-{i}" for i in range(100)] # Many products
|
|
}
|
|
|
|
job = await training_service.create_training_job(
|
|
db=test_db_session,
|
|
tenant_id="test-tenant",
|
|
job_id="large-dataset-job",
|
|
config=large_config
|
|
)
|
|
|
|
assert job.config == large_config
|
|
assert job.job_id == "large-dataset-job"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rapid_status_checks(
|
|
self,
|
|
test_client: AsyncClient,
|
|
training_job_in_db
|
|
):
|
|
"""Test rapid successive status checks"""
|
|
|
|
job_id = training_job_in_db.job_id
|
|
|
|
# Make many rapid status requests
|
|
tasks = []
|
|
for _ in range(20):
|
|
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
|
task = test_client.get(f"/training/jobs/{job_id}/status")
|
|
tasks.append(task)
|
|
|
|
responses = await asyncio.gather(*tasks)
|
|
|
|
# All requests should succeed
|
|
for response in responses:
|
|
assert response.status_code == 200
|
|
|
|
|
|
class TestSecurityIntegration:
|
|
"""Test security aspects of service integration"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tenant_isolation(
|
|
self,
|
|
test_client: AsyncClient,
|
|
training_job_in_db,
|
|
mock_messaging
|
|
):
|
|
"""Test that tenants cannot access each other's jobs"""
|
|
|
|
job_id = training_job_in_db.job_id
|
|
|
|
# Try to access job with different tenant ID
|
|
with patch('app.api.training.get_current_tenant_id', return_value="different-tenant"):
|
|
response = await test_client.get(f"/training/jobs/{job_id}/status")
|
|
|
|
# Should not find the job (belongs to different tenant)
|
|
assert response.status_code == 404
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_input_validation_integration(
|
|
self,
|
|
test_client: AsyncClient
|
|
):
|
|
"""Test input validation across API boundaries"""
|
|
|
|
# Test invalid seasonality mode
|
|
invalid_request = {
|
|
"seasonality_mode": "invalid_mode",
|
|
"min_data_points": -5 # Invalid negative value
|
|
}
|
|
|
|
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
|
response = await test_client.post("/training/jobs", json=invalid_request)
|
|
|
|
assert response.status_code == 422 # Validation error
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sql_injection_protection(
|
|
self,
|
|
test_client: AsyncClient
|
|
):
|
|
"""Test protection against SQL injection attempts"""
|
|
|
|
# Try SQL injection in job ID
|
|
malicious_job_id = "job'; DROP TABLE model_training_logs; --"
|
|
|
|
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
|
response = await test_client.get(f"/training/jobs/{malicious_job_id}/status")
|
|
|
|
# Should return 404, not cause database error
|
|
assert response.status_code == 404
|
|
|
|
|
|
class TestRecoveryIntegration:
|
|
"""Test recovery and resilience scenarios"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_service_restart_recovery(
|
|
self,
|
|
test_db_session,
|
|
training_service,
|
|
training_job_in_db
|
|
):
|
|
"""Test service recovery after restart"""
|
|
|
|
# Simulate service restart by creating new service instance
|
|
new_training_service = training_service.__class__()
|
|
|
|
# Should be able to access existing jobs
|
|
existing_job = await new_training_service.get_job_status(
|
|
db=test_db_session,
|
|
job_id=training_job_in_db.job_id,
|
|
tenant_id=training_job_in_db.tenant_id
|
|
)
|
|
|
|
assert existing_job is not None
|
|
assert existing_job.job_id == training_job_in_db.job_id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_partial_failure_recovery(
|
|
self,
|
|
training_service,
|
|
test_db_session
|
|
):
|
|
"""Test recovery from partial failures"""
|
|
|
|
# Create job that might fail partway through
|
|
job = await training_service.create_training_job(
|
|
db=test_db_session,
|
|
tenant_id="test-tenant",
|
|
job_id="partial-failure-job",
|
|
config={"simulate_failure": True}
|
|
)
|
|
|
|
# Simulate partial progress
|
|
await training_service._update_job_status(
|
|
db=test_db_session,
|
|
job_id=job.job_id,
|
|
status="running",
|
|
progress=50,
|
|
current_step="Halfway through training"
|
|
)
|
|
|
|
# Simulate failure
|
|
await training_service._update_job_status(
|
|
db=test_db_session,
|
|
job_id=job.job_id,
|
|
status="failed",
|
|
progress=50,
|
|
current_step="Training failed",
|
|
error_message="Simulated failure"
|
|
)
|
|
|
|
# Verify failure was recorded
|
|
failed_job = await training_service.get_job_status(
|
|
db=test_db_session,
|
|
job_id=job.job_id,
|
|
tenant_id="test-tenant"
|
|
)
|
|
|
|
assert failed_job.status == "failed"
|
|
assert failed_job.error_message == "Simulated failure"
|
|
assert failed_job.progress == 50
|
|
|
|
|
|
class TestComplianceIntegration:
|
|
"""Test compliance and audit requirements"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_audit_trail_creation(
|
|
self,
|
|
training_service,
|
|
test_db_session
|
|
):
|
|
"""Test that audit trail is properly created"""
|
|
|
|
# Create and update job
|
|
job = await training_service.create_training_job(
|
|
db=test_db_session,
|
|
tenant_id="test-tenant",
|
|
job_id="audit-test-job",
|
|
config={"audit_test": True}
|
|
)
|
|
|
|
# Multiple status updates
|
|
await training_service._update_job_status(
|
|
db=test_db_session,
|
|
job_id=job.job_id,
|
|
status="running",
|
|
progress=25,
|
|
current_step="Started processing"
|
|
)
|
|
|
|
await training_service._update_job_status(
|
|
db=test_db_session,
|
|
job_id=job.job_id,
|
|
status="running",
|
|
progress=75,
|
|
current_step="Almost complete"
|
|
)
|
|
|
|
await training_service._update_job_status(
|
|
db=test_db_session,
|
|
job_id=job.job_id,
|
|
status="completed",
|
|
progress=100,
|
|
current_step="Completed successfully"
|
|
)
|
|
|
|
# Verify audit trail
|
|
logs = await training_service.get_training_logs(
|
|
db=test_db_session,
|
|
job_id=job.job_id,
|
|
tenant_id="test-tenant"
|
|
)
|
|
|
|
assert logs is not None
|
|
assert len(logs) > 0
|
|
|
|
# Check final status
|
|
final_job = await training_service.get_job_status(
|
|
db=test_db_session,
|
|
job_id=job.job_id,
|
|
tenant_id="test-tenant"
|
|
)
|
|
|
|
assert final_job.status == "completed"
|
|
assert final_job.progress == 100
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_data_retention_compliance(
|
|
self,
|
|
training_service,
|
|
test_db_session
|
|
):
|
|
"""Test data retention and cleanup compliance"""
|
|
|
|
from datetime import datetime, timedelta
|
|
|
|
# Create old job (simulate old data)
|
|
old_job = await training_service.create_training_job(
|
|
db=test_db_session,
|
|
tenant_id="test-tenant",
|
|
job_id="old-job",
|
|
config={"created_long_ago": True}
|
|
)
|
|
|
|
# Manually set old timestamp
|
|
from sqlalchemy import update
|
|
from app.models.training import ModelTrainingLog
|
|
|
|
old_timestamp = datetime.now() - timedelta(days=400)
|
|
await test_db_session.execute(
|
|
update(ModelTrainingLog)
|
|
.where(ModelTrainingLog.job_id == old_job.job_id)
|
|
.values(start_time=old_timestamp, created_at=old_timestamp)
|
|
)
|
|
await test_db_session.commit()
|
|
|
|
# Verify old job exists
|
|
retrieved_job = await training_service.get_job_status(
|
|
db=test_db_session,
|
|
job_id=old_job.job_id,
|
|
tenant_id="test-tenant"
|
|
)
|
|
|
|
assert retrieved_job is not None
|
|
# In a real implementation, there would be cleanup procedures
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_gdpr_compliance_features(
|
|
self,
|
|
training_service,
|
|
test_db_session
|
|
):
|
|
"""Test GDPR compliance features"""
|
|
|
|
# Create job with tenant data
|
|
job = await training_service.create_training_job(
|
|
db=test_db_session,
|
|
tenant_id="gdpr-test-tenant",
|
|
job_id="gdpr-test-job",
|
|
config={"gdpr_test": True}
|
|
)
|
|
|
|
# Verify job is associated with tenant
|
|
assert job.tenant_id == "gdpr-test-tenant"
|
|
|
|
# Test data access (right to access)
|
|
tenant_jobs = await training_service.list_training_jobs(
|
|
db=test_db_session,
|
|
tenant_id="gdpr-test-tenant"
|
|
)
|
|
|
|
assert len(tenant_jobs) >= 1
|
|
assert any(job.job_id == "gdpr-test-job" for job in tenant_jobs)
|
|
|
|
|
|
@pytest.mark.slow
|
|
class TestLongRunningIntegration:
|
|
"""Test long-running integration scenarios (marked as slow)"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extended_training_simulation(
|
|
self,
|
|
training_service,
|
|
test_db_session,
|
|
mock_messaging
|
|
):
|
|
"""Test extended training process simulation"""
|
|
|
|
job = await training_service.create_training_job(
|
|
db=test_db_session,
|
|
tenant_id="test-tenant",
|
|
job_id="long-running-job",
|
|
config={"extended_test": True}
|
|
)
|
|
|
|
# Simulate progress over time
|
|
progress_steps = [
|
|
(10, "Initializing"),
|
|
(25, "Loading data"),
|
|
(50, "Training models"),
|
|
(75, "Validating results"),
|
|
(90, "Storing models"),
|
|
(100, "Completed")
|
|
]
|
|
|
|
for progress, step in progress_steps:
|
|
await training_service._update_job_status(
|
|
db=test_db_session,
|
|
job_id=job.job_id,
|
|
status="running" if progress < 100 else "completed",
|
|
progress=progress,
|
|
current_step=step
|
|
)
|
|
|
|
# Small delay to simulate real progression
|
|
await asyncio.sleep(0.01)
|
|
|
|
# Verify final state
|
|
final_job = await training_service.get_job_status(
|
|
db=test_db_session,
|
|
job_id=job.job_id,
|
|
tenant_id="test-tenant"
|
|
)
|
|
|
|
assert final_job.status == "completed"
|
|
assert final_job.progress == 100
|
|
assert final_job.current_step == "Completed"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_memory_usage_stability(
|
|
self,
|
|
training_service,
|
|
test_db_session
|
|
):
|
|
"""Test memory usage stability over many operations"""
|
|
|
|
# Create many jobs to test memory stability
|
|
for i in range(50):
|
|
job = await training_service.create_training_job(
|
|
db=test_db_session,
|
|
tenant_id=f"tenant-{i % 5}", # 5 different tenants
|
|
job_id=f"memory-test-job-{i}",
|
|
config={"iteration": i}
|
|
)
|
|
|
|
# Update status
|
|
await training_service._update_job_status(
|
|
db=test_db_session,
|
|
job_id=job.job_id,
|
|
status="completed",
|
|
progress=100,
|
|
current_step="Completed"
|
|
)
|
|
|
|
# List jobs for each tenant
|
|
for tenant_i in range(5):
|
|
tenant_id = f"tenant-{tenant_i}"
|
|
jobs = await training_service.list_training_jobs(
|
|
db=test_db_session,
|
|
tenant_id=tenant_id,
|
|
limit=20
|
|
)
|
|
|
|
# Should have 10 jobs per tenant (50 total / 5 tenants)
|
|
assert len(jobs) == 10
|
|
|
|
|
|
class TestBackwardCompatibility:
|
|
"""Test backward compatibility with existing systems"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_legacy_config_handling(
|
|
self,
|
|
training_service,
|
|
test_db_session
|
|
):
|
|
"""Test handling of legacy configuration formats"""
|
|
|
|
# Test with old-style configuration
|
|
legacy_config = {
|
|
"weather_enabled": True, # Old key
|
|
"traffic_enabled": True, # Old key
|
|
"minimum_samples": 30, # Old key
|
|
"prophet_config": { # Old nested structure
|
|
"seasonality": "additive"
|
|
}
|
|
}
|
|
|
|
job = await training_service.create_training_job(
|
|
db=test_db_session,
|
|
tenant_id="test-tenant",
|
|
job_id="legacy-config-job",
|
|
config=legacy_config
|
|
)
|
|
|
|
assert job.config == legacy_config
|
|
assert job.job_id == "legacy-config-job"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_api_version_compatibility(
|
|
self,
|
|
test_client: AsyncClient
|
|
):
|
|
"""Test API version compatibility"""
|
|
|
|
# Test with minimal request (old API style)
|
|
minimal_request = {
|
|
"include_weather": True
|
|
}
|
|
|
|
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
|
response = await test_client.post("/training/jobs", json=minimal_request)
|
|
|
|
# Should work with defaults for missing fields
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "job_id" in data
|
|
|
|
|
|
# Utility functions for integration tests
|
|
async def wait_for_condition(condition_func, timeout=5.0, interval=0.1):
|
|
"""Wait for a condition to become true"""
|
|
import time
|
|
start_time = time.time()
|
|
|
|
while time.time() - start_time < timeout:
|
|
if await condition_func():
|
|
return True
|
|
await asyncio.sleep(interval)
|
|
|
|
return False
|
|
|
|
|
|
def assert_job_progression(job_updates):
|
|
"""Assert that job updates show proper progression"""
|
|
assert len(job_updates) > 0
|
|
|
|
# Check progress is non-decreasing
|
|
for i in range(1, len(job_updates)):
|
|
assert job_updates[i]["progress"] >= job_updates[i-1]["progress"]
|
|
|
|
# Check final status
|
|
final_update = job_updates[-1]
|
|
assert final_update["status"] in ["completed", "failed", "cancelled"]
|
|
|
|
|
|
def assert_valid_job_structure(job_data):
|
|
"""Assert job data has valid structure"""
|
|
required_fields = ["job_id", "status", "tenant_id"]
|
|
for field in required_fields:
|
|
assert field in job_data
|
|
|
|
assert isinstance(job_data["progress"], int)
|
|
assert 0 <= job_data["progress"] <= 100
|
|
assert job_data["status"] in ["pending", "running", "completed", "failed", "cancelled"] |