# services/training/tests/test_integration.py """ Integration tests for training service Tests complete workflows and service interactions """ import pytest import asyncio from unittest.mock import AsyncMock, Mock, patch from httpx import AsyncClient from datetime import datetime, timedelta from app.main import app from app.schemas.training import TrainingJobRequest class TestTrainingWorkflowIntegration: """Test complete training workflows end-to-end""" @pytest.mark.asyncio async def test_complete_training_workflow( self, test_client: AsyncClient, test_db_session, mock_messaging, mock_data_service, mock_ml_trainer ): """Test complete training workflow from API to completion""" # Step 1: Start training job request_data = { "include_weather": True, "include_traffic": True, "min_data_points": 30, "seasonality_mode": "additive" } with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"): response = await test_client.post("/training/jobs", json=request_data) assert response.status_code == 200 job_data = response.json() job_id = job_data["job_id"] # Step 2: Check initial status with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"): response = await test_client.get(f"/training/jobs/{job_id}/status") assert response.status_code == 200 status_data = response.json() assert status_data["status"] in ["pending", "started"] # Step 3: Simulate background task completion # In real scenario, this would be handled by background tasks await asyncio.sleep(0.1) # Allow background task to start # Step 4: Check completion status with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"): response = await test_client.get(f"/training/jobs/{job_id}/status") # The job should exist in database even if not completed yet assert response.status_code == 200 @pytest.mark.asyncio async def test_single_product_training_workflow( self, test_client: AsyncClient, mock_messaging, mock_data_service, mock_ml_trainer ): """Test single product training complete workflow""" product_name = "Pan Integral" request_data = { "include_weather": True, "include_traffic": False, "seasonality_mode": "additive" } # Start single product training with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"): response = await test_client.post( f"/training/products/{product_name}", json=request_data ) assert response.status_code == 200 job_data = response.json() job_id = job_data["job_id"] assert f"training started for {product_name}" in job_data["message"].lower() # Check job status with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"): response = await test_client.get(f"/training/jobs/{job_id}/status") assert response.status_code == 200 status_data = response.json() assert status_data["job_id"] == job_id @pytest.mark.asyncio async def test_training_validation_workflow( self, test_client: AsyncClient, mock_data_service ): """Test training data validation workflow""" request_data = { "include_weather": True, "include_traffic": True, "min_data_points": 30 } # Validate training data with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"): response = await test_client.post("/training/validate", json=request_data) assert response.status_code == 200 validation_data = response.json() assert "is_valid" in validation_data assert "issues" in validation_data assert "recommendations" in validation_data assert "estimated_training_time" in validation_data # If validation passes, start actual training if validation_data["is_valid"]: with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"): response = await test_client.post("/training/jobs", json=request_data) assert response.status_code == 200 @pytest.mark.asyncio async def test_job_cancellation_workflow( self, test_client: AsyncClient, training_job_in_db, mock_messaging ): """Test job cancellation workflow""" job_id = training_job_in_db.job_id # Check initial status with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"): response = await test_client.get(f"/training/jobs/{job_id}/status") assert response.status_code == 200 initial_status = response.json() assert initial_status["status"] == "pending" # Cancel the job with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"): response = await test_client.post(f"/training/jobs/{job_id}/cancel") assert response.status_code == 200 cancel_response = response.json() assert "cancelled" in cancel_response["message"].lower() # Verify cancellation with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"): response = await test_client.get(f"/training/jobs/{job_id}/status") assert response.status_code == 200 final_status = response.json() assert final_status["status"] == "cancelled" class TestServiceInteractionIntegration: """Test interactions between training service and external services""" @pytest.mark.asyncio async def test_data_service_integration(self, training_service, mock_data_service): """Test integration with data service""" from app.schemas.training import TrainingJobRequest request = TrainingJobRequest( include_weather=True, include_traffic=True, min_data_points=30 ) # Test sales data fetching sales_data = await training_service._fetch_sales_data("test-tenant", request) assert isinstance(sales_data, list) # Test weather data fetching weather_data = await training_service._fetch_weather_data("test-tenant", request) assert isinstance(weather_data, list) # Test traffic data fetching traffic_data = await training_service._fetch_traffic_data("test-tenant", request) assert isinstance(traffic_data, list) @pytest.mark.asyncio async def test_messaging_integration(self, mock_messaging): """Test integration with messaging system""" from app.services.messaging import ( publish_job_started, publish_job_completed, publish_model_trained ) # Test various message types result1 = await publish_job_started("job-123", "tenant-123", {}) result2 = await publish_job_completed("job-123", "tenant-123", {"status": "success"}) result3 = await publish_model_trained("model-123", "tenant-123", "Pan Integral", {"mae": 5.0}) assert result1 is True assert result2 is True assert result3 is True @pytest.mark.asyncio async def test_database_integration(self, test_db_session, training_service): """Test database operations integration""" # Create a training job job = await training_service.create_training_job( db=test_db_session, tenant_id="test-tenant", job_id="integration-test-job", config={"test": True} ) assert job.job_id == "integration-test-job" # Update job status await training_service._update_job_status( db=test_db_session, job_id=job.job_id, status="running", progress=50, current_step="Processing data" ) # Retrieve updated job updated_job = await training_service.get_job_status( db=test_db_session, job_id=job.job_id, tenant_id="test-tenant" ) assert updated_job.status == "running" assert updated_job.progress == 50 class TestErrorHandlingIntegration: """Test error handling across service boundaries""" @pytest.mark.asyncio async def test_data_service_failure_handling( self, test_client: AsyncClient, mock_messaging ): """Test handling when data service is unavailable""" request_data = { "include_weather": True, "include_traffic": True, "min_data_points": 30 } # Mock data service failure with patch('httpx.AsyncClient') as mock_client: mock_client.return_value.__aenter__.return_value.get.side_effect = Exception("Service unavailable") with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"): response = await test_client.post("/training/jobs", json=request_data) # Should still create job but might fail during execution assert response.status_code == 200 @pytest.mark.asyncio async def test_messaging_failure_handling( self, test_client: AsyncClient, mock_data_service ): """Test handling when messaging fails""" request_data = { "include_weather": True, "include_traffic": True, "min_data_points": 30 } # Mock messaging failure with patch('app.services.messaging.publish_job_started', side_effect=Exception("Messaging failed")): with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"): response = await test_client.post("/training/jobs", json=request_data) # Should still succeed even if messaging fails assert response.status_code == 200 @pytest.mark.asyncio async def test_ml_training_failure_handling( self, test_client: AsyncClient, mock_messaging, mock_data_service ): """Test handling when ML training fails""" request_data = { "include_weather": True, "include_traffic": True, "min_data_points": 30 } # Mock ML training failure with patch('app.ml.trainer.BakeryMLTrainer.train_tenant_models', side_effect=Exception("ML training failed")): with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"): response = await test_client.post("/training/jobs", json=request_data) # Job should be created successfully assert response.status_code == 200 # Background task would handle the failure class TestPerformanceIntegration: """Test performance characteristics of integrated workflows""" @pytest.mark.asyncio async def test_concurrent_training_jobs( self, test_client: AsyncClient, mock_messaging, mock_data_service, mock_ml_trainer ): """Test handling multiple concurrent training jobs""" request_data = { "include_weather": True, "include_traffic": True, "min_data_points": 30 } # Start multiple jobs concurrently tasks = [] for i in range(5): with patch('app.api.training.get_current_tenant_id', return_value=f"tenant-{i}"): task = test_client.post("/training/jobs", json=request_data) tasks.append(task) responses = await asyncio.gather(*tasks) # All jobs should be created successfully for response in responses: assert response.status_code == 200 data = response.json() assert "job_id" in data @pytest.mark.asyncio async def test_large_dataset_handling( self, training_service, test_db_session ): """Test handling of large datasets""" # Simulate large dataset large_config = { "include_weather": True, "include_traffic": True, "min_data_points": 1000, # Large minimum "products": [f"Product-{i}" for i in range(100)] # Many products } job = await training_service.create_training_job( db=test_db_session, tenant_id="test-tenant", job_id="large-dataset-job", config=large_config ) assert job.config == large_config assert job.job_id == "large-dataset-job" @pytest.mark.asyncio async def test_rapid_status_checks( self, test_client: AsyncClient, training_job_in_db ): """Test rapid successive status checks""" job_id = training_job_in_db.job_id # Make many rapid status requests tasks = [] for _ in range(20): with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"): task = test_client.get(f"/training/jobs/{job_id}/status") tasks.append(task) responses = await asyncio.gather(*tasks) # All requests should succeed for response in responses: assert response.status_code == 200 class TestSecurityIntegration: """Test security aspects of service integration""" @pytest.mark.asyncio async def test_tenant_isolation( self, test_client: AsyncClient, training_job_in_db, mock_messaging ): """Test that tenants cannot access each other's jobs""" job_id = training_job_in_db.job_id # Try to access job with different tenant ID with patch('app.api.training.get_current_tenant_id', return_value="different-tenant"): response = await test_client.get(f"/training/jobs/{job_id}/status") # Should not find the job (belongs to different tenant) assert response.status_code == 404 @pytest.mark.asyncio async def test_input_validation_integration( self, test_client: AsyncClient ): """Test input validation across API boundaries""" # Test invalid seasonality mode invalid_request = { "seasonality_mode": "invalid_mode", "min_data_points": -5 # Invalid negative value } with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"): response = await test_client.post("/training/jobs", json=invalid_request) assert response.status_code == 422 # Validation error @pytest.mark.asyncio async def test_sql_injection_protection( self, test_client: AsyncClient ): """Test protection against SQL injection attempts""" # Try SQL injection in job ID malicious_job_id = "job'; DROP TABLE model_training_logs; --" with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"): response = await test_client.get(f"/training/jobs/{malicious_job_id}/status") # Should return 404, not cause database error assert response.status_code == 404 class TestRecoveryIntegration: """Test recovery and resilience scenarios""" @pytest.mark.asyncio async def test_service_restart_recovery( self, test_db_session, training_service, training_job_in_db ): """Test service recovery after restart""" # Simulate service restart by creating new service instance new_training_service = training_service.__class__() # Should be able to access existing jobs existing_job = await new_training_service.get_job_status( db=test_db_session, job_id=training_job_in_db.job_id, tenant_id=training_job_in_db.tenant_id ) assert existing_job is not None assert existing_job.job_id == training_job_in_db.job_id @pytest.mark.asyncio async def test_partial_failure_recovery( self, training_service, test_db_session ): """Test recovery from partial failures""" # Create job that might fail partway through job = await training_service.create_training_job( db=test_db_session, tenant_id="test-tenant", job_id="partial-failure-job", config={"simulate_failure": True} ) # Simulate partial progress await training_service._update_job_status( db=test_db_session, job_id=job.job_id, status="running", progress=50, current_step="Halfway through training" ) # Simulate failure await training_service._update_job_status( db=test_db_session, job_id=job.job_id, status="failed", progress=50, current_step="Training failed", error_message="Simulated failure" ) # Verify failure was recorded failed_job = await training_service.get_job_status( db=test_db_session, job_id=job.job_id, tenant_id="test-tenant" ) assert failed_job.status == "failed" assert failed_job.error_message == "Simulated failure" assert failed_job.progress == 50 class TestComplianceIntegration: """Test compliance and audit requirements""" @pytest.mark.asyncio async def test_audit_trail_creation( self, training_service, test_db_session ): """Test that audit trail is properly created""" # Create and update job job = await training_service.create_training_job( db=test_db_session, tenant_id="test-tenant", job_id="audit-test-job", config={"audit_test": True} ) # Multiple status updates await training_service._update_job_status( db=test_db_session, job_id=job.job_id, status="running", progress=25, current_step="Started processing" ) await training_service._update_job_status( db=test_db_session, job_id=job.job_id, status="running", progress=75, current_step="Almost complete" ) await training_service._update_job_status( db=test_db_session, job_id=job.job_id, status="completed", progress=100, current_step="Completed successfully" ) # Verify audit trail logs = await training_service.get_training_logs( db=test_db_session, job_id=job.job_id, tenant_id="test-tenant" ) assert logs is not None assert len(logs) > 0 # Check final status final_job = await training_service.get_job_status( db=test_db_session, job_id=job.job_id, tenant_id="test-tenant" ) assert final_job.status == "completed" assert final_job.progress == 100 @pytest.mark.asyncio async def test_data_retention_compliance( self, training_service, test_db_session ): """Test data retention and cleanup compliance""" from datetime import datetime, timedelta # Create old job (simulate old data) old_job = await training_service.create_training_job( db=test_db_session, tenant_id="test-tenant", job_id="old-job", config={"created_long_ago": True} ) # Manually set old timestamp from sqlalchemy import update from app.models.training import ModelTrainingLog old_timestamp = datetime.now() - timedelta(days=400) await test_db_session.execute( update(ModelTrainingLog) .where(ModelTrainingLog.job_id == old_job.job_id) .values(start_time=old_timestamp, created_at=old_timestamp) ) await test_db_session.commit() # Verify old job exists retrieved_job = await training_service.get_job_status( db=test_db_session, job_id=old_job.job_id, tenant_id="test-tenant" ) assert retrieved_job is not None # In a real implementation, there would be cleanup procedures @pytest.mark.asyncio async def test_gdpr_compliance_features( self, training_service, test_db_session ): """Test GDPR compliance features""" # Create job with tenant data job = await training_service.create_training_job( db=test_db_session, tenant_id="gdpr-test-tenant", job_id="gdpr-test-job", config={"gdpr_test": True} ) # Verify job is associated with tenant assert job.tenant_id == "gdpr-test-tenant" # Test data access (right to access) tenant_jobs = await training_service.list_training_jobs( db=test_db_session, tenant_id="gdpr-test-tenant" ) assert len(tenant_jobs) >= 1 assert any(job.job_id == "gdpr-test-job" for job in tenant_jobs) @pytest.mark.slow class TestLongRunningIntegration: """Test long-running integration scenarios (marked as slow)""" @pytest.mark.asyncio async def test_extended_training_simulation( self, training_service, test_db_session, mock_messaging ): """Test extended training process simulation""" job = await training_service.create_training_job( db=test_db_session, tenant_id="test-tenant", job_id="long-running-job", config={"extended_test": True} ) # Simulate progress over time progress_steps = [ (10, "Initializing"), (25, "Loading data"), (50, "Training models"), (75, "Validating results"), (90, "Storing models"), (100, "Completed") ] for progress, step in progress_steps: await training_service._update_job_status( db=test_db_session, job_id=job.job_id, status="running" if progress < 100 else "completed", progress=progress, current_step=step ) # Small delay to simulate real progression await asyncio.sleep(0.01) # Verify final state final_job = await training_service.get_job_status( db=test_db_session, job_id=job.job_id, tenant_id="test-tenant" ) assert final_job.status == "completed" assert final_job.progress == 100 assert final_job.current_step == "Completed" @pytest.mark.asyncio async def test_memory_usage_stability( self, training_service, test_db_session ): """Test memory usage stability over many operations""" # Create many jobs to test memory stability for i in range(50): job = await training_service.create_training_job( db=test_db_session, tenant_id=f"tenant-{i % 5}", # 5 different tenants job_id=f"memory-test-job-{i}", config={"iteration": i} ) # Update status await training_service._update_job_status( db=test_db_session, job_id=job.job_id, status="completed", progress=100, current_step="Completed" ) # List jobs for each tenant for tenant_i in range(5): tenant_id = f"tenant-{tenant_i}" jobs = await training_service.list_training_jobs( db=test_db_session, tenant_id=tenant_id, limit=20 ) # Should have 10 jobs per tenant (50 total / 5 tenants) assert len(jobs) == 10 class TestBackwardCompatibility: """Test backward compatibility with existing systems""" @pytest.mark.asyncio async def test_legacy_config_handling( self, training_service, test_db_session ): """Test handling of legacy configuration formats""" # Test with old-style configuration legacy_config = { "weather_enabled": True, # Old key "traffic_enabled": True, # Old key "minimum_samples": 30, # Old key "prophet_config": { # Old nested structure "seasonality": "additive" } } job = await training_service.create_training_job( db=test_db_session, tenant_id="test-tenant", job_id="legacy-config-job", config=legacy_config ) assert job.config == legacy_config assert job.job_id == "legacy-config-job" @pytest.mark.asyncio async def test_api_version_compatibility( self, test_client: AsyncClient ): """Test API version compatibility""" # Test with minimal request (old API style) minimal_request = { "include_weather": True } with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"): response = await test_client.post("/training/jobs", json=minimal_request) # Should work with defaults for missing fields assert response.status_code == 200 data = response.json() assert "job_id" in data # Utility functions for integration tests async def wait_for_condition(condition_func, timeout=5.0, interval=0.1): """Wait for a condition to become true""" import time start_time = time.time() while time.time() - start_time < timeout: if await condition_func(): return True await asyncio.sleep(interval) return False def assert_job_progression(job_updates): """Assert that job updates show proper progression""" assert len(job_updates) > 0 # Check progress is non-decreasing for i in range(1, len(job_updates)): assert job_updates[i]["progress"] >= job_updates[i-1]["progress"] # Check final status final_update = job_updates[-1] assert final_update["status"] in ["completed", "failed", "cancelled"] def assert_valid_job_structure(job_data): """Assert job data has valid structure""" required_fields = ["job_id", "status", "tenant_id"] for field in required_fields: assert field in job_data assert isinstance(job_data["progress"], int) assert 0 <= job_data["progress"] <= 100 assert job_data["status"] in ["pending", "running", "completed", "failed", "cancelled"]