# services/training/tests/test_service.py """ Tests for training service business logic layer """ import pytest from unittest.mock import AsyncMock, Mock, patch from datetime import datetime, timedelta import httpx from app.services.training_service import TrainingService from app.schemas.training import TrainingJobRequest, SingleProductTrainingRequest from app.models.training import ModelTrainingLog, TrainedModel class TestTrainingService: """Test the training service business logic""" @pytest.fixture def training_service(self, mock_ml_trainer): return TrainingService() @pytest.mark.asyncio async def test_create_training_job_success( self, training_service, test_db_session ): """Test successful training job creation""" job_id = "test-job-123" tenant_id = "test-tenant" config = {"include_weather": True, "include_traffic": True} result = await training_service.create_training_job( db=test_db_session, tenant_id=tenant_id, job_id=job_id, config=config ) assert isinstance(result, ModelTrainingLog) assert result.job_id == job_id assert result.tenant_id == tenant_id assert result.status == "pending" assert result.progress == 0 assert result.config == config @pytest.mark.asyncio async def test_create_single_product_job_success( self, training_service, test_db_session ): """Test successful single product job creation""" job_id = "test-product-job-123" tenant_id = "test-tenant" product_name = "Pan Integral" config = {"include_weather": True} result = await training_service.create_single_product_job( db=test_db_session, tenant_id=tenant_id, product_name=product_name, job_id=job_id, config=config ) assert isinstance(result, ModelTrainingLog) assert result.job_id == job_id assert result.tenant_id == tenant_id assert result.config["single_product"] == product_name assert f"Initializing training for {product_name}" in result.current_step @pytest.mark.asyncio async def test_get_job_status_existing( self, training_service, test_db_session, training_job_in_db ): """Test getting status of existing job""" result = await training_service.get_job_status( db=test_db_session, job_id=training_job_in_db.job_id, tenant_id=training_job_in_db.tenant_id ) assert result is not None assert result.job_id == training_job_in_db.job_id assert result.status == training_job_in_db.status @pytest.mark.asyncio async def test_get_job_status_nonexistent( self, training_service, test_db_session ): """Test getting status of non-existent job""" result = await training_service.get_job_status( db=test_db_session, job_id="nonexistent-job", tenant_id="test-tenant" ) assert result is None @pytest.mark.asyncio async def test_list_training_jobs( self, training_service, test_db_session, training_job_in_db ): """Test listing training jobs""" result = await training_service.list_training_jobs( db=test_db_session, tenant_id=training_job_in_db.tenant_id, limit=10 ) assert isinstance(result, list) assert len(result) >= 1 assert result[0].job_id == training_job_in_db.job_id @pytest.mark.asyncio async def test_list_training_jobs_with_filter( self, training_service, test_db_session, training_job_in_db ): """Test listing training jobs with status filter""" result = await training_service.list_training_jobs( db=test_db_session, tenant_id=training_job_in_db.tenant_id, limit=10, status_filter="pending" ) assert isinstance(result, list) for job in result: assert job.status == "pending" @pytest.mark.asyncio async def test_cancel_training_job_success( self, training_service, test_db_session, training_job_in_db ): """Test successful job cancellation""" result = await training_service.cancel_training_job( db=test_db_session, job_id=training_job_in_db.job_id, tenant_id=training_job_in_db.tenant_id ) assert result is True # Verify status was updated updated_job = await training_service.get_job_status( db=test_db_session, job_id=training_job_in_db.job_id, tenant_id=training_job_in_db.tenant_id ) assert updated_job.status == "cancelled" @pytest.mark.asyncio async def test_cancel_nonexistent_job( self, training_service, test_db_session ): """Test cancelling non-existent job""" result = await training_service.cancel_training_job( db=test_db_session, job_id="nonexistent-job", tenant_id="test-tenant" ) assert result is False @pytest.mark.asyncio async def test_validate_training_data_valid( self, training_service, test_db_session, mock_data_service ): """Test validation with valid data""" config = {"min_data_points": 30} result = await training_service.validate_training_data( db=test_db_session, tenant_id="test-tenant", config=config ) assert isinstance(result, dict) assert "is_valid" in result assert "issues" in result assert "recommendations" in result assert "estimated_time_minutes" in result @pytest.mark.asyncio async def test_validate_training_data_no_data( self, training_service, test_db_session ): """Test validation with no data""" config = {"min_data_points": 30} with patch('app.services.training_service.TrainingService._fetch_sales_data', return_value=AsyncMock(return_value=[])): result = await training_service.validate_training_data( db=test_db_session, tenant_id="test-tenant", config=config ) assert result["is_valid"] is False assert "No sales data found" in result["issues"][0] @pytest.mark.asyncio async def test_update_job_status( self, training_service, test_db_session, training_job_in_db ): """Test updating job status""" await training_service._update_job_status( db=test_db_session, job_id=training_job_in_db.job_id, status="running", progress=50, current_step="Training models" ) # Verify update updated_job = await training_service.get_job_status( db=test_db_session, job_id=training_job_in_db.job_id, tenant_id=training_job_in_db.tenant_id ) assert updated_job.status == "running" assert updated_job.progress == 50 assert updated_job.current_step == "Training models" @pytest.mark.asyncio async def test_store_trained_models( self, training_service, test_db_session ): """Test storing trained models""" tenant_id = "test-tenant" training_results = { "training_results": { "Pan Integral": { "status": "success", "model_info": { "model_id": "test-model-123", "model_path": "/test/models/test-model-123.pkl", "type": "prophet", "training_samples": 100, "features": ["temperature", "humidity"], "hyperparameters": {"seasonality_mode": "additive"}, "training_metrics": {"mae": 5.2, "rmse": 7.8}, "data_period": { "start_date": "2024-01-01T00:00:00", "end_date": "2024-01-31T00:00:00" } } } } } await training_service._store_trained_models( db=test_db_session, tenant_id=tenant_id, training_results=training_results ) # Verify model was stored from sqlalchemy import select result = await test_db_session.execute( select(TrainedModel).where( TrainedModel.tenant_id == tenant_id, TrainedModel.product_name == "Pan Integral" ) ) stored_model = result.scalar_one_or_none() assert stored_model is not None assert stored_model.model_id == "test-model-123" assert stored_model.is_active is True @pytest.mark.asyncio async def test_get_training_logs( self, training_service, test_db_session, training_job_in_db ): """Test getting training logs""" result = await training_service.get_training_logs( db=test_db_session, job_id=training_job_in_db.job_id, tenant_id=training_job_in_db.tenant_id ) assert isinstance(result, list) assert len(result) > 0 # Check log content log_text = " ".join(result) assert training_job_in_db.job_id in log_text or "Job started" in log_text class TestTrainingServiceDataFetching: """Test external data fetching functionality""" @pytest.fixture def training_service(self): return TrainingService() @pytest.mark.asyncio async def test_fetch_sales_data_success(self, training_service): """Test successful sales data fetching""" mock_request = Mock() mock_request.start_date = None mock_request.end_date = None mock_response_data = { "sales": [ {"date": "2024-01-01", "product_name": "Pan Integral", "quantity": 45} ] } with patch('httpx.AsyncClient') as mock_client: mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = mock_response_data mock_client.return_value.__aenter__.return_value.get.return_value = mock_response result = await training_service._fetch_sales_data( tenant_id="test-tenant", request=mock_request ) assert result == mock_response_data["sales"] @pytest.mark.asyncio async def test_fetch_sales_data_error(self, training_service): """Test sales data fetching with API error""" mock_request = Mock() mock_request.start_date = None mock_request.end_date = None with patch('httpx.AsyncClient') as mock_client: mock_response = Mock() mock_response.status_code = 500 mock_client.return_value.__aenter__.return_value.get.return_value = mock_response result = await training_service._fetch_sales_data( tenant_id="test-tenant", request=mock_request ) assert result == [] @pytest.mark.asyncio async def test_fetch_weather_data_success(self, training_service): """Test successful weather data fetching""" mock_request = Mock() mock_request.start_date = None mock_request.end_date = None mock_response_data = { "weather": [ {"date": "2024-01-01", "temperature": 15.2, "precipitation": 0.0} ] } with patch('httpx.AsyncClient') as mock_client: mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = mock_response_data mock_client.return_value.__aenter__.return_value.get.return_value = mock_response result = await training_service._fetch_weather_data( tenant_id="test-tenant", request=mock_request ) assert result == mock_response_data["weather"] @pytest.mark.asyncio async def test_fetch_traffic_data_success(self, training_service): """Test successful traffic data fetching""" mock_request = Mock() mock_request.start_date = None mock_request.end_date = None mock_response_data = { "traffic": [ {"date": "2024-01-01", "traffic_volume": 120} ] } with patch('httpx.AsyncClient') as mock_client: mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = mock_response_data mock_client.return_value.__aenter__.return_value.get.return_value = mock_response result = await training_service._fetch_traffic_data( tenant_id="test-tenant", request=mock_request ) assert result == mock_response_data["traffic"] @pytest.mark.asyncio async def test_fetch_data_with_date_filters(self, training_service): """Test data fetching with date filters""" from datetime import datetime mock_request = Mock() mock_request.start_date = datetime(2024, 1, 1) mock_request.end_date = datetime(2024, 1, 31) with patch('httpx.AsyncClient') as mock_client: mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = {"sales": []} mock_get = mock_client.return_value.__aenter__.return_value.get mock_get.return_value = mock_response await training_service._fetch_sales_data( tenant_id="test-tenant", request=mock_request ) # Verify dates were passed in params call_args = mock_get.call_args params = call_args[1]["params"] assert "start_date" in params assert "end_date" in params assert params["start_date"] == "2024-01-01T00:00:00" assert params["end_date"] == "2024-01-31T00:00:00" class TestTrainingServiceExecution: """Test training execution workflow""" @pytest.fixture def training_service(self, mock_ml_trainer): return TrainingService() @pytest.mark.asyncio async def test_execute_training_job_success( self, training_service, test_db_session, mock_messaging, mock_data_service ): """Test successful training job execution""" # Create job first job_id = "test-execution-job" training_log = await training_service.create_training_job( db=test_db_session, tenant_id="test-tenant", job_id=job_id, config={"include_weather": True} ) request = TrainingJobRequest( include_weather=True, include_traffic=True, min_data_points=30 ) with patch('app.services.training_service.TrainingService._fetch_sales_data') as mock_fetch_sales, \ patch('app.services.training_service.TrainingService._fetch_weather_data') as mock_fetch_weather, \ patch('app.services.training_service.TrainingService._fetch_traffic_data') as mock_fetch_traffic, \ patch('app.services.training_service.TrainingService._store_trained_models') as mock_store: mock_fetch_sales.return_value = [{"date": "2024-01-01", "product_name": "Pan Integral", "quantity": 45}] mock_fetch_weather.return_value = [] mock_fetch_traffic.return_value = [] mock_store.return_value = None await training_service.execute_training_job( db=test_db_session, job_id=job_id, tenant_id="test-tenant", request=request ) # Verify job was completed updated_job = await training_service.get_job_status( db=test_db_session, job_id=job_id, tenant_id="test-tenant" ) assert updated_job.status == "completed" assert updated_job.progress == 100 @pytest.mark.asyncio async def test_execute_training_job_failure( self, training_service, test_db_session, mock_messaging ): """Test training job execution with failure""" # Create job first job_id = "test-failure-job" await training_service.create_training_job( db=test_db_session, tenant_id="test-tenant", job_id=job_id, config={} ) request = TrainingJobRequest(min_data_points=30) with patch('app.services.training_service.TrainingService._fetch_sales_data') as mock_fetch: mock_fetch.side_effect = Exception("Data service unavailable") with pytest.raises(Exception): await training_service.execute_training_job( db=test_db_session, job_id=job_id, tenant_id="test-tenant", request=request ) # Verify job was marked as failed updated_job = await training_service.get_job_status( db=test_db_session, job_id=job_id, tenant_id="test-tenant" ) assert updated_job.status == "failed" assert "Data service unavailable" in updated_job.error_message @pytest.mark.asyncio async def test_execute_single_product_training_success( self, training_service, test_db_session, mock_messaging, mock_data_service ): """Test successful single product training execution""" job_id = "test-single-product-job" product_name = "Pan Integral" await training_service.create_single_product_job( db=test_db_session, tenant_id="test-tenant", product_name=product_name, job_id=job_id, config={} ) request = SingleProductTrainingRequest( include_weather=True, include_traffic=False ) with patch('app.services.training_service.TrainingService._fetch_product_sales_data') as mock_fetch_sales, \ patch('app.services.training_service.TrainingService._fetch_weather_data') as mock_fetch_weather, \ patch('app.services.training_service.TrainingService._store_single_trained_model') as mock_store: mock_fetch_sales.return_value = [{"date": "2024-01-01", "product_name": product_name, "quantity": 45}] mock_fetch_weather.return_value = [] mock_store.return_value = None await training_service.execute_single_product_training( db=test_db_session, job_id=job_id, tenant_id="test-tenant", product_name=product_name, request=request ) # Verify job was completed updated_job = await training_service.get_job_status( db=test_db_session, job_id=job_id, tenant_id="test-tenant" ) assert updated_job.status == "completed" assert updated_job.progress == 100 class TestTrainingServiceEdgeCases: """Test edge cases and error conditions""" @pytest.fixture def training_service(self): return TrainingService() @pytest.mark.asyncio async def test_database_connection_failure(self, training_service): """Test handling of database connection failures""" with patch('sqlalchemy.ext.asyncio.AsyncSession') as mock_session: mock_session.side_effect = Exception("Database connection failed") with pytest.raises(Exception): await training_service.create_training_job( db=mock_session, tenant_id="test-tenant", job_id="test-job", config={} ) @pytest.mark.asyncio async def test_external_service_timeout(self, training_service): """Test handling of external service timeouts""" mock_request = Mock() mock_request.start_date = None mock_request.end_date = None with patch('httpx.AsyncClient') as mock_client: mock_client.return_value.__aenter__.return_value.get.side_effect = httpx.TimeoutException("Request timeout") result = await training_service._fetch_sales_data( tenant_id="test-tenant", request=mock_request ) # Should return empty list on timeout assert result == [] @pytest.mark.asyncio async def test_concurrent_job_creation(self, training_service, test_db_session): """Test handling of concurrent job creation""" # This test would need more sophisticated setup for true concurrency testing # For now, just test that multiple jobs can be created job_ids = ["concurrent-job-1", "concurrent-job-2", "concurrent-job-3"] jobs = [] for job_id in job_ids: job = await training_service.create_training_job( db=test_db_session, tenant_id="test-tenant", job_id=job_id, config={} ) jobs.append(job) assert len(jobs) == 3 for i, job in enumerate(jobs): assert job.job_id == job_ids[i] @pytest.mark.asyncio async def test_malformed_config_handling(self, training_service, test_db_session): """Test handling of malformed configuration""" malformed_config = { "invalid_key": "invalid_value", "nested": {"data": None} } # Should not raise exception, just store the config as-is job = await training_service.create_training_job( db=test_db_session, tenant_id="test-tenant", job_id="malformed-config-job", config=malformed_config ) assert job.config == malformed_config