Add pytest tests to auth 1

This commit is contained in:
Urtzi Alfaro
2025-07-20 14:10:10 +02:00
parent 351f673318
commit 76a06d5c9e
2 changed files with 1089 additions and 1551 deletions

View File

@@ -1,11 +1,3 @@
# ================================================================
# services/auth/tests/conftest.py
# Pytest configuration and shared fixtures for auth service tests
# ================================================================
"""
Shared test configuration and fixtures for authentication service tests
"""
import pytest import pytest
import asyncio import asyncio
import os import os
@@ -43,24 +35,14 @@ async def test_engine():
future=True, future=True,
pool_pre_ping=True pool_pre_ping=True
) )
# Import Base and metadata after engine creation to avoid circular imports
try: from shared.database.base import Base
# Import models and base here to avoid import issues async with engine.begin() as conn:
from shared.database.base import Base await conn.run_sync(Base.metadata.create_all)
# Create all tables yield engine
async with engine.begin() as conn: async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all) await conn.run_sync(Base.metadata.drop_all)
await engine.dispose()
yield engine
# Cleanup
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
except ImportError:
# If shared.database.base is not available, create a mock
yield engine
finally:
await engine.dispose()
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
async def test_db(test_engine) -> AsyncGenerator[AsyncSession, None]: async def test_db(test_engine) -> AsyncGenerator[AsyncSession, None]:
@@ -73,6 +55,7 @@ async def test_db(test_engine) -> AsyncGenerator[AsyncSession, None]:
async with async_session() as session: async with async_session() as session:
yield session yield session
await session.rollback() # Rollback after each test to ensure a clean state
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def client(test_db): def client(test_db):
@@ -82,7 +65,8 @@ def client(test_db):
from app.core.database import get_db from app.core.database import get_db
def override_get_db(): def override_get_db():
return test_db # test_db is already an AsyncSession yielded by the fixture
yield test_db
app.dependency_overrides[get_db] = override_get_db app.dependency_overrides[get_db] = override_get_db
@@ -92,396 +76,81 @@ def client(test_db):
# Clean up overrides # Clean up overrides
app.dependency_overrides.clear() app.dependency_overrides.clear()
except ImportError as e: except ImportError as e:
pytest.skip(f"Cannot import app modules: {e}") pytest.skip(f"Cannot import app modules: {e}. Ensure app.main and app.core.database are accessible.")
# ================================================================ @pytest.fixture(scope="function")
# MOCK FIXTURES async def test_user(test_db):
# ================================================================
@pytest.fixture
def mock_redis():
"""Mock Redis client for testing rate limiting and session management"""
redis_mock = AsyncMock()
# Default return values for common operations
redis_mock.get.return_value = None
redis_mock.incr.return_value = 1
redis_mock.expire.return_value = True
redis_mock.delete.return_value = True
redis_mock.setex.return_value = True
redis_mock.exists.return_value = False
return redis_mock
@pytest.fixture
def mock_rabbitmq():
"""Mock RabbitMQ for testing event publishing"""
rabbitmq_mock = AsyncMock()
# Mock publisher methods
rabbitmq_mock.publish.return_value = True
rabbitmq_mock.connect.return_value = True
rabbitmq_mock.disconnect.return_value = True
return rabbitmq_mock
@pytest.fixture
def mock_external_services():
"""Mock external service calls (tenant service, etc.)"""
with patch('httpx.AsyncClient') as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.json.return_value = {"tenants": []}
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
mock_client.return_value.__aenter__.return_value.post.return_value = mock_response
yield mock_client
# ================================================================
# DATA FIXTURES
# ================================================================
@pytest.fixture
def valid_user_data():
"""Valid user registration data"""
return {
"email": "test@bakery.es",
"password": "TestPassword123",
"full_name": "Test User"
}
@pytest.fixture
def valid_user_data_list():
"""List of valid user data for multiple users"""
return [
{
"email": f"test{i}@bakery.es",
"password": "TestPassword123",
"full_name": f"Test User {i}"
}
for i in range(1, 6)
]
@pytest.fixture
def weak_password_data():
"""User data with various weak passwords"""
return [
{"email": "weak1@bakery.es", "password": "123", "full_name": "Weak 1"},
{"email": "weak2@bakery.es", "password": "password", "full_name": "Weak 2"},
{"email": "weak3@bakery.es", "password": "PASSWORD123", "full_name": "Weak 3"},
{"email": "weak4@bakery.es", "password": "testpassword", "full_name": "Weak 4"},
]
@pytest.fixture
def invalid_email_data():
"""User data with invalid email formats"""
return [
{"email": "invalid", "password": "TestPassword123", "full_name": "Invalid 1"},
{"email": "@bakery.es", "password": "TestPassword123", "full_name": "Invalid 2"},
{"email": "test@", "password": "TestPassword123", "full_name": "Invalid 3"},
{"email": "test..test@bakery.es", "password": "TestPassword123", "full_name": "Invalid 4"},
]
# ================================================================
# USER FIXTURES
# ================================================================
@pytest.fixture
async def test_user(test_db, valid_user_data):
"""Create a test user in the database""" """Create a test user in the database"""
try: try:
from app.services.auth_service import AuthService from app.services.auth_service import AuthService
from app.schemas.auth import UserRegistration
user_data = UserRegistration(
email="existing@bakery.es",
password="TestPassword123",
full_name="Existing User"
)
user = await AuthService.create_user( user = await AuthService.create_user(
email=valid_user_data["email"], email=user_data.email,
password=valid_user_data["password"], password=user_data.password,
full_name=valid_user_data["full_name"], full_name=user_data.full_name,
db=test_db db=test_db
) )
return user return user
except ImportError: except ImportError:
pytest.skip("AuthService not available") pytest.skip("AuthService not available")
@pytest.fixture @pytest.fixture(scope="function")
async def test_users(test_db, valid_user_data_list): async def test_redis_client():
"""Create multiple test users in the database""" """Create a test Redis client"""
try: # Use a mock Redis client for testing
from app.services.auth_service import AuthService mock_redis = AsyncMock(spec=redis.Redis)
yield mock_redis
await mock_redis.close()
users = [] # ================================================================\
for user_data in valid_user_data_list: # TEST HELPERS
user = await AuthService.create_user( # ================================================================\
email=user_data["email"],
password=user_data["password"],
full_name=user_data["full_name"],
db=test_db
)
users.append(user)
return users
except ImportError:
pytest.skip("AuthService not available")
@pytest.fixture import uuid # Moved from test_auth_comprehensive.py as it's a shared helper
async def authenticated_user(client, valid_user_data):
"""Create an authenticated user and return user info, tokens, and headers"""
# Register user
register_response = client.post("/auth/register", json=valid_user_data)
assert register_response.status_code == 200
# Login user
login_data = {
"email": valid_user_data["email"],
"password": valid_user_data["password"]
}
login_response = client.post("/auth/login", json=login_data)
assert login_response.status_code == 200
token_data = login_response.json()
def generate_random_user_data(prefix="test"):
"""Generates unique user data for testing."""
unique_id = uuid.uuid4().hex[:8]
return { return {
"user": register_response.json(), "email": f"{prefix}_{unique_id}@bakery.es",
"tokens": token_data, "password": f"StrongPwd{unique_id}!",
"access_token": token_data["access_token"], "full_name": f"Test User {unique_id}"
"refresh_token": token_data["refresh_token"],
"headers": {"Authorization": f"Bearer {token_data['access_token']}"}
} }
# ================================================================ # ================================================================\
# CONFIGURATION FIXTURES # PYTEST HOOKS
# ================================================================ # ================================================================\
@pytest.fixture def pytest_addoption(parser):
def test_settings(): """Add custom options to pytest"""
"""Test-specific settings override""" parser.addoption(
try: "--integration", action="store_true", default=False, help="run integration tests"
from app.core.config import settings )
parser.addoption(
original_settings = {} "--api", action="store_true", default=False, help="run API tests"
)
# Store original values parser.addoption(
test_overrides = { "--security", action="store_true", default=False, help="run security tests"
'JWT_ACCESS_TOKEN_EXPIRE_MINUTES': 30, )
'JWT_REFRESH_TOKEN_EXPIRE_DAYS': 7, parser.addoption(
'PASSWORD_MIN_LENGTH': 8, "--performance", action="store_true", default=False, help="run performance tests"
'PASSWORD_REQUIRE_UPPERCASE': True, )
'PASSWORD_REQUIRE_LOWERCASE': True, parser.addoption(
'PASSWORD_REQUIRE_NUMBERS': True, "--slow", action="store_true", default=False, help="run slow tests"
'PASSWORD_REQUIRE_SYMBOLS': False, )
'MAX_LOGIN_ATTEMPTS': 5, parser.addoption(
'LOCKOUT_DURATION_MINUTES': 30, "--auth", action="store_true", default=False, help="run authentication tests"
'BCRYPT_ROUNDS': 4, # Lower for faster tests )
}
for key, value in test_overrides.items():
if hasattr(settings, key):
original_settings[key] = getattr(settings, key)
setattr(settings, key, value)
yield settings
# Restore original values
for key, value in original_settings.items():
setattr(settings, key, value)
except ImportError:
pytest.skip("Settings not available")
# ================================================================
# PATCHING FIXTURES
# ================================================================
@pytest.fixture
def patch_redis(mock_redis):
"""Patch Redis client for all tests"""
with patch('app.core.security.redis_client', mock_redis):
yield mock_redis
@pytest.fixture
def patch_messaging(mock_rabbitmq):
"""Patch messaging system for all tests"""
with patch('app.services.messaging.publisher', mock_rabbitmq):
yield mock_rabbitmq
@pytest.fixture
def patch_external_apis(mock_external_services):
"""Patch external API calls"""
yield mock_external_services
# ================================================================
# UTILITY FIXTURES
# ================================================================
@pytest.fixture
def auth_headers():
"""Factory for creating authorization headers"""
def _create_headers(token):
return {"Authorization": f"Bearer {token}"}
return _create_headers
@pytest.fixture
def password_generator():
"""Generate passwords with different characteristics"""
def _generate(
length=12,
include_upper=True,
include_lower=True,
include_numbers=True,
include_symbols=False
):
import random
import string
chars = ""
password = ""
if include_lower:
chars += string.ascii_lowercase
password += random.choice(string.ascii_lowercase)
if include_upper:
chars += string.ascii_uppercase
password += random.choice(string.ascii_uppercase)
if include_numbers:
chars += string.digits
password += random.choice(string.digits)
if include_symbols:
chars += "!@#$%^&*"
password += random.choice("!@#$%^&*")
# Fill remaining length
remaining = length - len(password)
if remaining > 0:
password += ''.join(random.choice(chars) for _ in range(remaining))
# Shuffle the password
password_list = list(password)
random.shuffle(password_list)
return ''.join(password_list)
return _generate
# ================================================================
# PERFORMANCE TESTING FIXTURES
# ================================================================
@pytest.fixture
def performance_timer():
"""Timer utility for performance testing"""
import time
class Timer:
def __init__(self):
self.start_time = None
self.end_time = None
def start(self):
self.start_time = time.time()
def stop(self):
self.end_time = time.time()
@property
def elapsed(self):
if self.start_time and self.end_time:
return self.end_time - self.start_time
return None
def __enter__(self):
self.start()
return self
def __exit__(self, *args):
self.stop()
return Timer
# ================================================================
# DATABASE UTILITY FIXTURES
# ================================================================
@pytest.fixture
async def db_utils(test_db):
"""Database utility functions for testing"""
class DBUtils:
def __init__(self, db):
self.db = db
async def count_users(self):
try:
from sqlalchemy import select, func
from app.models.users import User
result = await self.db.execute(select(func.count(User.id)))
return result.scalar()
except ImportError:
return 0
async def get_user_by_email(self, email):
try:
from sqlalchemy import select
from app.models.users import User
result = await self.db.execute(select(User).where(User.email == email))
return result.scalar_one_or_none()
except ImportError:
return None
async def count_refresh_tokens(self):
try:
from sqlalchemy import select, func
from app.models.users import RefreshToken
result = await self.db.execute(select(func.count(RefreshToken.id)))
return result.scalar()
except ImportError:
return 0
async def clear_all_data(self):
try:
from app.models.users import User, RefreshToken
await self.db.execute(RefreshToken.__table__.delete())
await self.db.execute(User.__table__.delete())
await self.db.commit()
except ImportError:
pass
return DBUtils(test_db)
# ================================================================
# LOGGING FIXTURES
# ================================================================
@pytest.fixture
def capture_logs():
"""Capture logs for testing"""
import logging
from io import StringIO
log_capture = StringIO()
handler = logging.StreamHandler(log_capture)
handler.setLevel(logging.DEBUG)
# Add handler to auth service loggers
loggers = [
logging.getLogger('app.services.auth_service'),
logging.getLogger('app.core.security'),
logging.getLogger('app.api.auth'),
]
for logger in loggers:
logger.addHandler(handler)
logger.setLevel(logging.DEBUG)
yield log_capture
# Clean up
for logger in loggers:
logger.removeHandler(handler)
# ================================================================
# TEST MARKERS AND CONFIGURATION
# ================================================================
def pytest_configure(config): def pytest_configure(config):
"""Configure pytest with custom markers""" """Configure pytest markers"""
config.addinivalue_line( config.addinivalue_line(
"markers", "unit: marks tests as unit tests" "markers", "unit: marks tests as unit tests"
) )
@@ -522,7 +191,35 @@ def pytest_collection_modifyitems(config, items):
item.add_marker(pytest.mark.integration) item.add_marker(pytest.mark.integration)
if "Flow" in str(item.cls) or "flow" in item.name.lower(): if "Flow" in str(item.cls) or "flow" in item.name.lower():
item.add_marker(pytest.mark.integration) item.add_marker(pytest.mark.integration) # Authentication flows are integration tests
if "auth" in item.name.lower() or "Auth" in str(item.cls): # Mark all tests in test_auth_comprehensive.py with 'auth'
if "test_auth_comprehensive" in str(item.fspath):
item.add_marker(pytest.mark.auth) item.add_marker(pytest.mark.auth)
# Filtering logic for command line options
if not any([config.getoption("--integration"), config.getoption("--api"),
config.getoption("--security"), config.getoption("--performance"),
config.getoption("--slow"), config.getoption("--auth")]):
return # No specific filter applied, run all collected tests
skip_markers = []
if not config.getoption("--integration"):
skip_markers.append(pytest.mark.integration)
if not config.getoption("--api"):
skip_markers.append(pytest.mark.api)
if not config.getoption("--security"):
skip_markers.append(pytest.mark.security)
if not config.getoption("--performance"):
skip_markers.append(pytest.mark.performance)
if not config.getoption("--slow"):
skip_markers.append(pytest.mark.slow)
if not config.getoption("--auth"):
skip_markers.append(pytest.mark.auth)
# Remove tests with any of the skip markers
if skip_markers:
for item in list(items): # Iterate over a copy to allow modification
if any(marker in item.iter_markers() for marker in skip_markers):
items.remove(item)
item.add_marker(pytest.mark.skip(reason="filtered by command line option"))

File diff suppressed because it is too large Load Diff