Files
bakery-ia/services/training/app/utils/distributed_lock.py

251 lines
8.1 KiB
Python
Raw Normal View History

"""
Distributed Locking Mechanisms
Prevents concurrent training jobs for the same product
2026-01-18 09:02:27 +01:00
HORIZONTAL SCALING FIX:
- Uses SHA256 for stable hash across all Python processes/pods
- Python's built-in hash() varies between processes due to hash randomization (Python 3.3+)
- This ensures all pods compute the same lock ID for the same lock name
"""
import asyncio
import time
2026-01-18 09:02:27 +01:00
import hashlib
from typing import Optional
import logging
from contextlib import asynccontextmanager
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
from datetime import datetime, timezone, timedelta
logger = logging.getLogger(__name__)
class LockAcquisitionError(Exception):
"""Raised when lock cannot be acquired"""
pass
class DatabaseLock:
"""
Database-based distributed lock using PostgreSQL advisory locks.
Works across multiple service instances.
"""
def __init__(self, lock_name: str, timeout: float = 30.0):
"""
Initialize database lock.
Args:
lock_name: Unique identifier for the lock
timeout: Maximum seconds to wait for lock acquisition
"""
self.lock_name = lock_name
self.timeout = timeout
self.lock_id = self._hash_lock_name(lock_name)
def _hash_lock_name(self, name: str) -> int:
2026-01-18 09:02:27 +01:00
"""
Convert lock name to integer ID for PostgreSQL advisory lock.
CRITICAL: Uses SHA256 for stable hash across all Python processes/pods.
Python's built-in hash() varies between processes due to hash randomization
(PYTHONHASHSEED, enabled by default since Python 3.3), which would cause
different pods to compute different lock IDs for the same lock name,
defeating the purpose of distributed locking.
"""
# Use SHA256 for stable, cross-process hash
hash_bytes = hashlib.sha256(name.encode('utf-8')).digest()
# Take first 4 bytes and convert to positive 31-bit integer
# (PostgreSQL advisory locks use bigint, but we use 31-bit for safety)
return int.from_bytes(hash_bytes[:4], 'big') % (2**31)
@asynccontextmanager
async def acquire(self, session: AsyncSession):
"""
Acquire distributed lock as async context manager.
Args:
session: Database session for lock operations
Raises:
LockAcquisitionError: If lock cannot be acquired within timeout
"""
acquired = False
start_time = time.time()
try:
# Try to acquire lock with timeout
while time.time() - start_time < self.timeout:
# Try non-blocking lock acquisition
result = await session.execute(
text("SELECT pg_try_advisory_lock(:lock_id)"),
{"lock_id": self.lock_id}
)
acquired = result.scalar()
if acquired:
logger.info(f"Acquired lock: {self.lock_name} (id={self.lock_id})")
break
# Wait a bit before retrying
await asyncio.sleep(0.1)
if not acquired:
raise LockAcquisitionError(
f"Could not acquire lock '{self.lock_name}' within {self.timeout}s"
)
yield
finally:
if acquired:
# Release lock
await session.execute(
text("SELECT pg_advisory_unlock(:lock_id)"),
{"lock_id": self.lock_id}
)
logger.info(f"Released lock: {self.lock_name} (id={self.lock_id})")
class SimpleDatabaseLock:
"""
Simple table-based distributed lock.
Alternative to advisory locks, uses a dedicated locks table.
"""
def __init__(self, lock_name: str, timeout: float = 30.0, ttl: float = 300.0):
"""
Initialize simple database lock.
Args:
lock_name: Unique identifier for the lock
timeout: Maximum seconds to wait for lock acquisition
ttl: Time-to-live for stale lock cleanup (seconds)
"""
self.lock_name = lock_name
self.timeout = timeout
self.ttl = ttl
async def _ensure_lock_table(self, session: AsyncSession):
"""Ensure locks table exists"""
create_table_sql = """
CREATE TABLE IF NOT EXISTS distributed_locks (
lock_name VARCHAR(255) PRIMARY KEY,
acquired_at TIMESTAMP WITH TIME ZONE NOT NULL,
acquired_by VARCHAR(255),
expires_at TIMESTAMP WITH TIME ZONE NOT NULL
)
"""
await session.execute(text(create_table_sql))
await session.commit()
async def _cleanup_stale_locks(self, session: AsyncSession):
"""Remove expired locks"""
cleanup_sql = """
DELETE FROM distributed_locks
WHERE expires_at < :now
"""
await session.execute(
text(cleanup_sql),
{"now": datetime.now(timezone.utc)}
)
await session.commit()
@asynccontextmanager
async def acquire(self, session: AsyncSession, owner: str = "training-service"):
"""
Acquire simple database lock.
Args:
session: Database session
owner: Identifier for lock owner
Raises:
LockAcquisitionError: If lock cannot be acquired
"""
await self._ensure_lock_table(session)
await self._cleanup_stale_locks(session)
acquired = False
start_time = time.time()
try:
# Try to acquire lock
while time.time() - start_time < self.timeout:
now = datetime.now(timezone.utc)
expires_at = now + timedelta(seconds=self.ttl)
try:
# Try to insert lock record
insert_sql = """
INSERT INTO distributed_locks (lock_name, acquired_at, acquired_by, expires_at)
VALUES (:lock_name, :acquired_at, :acquired_by, :expires_at)
ON CONFLICT (lock_name) DO NOTHING
RETURNING lock_name
"""
result = await session.execute(
text(insert_sql),
{
"lock_name": self.lock_name,
"acquired_at": now,
"acquired_by": owner,
"expires_at": expires_at
}
)
await session.commit()
if result.rowcount > 0:
acquired = True
logger.info(f"Acquired simple lock: {self.lock_name}")
break
except Exception as e:
logger.debug(f"Lock acquisition attempt failed: {e}")
await session.rollback()
# Wait before retrying
await asyncio.sleep(0.5)
if not acquired:
raise LockAcquisitionError(
f"Could not acquire lock '{self.lock_name}' within {self.timeout}s"
)
yield
finally:
if acquired:
# Release lock
delete_sql = """
DELETE FROM distributed_locks
WHERE lock_name = :lock_name
"""
await session.execute(
text(delete_sql),
{"lock_name": self.lock_name}
)
await session.commit()
logger.info(f"Released simple lock: {self.lock_name}")
def get_training_lock(tenant_id: str, product_id: str, use_advisory: bool = True) -> DatabaseLock:
"""
Get distributed lock for training a specific product.
Args:
tenant_id: Tenant identifier
product_id: Product identifier
use_advisory: Use PostgreSQL advisory locks (True) or table-based (False)
Returns:
Lock instance
"""
lock_name = f"training:{tenant_id}:{product_id}"
if use_advisory:
return DatabaseLock(lock_name, timeout=60.0)
else:
return SimpleDatabaseLock(lock_name, timeout=60.0, ttl=600.0)