225 lines
8.0 KiB
Python
225 lines
8.0 KiB
Python
import pytest
|
|
import asyncio
|
|
import os
|
|
import sys
|
|
from typing import AsyncGenerator
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from fastapi.testclient import TestClient
|
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
import redis.asyncio as redis
|
|
|
|
# Add the app directory to the Python path for imports
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
|
|
|
# ================================================================
|
|
# TEST DATABASE CONFIGURATION
|
|
# ================================================================
|
|
|
|
# Use in-memory SQLite for fast testing
|
|
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
|
|
|
|
@pytest.fixture(scope="session")
|
|
def event_loop():
|
|
"""Create an instance of the default event loop for the test session."""
|
|
loop = asyncio.get_event_loop_policy().new_event_loop()
|
|
yield loop
|
|
loop.close()
|
|
|
|
@pytest.fixture(scope="function")
|
|
async def test_engine():
|
|
"""Create a test database engine for each test function"""
|
|
engine = create_async_engine(
|
|
TEST_DATABASE_URL,
|
|
echo=False, # Set to True for SQL debugging
|
|
future=True,
|
|
pool_pre_ping=True
|
|
)
|
|
# Import Base and metadata after engine creation to avoid circular imports
|
|
from shared.database.base import Base
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
yield engine
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.drop_all)
|
|
await engine.dispose()
|
|
|
|
@pytest.fixture(scope="function")
|
|
async def test_db(test_engine) -> AsyncGenerator[AsyncSession, None]:
|
|
"""Create a test database session for each test function"""
|
|
async_session = sessionmaker(
|
|
test_engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False
|
|
)
|
|
|
|
async with async_session() as session:
|
|
yield session
|
|
await session.rollback() # Rollback after each test to ensure a clean state
|
|
|
|
@pytest.fixture(scope="function")
|
|
def client(test_db):
|
|
"""Create a test client with database dependency override"""
|
|
try:
|
|
from app.main import app
|
|
from app.core.database import get_db
|
|
|
|
def override_get_db():
|
|
# test_db is already an AsyncSession yielded by the fixture
|
|
yield test_db
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
|
|
with TestClient(app) as test_client:
|
|
yield test_client
|
|
|
|
# Clean up overrides
|
|
app.dependency_overrides.clear()
|
|
except ImportError as e:
|
|
pytest.skip(f"Cannot import app modules: {e}. Ensure app.main and app.core.database are accessible.")
|
|
|
|
@pytest.fixture(scope="function")
|
|
async def test_user(test_db):
|
|
"""Create a test user in the database"""
|
|
try:
|
|
from app.services.auth_service import AuthService
|
|
from app.schemas.auth import UserRegistration
|
|
|
|
user_data = UserRegistration(
|
|
email="existing@bakery.es",
|
|
password="TestPassword123",
|
|
full_name="Existing User"
|
|
)
|
|
|
|
user = await AuthService.create_user(
|
|
email=user_data.email,
|
|
password=user_data.password,
|
|
full_name=user_data.full_name,
|
|
db=test_db
|
|
)
|
|
return user
|
|
except ImportError:
|
|
pytest.skip("AuthService not available")
|
|
|
|
@pytest.fixture(scope="function")
|
|
async def test_redis_client():
|
|
"""Create a test Redis client"""
|
|
# Use a mock Redis client for testing
|
|
mock_redis = AsyncMock(spec=redis.Redis)
|
|
yield mock_redis
|
|
await mock_redis.close()
|
|
|
|
# ================================================================\
|
|
# TEST HELPERS
|
|
# ================================================================\
|
|
|
|
import uuid # Moved from test_auth_comprehensive.py as it's a shared helper
|
|
|
|
def generate_random_user_data(prefix="test"):
|
|
"""Generates unique user data for testing."""
|
|
unique_id = uuid.uuid4().hex[:8]
|
|
return {
|
|
"email": f"{prefix}_{unique_id}@bakery.es",
|
|
"password": f"StrongPwd{unique_id}!",
|
|
"full_name": f"Test User {unique_id}"
|
|
}
|
|
|
|
# ================================================================\
|
|
# PYTEST HOOKS
|
|
# ================================================================\
|
|
|
|
def pytest_addoption(parser):
|
|
"""Add custom options to pytest"""
|
|
parser.addoption(
|
|
"--integration", action="store_true", default=False, help="run integration tests"
|
|
)
|
|
parser.addoption(
|
|
"--api", action="store_true", default=False, help="run API tests"
|
|
)
|
|
parser.addoption(
|
|
"--security", action="store_true", default=False, help="run security tests"
|
|
)
|
|
parser.addoption(
|
|
"--performance", action="store_true", default=False, help="run performance tests"
|
|
)
|
|
parser.addoption(
|
|
"--slow", action="store_true", default=False, help="run slow tests"
|
|
)
|
|
parser.addoption(
|
|
"--auth", action="store_true", default=False, help="run authentication tests"
|
|
)
|
|
|
|
def pytest_configure(config):
|
|
"""Configure pytest markers"""
|
|
config.addinivalue_line(
|
|
"markers", "unit: marks tests as unit tests"
|
|
)
|
|
config.addinivalue_line(
|
|
"markers", "integration: marks tests as integration tests"
|
|
)
|
|
config.addinivalue_line(
|
|
"markers", "api: marks tests as API tests"
|
|
)
|
|
config.addinivalue_line(
|
|
"markers", "security: marks tests as security tests"
|
|
)
|
|
config.addinivalue_line(
|
|
"markers", "performance: marks tests as performance tests"
|
|
)
|
|
config.addinivalue_line(
|
|
"markers", "slow: marks tests as slow running"
|
|
)
|
|
config.addinivalue_line(
|
|
"markers", "auth: marks tests as authentication tests"
|
|
)
|
|
|
|
def pytest_collection_modifyitems(config, items):
|
|
"""Modify test collection to add markers automatically"""
|
|
for item in items:
|
|
# Add markers based on test class or function names
|
|
if "test_api" in item.name.lower() or "API" in str(item.cls):
|
|
item.add_marker(pytest.mark.api)
|
|
|
|
if "test_security" in item.name.lower() or "Security" in str(item.cls):
|
|
item.add_marker(pytest.mark.security)
|
|
|
|
if "test_performance" in item.name.lower() or "Performance" in str(item.cls):
|
|
item.add_marker(pytest.mark.performance)
|
|
item.add_marker(pytest.mark.slow)
|
|
|
|
if "integration" in item.name.lower() or "Integration" in str(item.cls):
|
|
item.add_marker(pytest.mark.integration)
|
|
|
|
if "Flow" in str(item.cls) or "flow" in item.name.lower():
|
|
item.add_marker(pytest.mark.integration) # Authentication flows are integration tests
|
|
|
|
# Mark all tests in test_auth_comprehensive.py with 'auth'
|
|
if "test_auth_comprehensive" in str(item.fspath):
|
|
item.add_marker(pytest.mark.auth)
|
|
|
|
# Filtering logic for command line options
|
|
if not any([config.getoption("--integration"), config.getoption("--api"),
|
|
config.getoption("--security"), config.getoption("--performance"),
|
|
config.getoption("--slow"), config.getoption("--auth")]):
|
|
return # No specific filter applied, run all collected tests
|
|
|
|
skip_markers = []
|
|
if not config.getoption("--integration"):
|
|
skip_markers.append(pytest.mark.integration)
|
|
if not config.getoption("--api"):
|
|
skip_markers.append(pytest.mark.api)
|
|
if not config.getoption("--security"):
|
|
skip_markers.append(pytest.mark.security)
|
|
if not config.getoption("--performance"):
|
|
skip_markers.append(pytest.mark.performance)
|
|
if not config.getoption("--slow"):
|
|
skip_markers.append(pytest.mark.slow)
|
|
if not config.getoption("--auth"):
|
|
skip_markers.append(pytest.mark.auth)
|
|
|
|
# Remove tests with any of the skip markers
|
|
if skip_markers:
|
|
for item in list(items): # Iterate over a copy to allow modification
|
|
if any(marker in item.iter_markers() for marker in skip_markers):
|
|
items.remove(item)
|
|
item.add_marker(pytest.mark.skip(reason="filtered by command line option")) |