Add all the code for training service
This commit is contained in:
686
services/training/tests/test_api.py
Normal file
686
services/training/tests/test_api.py
Normal file
@@ -0,0 +1,686 @@
|
||||
# services/training/tests/test_api.py
|
||||
"""
|
||||
Tests for training service API endpoints
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from fastapi import status
|
||||
from httpx import AsyncClient
|
||||
|
||||
from app.schemas.training import TrainingJobRequest
|
||||
|
||||
|
||||
class TestTrainingAPI:
|
||||
"""Test training API endpoints"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check(self, test_client: AsyncClient):
|
||||
"""Test health check endpoint"""
|
||||
response = await test_client.get("/health")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["service"] == "training-service"
|
||||
assert data["version"] == "1.0.0"
|
||||
assert "status" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_readiness_check_ready(self, test_client: AsyncClient):
|
||||
"""Test readiness check when service is ready"""
|
||||
# Mock app state as ready
|
||||
with patch('app.main.app.state.ready', True):
|
||||
response = await test_client.get("/health/ready")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["status"] == "ready"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_readiness_check_not_ready(self, test_client: AsyncClient):
|
||||
"""Test readiness check when service is not ready"""
|
||||
with patch('app.main.app.state.ready', False):
|
||||
response = await test_client.get("/health/ready")
|
||||
|
||||
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
data = response.json()
|
||||
assert data["status"] == "not_ready"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_liveness_check_healthy(self, test_client: AsyncClient):
|
||||
"""Test liveness check when service is healthy"""
|
||||
with patch('app.core.database.get_db_health', return_value=AsyncMock(return_value=True)):
|
||||
response = await test_client.get("/health/live")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["status"] == "alive"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_liveness_check_unhealthy(self, test_client: AsyncClient):
|
||||
"""Test liveness check when database is unhealthy"""
|
||||
with patch('app.core.database.get_db_health', return_value=AsyncMock(return_value=False)):
|
||||
response = await test_client.get("/health/live")
|
||||
|
||||
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
data = response.json()
|
||||
assert data["status"] == "unhealthy"
|
||||
assert data["reason"] == "database_unavailable"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metrics_endpoint(self, test_client: AsyncClient):
|
||||
"""Test metrics endpoint"""
|
||||
response = await test_client.get("/metrics")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
expected_metrics = [
|
||||
"training_jobs_active",
|
||||
"training_jobs_completed",
|
||||
"training_jobs_failed",
|
||||
"models_trained_total",
|
||||
"uptime_seconds"
|
||||
]
|
||||
|
||||
for metric in expected_metrics:
|
||||
assert metric in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_root_endpoint(self, test_client: AsyncClient):
|
||||
"""Test root endpoint"""
|
||||
response = await test_client.get("/")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["service"] == "training-service"
|
||||
assert data["version"] == "1.0.0"
|
||||
assert "description" in data
|
||||
|
||||
|
||||
class TestTrainingJobsAPI:
|
||||
"""Test training jobs API endpoints"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_training_job_success(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
mock_messaging,
|
||||
mock_ml_trainer,
|
||||
mock_data_service
|
||||
):
|
||||
"""Test starting a training job successfully"""
|
||||
request_data = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 30,
|
||||
"seasonality_mode": "additive"
|
||||
}
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs", json=request_data)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert "job_id" in data
|
||||
assert data["status"] == "started"
|
||||
assert data["tenant_id"] == "test-tenant"
|
||||
assert "estimated_duration_minutes" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_training_job_validation_error(self, test_client: AsyncClient):
|
||||
"""Test starting training job with validation error"""
|
||||
request_data = {
|
||||
"seasonality_mode": "invalid_mode", # Invalid value
|
||||
"min_data_points": 5 # Too low
|
||||
}
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs", json=request_data)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_training_status_existing_job(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
training_job_in_db
|
||||
):
|
||||
"""Test getting status of existing training job"""
|
||||
job_id = training_job_in_db.job_id
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.get(f"/training/jobs/{job_id}/status")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert data["job_id"] == job_id
|
||||
assert data["status"] == "pending"
|
||||
assert "progress" in data
|
||||
assert "started_at" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_training_status_nonexistent_job(self, test_client: AsyncClient):
|
||||
"""Test getting status of non-existent training job"""
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.get("/training/jobs/nonexistent-job/status")
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_training_jobs(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
training_job_in_db
|
||||
):
|
||||
"""Test listing training jobs"""
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.get("/training/jobs")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert isinstance(data, list)
|
||||
assert len(data) >= 1
|
||||
|
||||
# Check first job structure
|
||||
job = data[0]
|
||||
assert "job_id" in job
|
||||
assert "status" in job
|
||||
assert "started_at" in job
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_training_jobs_with_status_filter(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
training_job_in_db
|
||||
):
|
||||
"""Test listing training jobs with status filter"""
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.get("/training/jobs?status=pending")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert isinstance(data, list)
|
||||
# All jobs should have status "pending"
|
||||
for job in data:
|
||||
assert job["status"] == "pending"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_training_job_success(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
training_job_in_db,
|
||||
mock_messaging
|
||||
):
|
||||
"""Test cancelling a training job successfully"""
|
||||
job_id = training_job_in_db.job_id
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post(f"/training/jobs/{job_id}/cancel")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert "message" in data
|
||||
assert "cancelled" in data["message"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_nonexistent_job(self, test_client: AsyncClient):
|
||||
"""Test cancelling a non-existent training job"""
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs/nonexistent-job/cancel")
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_training_logs(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
training_job_in_db
|
||||
):
|
||||
"""Test getting training logs"""
|
||||
job_id = training_job_in_db.job_id
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.get(f"/training/jobs/{job_id}/logs")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert "job_id" in data
|
||||
assert "logs" in data
|
||||
assert isinstance(data["logs"], list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_training_data_valid(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
mock_data_service
|
||||
):
|
||||
"""Test validating valid training data"""
|
||||
request_data = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 30
|
||||
}
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/validate", json=request_data)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert "is_valid" in data
|
||||
assert "issues" in data
|
||||
assert "recommendations" in data
|
||||
assert "estimated_training_time" in data
|
||||
|
||||
|
||||
class TestSingleProductTrainingAPI:
|
||||
"""Test single product training API endpoints"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_train_single_product_success(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
mock_messaging,
|
||||
mock_ml_trainer,
|
||||
mock_data_service
|
||||
):
|
||||
"""Test training a single product successfully"""
|
||||
product_name = "Pan Integral"
|
||||
request_data = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"seasonality_mode": "additive"
|
||||
}
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post(
|
||||
f"/training/products/{product_name}",
|
||||
json=request_data
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert "job_id" in data
|
||||
assert data["status"] == "started"
|
||||
assert data["tenant_id"] == "test-tenant"
|
||||
assert f"training started for {product_name}" in data["message"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_train_single_product_validation_error(self, test_client: AsyncClient):
|
||||
"""Test single product training with validation error"""
|
||||
product_name = "Pan Integral"
|
||||
request_data = {
|
||||
"seasonality_mode": "invalid_mode" # Invalid value
|
||||
}
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post(
|
||||
f"/training/products/{product_name}",
|
||||
json=request_data
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_train_single_product_special_characters(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
mock_messaging,
|
||||
mock_ml_trainer,
|
||||
mock_data_service
|
||||
):
|
||||
"""Test training product with special characters in name"""
|
||||
product_name = "Pan Francés" # With accent
|
||||
request_data = {
|
||||
"include_weather": True,
|
||||
"seasonality_mode": "additive"
|
||||
}
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post(
|
||||
f"/training/products/{product_name}",
|
||||
json=request_data
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert "job_id" in data
|
||||
|
||||
|
||||
class TestModelsAPI:
|
||||
"""Test models API endpoints"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_models(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
trained_model_in_db
|
||||
):
|
||||
"""Test listing trained models"""
|
||||
with patch('app.api.models.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.get("/models")
|
||||
|
||||
# This endpoint might not exist yet, so we expect either 200 or 404
|
||||
assert response.status_code in [status.HTTP_200_OK, status.HTTP_404_NOT_FOUND]
|
||||
|
||||
if response.status_code == status.HTTP_200_OK:
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_model_details(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
trained_model_in_db
|
||||
):
|
||||
"""Test getting model details"""
|
||||
model_id = trained_model_in_db.model_id
|
||||
|
||||
with patch('app.api.models.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.get(f"/models/{model_id}")
|
||||
|
||||
# This endpoint might not exist yet
|
||||
assert response.status_code in [
|
||||
status.HTTP_200_OK,
|
||||
status.HTTP_404_NOT_FOUND,
|
||||
status.HTTP_501_NOT_IMPLEMENTED
|
||||
]
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Test error handling in API endpoints"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_database_error_handling(self, test_client: AsyncClient):
|
||||
"""Test handling of database errors"""
|
||||
with patch('app.services.training_service.TrainingService.create_training_job') as mock_create:
|
||||
mock_create.side_effect = Exception("Database connection failed")
|
||||
|
||||
request_data = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 30
|
||||
}
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs", json=request_data)
|
||||
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_tenant_id(self, test_client: AsyncClient):
|
||||
"""Test handling when tenant ID is missing"""
|
||||
request_data = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 30
|
||||
}
|
||||
|
||||
# Don't mock get_current_tenant_id to simulate missing auth
|
||||
response = await test_client.post("/training/jobs", json=request_data)
|
||||
|
||||
# Should fail due to missing authentication
|
||||
assert response.status_code in [status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_job_id_format(self, test_client: AsyncClient):
|
||||
"""Test handling of invalid job ID format"""
|
||||
invalid_job_id = "invalid-job-id-with-special-chars@#$"
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.get(f"/training/jobs/{invalid_job_id}/status")
|
||||
|
||||
# Should handle gracefully
|
||||
assert response.status_code in [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messaging_failure_handling(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
mock_data_service
|
||||
):
|
||||
"""Test handling when messaging fails"""
|
||||
request_data = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 30
|
||||
}
|
||||
|
||||
with patch('app.services.messaging.publish_job_started', side_effect=Exception("Messaging failed")), \
|
||||
patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
|
||||
response = await test_client.post("/training/jobs", json=request_data)
|
||||
|
||||
# Should still succeed even if messaging fails
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert "job_id" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_json_payload(self, test_client: AsyncClient):
|
||||
"""Test handling of invalid JSON payload"""
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post(
|
||||
"/training/jobs",
|
||||
content="invalid json {{{",
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsupported_content_type(self, test_client: AsyncClient):
|
||||
"""Test handling of unsupported content type"""
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post(
|
||||
"/training/jobs",
|
||||
content="some text data",
|
||||
headers={"Content-Type": "text/plain"}
|
||||
)
|
||||
|
||||
assert response.status_code in [
|
||||
status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
|
||||
status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
]
|
||||
|
||||
|
||||
class TestAuthenticationIntegration:
|
||||
"""Test authentication integration"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_endpoints_require_auth(self, test_client: AsyncClient):
|
||||
"""Test that endpoints require authentication in production"""
|
||||
# This test would be more meaningful in a production environment
|
||||
# where authentication is actually enforced
|
||||
|
||||
endpoints_to_test = [
|
||||
("POST", "/training/jobs"),
|
||||
("GET", "/training/jobs"),
|
||||
("POST", "/training/products/Pan Integral"),
|
||||
("POST", "/training/validate")
|
||||
]
|
||||
|
||||
for method, endpoint in endpoints_to_test:
|
||||
if method == "POST":
|
||||
response = await test_client.post(endpoint, json={})
|
||||
else:
|
||||
response = await test_client.get(endpoint)
|
||||
|
||||
# In test environment with mocked auth, should work
|
||||
# In production, would require valid authentication
|
||||
assert response.status_code != status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tenant_isolation_in_api(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
training_job_in_db
|
||||
):
|
||||
"""Test tenant isolation at API level"""
|
||||
job_id = training_job_in_db.job_id
|
||||
|
||||
# Try to access job with different tenant
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="different-tenant"):
|
||||
response = await test_client.get(f"/training/jobs/{job_id}/status")
|
||||
|
||||
# Should not find job for different tenant
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
class TestAPIValidation:
|
||||
"""Test API validation and input handling"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_training_request_validation(self, test_client: AsyncClient):
|
||||
"""Test comprehensive training request validation"""
|
||||
|
||||
# Test valid request
|
||||
valid_request = {
|
||||
"include_weather": True,
|
||||
"include_traffic": False,
|
||||
"min_data_points": 30,
|
||||
"seasonality_mode": "additive",
|
||||
"daily_seasonality": True,
|
||||
"weekly_seasonality": True,
|
||||
"yearly_seasonality": True
|
||||
}
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs", json=valid_request)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
# Test invalid seasonality mode
|
||||
invalid_request = valid_request.copy()
|
||||
invalid_request["seasonality_mode"] = "invalid_mode"
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs", json=invalid_request)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
# Test invalid min_data_points
|
||||
invalid_request = valid_request.copy()
|
||||
invalid_request["min_data_points"] = 5 # Too low
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs", json=invalid_request)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_product_request_validation(self, test_client: AsyncClient):
|
||||
"""Test single product training request validation"""
|
||||
|
||||
product_name = "Pan Integral"
|
||||
|
||||
# Test valid request
|
||||
valid_request = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"seasonality_mode": "multiplicative"
|
||||
}
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post(
|
||||
f"/training/products/{product_name}",
|
||||
json=valid_request
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
# Test empty product name
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post(
|
||||
"/training/products/",
|
||||
json=valid_request
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_parameter_validation(self, test_client: AsyncClient):
|
||||
"""Test query parameter validation"""
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
# Test valid limit parameter
|
||||
response = await test_client.get("/training/jobs?limit=5")
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
# Test invalid limit parameter
|
||||
response = await test_client.get("/training/jobs?limit=invalid")
|
||||
assert response.status_code in [
|
||||
status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
status.HTTP_400_BAD_REQUEST
|
||||
]
|
||||
|
||||
# Test negative limit
|
||||
response = await test_client.get("/training/jobs?limit=-1")
|
||||
assert response.status_code in [
|
||||
status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
status.HTTP_400_BAD_REQUEST
|
||||
]
|
||||
|
||||
|
||||
class TestAPIPerformance:
|
||||
"""Test API performance characteristics"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_requests(self, test_client: AsyncClient):
|
||||
"""Test handling of concurrent requests"""
|
||||
import asyncio
|
||||
|
||||
# Create multiple concurrent requests
|
||||
tasks = []
|
||||
for i in range(10):
|
||||
with patch('app.api.training.get_current_tenant_id', return_value=f"tenant-{i}"):
|
||||
task = test_client.get("/health")
|
||||
tasks.append(task)
|
||||
|
||||
responses = await asyncio.gather(*tasks)
|
||||
|
||||
# All requests should succeed
|
||||
for response in responses:
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_payload_handling(self, test_client: AsyncClient):
|
||||
"""Test handling of large request payloads"""
|
||||
|
||||
# Create large request payload
|
||||
large_request = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 30,
|
||||
"large_config": {f"key_{i}": f"value_{i}" for i in range(1000)}
|
||||
}
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs", json=large_request)
|
||||
|
||||
# Should handle large payload gracefully
|
||||
assert response.status_code in [
|
||||
status.HTTP_200_OK,
|
||||
status.HTTP_413_REQUEST_ENTITY_TOO_LARGE
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rapid_successive_requests(self, test_client: AsyncClient):
|
||||
"""Test rapid successive requests to same endpoint"""
|
||||
|
||||
# Make rapid requests
|
||||
responses = []
|
||||
for _ in range(20):
|
||||
response = await test_client.get("/health")
|
||||
responses.append(response)
|
||||
|
||||
# All should succeed
|
||||
for response in responses:
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
Reference in New Issue
Block a user