Files
bakery-ia/services/forecasting/app/utils/distributed_lock.py
2025-11-15 15:20:10 +01:00

259 lines
8.0 KiB
Python

"""
Distributed Locking Mechanisms for Forecasting Service
Prevents concurrent forecast generation for the same product/date
"""
import asyncio
import time
from typing import Optional
import logging
from contextlib import asynccontextmanager
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
from datetime import datetime, timezone, timedelta
logger = logging.getLogger(__name__)
class LockAcquisitionError(Exception):
"""Raised when lock cannot be acquired"""
pass
class DatabaseLock:
"""
Database-based distributed lock using PostgreSQL advisory locks.
Works across multiple service instances.
"""
def __init__(self, lock_name: str, timeout: float = 30.0):
"""
Initialize database lock.
Args:
lock_name: Unique identifier for the lock
timeout: Maximum seconds to wait for lock acquisition
"""
self.lock_name = lock_name
self.timeout = timeout
self.lock_id = self._hash_lock_name(lock_name)
def _hash_lock_name(self, name: str) -> int:
"""Convert lock name to integer ID for PostgreSQL advisory lock"""
# Use hash and modulo to get a positive 32-bit integer
return abs(hash(name)) % (2**31)
@asynccontextmanager
async def acquire(self, session: AsyncSession):
"""
Acquire distributed lock as async context manager.
Args:
session: Database session for lock operations
Raises:
LockAcquisitionError: If lock cannot be acquired within timeout
"""
acquired = False
start_time = time.time()
try:
# Try to acquire lock with timeout
while time.time() - start_time < self.timeout:
# Try non-blocking lock acquisition
result = await session.execute(
text("SELECT pg_try_advisory_lock(:lock_id)"),
{"lock_id": self.lock_id}
)
acquired = result.scalar()
if acquired:
logger.info(f"Acquired lock: {self.lock_name} (id={self.lock_id})")
break
# Wait a bit before retrying
await asyncio.sleep(0.1)
if not acquired:
raise LockAcquisitionError(
f"Could not acquire lock '{self.lock_name}' within {self.timeout}s"
)
yield
finally:
if acquired:
# Release lock
await session.execute(
text("SELECT pg_advisory_unlock(:lock_id)"),
{"lock_id": self.lock_id}
)
logger.info(f"Released lock: {self.lock_name} (id={self.lock_id})")
class SimpleDatabaseLock:
"""
Simple table-based distributed lock.
Alternative to advisory locks, uses a dedicated locks table.
"""
def __init__(self, lock_name: str, timeout: float = 30.0, ttl: float = 300.0):
"""
Initialize simple database lock.
Args:
lock_name: Unique identifier for the lock
timeout: Maximum seconds to wait for lock acquisition
ttl: Time-to-live for stale lock cleanup (seconds)
"""
self.lock_name = lock_name
self.timeout = timeout
self.ttl = ttl
async def _ensure_lock_table(self, session: AsyncSession):
"""Ensure locks table exists"""
create_table_sql = """
CREATE TABLE IF NOT EXISTS distributed_locks (
lock_name VARCHAR(255) PRIMARY KEY,
acquired_at TIMESTAMP WITH TIME ZONE NOT NULL,
acquired_by VARCHAR(255),
expires_at TIMESTAMP WITH TIME ZONE NOT NULL
)
"""
await session.execute(text(create_table_sql))
await session.commit()
async def _cleanup_stale_locks(self, session: AsyncSession):
"""Remove expired locks"""
cleanup_sql = """
DELETE FROM distributed_locks
WHERE expires_at < :now
"""
await session.execute(
text(cleanup_sql),
{"now": datetime.now(timezone.utc)}
)
await session.commit()
@asynccontextmanager
async def acquire(self, session: AsyncSession, owner: str = "forecasting-service"):
"""
Acquire simple database lock.
Args:
session: Database session
owner: Identifier for lock owner
Raises:
LockAcquisitionError: If lock cannot be acquired
"""
await self._ensure_lock_table(session)
await self._cleanup_stale_locks(session)
acquired = False
start_time = time.time()
try:
# Try to acquire lock
while time.time() - start_time < self.timeout:
now = datetime.now(timezone.utc)
expires_at = now + timedelta(seconds=self.ttl)
try:
# Try to insert lock record
insert_sql = """
INSERT INTO distributed_locks (lock_name, acquired_at, acquired_by, expires_at)
VALUES (:lock_name, :acquired_at, :acquired_by, :expires_at)
ON CONFLICT (lock_name) DO NOTHING
RETURNING lock_name
"""
result = await session.execute(
text(insert_sql),
{
"lock_name": self.lock_name,
"acquired_at": now,
"acquired_by": owner,
"expires_at": expires_at
}
)
await session.commit()
if result.rowcount > 0:
acquired = True
logger.info(f"Acquired simple lock: {self.lock_name}")
break
except Exception as e:
logger.debug(f"Lock acquisition attempt failed: {e}")
await session.rollback()
# Wait before retrying
await asyncio.sleep(0.5)
if not acquired:
raise LockAcquisitionError(
f"Could not acquire lock '{self.lock_name}' within {self.timeout}s"
)
yield
finally:
if acquired:
# Release lock
delete_sql = """
DELETE FROM distributed_locks
WHERE lock_name = :lock_name
"""
await session.execute(
text(delete_sql),
{"lock_name": self.lock_name}
)
await session.commit()
logger.info(f"Released simple lock: {self.lock_name}")
def get_forecast_lock(
tenant_id: str,
product_id: str,
forecast_date: str,
use_advisory: bool = True
) -> DatabaseLock:
"""
Get distributed lock for generating a forecast for a specific product and date.
Args:
tenant_id: Tenant identifier
product_id: Product identifier
forecast_date: Forecast date (ISO format)
use_advisory: Use PostgreSQL advisory locks (True) or table-based (False)
Returns:
Lock instance
"""
lock_name = f"forecast:{tenant_id}:{product_id}:{forecast_date}"
if use_advisory:
return DatabaseLock(lock_name, timeout=30.0)
else:
return SimpleDatabaseLock(lock_name, timeout=30.0, ttl=300.0)
def get_batch_forecast_lock(tenant_id: str, use_advisory: bool = True) -> DatabaseLock:
"""
Get distributed lock for batch forecast generation for a tenant.
Args:
tenant_id: Tenant identifier
use_advisory: Use PostgreSQL advisory locks (True) or table-based (False)
Returns:
Lock instance
"""
lock_name = f"forecast_batch:{tenant_id}"
if use_advisory:
return DatabaseLock(lock_name, timeout=60.0)
else:
return SimpleDatabaseLock(lock_name, timeout=60.0, ttl=600.0)