Fix generating pytest for training service

This commit is contained in:
Urtzi Alfaro
2025-07-25 14:10:27 +02:00
parent 0dc12f4b93
commit e2b85162f0
14 changed files with 151 additions and 7448 deletions

View File

@@ -36,7 +36,7 @@ class TestTrainingWorkflowIntegration:
"seasonality_mode": "additive"
}
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=request_data)
assert response.status_code == 200
@@ -44,7 +44,7 @@ class TestTrainingWorkflowIntegration:
job_id = job_data["job_id"]
# Step 2: Check initial status
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.get(f"/training/jobs/{job_id}/status")
assert response.status_code == 200
@@ -56,7 +56,7 @@ class TestTrainingWorkflowIntegration:
await asyncio.sleep(0.1) # Allow background task to start
# Step 4: Check completion status
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.get(f"/training/jobs/{job_id}/status")
# The job should exist in database even if not completed yet
@@ -80,7 +80,7 @@ class TestTrainingWorkflowIntegration:
}
# Start single product training
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post(
f"/training/products/{product_name}",
json=request_data
@@ -92,7 +92,7 @@ class TestTrainingWorkflowIntegration:
assert f"training started for {product_name}" in job_data["message"].lower()
# Check job status
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.get(f"/training/jobs/{job_id}/status")
assert response.status_code == 200
@@ -114,7 +114,7 @@ class TestTrainingWorkflowIntegration:
}
# Validate training data
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/validate", json=request_data)
assert response.status_code == 200
@@ -127,7 +127,7 @@ class TestTrainingWorkflowIntegration:
# If validation passes, start actual training
if validation_data["is_valid"]:
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=request_data)
assert response.status_code == 200
@@ -144,7 +144,7 @@ class TestTrainingWorkflowIntegration:
job_id = training_job_in_db.job_id
# Check initial status
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.get(f"/training/jobs/{job_id}/status")
assert response.status_code == 200
@@ -152,7 +152,7 @@ class TestTrainingWorkflowIntegration:
assert initial_status["status"] == "pending"
# Cancel the job
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post(f"/training/jobs/{job_id}/cancel")
assert response.status_code == 200
@@ -160,7 +160,7 @@ class TestTrainingWorkflowIntegration:
assert "cancelled" in cancel_response["message"].lower()
# Verify cancellation
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.get(f"/training/jobs/{job_id}/status")
assert response.status_code == 200
@@ -267,7 +267,7 @@ class TestErrorHandlingIntegration:
with patch('httpx.AsyncClient') as mock_client:
mock_client.return_value.__aenter__.return_value.get.side_effect = Exception("Service unavailable")
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=request_data)
# Should still create job but might fail during execution
@@ -289,7 +289,7 @@ class TestErrorHandlingIntegration:
# Mock messaging failure
with patch('app.services.messaging.publish_job_started', side_effect=Exception("Messaging failed")):
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=request_data)
# Should still succeed even if messaging fails
@@ -312,7 +312,7 @@ class TestErrorHandlingIntegration:
# Mock ML training failure
with patch('app.ml.trainer.BakeryMLTrainer.train_tenant_models', side_effect=Exception("ML training failed")):
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=request_data)
# Job should be created successfully
@@ -394,7 +394,7 @@ class TestPerformanceIntegration:
# Make many rapid status requests
tasks = []
for _ in range(20):
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
task = test_client.get(f"/training/jobs/{job_id}/status")
tasks.append(task)
@@ -439,7 +439,7 @@ class TestSecurityIntegration:
"min_data_points": -5 # Invalid negative value
}
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=invalid_request)
assert response.status_code == 422 # Validation error
@@ -454,7 +454,7 @@ class TestSecurityIntegration:
# Try SQL injection in job ID
malicious_job_id = "job'; DROP TABLE model_training_logs; --"
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.get(f"/training/jobs/{malicious_job_id}/status")
# Should return 404, not cause database error
@@ -801,7 +801,7 @@ class TestBackwardCompatibility:
"include_weather": True
}
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=minimal_request)
# Should work with defaults for missing fields