362 lines
12 KiB
Python
362 lines
12 KiB
Python
|
|
# services/training/tests/conftest.py
|
||
|
|
"""
|
||
|
|
Pytest configuration and fixtures for training service tests
|
||
|
|
"""
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
import asyncio
|
||
|
|
import os
|
||
|
|
from typing import AsyncGenerator, Generator
|
||
|
|
from unittest.mock import AsyncMock, Mock, patch
|
||
|
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||
|
|
from sqlalchemy.orm import sessionmaker
|
||
|
|
from httpx import AsyncClient
|
||
|
|
from fastapi.testclient import TestClient
|
||
|
|
|
||
|
|
# Add app to Python path
|
||
|
|
import sys
|
||
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||
|
|
|
||
|
|
from app.main import app
|
||
|
|
from app.core.database import Base, get_db
|
||
|
|
from app.core.config import settings
|
||
|
|
from app.models.training import ModelTrainingLog, TrainedModel
|
||
|
|
from app.ml.trainer import BakeryMLTrainer
|
||
|
|
from app.ml.prophet_manager import BakeryProphetManager
|
||
|
|
from app.ml.data_processor import BakeryDataProcessor
|
||
|
|
|
||
|
|
# Test database URL
|
||
|
|
TEST_DATABASE_URL = "sqlite+aiosqlite:///./test_training.db"
|
||
|
|
|
||
|
|
@pytest.fixture(scope="session")
|
||
|
|
def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]:
|
||
|
|
"""Create an instance of the default event loop for the test session."""
|
||
|
|
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||
|
|
yield loop
|
||
|
|
loop.close()
|
||
|
|
|
||
|
|
@pytest.fixture(scope="session")
|
||
|
|
async def test_engine():
|
||
|
|
"""Create test database engine"""
|
||
|
|
engine = create_async_engine(TEST_DATABASE_URL, echo=False)
|
||
|
|
|
||
|
|
# Create all tables
|
||
|
|
async with engine.begin() as conn:
|
||
|
|
await conn.run_sync(Base.metadata.create_all)
|
||
|
|
|
||
|
|
yield engine
|
||
|
|
|
||
|
|
# Cleanup
|
||
|
|
await engine.dispose()
|
||
|
|
|
||
|
|
# Remove test database file
|
||
|
|
try:
|
||
|
|
os.remove("./test_training.db")
|
||
|
|
except FileNotFoundError:
|
||
|
|
pass
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
async def test_db_session(test_engine) -> AsyncGenerator[AsyncSession, None]:
|
||
|
|
"""Create test database session"""
|
||
|
|
async_session = sessionmaker(
|
||
|
|
test_engine, class_=AsyncSession, expire_on_commit=False
|
||
|
|
)
|
||
|
|
|
||
|
|
async with async_session() as session:
|
||
|
|
yield session
|
||
|
|
await session.rollback()
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def override_get_db(test_db_session):
|
||
|
|
"""Override the get_db dependency"""
|
||
|
|
async def _override_get_db():
|
||
|
|
yield test_db_session
|
||
|
|
|
||
|
|
app.dependency_overrides[get_db] = _override_get_db
|
||
|
|
yield
|
||
|
|
app.dependency_overrides.clear()
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
async def test_client(override_get_db) -> AsyncGenerator[AsyncClient, None]:
|
||
|
|
"""Create test HTTP client"""
|
||
|
|
async with AsyncClient(app=app, base_url="http://test") as client:
|
||
|
|
yield client
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def sync_test_client() -> Generator[TestClient, None, None]:
|
||
|
|
"""Create synchronous test client for simple tests"""
|
||
|
|
with TestClient(app) as client:
|
||
|
|
yield client
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def mock_messaging():
|
||
|
|
"""Mock messaging for tests"""
|
||
|
|
with patch('app.services.messaging.setup_messaging') as mock_setup, \
|
||
|
|
patch('app.services.messaging.cleanup_messaging') as mock_cleanup, \
|
||
|
|
patch('app.services.messaging.publish_job_started') as mock_start, \
|
||
|
|
patch('app.services.messaging.publish_job_completed') as mock_complete, \
|
||
|
|
patch('app.services.messaging.publish_job_failed') as mock_failed:
|
||
|
|
|
||
|
|
mock_setup.return_value = AsyncMock()
|
||
|
|
mock_cleanup.return_value = AsyncMock()
|
||
|
|
mock_start.return_value = AsyncMock(return_value=True)
|
||
|
|
mock_complete.return_value = AsyncMock(return_value=True)
|
||
|
|
mock_failed.return_value = AsyncMock(return_value=True)
|
||
|
|
|
||
|
|
yield {
|
||
|
|
'setup': mock_setup,
|
||
|
|
'cleanup': mock_cleanup,
|
||
|
|
'start': mock_start,
|
||
|
|
'complete': mock_complete,
|
||
|
|
'failed': mock_failed
|
||
|
|
}
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def mock_data_service():
|
||
|
|
"""Mock external data service responses"""
|
||
|
|
mock_sales_data = [
|
||
|
|
{
|
||
|
|
"date": "2024-01-01",
|
||
|
|
"product_name": "Pan Integral",
|
||
|
|
"quantity": 45
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"date": "2024-01-02",
|
||
|
|
"product_name": "Pan Integral",
|
||
|
|
"quantity": 52
|
||
|
|
}
|
||
|
|
]
|
||
|
|
|
||
|
|
mock_weather_data = [
|
||
|
|
{
|
||
|
|
"date": "2024-01-01",
|
||
|
|
"temperature": 15.2,
|
||
|
|
"precipitation": 0.0,
|
||
|
|
"humidity": 65
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"date": "2024-01-02",
|
||
|
|
"temperature": 18.1,
|
||
|
|
"precipitation": 2.5,
|
||
|
|
"humidity": 72
|
||
|
|
}
|
||
|
|
]
|
||
|
|
|
||
|
|
mock_traffic_data = [
|
||
|
|
{
|
||
|
|
"date": "2024-01-01",
|
||
|
|
"traffic_volume": 120
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"date": "2024-01-02",
|
||
|
|
"traffic_volume": 95
|
||
|
|
}
|
||
|
|
]
|
||
|
|
|
||
|
|
with patch('httpx.AsyncClient') as mock_client:
|
||
|
|
mock_response = Mock()
|
||
|
|
mock_response.status_code = 200
|
||
|
|
mock_response.json.return_value = {
|
||
|
|
"sales": mock_sales_data,
|
||
|
|
"weather": mock_weather_data,
|
||
|
|
"traffic": mock_traffic_data
|
||
|
|
}
|
||
|
|
|
||
|
|
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
|
||
|
|
|
||
|
|
yield {
|
||
|
|
'sales': mock_sales_data,
|
||
|
|
'weather': mock_weather_data,
|
||
|
|
'traffic': mock_traffic_data
|
||
|
|
}
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def mock_ml_trainer():
|
||
|
|
"""Mock ML trainer for testing"""
|
||
|
|
with patch('app.ml.trainer.BakeryMLTrainer') as mock_trainer_class:
|
||
|
|
mock_trainer = Mock(spec=BakeryMLTrainer)
|
||
|
|
|
||
|
|
# Mock training results
|
||
|
|
mock_training_results = {
|
||
|
|
"job_id": "test-job-123",
|
||
|
|
"tenant_id": "test-tenant",
|
||
|
|
"status": "completed",
|
||
|
|
"products_trained": 1,
|
||
|
|
"products_failed": 0,
|
||
|
|
"total_products": 1,
|
||
|
|
"training_results": {
|
||
|
|
"Pan Integral": {
|
||
|
|
"status": "success",
|
||
|
|
"model_info": {
|
||
|
|
"model_id": "test-model-123",
|
||
|
|
"model_path": "/test/models/test-model-123.pkl",
|
||
|
|
"type": "prophet",
|
||
|
|
"training_samples": 100,
|
||
|
|
"features": ["temperature", "humidity"],
|
||
|
|
"training_metrics": {
|
||
|
|
"mae": 5.2,
|
||
|
|
"rmse": 7.8,
|
||
|
|
"mape": 12.5,
|
||
|
|
"r2_score": 0.85
|
||
|
|
},
|
||
|
|
"data_period": {
|
||
|
|
"start_date": "2024-01-01",
|
||
|
|
"end_date": "2024-01-31"
|
||
|
|
}
|
||
|
|
},
|
||
|
|
"data_points": 100
|
||
|
|
}
|
||
|
|
},
|
||
|
|
"summary": {
|
||
|
|
"success_rate": 100.0,
|
||
|
|
"total_products": 1,
|
||
|
|
"successful_products": 1,
|
||
|
|
"failed_products": 0
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
mock_trainer.train_tenant_models.return_value = AsyncMock(return_value=mock_training_results)
|
||
|
|
mock_trainer.train_single_product.return_value = AsyncMock(return_value={
|
||
|
|
"status": "success",
|
||
|
|
"model_info": mock_training_results["training_results"]["Pan Integral"]["model_info"]
|
||
|
|
})
|
||
|
|
|
||
|
|
mock_trainer_class.return_value = mock_trainer
|
||
|
|
yield mock_trainer
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def sample_training_job() -> dict:
|
||
|
|
"""Sample training job data"""
|
||
|
|
return {
|
||
|
|
"job_id": "test-job-123",
|
||
|
|
"tenant_id": "test-tenant",
|
||
|
|
"status": "pending",
|
||
|
|
"progress": 0,
|
||
|
|
"current_step": "Initializing",
|
||
|
|
"config": {
|
||
|
|
"include_weather": True,
|
||
|
|
"include_traffic": True,
|
||
|
|
"min_data_points": 30
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def sample_trained_model() -> dict:
|
||
|
|
"""Sample trained model data"""
|
||
|
|
return {
|
||
|
|
"model_id": "test-model-123",
|
||
|
|
"tenant_id": "test-tenant",
|
||
|
|
"product_name": "Pan Integral",
|
||
|
|
"model_type": "prophet",
|
||
|
|
"model_path": "/test/models/test-model-123.pkl",
|
||
|
|
"version": 1,
|
||
|
|
"training_samples": 100,
|
||
|
|
"features": ["temperature", "humidity", "traffic_volume"],
|
||
|
|
"hyperparameters": {
|
||
|
|
"seasonality_mode": "additive",
|
||
|
|
"daily_seasonality": True,
|
||
|
|
"weekly_seasonality": True
|
||
|
|
},
|
||
|
|
"training_metrics": {
|
||
|
|
"mae": 5.2,
|
||
|
|
"rmse": 7.8,
|
||
|
|
"mape": 12.5,
|
||
|
|
"r2_score": 0.85
|
||
|
|
},
|
||
|
|
"is_active": True
|
||
|
|
}
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
async def training_job_in_db(test_db_session, sample_training_job):
|
||
|
|
"""Create a training job in the test database"""
|
||
|
|
training_log = ModelTrainingLog(**sample_training_job)
|
||
|
|
test_db_session.add(training_log)
|
||
|
|
await test_db_session.commit()
|
||
|
|
await test_db_session.refresh(training_log)
|
||
|
|
return training_log
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
async def trained_model_in_db(test_db_session, sample_trained_model):
|
||
|
|
"""Create a trained model in the test database"""
|
||
|
|
from datetime import datetime
|
||
|
|
|
||
|
|
model_data = sample_trained_model.copy()
|
||
|
|
model_data.update({
|
||
|
|
"data_period_start": datetime(2024, 1, 1),
|
||
|
|
"data_period_end": datetime(2024, 1, 31),
|
||
|
|
"created_at": datetime.now()
|
||
|
|
})
|
||
|
|
|
||
|
|
trained_model = TrainedModel(**model_data)
|
||
|
|
test_db_session.add(trained_model)
|
||
|
|
await test_db_session.commit()
|
||
|
|
await test_db_session.refresh(trained_model)
|
||
|
|
return trained_model
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def mock_prophet_manager():
|
||
|
|
"""Mock Prophet manager for testing"""
|
||
|
|
with patch('app.ml.prophet_manager.BakeryProphetManager') as mock_manager_class:
|
||
|
|
mock_manager = Mock(spec=BakeryProphetManager)
|
||
|
|
|
||
|
|
mock_model_info = {
|
||
|
|
"model_id": "test-model-123",
|
||
|
|
"model_path": "/test/models/test-model-123.pkl",
|
||
|
|
"type": "prophet",
|
||
|
|
"training_samples": 100,
|
||
|
|
"features": ["temperature", "humidity"],
|
||
|
|
"training_metrics": {
|
||
|
|
"mae": 5.2,
|
||
|
|
"rmse": 7.8,
|
||
|
|
"mape": 12.5,
|
||
|
|
"r2_score": 0.85
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
mock_manager.train_bakery_model.return_value = AsyncMock(return_value=mock_model_info)
|
||
|
|
mock_manager.generate_forecast.return_value = AsyncMock()
|
||
|
|
|
||
|
|
mock_manager_class.return_value = mock_manager
|
||
|
|
yield mock_manager
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def mock_data_processor():
|
||
|
|
"""Mock data processor for testing"""
|
||
|
|
import pandas as pd
|
||
|
|
|
||
|
|
with patch('app.ml.data_processor.BakeryDataProcessor') as mock_processor_class:
|
||
|
|
mock_processor = Mock(spec=BakeryDataProcessor)
|
||
|
|
|
||
|
|
# Mock processed data
|
||
|
|
mock_processed_data = pd.DataFrame({
|
||
|
|
'ds': pd.date_range('2024-01-01', periods=30, freq='D'),
|
||
|
|
'y': [45 + i for i in range(30)],
|
||
|
|
'temperature': [15.0 + (i % 10) for i in range(30)],
|
||
|
|
'humidity': [60.0 + (i % 20) for i in range(30)]
|
||
|
|
})
|
||
|
|
|
||
|
|
mock_processor.prepare_training_data.return_value = AsyncMock(return_value=mock_processed_data)
|
||
|
|
mock_processor.prepare_prediction_features.return_value = AsyncMock(return_value=mock_processed_data)
|
||
|
|
|
||
|
|
mock_processor_class.return_value = mock_processor
|
||
|
|
yield mock_processor
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def mock_auth():
|
||
|
|
"""Mock authentication for tests"""
|
||
|
|
with patch('shared.auth.decorators.require_auth') as mock_auth:
|
||
|
|
mock_auth.return_value = lambda func: func # Pass through without auth
|
||
|
|
yield mock_auth
|
||
|
|
|
||
|
|
# Helper functions for tests
|
||
|
|
def assert_training_job_structure(job_data: dict):
|
||
|
|
"""Assert that training job data has correct structure"""
|
||
|
|
required_fields = ["job_id", "status", "tenant_id", "created_at"]
|
||
|
|
for field in required_fields:
|
||
|
|
assert field in job_data, f"Missing required field: {field}"
|
||
|
|
|
||
|
|
def assert_model_structure(model_data: dict):
|
||
|
|
"""Assert that model data has correct structure"""
|
||
|
|
required_fields = ["model_id", "model_type", "training_samples", "features"]
|
||
|
|
for field in required_fields:
|
||
|
|
assert field in model_data, f"Missing required field: {field}"
|