Add all the code for training service
This commit is contained in:
@@ -1,38 +1,303 @@
|
||||
# services/training/app/core/auth.py
|
||||
"""
|
||||
Authentication utilities for training service
|
||||
Authentication and authorization for training service
|
||||
"""
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException, status, Depends
|
||||
from fastapi.security import HTTPBearer
|
||||
import structlog
|
||||
from typing import Optional
|
||||
from fastapi import HTTPException, Depends, status
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
import httpx
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
security = HTTPBearer()
|
||||
# HTTP Bearer token scheme
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
async def verify_token(token: str = Depends(security)):
|
||||
"""Verify token with auth service"""
|
||||
class AuthenticationError(Exception):
|
||||
"""Custom exception for authentication errors"""
|
||||
pass
|
||||
|
||||
class AuthorizationError(Exception):
|
||||
"""Custom exception for authorization errors"""
|
||||
pass
|
||||
|
||||
async def verify_token(token: str) -> dict:
|
||||
"""
|
||||
Verify JWT token with auth service
|
||||
|
||||
Args:
|
||||
token: JWT token to verify
|
||||
|
||||
Returns:
|
||||
dict: Token payload with user and tenant information
|
||||
|
||||
Raises:
|
||||
AuthenticationError: If token is invalid
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{settings.AUTH_SERVICE_URL}/auth/verify",
|
||||
headers={"Authorization": f"Bearer {token.credentials}"}
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
token_data = response.json()
|
||||
logger.debug("Token verified successfully", user_id=token_data.get("user_id"))
|
||||
return token_data
|
||||
elif response.status_code == 401:
|
||||
logger.warning("Invalid token provided")
|
||||
raise AuthenticationError("Invalid or expired token")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authentication credentials"
|
||||
)
|
||||
logger.error("Auth service error", status_code=response.status_code)
|
||||
raise AuthenticationError("Authentication service unavailable")
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.error("Auth service timeout")
|
||||
raise AuthenticationError("Authentication service timeout")
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Auth service unavailable: {e}")
|
||||
logger.error("Auth service request error", error=str(e))
|
||||
raise AuthenticationError("Authentication service unavailable")
|
||||
except AuthenticationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Unexpected auth error", error=str(e))
|
||||
raise AuthenticationError("Authentication failed")
|
||||
|
||||
async def get_current_user(
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
|
||||
) -> dict:
|
||||
"""
|
||||
Get current authenticated user
|
||||
|
||||
Args:
|
||||
credentials: HTTP Bearer credentials
|
||||
|
||||
Returns:
|
||||
dict: User information
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails
|
||||
"""
|
||||
if not credentials:
|
||||
logger.warning("No credentials provided")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Authentication service unavailable"
|
||||
)
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication credentials required",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
try:
|
||||
token_data = await verify_token(credentials.credentials)
|
||||
return token_data
|
||||
|
||||
except AuthenticationError as e:
|
||||
logger.warning("Authentication failed", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
async def get_current_tenant_id(
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> str:
|
||||
"""
|
||||
Get current tenant ID from authenticated user
|
||||
|
||||
Args:
|
||||
current_user: Current authenticated user data
|
||||
|
||||
Returns:
|
||||
str: Tenant ID
|
||||
|
||||
Raises:
|
||||
HTTPException: If tenant ID is missing
|
||||
"""
|
||||
tenant_id = current_user.get("tenant_id")
|
||||
if not tenant_id:
|
||||
logger.error("Missing tenant_id in token", user_data=current_user)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid token: missing tenant information"
|
||||
)
|
||||
|
||||
return tenant_id
|
||||
|
||||
async def require_admin_role(
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> dict:
|
||||
"""
|
||||
Require admin role for endpoint access
|
||||
|
||||
Args:
|
||||
current_user: Current authenticated user data
|
||||
|
||||
Returns:
|
||||
dict: User information
|
||||
|
||||
Raises:
|
||||
HTTPException: If user is not admin
|
||||
"""
|
||||
user_role = current_user.get("role", "").lower()
|
||||
if user_role != "admin":
|
||||
logger.warning("Access denied - admin role required",
|
||||
user_id=current_user.get("user_id"),
|
||||
role=user_role)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Admin role required"
|
||||
)
|
||||
|
||||
return current_user
|
||||
|
||||
async def require_training_permission(
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> dict:
|
||||
"""
|
||||
Require training permission for endpoint access
|
||||
|
||||
Args:
|
||||
current_user: Current authenticated user data
|
||||
|
||||
Returns:
|
||||
dict: User information
|
||||
|
||||
Raises:
|
||||
HTTPException: If user doesn't have training permission
|
||||
"""
|
||||
permissions = current_user.get("permissions", [])
|
||||
if "training" not in permissions and current_user.get("role", "").lower() != "admin":
|
||||
logger.warning("Access denied - training permission required",
|
||||
user_id=current_user.get("user_id"),
|
||||
permissions=permissions)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Training permission required"
|
||||
)
|
||||
|
||||
return current_user
|
||||
|
||||
# Optional authentication for development/testing
|
||||
async def get_current_user_optional(
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get current user but don't require authentication (for development)
|
||||
|
||||
Args:
|
||||
credentials: HTTP Bearer credentials
|
||||
|
||||
Returns:
|
||||
dict or None: User information if authenticated, None otherwise
|
||||
"""
|
||||
if not credentials:
|
||||
return None
|
||||
|
||||
try:
|
||||
token_data = await verify_token(credentials.credentials)
|
||||
return token_data
|
||||
except AuthenticationError:
|
||||
return None
|
||||
|
||||
async def get_tenant_id_optional(
|
||||
current_user: Optional[dict] = Depends(get_current_user_optional)
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get tenant ID but don't require authentication (for development)
|
||||
|
||||
Args:
|
||||
current_user: Current user data (optional)
|
||||
|
||||
Returns:
|
||||
str or None: Tenant ID if available, None otherwise
|
||||
"""
|
||||
if not current_user:
|
||||
return None
|
||||
|
||||
return current_user.get("tenant_id")
|
||||
|
||||
# Development/testing auth bypass
|
||||
async def get_test_tenant_id() -> str:
|
||||
"""
|
||||
Get test tenant ID for development/testing
|
||||
Only works when DEBUG is enabled
|
||||
|
||||
Returns:
|
||||
str: Test tenant ID
|
||||
"""
|
||||
if settings.DEBUG:
|
||||
return "test-tenant-development"
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Test authentication only available in debug mode"
|
||||
)
|
||||
|
||||
# Token validation utility
|
||||
def validate_token_structure(token_data: dict) -> bool:
|
||||
"""
|
||||
Validate that token data has required structure
|
||||
|
||||
Args:
|
||||
token_data: Token payload data
|
||||
|
||||
Returns:
|
||||
bool: True if valid structure, False otherwise
|
||||
"""
|
||||
required_fields = ["user_id", "tenant_id"]
|
||||
|
||||
for field in required_fields:
|
||||
if field not in token_data:
|
||||
logger.warning("Invalid token structure - missing field", field=field)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
# Role checking utilities
|
||||
def has_role(user_data: dict, required_role: str) -> bool:
|
||||
"""
|
||||
Check if user has required role
|
||||
|
||||
Args:
|
||||
user_data: User data from token
|
||||
required_role: Required role name
|
||||
|
||||
Returns:
|
||||
bool: True if user has role, False otherwise
|
||||
"""
|
||||
user_role = user_data.get("role", "").lower()
|
||||
return user_role == required_role.lower()
|
||||
|
||||
def has_permission(user_data: dict, required_permission: str) -> bool:
|
||||
"""
|
||||
Check if user has required permission
|
||||
|
||||
Args:
|
||||
user_data: User data from token
|
||||
required_permission: Required permission name
|
||||
|
||||
Returns:
|
||||
bool: True if user has permission, False otherwise
|
||||
"""
|
||||
permissions = user_data.get("permissions", [])
|
||||
return required_permission in permissions or has_role(user_data, "admin")
|
||||
|
||||
# Export commonly used items
|
||||
__all__ = [
|
||||
'get_current_user',
|
||||
'get_current_tenant_id',
|
||||
'require_admin_role',
|
||||
'require_training_permission',
|
||||
'get_current_user_optional',
|
||||
'get_tenant_id_optional',
|
||||
'get_test_tenant_id',
|
||||
'has_role',
|
||||
'has_permission',
|
||||
'AuthenticationError',
|
||||
'AuthorizationError'
|
||||
]
|
||||
@@ -1,12 +1,260 @@
|
||||
# services/training/app/core/database.py
|
||||
"""
|
||||
Database configuration for training service
|
||||
Uses shared database infrastructure
|
||||
"""
|
||||
|
||||
from shared.database.base import DatabaseManager
|
||||
import structlog
|
||||
from typing import AsyncGenerator
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import text
|
||||
|
||||
from shared.database.base import DatabaseManager, Base
|
||||
from app.core.config import settings
|
||||
|
||||
# Initialize database manager
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Initialize database manager using shared infrastructure
|
||||
database_manager = DatabaseManager(settings.DATABASE_URL)
|
||||
|
||||
# Alias for convenience
|
||||
get_db = database_manager.get_db
|
||||
# Alias for convenience - matches the existing interface
|
||||
get_db = database_manager.get_db
|
||||
|
||||
async def get_db_health() -> bool:
|
||||
"""
|
||||
Health check function for database connectivity
|
||||
Enhanced version of the shared functionality
|
||||
"""
|
||||
try:
|
||||
async with database_manager.async_engine.begin() as conn:
|
||||
await conn.execute(text("SELECT 1"))
|
||||
logger.debug("Database health check passed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Database health check failed", error=str(e))
|
||||
return False
|
||||
|
||||
# Training service specific database utilities
|
||||
class TrainingDatabaseUtils:
|
||||
"""Training service specific database utilities"""
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_old_training_logs(days_old: int = 90):
|
||||
"""Clean up old training logs"""
|
||||
try:
|
||||
async with database_manager.async_session_local() as session:
|
||||
if settings.DATABASE_URL.startswith("sqlite"):
|
||||
query = text(
|
||||
"DELETE FROM model_training_logs "
|
||||
"WHERE start_time < datetime('now', :days_param)"
|
||||
)
|
||||
params = {"days_param": f"-{days_old} days"}
|
||||
else:
|
||||
query = text(
|
||||
"DELETE FROM model_training_logs "
|
||||
"WHERE start_time < NOW() - INTERVAL :days_param"
|
||||
)
|
||||
params = {"days_param": f"{days_old} days"}
|
||||
|
||||
result = await session.execute(query, params)
|
||||
await session.commit()
|
||||
|
||||
deleted_count = result.rowcount
|
||||
logger.info("Cleaned up old training logs",
|
||||
deleted_count=deleted_count,
|
||||
days_old=days_old)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Training logs cleanup failed", error=str(e))
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_old_models(days_old: int = 365):
|
||||
"""Clean up old inactive models"""
|
||||
try:
|
||||
async with database_manager.async_session_local() as session:
|
||||
if settings.DATABASE_URL.startswith("sqlite"):
|
||||
query = text(
|
||||
"DELETE FROM trained_models "
|
||||
"WHERE is_active = 0 AND created_at < datetime('now', :days_param)"
|
||||
)
|
||||
params = {"days_param": f"-{days_old} days"}
|
||||
else:
|
||||
query = text(
|
||||
"DELETE FROM trained_models "
|
||||
"WHERE is_active = false AND created_at < NOW() - INTERVAL :days_param"
|
||||
)
|
||||
params = {"days_param": f"{days_old} days"}
|
||||
|
||||
result = await session.execute(query, params)
|
||||
await session.commit()
|
||||
|
||||
deleted_count = result.rowcount
|
||||
logger.info("Cleaned up old models",
|
||||
deleted_count=deleted_count,
|
||||
days_old=days_old)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Model cleanup failed", error=str(e))
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
async def get_training_statistics(tenant_id: str = None) -> dict:
|
||||
"""Get training statistics"""
|
||||
try:
|
||||
async with database_manager.async_session_local() as session:
|
||||
# Base query for training logs
|
||||
if tenant_id:
|
||||
logs_query = text(
|
||||
"SELECT status, COUNT(*) as count "
|
||||
"FROM model_training_logs "
|
||||
"WHERE tenant_id = :tenant_id "
|
||||
"GROUP BY status"
|
||||
)
|
||||
models_query = text(
|
||||
"SELECT COUNT(*) as count "
|
||||
"FROM trained_models "
|
||||
"WHERE tenant_id = :tenant_id AND is_active = :is_active"
|
||||
)
|
||||
params = {"tenant_id": tenant_id}
|
||||
else:
|
||||
logs_query = text(
|
||||
"SELECT status, COUNT(*) as count "
|
||||
"FROM model_training_logs "
|
||||
"GROUP BY status"
|
||||
)
|
||||
models_query = text(
|
||||
"SELECT COUNT(*) as count "
|
||||
"FROM trained_models "
|
||||
"WHERE is_active = :is_active"
|
||||
)
|
||||
params = {}
|
||||
|
||||
# Get training job statistics
|
||||
logs_result = await session.execute(logs_query, params)
|
||||
job_stats = {row.status: row.count for row in logs_result.fetchall()}
|
||||
|
||||
# Get active models count
|
||||
active_models_result = await session.execute(
|
||||
models_query,
|
||||
{**params, "is_active": True}
|
||||
)
|
||||
active_models = active_models_result.scalar() or 0
|
||||
|
||||
# Get inactive models count
|
||||
inactive_models_result = await session.execute(
|
||||
models_query,
|
||||
{**params, "is_active": False}
|
||||
)
|
||||
inactive_models = inactive_models_result.scalar() or 0
|
||||
|
||||
return {
|
||||
"training_jobs": job_stats,
|
||||
"active_models": active_models,
|
||||
"inactive_models": inactive_models,
|
||||
"total_models": active_models + inactive_models
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get training statistics", error=str(e))
|
||||
return {
|
||||
"training_jobs": {},
|
||||
"active_models": 0,
|
||||
"inactive_models": 0,
|
||||
"total_models": 0
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def check_tenant_data_exists(tenant_id: str) -> bool:
|
||||
"""Check if tenant has any training data"""
|
||||
try:
|
||||
async with database_manager.async_session_local() as session:
|
||||
query = text(
|
||||
"SELECT COUNT(*) as count "
|
||||
"FROM model_training_logs "
|
||||
"WHERE tenant_id = :tenant_id "
|
||||
"LIMIT 1"
|
||||
)
|
||||
|
||||
result = await session.execute(query, {"tenant_id": tenant_id})
|
||||
count = result.scalar() or 0
|
||||
|
||||
return count > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to check tenant data existence",
|
||||
tenant_id=tenant_id, error=str(e))
|
||||
return False
|
||||
|
||||
# Enhanced database session dependency with better error handling
|
||||
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
Enhanced database session dependency with better logging and error handling
|
||||
"""
|
||||
async with database_manager.async_session_local() as session:
|
||||
try:
|
||||
logger.debug("Database session created")
|
||||
yield session
|
||||
except Exception as e:
|
||||
logger.error("Database session error", error=str(e), exc_info=True)
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
logger.debug("Database session closed")
|
||||
|
||||
# Database initialization for training service
|
||||
async def initialize_training_database():
|
||||
"""Initialize database tables for training service"""
|
||||
try:
|
||||
logger.info("Initializing training service database")
|
||||
|
||||
# Import models to ensure they're registered
|
||||
from app.models.training import (
|
||||
ModelTrainingLog,
|
||||
TrainedModel,
|
||||
ModelPerformanceMetric,
|
||||
TrainingJobQueue,
|
||||
ModelArtifact
|
||||
)
|
||||
|
||||
# Create tables using shared infrastructure
|
||||
await database_manager.create_tables()
|
||||
|
||||
logger.info("Training service database initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize training service database", error=str(e))
|
||||
raise
|
||||
|
||||
# Database cleanup for training service
|
||||
async def cleanup_training_database():
|
||||
"""Cleanup database connections for training service"""
|
||||
try:
|
||||
logger.info("Cleaning up training service database connections")
|
||||
|
||||
# Close engine connections
|
||||
if hasattr(database_manager, 'async_engine') and database_manager.async_engine:
|
||||
await database_manager.async_engine.dispose()
|
||||
|
||||
logger.info("Training service database cleanup completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup training service database", error=str(e))
|
||||
|
||||
# Export the commonly used items to maintain compatibility
|
||||
__all__ = [
|
||||
'Base',
|
||||
'database_manager',
|
||||
'get_db',
|
||||
'get_db_session',
|
||||
'get_db_health',
|
||||
'TrainingDatabaseUtils',
|
||||
'initialize_training_database',
|
||||
'cleanup_training_database'
|
||||
]
|
||||
Reference in New Issue
Block a user