Fix generating pytest for training service

This commit is contained in:
Urtzi Alfaro
2025-07-25 14:10:27 +02:00
parent 0dc12f4b93
commit e2b85162f0
14 changed files with 151 additions and 7448 deletions

View File

@@ -29,7 +29,8 @@ class TestTrainingAPI:
async def test_readiness_check_ready(self, test_client: AsyncClient):
"""Test readiness check when service is ready"""
# Mock app state as ready
with patch('app.main.app.state.ready', True):
from app.main import app # Add import at top
with patch.object(app.state, 'ready', True, create=True):
response = await test_client.get("/health/ready")
assert response.status_code == status.HTTP_200_OK
@@ -117,7 +118,7 @@ class TestTrainingJobsAPI:
"seasonality_mode": "additive"
}
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=request_data)
assert response.status_code == status.HTTP_200_OK
@@ -136,7 +137,7 @@ class TestTrainingJobsAPI:
"min_data_points": 5 # Too low
}
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=request_data)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@@ -150,7 +151,7 @@ class TestTrainingJobsAPI:
"""Test getting status of existing training job"""
job_id = training_job_in_db.job_id
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.get(f"/training/jobs/{job_id}/status")
assert response.status_code == status.HTTP_200_OK
@@ -164,7 +165,7 @@ class TestTrainingJobsAPI:
@pytest.mark.asyncio
async def test_get_training_status_nonexistent_job(self, test_client: AsyncClient):
"""Test getting status of non-existent training job"""
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.get("/training/jobs/nonexistent-job/status")
assert response.status_code == status.HTTP_404_NOT_FOUND
@@ -176,7 +177,7 @@ class TestTrainingJobsAPI:
training_job_in_db
):
"""Test listing training jobs"""
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.get("/training/jobs")
assert response.status_code == status.HTTP_200_OK
@@ -198,7 +199,7 @@ class TestTrainingJobsAPI:
training_job_in_db
):
"""Test listing training jobs with status filter"""
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.get("/training/jobs?status=pending")
assert response.status_code == status.HTTP_200_OK
@@ -219,7 +220,7 @@ class TestTrainingJobsAPI:
"""Test cancelling a training job successfully"""
job_id = training_job_in_db.job_id
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post(f"/training/jobs/{job_id}/cancel")
assert response.status_code == status.HTTP_200_OK
@@ -230,7 +231,7 @@ class TestTrainingJobsAPI:
@pytest.mark.asyncio
async def test_cancel_nonexistent_job(self, test_client: AsyncClient):
"""Test cancelling a non-existent training job"""
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/jobs/nonexistent-job/cancel")
assert response.status_code == status.HTTP_404_NOT_FOUND
@@ -244,7 +245,7 @@ class TestTrainingJobsAPI:
"""Test getting training logs"""
job_id = training_job_in_db.job_id
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.get(f"/training/jobs/{job_id}/logs")
assert response.status_code == status.HTTP_200_OK
@@ -267,7 +268,7 @@ class TestTrainingJobsAPI:
"min_data_points": 30
}
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/validate", json=request_data)
assert response.status_code == status.HTTP_200_OK
@@ -298,7 +299,7 @@ class TestSingleProductTrainingAPI:
"seasonality_mode": "additive"
}
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post(
f"/training/products/{product_name}",
json=request_data
@@ -320,7 +321,7 @@ class TestSingleProductTrainingAPI:
"seasonality_mode": "invalid_mode" # Invalid value
}
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post(
f"/training/products/{product_name}",
json=request_data
@@ -343,7 +344,7 @@ class TestSingleProductTrainingAPI:
"seasonality_mode": "additive"
}
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post(
f"/training/products/{product_name}",
json=request_data
@@ -409,7 +410,7 @@ class TestErrorHandling:
"min_data_points": 30
}
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=request_data)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
@@ -434,7 +435,7 @@ class TestErrorHandling:
"""Test handling of invalid job ID format"""
invalid_job_id = "invalid-job-id-with-special-chars@#$"
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.get(f"/training/jobs/{invalid_job_id}/status")
# Should handle gracefully
@@ -454,7 +455,7 @@ class TestErrorHandling:
}
with patch('app.services.messaging.publish_job_started', side_effect=Exception("Messaging failed")), \
patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=request_data)
@@ -466,7 +467,7 @@ class TestErrorHandling:
@pytest.mark.asyncio
async def test_invalid_json_payload(self, test_client: AsyncClient):
"""Test handling of invalid JSON payload"""
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post(
"/training/jobs",
content="invalid json {{{",
@@ -478,7 +479,7 @@ class TestErrorHandling:
@pytest.mark.asyncio
async def test_unsupported_content_type(self, test_client: AsyncClient):
"""Test handling of unsupported content type"""
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post(
"/training/jobs",
content="some text data",
@@ -552,7 +553,7 @@ class TestAPIValidation:
"yearly_seasonality": True
}
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=valid_request)
assert response.status_code == status.HTTP_200_OK
@@ -561,7 +562,7 @@ class TestAPIValidation:
invalid_request = valid_request.copy()
invalid_request["seasonality_mode"] = "invalid_mode"
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=invalid_request)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@@ -570,7 +571,7 @@ class TestAPIValidation:
invalid_request = valid_request.copy()
invalid_request["min_data_points"] = 5 # Too low
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=invalid_request)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@@ -588,7 +589,7 @@ class TestAPIValidation:
"seasonality_mode": "multiplicative"
}
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post(
f"/training/products/{product_name}",
json=valid_request
@@ -597,7 +598,7 @@ class TestAPIValidation:
assert response.status_code == status.HTTP_200_OK
# Test empty product name
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post(
"/training/products/",
json=valid_request
@@ -609,7 +610,7 @@ class TestAPIValidation:
async def test_query_parameter_validation(self, test_client: AsyncClient):
"""Test query parameter validation"""
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
# Test valid limit parameter
response = await test_client.get("/training/jobs?limit=5")
assert response.status_code == status.HTTP_200_OK
@@ -662,7 +663,7 @@ class TestAPIPerformance:
"large_config": {f"key_{i}": f"value_{i}" for i in range(1000)}
}
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=large_request)
# Should handle large payload gracefully