# services/training/tests/test_api.py """ Tests for training service API endpoints """ import pytest from unittest.mock import AsyncMock, patch from fastapi import status from httpx import AsyncClient from app.schemas.training import TrainingJobRequest class TestTrainingAPI: """Test training API endpoints""" @pytest.mark.asyncio async def test_health_check(self, test_client: AsyncClient): """Test health check endpoint""" response = await test_client.get("/health") assert response.status_code == status.HTTP_200_OK data = response.json() assert data["service"] == "training-service" assert data["version"] == "1.0.0" assert "status" in data @pytest.mark.asyncio async def test_readiness_check_ready(self, test_client: AsyncClient): """Test readiness check when service is ready""" # Mock app state as ready from app.main import app # Add import at top with patch.object(app.state, 'ready', True, create=True): response = await test_client.get("/health/ready") assert response.status_code == status.HTTP_200_OK data = response.json() assert data["status"] == "ready" @pytest.mark.asyncio async def test_readiness_check_not_ready(self, test_client: AsyncClient): """Test readiness check when service is not ready""" with patch('app.main.app.state.ready', False): response = await test_client.get("/health/ready") assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE data = response.json() assert data["status"] == "not_ready" @pytest.mark.asyncio async def test_liveness_check_healthy(self, test_client: AsyncClient): """Test liveness check when service is healthy""" with patch('app.core.database.get_db_health', return_value=AsyncMock(return_value=True)): response = await test_client.get("/health/live") assert response.status_code == status.HTTP_200_OK data = response.json() assert data["status"] == "alive" @pytest.mark.asyncio async def test_liveness_check_unhealthy(self, test_client: AsyncClient): """Test liveness check when database is unhealthy""" with patch('app.core.database.get_db_health', return_value=AsyncMock(return_value=False)): response = await test_client.get("/health/live") assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE data = response.json() assert data["status"] == "unhealthy" assert data["reason"] == "database_unavailable" @pytest.mark.asyncio async def test_metrics_endpoint(self, test_client: AsyncClient): """Test metrics endpoint""" response = await test_client.get("/metrics") assert response.status_code == status.HTTP_200_OK data = response.json() expected_metrics = [ "training_jobs_active", "training_jobs_completed", "training_jobs_failed", "models_trained_total", "uptime_seconds" ] for metric in expected_metrics: assert metric in data @pytest.mark.asyncio async def test_root_endpoint(self, test_client: AsyncClient): """Test root endpoint""" response = await test_client.get("/") assert response.status_code == status.HTTP_200_OK data = response.json() assert data["service"] == "training-service" assert data["version"] == "1.0.0" assert "description" in data class TestTrainingJobsAPI: """Test training jobs API endpoints""" @pytest.mark.asyncio async def test_start_training_job_success( self, test_client: AsyncClient, mock_messaging, mock_ml_trainer, mock_data_service ): """Test starting a training job successfully""" request_data = { "include_weather": True, "include_traffic": True, "min_data_points": 30, "seasonality_mode": "additive" } with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"): response = await test_client.post("/training/jobs", json=request_data) assert response.status_code == status.HTTP_200_OK data = response.json() assert "job_id" in data assert data["status"] == "started" assert data["tenant_id"] == "test-tenant" assert "estimated_duration_minutes" in data @pytest.mark.asyncio async def test_start_training_job_validation_error(self, test_client: AsyncClient): """Test starting training job with validation error""" request_data = { "seasonality_mode": "invalid_mode", # Invalid value "min_data_points": 5 # Too low } with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"): response = await test_client.post("/training/jobs", json=request_data) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY @pytest.mark.asyncio async def test_get_training_status_existing_job( self, test_client: AsyncClient, training_job_in_db ): """Test getting status of existing training job""" job_id = training_job_in_db.job_id with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"): response = await test_client.get(f"/training/jobs/{job_id}/status") assert response.status_code == status.HTTP_200_OK data = response.json() assert data["job_id"] == job_id assert data["status"] == "pending" assert "progress" in data assert "started_at" in data @pytest.mark.asyncio async def test_get_training_status_nonexistent_job(self, test_client: AsyncClient): """Test getting status of non-existent training job""" with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"): response = await test_client.get("/training/jobs/nonexistent-job/status") assert response.status_code == status.HTTP_404_NOT_FOUND @pytest.mark.asyncio async def test_list_training_jobs( self, test_client: AsyncClient, training_job_in_db ): """Test listing training jobs""" with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"): response = await test_client.get("/training/jobs") assert response.status_code == status.HTTP_200_OK data = response.json() assert isinstance(data, list) assert len(data) >= 1 # Check first job structure job = data[0] assert "job_id" in job assert "status" in job assert "started_at" in job @pytest.mark.asyncio async def test_list_training_jobs_with_status_filter( self, test_client: AsyncClient, training_job_in_db ): """Test listing training jobs with status filter""" with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"): response = await test_client.get("/training/jobs?status=pending") assert response.status_code == status.HTTP_200_OK data = response.json() assert isinstance(data, list) # All jobs should have status "pending" for job in data: assert job["status"] == "pending" @pytest.mark.asyncio async def test_cancel_training_job_success( self, test_client: AsyncClient, training_job_in_db, mock_messaging ): """Test cancelling a training job successfully""" job_id = training_job_in_db.job_id with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"): response = await test_client.post(f"/training/jobs/{job_id}/cancel") assert response.status_code == status.HTTP_200_OK data = response.json() assert "message" in data assert "cancelled" in data["message"].lower() @pytest.mark.asyncio async def test_cancel_nonexistent_job(self, test_client: AsyncClient): """Test cancelling a non-existent training job""" with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"): response = await test_client.post("/training/jobs/nonexistent-job/cancel") assert response.status_code == status.HTTP_404_NOT_FOUND @pytest.mark.asyncio async def test_get_training_logs( self, test_client: AsyncClient, training_job_in_db ): """Test getting training logs""" job_id = training_job_in_db.job_id with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"): response = await test_client.get(f"/training/jobs/{job_id}/logs") assert response.status_code == status.HTTP_200_OK data = response.json() assert "job_id" in data assert "logs" in data assert isinstance(data["logs"], list) @pytest.mark.asyncio async def test_validate_training_data_valid( self, test_client: AsyncClient, mock_data_service ): """Test validating valid training data""" request_data = { "include_weather": True, "include_traffic": True, "min_data_points": 30 } with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"): response = await test_client.post("/training/validate", json=request_data) assert response.status_code == status.HTTP_200_OK data = response.json() assert "is_valid" in data assert "issues" in data assert "recommendations" in data assert "estimated_training_time" in data class TestSingleProductTrainingAPI: """Test single product training API endpoints""" @pytest.mark.asyncio async def test_train_single_product_success( self, test_client: AsyncClient, mock_messaging, mock_ml_trainer, mock_data_service ): """Test training a single product successfully""" product_name = "Pan Integral" request_data = { "include_weather": True, "include_traffic": True, "seasonality_mode": "additive" } with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"): response = await test_client.post( f"/training/products/{product_name}", json=request_data ) assert response.status_code == status.HTTP_200_OK data = response.json() assert "job_id" in data assert data["status"] == "started" assert data["tenant_id"] == "test-tenant" assert f"training started for {product_name}" in data["message"].lower() @pytest.mark.asyncio async def test_train_single_product_validation_error(self, test_client: AsyncClient): """Test single product training with validation error""" product_name = "Pan Integral" request_data = { "seasonality_mode": "invalid_mode" # Invalid value } with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"): response = await test_client.post( f"/training/products/{product_name}", json=request_data ) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY @pytest.mark.asyncio async def test_train_single_product_special_characters( self, test_client: AsyncClient, mock_messaging, mock_ml_trainer, mock_data_service ): """Test training product with special characters in name""" product_name = "Pan Francés" # With accent request_data = { "include_weather": True, "seasonality_mode": "additive" } with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"): response = await test_client.post( f"/training/products/{product_name}", json=request_data ) assert response.status_code == status.HTTP_200_OK data = response.json() assert "job_id" in data class TestModelsAPI: """Test models API endpoints""" @pytest.mark.asyncio async def test_list_models( self, test_client: AsyncClient, trained_model_in_db ): """Test listing trained models""" with patch('app.api.models.get_current_tenant_id', return_value="test-tenant"): response = await test_client.get("/models") # This endpoint might not exist yet, so we expect either 200 or 404 assert response.status_code in [status.HTTP_200_OK, status.HTTP_404_NOT_FOUND] if response.status_code == status.HTTP_200_OK: data = response.json() assert isinstance(data, list) @pytest.mark.asyncio async def test_get_model_details( self, test_client: AsyncClient, trained_model_in_db ): """Test getting model details""" model_id = trained_model_in_db.model_id with patch('app.api.models.get_current_tenant_id', return_value="test-tenant"): response = await test_client.get(f"/models/{model_id}") # This endpoint might not exist yet assert response.status_code in [ status.HTTP_200_OK, status.HTTP_404_NOT_FOUND, status.HTTP_501_NOT_IMPLEMENTED ] class TestErrorHandling: """Test error handling in API endpoints""" @pytest.mark.asyncio async def test_database_error_handling(self, test_client: AsyncClient): """Test handling of database errors""" with patch('app.services.training_service.TrainingService.create_training_job') as mock_create: mock_create.side_effect = Exception("Database connection failed") request_data = { "include_weather": True, "include_traffic": True, "min_data_points": 30 } with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"): response = await test_client.post("/training/jobs", json=request_data) assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR @pytest.mark.asyncio async def test_missing_tenant_id(self, test_client: AsyncClient): """Test handling when tenant ID is missing""" request_data = { "include_weather": True, "include_traffic": True, "min_data_points": 30 } # Don't mock get_current_tenant_id to simulate missing auth response = await test_client.post("/training/jobs", json=request_data) # Should fail due to missing authentication assert response.status_code in [status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN] @pytest.mark.asyncio async def test_invalid_job_id_format(self, test_client: AsyncClient): """Test handling of invalid job ID format""" invalid_job_id = "invalid-job-id-with-special-chars@#$" with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"): response = await test_client.get(f"/training/jobs/{invalid_job_id}/status") # Should handle gracefully assert response.status_code in [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST] @pytest.mark.asyncio async def test_messaging_failure_handling( self, test_client: AsyncClient, mock_data_service ): """Test handling when messaging fails""" request_data = { "include_weather": True, "include_traffic": True, "min_data_points": 30 } with patch('app.services.messaging.publish_job_started', side_effect=Exception("Messaging failed")), \ patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"): response = await test_client.post("/training/jobs", json=request_data) # Should still succeed even if messaging fails assert response.status_code == status.HTTP_200_OK data = response.json() assert "job_id" in data @pytest.mark.asyncio async def test_invalid_json_payload(self, test_client: AsyncClient): """Test handling of invalid JSON payload""" with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"): response = await test_client.post( "/training/jobs", content="invalid json {{{", headers={"Content-Type": "application/json"} ) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY @pytest.mark.asyncio async def test_unsupported_content_type(self, test_client: AsyncClient): """Test handling of unsupported content type""" with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"): response = await test_client.post( "/training/jobs", content="some text data", headers={"Content-Type": "text/plain"} ) assert response.status_code in [ status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, status.HTTP_422_UNPROCESSABLE_ENTITY ] class TestAuthenticationIntegration: """Test authentication integration""" @pytest.mark.asyncio async def test_endpoints_require_auth(self, test_client: AsyncClient): """Test that endpoints require authentication in production""" # This test would be more meaningful in a production environment # where authentication is actually enforced endpoints_to_test = [ ("POST", "/training/jobs"), ("GET", "/training/jobs"), ("POST", "/training/products/Pan Integral"), ("POST", "/training/validate") ] for method, endpoint in endpoints_to_test: if method == "POST": response = await test_client.post(endpoint, json={}) else: response = await test_client.get(endpoint) # In test environment with mocked auth, should work # In production, would require valid authentication assert response.status_code != status.HTTP_500_INTERNAL_SERVER_ERROR @pytest.mark.asyncio async def test_tenant_isolation_in_api( self, test_client: AsyncClient, training_job_in_db ): """Test tenant isolation at API level""" job_id = training_job_in_db.job_id # Try to access job with different tenant with patch('app.api.training.get_current_tenant_id', return_value="different-tenant"): response = await test_client.get(f"/training/jobs/{job_id}/status") # Should not find job for different tenant assert response.status_code == status.HTTP_404_NOT_FOUND class TestAPIValidation: """Test API validation and input handling""" @pytest.mark.asyncio async def test_training_request_validation(self, test_client: AsyncClient): """Test comprehensive training request validation""" # Test valid request valid_request = { "include_weather": True, "include_traffic": False, "min_data_points": 30, "seasonality_mode": "additive", "daily_seasonality": True, "weekly_seasonality": True, "yearly_seasonality": True } with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"): response = await test_client.post("/training/jobs", json=valid_request) assert response.status_code == status.HTTP_200_OK # Test invalid seasonality mode invalid_request = valid_request.copy() invalid_request["seasonality_mode"] = "invalid_mode" with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"): response = await test_client.post("/training/jobs", json=invalid_request) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY # Test invalid min_data_points invalid_request = valid_request.copy() invalid_request["min_data_points"] = 5 # Too low with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"): response = await test_client.post("/training/jobs", json=invalid_request) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY @pytest.mark.asyncio async def test_single_product_request_validation(self, test_client: AsyncClient): """Test single product training request validation""" product_name = "Pan Integral" # Test valid request valid_request = { "include_weather": True, "include_traffic": True, "seasonality_mode": "multiplicative" } with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"): response = await test_client.post( f"/training/products/{product_name}", json=valid_request ) assert response.status_code == status.HTTP_200_OK # Test empty product name with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"): response = await test_client.post( "/training/products/", json=valid_request ) assert response.status_code == status.HTTP_404_NOT_FOUND @pytest.mark.asyncio async def test_query_parameter_validation(self, test_client: AsyncClient): """Test query parameter validation""" with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"): # Test valid limit parameter response = await test_client.get("/training/jobs?limit=5") assert response.status_code == status.HTTP_200_OK # Test invalid limit parameter response = await test_client.get("/training/jobs?limit=invalid") assert response.status_code in [ status.HTTP_422_UNPROCESSABLE_ENTITY, status.HTTP_400_BAD_REQUEST ] # Test negative limit response = await test_client.get("/training/jobs?limit=-1") assert response.status_code in [ status.HTTP_422_UNPROCESSABLE_ENTITY, status.HTTP_400_BAD_REQUEST ] class TestAPIPerformance: """Test API performance characteristics""" @pytest.mark.asyncio async def test_concurrent_requests(self, test_client: AsyncClient): """Test handling of concurrent requests""" import asyncio # Create multiple concurrent requests tasks = [] for i in range(10): with patch('app.api.training.get_current_tenant_id', return_value=f"tenant-{i}"): task = test_client.get("/health") tasks.append(task) responses = await asyncio.gather(*tasks) # All requests should succeed for response in responses: assert response.status_code == status.HTTP_200_OK @pytest.mark.asyncio async def test_large_payload_handling(self, test_client: AsyncClient): """Test handling of large request payloads""" # Create large request payload large_request = { "include_weather": True, "include_traffic": True, "min_data_points": 30, "large_config": {f"key_{i}": f"value_{i}" for i in range(1000)} } with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"): response = await test_client.post("/training/jobs", json=large_request) # Should handle large payload gracefully assert response.status_code in [ status.HTTP_200_OK, status.HTTP_413_REQUEST_ENTITY_TOO_LARGE ] @pytest.mark.asyncio async def test_rapid_successive_requests(self, test_client: AsyncClient): """Test rapid successive requests to same endpoint""" # Make rapid requests responses = [] for _ in range(20): response = await test_client.get("/health") responses.append(response) # All should succeed for response in responses: assert response.status_code == status.HTTP_200_OK