687 lines
25 KiB
Python
687 lines
25 KiB
Python
# services/training/tests/test_api.py
|
|
"""
|
|
Tests for training service API endpoints
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock, patch
|
|
from fastapi import status
|
|
from httpx import AsyncClient
|
|
|
|
from app.schemas.training import TrainingJobRequest
|
|
|
|
|
|
class TestTrainingAPI:
|
|
"""Test training API endpoints"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_check(self, test_client: AsyncClient):
|
|
"""Test health check endpoint"""
|
|
response = await test_client.get("/health")
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert data["service"] == "training-service"
|
|
assert data["version"] == "1.0.0"
|
|
assert "status" in data
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_readiness_check_ready(self, test_client: AsyncClient):
|
|
"""Test readiness check when service is ready"""
|
|
# Mock app state as ready
|
|
from app.main import app # Add import at top
|
|
with patch.object(app.state, 'ready', True, create=True):
|
|
response = await test_client.get("/health/ready")
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert data["status"] == "ready"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_readiness_check_not_ready(self, test_client: AsyncClient):
|
|
"""Test readiness check when service is not ready"""
|
|
with patch('app.main.app.state.ready', False):
|
|
response = await test_client.get("/health/ready")
|
|
|
|
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
|
|
data = response.json()
|
|
assert data["status"] == "not_ready"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_liveness_check_healthy(self, test_client: AsyncClient):
|
|
"""Test liveness check when service is healthy"""
|
|
with patch('app.core.database.get_db_health', return_value=AsyncMock(return_value=True)):
|
|
response = await test_client.get("/health/live")
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert data["status"] == "alive"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_liveness_check_unhealthy(self, test_client: AsyncClient):
|
|
"""Test liveness check when database is unhealthy"""
|
|
with patch('app.core.database.get_db_health', return_value=AsyncMock(return_value=False)):
|
|
response = await test_client.get("/health/live")
|
|
|
|
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
|
|
data = response.json()
|
|
assert data["status"] == "unhealthy"
|
|
assert data["reason"] == "database_unavailable"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_metrics_endpoint(self, test_client: AsyncClient):
|
|
"""Test metrics endpoint"""
|
|
response = await test_client.get("/metrics")
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
|
|
expected_metrics = [
|
|
"training_jobs_active",
|
|
"training_jobs_completed",
|
|
"training_jobs_failed",
|
|
"models_trained_total",
|
|
"uptime_seconds"
|
|
]
|
|
|
|
for metric in expected_metrics:
|
|
assert metric in data
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_root_endpoint(self, test_client: AsyncClient):
|
|
"""Test root endpoint"""
|
|
response = await test_client.get("/")
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert data["service"] == "training-service"
|
|
assert data["version"] == "1.0.0"
|
|
assert "description" in data
|
|
|
|
|
|
class TestTrainingJobsAPI:
|
|
"""Test training jobs API endpoints"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_start_training_job_success(
|
|
self,
|
|
test_client: AsyncClient,
|
|
mock_messaging,
|
|
mock_ml_trainer,
|
|
mock_data_service
|
|
):
|
|
"""Test starting a training job successfully"""
|
|
request_data = {
|
|
"include_weather": True,
|
|
"include_traffic": True,
|
|
"min_data_points": 30,
|
|
"seasonality_mode": "additive"
|
|
}
|
|
|
|
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
|
response = await test_client.post("/training/jobs", json=request_data)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
|
|
assert "job_id" in data
|
|
assert data["status"] == "started"
|
|
assert data["tenant_id"] == "test-tenant"
|
|
assert "estimated_duration_minutes" in data
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_start_training_job_validation_error(self, test_client: AsyncClient):
|
|
"""Test starting training job with validation error"""
|
|
request_data = {
|
|
"seasonality_mode": "invalid_mode", # Invalid value
|
|
"min_data_points": 5 # Too low
|
|
}
|
|
|
|
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
|
response = await test_client.post("/training/jobs", json=request_data)
|
|
|
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_training_status_existing_job(
|
|
self,
|
|
test_client: AsyncClient,
|
|
training_job_in_db
|
|
):
|
|
"""Test getting status of existing training job"""
|
|
job_id = training_job_in_db.job_id
|
|
|
|
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
|
response = await test_client.get(f"/training/jobs/{job_id}/status")
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
|
|
assert data["job_id"] == job_id
|
|
assert data["status"] == "pending"
|
|
assert "progress" in data
|
|
assert "started_at" in data
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_training_status_nonexistent_job(self, test_client: AsyncClient):
|
|
"""Test getting status of non-existent training job"""
|
|
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
|
response = await test_client.get("/training/jobs/nonexistent-job/status")
|
|
|
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_training_jobs(
|
|
self,
|
|
test_client: AsyncClient,
|
|
training_job_in_db
|
|
):
|
|
"""Test listing training jobs"""
|
|
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
|
response = await test_client.get("/training/jobs")
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
|
|
assert isinstance(data, list)
|
|
assert len(data) >= 1
|
|
|
|
# Check first job structure
|
|
job = data[0]
|
|
assert "job_id" in job
|
|
assert "status" in job
|
|
assert "started_at" in job
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_training_jobs_with_status_filter(
|
|
self,
|
|
test_client: AsyncClient,
|
|
training_job_in_db
|
|
):
|
|
"""Test listing training jobs with status filter"""
|
|
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
|
response = await test_client.get("/training/jobs?status=pending")
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
|
|
assert isinstance(data, list)
|
|
# All jobs should have status "pending"
|
|
for job in data:
|
|
assert job["status"] == "pending"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cancel_training_job_success(
|
|
self,
|
|
test_client: AsyncClient,
|
|
training_job_in_db,
|
|
mock_messaging
|
|
):
|
|
"""Test cancelling a training job successfully"""
|
|
job_id = training_job_in_db.job_id
|
|
|
|
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
|
response = await test_client.post(f"/training/jobs/{job_id}/cancel")
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert "message" in data
|
|
assert "cancelled" in data["message"].lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cancel_nonexistent_job(self, test_client: AsyncClient):
|
|
"""Test cancelling a non-existent training job"""
|
|
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
|
response = await test_client.post("/training/jobs/nonexistent-job/cancel")
|
|
|
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_training_logs(
|
|
self,
|
|
test_client: AsyncClient,
|
|
training_job_in_db
|
|
):
|
|
"""Test getting training logs"""
|
|
job_id = training_job_in_db.job_id
|
|
|
|
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
|
response = await test_client.get(f"/training/jobs/{job_id}/logs")
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
|
|
assert "job_id" in data
|
|
assert "logs" in data
|
|
assert isinstance(data["logs"], list)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_validate_training_data_valid(
|
|
self,
|
|
test_client: AsyncClient,
|
|
mock_data_service
|
|
):
|
|
"""Test validating valid training data"""
|
|
request_data = {
|
|
"include_weather": True,
|
|
"include_traffic": True,
|
|
"min_data_points": 30
|
|
}
|
|
|
|
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
|
response = await test_client.post("/training/validate", json=request_data)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
|
|
assert "is_valid" in data
|
|
assert "issues" in data
|
|
assert "recommendations" in data
|
|
assert "estimated_training_time" in data
|
|
|
|
|
|
class TestSingleProductTrainingAPI:
|
|
"""Test single product training API endpoints"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_train_single_product_success(
|
|
self,
|
|
test_client: AsyncClient,
|
|
mock_messaging,
|
|
mock_ml_trainer,
|
|
mock_data_service
|
|
):
|
|
"""Test training a single product successfully"""
|
|
product_name = "Pan Integral"
|
|
request_data = {
|
|
"include_weather": True,
|
|
"include_traffic": True,
|
|
"seasonality_mode": "additive"
|
|
}
|
|
|
|
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
|
response = await test_client.post(
|
|
f"/training/products/{product_name}",
|
|
json=request_data
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
|
|
assert "job_id" in data
|
|
assert data["status"] == "started"
|
|
assert data["tenant_id"] == "test-tenant"
|
|
assert f"training started for {product_name}" in data["message"].lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_train_single_product_validation_error(self, test_client: AsyncClient):
|
|
"""Test single product training with validation error"""
|
|
product_name = "Pan Integral"
|
|
request_data = {
|
|
"seasonality_mode": "invalid_mode" # Invalid value
|
|
}
|
|
|
|
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
|
response = await test_client.post(
|
|
f"/training/products/{product_name}",
|
|
json=request_data
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_train_single_product_special_characters(
|
|
self,
|
|
test_client: AsyncClient,
|
|
mock_messaging,
|
|
mock_ml_trainer,
|
|
mock_data_service
|
|
):
|
|
"""Test training product with special characters in name"""
|
|
product_name = "Pan Francés" # With accent
|
|
request_data = {
|
|
"include_weather": True,
|
|
"seasonality_mode": "additive"
|
|
}
|
|
|
|
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
|
response = await test_client.post(
|
|
f"/training/products/{product_name}",
|
|
json=request_data
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert "job_id" in data
|
|
|
|
|
|
class TestModelsAPI:
|
|
"""Test models API endpoints"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_models(
|
|
self,
|
|
test_client: AsyncClient,
|
|
trained_model_in_db
|
|
):
|
|
"""Test listing trained models"""
|
|
with patch('app.api.models.get_current_tenant_id', return_value="test-tenant"):
|
|
response = await test_client.get("/models")
|
|
|
|
# This endpoint might not exist yet, so we expect either 200 or 404
|
|
assert response.status_code in [status.HTTP_200_OK, status.HTTP_404_NOT_FOUND]
|
|
|
|
if response.status_code == status.HTTP_200_OK:
|
|
data = response.json()
|
|
assert isinstance(data, list)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_model_details(
|
|
self,
|
|
test_client: AsyncClient,
|
|
trained_model_in_db
|
|
):
|
|
"""Test getting model details"""
|
|
model_id = trained_model_in_db.model_id
|
|
|
|
with patch('app.api.models.get_current_tenant_id', return_value="test-tenant"):
|
|
response = await test_client.get(f"/models/{model_id}")
|
|
|
|
# This endpoint might not exist yet
|
|
assert response.status_code in [
|
|
status.HTTP_200_OK,
|
|
status.HTTP_404_NOT_FOUND,
|
|
status.HTTP_501_NOT_IMPLEMENTED
|
|
]
|
|
|
|
|
|
class TestErrorHandling:
|
|
"""Test error handling in API endpoints"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_database_error_handling(self, test_client: AsyncClient):
|
|
"""Test handling of database errors"""
|
|
with patch('app.services.training_service.TrainingService.create_training_job') as mock_create:
|
|
mock_create.side_effect = Exception("Database connection failed")
|
|
|
|
request_data = {
|
|
"include_weather": True,
|
|
"include_traffic": True,
|
|
"min_data_points": 30
|
|
}
|
|
|
|
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
|
response = await test_client.post("/training/jobs", json=request_data)
|
|
|
|
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_missing_tenant_id(self, test_client: AsyncClient):
|
|
"""Test handling when tenant ID is missing"""
|
|
request_data = {
|
|
"include_weather": True,
|
|
"include_traffic": True,
|
|
"min_data_points": 30
|
|
}
|
|
|
|
# Don't mock get_current_tenant_id to simulate missing auth
|
|
response = await test_client.post("/training/jobs", json=request_data)
|
|
|
|
# Should fail due to missing authentication
|
|
assert response.status_code in [status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_job_id_format(self, test_client: AsyncClient):
|
|
"""Test handling of invalid job ID format"""
|
|
invalid_job_id = "invalid-job-id-with-special-chars@#$"
|
|
|
|
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
|
response = await test_client.get(f"/training/jobs/{invalid_job_id}/status")
|
|
|
|
# Should handle gracefully
|
|
assert response.status_code in [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_messaging_failure_handling(
|
|
self,
|
|
test_client: AsyncClient,
|
|
mock_data_service
|
|
):
|
|
"""Test handling when messaging fails"""
|
|
request_data = {
|
|
"include_weather": True,
|
|
"include_traffic": True,
|
|
"min_data_points": 30
|
|
}
|
|
|
|
with patch('app.services.messaging.publish_job_started', side_effect=Exception("Messaging failed")), \
|
|
patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
|
|
|
response = await test_client.post("/training/jobs", json=request_data)
|
|
|
|
# Should still succeed even if messaging fails
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert "job_id" in data
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_json_payload(self, test_client: AsyncClient):
|
|
"""Test handling of invalid JSON payload"""
|
|
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
|
response = await test_client.post(
|
|
"/training/jobs",
|
|
content="invalid json {{{",
|
|
headers={"Content-Type": "application/json"}
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unsupported_content_type(self, test_client: AsyncClient):
|
|
"""Test handling of unsupported content type"""
|
|
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
|
response = await test_client.post(
|
|
"/training/jobs",
|
|
content="some text data",
|
|
headers={"Content-Type": "text/plain"}
|
|
)
|
|
|
|
assert response.status_code in [
|
|
status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
|
|
status.HTTP_422_UNPROCESSABLE_ENTITY
|
|
]
|
|
|
|
|
|
class TestAuthenticationIntegration:
|
|
"""Test authentication integration"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_endpoints_require_auth(self, test_client: AsyncClient):
|
|
"""Test that endpoints require authentication in production"""
|
|
# This test would be more meaningful in a production environment
|
|
# where authentication is actually enforced
|
|
|
|
endpoints_to_test = [
|
|
("POST", "/training/jobs"),
|
|
("GET", "/training/jobs"),
|
|
("POST", "/training/products/Pan Integral"),
|
|
("POST", "/training/validate")
|
|
]
|
|
|
|
for method, endpoint in endpoints_to_test:
|
|
if method == "POST":
|
|
response = await test_client.post(endpoint, json={})
|
|
else:
|
|
response = await test_client.get(endpoint)
|
|
|
|
# In test environment with mocked auth, should work
|
|
# In production, would require valid authentication
|
|
assert response.status_code != status.HTTP_500_INTERNAL_SERVER_ERROR
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tenant_isolation_in_api(
|
|
self,
|
|
test_client: AsyncClient,
|
|
training_job_in_db
|
|
):
|
|
"""Test tenant isolation at API level"""
|
|
job_id = training_job_in_db.job_id
|
|
|
|
# Try to access job with different tenant
|
|
with patch('app.api.training.get_current_tenant_id', return_value="different-tenant"):
|
|
response = await test_client.get(f"/training/jobs/{job_id}/status")
|
|
|
|
# Should not find job for different tenant
|
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
|
|
|
|
|
class TestAPIValidation:
|
|
"""Test API validation and input handling"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_training_request_validation(self, test_client: AsyncClient):
|
|
"""Test comprehensive training request validation"""
|
|
|
|
# Test valid request
|
|
valid_request = {
|
|
"include_weather": True,
|
|
"include_traffic": False,
|
|
"min_data_points": 30,
|
|
"seasonality_mode": "additive",
|
|
"daily_seasonality": True,
|
|
"weekly_seasonality": True,
|
|
"yearly_seasonality": True
|
|
}
|
|
|
|
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
|
response = await test_client.post("/training/jobs", json=valid_request)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
|
|
# Test invalid seasonality mode
|
|
invalid_request = valid_request.copy()
|
|
invalid_request["seasonality_mode"] = "invalid_mode"
|
|
|
|
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
|
response = await test_client.post("/training/jobs", json=invalid_request)
|
|
|
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
|
|
|
# Test invalid min_data_points
|
|
invalid_request = valid_request.copy()
|
|
invalid_request["min_data_points"] = 5 # Too low
|
|
|
|
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
|
response = await test_client.post("/training/jobs", json=invalid_request)
|
|
|
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_single_product_request_validation(self, test_client: AsyncClient):
|
|
"""Test single product training request validation"""
|
|
|
|
product_name = "Pan Integral"
|
|
|
|
# Test valid request
|
|
valid_request = {
|
|
"include_weather": True,
|
|
"include_traffic": True,
|
|
"seasonality_mode": "multiplicative"
|
|
}
|
|
|
|
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
|
response = await test_client.post(
|
|
f"/training/products/{product_name}",
|
|
json=valid_request
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
|
|
# Test empty product name
|
|
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
|
response = await test_client.post(
|
|
"/training/products/",
|
|
json=valid_request
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_query_parameter_validation(self, test_client: AsyncClient):
|
|
"""Test query parameter validation"""
|
|
|
|
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
|
# Test valid limit parameter
|
|
response = await test_client.get("/training/jobs?limit=5")
|
|
assert response.status_code == status.HTTP_200_OK
|
|
|
|
# Test invalid limit parameter
|
|
response = await test_client.get("/training/jobs?limit=invalid")
|
|
assert response.status_code in [
|
|
status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
status.HTTP_400_BAD_REQUEST
|
|
]
|
|
|
|
# Test negative limit
|
|
response = await test_client.get("/training/jobs?limit=-1")
|
|
assert response.status_code in [
|
|
status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
status.HTTP_400_BAD_REQUEST
|
|
]
|
|
|
|
|
|
class TestAPIPerformance:
|
|
"""Test API performance characteristics"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_concurrent_requests(self, test_client: AsyncClient):
|
|
"""Test handling of concurrent requests"""
|
|
import asyncio
|
|
|
|
# Create multiple concurrent requests
|
|
tasks = []
|
|
for i in range(10):
|
|
with patch('app.api.training.get_current_tenant_id', return_value=f"tenant-{i}"):
|
|
task = test_client.get("/health")
|
|
tasks.append(task)
|
|
|
|
responses = await asyncio.gather(*tasks)
|
|
|
|
# All requests should succeed
|
|
for response in responses:
|
|
assert response.status_code == status.HTTP_200_OK
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_large_payload_handling(self, test_client: AsyncClient):
|
|
"""Test handling of large request payloads"""
|
|
|
|
# Create large request payload
|
|
large_request = {
|
|
"include_weather": True,
|
|
"include_traffic": True,
|
|
"min_data_points": 30,
|
|
"large_config": {f"key_{i}": f"value_{i}" for i in range(1000)}
|
|
}
|
|
|
|
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
|
response = await test_client.post("/training/jobs", json=large_request)
|
|
|
|
# Should handle large payload gracefully
|
|
assert response.status_code in [
|
|
status.HTTP_200_OK,
|
|
status.HTTP_413_REQUEST_ENTITY_TOO_LARGE
|
|
]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rapid_successive_requests(self, test_client: AsyncClient):
|
|
"""Test rapid successive requests to same endpoint"""
|
|
|
|
# Make rapid requests
|
|
responses = []
|
|
for _ in range(20):
|
|
response = await test_client.get("/health")
|
|
responses.append(response)
|
|
|
|
# All should succeed
|
|
for response in responses:
|
|
assert response.status_code == status.HTTP_200_OK |