Add all the code for training service

This commit is contained in:
Urtzi Alfaro
2025-07-19 16:59:37 +02:00
parent 42097202d2
commit f3071c00bd
21 changed files with 7504 additions and 764 deletions

View File

@@ -0,0 +1,362 @@
# services/training/tests/conftest.py
"""
Pytest configuration and fixtures for training service tests
"""
import pytest
import asyncio
import os
from typing import AsyncGenerator, Generator
from unittest.mock import AsyncMock, Mock, patch
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from httpx import AsyncClient
from fastapi.testclient import TestClient
# Add app to Python path
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from app.main import app
from app.core.database import Base, get_db
from app.core.config import settings
from app.models.training import ModelTrainingLog, TrainedModel
from app.ml.trainer import BakeryMLTrainer
from app.ml.prophet_manager import BakeryProphetManager
from app.ml.data_processor import BakeryDataProcessor
# Test database URL
TEST_DATABASE_URL = "sqlite+aiosqlite:///./test_training.db"
@pytest.fixture(scope="session")
def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]:
"""Create an instance of the default event loop for the test session."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture(scope="session")
async def test_engine():
"""Create test database engine"""
engine = create_async_engine(TEST_DATABASE_URL, echo=False)
# Create all tables
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
# Cleanup
await engine.dispose()
# Remove test database file
try:
os.remove("./test_training.db")
except FileNotFoundError:
pass
@pytest.fixture
async def test_db_session(test_engine) -> AsyncGenerator[AsyncSession, None]:
"""Create test database session"""
async_session = sessionmaker(
test_engine, class_=AsyncSession, expire_on_commit=False
)
async with async_session() as session:
yield session
await session.rollback()
@pytest.fixture
def override_get_db(test_db_session):
"""Override the get_db dependency"""
async def _override_get_db():
yield test_db_session
app.dependency_overrides[get_db] = _override_get_db
yield
app.dependency_overrides.clear()
@pytest.fixture
async def test_client(override_get_db) -> AsyncGenerator[AsyncClient, None]:
"""Create test HTTP client"""
async with AsyncClient(app=app, base_url="http://test") as client:
yield client
@pytest.fixture
def sync_test_client() -> Generator[TestClient, None, None]:
"""Create synchronous test client for simple tests"""
with TestClient(app) as client:
yield client
@pytest.fixture
def mock_messaging():
"""Mock messaging for tests"""
with patch('app.services.messaging.setup_messaging') as mock_setup, \
patch('app.services.messaging.cleanup_messaging') as mock_cleanup, \
patch('app.services.messaging.publish_job_started') as mock_start, \
patch('app.services.messaging.publish_job_completed') as mock_complete, \
patch('app.services.messaging.publish_job_failed') as mock_failed:
mock_setup.return_value = AsyncMock()
mock_cleanup.return_value = AsyncMock()
mock_start.return_value = AsyncMock(return_value=True)
mock_complete.return_value = AsyncMock(return_value=True)
mock_failed.return_value = AsyncMock(return_value=True)
yield {
'setup': mock_setup,
'cleanup': mock_cleanup,
'start': mock_start,
'complete': mock_complete,
'failed': mock_failed
}
@pytest.fixture
def mock_data_service():
"""Mock external data service responses"""
mock_sales_data = [
{
"date": "2024-01-01",
"product_name": "Pan Integral",
"quantity": 45
},
{
"date": "2024-01-02",
"product_name": "Pan Integral",
"quantity": 52
}
]
mock_weather_data = [
{
"date": "2024-01-01",
"temperature": 15.2,
"precipitation": 0.0,
"humidity": 65
},
{
"date": "2024-01-02",
"temperature": 18.1,
"precipitation": 2.5,
"humidity": 72
}
]
mock_traffic_data = [
{
"date": "2024-01-01",
"traffic_volume": 120
},
{
"date": "2024-01-02",
"traffic_volume": 95
}
]
with patch('httpx.AsyncClient') as mock_client:
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"sales": mock_sales_data,
"weather": mock_weather_data,
"traffic": mock_traffic_data
}
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
yield {
'sales': mock_sales_data,
'weather': mock_weather_data,
'traffic': mock_traffic_data
}
@pytest.fixture
def mock_ml_trainer():
"""Mock ML trainer for testing"""
with patch('app.ml.trainer.BakeryMLTrainer') as mock_trainer_class:
mock_trainer = Mock(spec=BakeryMLTrainer)
# Mock training results
mock_training_results = {
"job_id": "test-job-123",
"tenant_id": "test-tenant",
"status": "completed",
"products_trained": 1,
"products_failed": 0,
"total_products": 1,
"training_results": {
"Pan Integral": {
"status": "success",
"model_info": {
"model_id": "test-model-123",
"model_path": "/test/models/test-model-123.pkl",
"type": "prophet",
"training_samples": 100,
"features": ["temperature", "humidity"],
"training_metrics": {
"mae": 5.2,
"rmse": 7.8,
"mape": 12.5,
"r2_score": 0.85
},
"data_period": {
"start_date": "2024-01-01",
"end_date": "2024-01-31"
}
},
"data_points": 100
}
},
"summary": {
"success_rate": 100.0,
"total_products": 1,
"successful_products": 1,
"failed_products": 0
}
}
mock_trainer.train_tenant_models.return_value = AsyncMock(return_value=mock_training_results)
mock_trainer.train_single_product.return_value = AsyncMock(return_value={
"status": "success",
"model_info": mock_training_results["training_results"]["Pan Integral"]["model_info"]
})
mock_trainer_class.return_value = mock_trainer
yield mock_trainer
@pytest.fixture
def sample_training_job() -> dict:
"""Sample training job data"""
return {
"job_id": "test-job-123",
"tenant_id": "test-tenant",
"status": "pending",
"progress": 0,
"current_step": "Initializing",
"config": {
"include_weather": True,
"include_traffic": True,
"min_data_points": 30
}
}
@pytest.fixture
def sample_trained_model() -> dict:
"""Sample trained model data"""
return {
"model_id": "test-model-123",
"tenant_id": "test-tenant",
"product_name": "Pan Integral",
"model_type": "prophet",
"model_path": "/test/models/test-model-123.pkl",
"version": 1,
"training_samples": 100,
"features": ["temperature", "humidity", "traffic_volume"],
"hyperparameters": {
"seasonality_mode": "additive",
"daily_seasonality": True,
"weekly_seasonality": True
},
"training_metrics": {
"mae": 5.2,
"rmse": 7.8,
"mape": 12.5,
"r2_score": 0.85
},
"is_active": True
}
@pytest.fixture
async def training_job_in_db(test_db_session, sample_training_job):
"""Create a training job in the test database"""
training_log = ModelTrainingLog(**sample_training_job)
test_db_session.add(training_log)
await test_db_session.commit()
await test_db_session.refresh(training_log)
return training_log
@pytest.fixture
async def trained_model_in_db(test_db_session, sample_trained_model):
"""Create a trained model in the test database"""
from datetime import datetime
model_data = sample_trained_model.copy()
model_data.update({
"data_period_start": datetime(2024, 1, 1),
"data_period_end": datetime(2024, 1, 31),
"created_at": datetime.now()
})
trained_model = TrainedModel(**model_data)
test_db_session.add(trained_model)
await test_db_session.commit()
await test_db_session.refresh(trained_model)
return trained_model
@pytest.fixture
def mock_prophet_manager():
"""Mock Prophet manager for testing"""
with patch('app.ml.prophet_manager.BakeryProphetManager') as mock_manager_class:
mock_manager = Mock(spec=BakeryProphetManager)
mock_model_info = {
"model_id": "test-model-123",
"model_path": "/test/models/test-model-123.pkl",
"type": "prophet",
"training_samples": 100,
"features": ["temperature", "humidity"],
"training_metrics": {
"mae": 5.2,
"rmse": 7.8,
"mape": 12.5,
"r2_score": 0.85
}
}
mock_manager.train_bakery_model.return_value = AsyncMock(return_value=mock_model_info)
mock_manager.generate_forecast.return_value = AsyncMock()
mock_manager_class.return_value = mock_manager
yield mock_manager
@pytest.fixture
def mock_data_processor():
"""Mock data processor for testing"""
import pandas as pd
with patch('app.ml.data_processor.BakeryDataProcessor') as mock_processor_class:
mock_processor = Mock(spec=BakeryDataProcessor)
# Mock processed data
mock_processed_data = pd.DataFrame({
'ds': pd.date_range('2024-01-01', periods=30, freq='D'),
'y': [45 + i for i in range(30)],
'temperature': [15.0 + (i % 10) for i in range(30)],
'humidity': [60.0 + (i % 20) for i in range(30)]
})
mock_processor.prepare_training_data.return_value = AsyncMock(return_value=mock_processed_data)
mock_processor.prepare_prediction_features.return_value = AsyncMock(return_value=mock_processed_data)
mock_processor_class.return_value = mock_processor
yield mock_processor
@pytest.fixture
def mock_auth():
"""Mock authentication for tests"""
with patch('shared.auth.decorators.require_auth') as mock_auth:
mock_auth.return_value = lambda func: func # Pass through without auth
yield mock_auth
# Helper functions for tests
def assert_training_job_structure(job_data: dict):
"""Assert that training job data has correct structure"""
required_fields = ["job_id", "status", "tenant_id", "created_at"]
for field in required_fields:
assert field in job_data, f"Missing required field: {field}"
def assert_model_structure(model_data: dict):
"""Assert that model data has correct structure"""
required_fields = ["model_id", "model_type", "training_samples", "features"]
for field in required_fields:
assert field in model_data, f"Missing required field: {field}"