Fix generating pytest for training service
This commit is contained in:
@@ -18,6 +18,7 @@ from unittest.mock import Mock, AsyncMock, patch
|
||||
from typing import Dict, List, Any, Generator
|
||||
from pathlib import Path
|
||||
import logging
|
||||
from app.models.training import ModelTrainingLog, TrainedModel
|
||||
|
||||
# Configure pytest-asyncio
|
||||
pytestmark = pytest.mark.asyncio
|
||||
@@ -213,16 +214,14 @@ async def test_app():
|
||||
from app.main import app
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_client(test_app):
|
||||
"""Test client for API testing"""
|
||||
from httpx import AsyncClient
|
||||
def test_client(test_app):
|
||||
"""Create test client for API testing - SYNC VERSION"""
|
||||
from httpx import Client
|
||||
|
||||
async with AsyncClient(app=test_app, base_url="http://test") as client:
|
||||
with Client(app=test_app, base_url="http://test") as client:
|
||||
yield client
|
||||
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_headers():
|
||||
"""Mock authentication headers"""
|
||||
@@ -452,7 +451,7 @@ def setup_test_environment():
|
||||
|
||||
yield
|
||||
|
||||
# Cleanup environment
|
||||
# Cleanup environment - FIXED: removed (scope="session")
|
||||
test_vars = [
|
||||
'ENVIRONMENT', 'LOG_LEVEL', 'MODEL_STORAGE_PATH',
|
||||
'MAX_TRAINING_TIME_MINUTES', 'MIN_TRAINING_DATA_DAYS',
|
||||
@@ -461,7 +460,8 @@ def setup_test_environment():
|
||||
]
|
||||
|
||||
for var in test_vars:
|
||||
os.environ.pop(var, None)(scope="session")
|
||||
os.environ.pop(var, None) # FIXED: removed the erroneous (scope="session")
|
||||
|
||||
def event_loop():
|
||||
"""Create an instance of the default event loop for the test session."""
|
||||
loop = asyncio.new_event_loop()
|
||||
@@ -514,41 +514,60 @@ def pytest_collection_modifyitems(config, items):
|
||||
# TEST DATABASE FIXTURES
|
||||
# ================================================================
|
||||
|
||||
@pytest.fixture
|
||||
@pytest_asyncio.fixture
|
||||
async def test_db_session():
|
||||
"""Mock database session for testing"""
|
||||
mock_session = AsyncMock()
|
||||
"""Create async test database session"""
|
||||
from app.core.database import database_manager
|
||||
|
||||
# Mock common database operations
|
||||
mock_session.add = Mock()
|
||||
mock_session.commit = AsyncMock()
|
||||
mock_session.rollback = AsyncMock()
|
||||
mock_session.refresh = AsyncMock()
|
||||
mock_session.close = AsyncMock()
|
||||
mock_session.execute = AsyncMock()
|
||||
mock_session.scalar = AsyncMock()
|
||||
|
||||
return mock_session
|
||||
|
||||
async with database_manager.async_session_local() as session:
|
||||
yield session
|
||||
|
||||
@pytest.fixture
|
||||
def training_job_in_db():
|
||||
"""Mock training job already in database"""
|
||||
from app.models.training import ModelTrainingLog
|
||||
def training_job_in_db(test_db_session):
|
||||
"""Create a training job in database for testing"""
|
||||
from app.models.training import ModelTrainingLog # Add this import
|
||||
from datetime import datetime
|
||||
|
||||
job = ModelTrainingLog(
|
||||
job_id="test_job_123",
|
||||
tenant_id="test_tenant",
|
||||
job_id="test-job-123",
|
||||
tenant_id="test-tenant",
|
||||
status="running",
|
||||
progress=50,
|
||||
current_step="Training model for Pan Integral",
|
||||
config={"include_weather": True, "include_traffic": True},
|
||||
started_at=datetime.now(),
|
||||
logs=["Started training", "Processing data"]
|
||||
current_step="Training models",
|
||||
start_time=datetime.now(), # Use start_time, not started_at
|
||||
config={"include_weather": True},
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now()
|
||||
)
|
||||
|
||||
test_db_session.add(job)
|
||||
test_db_session.commit()
|
||||
test_db_session.refresh(job)
|
||||
return job
|
||||
|
||||
@pytest.fixture
|
||||
def trained_model_in_db(test_db_session):
|
||||
"""Create a trained model in database for testing"""
|
||||
from app.models.training import TrainedModel # Add this import
|
||||
from datetime import datetime
|
||||
|
||||
model = TrainedModel(
|
||||
model_id="test-model-123",
|
||||
tenant_id="test-tenant",
|
||||
product_name="Pan Integral",
|
||||
model_type="prophet",
|
||||
model_path="/tmp/test_model.pkl",
|
||||
version=1,
|
||||
training_samples=100,
|
||||
features=["temperature", "humidity"],
|
||||
hyperparameters={"seasonality_mode": "additive"},
|
||||
training_metrics={"mae": 2.5, "mse": 8.3},
|
||||
is_active=True,
|
||||
created_at=datetime.now()
|
||||
)
|
||||
test_db_session.add(model)
|
||||
test_db_session.commit()
|
||||
test_db_session.refresh(model)
|
||||
return model
|
||||
|
||||
# ================================================================
|
||||
# SAMPLE DATA FIXTURES
|
||||
@@ -843,6 +862,24 @@ def mock_data_processor():
|
||||
|
||||
yield mock_instance
|
||||
|
||||
@pytest.fixture
|
||||
def mock_data_service():
|
||||
"""Mock data service for testing"""
|
||||
from unittest.mock import Mock, AsyncMock
|
||||
|
||||
mock_service = Mock()
|
||||
mock_service.get_sales_data = AsyncMock(return_value=[
|
||||
{"date": "2024-01-01", "product_name": "Pan Integral", "quantity": 45},
|
||||
{"date": "2024-01-02", "product_name": "Pan Integral", "quantity": 38}
|
||||
])
|
||||
mock_service.get_weather_data = AsyncMock(return_value=[
|
||||
{"date": "2024-01-01", "temperature": 20.5, "humidity": 65}
|
||||
])
|
||||
mock_service.get_traffic_data = AsyncMock(return_value=[
|
||||
{"date": "2024-01-01", "traffic_index": 0.7}
|
||||
])
|
||||
|
||||
return mock_service
|
||||
|
||||
@pytest.fixture
|
||||
def mock_prophet_manager():
|
||||
|
||||
Reference in New Issue
Block a user