Add all the code for training service
This commit is contained in:
362
services/training/tests/conftest.py
Normal file
362
services/training/tests/conftest.py
Normal file
@@ -0,0 +1,362 @@
|
||||
# services/training/tests/conftest.py
|
||||
"""
|
||||
Pytest configuration and fixtures for training service tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import os
|
||||
from typing import AsyncGenerator, Generator
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from httpx import AsyncClient
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
# Add app to Python path
|
||||
import sys
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from app.main import app
|
||||
from app.core.database import Base, get_db
|
||||
from app.core.config import settings
|
||||
from app.models.training import ModelTrainingLog, TrainedModel
|
||||
from app.ml.trainer import BakeryMLTrainer
|
||||
from app.ml.prophet_manager import BakeryProphetManager
|
||||
from app.ml.data_processor import BakeryDataProcessor
|
||||
|
||||
# Test database URL
|
||||
TEST_DATABASE_URL = "sqlite+aiosqlite:///./test_training.db"
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]:
|
||||
"""Create an instance of the default event loop for the test session."""
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def test_engine():
|
||||
"""Create test database engine"""
|
||||
engine = create_async_engine(TEST_DATABASE_URL, echo=False)
|
||||
|
||||
# Create all tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
# Cleanup
|
||||
await engine.dispose()
|
||||
|
||||
# Remove test database file
|
||||
try:
|
||||
os.remove("./test_training.db")
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
@pytest.fixture
|
||||
async def test_db_session(test_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Create test database session"""
|
||||
async_session = sessionmaker(
|
||||
test_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
||||
@pytest.fixture
|
||||
def override_get_db(test_db_session):
|
||||
"""Override the get_db dependency"""
|
||||
async def _override_get_db():
|
||||
yield test_db_session
|
||||
|
||||
app.dependency_overrides[get_db] = _override_get_db
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
@pytest.fixture
|
||||
async def test_client(override_get_db) -> AsyncGenerator[AsyncClient, None]:
|
||||
"""Create test HTTP client"""
|
||||
async with AsyncClient(app=app, base_url="http://test") as client:
|
||||
yield client
|
||||
|
||||
@pytest.fixture
|
||||
def sync_test_client() -> Generator[TestClient, None, None]:
|
||||
"""Create synchronous test client for simple tests"""
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_messaging():
|
||||
"""Mock messaging for tests"""
|
||||
with patch('app.services.messaging.setup_messaging') as mock_setup, \
|
||||
patch('app.services.messaging.cleanup_messaging') as mock_cleanup, \
|
||||
patch('app.services.messaging.publish_job_started') as mock_start, \
|
||||
patch('app.services.messaging.publish_job_completed') as mock_complete, \
|
||||
patch('app.services.messaging.publish_job_failed') as mock_failed:
|
||||
|
||||
mock_setup.return_value = AsyncMock()
|
||||
mock_cleanup.return_value = AsyncMock()
|
||||
mock_start.return_value = AsyncMock(return_value=True)
|
||||
mock_complete.return_value = AsyncMock(return_value=True)
|
||||
mock_failed.return_value = AsyncMock(return_value=True)
|
||||
|
||||
yield {
|
||||
'setup': mock_setup,
|
||||
'cleanup': mock_cleanup,
|
||||
'start': mock_start,
|
||||
'complete': mock_complete,
|
||||
'failed': mock_failed
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_data_service():
|
||||
"""Mock external data service responses"""
|
||||
mock_sales_data = [
|
||||
{
|
||||
"date": "2024-01-01",
|
||||
"product_name": "Pan Integral",
|
||||
"quantity": 45
|
||||
},
|
||||
{
|
||||
"date": "2024-01-02",
|
||||
"product_name": "Pan Integral",
|
||||
"quantity": 52
|
||||
}
|
||||
]
|
||||
|
||||
mock_weather_data = [
|
||||
{
|
||||
"date": "2024-01-01",
|
||||
"temperature": 15.2,
|
||||
"precipitation": 0.0,
|
||||
"humidity": 65
|
||||
},
|
||||
{
|
||||
"date": "2024-01-02",
|
||||
"temperature": 18.1,
|
||||
"precipitation": 2.5,
|
||||
"humidity": 72
|
||||
}
|
||||
]
|
||||
|
||||
mock_traffic_data = [
|
||||
{
|
||||
"date": "2024-01-01",
|
||||
"traffic_volume": 120
|
||||
},
|
||||
{
|
||||
"date": "2024-01-02",
|
||||
"traffic_volume": 95
|
||||
}
|
||||
]
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"sales": mock_sales_data,
|
||||
"weather": mock_weather_data,
|
||||
"traffic": mock_traffic_data
|
||||
}
|
||||
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
|
||||
|
||||
yield {
|
||||
'sales': mock_sales_data,
|
||||
'weather': mock_weather_data,
|
||||
'traffic': mock_traffic_data
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ml_trainer():
|
||||
"""Mock ML trainer for testing"""
|
||||
with patch('app.ml.trainer.BakeryMLTrainer') as mock_trainer_class:
|
||||
mock_trainer = Mock(spec=BakeryMLTrainer)
|
||||
|
||||
# Mock training results
|
||||
mock_training_results = {
|
||||
"job_id": "test-job-123",
|
||||
"tenant_id": "test-tenant",
|
||||
"status": "completed",
|
||||
"products_trained": 1,
|
||||
"products_failed": 0,
|
||||
"total_products": 1,
|
||||
"training_results": {
|
||||
"Pan Integral": {
|
||||
"status": "success",
|
||||
"model_info": {
|
||||
"model_id": "test-model-123",
|
||||
"model_path": "/test/models/test-model-123.pkl",
|
||||
"type": "prophet",
|
||||
"training_samples": 100,
|
||||
"features": ["temperature", "humidity"],
|
||||
"training_metrics": {
|
||||
"mae": 5.2,
|
||||
"rmse": 7.8,
|
||||
"mape": 12.5,
|
||||
"r2_score": 0.85
|
||||
},
|
||||
"data_period": {
|
||||
"start_date": "2024-01-01",
|
||||
"end_date": "2024-01-31"
|
||||
}
|
||||
},
|
||||
"data_points": 100
|
||||
}
|
||||
},
|
||||
"summary": {
|
||||
"success_rate": 100.0,
|
||||
"total_products": 1,
|
||||
"successful_products": 1,
|
||||
"failed_products": 0
|
||||
}
|
||||
}
|
||||
|
||||
mock_trainer.train_tenant_models.return_value = AsyncMock(return_value=mock_training_results)
|
||||
mock_trainer.train_single_product.return_value = AsyncMock(return_value={
|
||||
"status": "success",
|
||||
"model_info": mock_training_results["training_results"]["Pan Integral"]["model_info"]
|
||||
})
|
||||
|
||||
mock_trainer_class.return_value = mock_trainer
|
||||
yield mock_trainer
|
||||
|
||||
@pytest.fixture
|
||||
def sample_training_job() -> dict:
|
||||
"""Sample training job data"""
|
||||
return {
|
||||
"job_id": "test-job-123",
|
||||
"tenant_id": "test-tenant",
|
||||
"status": "pending",
|
||||
"progress": 0,
|
||||
"current_step": "Initializing",
|
||||
"config": {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 30
|
||||
}
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def sample_trained_model() -> dict:
|
||||
"""Sample trained model data"""
|
||||
return {
|
||||
"model_id": "test-model-123",
|
||||
"tenant_id": "test-tenant",
|
||||
"product_name": "Pan Integral",
|
||||
"model_type": "prophet",
|
||||
"model_path": "/test/models/test-model-123.pkl",
|
||||
"version": 1,
|
||||
"training_samples": 100,
|
||||
"features": ["temperature", "humidity", "traffic_volume"],
|
||||
"hyperparameters": {
|
||||
"seasonality_mode": "additive",
|
||||
"daily_seasonality": True,
|
||||
"weekly_seasonality": True
|
||||
},
|
||||
"training_metrics": {
|
||||
"mae": 5.2,
|
||||
"rmse": 7.8,
|
||||
"mape": 12.5,
|
||||
"r2_score": 0.85
|
||||
},
|
||||
"is_active": True
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
async def training_job_in_db(test_db_session, sample_training_job):
|
||||
"""Create a training job in the test database"""
|
||||
training_log = ModelTrainingLog(**sample_training_job)
|
||||
test_db_session.add(training_log)
|
||||
await test_db_session.commit()
|
||||
await test_db_session.refresh(training_log)
|
||||
return training_log
|
||||
|
||||
@pytest.fixture
|
||||
async def trained_model_in_db(test_db_session, sample_trained_model):
|
||||
"""Create a trained model in the test database"""
|
||||
from datetime import datetime
|
||||
|
||||
model_data = sample_trained_model.copy()
|
||||
model_data.update({
|
||||
"data_period_start": datetime(2024, 1, 1),
|
||||
"data_period_end": datetime(2024, 1, 31),
|
||||
"created_at": datetime.now()
|
||||
})
|
||||
|
||||
trained_model = TrainedModel(**model_data)
|
||||
test_db_session.add(trained_model)
|
||||
await test_db_session.commit()
|
||||
await test_db_session.refresh(trained_model)
|
||||
return trained_model
|
||||
|
||||
@pytest.fixture
|
||||
def mock_prophet_manager():
|
||||
"""Mock Prophet manager for testing"""
|
||||
with patch('app.ml.prophet_manager.BakeryProphetManager') as mock_manager_class:
|
||||
mock_manager = Mock(spec=BakeryProphetManager)
|
||||
|
||||
mock_model_info = {
|
||||
"model_id": "test-model-123",
|
||||
"model_path": "/test/models/test-model-123.pkl",
|
||||
"type": "prophet",
|
||||
"training_samples": 100,
|
||||
"features": ["temperature", "humidity"],
|
||||
"training_metrics": {
|
||||
"mae": 5.2,
|
||||
"rmse": 7.8,
|
||||
"mape": 12.5,
|
||||
"r2_score": 0.85
|
||||
}
|
||||
}
|
||||
|
||||
mock_manager.train_bakery_model.return_value = AsyncMock(return_value=mock_model_info)
|
||||
mock_manager.generate_forecast.return_value = AsyncMock()
|
||||
|
||||
mock_manager_class.return_value = mock_manager
|
||||
yield mock_manager
|
||||
|
||||
@pytest.fixture
|
||||
def mock_data_processor():
|
||||
"""Mock data processor for testing"""
|
||||
import pandas as pd
|
||||
|
||||
with patch('app.ml.data_processor.BakeryDataProcessor') as mock_processor_class:
|
||||
mock_processor = Mock(spec=BakeryDataProcessor)
|
||||
|
||||
# Mock processed data
|
||||
mock_processed_data = pd.DataFrame({
|
||||
'ds': pd.date_range('2024-01-01', periods=30, freq='D'),
|
||||
'y': [45 + i for i in range(30)],
|
||||
'temperature': [15.0 + (i % 10) for i in range(30)],
|
||||
'humidity': [60.0 + (i % 20) for i in range(30)]
|
||||
})
|
||||
|
||||
mock_processor.prepare_training_data.return_value = AsyncMock(return_value=mock_processed_data)
|
||||
mock_processor.prepare_prediction_features.return_value = AsyncMock(return_value=mock_processed_data)
|
||||
|
||||
mock_processor_class.return_value = mock_processor
|
||||
yield mock_processor
|
||||
|
||||
@pytest.fixture
|
||||
def mock_auth():
|
||||
"""Mock authentication for tests"""
|
||||
with patch('shared.auth.decorators.require_auth') as mock_auth:
|
||||
mock_auth.return_value = lambda func: func # Pass through without auth
|
||||
yield mock_auth
|
||||
|
||||
# Helper functions for tests
|
||||
def assert_training_job_structure(job_data: dict):
|
||||
"""Assert that training job data has correct structure"""
|
||||
required_fields = ["job_id", "status", "tenant_id", "created_at"]
|
||||
for field in required_fields:
|
||||
assert field in job_data, f"Missing required field: {field}"
|
||||
|
||||
def assert_model_structure(model_data: dict):
|
||||
"""Assert that model data has correct structure"""
|
||||
required_fields = ["model_id", "model_type", "training_samples", "features"]
|
||||
for field in required_fields:
|
||||
assert field in model_data, f"Missing required field: {field}"
|
||||
Reference in New Issue
Block a user