REFACTOR external service and improve websocket training
This commit is contained in:
92
services/training/app/utils/__init__.py
Normal file
92
services/training/app/utils/__init__.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""
|
||||
Training Service Utilities
|
||||
"""
|
||||
|
||||
from .timezone_utils import (
|
||||
ensure_timezone_aware,
|
||||
ensure_timezone_naive,
|
||||
normalize_datetime_to_utc,
|
||||
normalize_dataframe_datetime_column,
|
||||
prepare_prophet_datetime,
|
||||
safe_datetime_comparison,
|
||||
get_current_utc,
|
||||
convert_timestamp_to_datetime
|
||||
)
|
||||
|
||||
from .circuit_breaker import (
|
||||
CircuitBreaker,
|
||||
CircuitBreakerError,
|
||||
CircuitState,
|
||||
circuit_breaker_registry
|
||||
)
|
||||
|
||||
from .file_utils import (
|
||||
calculate_file_checksum,
|
||||
verify_file_checksum,
|
||||
get_file_size,
|
||||
ensure_directory_exists,
|
||||
safe_file_delete,
|
||||
get_file_metadata,
|
||||
ChecksummedFile
|
||||
)
|
||||
|
||||
from .distributed_lock import (
|
||||
DatabaseLock,
|
||||
SimpleDatabaseLock,
|
||||
LockAcquisitionError,
|
||||
get_training_lock
|
||||
)
|
||||
|
||||
from .retry import (
|
||||
RetryStrategy,
|
||||
RetryError,
|
||||
retry_async,
|
||||
with_retry,
|
||||
retry_with_timeout,
|
||||
AdaptiveRetryStrategy,
|
||||
TimeoutRetryStrategy,
|
||||
HTTP_RETRY_STRATEGY,
|
||||
DATABASE_RETRY_STRATEGY,
|
||||
EXTERNAL_SERVICE_RETRY_STRATEGY
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Timezone utilities
|
||||
'ensure_timezone_aware',
|
||||
'ensure_timezone_naive',
|
||||
'normalize_datetime_to_utc',
|
||||
'normalize_dataframe_datetime_column',
|
||||
'prepare_prophet_datetime',
|
||||
'safe_datetime_comparison',
|
||||
'get_current_utc',
|
||||
'convert_timestamp_to_datetime',
|
||||
# Circuit breaker
|
||||
'CircuitBreaker',
|
||||
'CircuitBreakerError',
|
||||
'CircuitState',
|
||||
'circuit_breaker_registry',
|
||||
# File utilities
|
||||
'calculate_file_checksum',
|
||||
'verify_file_checksum',
|
||||
'get_file_size',
|
||||
'ensure_directory_exists',
|
||||
'safe_file_delete',
|
||||
'get_file_metadata',
|
||||
'ChecksummedFile',
|
||||
# Distributed locking
|
||||
'DatabaseLock',
|
||||
'SimpleDatabaseLock',
|
||||
'LockAcquisitionError',
|
||||
'get_training_lock',
|
||||
# Retry mechanisms
|
||||
'RetryStrategy',
|
||||
'RetryError',
|
||||
'retry_async',
|
||||
'with_retry',
|
||||
'retry_with_timeout',
|
||||
'AdaptiveRetryStrategy',
|
||||
'TimeoutRetryStrategy',
|
||||
'HTTP_RETRY_STRATEGY',
|
||||
'DATABASE_RETRY_STRATEGY',
|
||||
'EXTERNAL_SERVICE_RETRY_STRATEGY'
|
||||
]
|
||||
198
services/training/app/utils/circuit_breaker.py
Normal file
198
services/training/app/utils/circuit_breaker.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""
|
||||
Circuit Breaker Pattern Implementation
|
||||
Protects against cascading failures from external service calls
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Callable, Any, Optional
|
||||
import logging
|
||||
from functools import wraps
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CircuitState(Enum):
|
||||
"""Circuit breaker states"""
|
||||
CLOSED = "closed" # Normal operation
|
||||
OPEN = "open" # Circuit is open, rejecting requests
|
||||
HALF_OPEN = "half_open" # Testing if service recovered
|
||||
|
||||
|
||||
class CircuitBreakerError(Exception):
|
||||
"""Raised when circuit breaker is open"""
|
||||
pass
|
||||
|
||||
|
||||
class CircuitBreaker:
|
||||
"""
|
||||
Circuit breaker to prevent cascading failures.
|
||||
|
||||
States:
|
||||
- CLOSED: Normal operation, requests pass through
|
||||
- OPEN: Too many failures, rejecting all requests
|
||||
- HALF_OPEN: Testing if service recovered, allowing limited requests
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
failure_threshold: int = 5,
|
||||
recovery_timeout: float = 60.0,
|
||||
expected_exception: type = Exception,
|
||||
name: str = "circuit_breaker"
|
||||
):
|
||||
"""
|
||||
Initialize circuit breaker.
|
||||
|
||||
Args:
|
||||
failure_threshold: Number of failures before opening circuit
|
||||
recovery_timeout: Seconds to wait before attempting recovery
|
||||
expected_exception: Exception type to catch (others will pass through)
|
||||
name: Name for logging purposes
|
||||
"""
|
||||
self.failure_threshold = failure_threshold
|
||||
self.recovery_timeout = recovery_timeout
|
||||
self.expected_exception = expected_exception
|
||||
self.name = name
|
||||
|
||||
self.failure_count = 0
|
||||
self.last_failure_time: Optional[float] = None
|
||||
self.state = CircuitState.CLOSED
|
||||
|
||||
def _record_success(self):
|
||||
"""Record successful call"""
|
||||
self.failure_count = 0
|
||||
self.last_failure_time = None
|
||||
if self.state == CircuitState.HALF_OPEN:
|
||||
logger.info(f"Circuit breaker '{self.name}' recovered, closing circuit")
|
||||
self.state = CircuitState.CLOSED
|
||||
|
||||
def _record_failure(self):
|
||||
"""Record failed call"""
|
||||
self.failure_count += 1
|
||||
self.last_failure_time = time.time()
|
||||
|
||||
if self.failure_count >= self.failure_threshold:
|
||||
if self.state != CircuitState.OPEN:
|
||||
logger.warning(
|
||||
f"Circuit breaker '{self.name}' opened after {self.failure_count} failures"
|
||||
)
|
||||
self.state = CircuitState.OPEN
|
||||
|
||||
def _should_attempt_reset(self) -> bool:
|
||||
"""Check if we should attempt to reset circuit"""
|
||||
return (
|
||||
self.state == CircuitState.OPEN
|
||||
and self.last_failure_time is not None
|
||||
and time.time() - self.last_failure_time >= self.recovery_timeout
|
||||
)
|
||||
|
||||
async def call(self, func: Callable, *args, **kwargs) -> Any:
|
||||
"""
|
||||
Execute function with circuit breaker protection.
|
||||
|
||||
Args:
|
||||
func: Async function to execute
|
||||
*args: Positional arguments for func
|
||||
**kwargs: Keyword arguments for func
|
||||
|
||||
Returns:
|
||||
Result from func
|
||||
|
||||
Raises:
|
||||
CircuitBreakerError: If circuit is open
|
||||
Exception: Original exception if not expected_exception type
|
||||
"""
|
||||
# Check if circuit is open
|
||||
if self.state == CircuitState.OPEN:
|
||||
if self._should_attempt_reset():
|
||||
logger.info(f"Circuit breaker '{self.name}' attempting recovery (half-open)")
|
||||
self.state = CircuitState.HALF_OPEN
|
||||
else:
|
||||
raise CircuitBreakerError(
|
||||
f"Circuit breaker '{self.name}' is open. "
|
||||
f"Service unavailable for {self.recovery_timeout}s after {self.failure_count} failures."
|
||||
)
|
||||
|
||||
try:
|
||||
# Execute the function
|
||||
result = await func(*args, **kwargs)
|
||||
self._record_success()
|
||||
return result
|
||||
|
||||
except self.expected_exception as e:
|
||||
self._record_failure()
|
||||
logger.error(
|
||||
f"Circuit breaker '{self.name}' caught failure",
|
||||
error=str(e),
|
||||
failure_count=self.failure_count,
|
||||
state=self.state.value
|
||||
)
|
||||
raise
|
||||
|
||||
def __call__(self, func: Callable) -> Callable:
|
||||
"""Decorator interface for circuit breaker"""
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
return await self.call(func, *args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
def get_state(self) -> dict:
|
||||
"""Get current circuit breaker state for monitoring"""
|
||||
return {
|
||||
"name": self.name,
|
||||
"state": self.state.value,
|
||||
"failure_count": self.failure_count,
|
||||
"failure_threshold": self.failure_threshold,
|
||||
"last_failure_time": self.last_failure_time,
|
||||
"recovery_timeout": self.recovery_timeout
|
||||
}
|
||||
|
||||
|
||||
class CircuitBreakerRegistry:
|
||||
"""Registry to manage multiple circuit breakers"""
|
||||
|
||||
def __init__(self):
|
||||
self._breakers: dict[str, CircuitBreaker] = {}
|
||||
|
||||
def get_or_create(
|
||||
self,
|
||||
name: str,
|
||||
failure_threshold: int = 5,
|
||||
recovery_timeout: float = 60.0,
|
||||
expected_exception: type = Exception
|
||||
) -> CircuitBreaker:
|
||||
"""Get existing circuit breaker or create new one"""
|
||||
if name not in self._breakers:
|
||||
self._breakers[name] = CircuitBreaker(
|
||||
failure_threshold=failure_threshold,
|
||||
recovery_timeout=recovery_timeout,
|
||||
expected_exception=expected_exception,
|
||||
name=name
|
||||
)
|
||||
return self._breakers[name]
|
||||
|
||||
def get(self, name: str) -> Optional[CircuitBreaker]:
|
||||
"""Get circuit breaker by name"""
|
||||
return self._breakers.get(name)
|
||||
|
||||
def get_all_states(self) -> dict:
|
||||
"""Get states of all circuit breakers"""
|
||||
return {
|
||||
name: breaker.get_state()
|
||||
for name, breaker in self._breakers.items()
|
||||
}
|
||||
|
||||
def reset(self, name: str):
|
||||
"""Manually reset a circuit breaker"""
|
||||
if name in self._breakers:
|
||||
breaker = self._breakers[name]
|
||||
breaker.failure_count = 0
|
||||
breaker.last_failure_time = None
|
||||
breaker.state = CircuitState.CLOSED
|
||||
logger.info(f"Circuit breaker '{name}' manually reset")
|
||||
|
||||
|
||||
# Global registry instance
|
||||
circuit_breaker_registry = CircuitBreakerRegistry()
|
||||
233
services/training/app/utils/distributed_lock.py
Normal file
233
services/training/app/utils/distributed_lock.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""
|
||||
Distributed Locking Mechanisms
|
||||
Prevents concurrent training jobs for the same product
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Optional
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import text
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LockAcquisitionError(Exception):
|
||||
"""Raised when lock cannot be acquired"""
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseLock:
|
||||
"""
|
||||
Database-based distributed lock using PostgreSQL advisory locks.
|
||||
Works across multiple service instances.
|
||||
"""
|
||||
|
||||
def __init__(self, lock_name: str, timeout: float = 30.0):
|
||||
"""
|
||||
Initialize database lock.
|
||||
|
||||
Args:
|
||||
lock_name: Unique identifier for the lock
|
||||
timeout: Maximum seconds to wait for lock acquisition
|
||||
"""
|
||||
self.lock_name = lock_name
|
||||
self.timeout = timeout
|
||||
self.lock_id = self._hash_lock_name(lock_name)
|
||||
|
||||
def _hash_lock_name(self, name: str) -> int:
|
||||
"""Convert lock name to integer ID for PostgreSQL advisory lock"""
|
||||
# Use hash and modulo to get a positive 32-bit integer
|
||||
return abs(hash(name)) % (2**31)
|
||||
|
||||
@asynccontextmanager
|
||||
async def acquire(self, session: AsyncSession):
|
||||
"""
|
||||
Acquire distributed lock as async context manager.
|
||||
|
||||
Args:
|
||||
session: Database session for lock operations
|
||||
|
||||
Raises:
|
||||
LockAcquisitionError: If lock cannot be acquired within timeout
|
||||
"""
|
||||
acquired = False
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Try to acquire lock with timeout
|
||||
while time.time() - start_time < self.timeout:
|
||||
# Try non-blocking lock acquisition
|
||||
result = await session.execute(
|
||||
text("SELECT pg_try_advisory_lock(:lock_id)"),
|
||||
{"lock_id": self.lock_id}
|
||||
)
|
||||
acquired = result.scalar()
|
||||
|
||||
if acquired:
|
||||
logger.info(f"Acquired lock: {self.lock_name} (id={self.lock_id})")
|
||||
break
|
||||
|
||||
# Wait a bit before retrying
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
if not acquired:
|
||||
raise LockAcquisitionError(
|
||||
f"Could not acquire lock '{self.lock_name}' within {self.timeout}s"
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
if acquired:
|
||||
# Release lock
|
||||
await session.execute(
|
||||
text("SELECT pg_advisory_unlock(:lock_id)"),
|
||||
{"lock_id": self.lock_id}
|
||||
)
|
||||
logger.info(f"Released lock: {self.lock_name} (id={self.lock_id})")
|
||||
|
||||
|
||||
class SimpleDatabaseLock:
|
||||
"""
|
||||
Simple table-based distributed lock.
|
||||
Alternative to advisory locks, uses a dedicated locks table.
|
||||
"""
|
||||
|
||||
def __init__(self, lock_name: str, timeout: float = 30.0, ttl: float = 300.0):
|
||||
"""
|
||||
Initialize simple database lock.
|
||||
|
||||
Args:
|
||||
lock_name: Unique identifier for the lock
|
||||
timeout: Maximum seconds to wait for lock acquisition
|
||||
ttl: Time-to-live for stale lock cleanup (seconds)
|
||||
"""
|
||||
self.lock_name = lock_name
|
||||
self.timeout = timeout
|
||||
self.ttl = ttl
|
||||
|
||||
async def _ensure_lock_table(self, session: AsyncSession):
|
||||
"""Ensure locks table exists"""
|
||||
create_table_sql = """
|
||||
CREATE TABLE IF NOT EXISTS distributed_locks (
|
||||
lock_name VARCHAR(255) PRIMARY KEY,
|
||||
acquired_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
||||
acquired_by VARCHAR(255),
|
||||
expires_at TIMESTAMP WITH TIME ZONE NOT NULL
|
||||
)
|
||||
"""
|
||||
await session.execute(text(create_table_sql))
|
||||
await session.commit()
|
||||
|
||||
async def _cleanup_stale_locks(self, session: AsyncSession):
|
||||
"""Remove expired locks"""
|
||||
cleanup_sql = """
|
||||
DELETE FROM distributed_locks
|
||||
WHERE expires_at < :now
|
||||
"""
|
||||
await session.execute(
|
||||
text(cleanup_sql),
|
||||
{"now": datetime.now(timezone.utc)}
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
@asynccontextmanager
|
||||
async def acquire(self, session: AsyncSession, owner: str = "training-service"):
|
||||
"""
|
||||
Acquire simple database lock.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
owner: Identifier for lock owner
|
||||
|
||||
Raises:
|
||||
LockAcquisitionError: If lock cannot be acquired
|
||||
"""
|
||||
await self._ensure_lock_table(session)
|
||||
await self._cleanup_stale_locks(session)
|
||||
|
||||
acquired = False
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Try to acquire lock
|
||||
while time.time() - start_time < self.timeout:
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now + timedelta(seconds=self.ttl)
|
||||
|
||||
try:
|
||||
# Try to insert lock record
|
||||
insert_sql = """
|
||||
INSERT INTO distributed_locks (lock_name, acquired_at, acquired_by, expires_at)
|
||||
VALUES (:lock_name, :acquired_at, :acquired_by, :expires_at)
|
||||
ON CONFLICT (lock_name) DO NOTHING
|
||||
RETURNING lock_name
|
||||
"""
|
||||
|
||||
result = await session.execute(
|
||||
text(insert_sql),
|
||||
{
|
||||
"lock_name": self.lock_name,
|
||||
"acquired_at": now,
|
||||
"acquired_by": owner,
|
||||
"expires_at": expires_at
|
||||
}
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
if result.rowcount > 0:
|
||||
acquired = True
|
||||
logger.info(f"Acquired simple lock: {self.lock_name}")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Lock acquisition attempt failed: {e}")
|
||||
await session.rollback()
|
||||
|
||||
# Wait before retrying
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
if not acquired:
|
||||
raise LockAcquisitionError(
|
||||
f"Could not acquire lock '{self.lock_name}' within {self.timeout}s"
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
if acquired:
|
||||
# Release lock
|
||||
delete_sql = """
|
||||
DELETE FROM distributed_locks
|
||||
WHERE lock_name = :lock_name
|
||||
"""
|
||||
await session.execute(
|
||||
text(delete_sql),
|
||||
{"lock_name": self.lock_name}
|
||||
)
|
||||
await session.commit()
|
||||
logger.info(f"Released simple lock: {self.lock_name}")
|
||||
|
||||
|
||||
def get_training_lock(tenant_id: str, product_id: str, use_advisory: bool = True) -> DatabaseLock:
|
||||
"""
|
||||
Get distributed lock for training a specific product.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
product_id: Product identifier
|
||||
use_advisory: Use PostgreSQL advisory locks (True) or table-based (False)
|
||||
|
||||
Returns:
|
||||
Lock instance
|
||||
"""
|
||||
lock_name = f"training:{tenant_id}:{product_id}"
|
||||
|
||||
if use_advisory:
|
||||
return DatabaseLock(lock_name, timeout=60.0)
|
||||
else:
|
||||
return SimpleDatabaseLock(lock_name, timeout=60.0, ttl=600.0)
|
||||
216
services/training/app/utils/file_utils.py
Normal file
216
services/training/app/utils/file_utils.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""
|
||||
File Utility Functions
|
||||
Utilities for secure file operations including checksum verification
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def calculate_file_checksum(file_path: str, algorithm: str = "sha256") -> str:
|
||||
"""
|
||||
Calculate checksum of a file.
|
||||
|
||||
Args:
|
||||
file_path: Path to file
|
||||
algorithm: Hash algorithm (sha256, md5, etc.)
|
||||
|
||||
Returns:
|
||||
Hexadecimal checksum string
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
ValueError: If algorithm not supported
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
try:
|
||||
hash_func = hashlib.new(algorithm)
|
||||
except ValueError:
|
||||
raise ValueError(f"Unsupported hash algorithm: {algorithm}")
|
||||
|
||||
# Read file in chunks to handle large files efficiently
|
||||
with open(file_path, 'rb') as f:
|
||||
while chunk := f.read(8192):
|
||||
hash_func.update(chunk)
|
||||
|
||||
return hash_func.hexdigest()
|
||||
|
||||
|
||||
def verify_file_checksum(file_path: str, expected_checksum: str, algorithm: str = "sha256") -> bool:
|
||||
"""
|
||||
Verify file matches expected checksum.
|
||||
|
||||
Args:
|
||||
file_path: Path to file
|
||||
expected_checksum: Expected checksum value
|
||||
algorithm: Hash algorithm used
|
||||
|
||||
Returns:
|
||||
True if checksum matches, False otherwise
|
||||
"""
|
||||
try:
|
||||
actual_checksum = calculate_file_checksum(file_path, algorithm)
|
||||
matches = actual_checksum == expected_checksum
|
||||
|
||||
if matches:
|
||||
logger.debug(f"Checksum verified for {file_path}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Checksum mismatch for {file_path}",
|
||||
expected=expected_checksum,
|
||||
actual=actual_checksum
|
||||
)
|
||||
|
||||
return matches
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error verifying checksum for {file_path}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_file_size(file_path: str) -> int:
|
||||
"""
|
||||
Get file size in bytes.
|
||||
|
||||
Args:
|
||||
file_path: Path to file
|
||||
|
||||
Returns:
|
||||
File size in bytes
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
return os.path.getsize(file_path)
|
||||
|
||||
|
||||
def ensure_directory_exists(directory: str) -> Path:
|
||||
"""
|
||||
Ensure directory exists, create if necessary.
|
||||
|
||||
Args:
|
||||
directory: Directory path
|
||||
|
||||
Returns:
|
||||
Path object for directory
|
||||
"""
|
||||
path = Path(directory)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
def safe_file_delete(file_path: str) -> bool:
|
||||
"""
|
||||
Safely delete a file, logging any errors.
|
||||
|
||||
Args:
|
||||
file_path: Path to file
|
||||
|
||||
Returns:
|
||||
True if deleted successfully, False otherwise
|
||||
"""
|
||||
try:
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
logger.info(f"Deleted file: {file_path}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"File not found for deletion: {file_path}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting file {file_path}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_file_metadata(file_path: str) -> dict:
|
||||
"""
|
||||
Get comprehensive file metadata.
|
||||
|
||||
Args:
|
||||
file_path: Path to file
|
||||
|
||||
Returns:
|
||||
Dictionary with file metadata
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
stat = os.stat(file_path)
|
||||
|
||||
return {
|
||||
"path": file_path,
|
||||
"size_bytes": stat.st_size,
|
||||
"created_at": stat.st_ctime,
|
||||
"modified_at": stat.st_mtime,
|
||||
"accessed_at": stat.st_atime,
|
||||
"is_file": os.path.isfile(file_path),
|
||||
"is_dir": os.path.isdir(file_path),
|
||||
"exists": True
|
||||
}
|
||||
|
||||
|
||||
class ChecksummedFile:
|
||||
"""
|
||||
Context manager for working with checksummed files.
|
||||
Automatically calculates and stores checksum when file is written.
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str, checksum_path: Optional[str] = None, algorithm: str = "sha256"):
|
||||
"""
|
||||
Initialize checksummed file handler.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
checksum_path: Path to store checksum (default: file_path + '.checksum')
|
||||
algorithm: Hash algorithm to use
|
||||
"""
|
||||
self.file_path = file_path
|
||||
self.checksum_path = checksum_path or f"{file_path}.checksum"
|
||||
self.algorithm = algorithm
|
||||
self.checksum: Optional[str] = None
|
||||
|
||||
def calculate_and_save_checksum(self) -> str:
|
||||
"""Calculate checksum and save to file"""
|
||||
self.checksum = calculate_file_checksum(self.file_path, self.algorithm)
|
||||
|
||||
with open(self.checksum_path, 'w') as f:
|
||||
f.write(f"{self.checksum} {os.path.basename(self.file_path)}\n")
|
||||
|
||||
logger.info(f"Saved checksum for {self.file_path}: {self.checksum}")
|
||||
return self.checksum
|
||||
|
||||
def load_and_verify_checksum(self) -> bool:
|
||||
"""Load expected checksum and verify file"""
|
||||
try:
|
||||
with open(self.checksum_path, 'r') as f:
|
||||
expected_checksum = f.read().strip().split()[0]
|
||||
|
||||
return verify_file_checksum(self.file_path, expected_checksum, self.algorithm)
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.warning(f"Checksum file not found: {self.checksum_path}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading checksum: {e}")
|
||||
return False
|
||||
|
||||
def get_stored_checksum(self) -> Optional[str]:
|
||||
"""Get checksum from stored file"""
|
||||
try:
|
||||
with open(self.checksum_path, 'r') as f:
|
||||
return f.read().strip().split()[0]
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
316
services/training/app/utils/retry.py
Normal file
316
services/training/app/utils/retry.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""
|
||||
Retry Mechanism with Exponential Backoff
|
||||
Handles transient failures with intelligent retry strategies
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import random
|
||||
from typing import Callable, Any, Optional, Type, Tuple
|
||||
from functools import wraps
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RetryError(Exception):
|
||||
"""Raised when all retry attempts are exhausted"""
|
||||
def __init__(self, message: str, attempts: int, last_exception: Exception):
|
||||
super().__init__(message)
|
||||
self.attempts = attempts
|
||||
self.last_exception = last_exception
|
||||
|
||||
|
||||
class RetryStrategy:
|
||||
"""Base retry strategy"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_attempts: int = 3,
|
||||
initial_delay: float = 1.0,
|
||||
max_delay: float = 60.0,
|
||||
exponential_base: float = 2.0,
|
||||
jitter: bool = True,
|
||||
retriable_exceptions: Tuple[Type[Exception], ...] = (Exception,)
|
||||
):
|
||||
"""
|
||||
Initialize retry strategy.
|
||||
|
||||
Args:
|
||||
max_attempts: Maximum number of retry attempts
|
||||
initial_delay: Initial delay in seconds
|
||||
max_delay: Maximum delay between retries
|
||||
exponential_base: Base for exponential backoff
|
||||
jitter: Add random jitter to prevent thundering herd
|
||||
retriable_exceptions: Tuple of exception types to retry
|
||||
"""
|
||||
self.max_attempts = max_attempts
|
||||
self.initial_delay = initial_delay
|
||||
self.max_delay = max_delay
|
||||
self.exponential_base = exponential_base
|
||||
self.jitter = jitter
|
||||
self.retriable_exceptions = retriable_exceptions
|
||||
|
||||
def calculate_delay(self, attempt: int) -> float:
|
||||
"""Calculate delay for given attempt using exponential backoff"""
|
||||
delay = min(
|
||||
self.initial_delay * (self.exponential_base ** attempt),
|
||||
self.max_delay
|
||||
)
|
||||
|
||||
if self.jitter:
|
||||
# Add random jitter (0-100% of delay)
|
||||
delay = delay * (0.5 + random.random() * 0.5)
|
||||
|
||||
return delay
|
||||
|
||||
def is_retriable(self, exception: Exception) -> bool:
|
||||
"""Check if exception should trigger retry"""
|
||||
return isinstance(exception, self.retriable_exceptions)
|
||||
|
||||
|
||||
async def retry_async(
|
||||
func: Callable,
|
||||
*args,
|
||||
strategy: Optional[RetryStrategy] = None,
|
||||
**kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
Retry async function with exponential backoff.
|
||||
|
||||
Args:
|
||||
func: Async function to retry
|
||||
*args: Positional arguments for func
|
||||
strategy: Retry strategy (uses default if None)
|
||||
**kwargs: Keyword arguments for func
|
||||
|
||||
Returns:
|
||||
Result from func
|
||||
|
||||
Raises:
|
||||
RetryError: When all attempts exhausted
|
||||
"""
|
||||
if strategy is None:
|
||||
strategy = RetryStrategy()
|
||||
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(strategy.max_attempts):
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
if attempt > 0:
|
||||
logger.info(
|
||||
f"Retry succeeded on attempt {attempt + 1}",
|
||||
function=func.__name__,
|
||||
attempt=attempt + 1
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
|
||||
if not strategy.is_retriable(e):
|
||||
logger.error(
|
||||
f"Non-retriable exception occurred",
|
||||
function=func.__name__,
|
||||
exception=str(e)
|
||||
)
|
||||
raise
|
||||
|
||||
if attempt < strategy.max_attempts - 1:
|
||||
delay = strategy.calculate_delay(attempt)
|
||||
logger.warning(
|
||||
f"Attempt {attempt + 1} failed, retrying in {delay:.2f}s",
|
||||
function=func.__name__,
|
||||
attempt=attempt + 1,
|
||||
max_attempts=strategy.max_attempts,
|
||||
exception=str(e)
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
else:
|
||||
logger.error(
|
||||
f"All {strategy.max_attempts} retry attempts exhausted",
|
||||
function=func.__name__,
|
||||
exception=str(e)
|
||||
)
|
||||
|
||||
raise RetryError(
|
||||
f"Failed after {strategy.max_attempts} attempts: {str(last_exception)}",
|
||||
attempts=strategy.max_attempts,
|
||||
last_exception=last_exception
|
||||
)
|
||||
|
||||
|
||||
def with_retry(
|
||||
max_attempts: int = 3,
|
||||
initial_delay: float = 1.0,
|
||||
max_delay: float = 60.0,
|
||||
exponential_base: float = 2.0,
|
||||
jitter: bool = True,
|
||||
retriable_exceptions: Tuple[Type[Exception], ...] = (Exception,)
|
||||
):
|
||||
"""
|
||||
Decorator to add retry logic to async functions.
|
||||
|
||||
Example:
|
||||
@with_retry(max_attempts=5, initial_delay=2.0)
|
||||
async def fetch_data():
|
||||
# Your code here
|
||||
pass
|
||||
"""
|
||||
strategy = RetryStrategy(
|
||||
max_attempts=max_attempts,
|
||||
initial_delay=initial_delay,
|
||||
max_delay=max_delay,
|
||||
exponential_base=exponential_base,
|
||||
jitter=jitter,
|
||||
retriable_exceptions=retriable_exceptions
|
||||
)
|
||||
|
||||
def decorator(func: Callable):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
return await retry_async(func, *args, strategy=strategy, **kwargs)
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class AdaptiveRetryStrategy(RetryStrategy):
|
||||
"""
|
||||
Adaptive retry strategy that adjusts based on success/failure patterns.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.success_count = 0
|
||||
self.failure_count = 0
|
||||
self.consecutive_failures = 0
|
||||
|
||||
def calculate_delay(self, attempt: int) -> float:
|
||||
"""Calculate delay with adaptation based on recent history"""
|
||||
base_delay = super().calculate_delay(attempt)
|
||||
|
||||
# Increase delay if seeing consecutive failures
|
||||
if self.consecutive_failures > 5:
|
||||
multiplier = min(2.0, 1.0 + (self.consecutive_failures - 5) * 0.2)
|
||||
base_delay *= multiplier
|
||||
|
||||
return min(base_delay, self.max_delay)
|
||||
|
||||
def record_success(self):
|
||||
"""Record successful attempt"""
|
||||
self.success_count += 1
|
||||
self.consecutive_failures = 0
|
||||
|
||||
def record_failure(self):
|
||||
"""Record failed attempt"""
|
||||
self.failure_count += 1
|
||||
self.consecutive_failures += 1
|
||||
|
||||
|
||||
class TimeoutRetryStrategy(RetryStrategy):
|
||||
"""
|
||||
Retry strategy with overall timeout across all attempts.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, timeout: float = 300.0, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
timeout: Total timeout in seconds for all attempts
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.timeout = timeout
|
||||
self.start_time: Optional[float] = None
|
||||
|
||||
def should_retry(self, attempt: int) -> bool:
|
||||
"""Check if should attempt another retry"""
|
||||
if self.start_time is None:
|
||||
self.start_time = time.time()
|
||||
return True
|
||||
|
||||
elapsed = time.time() - self.start_time
|
||||
return elapsed < self.timeout and attempt < self.max_attempts
|
||||
|
||||
|
||||
async def retry_with_timeout(
|
||||
func: Callable,
|
||||
*args,
|
||||
max_attempts: int = 3,
|
||||
timeout: float = 300.0,
|
||||
**kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
Retry with overall timeout.
|
||||
|
||||
Args:
|
||||
func: Function to retry
|
||||
max_attempts: Maximum attempts
|
||||
timeout: Overall timeout in seconds
|
||||
|
||||
Returns:
|
||||
Result from func
|
||||
"""
|
||||
strategy = TimeoutRetryStrategy(
|
||||
max_attempts=max_attempts,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
strategy.start_time = start_time
|
||||
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(strategy.max_attempts):
|
||||
if time.time() - start_time >= timeout:
|
||||
raise RetryError(
|
||||
f"Timeout of {timeout}s exceeded",
|
||||
attempts=attempt + 1,
|
||||
last_exception=last_exception
|
||||
)
|
||||
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
|
||||
if not strategy.is_retriable(e):
|
||||
raise
|
||||
|
||||
if attempt < strategy.max_attempts - 1:
|
||||
delay = strategy.calculate_delay(attempt)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
raise RetryError(
|
||||
f"Failed after {strategy.max_attempts} attempts",
|
||||
attempts=strategy.max_attempts,
|
||||
last_exception=last_exception
|
||||
)
|
||||
|
||||
|
||||
# Pre-configured strategies for common use cases
|
||||
HTTP_RETRY_STRATEGY = RetryStrategy(
|
||||
max_attempts=3,
|
||||
initial_delay=1.0,
|
||||
max_delay=10.0,
|
||||
exponential_base=2.0,
|
||||
jitter=True
|
||||
)
|
||||
|
||||
DATABASE_RETRY_STRATEGY = RetryStrategy(
|
||||
max_attempts=5,
|
||||
initial_delay=0.5,
|
||||
max_delay=5.0,
|
||||
exponential_base=1.5,
|
||||
jitter=True
|
||||
)
|
||||
|
||||
EXTERNAL_SERVICE_RETRY_STRATEGY = RetryStrategy(
|
||||
max_attempts=4,
|
||||
initial_delay=2.0,
|
||||
max_delay=30.0,
|
||||
exponential_base=2.5,
|
||||
jitter=True
|
||||
)
|
||||
184
services/training/app/utils/timezone_utils.py
Normal file
184
services/training/app/utils/timezone_utils.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""
|
||||
Timezone Utility Functions
|
||||
Centralized timezone handling to ensure consistency across the training service
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Union
|
||||
import pandas as pd
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def ensure_timezone_aware(dt: datetime, default_tz=timezone.utc) -> datetime:
|
||||
"""
|
||||
Ensure a datetime is timezone-aware.
|
||||
|
||||
Args:
|
||||
dt: Datetime to check
|
||||
default_tz: Timezone to apply if datetime is naive (default: UTC)
|
||||
|
||||
Returns:
|
||||
Timezone-aware datetime
|
||||
"""
|
||||
if dt is None:
|
||||
return None
|
||||
|
||||
if dt.tzinfo is None:
|
||||
return dt.replace(tzinfo=default_tz)
|
||||
return dt
|
||||
|
||||
|
||||
def ensure_timezone_naive(dt: datetime) -> datetime:
|
||||
"""
|
||||
Remove timezone information from a datetime.
|
||||
|
||||
Args:
|
||||
dt: Datetime to process
|
||||
|
||||
Returns:
|
||||
Timezone-naive datetime
|
||||
"""
|
||||
if dt is None:
|
||||
return None
|
||||
|
||||
if dt.tzinfo is not None:
|
||||
return dt.replace(tzinfo=None)
|
||||
return dt
|
||||
|
||||
|
||||
def normalize_datetime_to_utc(dt: Union[datetime, pd.Timestamp]) -> datetime:
|
||||
"""
|
||||
Normalize any datetime to UTC timezone-aware datetime.
|
||||
|
||||
Args:
|
||||
dt: Datetime or pandas Timestamp to normalize
|
||||
|
||||
Returns:
|
||||
UTC timezone-aware datetime
|
||||
"""
|
||||
if dt is None:
|
||||
return None
|
||||
|
||||
# Handle pandas Timestamp
|
||||
if isinstance(dt, pd.Timestamp):
|
||||
dt = dt.to_pydatetime()
|
||||
|
||||
# If naive, assume UTC
|
||||
if dt.tzinfo is None:
|
||||
return dt.replace(tzinfo=timezone.utc)
|
||||
|
||||
# If aware but not UTC, convert to UTC
|
||||
return dt.astimezone(timezone.utc)
|
||||
|
||||
|
||||
def normalize_dataframe_datetime_column(
|
||||
df: pd.DataFrame,
|
||||
column: str,
|
||||
target_format: str = 'naive'
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Normalize a datetime column in a dataframe to consistent format.
|
||||
|
||||
Args:
|
||||
df: DataFrame to process
|
||||
column: Name of datetime column
|
||||
target_format: 'naive' or 'aware' (UTC)
|
||||
|
||||
Returns:
|
||||
DataFrame with normalized datetime column
|
||||
"""
|
||||
if column not in df.columns:
|
||||
logger.warning(f"Column {column} not found in dataframe")
|
||||
return df
|
||||
|
||||
# Convert to datetime if not already
|
||||
df[column] = pd.to_datetime(df[column])
|
||||
|
||||
if target_format == 'naive':
|
||||
# Remove timezone if present
|
||||
if df[column].dt.tz is not None:
|
||||
df[column] = df[column].dt.tz_localize(None)
|
||||
elif target_format == 'aware':
|
||||
# Add UTC timezone if not present
|
||||
if df[column].dt.tz is None:
|
||||
df[column] = df[column].dt.tz_localize(timezone.utc)
|
||||
else:
|
||||
# Convert to UTC if different timezone
|
||||
df[column] = df[column].dt.tz_convert(timezone.utc)
|
||||
else:
|
||||
raise ValueError(f"Invalid target_format: {target_format}. Must be 'naive' or 'aware'")
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def prepare_prophet_datetime(df: pd.DataFrame, datetime_col: str = 'ds') -> pd.DataFrame:
|
||||
"""
|
||||
Prepare datetime column for Prophet (requires timezone-naive datetimes).
|
||||
|
||||
Args:
|
||||
df: DataFrame with datetime column
|
||||
datetime_col: Name of datetime column (default: 'ds')
|
||||
|
||||
Returns:
|
||||
DataFrame with Prophet-compatible datetime column
|
||||
"""
|
||||
df = df.copy()
|
||||
df = normalize_dataframe_datetime_column(df, datetime_col, target_format='naive')
|
||||
return df
|
||||
|
||||
|
||||
def safe_datetime_comparison(dt1: datetime, dt2: datetime) -> int:
|
||||
"""
|
||||
Safely compare two datetimes, handling timezone mismatches.
|
||||
|
||||
Args:
|
||||
dt1: First datetime
|
||||
dt2: Second datetime
|
||||
|
||||
Returns:
|
||||
-1 if dt1 < dt2, 0 if equal, 1 if dt1 > dt2
|
||||
"""
|
||||
# Normalize both to UTC for comparison
|
||||
dt1_utc = normalize_datetime_to_utc(dt1)
|
||||
dt2_utc = normalize_datetime_to_utc(dt2)
|
||||
|
||||
if dt1_utc < dt2_utc:
|
||||
return -1
|
||||
elif dt1_utc > dt2_utc:
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def get_current_utc() -> datetime:
|
||||
"""
|
||||
Get current datetime in UTC with timezone awareness.
|
||||
|
||||
Returns:
|
||||
Current UTC datetime
|
||||
"""
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def convert_timestamp_to_datetime(timestamp: Union[int, float, str]) -> datetime:
|
||||
"""
|
||||
Convert various timestamp formats to datetime.
|
||||
|
||||
Args:
|
||||
timestamp: Unix timestamp (seconds or milliseconds) or ISO string
|
||||
|
||||
Returns:
|
||||
UTC timezone-aware datetime
|
||||
"""
|
||||
if isinstance(timestamp, str):
|
||||
dt = pd.to_datetime(timestamp)
|
||||
return normalize_datetime_to_utc(dt)
|
||||
|
||||
# Check if milliseconds (typical JavaScript timestamp)
|
||||
if timestamp > 1e10:
|
||||
timestamp = timestamp / 1000
|
||||
|
||||
dt = datetime.fromtimestamp(timestamp, tz=timezone.utc)
|
||||
return dt
|
||||
Reference in New Issue
Block a user