Add all the code for training service
This commit is contained in:
263
services/training/tests/README.md
Normal file
263
services/training/tests/README.md
Normal file
@@ -0,0 +1,263 @@
|
||||
# Training Service - Complete Testing Suite
|
||||
|
||||
## 📁 Test Structure
|
||||
|
||||
```
|
||||
services/training/tests/
|
||||
├── conftest.py # Test configuration and fixtures
|
||||
├── test_api.py # API endpoint tests
|
||||
├── test_ml.py # ML component tests
|
||||
├── test_service.py # Service layer tests
|
||||
├── test_messaging.py # Messaging tests
|
||||
└── test_integration.py # Integration tests
|
||||
```
|
||||
|
||||
## 🧪 Test Coverage
|
||||
|
||||
### **1. API Tests (`test_api.py`)**
|
||||
- ✅ Health check endpoints (`/health`, `/health/ready`, `/health/live`)
|
||||
- ✅ Metrics endpoint (`/metrics`)
|
||||
- ✅ Training job creation and management
|
||||
- ✅ Single product training
|
||||
- ✅ Job status tracking and cancellation
|
||||
- ✅ Data validation endpoints
|
||||
- ✅ Error handling and edge cases
|
||||
- ✅ Authentication integration
|
||||
|
||||
**Key Test Classes:**
|
||||
- `TestTrainingAPI` - Basic API functionality
|
||||
- `TestTrainingJobsAPI` - Training job management
|
||||
- `TestSingleProductTrainingAPI` - Single product workflows
|
||||
- `TestErrorHandling` - Error scenarios
|
||||
- `TestAuthenticationIntegration` - Security tests
|
||||
|
||||
### **2. ML Component Tests (`test_ml.py`)**
|
||||
- ✅ Data processor functionality
|
||||
- ✅ Prophet manager operations
|
||||
- ✅ ML trainer orchestration
|
||||
- ✅ Feature engineering validation
|
||||
- ✅ Model training and validation
|
||||
|
||||
**Key Test Classes:**
|
||||
- `TestBakeryDataProcessor` - Data preparation and feature engineering
|
||||
- `TestBakeryProphetManager` - Prophet model management
|
||||
- `TestBakeryMLTrainer` - ML training orchestration
|
||||
- `TestIntegrationML` - ML component integration
|
||||
|
||||
**Key Features Tested:**
|
||||
- Spanish holiday detection
|
||||
- Temporal feature engineering
|
||||
- Weather and traffic data integration
|
||||
- Model validation and metrics
|
||||
- Data quality checks
|
||||
|
||||
### **3. Service Layer Tests (`test_service.py`)**
|
||||
- ✅ Training service business logic
|
||||
- ✅ Database operations
|
||||
- ✅ External service integration
|
||||
- ✅ Job lifecycle management
|
||||
- ✅ Error recovery and resilience
|
||||
|
||||
**Key Test Classes:**
|
||||
- `TestTrainingService` - Core business logic
|
||||
- `TestTrainingServiceDataFetching` - External API integration
|
||||
- `TestTrainingServiceExecution` - Training workflow execution
|
||||
- `TestTrainingServiceEdgeCases` - Edge cases and error conditions
|
||||
|
||||
### **4. Messaging Tests (`test_messaging.py`)**
|
||||
- ✅ Event publishing functionality
|
||||
- ✅ Message structure validation
|
||||
- ✅ Error handling in messaging
|
||||
- ✅ Integration with shared components
|
||||
|
||||
**Key Test Classes:**
|
||||
- `TestTrainingMessaging` - Basic messaging operations
|
||||
- `TestMessagingErrorHandling` - Error scenarios
|
||||
- `TestMessagingIntegration` - Shared component integration
|
||||
- `TestMessagingPerformance` - Performance and reliability
|
||||
|
||||
### **5. Integration Tests (`test_integration.py`)**
|
||||
- ✅ End-to-end workflow testing
|
||||
- ✅ Service interaction validation
|
||||
- ✅ Error handling across boundaries
|
||||
- ✅ Performance and scalability
|
||||
- ✅ Security and compliance
|
||||
|
||||
**Key Test Classes:**
|
||||
- `TestTrainingWorkflowIntegration` - Complete workflows
|
||||
- `TestServiceInteractionIntegration` - Cross-service communication
|
||||
- `TestErrorHandlingIntegration` - Error propagation
|
||||
- `TestPerformanceIntegration` - Performance characteristics
|
||||
- `TestSecurityIntegration` - Security validation
|
||||
- `TestRecoveryIntegration` - Recovery scenarios
|
||||
- `TestComplianceIntegration` - GDPR and audit compliance
|
||||
|
||||
## 🔧 Test Configuration (`conftest.py`)
|
||||
|
||||
### **Fixtures Provided:**
|
||||
- `test_engine` - Test database engine
|
||||
- `test_db_session` - Database session for tests
|
||||
- `test_client` - HTTP test client
|
||||
- `mock_messaging` - Mocked messaging system
|
||||
- `mock_data_service` - Mocked external data services
|
||||
- `mock_ml_trainer` - Mocked ML trainer
|
||||
- `mock_prophet_manager` - Mocked Prophet manager
|
||||
- `mock_data_processor` - Mocked data processor
|
||||
- `training_job_in_db` - Sample training job in database
|
||||
- `trained_model_in_db` - Sample trained model in database
|
||||
|
||||
### **Helper Functions:**
|
||||
- `assert_training_job_structure()` - Validate job data structure
|
||||
- `assert_model_structure()` - Validate model data structure
|
||||
|
||||
## 🚀 Running Tests
|
||||
|
||||
### **Run All Tests:**
|
||||
```bash
|
||||
cd services/training
|
||||
pytest tests/ -v
|
||||
```
|
||||
|
||||
### **Run Specific Test Categories:**
|
||||
```bash
|
||||
# API tests only
|
||||
pytest tests/test_api.py -v
|
||||
|
||||
# ML component tests
|
||||
pytest tests/test_ml.py -v
|
||||
|
||||
# Service layer tests
|
||||
pytest tests/test_service.py -v
|
||||
|
||||
# Messaging tests
|
||||
pytest tests/test_messaging.py -v
|
||||
|
||||
# Integration tests
|
||||
pytest tests/test_integration.py -v
|
||||
```
|
||||
|
||||
### **Run with Coverage:**
|
||||
```bash
|
||||
pytest tests/ --cov=app --cov-report=html --cov-report=term
|
||||
```
|
||||
|
||||
### **Run Performance Tests:**
|
||||
```bash
|
||||
pytest tests/test_integration.py::TestPerformanceIntegration -v
|
||||
```
|
||||
|
||||
### **Skip Slow Tests:**
|
||||
```bash
|
||||
pytest tests/ -v -m "not slow"
|
||||
```
|
||||
|
||||
## 📊 Test Scenarios Covered
|
||||
|
||||
### **Happy Path Scenarios:**
|
||||
- ✅ Complete training workflow (start → progress → completion)
|
||||
- ✅ Single product training
|
||||
- ✅ Data validation and preprocessing
|
||||
- ✅ Model training and storage
|
||||
- ✅ Event publishing and messaging
|
||||
- ✅ Job status tracking and cancellation
|
||||
|
||||
### **Error Scenarios:**
|
||||
- ✅ Database connection failures
|
||||
- ✅ External service unavailability
|
||||
- ✅ Invalid input data
|
||||
- ✅ ML training failures
|
||||
- ✅ Messaging system failures
|
||||
- ✅ Authentication and authorization errors
|
||||
|
||||
### **Edge Cases:**
|
||||
- ✅ Concurrent job execution
|
||||
- ✅ Large datasets
|
||||
- ✅ Malformed configurations
|
||||
- ✅ Network timeouts
|
||||
- ✅ Memory pressure scenarios
|
||||
- ✅ Rapid successive requests
|
||||
|
||||
### **Security Tests:**
|
||||
- ✅ Tenant isolation
|
||||
- ✅ Input validation
|
||||
- ✅ SQL injection protection
|
||||
- ✅ Authentication enforcement
|
||||
- ✅ Data access controls
|
||||
|
||||
### **Compliance Tests:**
|
||||
- ✅ Audit trail creation
|
||||
- ✅ Data retention policies
|
||||
- ✅ GDPR compliance features
|
||||
- ✅ Backward compatibility
|
||||
|
||||
## 🎯 Test Quality Metrics
|
||||
|
||||
### **Coverage Goals:**
|
||||
- **API Layer:** 95%+ coverage
|
||||
- **Service Layer:** 90%+ coverage
|
||||
- **ML Components:** 85%+ coverage
|
||||
- **Integration:** 80%+ coverage
|
||||
|
||||
### **Test Types Distribution:**
|
||||
- **Unit Tests:** ~60% (isolated component testing)
|
||||
- **Integration Tests:** ~30% (service interaction testing)
|
||||
- **End-to-End Tests:** ~10% (complete workflow testing)
|
||||
|
||||
### **Performance Benchmarks:**
|
||||
- All unit tests complete in <5 seconds
|
||||
- Integration tests complete in <30 seconds
|
||||
- End-to-end tests complete in <60 seconds
|
||||
|
||||
## 🔧 Mocking Strategy
|
||||
|
||||
### **External Dependencies Mocked:**
|
||||
- ✅ **Data Service:** HTTP calls mocked with realistic responses
|
||||
- ✅ **RabbitMQ:** Message publishing mocked for isolation
|
||||
- ✅ **Database:** SQLite in-memory for fast testing
|
||||
- ✅ **Prophet Models:** Training mocked for speed
|
||||
- ✅ **File System:** Model storage mocked
|
||||
|
||||
### **Real Components Tested:**
|
||||
- ✅ **FastAPI Application:** Real app instance
|
||||
- ✅ **Pydantic Validation:** Real validation logic
|
||||
- ✅ **SQLAlchemy ORM:** Real database operations
|
||||
- ✅ **Business Logic:** Real service layer code
|
||||
|
||||
## 🛡️ Continuous Integration
|
||||
|
||||
### **CI Pipeline Tests:**
|
||||
```yaml
|
||||
# Example CI configuration
|
||||
test_matrix:
|
||||
- python: "3.11"
|
||||
database: "postgresql"
|
||||
- python: "3.11"
|
||||
database: "sqlite"
|
||||
|
||||
test_commands:
|
||||
- pytest tests/ --cov=app --cov-fail-under=85
|
||||
- pytest tests/test_integration.py -m "not slow"
|
||||
- pytest tests/ --maxfail=1 --tb=short
|
||||
```
|
||||
|
||||
### **Quality Gates:**
|
||||
- ✅ All tests must pass
|
||||
- ✅ Coverage must be >85%
|
||||
- ✅ No critical security issues
|
||||
- ✅ Performance benchmarks met
|
||||
|
||||
## 📈 Test Maintenance
|
||||
|
||||
### **Regular Updates:**
|
||||
- ✅ Add tests for new features
|
||||
- ✅ Update mocks when APIs change
|
||||
- ✅ Review and update test data
|
||||
- ✅ Maintain realistic test scenarios
|
||||
|
||||
### **Monitoring:**
|
||||
- ✅ Test execution time tracking
|
||||
- ✅ Flaky test identification
|
||||
- ✅ Coverage trend monitoring
|
||||
- ✅ Test failure analysis
|
||||
|
||||
This comprehensive test suite ensures the training service is robust, reliable, and ready for production deployment! 🎉
|
||||
362
services/training/tests/conftest.py
Normal file
362
services/training/tests/conftest.py
Normal file
@@ -0,0 +1,362 @@
|
||||
# services/training/tests/conftest.py
|
||||
"""
|
||||
Pytest configuration and fixtures for training service tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import os
|
||||
from typing import AsyncGenerator, Generator
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from httpx import AsyncClient
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
# Add app to Python path
|
||||
import sys
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from app.main import app
|
||||
from app.core.database import Base, get_db
|
||||
from app.core.config import settings
|
||||
from app.models.training import ModelTrainingLog, TrainedModel
|
||||
from app.ml.trainer import BakeryMLTrainer
|
||||
from app.ml.prophet_manager import BakeryProphetManager
|
||||
from app.ml.data_processor import BakeryDataProcessor
|
||||
|
||||
# Test database URL
|
||||
TEST_DATABASE_URL = "sqlite+aiosqlite:///./test_training.db"
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]:
|
||||
"""Create an instance of the default event loop for the test session."""
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def test_engine():
|
||||
"""Create test database engine"""
|
||||
engine = create_async_engine(TEST_DATABASE_URL, echo=False)
|
||||
|
||||
# Create all tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
# Cleanup
|
||||
await engine.dispose()
|
||||
|
||||
# Remove test database file
|
||||
try:
|
||||
os.remove("./test_training.db")
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
@pytest.fixture
|
||||
async def test_db_session(test_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Create test database session"""
|
||||
async_session = sessionmaker(
|
||||
test_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
||||
@pytest.fixture
|
||||
def override_get_db(test_db_session):
|
||||
"""Override the get_db dependency"""
|
||||
async def _override_get_db():
|
||||
yield test_db_session
|
||||
|
||||
app.dependency_overrides[get_db] = _override_get_db
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
@pytest.fixture
|
||||
async def test_client(override_get_db) -> AsyncGenerator[AsyncClient, None]:
|
||||
"""Create test HTTP client"""
|
||||
async with AsyncClient(app=app, base_url="http://test") as client:
|
||||
yield client
|
||||
|
||||
@pytest.fixture
|
||||
def sync_test_client() -> Generator[TestClient, None, None]:
|
||||
"""Create synchronous test client for simple tests"""
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_messaging():
|
||||
"""Mock messaging for tests"""
|
||||
with patch('app.services.messaging.setup_messaging') as mock_setup, \
|
||||
patch('app.services.messaging.cleanup_messaging') as mock_cleanup, \
|
||||
patch('app.services.messaging.publish_job_started') as mock_start, \
|
||||
patch('app.services.messaging.publish_job_completed') as mock_complete, \
|
||||
patch('app.services.messaging.publish_job_failed') as mock_failed:
|
||||
|
||||
mock_setup.return_value = AsyncMock()
|
||||
mock_cleanup.return_value = AsyncMock()
|
||||
mock_start.return_value = AsyncMock(return_value=True)
|
||||
mock_complete.return_value = AsyncMock(return_value=True)
|
||||
mock_failed.return_value = AsyncMock(return_value=True)
|
||||
|
||||
yield {
|
||||
'setup': mock_setup,
|
||||
'cleanup': mock_cleanup,
|
||||
'start': mock_start,
|
||||
'complete': mock_complete,
|
||||
'failed': mock_failed
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_data_service():
|
||||
"""Mock external data service responses"""
|
||||
mock_sales_data = [
|
||||
{
|
||||
"date": "2024-01-01",
|
||||
"product_name": "Pan Integral",
|
||||
"quantity": 45
|
||||
},
|
||||
{
|
||||
"date": "2024-01-02",
|
||||
"product_name": "Pan Integral",
|
||||
"quantity": 52
|
||||
}
|
||||
]
|
||||
|
||||
mock_weather_data = [
|
||||
{
|
||||
"date": "2024-01-01",
|
||||
"temperature": 15.2,
|
||||
"precipitation": 0.0,
|
||||
"humidity": 65
|
||||
},
|
||||
{
|
||||
"date": "2024-01-02",
|
||||
"temperature": 18.1,
|
||||
"precipitation": 2.5,
|
||||
"humidity": 72
|
||||
}
|
||||
]
|
||||
|
||||
mock_traffic_data = [
|
||||
{
|
||||
"date": "2024-01-01",
|
||||
"traffic_volume": 120
|
||||
},
|
||||
{
|
||||
"date": "2024-01-02",
|
||||
"traffic_volume": 95
|
||||
}
|
||||
]
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"sales": mock_sales_data,
|
||||
"weather": mock_weather_data,
|
||||
"traffic": mock_traffic_data
|
||||
}
|
||||
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
|
||||
|
||||
yield {
|
||||
'sales': mock_sales_data,
|
||||
'weather': mock_weather_data,
|
||||
'traffic': mock_traffic_data
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ml_trainer():
|
||||
"""Mock ML trainer for testing"""
|
||||
with patch('app.ml.trainer.BakeryMLTrainer') as mock_trainer_class:
|
||||
mock_trainer = Mock(spec=BakeryMLTrainer)
|
||||
|
||||
# Mock training results
|
||||
mock_training_results = {
|
||||
"job_id": "test-job-123",
|
||||
"tenant_id": "test-tenant",
|
||||
"status": "completed",
|
||||
"products_trained": 1,
|
||||
"products_failed": 0,
|
||||
"total_products": 1,
|
||||
"training_results": {
|
||||
"Pan Integral": {
|
||||
"status": "success",
|
||||
"model_info": {
|
||||
"model_id": "test-model-123",
|
||||
"model_path": "/test/models/test-model-123.pkl",
|
||||
"type": "prophet",
|
||||
"training_samples": 100,
|
||||
"features": ["temperature", "humidity"],
|
||||
"training_metrics": {
|
||||
"mae": 5.2,
|
||||
"rmse": 7.8,
|
||||
"mape": 12.5,
|
||||
"r2_score": 0.85
|
||||
},
|
||||
"data_period": {
|
||||
"start_date": "2024-01-01",
|
||||
"end_date": "2024-01-31"
|
||||
}
|
||||
},
|
||||
"data_points": 100
|
||||
}
|
||||
},
|
||||
"summary": {
|
||||
"success_rate": 100.0,
|
||||
"total_products": 1,
|
||||
"successful_products": 1,
|
||||
"failed_products": 0
|
||||
}
|
||||
}
|
||||
|
||||
mock_trainer.train_tenant_models.return_value = AsyncMock(return_value=mock_training_results)
|
||||
mock_trainer.train_single_product.return_value = AsyncMock(return_value={
|
||||
"status": "success",
|
||||
"model_info": mock_training_results["training_results"]["Pan Integral"]["model_info"]
|
||||
})
|
||||
|
||||
mock_trainer_class.return_value = mock_trainer
|
||||
yield mock_trainer
|
||||
|
||||
@pytest.fixture
|
||||
def sample_training_job() -> dict:
|
||||
"""Sample training job data"""
|
||||
return {
|
||||
"job_id": "test-job-123",
|
||||
"tenant_id": "test-tenant",
|
||||
"status": "pending",
|
||||
"progress": 0,
|
||||
"current_step": "Initializing",
|
||||
"config": {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 30
|
||||
}
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def sample_trained_model() -> dict:
|
||||
"""Sample trained model data"""
|
||||
return {
|
||||
"model_id": "test-model-123",
|
||||
"tenant_id": "test-tenant",
|
||||
"product_name": "Pan Integral",
|
||||
"model_type": "prophet",
|
||||
"model_path": "/test/models/test-model-123.pkl",
|
||||
"version": 1,
|
||||
"training_samples": 100,
|
||||
"features": ["temperature", "humidity", "traffic_volume"],
|
||||
"hyperparameters": {
|
||||
"seasonality_mode": "additive",
|
||||
"daily_seasonality": True,
|
||||
"weekly_seasonality": True
|
||||
},
|
||||
"training_metrics": {
|
||||
"mae": 5.2,
|
||||
"rmse": 7.8,
|
||||
"mape": 12.5,
|
||||
"r2_score": 0.85
|
||||
},
|
||||
"is_active": True
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
async def training_job_in_db(test_db_session, sample_training_job):
|
||||
"""Create a training job in the test database"""
|
||||
training_log = ModelTrainingLog(**sample_training_job)
|
||||
test_db_session.add(training_log)
|
||||
await test_db_session.commit()
|
||||
await test_db_session.refresh(training_log)
|
||||
return training_log
|
||||
|
||||
@pytest.fixture
|
||||
async def trained_model_in_db(test_db_session, sample_trained_model):
|
||||
"""Create a trained model in the test database"""
|
||||
from datetime import datetime
|
||||
|
||||
model_data = sample_trained_model.copy()
|
||||
model_data.update({
|
||||
"data_period_start": datetime(2024, 1, 1),
|
||||
"data_period_end": datetime(2024, 1, 31),
|
||||
"created_at": datetime.now()
|
||||
})
|
||||
|
||||
trained_model = TrainedModel(**model_data)
|
||||
test_db_session.add(trained_model)
|
||||
await test_db_session.commit()
|
||||
await test_db_session.refresh(trained_model)
|
||||
return trained_model
|
||||
|
||||
@pytest.fixture
|
||||
def mock_prophet_manager():
|
||||
"""Mock Prophet manager for testing"""
|
||||
with patch('app.ml.prophet_manager.BakeryProphetManager') as mock_manager_class:
|
||||
mock_manager = Mock(spec=BakeryProphetManager)
|
||||
|
||||
mock_model_info = {
|
||||
"model_id": "test-model-123",
|
||||
"model_path": "/test/models/test-model-123.pkl",
|
||||
"type": "prophet",
|
||||
"training_samples": 100,
|
||||
"features": ["temperature", "humidity"],
|
||||
"training_metrics": {
|
||||
"mae": 5.2,
|
||||
"rmse": 7.8,
|
||||
"mape": 12.5,
|
||||
"r2_score": 0.85
|
||||
}
|
||||
}
|
||||
|
||||
mock_manager.train_bakery_model.return_value = AsyncMock(return_value=mock_model_info)
|
||||
mock_manager.generate_forecast.return_value = AsyncMock()
|
||||
|
||||
mock_manager_class.return_value = mock_manager
|
||||
yield mock_manager
|
||||
|
||||
@pytest.fixture
|
||||
def mock_data_processor():
|
||||
"""Mock data processor for testing"""
|
||||
import pandas as pd
|
||||
|
||||
with patch('app.ml.data_processor.BakeryDataProcessor') as mock_processor_class:
|
||||
mock_processor = Mock(spec=BakeryDataProcessor)
|
||||
|
||||
# Mock processed data
|
||||
mock_processed_data = pd.DataFrame({
|
||||
'ds': pd.date_range('2024-01-01', periods=30, freq='D'),
|
||||
'y': [45 + i for i in range(30)],
|
||||
'temperature': [15.0 + (i % 10) for i in range(30)],
|
||||
'humidity': [60.0 + (i % 20) for i in range(30)]
|
||||
})
|
||||
|
||||
mock_processor.prepare_training_data.return_value = AsyncMock(return_value=mock_processed_data)
|
||||
mock_processor.prepare_prediction_features.return_value = AsyncMock(return_value=mock_processed_data)
|
||||
|
||||
mock_processor_class.return_value = mock_processor
|
||||
yield mock_processor
|
||||
|
||||
@pytest.fixture
|
||||
def mock_auth():
|
||||
"""Mock authentication for tests"""
|
||||
with patch('shared.auth.decorators.require_auth') as mock_auth:
|
||||
mock_auth.return_value = lambda func: func # Pass through without auth
|
||||
yield mock_auth
|
||||
|
||||
# Helper functions for tests
|
||||
def assert_training_job_structure(job_data: dict):
|
||||
"""Assert that training job data has correct structure"""
|
||||
required_fields = ["job_id", "status", "tenant_id", "created_at"]
|
||||
for field in required_fields:
|
||||
assert field in job_data, f"Missing required field: {field}"
|
||||
|
||||
def assert_model_structure(model_data: dict):
|
||||
"""Assert that model data has correct structure"""
|
||||
required_fields = ["model_id", "model_type", "training_samples", "features"]
|
||||
for field in required_fields:
|
||||
assert field in model_data, f"Missing required field: {field}"
|
||||
686
services/training/tests/test_api.py
Normal file
686
services/training/tests/test_api.py
Normal file
@@ -0,0 +1,686 @@
|
||||
# services/training/tests/test_api.py
|
||||
"""
|
||||
Tests for training service API endpoints
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from fastapi import status
|
||||
from httpx import AsyncClient
|
||||
|
||||
from app.schemas.training import TrainingJobRequest
|
||||
|
||||
|
||||
class TestTrainingAPI:
|
||||
"""Test training API endpoints"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check(self, test_client: AsyncClient):
|
||||
"""Test health check endpoint"""
|
||||
response = await test_client.get("/health")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["service"] == "training-service"
|
||||
assert data["version"] == "1.0.0"
|
||||
assert "status" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_readiness_check_ready(self, test_client: AsyncClient):
|
||||
"""Test readiness check when service is ready"""
|
||||
# Mock app state as ready
|
||||
with patch('app.main.app.state.ready', True):
|
||||
response = await test_client.get("/health/ready")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["status"] == "ready"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_readiness_check_not_ready(self, test_client: AsyncClient):
|
||||
"""Test readiness check when service is not ready"""
|
||||
with patch('app.main.app.state.ready', False):
|
||||
response = await test_client.get("/health/ready")
|
||||
|
||||
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
data = response.json()
|
||||
assert data["status"] == "not_ready"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_liveness_check_healthy(self, test_client: AsyncClient):
|
||||
"""Test liveness check when service is healthy"""
|
||||
with patch('app.core.database.get_db_health', return_value=AsyncMock(return_value=True)):
|
||||
response = await test_client.get("/health/live")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["status"] == "alive"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_liveness_check_unhealthy(self, test_client: AsyncClient):
|
||||
"""Test liveness check when database is unhealthy"""
|
||||
with patch('app.core.database.get_db_health', return_value=AsyncMock(return_value=False)):
|
||||
response = await test_client.get("/health/live")
|
||||
|
||||
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
data = response.json()
|
||||
assert data["status"] == "unhealthy"
|
||||
assert data["reason"] == "database_unavailable"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metrics_endpoint(self, test_client: AsyncClient):
|
||||
"""Test metrics endpoint"""
|
||||
response = await test_client.get("/metrics")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
expected_metrics = [
|
||||
"training_jobs_active",
|
||||
"training_jobs_completed",
|
||||
"training_jobs_failed",
|
||||
"models_trained_total",
|
||||
"uptime_seconds"
|
||||
]
|
||||
|
||||
for metric in expected_metrics:
|
||||
assert metric in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_root_endpoint(self, test_client: AsyncClient):
|
||||
"""Test root endpoint"""
|
||||
response = await test_client.get("/")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["service"] == "training-service"
|
||||
assert data["version"] == "1.0.0"
|
||||
assert "description" in data
|
||||
|
||||
|
||||
class TestTrainingJobsAPI:
|
||||
"""Test training jobs API endpoints"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_training_job_success(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
mock_messaging,
|
||||
mock_ml_trainer,
|
||||
mock_data_service
|
||||
):
|
||||
"""Test starting a training job successfully"""
|
||||
request_data = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 30,
|
||||
"seasonality_mode": "additive"
|
||||
}
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs", json=request_data)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert "job_id" in data
|
||||
assert data["status"] == "started"
|
||||
assert data["tenant_id"] == "test-tenant"
|
||||
assert "estimated_duration_minutes" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_training_job_validation_error(self, test_client: AsyncClient):
|
||||
"""Test starting training job with validation error"""
|
||||
request_data = {
|
||||
"seasonality_mode": "invalid_mode", # Invalid value
|
||||
"min_data_points": 5 # Too low
|
||||
}
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs", json=request_data)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_training_status_existing_job(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
training_job_in_db
|
||||
):
|
||||
"""Test getting status of existing training job"""
|
||||
job_id = training_job_in_db.job_id
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.get(f"/training/jobs/{job_id}/status")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert data["job_id"] == job_id
|
||||
assert data["status"] == "pending"
|
||||
assert "progress" in data
|
||||
assert "started_at" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_training_status_nonexistent_job(self, test_client: AsyncClient):
|
||||
"""Test getting status of non-existent training job"""
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.get("/training/jobs/nonexistent-job/status")
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_training_jobs(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
training_job_in_db
|
||||
):
|
||||
"""Test listing training jobs"""
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.get("/training/jobs")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert isinstance(data, list)
|
||||
assert len(data) >= 1
|
||||
|
||||
# Check first job structure
|
||||
job = data[0]
|
||||
assert "job_id" in job
|
||||
assert "status" in job
|
||||
assert "started_at" in job
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_training_jobs_with_status_filter(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
training_job_in_db
|
||||
):
|
||||
"""Test listing training jobs with status filter"""
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.get("/training/jobs?status=pending")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert isinstance(data, list)
|
||||
# All jobs should have status "pending"
|
||||
for job in data:
|
||||
assert job["status"] == "pending"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_training_job_success(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
training_job_in_db,
|
||||
mock_messaging
|
||||
):
|
||||
"""Test cancelling a training job successfully"""
|
||||
job_id = training_job_in_db.job_id
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post(f"/training/jobs/{job_id}/cancel")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert "message" in data
|
||||
assert "cancelled" in data["message"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_nonexistent_job(self, test_client: AsyncClient):
|
||||
"""Test cancelling a non-existent training job"""
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs/nonexistent-job/cancel")
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_training_logs(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
training_job_in_db
|
||||
):
|
||||
"""Test getting training logs"""
|
||||
job_id = training_job_in_db.job_id
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.get(f"/training/jobs/{job_id}/logs")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert "job_id" in data
|
||||
assert "logs" in data
|
||||
assert isinstance(data["logs"], list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_training_data_valid(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
mock_data_service
|
||||
):
|
||||
"""Test validating valid training data"""
|
||||
request_data = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 30
|
||||
}
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/validate", json=request_data)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert "is_valid" in data
|
||||
assert "issues" in data
|
||||
assert "recommendations" in data
|
||||
assert "estimated_training_time" in data
|
||||
|
||||
|
||||
class TestSingleProductTrainingAPI:
|
||||
"""Test single product training API endpoints"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_train_single_product_success(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
mock_messaging,
|
||||
mock_ml_trainer,
|
||||
mock_data_service
|
||||
):
|
||||
"""Test training a single product successfully"""
|
||||
product_name = "Pan Integral"
|
||||
request_data = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"seasonality_mode": "additive"
|
||||
}
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post(
|
||||
f"/training/products/{product_name}",
|
||||
json=request_data
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert "job_id" in data
|
||||
assert data["status"] == "started"
|
||||
assert data["tenant_id"] == "test-tenant"
|
||||
assert f"training started for {product_name}" in data["message"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_train_single_product_validation_error(self, test_client: AsyncClient):
|
||||
"""Test single product training with validation error"""
|
||||
product_name = "Pan Integral"
|
||||
request_data = {
|
||||
"seasonality_mode": "invalid_mode" # Invalid value
|
||||
}
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post(
|
||||
f"/training/products/{product_name}",
|
||||
json=request_data
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_train_single_product_special_characters(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
mock_messaging,
|
||||
mock_ml_trainer,
|
||||
mock_data_service
|
||||
):
|
||||
"""Test training product with special characters in name"""
|
||||
product_name = "Pan Francés" # With accent
|
||||
request_data = {
|
||||
"include_weather": True,
|
||||
"seasonality_mode": "additive"
|
||||
}
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post(
|
||||
f"/training/products/{product_name}",
|
||||
json=request_data
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert "job_id" in data
|
||||
|
||||
|
||||
class TestModelsAPI:
|
||||
"""Test models API endpoints"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_models(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
trained_model_in_db
|
||||
):
|
||||
"""Test listing trained models"""
|
||||
with patch('app.api.models.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.get("/models")
|
||||
|
||||
# This endpoint might not exist yet, so we expect either 200 or 404
|
||||
assert response.status_code in [status.HTTP_200_OK, status.HTTP_404_NOT_FOUND]
|
||||
|
||||
if response.status_code == status.HTTP_200_OK:
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_model_details(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
trained_model_in_db
|
||||
):
|
||||
"""Test getting model details"""
|
||||
model_id = trained_model_in_db.model_id
|
||||
|
||||
with patch('app.api.models.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.get(f"/models/{model_id}")
|
||||
|
||||
# This endpoint might not exist yet
|
||||
assert response.status_code in [
|
||||
status.HTTP_200_OK,
|
||||
status.HTTP_404_NOT_FOUND,
|
||||
status.HTTP_501_NOT_IMPLEMENTED
|
||||
]
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Test error handling in API endpoints"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_database_error_handling(self, test_client: AsyncClient):
|
||||
"""Test handling of database errors"""
|
||||
with patch('app.services.training_service.TrainingService.create_training_job') as mock_create:
|
||||
mock_create.side_effect = Exception("Database connection failed")
|
||||
|
||||
request_data = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 30
|
||||
}
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs", json=request_data)
|
||||
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_tenant_id(self, test_client: AsyncClient):
|
||||
"""Test handling when tenant ID is missing"""
|
||||
request_data = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 30
|
||||
}
|
||||
|
||||
# Don't mock get_current_tenant_id to simulate missing auth
|
||||
response = await test_client.post("/training/jobs", json=request_data)
|
||||
|
||||
# Should fail due to missing authentication
|
||||
assert response.status_code in [status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_job_id_format(self, test_client: AsyncClient):
|
||||
"""Test handling of invalid job ID format"""
|
||||
invalid_job_id = "invalid-job-id-with-special-chars@#$"
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.get(f"/training/jobs/{invalid_job_id}/status")
|
||||
|
||||
# Should handle gracefully
|
||||
assert response.status_code in [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messaging_failure_handling(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
mock_data_service
|
||||
):
|
||||
"""Test handling when messaging fails"""
|
||||
request_data = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 30
|
||||
}
|
||||
|
||||
with patch('app.services.messaging.publish_job_started', side_effect=Exception("Messaging failed")), \
|
||||
patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
|
||||
response = await test_client.post("/training/jobs", json=request_data)
|
||||
|
||||
# Should still succeed even if messaging fails
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert "job_id" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_json_payload(self, test_client: AsyncClient):
|
||||
"""Test handling of invalid JSON payload"""
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post(
|
||||
"/training/jobs",
|
||||
content="invalid json {{{",
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsupported_content_type(self, test_client: AsyncClient):
|
||||
"""Test handling of unsupported content type"""
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post(
|
||||
"/training/jobs",
|
||||
content="some text data",
|
||||
headers={"Content-Type": "text/plain"}
|
||||
)
|
||||
|
||||
assert response.status_code in [
|
||||
status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
|
||||
status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
]
|
||||
|
||||
|
||||
class TestAuthenticationIntegration:
|
||||
"""Test authentication integration"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_endpoints_require_auth(self, test_client: AsyncClient):
|
||||
"""Test that endpoints require authentication in production"""
|
||||
# This test would be more meaningful in a production environment
|
||||
# where authentication is actually enforced
|
||||
|
||||
endpoints_to_test = [
|
||||
("POST", "/training/jobs"),
|
||||
("GET", "/training/jobs"),
|
||||
("POST", "/training/products/Pan Integral"),
|
||||
("POST", "/training/validate")
|
||||
]
|
||||
|
||||
for method, endpoint in endpoints_to_test:
|
||||
if method == "POST":
|
||||
response = await test_client.post(endpoint, json={})
|
||||
else:
|
||||
response = await test_client.get(endpoint)
|
||||
|
||||
# In test environment with mocked auth, should work
|
||||
# In production, would require valid authentication
|
||||
assert response.status_code != status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tenant_isolation_in_api(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
training_job_in_db
|
||||
):
|
||||
"""Test tenant isolation at API level"""
|
||||
job_id = training_job_in_db.job_id
|
||||
|
||||
# Try to access job with different tenant
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="different-tenant"):
|
||||
response = await test_client.get(f"/training/jobs/{job_id}/status")
|
||||
|
||||
# Should not find job for different tenant
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
class TestAPIValidation:
|
||||
"""Test API validation and input handling"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_training_request_validation(self, test_client: AsyncClient):
|
||||
"""Test comprehensive training request validation"""
|
||||
|
||||
# Test valid request
|
||||
valid_request = {
|
||||
"include_weather": True,
|
||||
"include_traffic": False,
|
||||
"min_data_points": 30,
|
||||
"seasonality_mode": "additive",
|
||||
"daily_seasonality": True,
|
||||
"weekly_seasonality": True,
|
||||
"yearly_seasonality": True
|
||||
}
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs", json=valid_request)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
# Test invalid seasonality mode
|
||||
invalid_request = valid_request.copy()
|
||||
invalid_request["seasonality_mode"] = "invalid_mode"
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs", json=invalid_request)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
# Test invalid min_data_points
|
||||
invalid_request = valid_request.copy()
|
||||
invalid_request["min_data_points"] = 5 # Too low
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs", json=invalid_request)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_product_request_validation(self, test_client: AsyncClient):
|
||||
"""Test single product training request validation"""
|
||||
|
||||
product_name = "Pan Integral"
|
||||
|
||||
# Test valid request
|
||||
valid_request = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"seasonality_mode": "multiplicative"
|
||||
}
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post(
|
||||
f"/training/products/{product_name}",
|
||||
json=valid_request
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
# Test empty product name
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post(
|
||||
"/training/products/",
|
||||
json=valid_request
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_parameter_validation(self, test_client: AsyncClient):
|
||||
"""Test query parameter validation"""
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
# Test valid limit parameter
|
||||
response = await test_client.get("/training/jobs?limit=5")
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
# Test invalid limit parameter
|
||||
response = await test_client.get("/training/jobs?limit=invalid")
|
||||
assert response.status_code in [
|
||||
status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
status.HTTP_400_BAD_REQUEST
|
||||
]
|
||||
|
||||
# Test negative limit
|
||||
response = await test_client.get("/training/jobs?limit=-1")
|
||||
assert response.status_code in [
|
||||
status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
status.HTTP_400_BAD_REQUEST
|
||||
]
|
||||
|
||||
|
||||
class TestAPIPerformance:
|
||||
"""Test API performance characteristics"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_requests(self, test_client: AsyncClient):
|
||||
"""Test handling of concurrent requests"""
|
||||
import asyncio
|
||||
|
||||
# Create multiple concurrent requests
|
||||
tasks = []
|
||||
for i in range(10):
|
||||
with patch('app.api.training.get_current_tenant_id', return_value=f"tenant-{i}"):
|
||||
task = test_client.get("/health")
|
||||
tasks.append(task)
|
||||
|
||||
responses = await asyncio.gather(*tasks)
|
||||
|
||||
# All requests should succeed
|
||||
for response in responses:
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_payload_handling(self, test_client: AsyncClient):
|
||||
"""Test handling of large request payloads"""
|
||||
|
||||
# Create large request payload
|
||||
large_request = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 30,
|
||||
"large_config": {f"key_{i}": f"value_{i}" for i in range(1000)}
|
||||
}
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs", json=large_request)
|
||||
|
||||
# Should handle large payload gracefully
|
||||
assert response.status_code in [
|
||||
status.HTTP_200_OK,
|
||||
status.HTTP_413_REQUEST_ENTITY_TOO_LARGE
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rapid_successive_requests(self, test_client: AsyncClient):
|
||||
"""Test rapid successive requests to same endpoint"""
|
||||
|
||||
# Make rapid requests
|
||||
responses = []
|
||||
for _ in range(20):
|
||||
response = await test_client.get("/health")
|
||||
responses.append(response)
|
||||
|
||||
# All should succeed
|
||||
for response in responses:
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
848
services/training/tests/test_integration.py
Normal file
848
services/training/tests/test_integration.py
Normal file
@@ -0,0 +1,848 @@
|
||||
# services/training/tests/test_integration.py
|
||||
"""
|
||||
Integration tests for training service
|
||||
Tests complete workflows and service interactions
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from httpx import AsyncClient
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from app.main import app
|
||||
from app.schemas.training import TrainingJobRequest
|
||||
|
||||
|
||||
class TestTrainingWorkflowIntegration:
|
||||
"""Test complete training workflows end-to-end"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_training_workflow(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
test_db_session,
|
||||
mock_messaging,
|
||||
mock_data_service,
|
||||
mock_ml_trainer
|
||||
):
|
||||
"""Test complete training workflow from API to completion"""
|
||||
|
||||
# Step 1: Start training job
|
||||
request_data = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 30,
|
||||
"seasonality_mode": "additive"
|
||||
}
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
job_data = response.json()
|
||||
job_id = job_data["job_id"]
|
||||
|
||||
# Step 2: Check initial status
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.get(f"/training/jobs/{job_id}/status")
|
||||
|
||||
assert response.status_code == 200
|
||||
status_data = response.json()
|
||||
assert status_data["status"] in ["pending", "started"]
|
||||
|
||||
# Step 3: Simulate background task completion
|
||||
# In real scenario, this would be handled by background tasks
|
||||
await asyncio.sleep(0.1) # Allow background task to start
|
||||
|
||||
# Step 4: Check completion status
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.get(f"/training/jobs/{job_id}/status")
|
||||
|
||||
# The job should exist in database even if not completed yet
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_product_training_workflow(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
mock_messaging,
|
||||
mock_data_service,
|
||||
mock_ml_trainer
|
||||
):
|
||||
"""Test single product training complete workflow"""
|
||||
|
||||
product_name = "Pan Integral"
|
||||
request_data = {
|
||||
"include_weather": True,
|
||||
"include_traffic": False,
|
||||
"seasonality_mode": "additive"
|
||||
}
|
||||
|
||||
# Start single product training
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post(
|
||||
f"/training/products/{product_name}",
|
||||
json=request_data
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
job_data = response.json()
|
||||
job_id = job_data["job_id"]
|
||||
assert f"training started for {product_name}" in job_data["message"].lower()
|
||||
|
||||
# Check job status
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.get(f"/training/jobs/{job_id}/status")
|
||||
|
||||
assert response.status_code == 200
|
||||
status_data = response.json()
|
||||
assert status_data["job_id"] == job_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_training_validation_workflow(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
mock_data_service
|
||||
):
|
||||
"""Test training data validation workflow"""
|
||||
|
||||
request_data = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 30
|
||||
}
|
||||
|
||||
# Validate training data
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/validate", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
validation_data = response.json()
|
||||
|
||||
assert "is_valid" in validation_data
|
||||
assert "issues" in validation_data
|
||||
assert "recommendations" in validation_data
|
||||
assert "estimated_training_time" in validation_data
|
||||
|
||||
# If validation passes, start actual training
|
||||
if validation_data["is_valid"]:
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_job_cancellation_workflow(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
training_job_in_db,
|
||||
mock_messaging
|
||||
):
|
||||
"""Test job cancellation workflow"""
|
||||
|
||||
job_id = training_job_in_db.job_id
|
||||
|
||||
# Check initial status
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.get(f"/training/jobs/{job_id}/status")
|
||||
|
||||
assert response.status_code == 200
|
||||
initial_status = response.json()
|
||||
assert initial_status["status"] == "pending"
|
||||
|
||||
# Cancel the job
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post(f"/training/jobs/{job_id}/cancel")
|
||||
|
||||
assert response.status_code == 200
|
||||
cancel_response = response.json()
|
||||
assert "cancelled" in cancel_response["message"].lower()
|
||||
|
||||
# Verify cancellation
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.get(f"/training/jobs/{job_id}/status")
|
||||
|
||||
assert response.status_code == 200
|
||||
final_status = response.json()
|
||||
assert final_status["status"] == "cancelled"
|
||||
|
||||
|
||||
class TestServiceInteractionIntegration:
|
||||
"""Test interactions between training service and external services"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_data_service_integration(self, training_service, mock_data_service):
|
||||
"""Test integration with data service"""
|
||||
from app.schemas.training import TrainingJobRequest
|
||||
|
||||
request = TrainingJobRequest(
|
||||
include_weather=True,
|
||||
include_traffic=True,
|
||||
min_data_points=30
|
||||
)
|
||||
|
||||
# Test sales data fetching
|
||||
sales_data = await training_service._fetch_sales_data("test-tenant", request)
|
||||
assert isinstance(sales_data, list)
|
||||
|
||||
# Test weather data fetching
|
||||
weather_data = await training_service._fetch_weather_data("test-tenant", request)
|
||||
assert isinstance(weather_data, list)
|
||||
|
||||
# Test traffic data fetching
|
||||
traffic_data = await training_service._fetch_traffic_data("test-tenant", request)
|
||||
assert isinstance(traffic_data, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messaging_integration(self, mock_messaging):
|
||||
"""Test integration with messaging system"""
|
||||
from app.services.messaging import (
|
||||
publish_job_started,
|
||||
publish_job_completed,
|
||||
publish_model_trained
|
||||
)
|
||||
|
||||
# Test various message types
|
||||
result1 = await publish_job_started("job-123", "tenant-123", {})
|
||||
result2 = await publish_job_completed("job-123", "tenant-123", {"status": "success"})
|
||||
result3 = await publish_model_trained("model-123", "tenant-123", "Pan Integral", {"mae": 5.0})
|
||||
|
||||
assert result1 is True
|
||||
assert result2 is True
|
||||
assert result3 is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_database_integration(self, test_db_session, training_service):
|
||||
"""Test database operations integration"""
|
||||
|
||||
# Create a training job
|
||||
job = await training_service.create_training_job(
|
||||
db=test_db_session,
|
||||
tenant_id="test-tenant",
|
||||
job_id="integration-test-job",
|
||||
config={"test": True}
|
||||
)
|
||||
|
||||
assert job.job_id == "integration-test-job"
|
||||
|
||||
# Update job status
|
||||
await training_service._update_job_status(
|
||||
db=test_db_session,
|
||||
job_id=job.job_id,
|
||||
status="running",
|
||||
progress=50,
|
||||
current_step="Processing data"
|
||||
)
|
||||
|
||||
# Retrieve updated job
|
||||
updated_job = await training_service.get_job_status(
|
||||
db=test_db_session,
|
||||
job_id=job.job_id,
|
||||
tenant_id="test-tenant"
|
||||
)
|
||||
|
||||
assert updated_job.status == "running"
|
||||
assert updated_job.progress == 50
|
||||
|
||||
|
||||
class TestErrorHandlingIntegration:
|
||||
"""Test error handling across service boundaries"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_data_service_failure_handling(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
mock_messaging
|
||||
):
|
||||
"""Test handling when data service is unavailable"""
|
||||
|
||||
request_data = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 30
|
||||
}
|
||||
|
||||
# Mock data service failure
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_client.return_value.__aenter__.return_value.get.side_effect = Exception("Service unavailable")
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs", json=request_data)
|
||||
|
||||
# Should still create job but might fail during execution
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messaging_failure_handling(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
mock_data_service
|
||||
):
|
||||
"""Test handling when messaging fails"""
|
||||
|
||||
request_data = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 30
|
||||
}
|
||||
|
||||
# Mock messaging failure
|
||||
with patch('app.services.messaging.publish_job_started', side_effect=Exception("Messaging failed")):
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs", json=request_data)
|
||||
|
||||
# Should still succeed even if messaging fails
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ml_training_failure_handling(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
mock_messaging,
|
||||
mock_data_service
|
||||
):
|
||||
"""Test handling when ML training fails"""
|
||||
|
||||
request_data = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 30
|
||||
}
|
||||
|
||||
# Mock ML training failure
|
||||
with patch('app.ml.trainer.BakeryMLTrainer.train_tenant_models', side_effect=Exception("ML training failed")):
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs", json=request_data)
|
||||
|
||||
# Job should be created successfully
|
||||
assert response.status_code == 200
|
||||
|
||||
# Background task would handle the failure
|
||||
|
||||
|
||||
class TestPerformanceIntegration:
|
||||
"""Test performance characteristics of integrated workflows"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_training_jobs(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
mock_messaging,
|
||||
mock_data_service,
|
||||
mock_ml_trainer
|
||||
):
|
||||
"""Test handling multiple concurrent training jobs"""
|
||||
|
||||
request_data = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 30
|
||||
}
|
||||
|
||||
# Start multiple jobs concurrently
|
||||
tasks = []
|
||||
for i in range(5):
|
||||
with patch('app.api.training.get_current_tenant_id', return_value=f"tenant-{i}"):
|
||||
task = test_client.post("/training/jobs", json=request_data)
|
||||
tasks.append(task)
|
||||
|
||||
responses = await asyncio.gather(*tasks)
|
||||
|
||||
# All jobs should be created successfully
|
||||
for response in responses:
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "job_id" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_dataset_handling(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session
|
||||
):
|
||||
"""Test handling of large datasets"""
|
||||
|
||||
# Simulate large dataset
|
||||
large_config = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 1000, # Large minimum
|
||||
"products": [f"Product-{i}" for i in range(100)] # Many products
|
||||
}
|
||||
|
||||
job = await training_service.create_training_job(
|
||||
db=test_db_session,
|
||||
tenant_id="test-tenant",
|
||||
job_id="large-dataset-job",
|
||||
config=large_config
|
||||
)
|
||||
|
||||
assert job.config == large_config
|
||||
assert job.job_id == "large-dataset-job"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rapid_status_checks(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
training_job_in_db
|
||||
):
|
||||
"""Test rapid successive status checks"""
|
||||
|
||||
job_id = training_job_in_db.job_id
|
||||
|
||||
# Make many rapid status requests
|
||||
tasks = []
|
||||
for _ in range(20):
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
task = test_client.get(f"/training/jobs/{job_id}/status")
|
||||
tasks.append(task)
|
||||
|
||||
responses = await asyncio.gather(*tasks)
|
||||
|
||||
# All requests should succeed
|
||||
for response in responses:
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
class TestSecurityIntegration:
|
||||
"""Test security aspects of service integration"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tenant_isolation(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
training_job_in_db,
|
||||
mock_messaging
|
||||
):
|
||||
"""Test that tenants cannot access each other's jobs"""
|
||||
|
||||
job_id = training_job_in_db.job_id
|
||||
|
||||
# Try to access job with different tenant ID
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="different-tenant"):
|
||||
response = await test_client.get(f"/training/jobs/{job_id}/status")
|
||||
|
||||
# Should not find the job (belongs to different tenant)
|
||||
assert response.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_input_validation_integration(
|
||||
self,
|
||||
test_client: AsyncClient
|
||||
):
|
||||
"""Test input validation across API boundaries"""
|
||||
|
||||
# Test invalid seasonality mode
|
||||
invalid_request = {
|
||||
"seasonality_mode": "invalid_mode",
|
||||
"min_data_points": -5 # Invalid negative value
|
||||
}
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs", json=invalid_request)
|
||||
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sql_injection_protection(
|
||||
self,
|
||||
test_client: AsyncClient
|
||||
):
|
||||
"""Test protection against SQL injection attempts"""
|
||||
|
||||
# Try SQL injection in job ID
|
||||
malicious_job_id = "job'; DROP TABLE model_training_logs; --"
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.get(f"/training/jobs/{malicious_job_id}/status")
|
||||
|
||||
# Should return 404, not cause database error
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
class TestRecoveryIntegration:
|
||||
"""Test recovery and resilience scenarios"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_restart_recovery(
|
||||
self,
|
||||
test_db_session,
|
||||
training_service,
|
||||
training_job_in_db
|
||||
):
|
||||
"""Test service recovery after restart"""
|
||||
|
||||
# Simulate service restart by creating new service instance
|
||||
new_training_service = training_service.__class__()
|
||||
|
||||
# Should be able to access existing jobs
|
||||
existing_job = await new_training_service.get_job_status(
|
||||
db=test_db_session,
|
||||
job_id=training_job_in_db.job_id,
|
||||
tenant_id=training_job_in_db.tenant_id
|
||||
)
|
||||
|
||||
assert existing_job is not None
|
||||
assert existing_job.job_id == training_job_in_db.job_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_failure_recovery(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session
|
||||
):
|
||||
"""Test recovery from partial failures"""
|
||||
|
||||
# Create job that might fail partway through
|
||||
job = await training_service.create_training_job(
|
||||
db=test_db_session,
|
||||
tenant_id="test-tenant",
|
||||
job_id="partial-failure-job",
|
||||
config={"simulate_failure": True}
|
||||
)
|
||||
|
||||
# Simulate partial progress
|
||||
await training_service._update_job_status(
|
||||
db=test_db_session,
|
||||
job_id=job.job_id,
|
||||
status="running",
|
||||
progress=50,
|
||||
current_step="Halfway through training"
|
||||
)
|
||||
|
||||
# Simulate failure
|
||||
await training_service._update_job_status(
|
||||
db=test_db_session,
|
||||
job_id=job.job_id,
|
||||
status="failed",
|
||||
progress=50,
|
||||
current_step="Training failed",
|
||||
error_message="Simulated failure"
|
||||
)
|
||||
|
||||
# Verify failure was recorded
|
||||
failed_job = await training_service.get_job_status(
|
||||
db=test_db_session,
|
||||
job_id=job.job_id,
|
||||
tenant_id="test-tenant"
|
||||
)
|
||||
|
||||
assert failed_job.status == "failed"
|
||||
assert failed_job.error_message == "Simulated failure"
|
||||
assert failed_job.progress == 50
|
||||
|
||||
|
||||
class TestComplianceIntegration:
|
||||
"""Test compliance and audit requirements"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audit_trail_creation(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session
|
||||
):
|
||||
"""Test that audit trail is properly created"""
|
||||
|
||||
# Create and update job
|
||||
job = await training_service.create_training_job(
|
||||
db=test_db_session,
|
||||
tenant_id="test-tenant",
|
||||
job_id="audit-test-job",
|
||||
config={"audit_test": True}
|
||||
)
|
||||
|
||||
# Multiple status updates
|
||||
await training_service._update_job_status(
|
||||
db=test_db_session,
|
||||
job_id=job.job_id,
|
||||
status="running",
|
||||
progress=25,
|
||||
current_step="Started processing"
|
||||
)
|
||||
|
||||
await training_service._update_job_status(
|
||||
db=test_db_session,
|
||||
job_id=job.job_id,
|
||||
status="running",
|
||||
progress=75,
|
||||
current_step="Almost complete"
|
||||
)
|
||||
|
||||
await training_service._update_job_status(
|
||||
db=test_db_session,
|
||||
job_id=job.job_id,
|
||||
status="completed",
|
||||
progress=100,
|
||||
current_step="Completed successfully"
|
||||
)
|
||||
|
||||
# Verify audit trail
|
||||
logs = await training_service.get_training_logs(
|
||||
db=test_db_session,
|
||||
job_id=job.job_id,
|
||||
tenant_id="test-tenant"
|
||||
)
|
||||
|
||||
assert logs is not None
|
||||
assert len(logs) > 0
|
||||
|
||||
# Check final status
|
||||
final_job = await training_service.get_job_status(
|
||||
db=test_db_session,
|
||||
job_id=job.job_id,
|
||||
tenant_id="test-tenant"
|
||||
)
|
||||
|
||||
assert final_job.status == "completed"
|
||||
assert final_job.progress == 100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_data_retention_compliance(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session
|
||||
):
|
||||
"""Test data retention and cleanup compliance"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Create old job (simulate old data)
|
||||
old_job = await training_service.create_training_job(
|
||||
db=test_db_session,
|
||||
tenant_id="test-tenant",
|
||||
job_id="old-job",
|
||||
config={"created_long_ago": True}
|
||||
)
|
||||
|
||||
# Manually set old timestamp
|
||||
from sqlalchemy import update
|
||||
from app.models.training import ModelTrainingLog
|
||||
|
||||
old_timestamp = datetime.now() - timedelta(days=400)
|
||||
await test_db_session.execute(
|
||||
update(ModelTrainingLog)
|
||||
.where(ModelTrainingLog.job_id == old_job.job_id)
|
||||
.values(start_time=old_timestamp, created_at=old_timestamp)
|
||||
)
|
||||
await test_db_session.commit()
|
||||
|
||||
# Verify old job exists
|
||||
retrieved_job = await training_service.get_job_status(
|
||||
db=test_db_session,
|
||||
job_id=old_job.job_id,
|
||||
tenant_id="test-tenant"
|
||||
)
|
||||
|
||||
assert retrieved_job is not None
|
||||
# In a real implementation, there would be cleanup procedures
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gdpr_compliance_features(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session
|
||||
):
|
||||
"""Test GDPR compliance features"""
|
||||
|
||||
# Create job with tenant data
|
||||
job = await training_service.create_training_job(
|
||||
db=test_db_session,
|
||||
tenant_id="gdpr-test-tenant",
|
||||
job_id="gdpr-test-job",
|
||||
config={"gdpr_test": True}
|
||||
)
|
||||
|
||||
# Verify job is associated with tenant
|
||||
assert job.tenant_id == "gdpr-test-tenant"
|
||||
|
||||
# Test data access (right to access)
|
||||
tenant_jobs = await training_service.list_training_jobs(
|
||||
db=test_db_session,
|
||||
tenant_id="gdpr-test-tenant"
|
||||
)
|
||||
|
||||
assert len(tenant_jobs) >= 1
|
||||
assert any(job.job_id == "gdpr-test-job" for job in tenant_jobs)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestLongRunningIntegration:
|
||||
"""Test long-running integration scenarios (marked as slow)"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extended_training_simulation(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session,
|
||||
mock_messaging
|
||||
):
|
||||
"""Test extended training process simulation"""
|
||||
|
||||
job = await training_service.create_training_job(
|
||||
db=test_db_session,
|
||||
tenant_id="test-tenant",
|
||||
job_id="long-running-job",
|
||||
config={"extended_test": True}
|
||||
)
|
||||
|
||||
# Simulate progress over time
|
||||
progress_steps = [
|
||||
(10, "Initializing"),
|
||||
(25, "Loading data"),
|
||||
(50, "Training models"),
|
||||
(75, "Validating results"),
|
||||
(90, "Storing models"),
|
||||
(100, "Completed")
|
||||
]
|
||||
|
||||
for progress, step in progress_steps:
|
||||
await training_service._update_job_status(
|
||||
db=test_db_session,
|
||||
job_id=job.job_id,
|
||||
status="running" if progress < 100 else "completed",
|
||||
progress=progress,
|
||||
current_step=step
|
||||
)
|
||||
|
||||
# Small delay to simulate real progression
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Verify final state
|
||||
final_job = await training_service.get_job_status(
|
||||
db=test_db_session,
|
||||
job_id=job.job_id,
|
||||
tenant_id="test-tenant"
|
||||
)
|
||||
|
||||
assert final_job.status == "completed"
|
||||
assert final_job.progress == 100
|
||||
assert final_job.current_step == "Completed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_usage_stability(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session
|
||||
):
|
||||
"""Test memory usage stability over many operations"""
|
||||
|
||||
# Create many jobs to test memory stability
|
||||
for i in range(50):
|
||||
job = await training_service.create_training_job(
|
||||
db=test_db_session,
|
||||
tenant_id=f"tenant-{i % 5}", # 5 different tenants
|
||||
job_id=f"memory-test-job-{i}",
|
||||
config={"iteration": i}
|
||||
)
|
||||
|
||||
# Update status
|
||||
await training_service._update_job_status(
|
||||
db=test_db_session,
|
||||
job_id=job.job_id,
|
||||
status="completed",
|
||||
progress=100,
|
||||
current_step="Completed"
|
||||
)
|
||||
|
||||
# List jobs for each tenant
|
||||
for tenant_i in range(5):
|
||||
tenant_id = f"tenant-{tenant_i}"
|
||||
jobs = await training_service.list_training_jobs(
|
||||
db=test_db_session,
|
||||
tenant_id=tenant_id,
|
||||
limit=20
|
||||
)
|
||||
|
||||
# Should have 10 jobs per tenant (50 total / 5 tenants)
|
||||
assert len(jobs) == 10
|
||||
|
||||
|
||||
class TestBackwardCompatibility:
|
||||
"""Test backward compatibility with existing systems"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_legacy_config_handling(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session
|
||||
):
|
||||
"""Test handling of legacy configuration formats"""
|
||||
|
||||
# Test with old-style configuration
|
||||
legacy_config = {
|
||||
"weather_enabled": True, # Old key
|
||||
"traffic_enabled": True, # Old key
|
||||
"minimum_samples": 30, # Old key
|
||||
"prophet_config": { # Old nested structure
|
||||
"seasonality": "additive"
|
||||
}
|
||||
}
|
||||
|
||||
job = await training_service.create_training_job(
|
||||
db=test_db_session,
|
||||
tenant_id="test-tenant",
|
||||
job_id="legacy-config-job",
|
||||
config=legacy_config
|
||||
)
|
||||
|
||||
assert job.config == legacy_config
|
||||
assert job.job_id == "legacy-config-job"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_version_compatibility(
|
||||
self,
|
||||
test_client: AsyncClient
|
||||
):
|
||||
"""Test API version compatibility"""
|
||||
|
||||
# Test with minimal request (old API style)
|
||||
minimal_request = {
|
||||
"include_weather": True
|
||||
}
|
||||
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs", json=minimal_request)
|
||||
|
||||
# Should work with defaults for missing fields
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "job_id" in data
|
||||
|
||||
|
||||
# Utility functions for integration tests
|
||||
async def wait_for_condition(condition_func, timeout=5.0, interval=0.1):
|
||||
"""Wait for a condition to become true"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
if await condition_func():
|
||||
return True
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def assert_job_progression(job_updates):
|
||||
"""Assert that job updates show proper progression"""
|
||||
assert len(job_updates) > 0
|
||||
|
||||
# Check progress is non-decreasing
|
||||
for i in range(1, len(job_updates)):
|
||||
assert job_updates[i]["progress"] >= job_updates[i-1]["progress"]
|
||||
|
||||
# Check final status
|
||||
final_update = job_updates[-1]
|
||||
assert final_update["status"] in ["completed", "failed", "cancelled"]
|
||||
|
||||
|
||||
def assert_valid_job_structure(job_data):
|
||||
"""Assert job data has valid structure"""
|
||||
required_fields = ["job_id", "status", "tenant_id"]
|
||||
for field in required_fields:
|
||||
assert field in job_data
|
||||
|
||||
assert isinstance(job_data["progress"], int)
|
||||
assert 0 <= job_data["progress"] <= 100
|
||||
assert job_data["status"] in ["pending", "running", "completed", "failed", "cancelled"]
|
||||
467
services/training/tests/test_messaging.py
Normal file
467
services/training/tests/test_messaging.py
Normal file
@@ -0,0 +1,467 @@
|
||||
# services/training/tests/test_messaging.py
|
||||
"""
|
||||
Tests for training service messaging functionality
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
import json
|
||||
|
||||
from app.services import messaging
|
||||
|
||||
|
||||
class TestTrainingMessaging:
|
||||
"""Test training service messaging functions"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_publisher(self):
|
||||
"""Mock the RabbitMQ publisher"""
|
||||
with patch('app.services.messaging.training_publisher') as mock_pub:
|
||||
mock_pub.publish_event = AsyncMock(return_value=True)
|
||||
mock_pub.connect = AsyncMock(return_value=True)
|
||||
mock_pub.disconnect = AsyncMock(return_value=None)
|
||||
yield mock_pub
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_messaging_success(self, mock_publisher):
|
||||
"""Test successful messaging setup"""
|
||||
await messaging.setup_messaging()
|
||||
|
||||
mock_publisher.connect.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_messaging_failure(self, mock_publisher):
|
||||
"""Test messaging setup failure"""
|
||||
mock_publisher.connect.return_value = False
|
||||
|
||||
await messaging.setup_messaging()
|
||||
|
||||
mock_publisher.connect.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_messaging(self, mock_publisher):
|
||||
"""Test messaging cleanup"""
|
||||
await messaging.cleanup_messaging()
|
||||
|
||||
mock_publisher.disconnect.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_job_started(self, mock_publisher):
|
||||
"""Test publishing job started event"""
|
||||
job_id = "test-job-123"
|
||||
tenant_id = "test-tenant"
|
||||
config = {"include_weather": True}
|
||||
|
||||
result = await messaging.publish_job_started(job_id, tenant_id, config)
|
||||
|
||||
assert result is True
|
||||
mock_publisher.publish_event.assert_called_once()
|
||||
|
||||
# Check call arguments
|
||||
call_args = mock_publisher.publish_event.call_args
|
||||
assert call_args[1]["exchange_name"] == "training.events"
|
||||
assert call_args[1]["routing_key"] == "training.started"
|
||||
|
||||
event_data = call_args[1]["event_data"]
|
||||
assert event_data["service_name"] == "training-service"
|
||||
assert event_data["data"]["job_id"] == job_id
|
||||
assert event_data["data"]["tenant_id"] == tenant_id
|
||||
assert event_data["data"]["config"] == config
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_job_progress(self, mock_publisher):
|
||||
"""Test publishing job progress event"""
|
||||
job_id = "test-job-123"
|
||||
tenant_id = "test-tenant"
|
||||
progress = 50
|
||||
step = "Training models"
|
||||
|
||||
result = await messaging.publish_job_progress(job_id, tenant_id, progress, step)
|
||||
|
||||
assert result is True
|
||||
mock_publisher.publish_event.assert_called_once()
|
||||
|
||||
call_args = mock_publisher.publish_event.call_args
|
||||
assert call_args[1]["routing_key"] == "training.progress"
|
||||
|
||||
event_data = call_args[1]["event_data"]
|
||||
assert event_data["data"]["progress"] == progress
|
||||
assert event_data["data"]["current_step"] == step
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_job_completed(self, mock_publisher):
|
||||
"""Test publishing job completed event"""
|
||||
job_id = "test-job-123"
|
||||
tenant_id = "test-tenant"
|
||||
results = {
|
||||
"products_trained": 3,
|
||||
"summary": {"success_rate": 100.0}
|
||||
}
|
||||
|
||||
result = await messaging.publish_job_completed(job_id, tenant_id, results)
|
||||
|
||||
assert result is True
|
||||
mock_publisher.publish_event.assert_called_once()
|
||||
|
||||
call_args = mock_publisher.publish_event.call_args
|
||||
assert call_args[1]["routing_key"] == "training.completed"
|
||||
|
||||
event_data = call_args[1]["event_data"]
|
||||
assert event_data["data"]["results"] == results
|
||||
assert event_data["data"]["models_trained"] == 3
|
||||
assert event_data["data"]["success_rate"] == 100.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_job_failed(self, mock_publisher):
|
||||
"""Test publishing job failed event"""
|
||||
job_id = "test-job-123"
|
||||
tenant_id = "test-tenant"
|
||||
error = "Data service unavailable"
|
||||
|
||||
result = await messaging.publish_job_failed(job_id, tenant_id, error)
|
||||
|
||||
assert result is True
|
||||
mock_publisher.publish_event.assert_called_once()
|
||||
|
||||
call_args = mock_publisher.publish_event.call_args
|
||||
assert call_args[1]["routing_key"] == "training.failed"
|
||||
|
||||
event_data = call_args[1]["event_data"]
|
||||
assert event_data["data"]["error"] == error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_job_cancelled(self, mock_publisher):
|
||||
"""Test publishing job cancelled event"""
|
||||
job_id = "test-job-123"
|
||||
tenant_id = "test-tenant"
|
||||
|
||||
result = await messaging.publish_job_cancelled(job_id, tenant_id)
|
||||
|
||||
assert result is True
|
||||
mock_publisher.publish_event.assert_called_once()
|
||||
|
||||
call_args = mock_publisher.publish_event.call_args
|
||||
assert call_args[1]["routing_key"] == "training.cancelled"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_product_training_started(self, mock_publisher):
|
||||
"""Test publishing product training started event"""
|
||||
job_id = "test-product-job-123"
|
||||
tenant_id = "test-tenant"
|
||||
product_name = "Pan Integral"
|
||||
|
||||
result = await messaging.publish_product_training_started(job_id, tenant_id, product_name)
|
||||
|
||||
assert result is True
|
||||
mock_publisher.publish_event.assert_called_once()
|
||||
|
||||
call_args = mock_publisher.publish_event.call_args
|
||||
assert call_args[1]["routing_key"] == "training.product.started"
|
||||
|
||||
event_data = call_args[1]["event_data"]
|
||||
assert event_data["data"]["product_name"] == product_name
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_product_training_completed(self, mock_publisher):
|
||||
"""Test publishing product training completed event"""
|
||||
job_id = "test-product-job-123"
|
||||
tenant_id = "test-tenant"
|
||||
product_name = "Pan Integral"
|
||||
model_id = "test-model-123"
|
||||
|
||||
result = await messaging.publish_product_training_completed(
|
||||
job_id, tenant_id, product_name, model_id
|
||||
)
|
||||
|
||||
assert result is True
|
||||
mock_publisher.publish_event.assert_called_once()
|
||||
|
||||
call_args = mock_publisher.publish_event.call_args
|
||||
assert call_args[1]["routing_key"] == "training.product.completed"
|
||||
|
||||
event_data = call_args[1]["event_data"]
|
||||
assert event_data["data"]["model_id"] == model_id
|
||||
assert event_data["data"]["product_name"] == product_name
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_model_trained(self, mock_publisher):
|
||||
"""Test publishing model trained event"""
|
||||
model_id = "test-model-123"
|
||||
tenant_id = "test-tenant"
|
||||
product_name = "Pan Integral"
|
||||
metrics = {"mae": 5.2, "rmse": 7.8, "mape": 12.5}
|
||||
|
||||
result = await messaging.publish_model_trained(model_id, tenant_id, product_name, metrics)
|
||||
|
||||
assert result is True
|
||||
mock_publisher.publish_event.assert_called_once()
|
||||
|
||||
call_args = mock_publisher.publish_event.call_args
|
||||
assert call_args[1]["routing_key"] == "training.model.trained"
|
||||
|
||||
event_data = call_args[1]["event_data"]
|
||||
assert event_data["data"]["training_metrics"] == metrics
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_model_updated(self, mock_publisher):
|
||||
"""Test publishing model updated event"""
|
||||
model_id = "test-model-123"
|
||||
tenant_id = "test-tenant"
|
||||
product_name = "Pan Integral"
|
||||
version = 2
|
||||
|
||||
result = await messaging.publish_model_updated(model_id, tenant_id, product_name, version)
|
||||
|
||||
assert result is True
|
||||
mock_publisher.publish_event.assert_called_once()
|
||||
|
||||
call_args = mock_publisher.publish_event.call_args
|
||||
assert call_args[1]["routing_key"] == "training.model.updated"
|
||||
|
||||
event_data = call_args[1]["event_data"]
|
||||
assert event_data["data"]["version"] == version
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_model_validated(self, mock_publisher):
|
||||
"""Test publishing model validated event"""
|
||||
model_id = "test-model-123"
|
||||
tenant_id = "test-tenant"
|
||||
product_name = "Pan Integral"
|
||||
validation_results = {"is_valid": True, "accuracy": 0.95}
|
||||
|
||||
result = await messaging.publish_model_validated(
|
||||
model_id, tenant_id, product_name, validation_results
|
||||
)
|
||||
|
||||
assert result is True
|
||||
mock_publisher.publish_event.assert_called_once()
|
||||
|
||||
call_args = mock_publisher.publish_event.call_args
|
||||
assert call_args[1]["routing_key"] == "training.model.validated"
|
||||
|
||||
event_data = call_args[1]["event_data"]
|
||||
assert event_data["data"]["validation_results"] == validation_results
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_model_saved(self, mock_publisher):
|
||||
"""Test publishing model saved event"""
|
||||
model_id = "test-model-123"
|
||||
tenant_id = "test-tenant"
|
||||
product_name = "Pan Integral"
|
||||
model_path = "/models/test-model-123.pkl"
|
||||
|
||||
result = await messaging.publish_model_saved(model_id, tenant_id, product_name, model_path)
|
||||
|
||||
assert result is True
|
||||
mock_publisher.publish_event.assert_called_once()
|
||||
|
||||
call_args = mock_publisher.publish_event.call_args
|
||||
assert call_args[1]["routing_key"] == "training.model.saved"
|
||||
|
||||
event_data = call_args[1]["event_data"]
|
||||
assert event_data["data"]["model_path"] == model_path
|
||||
|
||||
|
||||
class TestMessagingErrorHandling:
|
||||
"""Test error handling in messaging"""
|
||||
|
||||
@pytest.fixture
|
||||
def failing_publisher(self):
|
||||
"""Mock publisher that fails"""
|
||||
with patch('app.services.messaging.training_publisher') as mock_pub:
|
||||
mock_pub.publish_event = AsyncMock(return_value=False)
|
||||
mock_pub.connect = AsyncMock(return_value=False)
|
||||
yield mock_pub
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_event_failure(self, failing_publisher):
|
||||
"""Test handling of publish event failure"""
|
||||
result = await messaging.publish_job_started("job-123", "tenant-123", {})
|
||||
|
||||
assert result is False
|
||||
failing_publisher.publish_event.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_messaging_connection_failure(self, failing_publisher):
|
||||
"""Test setup with connection failure"""
|
||||
await messaging.setup_messaging()
|
||||
|
||||
failing_publisher.connect.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_with_exception(self):
|
||||
"""Test publishing with exception"""
|
||||
with patch('app.services.messaging.training_publisher') as mock_pub:
|
||||
mock_pub.publish_event.side_effect = Exception("Connection lost")
|
||||
|
||||
result = await messaging.publish_job_started("job-123", "tenant-123", {})
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestMessagingIntegration:
|
||||
"""Test messaging integration with shared components"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_structure_consistency(self):
|
||||
"""Test that events follow consistent structure"""
|
||||
with patch('app.services.messaging.training_publisher') as mock_pub:
|
||||
mock_pub.publish_event = AsyncMock(return_value=True)
|
||||
|
||||
# Test different event types
|
||||
await messaging.publish_job_started("job-123", "tenant-123", {})
|
||||
await messaging.publish_job_completed("job-123", "tenant-123", {})
|
||||
await messaging.publish_model_trained("model-123", "tenant-123", "Pan", {})
|
||||
|
||||
# Verify all calls have consistent structure
|
||||
assert mock_pub.publish_event.call_count == 3
|
||||
|
||||
for call in mock_pub.publish_event.call_args_list:
|
||||
event_data = call[1]["event_data"]
|
||||
|
||||
# All events should have these fields
|
||||
assert "service_name" in event_data
|
||||
assert "event_type" in event_data
|
||||
assert "data" in event_data
|
||||
assert event_data["service_name"] == "training-service"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shared_event_classes_usage(self):
|
||||
"""Test that shared event classes are used properly"""
|
||||
with patch('shared.messaging.events.TrainingStartedEvent') as mock_event_class:
|
||||
mock_event = Mock()
|
||||
mock_event.to_dict.return_value = {
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.started",
|
||||
"data": {"job_id": "test-job"}
|
||||
}
|
||||
mock_event_class.return_value = mock_event
|
||||
|
||||
with patch('app.services.messaging.training_publisher') as mock_pub:
|
||||
mock_pub.publish_event = AsyncMock(return_value=True)
|
||||
|
||||
await messaging.publish_job_started("test-job", "test-tenant", {})
|
||||
|
||||
# Verify shared event class was used
|
||||
mock_event_class.assert_called_once()
|
||||
mock_event.to_dict.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_routing_key_consistency(self):
|
||||
"""Test that routing keys follow consistent patterns"""
|
||||
with patch('app.services.messaging.training_publisher') as mock_pub:
|
||||
mock_pub.publish_event = AsyncMock(return_value=True)
|
||||
|
||||
# Test various event types
|
||||
events_and_keys = [
|
||||
(messaging.publish_job_started, "training.started"),
|
||||
(messaging.publish_job_progress, "training.progress"),
|
||||
(messaging.publish_job_completed, "training.completed"),
|
||||
(messaging.publish_job_failed, "training.failed"),
|
||||
(messaging.publish_job_cancelled, "training.cancelled"),
|
||||
(messaging.publish_product_training_started, "training.product.started"),
|
||||
(messaging.publish_product_training_completed, "training.product.completed"),
|
||||
(messaging.publish_model_trained, "training.model.trained"),
|
||||
(messaging.publish_model_updated, "training.model.updated"),
|
||||
(messaging.publish_model_validated, "training.model.validated"),
|
||||
(messaging.publish_model_saved, "training.model.saved")
|
||||
]
|
||||
|
||||
for event_func, expected_key in events_and_keys:
|
||||
mock_pub.reset_mock()
|
||||
|
||||
# Call event function with appropriate parameters
|
||||
if "progress" in expected_key:
|
||||
await event_func("job-123", "tenant-123", 50, "step")
|
||||
elif "model" in expected_key and "trained" in expected_key:
|
||||
await event_func("model-123", "tenant-123", "product", {})
|
||||
elif "model" in expected_key and "updated" in expected_key:
|
||||
await event_func("model-123", "tenant-123", "product", 1)
|
||||
elif "model" in expected_key and "validated" in expected_key:
|
||||
await event_func("model-123", "tenant-123", "product", {})
|
||||
elif "model" in expected_key and "saved" in expected_key:
|
||||
await event_func("model-123", "tenant-123", "product", "/path")
|
||||
elif "product" in expected_key and "completed" in expected_key:
|
||||
await event_func("job-123", "tenant-123", "product", "model-123")
|
||||
elif "product" in expected_key:
|
||||
await event_func("job-123", "tenant-123", "product")
|
||||
elif "failed" in expected_key:
|
||||
await event_func("job-123", "tenant-123", "error")
|
||||
elif "cancelled" in expected_key:
|
||||
await event_func("job-123", "tenant-123")
|
||||
else:
|
||||
await event_func("job-123", "tenant-123", {})
|
||||
|
||||
# Verify routing key
|
||||
call_args = mock_pub.publish_event.call_args
|
||||
assert call_args[1]["routing_key"] == expected_key
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exchange_consistency(self):
|
||||
"""Test that all events use the same exchange"""
|
||||
with patch('app.services.messaging.training_publisher') as mock_pub:
|
||||
mock_pub.publish_event = AsyncMock(return_value=True)
|
||||
|
||||
# Test multiple events
|
||||
await messaging.publish_job_started("job-123", "tenant-123", {})
|
||||
await messaging.publish_model_trained("model-123", "tenant-123", "product", {})
|
||||
await messaging.publish_product_training_started("job-123", "tenant-123", "product")
|
||||
|
||||
# Verify all use same exchange
|
||||
for call in mock_pub.publish_event.call_args_list:
|
||||
assert call[1]["exchange_name"] == "training.events"
|
||||
|
||||
|
||||
class TestMessagingPerformance:
|
||||
"""Test messaging performance and reliability"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_publishing(self):
|
||||
"""Test concurrent event publishing"""
|
||||
import asyncio
|
||||
|
||||
with patch('app.services.messaging.training_publisher') as mock_pub:
|
||||
mock_pub.publish_event = AsyncMock(return_value=True)
|
||||
|
||||
# Create multiple concurrent publishing tasks
|
||||
tasks = []
|
||||
for i in range(10):
|
||||
task = messaging.publish_job_progress(f"job-{i}", "tenant-123", i * 10, f"step-{i}")
|
||||
tasks.append(task)
|
||||
|
||||
# Execute all tasks concurrently
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Verify all succeeded
|
||||
assert all(results)
|
||||
assert mock_pub.publish_event.call_count == 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_event_data(self):
|
||||
"""Test publishing events with large data payloads"""
|
||||
with patch('app.services.messaging.training_publisher') as mock_pub:
|
||||
mock_pub.publish_event = AsyncMock(return_value=True)
|
||||
|
||||
# Create large config data
|
||||
large_config = {
|
||||
"products": [f"Product-{i}" for i in range(1000)],
|
||||
"features": [f"feature-{i}" for i in range(100)],
|
||||
"hyperparameters": {f"param-{i}": i for i in range(50)}
|
||||
}
|
||||
|
||||
result = await messaging.publish_job_started("job-123", "tenant-123", large_config)
|
||||
|
||||
assert result is True
|
||||
mock_pub.publish_event.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rapid_sequential_publishing(self):
|
||||
"""Test rapid sequential event publishing"""
|
||||
with patch('app.services.messaging.training_publisher') as mock_pub:
|
||||
mock_pub.publish_event = AsyncMock(return_value=True)
|
||||
|
||||
# Publish many events in sequence
|
||||
for i in range(100):
|
||||
await messaging.publish_job_progress("job-123", "tenant-123", i, f"step-{i}")
|
||||
|
||||
assert mock_pub.publish_event.call_count == 100
|
||||
513
services/training/tests/test_ml.py
Normal file
513
services/training/tests/test_ml.py
Normal file
@@ -0,0 +1,513 @@
|
||||
# services/training/tests/test_ml.py
|
||||
"""
|
||||
Tests for ML components: trainer, prophet_manager, and data_processor
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
from datetime import datetime, timedelta
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from app.ml.trainer import BakeryMLTrainer
|
||||
from app.ml.prophet_manager import BakeryProphetManager
|
||||
from app.ml.data_processor import BakeryDataProcessor
|
||||
|
||||
|
||||
class TestBakeryDataProcessor:
|
||||
"""Test the data processor component"""
|
||||
|
||||
@pytest.fixture
|
||||
def data_processor(self):
|
||||
return BakeryDataProcessor()
|
||||
|
||||
@pytest.fixture
|
||||
def sample_sales_data(self):
|
||||
"""Create sample sales data"""
|
||||
dates = pd.date_range('2024-01-01', periods=60, freq='D')
|
||||
return pd.DataFrame({
|
||||
'date': dates,
|
||||
'product_name': ['Pan Integral'] * 60,
|
||||
'quantity': [45 + np.random.randint(-10, 11) for _ in range(60)]
|
||||
})
|
||||
|
||||
@pytest.fixture
|
||||
def sample_weather_data(self):
|
||||
"""Create sample weather data"""
|
||||
dates = pd.date_range('2024-01-01', periods=60, freq='D')
|
||||
return pd.DataFrame({
|
||||
'date': dates,
|
||||
'temperature': [15 + 5 * np.sin(2 * np.pi * i / 365) + np.random.normal(0, 2) for i in range(60)],
|
||||
'precipitation': [max(0, np.random.exponential(1)) for _ in range(60)],
|
||||
'humidity': [60 + np.random.normal(0, 10) for _ in range(60)]
|
||||
})
|
||||
|
||||
@pytest.fixture
|
||||
def sample_traffic_data(self):
|
||||
"""Create sample traffic data"""
|
||||
dates = pd.date_range('2024-01-01', periods=60, freq='D')
|
||||
return pd.DataFrame({
|
||||
'date': dates,
|
||||
'traffic_volume': [100 + np.random.normal(0, 20) for _ in range(60)]
|
||||
})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_training_data_basic(
|
||||
self,
|
||||
data_processor,
|
||||
sample_sales_data,
|
||||
sample_weather_data,
|
||||
sample_traffic_data
|
||||
):
|
||||
"""Test basic data preparation"""
|
||||
result = await data_processor.prepare_training_data(
|
||||
sales_data=sample_sales_data,
|
||||
weather_data=sample_weather_data,
|
||||
traffic_data=sample_traffic_data,
|
||||
product_name="Pan Integral"
|
||||
)
|
||||
|
||||
# Check result structure
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
assert 'ds' in result.columns
|
||||
assert 'y' in result.columns
|
||||
assert len(result) > 0
|
||||
|
||||
# Check Prophet format
|
||||
assert result['ds'].dtype == 'datetime64[ns]'
|
||||
assert pd.api.types.is_numeric_dtype(result['y'])
|
||||
|
||||
# Check temporal features
|
||||
temporal_features = ['day_of_week', 'is_weekend', 'month', 'is_holiday']
|
||||
for feature in temporal_features:
|
||||
assert feature in result.columns
|
||||
|
||||
# Check weather features
|
||||
weather_features = ['temperature', 'precipitation', 'humidity']
|
||||
for feature in weather_features:
|
||||
assert feature in result.columns
|
||||
|
||||
# Check traffic features
|
||||
assert 'traffic_volume' in result.columns
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_training_data_empty_weather(
|
||||
self,
|
||||
data_processor,
|
||||
sample_sales_data
|
||||
):
|
||||
"""Test data preparation with empty weather data"""
|
||||
result = await data_processor.prepare_training_data(
|
||||
sales_data=sample_sales_data,
|
||||
weather_data=pd.DataFrame(),
|
||||
traffic_data=pd.DataFrame(),
|
||||
product_name="Pan Integral"
|
||||
)
|
||||
|
||||
# Should still work with default values
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
assert 'ds' in result.columns
|
||||
assert 'y' in result.columns
|
||||
|
||||
# Should have default weather values
|
||||
assert 'temperature' in result.columns
|
||||
assert result['temperature'].iloc[0] == 15.0 # Default value
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_prediction_features(self, data_processor):
|
||||
"""Test preparation of prediction features"""
|
||||
future_dates = pd.date_range('2024-02-01', periods=7, freq='D')
|
||||
|
||||
weather_forecast = pd.DataFrame({
|
||||
'ds': future_dates,
|
||||
'temperature': [18.0] * 7,
|
||||
'precipitation': [0.0] * 7,
|
||||
'humidity': [65.0] * 7
|
||||
})
|
||||
|
||||
result = await data_processor.prepare_prediction_features(
|
||||
future_dates=future_dates,
|
||||
weather_forecast=weather_forecast,
|
||||
traffic_forecast=pd.DataFrame()
|
||||
)
|
||||
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
assert len(result) == 7
|
||||
assert 'ds' in result.columns
|
||||
|
||||
# Check temporal features are added
|
||||
assert 'day_of_week' in result.columns
|
||||
assert 'is_weekend' in result.columns
|
||||
|
||||
# Check weather features
|
||||
assert 'temperature' in result.columns
|
||||
assert all(result['temperature'] == 18.0)
|
||||
|
||||
def test_add_temporal_features(self, data_processor):
|
||||
"""Test temporal feature engineering"""
|
||||
dates = pd.date_range('2024-01-01', periods=10, freq='D')
|
||||
df = pd.DataFrame({'date': dates})
|
||||
|
||||
result = data_processor._add_temporal_features(df)
|
||||
|
||||
# Check temporal features
|
||||
assert 'day_of_week' in result.columns
|
||||
assert 'is_weekend' in result.columns
|
||||
assert 'month' in result.columns
|
||||
assert 'season' in result.columns
|
||||
assert 'week_of_year' in result.columns
|
||||
assert 'quarter' in result.columns
|
||||
assert 'is_holiday' in result.columns
|
||||
assert 'is_school_holiday' in result.columns
|
||||
|
||||
# Check weekend detection
|
||||
# 2024-01-01 was a Monday (day_of_week = 0)
|
||||
assert result.iloc[0]['day_of_week'] == 0
|
||||
assert result.iloc[0]['is_weekend'] == 0
|
||||
|
||||
# 2024-01-06 was a Saturday (day_of_week = 5)
|
||||
assert result.iloc[5]['day_of_week'] == 5
|
||||
assert result.iloc[5]['is_weekend'] == 1
|
||||
|
||||
def test_spanish_holiday_detection(self, data_processor):
|
||||
"""Test Spanish holiday detection"""
|
||||
# Test known Spanish holidays
|
||||
new_year = datetime(2024, 1, 1)
|
||||
epiphany = datetime(2024, 1, 6)
|
||||
labour_day = datetime(2024, 5, 1)
|
||||
christmas = datetime(2024, 12, 25)
|
||||
|
||||
assert data_processor._is_spanish_holiday(new_year) == True
|
||||
assert data_processor._is_spanish_holiday(epiphany) == True
|
||||
assert data_processor._is_spanish_holiday(labour_day) == True
|
||||
assert data_processor._is_spanish_holiday(christmas) == True
|
||||
|
||||
# Test non-holiday
|
||||
regular_day = datetime(2024, 3, 15)
|
||||
assert data_processor._is_spanish_holiday(regular_day) == False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_training_data_insufficient_data(self, data_processor):
|
||||
"""Test handling of insufficient training data"""
|
||||
# Create very small dataset
|
||||
small_sales_data = pd.DataFrame({
|
||||
'date': pd.date_range('2024-01-01', periods=5, freq='D'),
|
||||
'product_name': ['Pan Integral'] * 5,
|
||||
'quantity': [45, 50, 48, 52, 49]
|
||||
})
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await data_processor.prepare_training_data(
|
||||
sales_data=small_sales_data,
|
||||
weather_data=pd.DataFrame(),
|
||||
traffic_data=pd.DataFrame(),
|
||||
product_name="Pan Integral"
|
||||
)
|
||||
|
||||
|
||||
class TestBakeryProphetManager:
|
||||
"""Test the Prophet manager component"""
|
||||
|
||||
@pytest.fixture
|
||||
def prophet_manager(self):
|
||||
with patch('app.ml.prophet_manager.settings.MODEL_STORAGE_PATH', '/tmp/test_models'):
|
||||
os.makedirs('/tmp/test_models', exist_ok=True)
|
||||
return BakeryProphetManager()
|
||||
|
||||
@pytest.fixture
|
||||
def sample_prophet_data(self):
|
||||
"""Create sample data in Prophet format"""
|
||||
dates = pd.date_range('2024-01-01', periods=100, freq='D')
|
||||
return pd.DataFrame({
|
||||
'ds': dates,
|
||||
'y': [45 + 10 * np.sin(2 * np.pi * i / 7) + np.random.normal(0, 5) for i in range(100)],
|
||||
'temperature': [15 + 5 * np.sin(2 * np.pi * i / 365) for i in range(100)],
|
||||
'humidity': [60 + np.random.normal(0, 10) for _ in range(100)]
|
||||
})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_train_bakery_model_success(self, prophet_manager, sample_prophet_data):
|
||||
"""Test successful model training"""
|
||||
with patch('prophet.Prophet') as mock_prophet_class:
|
||||
mock_model = Mock()
|
||||
mock_model.fit.return_value = None
|
||||
mock_prophet_class.return_value = mock_model
|
||||
|
||||
with patch('joblib.dump') as mock_dump:
|
||||
result = await prophet_manager.train_bakery_model(
|
||||
tenant_id="test-tenant",
|
||||
product_name="Pan Integral",
|
||||
df=sample_prophet_data,
|
||||
job_id="test-job-123"
|
||||
)
|
||||
|
||||
# Check result structure
|
||||
assert isinstance(result, dict)
|
||||
assert 'model_id' in result
|
||||
assert 'model_path' in result
|
||||
assert 'type' in result
|
||||
assert result['type'] == 'prophet'
|
||||
assert 'training_samples' in result
|
||||
assert 'features' in result
|
||||
assert 'training_metrics' in result
|
||||
|
||||
# Check that model was fitted
|
||||
mock_model.fit.assert_called_once()
|
||||
mock_dump.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_training_data_valid(self, prophet_manager, sample_prophet_data):
|
||||
"""Test validation with valid data"""
|
||||
# Should not raise exception
|
||||
await prophet_manager._validate_training_data(sample_prophet_data, "Pan Integral")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_training_data_insufficient(self, prophet_manager):
|
||||
"""Test validation with insufficient data"""
|
||||
small_data = pd.DataFrame({
|
||||
'ds': pd.date_range('2024-01-01', periods=5, freq='D'),
|
||||
'y': [45, 50, 48, 52, 49]
|
||||
})
|
||||
|
||||
with pytest.raises(ValueError, match="Insufficient training data"):
|
||||
await prophet_manager._validate_training_data(small_data, "Pan Integral")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_training_data_missing_columns(self, prophet_manager):
|
||||
"""Test validation with missing required columns"""
|
||||
invalid_data = pd.DataFrame({
|
||||
'date': pd.date_range('2024-01-01', periods=50, freq='D'),
|
||||
'quantity': [45] * 50
|
||||
})
|
||||
|
||||
with pytest.raises(ValueError, match="Missing required columns"):
|
||||
await prophet_manager._validate_training_data(invalid_data, "Pan Integral")
|
||||
|
||||
def test_get_spanish_holidays(self, prophet_manager):
|
||||
"""Test Spanish holidays creation"""
|
||||
holidays = prophet_manager._get_spanish_holidays()
|
||||
|
||||
if not holidays.empty:
|
||||
assert 'holiday' in holidays.columns
|
||||
assert 'ds' in holidays.columns
|
||||
|
||||
# Check some known holidays exist
|
||||
holiday_names = holidays['holiday'].unique()
|
||||
expected_holidays = ['new_year', 'christmas', 'may_day']
|
||||
|
||||
for holiday in expected_holidays:
|
||||
assert holiday in holiday_names
|
||||
|
||||
def test_extract_regressor_columns(self, prophet_manager, sample_prophet_data):
|
||||
"""Test regressor column extraction"""
|
||||
regressors = prophet_manager._extract_regressor_columns(sample_prophet_data)
|
||||
|
||||
assert isinstance(regressors, list)
|
||||
assert 'temperature' in regressors
|
||||
assert 'humidity' in regressors
|
||||
assert 'ds' not in regressors # Should be excluded
|
||||
assert 'y' not in regressors # Should be excluded
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_forecast(self, prophet_manager):
|
||||
"""Test forecast generation"""
|
||||
# Create a temporary model file
|
||||
with tempfile.NamedTemporaryFile(suffix='.pkl', delete=False) as temp_file:
|
||||
model_path = temp_file.name
|
||||
|
||||
try:
|
||||
# Mock a saved model
|
||||
with patch('joblib.load') as mock_load:
|
||||
mock_model = Mock()
|
||||
mock_forecast = pd.DataFrame({
|
||||
'ds': pd.date_range('2024-02-01', periods=7, freq='D'),
|
||||
'yhat': [50.0] * 7,
|
||||
'yhat_lower': [45.0] * 7,
|
||||
'yhat_upper': [55.0] * 7
|
||||
})
|
||||
mock_model.predict.return_value = mock_forecast
|
||||
mock_load.return_value = mock_model
|
||||
|
||||
future_data = pd.DataFrame({
|
||||
'ds': pd.date_range('2024-02-01', periods=7, freq='D'),
|
||||
'temperature': [18.0] * 7,
|
||||
'humidity': [65.0] * 7
|
||||
})
|
||||
|
||||
result = await prophet_manager.generate_forecast(
|
||||
model_path=model_path,
|
||||
future_dates=future_data,
|
||||
regressor_columns=['temperature', 'humidity']
|
||||
)
|
||||
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
assert len(result) == 7
|
||||
mock_model.predict.assert_called_once()
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
try:
|
||||
os.unlink(model_path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
class TestBakeryMLTrainer:
|
||||
"""Test the ML trainer component"""
|
||||
|
||||
@pytest.fixture
|
||||
def ml_trainer(self, mock_prophet_manager, mock_data_processor):
|
||||
return BakeryMLTrainer()
|
||||
|
||||
@pytest.fixture
|
||||
def sample_sales_data(self):
|
||||
"""Sample sales data for training"""
|
||||
return [
|
||||
{"date": "2024-01-01", "product_name": "Pan Integral", "quantity": 45},
|
||||
{"date": "2024-01-02", "product_name": "Pan Integral", "quantity": 50},
|
||||
{"date": "2024-01-03", "product_name": "Pan Integral", "quantity": 48},
|
||||
{"date": "2024-01-04", "product_name": "Croissant", "quantity": 25},
|
||||
{"date": "2024-01-05", "product_name": "Croissant", "quantity": 30}
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_train_tenant_models_success(
|
||||
self,
|
||||
ml_trainer,
|
||||
sample_sales_data,
|
||||
mock_prophet_manager,
|
||||
mock_data_processor
|
||||
):
|
||||
"""Test successful training of tenant models"""
|
||||
result = await ml_trainer.train_tenant_models(
|
||||
tenant_id="test-tenant",
|
||||
sales_data=sample_sales_data,
|
||||
weather_data=[],
|
||||
traffic_data=[],
|
||||
job_id="test-job-123"
|
||||
)
|
||||
|
||||
# Check result structure
|
||||
assert isinstance(result, dict)
|
||||
assert 'job_id' in result
|
||||
assert 'tenant_id' in result
|
||||
assert 'status' in result
|
||||
assert 'training_results' in result
|
||||
assert 'summary' in result
|
||||
|
||||
assert result['status'] == 'completed'
|
||||
assert result['tenant_id'] == 'test-tenant'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_train_single_product_success(
|
||||
self,
|
||||
ml_trainer,
|
||||
sample_sales_data,
|
||||
mock_prophet_manager,
|
||||
mock_data_processor
|
||||
):
|
||||
"""Test successful single product training"""
|
||||
product_sales = [item for item in sample_sales_data if item['product_name'] == 'Pan Integral']
|
||||
|
||||
result = await ml_trainer.train_single_product(
|
||||
tenant_id="test-tenant",
|
||||
product_name="Pan Integral",
|
||||
sales_data=product_sales,
|
||||
weather_data=[],
|
||||
traffic_data=[],
|
||||
job_id="test-job-123"
|
||||
)
|
||||
|
||||
# Check result structure
|
||||
assert isinstance(result, dict)
|
||||
assert 'job_id' in result
|
||||
assert 'tenant_id' in result
|
||||
assert 'product_name' in result
|
||||
assert 'status' in result
|
||||
assert 'model_info' in result
|
||||
|
||||
assert result['status'] == 'success'
|
||||
assert result['product_name'] == 'Pan Integral'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_train_single_product_no_data(self, ml_trainer):
|
||||
"""Test single product training with no data"""
|
||||
with pytest.raises(ValueError, match="No sales data found"):
|
||||
await ml_trainer.train_single_product(
|
||||
tenant_id="test-tenant",
|
||||
product_name="Nonexistent Product",
|
||||
sales_data=[],
|
||||
weather_data=[],
|
||||
traffic_data=[],
|
||||
job_id="test-job-123"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_input_data_valid(self, ml_trainer, sample_sales_data):
|
||||
"""Test input data validation with valid data"""
|
||||
df = pd.DataFrame(sample_sales_data)
|
||||
|
||||
# Should not raise exception
|
||||
await ml_trainer._validate_input_data(df, "test-tenant")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_input_data_empty(self, ml_trainer):
|
||||
"""Test input data validation with empty data"""
|
||||
empty_df = pd.DataFrame()
|
||||
|
||||
with pytest.raises(ValueError, match="No sales data provided"):
|
||||
await ml_trainer._validate_input_data(empty_df, "test-tenant")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_input_data_missing_columns(self, ml_trainer):
|
||||
"""Test input data validation with missing columns"""
|
||||
invalid_df = pd.DataFrame([
|
||||
{"invalid_column": "value1"},
|
||||
{"invalid_column": "value2"}
|
||||
])
|
||||
|
||||
with pytest.raises(ValueError, match="Missing required columns"):
|
||||
await ml_trainer._validate_input_data(invalid_df, "test-tenant")
|
||||
|
||||
def test_calculate_training_summary(self, ml_trainer):
|
||||
"""Test training summary calculation"""
|
||||
training_results = {
|
||||
"Pan Integral": {
|
||||
"status": "success",
|
||||
"model_info": {"training_metrics": {"mae": 5.0, "rmse": 7.0}}
|
||||
},
|
||||
"Croissant": {
|
||||
"status": "error",
|
||||
"error_message": "Insufficient data"
|
||||
},
|
||||
"Baguette": {
|
||||
"status": "skipped",
|
||||
"reason": "insufficient_data"
|
||||
}
|
||||
}
|
||||
|
||||
summary = ml_trainer._calculate_training_summary(training_results)
|
||||
|
||||
assert summary['total_products'] == 3
|
||||
assert summary['successful_products'] == 1
|
||||
assert summary['failed_products'] == 1
|
||||
assert summary['skipped_products'] == 1
|
||||
assert summary['success_rate'] == 33.33 # 1/3 * 100
|
||||
|
||||
|
||||
class TestIntegrationML:
|
||||
"""Integration tests for ML components working together"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_training_flow(self):
|
||||
"""Test complete training flow from data to model"""
|
||||
# This test would require actual Prophet and data processing
|
||||
# Skip for now due to dependencies
|
||||
pytest.skip("Requires actual Prophet dependencies for integration test")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_data_pipeline_integration(self):
|
||||
"""Test data processor -> prophet manager integration"""
|
||||
pytest.skip("Requires actual dependencies for integration test")
|
||||
688
services/training/tests/test_service.py
Normal file
688
services/training/tests/test_service.py
Normal file
@@ -0,0 +1,688 @@
|
||||
# services/training/tests/test_service.py
|
||||
"""
|
||||
Tests for training service business logic layer
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from datetime import datetime, timedelta
|
||||
import httpx
|
||||
|
||||
from app.services.training_service import TrainingService
|
||||
from app.schemas.training import TrainingJobRequest, SingleProductTrainingRequest
|
||||
from app.models.training import ModelTrainingLog, TrainedModel
|
||||
|
||||
|
||||
class TestTrainingService:
|
||||
"""Test the training service business logic"""
|
||||
|
||||
@pytest.fixture
|
||||
def training_service(self, mock_ml_trainer):
|
||||
return TrainingService()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_training_job_success(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session
|
||||
):
|
||||
"""Test successful training job creation"""
|
||||
job_id = "test-job-123"
|
||||
tenant_id = "test-tenant"
|
||||
config = {"include_weather": True, "include_traffic": True}
|
||||
|
||||
result = await training_service.create_training_job(
|
||||
db=test_db_session,
|
||||
tenant_id=tenant_id,
|
||||
job_id=job_id,
|
||||
config=config
|
||||
)
|
||||
|
||||
assert isinstance(result, ModelTrainingLog)
|
||||
assert result.job_id == job_id
|
||||
assert result.tenant_id == tenant_id
|
||||
assert result.status == "pending"
|
||||
assert result.progress == 0
|
||||
assert result.config == config
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_single_product_job_success(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session
|
||||
):
|
||||
"""Test successful single product job creation"""
|
||||
job_id = "test-product-job-123"
|
||||
tenant_id = "test-tenant"
|
||||
product_name = "Pan Integral"
|
||||
config = {"include_weather": True}
|
||||
|
||||
result = await training_service.create_single_product_job(
|
||||
db=test_db_session,
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
job_id=job_id,
|
||||
config=config
|
||||
)
|
||||
|
||||
assert isinstance(result, ModelTrainingLog)
|
||||
assert result.job_id == job_id
|
||||
assert result.tenant_id == tenant_id
|
||||
assert result.config["single_product"] == product_name
|
||||
assert f"Initializing training for {product_name}" in result.current_step
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_job_status_existing(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session,
|
||||
training_job_in_db
|
||||
):
|
||||
"""Test getting status of existing job"""
|
||||
result = await training_service.get_job_status(
|
||||
db=test_db_session,
|
||||
job_id=training_job_in_db.job_id,
|
||||
tenant_id=training_job_in_db.tenant_id
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.job_id == training_job_in_db.job_id
|
||||
assert result.status == training_job_in_db.status
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_job_status_nonexistent(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session
|
||||
):
|
||||
"""Test getting status of non-existent job"""
|
||||
result = await training_service.get_job_status(
|
||||
db=test_db_session,
|
||||
job_id="nonexistent-job",
|
||||
tenant_id="test-tenant"
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_training_jobs(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session,
|
||||
training_job_in_db
|
||||
):
|
||||
"""Test listing training jobs"""
|
||||
result = await training_service.list_training_jobs(
|
||||
db=test_db_session,
|
||||
tenant_id=training_job_in_db.tenant_id,
|
||||
limit=10
|
||||
)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) >= 1
|
||||
assert result[0].job_id == training_job_in_db.job_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_training_jobs_with_filter(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session,
|
||||
training_job_in_db
|
||||
):
|
||||
"""Test listing training jobs with status filter"""
|
||||
result = await training_service.list_training_jobs(
|
||||
db=test_db_session,
|
||||
tenant_id=training_job_in_db.tenant_id,
|
||||
limit=10,
|
||||
status_filter="pending"
|
||||
)
|
||||
|
||||
assert isinstance(result, list)
|
||||
for job in result:
|
||||
assert job.status == "pending"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_training_job_success(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session,
|
||||
training_job_in_db
|
||||
):
|
||||
"""Test successful job cancellation"""
|
||||
result = await training_service.cancel_training_job(
|
||||
db=test_db_session,
|
||||
job_id=training_job_in_db.job_id,
|
||||
tenant_id=training_job_in_db.tenant_id
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify status was updated
|
||||
updated_job = await training_service.get_job_status(
|
||||
db=test_db_session,
|
||||
job_id=training_job_in_db.job_id,
|
||||
tenant_id=training_job_in_db.tenant_id
|
||||
)
|
||||
assert updated_job.status == "cancelled"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_nonexistent_job(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session
|
||||
):
|
||||
"""Test cancelling non-existent job"""
|
||||
result = await training_service.cancel_training_job(
|
||||
db=test_db_session,
|
||||
job_id="nonexistent-job",
|
||||
tenant_id="test-tenant"
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_training_data_valid(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session,
|
||||
mock_data_service
|
||||
):
|
||||
"""Test validation with valid data"""
|
||||
config = {"min_data_points": 30}
|
||||
|
||||
result = await training_service.validate_training_data(
|
||||
db=test_db_session,
|
||||
tenant_id="test-tenant",
|
||||
config=config
|
||||
)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "is_valid" in result
|
||||
assert "issues" in result
|
||||
assert "recommendations" in result
|
||||
assert "estimated_time_minutes" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_training_data_no_data(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session
|
||||
):
|
||||
"""Test validation with no data"""
|
||||
config = {"min_data_points": 30}
|
||||
|
||||
with patch('app.services.training_service.TrainingService._fetch_sales_data', return_value=AsyncMock(return_value=[])):
|
||||
result = await training_service.validate_training_data(
|
||||
db=test_db_session,
|
||||
tenant_id="test-tenant",
|
||||
config=config
|
||||
)
|
||||
|
||||
assert result["is_valid"] is False
|
||||
assert "No sales data found" in result["issues"][0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_job_status(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session,
|
||||
training_job_in_db
|
||||
):
|
||||
"""Test updating job status"""
|
||||
await training_service._update_job_status(
|
||||
db=test_db_session,
|
||||
job_id=training_job_in_db.job_id,
|
||||
status="running",
|
||||
progress=50,
|
||||
current_step="Training models"
|
||||
)
|
||||
|
||||
# Verify update
|
||||
updated_job = await training_service.get_job_status(
|
||||
db=test_db_session,
|
||||
job_id=training_job_in_db.job_id,
|
||||
tenant_id=training_job_in_db.tenant_id
|
||||
)
|
||||
|
||||
assert updated_job.status == "running"
|
||||
assert updated_job.progress == 50
|
||||
assert updated_job.current_step == "Training models"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_trained_models(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session
|
||||
):
|
||||
"""Test storing trained models"""
|
||||
tenant_id = "test-tenant"
|
||||
training_results = {
|
||||
"training_results": {
|
||||
"Pan Integral": {
|
||||
"status": "success",
|
||||
"model_info": {
|
||||
"model_id": "test-model-123",
|
||||
"model_path": "/test/models/test-model-123.pkl",
|
||||
"type": "prophet",
|
||||
"training_samples": 100,
|
||||
"features": ["temperature", "humidity"],
|
||||
"hyperparameters": {"seasonality_mode": "additive"},
|
||||
"training_metrics": {"mae": 5.2, "rmse": 7.8},
|
||||
"data_period": {
|
||||
"start_date": "2024-01-01T00:00:00",
|
||||
"end_date": "2024-01-31T00:00:00"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
await training_service._store_trained_models(
|
||||
db=test_db_session,
|
||||
tenant_id=tenant_id,
|
||||
training_results=training_results
|
||||
)
|
||||
|
||||
# Verify model was stored
|
||||
from sqlalchemy import select
|
||||
result = await test_db_session.execute(
|
||||
select(TrainedModel).where(
|
||||
TrainedModel.tenant_id == tenant_id,
|
||||
TrainedModel.product_name == "Pan Integral"
|
||||
)
|
||||
)
|
||||
|
||||
stored_model = result.scalar_one_or_none()
|
||||
assert stored_model is not None
|
||||
assert stored_model.model_id == "test-model-123"
|
||||
assert stored_model.is_active is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_training_logs(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session,
|
||||
training_job_in_db
|
||||
):
|
||||
"""Test getting training logs"""
|
||||
result = await training_service.get_training_logs(
|
||||
db=test_db_session,
|
||||
job_id=training_job_in_db.job_id,
|
||||
tenant_id=training_job_in_db.tenant_id
|
||||
)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) > 0
|
||||
|
||||
# Check log content
|
||||
log_text = " ".join(result)
|
||||
assert training_job_in_db.job_id in log_text or "Job started" in log_text
|
||||
|
||||
|
||||
class TestTrainingServiceDataFetching:
|
||||
"""Test external data fetching functionality"""
|
||||
|
||||
@pytest.fixture
|
||||
def training_service(self):
|
||||
return TrainingService()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_sales_data_success(self, training_service):
|
||||
"""Test successful sales data fetching"""
|
||||
mock_request = Mock()
|
||||
mock_request.start_date = None
|
||||
mock_request.end_date = None
|
||||
|
||||
mock_response_data = {
|
||||
"sales": [
|
||||
{"date": "2024-01-01", "product_name": "Pan Integral", "quantity": 45}
|
||||
]
|
||||
}
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_response_data
|
||||
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
|
||||
|
||||
result = await training_service._fetch_sales_data(
|
||||
tenant_id="test-tenant",
|
||||
request=mock_request
|
||||
)
|
||||
|
||||
assert result == mock_response_data["sales"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_sales_data_error(self, training_service):
|
||||
"""Test sales data fetching with API error"""
|
||||
mock_request = Mock()
|
||||
mock_request.start_date = None
|
||||
mock_request.end_date = None
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 500
|
||||
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
|
||||
|
||||
result = await training_service._fetch_sales_data(
|
||||
tenant_id="test-tenant",
|
||||
request=mock_request
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_weather_data_success(self, training_service):
|
||||
"""Test successful weather data fetching"""
|
||||
mock_request = Mock()
|
||||
mock_request.start_date = None
|
||||
mock_request.end_date = None
|
||||
|
||||
mock_response_data = {
|
||||
"weather": [
|
||||
{"date": "2024-01-01", "temperature": 15.2, "precipitation": 0.0}
|
||||
]
|
||||
}
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_response_data
|
||||
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
|
||||
|
||||
result = await training_service._fetch_weather_data(
|
||||
tenant_id="test-tenant",
|
||||
request=mock_request
|
||||
)
|
||||
|
||||
assert result == mock_response_data["weather"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_traffic_data_success(self, training_service):
|
||||
"""Test successful traffic data fetching"""
|
||||
mock_request = Mock()
|
||||
mock_request.start_date = None
|
||||
mock_request.end_date = None
|
||||
|
||||
mock_response_data = {
|
||||
"traffic": [
|
||||
{"date": "2024-01-01", "traffic_volume": 120}
|
||||
]
|
||||
}
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_response_data
|
||||
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
|
||||
|
||||
result = await training_service._fetch_traffic_data(
|
||||
tenant_id="test-tenant",
|
||||
request=mock_request
|
||||
)
|
||||
|
||||
assert result == mock_response_data["traffic"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_data_with_date_filters(self, training_service):
|
||||
"""Test data fetching with date filters"""
|
||||
from datetime import datetime
|
||||
|
||||
mock_request = Mock()
|
||||
mock_request.start_date = datetime(2024, 1, 1)
|
||||
mock_request.end_date = datetime(2024, 1, 31)
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"sales": []}
|
||||
|
||||
mock_get = mock_client.return_value.__aenter__.return_value.get
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
await training_service._fetch_sales_data(
|
||||
tenant_id="test-tenant",
|
||||
request=mock_request
|
||||
)
|
||||
|
||||
# Verify dates were passed in params
|
||||
call_args = mock_get.call_args
|
||||
params = call_args[1]["params"]
|
||||
assert "start_date" in params
|
||||
assert "end_date" in params
|
||||
assert params["start_date"] == "2024-01-01T00:00:00"
|
||||
assert params["end_date"] == "2024-01-31T00:00:00"
|
||||
|
||||
|
||||
class TestTrainingServiceExecution:
|
||||
"""Test training execution workflow"""
|
||||
|
||||
@pytest.fixture
|
||||
def training_service(self, mock_ml_trainer):
|
||||
return TrainingService()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_training_job_success(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session,
|
||||
mock_messaging,
|
||||
mock_data_service
|
||||
):
|
||||
"""Test successful training job execution"""
|
||||
# Create job first
|
||||
job_id = "test-execution-job"
|
||||
training_log = await training_service.create_training_job(
|
||||
db=test_db_session,
|
||||
tenant_id="test-tenant",
|
||||
job_id=job_id,
|
||||
config={"include_weather": True}
|
||||
)
|
||||
|
||||
request = TrainingJobRequest(
|
||||
include_weather=True,
|
||||
include_traffic=True,
|
||||
min_data_points=30
|
||||
)
|
||||
|
||||
with patch('app.services.training_service.TrainingService._fetch_sales_data') as mock_fetch_sales, \
|
||||
patch('app.services.training_service.TrainingService._fetch_weather_data') as mock_fetch_weather, \
|
||||
patch('app.services.training_service.TrainingService._fetch_traffic_data') as mock_fetch_traffic, \
|
||||
patch('app.services.training_service.TrainingService._store_trained_models') as mock_store:
|
||||
|
||||
mock_fetch_sales.return_value = [{"date": "2024-01-01", "product_name": "Pan Integral", "quantity": 45}]
|
||||
mock_fetch_weather.return_value = []
|
||||
mock_fetch_traffic.return_value = []
|
||||
mock_store.return_value = None
|
||||
|
||||
await training_service.execute_training_job(
|
||||
db=test_db_session,
|
||||
job_id=job_id,
|
||||
tenant_id="test-tenant",
|
||||
request=request
|
||||
)
|
||||
|
||||
# Verify job was completed
|
||||
updated_job = await training_service.get_job_status(
|
||||
db=test_db_session,
|
||||
job_id=job_id,
|
||||
tenant_id="test-tenant"
|
||||
)
|
||||
|
||||
assert updated_job.status == "completed"
|
||||
assert updated_job.progress == 100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_training_job_failure(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session,
|
||||
mock_messaging
|
||||
):
|
||||
"""Test training job execution with failure"""
|
||||
# Create job first
|
||||
job_id = "test-failure-job"
|
||||
await training_service.create_training_job(
|
||||
db=test_db_session,
|
||||
tenant_id="test-tenant",
|
||||
job_id=job_id,
|
||||
config={}
|
||||
)
|
||||
|
||||
request = TrainingJobRequest(min_data_points=30)
|
||||
|
||||
with patch('app.services.training_service.TrainingService._fetch_sales_data') as mock_fetch:
|
||||
mock_fetch.side_effect = Exception("Data service unavailable")
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await training_service.execute_training_job(
|
||||
db=test_db_session,
|
||||
job_id=job_id,
|
||||
tenant_id="test-tenant",
|
||||
request=request
|
||||
)
|
||||
|
||||
# Verify job was marked as failed
|
||||
updated_job = await training_service.get_job_status(
|
||||
db=test_db_session,
|
||||
job_id=job_id,
|
||||
tenant_id="test-tenant"
|
||||
)
|
||||
|
||||
assert updated_job.status == "failed"
|
||||
assert "Data service unavailable" in updated_job.error_message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_single_product_training_success(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session,
|
||||
mock_messaging,
|
||||
mock_data_service
|
||||
):
|
||||
"""Test successful single product training execution"""
|
||||
job_id = "test-single-product-job"
|
||||
product_name = "Pan Integral"
|
||||
|
||||
await training_service.create_single_product_job(
|
||||
db=test_db_session,
|
||||
tenant_id="test-tenant",
|
||||
product_name=product_name,
|
||||
job_id=job_id,
|
||||
config={}
|
||||
)
|
||||
|
||||
request = SingleProductTrainingRequest(
|
||||
include_weather=True,
|
||||
include_traffic=False
|
||||
)
|
||||
|
||||
with patch('app.services.training_service.TrainingService._fetch_product_sales_data') as mock_fetch_sales, \
|
||||
patch('app.services.training_service.TrainingService._fetch_weather_data') as mock_fetch_weather, \
|
||||
patch('app.services.training_service.TrainingService._store_single_trained_model') as mock_store:
|
||||
|
||||
mock_fetch_sales.return_value = [{"date": "2024-01-01", "product_name": product_name, "quantity": 45}]
|
||||
mock_fetch_weather.return_value = []
|
||||
mock_store.return_value = None
|
||||
|
||||
await training_service.execute_single_product_training(
|
||||
db=test_db_session,
|
||||
job_id=job_id,
|
||||
tenant_id="test-tenant",
|
||||
product_name=product_name,
|
||||
request=request
|
||||
)
|
||||
|
||||
# Verify job was completed
|
||||
updated_job = await training_service.get_job_status(
|
||||
db=test_db_session,
|
||||
job_id=job_id,
|
||||
tenant_id="test-tenant"
|
||||
)
|
||||
|
||||
assert updated_job.status == "completed"
|
||||
assert updated_job.progress == 100
|
||||
|
||||
|
||||
class TestTrainingServiceEdgeCases:
|
||||
"""Test edge cases and error conditions"""
|
||||
|
||||
@pytest.fixture
|
||||
def training_service(self):
|
||||
return TrainingService()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_database_connection_failure(self, training_service):
|
||||
"""Test handling of database connection failures"""
|
||||
with patch('sqlalchemy.ext.asyncio.AsyncSession') as mock_session:
|
||||
mock_session.side_effect = Exception("Database connection failed")
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await training_service.create_training_job(
|
||||
db=mock_session,
|
||||
tenant_id="test-tenant",
|
||||
job_id="test-job",
|
||||
config={}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_external_service_timeout(self, training_service):
|
||||
"""Test handling of external service timeouts"""
|
||||
mock_request = Mock()
|
||||
mock_request.start_date = None
|
||||
mock_request.end_date = None
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_client.return_value.__aenter__.return_value.get.side_effect = httpx.TimeoutException("Request timeout")
|
||||
|
||||
result = await training_service._fetch_sales_data(
|
||||
tenant_id="test-tenant",
|
||||
request=mock_request
|
||||
)
|
||||
|
||||
# Should return empty list on timeout
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_job_creation(self, training_service, test_db_session):
|
||||
"""Test handling of concurrent job creation"""
|
||||
# This test would need more sophisticated setup for true concurrency testing
|
||||
# For now, just test that multiple jobs can be created
|
||||
|
||||
job_ids = ["concurrent-job-1", "concurrent-job-2", "concurrent-job-3"]
|
||||
|
||||
jobs = []
|
||||
for job_id in job_ids:
|
||||
job = await training_service.create_training_job(
|
||||
db=test_db_session,
|
||||
tenant_id="test-tenant",
|
||||
job_id=job_id,
|
||||
config={}
|
||||
)
|
||||
jobs.append(job)
|
||||
|
||||
assert len(jobs) == 3
|
||||
for i, job in enumerate(jobs):
|
||||
assert job.job_id == job_ids[i]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_malformed_config_handling(self, training_service, test_db_session):
|
||||
"""Test handling of malformed configuration"""
|
||||
malformed_config = {
|
||||
"invalid_key": "invalid_value",
|
||||
"nested": {"data": None}
|
||||
}
|
||||
|
||||
# Should not raise exception, just store the config as-is
|
||||
job = await training_service.create_training_job(
|
||||
db=test_db_session,
|
||||
tenant_id="test-tenant",
|
||||
job_id="malformed-config-job",
|
||||
config=malformed_config
|
||||
)
|
||||
|
||||
assert job.config == malformed_config
|
||||
Reference in New Issue
Block a user