Files
bakery-ia/services/training/tests/test_api.py
2025-07-19 16:59:37 +02:00

686 lines
25 KiB
Python

# services/training/tests/test_api.py
"""
Tests for training service API endpoints
"""
import pytest
from unittest.mock import AsyncMock, patch
from fastapi import status
from httpx import AsyncClient
from app.schemas.training import TrainingJobRequest
class TestTrainingAPI:
"""Test training API endpoints"""
@pytest.mark.asyncio
async def test_health_check(self, test_client: AsyncClient):
"""Test health check endpoint"""
response = await test_client.get("/health")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["service"] == "training-service"
assert data["version"] == "1.0.0"
assert "status" in data
@pytest.mark.asyncio
async def test_readiness_check_ready(self, test_client: AsyncClient):
"""Test readiness check when service is ready"""
# Mock app state as ready
with patch('app.main.app.state.ready', True):
response = await test_client.get("/health/ready")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["status"] == "ready"
@pytest.mark.asyncio
async def test_readiness_check_not_ready(self, test_client: AsyncClient):
"""Test readiness check when service is not ready"""
with patch('app.main.app.state.ready', False):
response = await test_client.get("/health/ready")
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
data = response.json()
assert data["status"] == "not_ready"
@pytest.mark.asyncio
async def test_liveness_check_healthy(self, test_client: AsyncClient):
"""Test liveness check when service is healthy"""
with patch('app.core.database.get_db_health', return_value=AsyncMock(return_value=True)):
response = await test_client.get("/health/live")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["status"] == "alive"
@pytest.mark.asyncio
async def test_liveness_check_unhealthy(self, test_client: AsyncClient):
"""Test liveness check when database is unhealthy"""
with patch('app.core.database.get_db_health', return_value=AsyncMock(return_value=False)):
response = await test_client.get("/health/live")
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
data = response.json()
assert data["status"] == "unhealthy"
assert data["reason"] == "database_unavailable"
@pytest.mark.asyncio
async def test_metrics_endpoint(self, test_client: AsyncClient):
"""Test metrics endpoint"""
response = await test_client.get("/metrics")
assert response.status_code == status.HTTP_200_OK
data = response.json()
expected_metrics = [
"training_jobs_active",
"training_jobs_completed",
"training_jobs_failed",
"models_trained_total",
"uptime_seconds"
]
for metric in expected_metrics:
assert metric in data
@pytest.mark.asyncio
async def test_root_endpoint(self, test_client: AsyncClient):
"""Test root endpoint"""
response = await test_client.get("/")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["service"] == "training-service"
assert data["version"] == "1.0.0"
assert "description" in data
class TestTrainingJobsAPI:
"""Test training jobs API endpoints"""
@pytest.mark.asyncio
async def test_start_training_job_success(
self,
test_client: AsyncClient,
mock_messaging,
mock_ml_trainer,
mock_data_service
):
"""Test starting a training job successfully"""
request_data = {
"include_weather": True,
"include_traffic": True,
"min_data_points": 30,
"seasonality_mode": "additive"
}
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=request_data)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "job_id" in data
assert data["status"] == "started"
assert data["tenant_id"] == "test-tenant"
assert "estimated_duration_minutes" in data
@pytest.mark.asyncio
async def test_start_training_job_validation_error(self, test_client: AsyncClient):
"""Test starting training job with validation error"""
request_data = {
"seasonality_mode": "invalid_mode", # Invalid value
"min_data_points": 5 # Too low
}
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=request_data)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@pytest.mark.asyncio
async def test_get_training_status_existing_job(
self,
test_client: AsyncClient,
training_job_in_db
):
"""Test getting status of existing training job"""
job_id = training_job_in_db.job_id
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.get(f"/training/jobs/{job_id}/status")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["job_id"] == job_id
assert data["status"] == "pending"
assert "progress" in data
assert "started_at" in data
@pytest.mark.asyncio
async def test_get_training_status_nonexistent_job(self, test_client: AsyncClient):
"""Test getting status of non-existent training job"""
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.get("/training/jobs/nonexistent-job/status")
assert response.status_code == status.HTTP_404_NOT_FOUND
@pytest.mark.asyncio
async def test_list_training_jobs(
self,
test_client: AsyncClient,
training_job_in_db
):
"""Test listing training jobs"""
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.get("/training/jobs")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert isinstance(data, list)
assert len(data) >= 1
# Check first job structure
job = data[0]
assert "job_id" in job
assert "status" in job
assert "started_at" in job
@pytest.mark.asyncio
async def test_list_training_jobs_with_status_filter(
self,
test_client: AsyncClient,
training_job_in_db
):
"""Test listing training jobs with status filter"""
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.get("/training/jobs?status=pending")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert isinstance(data, list)
# All jobs should have status "pending"
for job in data:
assert job["status"] == "pending"
@pytest.mark.asyncio
async def test_cancel_training_job_success(
self,
test_client: AsyncClient,
training_job_in_db,
mock_messaging
):
"""Test cancelling a training job successfully"""
job_id = training_job_in_db.job_id
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.post(f"/training/jobs/{job_id}/cancel")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "message" in data
assert "cancelled" in data["message"].lower()
@pytest.mark.asyncio
async def test_cancel_nonexistent_job(self, test_client: AsyncClient):
"""Test cancelling a non-existent training job"""
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.post("/training/jobs/nonexistent-job/cancel")
assert response.status_code == status.HTTP_404_NOT_FOUND
@pytest.mark.asyncio
async def test_get_training_logs(
self,
test_client: AsyncClient,
training_job_in_db
):
"""Test getting training logs"""
job_id = training_job_in_db.job_id
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.get(f"/training/jobs/{job_id}/logs")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "job_id" in data
assert "logs" in data
assert isinstance(data["logs"], list)
@pytest.mark.asyncio
async def test_validate_training_data_valid(
self,
test_client: AsyncClient,
mock_data_service
):
"""Test validating valid training data"""
request_data = {
"include_weather": True,
"include_traffic": True,
"min_data_points": 30
}
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.post("/training/validate", json=request_data)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "is_valid" in data
assert "issues" in data
assert "recommendations" in data
assert "estimated_training_time" in data
class TestSingleProductTrainingAPI:
"""Test single product training API endpoints"""
@pytest.mark.asyncio
async def test_train_single_product_success(
self,
test_client: AsyncClient,
mock_messaging,
mock_ml_trainer,
mock_data_service
):
"""Test training a single product successfully"""
product_name = "Pan Integral"
request_data = {
"include_weather": True,
"include_traffic": True,
"seasonality_mode": "additive"
}
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.post(
f"/training/products/{product_name}",
json=request_data
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "job_id" in data
assert data["status"] == "started"
assert data["tenant_id"] == "test-tenant"
assert f"training started for {product_name}" in data["message"].lower()
@pytest.mark.asyncio
async def test_train_single_product_validation_error(self, test_client: AsyncClient):
"""Test single product training with validation error"""
product_name = "Pan Integral"
request_data = {
"seasonality_mode": "invalid_mode" # Invalid value
}
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.post(
f"/training/products/{product_name}",
json=request_data
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@pytest.mark.asyncio
async def test_train_single_product_special_characters(
self,
test_client: AsyncClient,
mock_messaging,
mock_ml_trainer,
mock_data_service
):
"""Test training product with special characters in name"""
product_name = "Pan Francés" # With accent
request_data = {
"include_weather": True,
"seasonality_mode": "additive"
}
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.post(
f"/training/products/{product_name}",
json=request_data
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "job_id" in data
class TestModelsAPI:
"""Test models API endpoints"""
@pytest.mark.asyncio
async def test_list_models(
self,
test_client: AsyncClient,
trained_model_in_db
):
"""Test listing trained models"""
with patch('app.api.models.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.get("/models")
# This endpoint might not exist yet, so we expect either 200 or 404
assert response.status_code in [status.HTTP_200_OK, status.HTTP_404_NOT_FOUND]
if response.status_code == status.HTTP_200_OK:
data = response.json()
assert isinstance(data, list)
@pytest.mark.asyncio
async def test_get_model_details(
self,
test_client: AsyncClient,
trained_model_in_db
):
"""Test getting model details"""
model_id = trained_model_in_db.model_id
with patch('app.api.models.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.get(f"/models/{model_id}")
# This endpoint might not exist yet
assert response.status_code in [
status.HTTP_200_OK,
status.HTTP_404_NOT_FOUND,
status.HTTP_501_NOT_IMPLEMENTED
]
class TestErrorHandling:
"""Test error handling in API endpoints"""
@pytest.mark.asyncio
async def test_database_error_handling(self, test_client: AsyncClient):
"""Test handling of database errors"""
with patch('app.services.training_service.TrainingService.create_training_job') as mock_create:
mock_create.side_effect = Exception("Database connection failed")
request_data = {
"include_weather": True,
"include_traffic": True,
"min_data_points": 30
}
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=request_data)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
@pytest.mark.asyncio
async def test_missing_tenant_id(self, test_client: AsyncClient):
"""Test handling when tenant ID is missing"""
request_data = {
"include_weather": True,
"include_traffic": True,
"min_data_points": 30
}
# Don't mock get_current_tenant_id to simulate missing auth
response = await test_client.post("/training/jobs", json=request_data)
# Should fail due to missing authentication
assert response.status_code in [status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN]
@pytest.mark.asyncio
async def test_invalid_job_id_format(self, test_client: AsyncClient):
"""Test handling of invalid job ID format"""
invalid_job_id = "invalid-job-id-with-special-chars@#$"
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.get(f"/training/jobs/{invalid_job_id}/status")
# Should handle gracefully
assert response.status_code in [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST]
@pytest.mark.asyncio
async def test_messaging_failure_handling(
self,
test_client: AsyncClient,
mock_data_service
):
"""Test handling when messaging fails"""
request_data = {
"include_weather": True,
"include_traffic": True,
"min_data_points": 30
}
with patch('app.services.messaging.publish_job_started', side_effect=Exception("Messaging failed")), \
patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=request_data)
# Should still succeed even if messaging fails
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "job_id" in data
@pytest.mark.asyncio
async def test_invalid_json_payload(self, test_client: AsyncClient):
"""Test handling of invalid JSON payload"""
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.post(
"/training/jobs",
content="invalid json {{{",
headers={"Content-Type": "application/json"}
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@pytest.mark.asyncio
async def test_unsupported_content_type(self, test_client: AsyncClient):
"""Test handling of unsupported content type"""
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.post(
"/training/jobs",
content="some text data",
headers={"Content-Type": "text/plain"}
)
assert response.status_code in [
status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
status.HTTP_422_UNPROCESSABLE_ENTITY
]
class TestAuthenticationIntegration:
"""Test authentication integration"""
@pytest.mark.asyncio
async def test_endpoints_require_auth(self, test_client: AsyncClient):
"""Test that endpoints require authentication in production"""
# This test would be more meaningful in a production environment
# where authentication is actually enforced
endpoints_to_test = [
("POST", "/training/jobs"),
("GET", "/training/jobs"),
("POST", "/training/products/Pan Integral"),
("POST", "/training/validate")
]
for method, endpoint in endpoints_to_test:
if method == "POST":
response = await test_client.post(endpoint, json={})
else:
response = await test_client.get(endpoint)
# In test environment with mocked auth, should work
# In production, would require valid authentication
assert response.status_code != status.HTTP_500_INTERNAL_SERVER_ERROR
@pytest.mark.asyncio
async def test_tenant_isolation_in_api(
self,
test_client: AsyncClient,
training_job_in_db
):
"""Test tenant isolation at API level"""
job_id = training_job_in_db.job_id
# Try to access job with different tenant
with patch('app.api.training.get_current_tenant_id', return_value="different-tenant"):
response = await test_client.get(f"/training/jobs/{job_id}/status")
# Should not find job for different tenant
assert response.status_code == status.HTTP_404_NOT_FOUND
class TestAPIValidation:
"""Test API validation and input handling"""
@pytest.mark.asyncio
async def test_training_request_validation(self, test_client: AsyncClient):
"""Test comprehensive training request validation"""
# Test valid request
valid_request = {
"include_weather": True,
"include_traffic": False,
"min_data_points": 30,
"seasonality_mode": "additive",
"daily_seasonality": True,
"weekly_seasonality": True,
"yearly_seasonality": True
}
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=valid_request)
assert response.status_code == status.HTTP_200_OK
# Test invalid seasonality mode
invalid_request = valid_request.copy()
invalid_request["seasonality_mode"] = "invalid_mode"
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=invalid_request)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
# Test invalid min_data_points
invalid_request = valid_request.copy()
invalid_request["min_data_points"] = 5 # Too low
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=invalid_request)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@pytest.mark.asyncio
async def test_single_product_request_validation(self, test_client: AsyncClient):
"""Test single product training request validation"""
product_name = "Pan Integral"
# Test valid request
valid_request = {
"include_weather": True,
"include_traffic": True,
"seasonality_mode": "multiplicative"
}
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.post(
f"/training/products/{product_name}",
json=valid_request
)
assert response.status_code == status.HTTP_200_OK
# Test empty product name
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.post(
"/training/products/",
json=valid_request
)
assert response.status_code == status.HTTP_404_NOT_FOUND
@pytest.mark.asyncio
async def test_query_parameter_validation(self, test_client: AsyncClient):
"""Test query parameter validation"""
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
# Test valid limit parameter
response = await test_client.get("/training/jobs?limit=5")
assert response.status_code == status.HTTP_200_OK
# Test invalid limit parameter
response = await test_client.get("/training/jobs?limit=invalid")
assert response.status_code in [
status.HTTP_422_UNPROCESSABLE_ENTITY,
status.HTTP_400_BAD_REQUEST
]
# Test negative limit
response = await test_client.get("/training/jobs?limit=-1")
assert response.status_code in [
status.HTTP_422_UNPROCESSABLE_ENTITY,
status.HTTP_400_BAD_REQUEST
]
class TestAPIPerformance:
"""Test API performance characteristics"""
@pytest.mark.asyncio
async def test_concurrent_requests(self, test_client: AsyncClient):
"""Test handling of concurrent requests"""
import asyncio
# Create multiple concurrent requests
tasks = []
for i in range(10):
with patch('app.api.training.get_current_tenant_id', return_value=f"tenant-{i}"):
task = test_client.get("/health")
tasks.append(task)
responses = await asyncio.gather(*tasks)
# All requests should succeed
for response in responses:
assert response.status_code == status.HTTP_200_OK
@pytest.mark.asyncio
async def test_large_payload_handling(self, test_client: AsyncClient):
"""Test handling of large request payloads"""
# Create large request payload
large_request = {
"include_weather": True,
"include_traffic": True,
"min_data_points": 30,
"large_config": {f"key_{i}": f"value_{i}" for i in range(1000)}
}
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=large_request)
# Should handle large payload gracefully
assert response.status_code in [
status.HTTP_200_OK,
status.HTTP_413_REQUEST_ENTITY_TOO_LARGE
]
@pytest.mark.asyncio
async def test_rapid_successive_requests(self, test_client: AsyncClient):
"""Test rapid successive requests to same endpoint"""
# Make rapid requests
responses = []
for _ in range(20):
response = await test_client.get("/health")
responses.append(response)
# All should succeed
for response in responses:
assert response.status_code == status.HTTP_200_OK