234 lines
7.3 KiB
Python
234 lines
7.3 KiB
Python
|
|
"""
|
||
|
|
Distributed Locking Mechanisms
|
||
|
|
Prevents concurrent training jobs for the same product
|
||
|
|
"""
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import time
|
||
|
|
from typing import Optional
|
||
|
|
import logging
|
||
|
|
from contextlib import asynccontextmanager
|
||
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
|
from sqlalchemy import text
|
||
|
|
from datetime import datetime, timezone, timedelta
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class LockAcquisitionError(Exception):
|
||
|
|
"""Raised when lock cannot be acquired"""
|
||
|
|
pass
|
||
|
|
|
||
|
|
|
||
|
|
class DatabaseLock:
|
||
|
|
"""
|
||
|
|
Database-based distributed lock using PostgreSQL advisory locks.
|
||
|
|
Works across multiple service instances.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, lock_name: str, timeout: float = 30.0):
|
||
|
|
"""
|
||
|
|
Initialize database lock.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
lock_name: Unique identifier for the lock
|
||
|
|
timeout: Maximum seconds to wait for lock acquisition
|
||
|
|
"""
|
||
|
|
self.lock_name = lock_name
|
||
|
|
self.timeout = timeout
|
||
|
|
self.lock_id = self._hash_lock_name(lock_name)
|
||
|
|
|
||
|
|
def _hash_lock_name(self, name: str) -> int:
|
||
|
|
"""Convert lock name to integer ID for PostgreSQL advisory lock"""
|
||
|
|
# Use hash and modulo to get a positive 32-bit integer
|
||
|
|
return abs(hash(name)) % (2**31)
|
||
|
|
|
||
|
|
@asynccontextmanager
|
||
|
|
async def acquire(self, session: AsyncSession):
|
||
|
|
"""
|
||
|
|
Acquire distributed lock as async context manager.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
session: Database session for lock operations
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
LockAcquisitionError: If lock cannot be acquired within timeout
|
||
|
|
"""
|
||
|
|
acquired = False
|
||
|
|
start_time = time.time()
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Try to acquire lock with timeout
|
||
|
|
while time.time() - start_time < self.timeout:
|
||
|
|
# Try non-blocking lock acquisition
|
||
|
|
result = await session.execute(
|
||
|
|
text("SELECT pg_try_advisory_lock(:lock_id)"),
|
||
|
|
{"lock_id": self.lock_id}
|
||
|
|
)
|
||
|
|
acquired = result.scalar()
|
||
|
|
|
||
|
|
if acquired:
|
||
|
|
logger.info(f"Acquired lock: {self.lock_name} (id={self.lock_id})")
|
||
|
|
break
|
||
|
|
|
||
|
|
# Wait a bit before retrying
|
||
|
|
await asyncio.sleep(0.1)
|
||
|
|
|
||
|
|
if not acquired:
|
||
|
|
raise LockAcquisitionError(
|
||
|
|
f"Could not acquire lock '{self.lock_name}' within {self.timeout}s"
|
||
|
|
)
|
||
|
|
|
||
|
|
yield
|
||
|
|
|
||
|
|
finally:
|
||
|
|
if acquired:
|
||
|
|
# Release lock
|
||
|
|
await session.execute(
|
||
|
|
text("SELECT pg_advisory_unlock(:lock_id)"),
|
||
|
|
{"lock_id": self.lock_id}
|
||
|
|
)
|
||
|
|
logger.info(f"Released lock: {self.lock_name} (id={self.lock_id})")
|
||
|
|
|
||
|
|
|
||
|
|
class SimpleDatabaseLock:
|
||
|
|
"""
|
||
|
|
Simple table-based distributed lock.
|
||
|
|
Alternative to advisory locks, uses a dedicated locks table.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, lock_name: str, timeout: float = 30.0, ttl: float = 300.0):
|
||
|
|
"""
|
||
|
|
Initialize simple database lock.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
lock_name: Unique identifier for the lock
|
||
|
|
timeout: Maximum seconds to wait for lock acquisition
|
||
|
|
ttl: Time-to-live for stale lock cleanup (seconds)
|
||
|
|
"""
|
||
|
|
self.lock_name = lock_name
|
||
|
|
self.timeout = timeout
|
||
|
|
self.ttl = ttl
|
||
|
|
|
||
|
|
async def _ensure_lock_table(self, session: AsyncSession):
|
||
|
|
"""Ensure locks table exists"""
|
||
|
|
create_table_sql = """
|
||
|
|
CREATE TABLE IF NOT EXISTS distributed_locks (
|
||
|
|
lock_name VARCHAR(255) PRIMARY KEY,
|
||
|
|
acquired_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
||
|
|
acquired_by VARCHAR(255),
|
||
|
|
expires_at TIMESTAMP WITH TIME ZONE NOT NULL
|
||
|
|
)
|
||
|
|
"""
|
||
|
|
await session.execute(text(create_table_sql))
|
||
|
|
await session.commit()
|
||
|
|
|
||
|
|
async def _cleanup_stale_locks(self, session: AsyncSession):
|
||
|
|
"""Remove expired locks"""
|
||
|
|
cleanup_sql = """
|
||
|
|
DELETE FROM distributed_locks
|
||
|
|
WHERE expires_at < :now
|
||
|
|
"""
|
||
|
|
await session.execute(
|
||
|
|
text(cleanup_sql),
|
||
|
|
{"now": datetime.now(timezone.utc)}
|
||
|
|
)
|
||
|
|
await session.commit()
|
||
|
|
|
||
|
|
@asynccontextmanager
|
||
|
|
async def acquire(self, session: AsyncSession, owner: str = "training-service"):
|
||
|
|
"""
|
||
|
|
Acquire simple database lock.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
session: Database session
|
||
|
|
owner: Identifier for lock owner
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
LockAcquisitionError: If lock cannot be acquired
|
||
|
|
"""
|
||
|
|
await self._ensure_lock_table(session)
|
||
|
|
await self._cleanup_stale_locks(session)
|
||
|
|
|
||
|
|
acquired = False
|
||
|
|
start_time = time.time()
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Try to acquire lock
|
||
|
|
while time.time() - start_time < self.timeout:
|
||
|
|
now = datetime.now(timezone.utc)
|
||
|
|
expires_at = now + timedelta(seconds=self.ttl)
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Try to insert lock record
|
||
|
|
insert_sql = """
|
||
|
|
INSERT INTO distributed_locks (lock_name, acquired_at, acquired_by, expires_at)
|
||
|
|
VALUES (:lock_name, :acquired_at, :acquired_by, :expires_at)
|
||
|
|
ON CONFLICT (lock_name) DO NOTHING
|
||
|
|
RETURNING lock_name
|
||
|
|
"""
|
||
|
|
|
||
|
|
result = await session.execute(
|
||
|
|
text(insert_sql),
|
||
|
|
{
|
||
|
|
"lock_name": self.lock_name,
|
||
|
|
"acquired_at": now,
|
||
|
|
"acquired_by": owner,
|
||
|
|
"expires_at": expires_at
|
||
|
|
}
|
||
|
|
)
|
||
|
|
await session.commit()
|
||
|
|
|
||
|
|
if result.rowcount > 0:
|
||
|
|
acquired = True
|
||
|
|
logger.info(f"Acquired simple lock: {self.lock_name}")
|
||
|
|
break
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.debug(f"Lock acquisition attempt failed: {e}")
|
||
|
|
await session.rollback()
|
||
|
|
|
||
|
|
# Wait before retrying
|
||
|
|
await asyncio.sleep(0.5)
|
||
|
|
|
||
|
|
if not acquired:
|
||
|
|
raise LockAcquisitionError(
|
||
|
|
f"Could not acquire lock '{self.lock_name}' within {self.timeout}s"
|
||
|
|
)
|
||
|
|
|
||
|
|
yield
|
||
|
|
|
||
|
|
finally:
|
||
|
|
if acquired:
|
||
|
|
# Release lock
|
||
|
|
delete_sql = """
|
||
|
|
DELETE FROM distributed_locks
|
||
|
|
WHERE lock_name = :lock_name
|
||
|
|
"""
|
||
|
|
await session.execute(
|
||
|
|
text(delete_sql),
|
||
|
|
{"lock_name": self.lock_name}
|
||
|
|
)
|
||
|
|
await session.commit()
|
||
|
|
logger.info(f"Released simple lock: {self.lock_name}")
|
||
|
|
|
||
|
|
|
||
|
|
def get_training_lock(tenant_id: str, product_id: str, use_advisory: bool = True) -> DatabaseLock:
|
||
|
|
"""
|
||
|
|
Get distributed lock for training a specific product.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
tenant_id: Tenant identifier
|
||
|
|
product_id: Product identifier
|
||
|
|
use_advisory: Use PostgreSQL advisory locks (True) or table-based (False)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Lock instance
|
||
|
|
"""
|
||
|
|
lock_name = f"training:{tenant_id}:{product_id}"
|
||
|
|
|
||
|
|
if use_advisory:
|
||
|
|
return DatabaseLock(lock_name, timeout=60.0)
|
||
|
|
else:
|
||
|
|
return SimpleDatabaseLock(lock_name, timeout=60.0, ttl=600.0)
|