REFACTOR - Database logic
This commit is contained in:
@@ -1,41 +1,48 @@
|
||||
# services/auth/app/api/auth.py - Fixed Login Method
|
||||
"""
|
||||
Authentication API endpoints - FIXED VERSION
|
||||
Enhanced Authentication API Endpoints
|
||||
Updated to use repository pattern with dependency injection and improved error handling
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
import structlog
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.security import SecurityManager
|
||||
from app.services.auth_service import AuthService
|
||||
from app.schemas.auth import PasswordReset, UserRegistration, UserLogin, TokenResponse, RefreshTokenRequest, PasswordChange
|
||||
from app.schemas.auth import (
|
||||
UserRegistration, UserLogin, TokenResponse, RefreshTokenRequest,
|
||||
PasswordChange, PasswordReset, UserResponse
|
||||
)
|
||||
from app.services.auth_service import EnhancedAuthService
|
||||
from shared.database.base import create_database_manager
|
||||
from shared.monitoring.decorators import track_execution_time
|
||||
from shared.monitoring.metrics import get_metrics_collector
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
router = APIRouter()
|
||||
router = APIRouter(tags=["enhanced-auth"])
|
||||
security = HTTPBearer()
|
||||
|
||||
|
||||
def get_auth_service():
|
||||
"""Dependency injection for EnhancedAuthService"""
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "auth-service")
|
||||
return EnhancedAuthService(database_manager)
|
||||
|
||||
|
||||
@router.post("/register", response_model=TokenResponse)
|
||||
@track_execution_time("registration_duration_seconds", "auth-service")
|
||||
@track_execution_time("enhanced_registration_duration_seconds", "auth-service")
|
||||
async def register(
|
||||
user_data: UserRegistration,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
auth_service: EnhancedAuthService = Depends(get_auth_service)
|
||||
):
|
||||
"""Register new user with enhanced debugging"""
|
||||
"""Register new user using enhanced repository pattern"""
|
||||
metrics = get_metrics_collector(request)
|
||||
|
||||
# ✅ DEBUG: Log incoming registration data (without password)
|
||||
logger.info(f"Registration attempt for email: {user_data.email}")
|
||||
logger.debug(f"Registration data - email: {user_data.email}, full_name: {user_data.full_name}, role: {user_data.role}")
|
||||
logger.info("Registration attempt using repository pattern",
|
||||
email=user_data.email)
|
||||
|
||||
try:
|
||||
# ✅ DEBUG: Validate input data
|
||||
# Enhanced input validation
|
||||
if not user_data.email or not user_data.email.strip():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
@@ -54,65 +61,58 @@ async def register(
|
||||
detail="Full name is required"
|
||||
)
|
||||
|
||||
logger.debug(f"Input validation passed for {user_data.email}")
|
||||
|
||||
result = await AuthService.register_user(user_data, db)
|
||||
|
||||
logger.info(f"Registration successful for {user_data.email}")
|
||||
# Register user using enhanced service
|
||||
result = await auth_service.register_user(user_data)
|
||||
|
||||
# Record successful registration
|
||||
if metrics:
|
||||
metrics.increment_counter("registration_total", labels={"status": "success"})
|
||||
metrics.increment_counter("enhanced_registration_total", labels={"status": "success"})
|
||||
|
||||
# ✅ DEBUG: Validate response before returning
|
||||
if not result.get("access_token"):
|
||||
logger.error(f"Registration succeeded but no access_token in result for {user_data.email}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Registration completed but token generation failed"
|
||||
)
|
||||
|
||||
logger.debug(f"Returning token response for {user_data.email}")
|
||||
return TokenResponse(**result)
|
||||
logger.info("Registration successful using repository pattern",
|
||||
user_id=result.user.id,
|
||||
email=user_data.email)
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException as e:
|
||||
# Record failed registration with specific error
|
||||
if metrics:
|
||||
error_type = "validation_error" if e.status_code == 400 else "conflict" if e.status_code == 409 else "failed"
|
||||
metrics.increment_counter("registration_total", labels={"status": error_type})
|
||||
metrics.increment_counter("enhanced_registration_total", labels={"status": error_type})
|
||||
|
||||
logger.warning(f"Registration failed for {user_data.email}: {e.detail}")
|
||||
logger.warning("Registration failed using repository pattern",
|
||||
email=user_data.email,
|
||||
error=e.detail)
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
# Record registration system error
|
||||
if metrics:
|
||||
metrics.increment_counter("registration_total", labels={"status": "error"})
|
||||
metrics.increment_counter("enhanced_registration_total", labels={"status": "error"})
|
||||
|
||||
logger.error(f"Registration system error for {user_data.email}: {str(e)}", exc_info=True)
|
||||
|
||||
# ✅ DEBUG: Provide more specific error information in development
|
||||
error_detail = f"Registration failed: {str(e)}" if logger.level == "DEBUG" else "Registration failed"
|
||||
logger.error("Registration system error using repository pattern",
|
||||
email=user_data.email,
|
||||
error=str(e))
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=error_detail
|
||||
detail="Registration failed"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
@track_execution_time("login_duration_seconds", "auth-service")
|
||||
@track_execution_time("enhanced_login_duration_seconds", "auth-service")
|
||||
async def login(
|
||||
login_data: UserLogin,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
auth_service: EnhancedAuthService = Depends(get_auth_service)
|
||||
):
|
||||
"""Login user with enhanced debugging"""
|
||||
"""Login user using enhanced repository pattern"""
|
||||
metrics = get_metrics_collector(request)
|
||||
|
||||
logger.info(f"Login attempt for email: {login_data.email}")
|
||||
logger.info("Login attempt using repository pattern",
|
||||
email=login_data.email)
|
||||
|
||||
try:
|
||||
# ✅ DEBUG: Validate login data
|
||||
# Enhanced input validation
|
||||
if not login_data.email or not login_data.email.strip():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
@@ -125,76 +125,88 @@ async def login(
|
||||
detail="Password is required"
|
||||
)
|
||||
|
||||
# Attempt login through AuthService
|
||||
result = await AuthService.login_user(login_data, db)
|
||||
# Login using enhanced service
|
||||
result = await auth_service.login_user(login_data)
|
||||
|
||||
# Record successful login
|
||||
if metrics:
|
||||
metrics.increment_counter("login_success_total")
|
||||
metrics.increment_counter("enhanced_login_success_total")
|
||||
|
||||
logger.info(f"Login successful for {login_data.email}")
|
||||
return TokenResponse(**result)
|
||||
logger.info("Login successful using repository pattern",
|
||||
user_id=result.user.id,
|
||||
email=login_data.email)
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException as e:
|
||||
# Record failed login with specific reason
|
||||
if metrics:
|
||||
reason = "validation_error" if e.status_code == 400 else "auth_failed"
|
||||
metrics.increment_counter("login_failure_total", labels={"reason": reason})
|
||||
metrics.increment_counter("enhanced_login_failure_total", labels={"reason": reason})
|
||||
|
||||
logger.warning(f"Login failed for {login_data.email}: {e.detail}")
|
||||
logger.warning("Login failed using repository pattern",
|
||||
email=login_data.email,
|
||||
error=e.detail)
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
# Record login system error
|
||||
if metrics:
|
||||
metrics.increment_counter("login_failure_total", labels={"reason": "error"})
|
||||
metrics.increment_counter("enhanced_login_failure_total", labels={"reason": "error"})
|
||||
|
||||
logger.error("Login system error using repository pattern",
|
||||
email=login_data.email,
|
||||
error=str(e))
|
||||
|
||||
logger.error(f"Login system error for {login_data.email}: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Login failed"
|
||||
)
|
||||
|
||||
@router.post("/refresh", response_model=TokenResponse)
|
||||
@track_execution_time("token_refresh_duration_seconds", "auth-service")
|
||||
|
||||
@router.post("/refresh")
|
||||
@track_execution_time("enhanced_token_refresh_duration_seconds", "auth-service")
|
||||
async def refresh_token(
|
||||
refresh_data: RefreshTokenRequest,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
auth_service: EnhancedAuthService = Depends(get_auth_service)
|
||||
):
|
||||
"""Refresh access token"""
|
||||
"""Refresh access token using repository pattern"""
|
||||
metrics = get_metrics_collector(request)
|
||||
|
||||
try:
|
||||
result = await AuthService.refresh_access_token(refresh_data.refresh_token, db)
|
||||
result = await auth_service.refresh_access_token(refresh_data.refresh_token)
|
||||
|
||||
# Record successful refresh
|
||||
if metrics:
|
||||
metrics.increment_counter("token_refresh_success_total")
|
||||
metrics.increment_counter("enhanced_token_refresh_success_total")
|
||||
|
||||
return TokenResponse(**result)
|
||||
logger.debug("Access token refreshed using repository pattern")
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("token_refresh_failure_total")
|
||||
logger.warning(f"Token refresh failed: {e.detail}")
|
||||
metrics.increment_counter("enhanced_token_refresh_failure_total")
|
||||
logger.warning("Token refresh failed using repository pattern", error=e.detail)
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("token_refresh_failure_total")
|
||||
logger.error(f"Token refresh error: {e}")
|
||||
metrics.increment_counter("enhanced_token_refresh_failure_total")
|
||||
logger.error("Token refresh error using repository pattern", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Token refresh failed"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/verify")
|
||||
@track_execution_time("token_verify_duration_seconds", "auth-service")
|
||||
@track_execution_time("enhanced_token_verify_duration_seconds", "auth-service")
|
||||
async def verify_token(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
request: Request = None
|
||||
request: Request = None,
|
||||
auth_service: EnhancedAuthService = Depends(get_auth_service)
|
||||
):
|
||||
"""Verify access token and return user info"""
|
||||
"""Verify access token using repository pattern"""
|
||||
metrics = get_metrics_collector(request) if request else None
|
||||
|
||||
try:
|
||||
@@ -204,74 +216,91 @@ async def verify_token(
|
||||
detail="Authentication required"
|
||||
)
|
||||
|
||||
result = await AuthService.verify_user_token(credentials.credentials)
|
||||
result = await auth_service.verify_user_token(credentials.credentials)
|
||||
|
||||
# Record successful verification
|
||||
if metrics:
|
||||
metrics.increment_counter("token_verify_success_total")
|
||||
metrics.increment_counter("enhanced_token_verify_success_total")
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"user_id": result.get("user_id"),
|
||||
"email": result.get("email"),
|
||||
"role": result.get("role"),
|
||||
"exp": result.get("exp"),
|
||||
"message": None
|
||||
}
|
||||
|
||||
except HTTPException as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("token_verify_failure_total")
|
||||
logger.warning(f"Token verification failed: {e.detail}")
|
||||
metrics.increment_counter("enhanced_token_verify_failure_total")
|
||||
logger.warning("Token verification failed using repository pattern", error=e.detail)
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("token_verify_failure_total")
|
||||
logger.error(f"Token verification error: {e}")
|
||||
metrics.increment_counter("enhanced_token_verify_failure_total")
|
||||
logger.error("Token verification error using repository pattern", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
@track_execution_time("logout_duration_seconds", "auth-service")
|
||||
@track_execution_time("enhanced_logout_duration_seconds", "auth-service")
|
||||
async def logout(
|
||||
refresh_data: RefreshTokenRequest,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
auth_service: EnhancedAuthService = Depends(get_auth_service)
|
||||
):
|
||||
"""Logout user by revoking refresh token"""
|
||||
"""Logout user using repository pattern"""
|
||||
metrics = get_metrics_collector(request)
|
||||
|
||||
try:
|
||||
success = await AuthService.logout(refresh_data.refresh_token, db)
|
||||
# Verify token to get user_id
|
||||
payload = await auth_service.verify_user_token(credentials.credentials)
|
||||
user_id = payload.get("user_id")
|
||||
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token"
|
||||
)
|
||||
|
||||
success = await auth_service.logout_user(user_id, refresh_data.refresh_token)
|
||||
|
||||
if metrics:
|
||||
status_label = "success" if success else "failed"
|
||||
metrics.increment_counter("logout_total", labels={"status": status_label})
|
||||
metrics.increment_counter("enhanced_logout_total", labels={"status": status_label})
|
||||
|
||||
logger.info("Logout using repository pattern",
|
||||
user_id=user_id,
|
||||
success=success)
|
||||
|
||||
return {"message": "Logout successful" if success else "Logout failed"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("logout_total", labels={"status": "error"})
|
||||
logger.error(f"Logout error: {e}")
|
||||
metrics.increment_counter("enhanced_logout_total", labels={"status": "error"})
|
||||
logger.error("Logout error using repository pattern", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Logout failed"
|
||||
)
|
||||
|
||||
# ================================================================
|
||||
# PASSWORD MANAGEMENT ENDPOINTS
|
||||
# ================================================================
|
||||
|
||||
@router.post("/change-password")
|
||||
async def change_password(
|
||||
password_data: PasswordChange,
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
request: Request = None,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
auth_service: EnhancedAuthService = Depends(get_auth_service)
|
||||
):
|
||||
"""Change user password"""
|
||||
"""Change user password using repository pattern"""
|
||||
metrics = get_metrics_collector(request) if request else None
|
||||
|
||||
try:
|
||||
@@ -282,7 +311,7 @@ async def change_password(
|
||||
)
|
||||
|
||||
# Verify current token
|
||||
payload = await AuthService.verify_user_token(credentials.credentials)
|
||||
payload = await auth_service.verify_user_token(credentials.credentials)
|
||||
user_id = payload.get("user_id")
|
||||
|
||||
if not user_id:
|
||||
@@ -291,74 +320,194 @@ async def change_password(
|
||||
detail="Invalid token"
|
||||
)
|
||||
|
||||
# Validate new password
|
||||
if not SecurityManager.validate_password(password_data.new_password):
|
||||
# Validate new password length
|
||||
if len(password_data.new_password) < 8:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="New password does not meet security requirements"
|
||||
detail="New password must be at least 8 characters long"
|
||||
)
|
||||
|
||||
# Change password logic would go here
|
||||
# This is a simplified version - you'd need to implement the actual password change in AuthService
|
||||
# Change password using enhanced service
|
||||
success = await auth_service.change_password(
|
||||
user_id,
|
||||
password_data.current_password,
|
||||
password_data.new_password
|
||||
)
|
||||
|
||||
# Record password change
|
||||
if metrics:
|
||||
metrics.increment_counter("password_change_total", labels={"status": "success"})
|
||||
status_label = "success" if success else "failed"
|
||||
metrics.increment_counter("enhanced_password_change_total", labels={"status": status_label})
|
||||
|
||||
logger.info("Password changed using repository pattern",
|
||||
user_id=user_id,
|
||||
success=success)
|
||||
|
||||
logger.info(f"Password changed for user: {user_id}")
|
||||
return {"message": "Password changed successfully"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
# Record password change error
|
||||
if metrics:
|
||||
metrics.increment_counter("password_change_total", labels={"status": "error"})
|
||||
logger.error(f"Password change error: {e}")
|
||||
metrics.increment_counter("enhanced_password_change_total", labels={"status": "error"})
|
||||
logger.error("Password change error using repository pattern", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Password change failed"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/profile", response_model=UserResponse)
|
||||
async def get_profile(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
auth_service: EnhancedAuthService = Depends(get_auth_service)
|
||||
):
|
||||
"""Get user profile using repository pattern"""
|
||||
try:
|
||||
if not credentials:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required"
|
||||
)
|
||||
|
||||
# Verify token and get user_id
|
||||
payload = await auth_service.verify_user_token(credentials.credentials)
|
||||
user_id = payload.get("user_id")
|
||||
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token"
|
||||
)
|
||||
|
||||
# Get user profile using enhanced service
|
||||
profile = await auth_service.get_user_profile(user_id)
|
||||
if not profile:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User profile not found"
|
||||
)
|
||||
|
||||
return profile
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Get profile error using repository pattern", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get profile"
|
||||
)
|
||||
|
||||
|
||||
@router.put("/profile", response_model=UserResponse)
|
||||
async def update_profile(
|
||||
update_data: dict,
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
auth_service: EnhancedAuthService = Depends(get_auth_service)
|
||||
):
|
||||
"""Update user profile using repository pattern"""
|
||||
try:
|
||||
if not credentials:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required"
|
||||
)
|
||||
|
||||
# Verify token and get user_id
|
||||
payload = await auth_service.verify_user_token(credentials.credentials)
|
||||
user_id = payload.get("user_id")
|
||||
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token"
|
||||
)
|
||||
|
||||
# Update profile using enhanced service
|
||||
updated_profile = await auth_service.update_user_profile(user_id, update_data)
|
||||
if not updated_profile:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
logger.info("Profile updated using repository pattern",
|
||||
user_id=user_id,
|
||||
updated_fields=list(update_data.keys()))
|
||||
|
||||
return updated_profile
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Update profile error using repository pattern", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update profile"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/verify-email")
|
||||
async def verify_email(
|
||||
user_id: str,
|
||||
verification_token: str,
|
||||
auth_service: EnhancedAuthService = Depends(get_auth_service)
|
||||
):
|
||||
"""Verify user email using repository pattern"""
|
||||
try:
|
||||
success = await auth_service.verify_user_email(user_id, verification_token)
|
||||
|
||||
logger.info("Email verification using repository pattern",
|
||||
user_id=user_id,
|
||||
success=success)
|
||||
|
||||
return {"message": "Email verified successfully" if success else "Email verification failed"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Email verification error using repository pattern", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Email verification failed"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/reset-password")
|
||||
async def reset_password(
|
||||
reset_data: PasswordReset,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
auth_service: EnhancedAuthService = Depends(get_auth_service)
|
||||
):
|
||||
"""Request password reset"""
|
||||
"""Request password reset using repository pattern"""
|
||||
metrics = get_metrics_collector(request)
|
||||
|
||||
try:
|
||||
# Password reset logic would go here
|
||||
# This is a simplified version - you'd need to implement email sending, etc.
|
||||
# In a full implementation, you'd send an email with a reset token
|
||||
# For now, just log the request
|
||||
|
||||
# Record password reset request
|
||||
if metrics:
|
||||
metrics.increment_counter("password_reset_total", labels={"status": "requested"})
|
||||
metrics.increment_counter("enhanced_password_reset_total", labels={"status": "requested"})
|
||||
|
||||
logger.info("Password reset requested using repository pattern",
|
||||
email=reset_data.email)
|
||||
|
||||
logger.info(f"Password reset requested for: {reset_data.email}")
|
||||
return {"message": "Password reset email sent if account exists"}
|
||||
|
||||
except Exception as e:
|
||||
# Record password reset error
|
||||
if metrics:
|
||||
metrics.increment_counter("password_reset_total", labels={"status": "error"})
|
||||
logger.error(f"Password reset error: {e}")
|
||||
metrics.increment_counter("enhanced_password_reset_total", labels={"status": "error"})
|
||||
logger.error("Password reset error using repository pattern", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Password reset failed"
|
||||
)
|
||||
|
||||
# ================================================================
|
||||
# HEALTH AND STATUS ENDPOINTS
|
||||
# ================================================================
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
"""Health check endpoint for enhanced auth service"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "auth-service",
|
||||
"version": "1.0.0"
|
||||
"service": "enhanced-auth-service",
|
||||
"version": "2.0.0",
|
||||
"features": ["repository-pattern", "dependency-injection", "enhanced-error-handling"]
|
||||
}
|
||||
@@ -95,8 +95,9 @@ async def lifespan(app: FastAPI):
|
||||
async def check_database():
|
||||
try:
|
||||
from app.core.database import get_db
|
||||
from sqlalchemy import text
|
||||
async for db in get_db():
|
||||
await db.execute("SELECT 1")
|
||||
await db.execute(text("SELECT 1"))
|
||||
return True
|
||||
except Exception as e:
|
||||
return f"Database error: {e}"
|
||||
|
||||
@@ -4,7 +4,7 @@ User models for authentication service - FIXED
|
||||
Removed tenant relationships to eliminate cross-service dependencies
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Text
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
@@ -56,18 +56,33 @@ class User(Base):
|
||||
"last_login": self.last_login.isoformat() if self.last_login else None
|
||||
}
|
||||
|
||||
|
||||
class RefreshToken(Base):
|
||||
"""Refresh token model for JWT authentication"""
|
||||
"""Refresh token model for JWT token management"""
|
||||
__tablename__ = "refresh_tokens"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(UUID(as_uuid=True), nullable=False, index=True) # No FK - cross-service
|
||||
token = Column(Text, unique=True, nullable=False) # CHANGED FROM String(255) TO Text
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False)
|
||||
token = Column(String(500), unique=True, nullable=False)
|
||||
expires_at = Column(DateTime(timezone=True), nullable=False)
|
||||
is_revoked = Column(Boolean, default=False)
|
||||
revoked_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Timezone-aware datetime fields
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
revoked_at = Column(DateTime(timezone=True))
|
||||
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
||||
def __repr__(self):
|
||||
return f"<RefreshToken(user_id={self.user_id}, expires_at={self.expires_at})>"
|
||||
return f"<RefreshToken(id={self.id}, user_id={self.user_id}, is_revoked={self.is_revoked})>"
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert refresh token to dictionary"""
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"user_id": str(self.user_id),
|
||||
"token": self.token,
|
||||
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||
"is_revoked": self.is_revoked,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None
|
||||
}
|
||||
14
services/auth/app/repositories/__init__.py
Normal file
14
services/auth/app/repositories/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
Auth Service Repositories
|
||||
Repository implementations for authentication service
|
||||
"""
|
||||
|
||||
from .base import AuthBaseRepository
|
||||
from .user_repository import UserRepository
|
||||
from .token_repository import TokenRepository
|
||||
|
||||
__all__ = [
|
||||
"AuthBaseRepository",
|
||||
"UserRepository",
|
||||
"TokenRepository"
|
||||
]
|
||||
101
services/auth/app/repositories/base.py
Normal file
101
services/auth/app/repositories/base.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Base Repository for Auth Service
|
||||
Service-specific repository base class with auth service utilities
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any, Type
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from datetime import datetime, timezone
|
||||
import structlog
|
||||
|
||||
from shared.database.repository import BaseRepository
|
||||
from shared.database.exceptions import DatabaseError, ValidationError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class AuthBaseRepository(BaseRepository):
|
||||
"""Base repository for auth service with common auth operations"""
|
||||
|
||||
def __init__(self, model: Type, session: AsyncSession, cache_ttl: Optional[int] = 600):
|
||||
# Auth data benefits from longer caching (10 minutes)
|
||||
super().__init__(model, session, cache_ttl)
|
||||
|
||||
async def get_active_records(self, skip: int = 0, limit: int = 100) -> List:
|
||||
"""Get active records (if model has is_active field)"""
|
||||
if hasattr(self.model, 'is_active'):
|
||||
return await self.get_multi(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
filters={"is_active": True},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
return await self.get_multi(skip=skip, limit=limit)
|
||||
|
||||
async def get_by_email(self, email: str) -> Optional:
|
||||
"""Get record by email (if model has email field)"""
|
||||
if hasattr(self.model, 'email'):
|
||||
return await self.get_by_field("email", email)
|
||||
return None
|
||||
|
||||
async def get_by_username(self, username: str) -> Optional:
|
||||
"""Get record by username (if model has username field)"""
|
||||
if hasattr(self.model, 'username'):
|
||||
return await self.get_by_field("username", username)
|
||||
return None
|
||||
|
||||
async def deactivate_record(self, record_id: Any) -> Optional:
|
||||
"""Deactivate a record instead of deleting it"""
|
||||
if hasattr(self.model, 'is_active'):
|
||||
return await self.update(record_id, {"is_active": False})
|
||||
return await self.delete(record_id)
|
||||
|
||||
async def activate_record(self, record_id: Any) -> Optional:
|
||||
"""Activate a record"""
|
||||
if hasattr(self.model, 'is_active'):
|
||||
return await self.update(record_id, {"is_active": True})
|
||||
return await self.get_by_id(record_id)
|
||||
|
||||
async def cleanup_expired_records(self, field_name: str = "expires_at") -> int:
|
||||
"""Clean up expired records (for tokens, sessions, etc.)"""
|
||||
try:
|
||||
if not hasattr(self.model, field_name):
|
||||
logger.warning(f"Model {self.model.__name__} has no {field_name} field for cleanup")
|
||||
return 0
|
||||
|
||||
# This would need custom implementation with raw SQL for date comparison
|
||||
# For now, return 0 to indicate no cleanup performed
|
||||
logger.info(f"Cleanup requested for {self.model.__name__} but not implemented")
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup expired records",
|
||||
model=self.model.__name__,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Cleanup failed: {str(e)}")
|
||||
|
||||
def _validate_auth_data(self, data: Dict[str, Any], required_fields: List[str]) -> Dict[str, Any]:
|
||||
"""Validate authentication-related data"""
|
||||
errors = []
|
||||
|
||||
for field in required_fields:
|
||||
if field not in data or not data[field]:
|
||||
errors.append(f"Missing required field: {field}")
|
||||
|
||||
# Validate email format if present
|
||||
if "email" in data and data["email"]:
|
||||
email = data["email"]
|
||||
if "@" not in email or "." not in email.split("@")[-1]:
|
||||
errors.append("Invalid email format")
|
||||
|
||||
# Validate password strength if present
|
||||
if "password" in data and data["password"]:
|
||||
password = data["password"]
|
||||
if len(password) < 8:
|
||||
errors.append("Password must be at least 8 characters long")
|
||||
|
||||
return {
|
||||
"is_valid": len(errors) == 0,
|
||||
"errors": errors
|
||||
}
|
||||
269
services/auth/app/repositories/token_repository.py
Normal file
269
services/auth/app/repositories/token_repository.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
Token Repository
|
||||
Repository for refresh token operations
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, text
|
||||
from datetime import datetime, timezone, timedelta
|
||||
import structlog
|
||||
|
||||
from .base import AuthBaseRepository
|
||||
from app.models.users import RefreshToken
|
||||
from shared.database.exceptions import DatabaseError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class TokenRepository(AuthBaseRepository):
|
||||
"""Repository for refresh token operations"""
|
||||
|
||||
def __init__(self, model, session: AsyncSession, cache_ttl: Optional[int] = 300):
|
||||
# Tokens change frequently, shorter cache time
|
||||
super().__init__(model, session, cache_ttl)
|
||||
|
||||
async def create_token(self, token_data: Dict[str, Any]) -> RefreshToken:
|
||||
"""Create a new refresh token from dictionary data"""
|
||||
return await self.create(token_data)
|
||||
|
||||
async def create_refresh_token(
|
||||
self,
|
||||
user_id: str,
|
||||
token: str,
|
||||
expires_at: datetime
|
||||
) -> RefreshToken:
|
||||
"""Create a new refresh token"""
|
||||
try:
|
||||
token_data = {
|
||||
"user_id": user_id,
|
||||
"token": token,
|
||||
"expires_at": expires_at,
|
||||
"is_revoked": False
|
||||
}
|
||||
|
||||
refresh_token = await self.create(token_data)
|
||||
|
||||
logger.debug("Refresh token created",
|
||||
user_id=user_id,
|
||||
token_id=refresh_token.id,
|
||||
expires_at=expires_at)
|
||||
|
||||
return refresh_token
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to create refresh token",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create refresh token: {str(e)}")
|
||||
|
||||
async def get_token_by_value(self, token: str) -> Optional[RefreshToken]:
|
||||
"""Get refresh token by token value"""
|
||||
try:
|
||||
return await self.get_by_field("token", token)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get token by value", error=str(e))
|
||||
raise DatabaseError(f"Failed to get token: {str(e)}")
|
||||
|
||||
async def get_active_tokens_for_user(self, user_id: str) -> List[RefreshToken]:
|
||||
"""Get all active (non-revoked, non-expired) tokens for a user"""
|
||||
try:
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Use raw query for complex filtering
|
||||
query = text("""
|
||||
SELECT * FROM refresh_tokens
|
||||
WHERE user_id = :user_id
|
||||
AND is_revoked = false
|
||||
AND expires_at > :now
|
||||
ORDER BY created_at DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(query, {
|
||||
"user_id": user_id,
|
||||
"now": now
|
||||
})
|
||||
|
||||
# Convert rows to RefreshToken objects
|
||||
tokens = []
|
||||
for row in result.fetchall():
|
||||
token = RefreshToken(
|
||||
id=row.id,
|
||||
user_id=row.user_id,
|
||||
token=row.token,
|
||||
expires_at=row.expires_at,
|
||||
is_revoked=row.is_revoked,
|
||||
created_at=row.created_at,
|
||||
revoked_at=row.revoked_at
|
||||
)
|
||||
tokens.append(token)
|
||||
|
||||
return tokens
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get active tokens for user",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get active tokens: {str(e)}")
|
||||
|
||||
async def revoke_token(self, token_id: str) -> Optional[RefreshToken]:
|
||||
"""Revoke a refresh token"""
|
||||
try:
|
||||
return await self.update(token_id, {
|
||||
"is_revoked": True,
|
||||
"revoked_at": datetime.now(timezone.utc)
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error("Failed to revoke token",
|
||||
token_id=token_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to revoke token: {str(e)}")
|
||||
|
||||
async def revoke_all_user_tokens(self, user_id: str) -> int:
|
||||
"""Revoke all tokens for a user"""
|
||||
try:
|
||||
# Use bulk update for efficiency
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
query = text("""
|
||||
UPDATE refresh_tokens
|
||||
SET is_revoked = true, revoked_at = :revoked_at
|
||||
WHERE user_id = :user_id AND is_revoked = false
|
||||
""")
|
||||
|
||||
result = await self.session.execute(query, {
|
||||
"user_id": user_id,
|
||||
"revoked_at": now
|
||||
})
|
||||
|
||||
revoked_count = result.rowcount
|
||||
|
||||
logger.info("Revoked all user tokens",
|
||||
user_id=user_id,
|
||||
revoked_count=revoked_count)
|
||||
|
||||
return revoked_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to revoke all user tokens",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to revoke user tokens: {str(e)}")
|
||||
|
||||
async def is_token_valid(self, token: str) -> bool:
|
||||
"""Check if a token is valid (exists, not revoked, not expired)"""
|
||||
try:
|
||||
refresh_token = await self.get_token_by_value(token)
|
||||
|
||||
if not refresh_token:
|
||||
return False
|
||||
|
||||
if refresh_token.is_revoked:
|
||||
return False
|
||||
|
||||
if refresh_token.expires_at < datetime.now(timezone.utc):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to validate token", error=str(e))
|
||||
return False
|
||||
|
||||
async def cleanup_expired_tokens(self) -> int:
|
||||
"""Clean up expired refresh tokens"""
|
||||
try:
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Delete expired tokens
|
||||
query = text("""
|
||||
DELETE FROM refresh_tokens
|
||||
WHERE expires_at < :now
|
||||
""")
|
||||
|
||||
result = await self.session.execute(query, {"now": now})
|
||||
deleted_count = result.rowcount
|
||||
|
||||
logger.info("Cleaned up expired tokens",
|
||||
deleted_count=deleted_count)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup expired tokens", error=str(e))
|
||||
raise DatabaseError(f"Token cleanup failed: {str(e)}")
|
||||
|
||||
async def cleanup_old_revoked_tokens(self, days_old: int = 30) -> int:
|
||||
"""Clean up old revoked tokens"""
|
||||
try:
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_old)
|
||||
|
||||
query = text("""
|
||||
DELETE FROM refresh_tokens
|
||||
WHERE is_revoked = true
|
||||
AND revoked_at < :cutoff_date
|
||||
""")
|
||||
|
||||
result = await self.session.execute(query, {
|
||||
"cutoff_date": cutoff_date
|
||||
})
|
||||
|
||||
deleted_count = result.rowcount
|
||||
|
||||
logger.info("Cleaned up old revoked tokens",
|
||||
deleted_count=deleted_count,
|
||||
days_old=days_old)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup old revoked tokens",
|
||||
days_old=days_old,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Revoked token cleanup failed: {str(e)}")
|
||||
|
||||
async def get_token_statistics(self) -> Dict[str, Any]:
|
||||
"""Get token statistics"""
|
||||
try:
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Get counts with raw queries
|
||||
stats_query = text("""
|
||||
SELECT
|
||||
COUNT(*) as total_tokens,
|
||||
COUNT(CASE WHEN is_revoked = false AND expires_at > :now THEN 1 END) as active_tokens,
|
||||
COUNT(CASE WHEN is_revoked = true THEN 1 END) as revoked_tokens,
|
||||
COUNT(CASE WHEN expires_at <= :now THEN 1 END) as expired_tokens,
|
||||
COUNT(DISTINCT user_id) as users_with_tokens
|
||||
FROM refresh_tokens
|
||||
""")
|
||||
|
||||
result = await self.session.execute(stats_query, {"now": now})
|
||||
row = result.fetchone()
|
||||
|
||||
if row:
|
||||
return {
|
||||
"total_tokens": row.total_tokens,
|
||||
"active_tokens": row.active_tokens,
|
||||
"revoked_tokens": row.revoked_tokens,
|
||||
"expired_tokens": row.expired_tokens,
|
||||
"users_with_tokens": row.users_with_tokens
|
||||
}
|
||||
|
||||
return {
|
||||
"total_tokens": 0,
|
||||
"active_tokens": 0,
|
||||
"revoked_tokens": 0,
|
||||
"expired_tokens": 0,
|
||||
"users_with_tokens": 0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get token statistics", error=str(e))
|
||||
return {
|
||||
"total_tokens": 0,
|
||||
"active_tokens": 0,
|
||||
"revoked_tokens": 0,
|
||||
"expired_tokens": 0,
|
||||
"users_with_tokens": 0
|
||||
}
|
||||
277
services/auth/app/repositories/user_repository.py
Normal file
277
services/auth/app/repositories/user_repository.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""
|
||||
User Repository
|
||||
Repository for user operations with authentication-specific queries
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, or_, func, desc, text
|
||||
from datetime import datetime, timezone, timedelta
|
||||
import structlog
|
||||
|
||||
from .base import AuthBaseRepository
|
||||
from app.models.users import User
|
||||
from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class UserRepository(AuthBaseRepository):
|
||||
"""Repository for user operations"""
|
||||
|
||||
def __init__(self, model, session: AsyncSession, cache_ttl: Optional[int] = 600):
|
||||
super().__init__(model, session, cache_ttl)
|
||||
|
||||
async def create_user(self, user_data: Dict[str, Any]) -> User:
|
||||
"""Create a new user with validation"""
|
||||
try:
|
||||
# Validate user data
|
||||
validation_result = self._validate_auth_data(
|
||||
user_data,
|
||||
["email", "hashed_password", "full_name", "role"]
|
||||
)
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid user data: {validation_result['errors']}")
|
||||
|
||||
# Check if user already exists
|
||||
existing_user = await self.get_by_email(user_data["email"])
|
||||
if existing_user:
|
||||
raise DuplicateRecordError(f"User with email {user_data['email']} already exists")
|
||||
|
||||
# Create user
|
||||
user = await self.create(user_data)
|
||||
|
||||
logger.info("User created successfully",
|
||||
user_id=user.id,
|
||||
email=user.email,
|
||||
role=user.role)
|
||||
|
||||
return user
|
||||
|
||||
except (ValidationError, DuplicateRecordError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to create user",
|
||||
email=user_data.get("email"),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create user: {str(e)}")
|
||||
|
||||
async def get_user_by_email(self, email: str) -> Optional[User]:
|
||||
"""Get user by email address"""
|
||||
return await self.get_by_email(email)
|
||||
|
||||
async def get_active_users(self, skip: int = 0, limit: int = 100) -> List[User]:
|
||||
"""Get all active users"""
|
||||
return await self.get_active_records(skip=skip, limit=limit)
|
||||
|
||||
async def authenticate_user(self, email: str, password: str) -> Optional[User]:
|
||||
"""Authenticate user with email and plain password"""
|
||||
try:
|
||||
user = await self.get_by_email(email)
|
||||
|
||||
if not user:
|
||||
logger.debug("User not found for authentication", email=email)
|
||||
return None
|
||||
|
||||
if not user.is_active:
|
||||
logger.debug("User account is inactive", email=email)
|
||||
return None
|
||||
|
||||
# Verify password using security manager
|
||||
from app.core.security import SecurityManager
|
||||
if SecurityManager.verify_password(password, user.hashed_password):
|
||||
# Update last login
|
||||
await self.update_last_login(user.id)
|
||||
logger.info("User authenticated successfully",
|
||||
user_id=user.id,
|
||||
email=email)
|
||||
return user
|
||||
|
||||
logger.debug("Invalid password for user", email=email)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Authentication failed",
|
||||
email=email,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Authentication failed: {str(e)}")
|
||||
|
||||
async def update_last_login(self, user_id: str) -> Optional[User]:
|
||||
"""Update user's last login timestamp"""
|
||||
try:
|
||||
return await self.update(user_id, {
|
||||
"last_login": datetime.now(timezone.utc)
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error("Failed to update last login",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
# Don't raise here - last login update is not critical
|
||||
return None
|
||||
|
||||
async def update_user_profile(self, user_id: str, profile_data: Dict[str, Any]) -> Optional[User]:
|
||||
"""Update user profile information"""
|
||||
try:
|
||||
# Remove sensitive fields that shouldn't be updated via profile
|
||||
profile_data.pop("id", None)
|
||||
profile_data.pop("hashed_password", None)
|
||||
profile_data.pop("created_at", None)
|
||||
profile_data.pop("is_active", None)
|
||||
|
||||
# Validate email if being updated
|
||||
if "email" in profile_data:
|
||||
validation_result = self._validate_auth_data(
|
||||
profile_data,
|
||||
["email"]
|
||||
)
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid profile data: {validation_result['errors']}")
|
||||
|
||||
# Check for email conflicts
|
||||
existing_user = await self.get_by_email(profile_data["email"])
|
||||
if existing_user and str(existing_user.id) != str(user_id):
|
||||
raise DuplicateRecordError(f"Email {profile_data['email']} is already in use")
|
||||
|
||||
updated_user = await self.update(user_id, profile_data)
|
||||
|
||||
if updated_user:
|
||||
logger.info("User profile updated",
|
||||
user_id=user_id,
|
||||
updated_fields=list(profile_data.keys()))
|
||||
|
||||
return updated_user
|
||||
|
||||
except (ValidationError, DuplicateRecordError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to update user profile",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to update profile: {str(e)}")
|
||||
|
||||
async def change_password(self, user_id: str, new_password_hash: str) -> bool:
|
||||
"""Change user password"""
|
||||
try:
|
||||
updated_user = await self.update(user_id, {
|
||||
"hashed_password": new_password_hash
|
||||
})
|
||||
|
||||
if updated_user:
|
||||
logger.info("Password changed successfully", user_id=user_id)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to change password",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to change password: {str(e)}")
|
||||
|
||||
async def verify_user_email(self, user_id: str) -> Optional[User]:
|
||||
"""Mark user email as verified"""
|
||||
try:
|
||||
return await self.update(user_id, {
|
||||
"is_verified": True
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error("Failed to verify user email",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to verify email: {str(e)}")
|
||||
|
||||
async def deactivate_user(self, user_id: str) -> Optional[User]:
|
||||
"""Deactivate user account"""
|
||||
return await self.deactivate_record(user_id)
|
||||
|
||||
async def activate_user(self, user_id: str) -> Optional[User]:
|
||||
"""Activate user account"""
|
||||
return await self.activate_record(user_id)
|
||||
|
||||
async def get_users_by_role(self, role: str, skip: int = 0, limit: int = 100) -> List[User]:
|
||||
"""Get users by role"""
|
||||
try:
|
||||
return await self.get_multi(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
filters={"role": role, "is_active": True},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get users by role",
|
||||
role=role,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get users by role: {str(e)}")
|
||||
|
||||
async def search_users(self, search_term: str, skip: int = 0, limit: int = 50) -> List[User]:
|
||||
"""Search users by email or full name"""
|
||||
try:
|
||||
return await self.search(
|
||||
search_term=search_term,
|
||||
search_fields=["email", "full_name"],
|
||||
skip=skip,
|
||||
limit=limit
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to search users",
|
||||
search_term=search_term,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to search users: {str(e)}")
|
||||
|
||||
async def get_user_statistics(self) -> Dict[str, Any]:
|
||||
"""Get user statistics"""
|
||||
try:
|
||||
# Get basic counts
|
||||
total_users = await self.count()
|
||||
active_users = await self.count(filters={"is_active": True})
|
||||
verified_users = await self.count(filters={"is_verified": True})
|
||||
|
||||
# Get users by role using raw query
|
||||
role_query = text("""
|
||||
SELECT role, COUNT(*) as count
|
||||
FROM users
|
||||
WHERE is_active = true
|
||||
GROUP BY role
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(role_query)
|
||||
role_stats = {row.role: row.count for row in result.fetchall()}
|
||||
|
||||
# Recent activity (users created in last 30 days)
|
||||
thirty_days_ago = datetime.now(timezone.utc) - timedelta(days=30)
|
||||
recent_users_query = text("""
|
||||
SELECT COUNT(*) as count
|
||||
FROM users
|
||||
WHERE created_at >= :thirty_days_ago
|
||||
""")
|
||||
|
||||
recent_result = await self.session.execute(
|
||||
recent_users_query,
|
||||
{"thirty_days_ago": thirty_days_ago}
|
||||
)
|
||||
recent_users = recent_result.scalar() or 0
|
||||
|
||||
return {
|
||||
"total_users": total_users,
|
||||
"active_users": active_users,
|
||||
"inactive_users": total_users - active_users,
|
||||
"verified_users": verified_users,
|
||||
"unverified_users": active_users - verified_users,
|
||||
"recent_registrations": recent_users,
|
||||
"users_by_role": role_stats
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get user statistics", error=str(e))
|
||||
return {
|
||||
"total_users": 0,
|
||||
"active_users": 0,
|
||||
"inactive_users": 0,
|
||||
"verified_users": 0,
|
||||
"unverified_users": 0,
|
||||
"recent_registrations": 0,
|
||||
"users_by_role": {}
|
||||
}
|
||||
@@ -106,6 +106,17 @@ class UserResponse(BaseModel):
|
||||
class Config:
|
||||
from_attributes = True # ✅ Enable ORM mode for SQLAlchemy objects
|
||||
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
"""User update schema"""
|
||||
full_name: Optional[str] = None
|
||||
phone: Optional[str] = None
|
||||
language: Optional[str] = None
|
||||
timezone: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class TokenVerification(BaseModel):
|
||||
"""Token verification response"""
|
||||
valid: bool
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
"""
|
||||
Auth Service Layer
|
||||
Business logic services for authentication and user management
|
||||
"""
|
||||
|
||||
from .auth_service import AuthService
|
||||
from .auth_service import EnhancedAuthService
|
||||
from .user_service import UserService
|
||||
from .auth_service import EnhancedUserService
|
||||
from .auth_service_clients import AuthServiceClientFactory
|
||||
from .admin_delete import AdminUserDeleteService
|
||||
from .messaging import (
|
||||
publish_user_registered,
|
||||
publish_user_login,
|
||||
publish_user_updated,
|
||||
publish_user_deactivated
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AuthService",
|
||||
"EnhancedAuthService",
|
||||
"UserService",
|
||||
"EnhancedUserService",
|
||||
"AuthServiceClientFactory",
|
||||
"AdminUserDeleteService",
|
||||
"publish_user_registered",
|
||||
"publish_user_login",
|
||||
"publish_user_updated",
|
||||
"publish_user_deactivated"
|
||||
]
|
||||
@@ -1,310 +1,284 @@
|
||||
# services/auth/app/services/auth_service.py - UPDATED WITH NEW REGISTRATION METHOD
|
||||
"""
|
||||
Authentication Service - Updated to support registration with direct token issuance
|
||||
Enhanced Authentication Service
|
||||
Updated to use repository pattern with dependency injection and improved error handling
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
import structlog
|
||||
|
||||
from app.repositories import UserRepository, TokenRepository
|
||||
from app.schemas.auth import UserRegistration, UserLogin, TokenResponse, UserResponse
|
||||
from app.models.users import User, RefreshToken
|
||||
from app.schemas.auth import UserRegistration, UserLogin
|
||||
from app.core.security import SecurityManager
|
||||
from app.services.messaging import publish_user_registered, publish_user_login
|
||||
from shared.database.unit_of_work import UnitOfWork
|
||||
from shared.database.transactions import transactional
|
||||
from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
class AuthService:
|
||||
"""Enhanced Authentication service with unified token response"""
|
||||
|
||||
@staticmethod
|
||||
async def register_user(user_data: UserRegistration, db: AsyncSession) -> Dict[str, Any]:
|
||||
"""Register a new user with FIXED token generation"""
|
||||
# Legacy compatibility alias
|
||||
AuthService = None # Will be set at the end of the file
|
||||
|
||||
|
||||
class EnhancedAuthService:
|
||||
"""Enhanced authentication service using repository pattern"""
|
||||
|
||||
def __init__(self, database_manager):
|
||||
"""Initialize service with database manager"""
|
||||
self.database_manager = database_manager
|
||||
|
||||
async def register_user(
|
||||
self,
|
||||
user_data: UserRegistration
|
||||
) -> TokenResponse:
|
||||
"""Register a new user using repository pattern"""
|
||||
try:
|
||||
# Check if user already exists
|
||||
existing_user = await db.execute(
|
||||
select(User).where(User.email == user_data.email)
|
||||
)
|
||||
if existing_user.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="User with this email already exists"
|
||||
)
|
||||
|
||||
user_role = user_data.role if user_data.role else "user"
|
||||
|
||||
# Create new user
|
||||
hashed_password = SecurityManager.hash_password(user_data.password)
|
||||
new_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email=user_data.email,
|
||||
full_name=user_data.full_name,
|
||||
hashed_password=hashed_password,
|
||||
is_active=True,
|
||||
is_verified=False,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
role=user_role
|
||||
)
|
||||
|
||||
db.add(new_user)
|
||||
await db.flush() # Get user ID without committing
|
||||
|
||||
logger.debug(f"User created with role: {new_user.role} for {user_data.email}")
|
||||
|
||||
# ✅ FIX 1: Create SEPARATE access and refresh tokens with different payloads
|
||||
access_token_data = {
|
||||
"user_id": str(new_user.id),
|
||||
"email": new_user.email,
|
||||
"full_name": new_user.full_name,
|
||||
"is_verified": new_user.is_verified,
|
||||
"is_active": new_user.is_active,
|
||||
"role": new_user.role,
|
||||
"type": "access" # ✅ Explicitly mark as access token
|
||||
}
|
||||
|
||||
refresh_token_data = {
|
||||
"user_id": str(new_user.id),
|
||||
"email": new_user.email,
|
||||
"type": "refresh" # ✅ Explicitly mark as refresh token
|
||||
}
|
||||
|
||||
logger.debug(f"Creating tokens for registration: {user_data.email}")
|
||||
|
||||
# ✅ FIX 2: Generate tokens with different payloads
|
||||
access_token = SecurityManager.create_access_token(user_data=access_token_data)
|
||||
refresh_token_value = SecurityManager.create_refresh_token(user_data=refresh_token_data)
|
||||
|
||||
logger.debug(f"Tokens created successfully for {user_data.email}")
|
||||
|
||||
# ✅ FIX 3: Store ONLY the refresh token in database (not access token)
|
||||
refresh_token = RefreshToken(
|
||||
id=uuid.uuid4(),
|
||||
user_id=new_user.id,
|
||||
token=refresh_token_value, # Store the actual refresh token
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=30),
|
||||
is_revoked=False,
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
db.add(refresh_token)
|
||||
await db.commit()
|
||||
|
||||
# Publish registration event (non-blocking)
|
||||
try:
|
||||
await publish_user_registered({
|
||||
"user_id": str(new_user.id),
|
||||
"email": new_user.email,
|
||||
"full_name": new_user.full_name,
|
||||
"role": new_user.role,
|
||||
"registered_at": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to publish registration event: {e}")
|
||||
|
||||
logger.info(f"User registered successfully: {user_data.email}")
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token_value,
|
||||
"token_type": "bearer",
|
||||
"expires_in": 1800, # 30 minutes
|
||||
"user": {
|
||||
"id": str(new_user.id),
|
||||
"email": new_user.email,
|
||||
"full_name": new_user.full_name,
|
||||
"is_active": new_user.is_active,
|
||||
"is_verified": new_user.is_verified,
|
||||
"created_at": new_user.created_at.isoformat(),
|
||||
"role": new_user.role
|
||||
}
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
await db.rollback()
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
async with UnitOfWork(db_session) as uow:
|
||||
# Register repositories
|
||||
user_repo = uow.register_repository("users", UserRepository, User)
|
||||
token_repo = uow.register_repository("tokens", TokenRepository, RefreshToken)
|
||||
|
||||
# Check if user already exists
|
||||
existing_user = await user_repo.get_by_email(user_data.email)
|
||||
if existing_user:
|
||||
raise DuplicateRecordError("User with this email already exists")
|
||||
|
||||
# Create user data
|
||||
user_role = user_data.role if user_data.role else "user"
|
||||
hashed_password = SecurityManager.hash_password(user_data.password)
|
||||
|
||||
create_data = {
|
||||
"email": user_data.email,
|
||||
"full_name": user_data.full_name,
|
||||
"hashed_password": hashed_password,
|
||||
"is_active": True,
|
||||
"is_verified": False,
|
||||
"role": user_role
|
||||
}
|
||||
|
||||
# Create user using repository
|
||||
new_user = await user_repo.create_user(create_data)
|
||||
|
||||
logger.debug("User created with repository pattern",
|
||||
user_id=new_user.id,
|
||||
email=user_data.email,
|
||||
role=user_role)
|
||||
|
||||
# Create tokens with different payloads
|
||||
access_token_data = {
|
||||
"user_id": str(new_user.id),
|
||||
"email": new_user.email,
|
||||
"full_name": new_user.full_name,
|
||||
"is_verified": new_user.is_verified,
|
||||
"is_active": new_user.is_active,
|
||||
"role": new_user.role,
|
||||
"type": "access"
|
||||
}
|
||||
|
||||
refresh_token_data = {
|
||||
"user_id": str(new_user.id),
|
||||
"email": new_user.email,
|
||||
"type": "refresh"
|
||||
}
|
||||
|
||||
# Generate tokens
|
||||
access_token = SecurityManager.create_access_token(user_data=access_token_data)
|
||||
refresh_token_value = SecurityManager.create_refresh_token(user_data=refresh_token_data)
|
||||
|
||||
# Store refresh token using repository
|
||||
token_data = {
|
||||
"user_id": str(new_user.id),
|
||||
"token": refresh_token_value,
|
||||
"expires_at": datetime.now(timezone.utc) + timedelta(days=30),
|
||||
"is_revoked": False
|
||||
}
|
||||
|
||||
await token_repo.create_token(token_data)
|
||||
|
||||
# Commit transaction
|
||||
await uow.commit()
|
||||
|
||||
# Publish registration event (non-blocking)
|
||||
try:
|
||||
await publish_user_registered({
|
||||
"user_id": str(new_user.id),
|
||||
"email": new_user.email,
|
||||
"full_name": new_user.full_name,
|
||||
"role": new_user.role,
|
||||
"registered_at": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish registration event", error=str(e))
|
||||
|
||||
logger.info("User registered successfully using repository pattern",
|
||||
user_id=new_user.id,
|
||||
email=user_data.email)
|
||||
|
||||
from app.schemas.auth import UserData
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token_value,
|
||||
token_type="bearer",
|
||||
expires_in=1800,
|
||||
user=UserData(
|
||||
id=str(new_user.id),
|
||||
email=new_user.email,
|
||||
full_name=new_user.full_name,
|
||||
is_active=new_user.is_active,
|
||||
is_verified=new_user.is_verified,
|
||||
created_at=new_user.created_at.isoformat() if new_user.created_at else datetime.now(timezone.utc).isoformat(),
|
||||
role=new_user.role
|
||||
)
|
||||
)
|
||||
|
||||
except (ValidationError, DuplicateRecordError):
|
||||
raise
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Registration failed for {user_data.email}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Registration failed"
|
||||
)
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Registration failed for {user_data.email}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Registration failed"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def login_user(login_data: UserLogin, db: AsyncSession) -> Dict[str, Any]:
|
||||
"""Login user with FIXED token generation and SQLAlchemy syntax"""
|
||||
logger.error("Registration failed using repository pattern",
|
||||
email=user_data.email,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Registration failed: {str(e)}")
|
||||
|
||||
async def login_user(
|
||||
self,
|
||||
login_data: UserLogin
|
||||
) -> TokenResponse:
|
||||
"""Login user using repository pattern"""
|
||||
try:
|
||||
# Find user
|
||||
result = await db.execute(
|
||||
select(User).where(User.email == login_data.email)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user or not SecurityManager.verify_password(login_data.password, user.hashed_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid email or password"
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Account is deactivated"
|
||||
)
|
||||
|
||||
# ✅ FIX 4: Revoke existing refresh tokens using proper SQLAlchemy ORM syntax
|
||||
logger.debug(f"Revoking existing refresh tokens for user: {user.id}")
|
||||
|
||||
# Using SQLAlchemy ORM update (more reliable than raw SQL)
|
||||
stmt = update(RefreshToken).where(
|
||||
RefreshToken.user_id == user.id,
|
||||
RefreshToken.is_revoked == False
|
||||
).values(
|
||||
is_revoked=True,
|
||||
revoked_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
revoked_count = result.rowcount
|
||||
logger.debug(f"Revoked {revoked_count} existing refresh tokens for user: {user.id}")
|
||||
|
||||
# ✅ FIX 5: Create DIFFERENT token payloads
|
||||
access_token_data = {
|
||||
"user_id": str(user.id),
|
||||
"email": user.email,
|
||||
"full_name": user.full_name,
|
||||
"is_verified": user.is_verified,
|
||||
"is_active": user.is_active,
|
||||
"role": user.role,
|
||||
"type": "access" # ✅ Explicitly mark as access token
|
||||
}
|
||||
|
||||
refresh_token_data = {
|
||||
"user_id": str(user.id),
|
||||
"email": user.email,
|
||||
"type": "refresh", # ✅ Explicitly mark as refresh token
|
||||
"jti": str(uuid.uuid4()) # ✅ Add unique identifier for each refresh token
|
||||
}
|
||||
|
||||
logger.debug(f"Creating access token for login with data: {list(access_token_data.keys())}")
|
||||
logger.debug(f"Creating refresh token for login with data: {list(refresh_token_data.keys())}")
|
||||
|
||||
# ✅ FIX 6: Generate tokens with different payloads and expiration
|
||||
access_token = SecurityManager.create_access_token(user_data=access_token_data)
|
||||
refresh_token_value = SecurityManager.create_refresh_token(user_data=refresh_token_data)
|
||||
|
||||
logger.debug(f"Access token created successfully for user {login_data.email}")
|
||||
logger.debug(f"Refresh token created successfully for user {str(user.id)}")
|
||||
|
||||
# ✅ FIX 7: Store ONLY refresh token in database with unique constraint handling
|
||||
refresh_token = RefreshToken(
|
||||
id=uuid.uuid4(),
|
||||
user_id=user.id,
|
||||
token=refresh_token_value, # This should be the refresh token, not access token
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=30),
|
||||
is_revoked=False,
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
db.add(refresh_token)
|
||||
|
||||
# Update last login
|
||||
user.last_login = datetime.now(timezone.utc)
|
||||
|
||||
await db.commit()
|
||||
|
||||
# Publish login event (non-blocking)
|
||||
try:
|
||||
await publish_user_login({
|
||||
"user_id": str(user.id),
|
||||
"email": user.email,
|
||||
"login_at": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to publish login event: {e}")
|
||||
|
||||
logger.info(f"User logged in successfully: {login_data.email}")
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token_value,
|
||||
"token_type": "bearer",
|
||||
"expires_in": 1800, # 30 minutes
|
||||
"user": {
|
||||
"id": str(user.id),
|
||||
"email": user.email,
|
||||
"full_name": user.full_name,
|
||||
"is_active": user.is_active,
|
||||
"is_verified": user.is_verified,
|
||||
"created_at": user.created_at.isoformat(),
|
||||
"role": user.role
|
||||
}
|
||||
}
|
||||
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
async with UnitOfWork(db_session) as uow:
|
||||
# Register repositories
|
||||
user_repo = uow.register_repository("users", UserRepository, User)
|
||||
token_repo = uow.register_repository("tokens", TokenRepository, RefreshToken)
|
||||
|
||||
# Authenticate user using repository
|
||||
user = await user_repo.authenticate_user(login_data.email, login_data.password)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid email or password"
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Account is deactivated"
|
||||
)
|
||||
|
||||
# Revoke existing refresh tokens using repository
|
||||
await token_repo.revoke_all_user_tokens(str(user.id))
|
||||
|
||||
logger.debug("Existing tokens revoked using repository pattern",
|
||||
user_id=user.id)
|
||||
|
||||
# Create tokens with different payloads
|
||||
access_token_data = {
|
||||
"user_id": str(user.id),
|
||||
"email": user.email,
|
||||
"full_name": user.full_name,
|
||||
"is_verified": user.is_verified,
|
||||
"is_active": user.is_active,
|
||||
"role": user.role,
|
||||
"type": "access"
|
||||
}
|
||||
|
||||
refresh_token_data = {
|
||||
"user_id": str(user.id),
|
||||
"email": user.email,
|
||||
"type": "refresh",
|
||||
"jti": str(uuid.uuid4())
|
||||
}
|
||||
|
||||
# Generate tokens
|
||||
access_token = SecurityManager.create_access_token(user_data=access_token_data)
|
||||
refresh_token_value = SecurityManager.create_refresh_token(user_data=refresh_token_data)
|
||||
|
||||
# Store refresh token using repository
|
||||
token_data = {
|
||||
"user_id": str(user.id),
|
||||
"token": refresh_token_value,
|
||||
"expires_at": datetime.now(timezone.utc) + timedelta(days=30),
|
||||
"is_revoked": False
|
||||
}
|
||||
|
||||
await token_repo.create_token(token_data)
|
||||
|
||||
# Update last login using repository
|
||||
await user_repo.update_last_login(str(user.id))
|
||||
|
||||
# Commit transaction
|
||||
await uow.commit()
|
||||
|
||||
# Publish login event (non-blocking)
|
||||
try:
|
||||
await publish_user_login({
|
||||
"user_id": str(user.id),
|
||||
"email": user.email,
|
||||
"login_at": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish login event", error=str(e))
|
||||
|
||||
logger.info("User logged in successfully using repository pattern",
|
||||
user_id=user.id,
|
||||
email=login_data.email)
|
||||
|
||||
from app.schemas.auth import UserData
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token_value,
|
||||
token_type="bearer",
|
||||
expires_in=1800,
|
||||
user=UserData(
|
||||
id=str(user.id),
|
||||
email=user.email,
|
||||
full_name=user.full_name,
|
||||
is_active=user.is_active,
|
||||
is_verified=user.is_verified,
|
||||
created_at=user.created_at.isoformat() if user.created_at else datetime.now(timezone.utc).isoformat(),
|
||||
role=user.role
|
||||
)
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
await db.rollback()
|
||||
raise
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Login failed for {login_data.email}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Login failed"
|
||||
)
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Login failed for {login_data.email}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Login failed"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def logout_user(user_id: str, refresh_token: str, db: AsyncSession) -> bool:
|
||||
"""Logout user by revoking refresh token"""
|
||||
logger.error("Login failed using repository pattern",
|
||||
email=login_data.email,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Login failed: {str(e)}")
|
||||
|
||||
async def logout_user(self, user_id: str, refresh_token: str) -> bool:
|
||||
"""Logout user using repository pattern"""
|
||||
try:
|
||||
# Revoke the specific refresh token using ORM
|
||||
stmt = update(RefreshToken).where(
|
||||
RefreshToken.user_id == user_id,
|
||||
RefreshToken.token == refresh_token,
|
||||
RefreshToken.is_revoked == False
|
||||
).values(
|
||||
is_revoked=True,
|
||||
revoked_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
|
||||
if result.rowcount > 0:
|
||||
await db.commit()
|
||||
logger.info(f"User logged out successfully: {user_id}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async with self.database_manager.get_session() as session:
|
||||
token_repo = TokenRepository(session)
|
||||
|
||||
# Revoke specific refresh token using repository
|
||||
success = await token_repo.revoke_token(user_id, refresh_token)
|
||||
|
||||
if success:
|
||||
logger.info("User logged out successfully using repository pattern",
|
||||
user_id=user_id)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Logout failed for user {user_id}: {e}")
|
||||
logger.error("Logout failed using repository pattern",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def refresh_access_token(refresh_token: str, db: AsyncSession) -> Dict[str, Any]:
|
||||
"""Refresh access token using refresh token"""
|
||||
|
||||
async def refresh_access_token(self, refresh_token: str) -> Dict[str, Any]:
|
||||
"""Refresh access token using repository pattern"""
|
||||
try:
|
||||
# Verify refresh token
|
||||
payload = SecurityManager.decode_token(refresh_token)
|
||||
@@ -316,66 +290,59 @@ class AuthService:
|
||||
detail="Invalid refresh token"
|
||||
)
|
||||
|
||||
# Check if refresh token exists and is valid using ORM
|
||||
result = await db.execute(
|
||||
select(RefreshToken).where(
|
||||
RefreshToken.user_id == user_id,
|
||||
RefreshToken.token == refresh_token,
|
||||
RefreshToken.is_revoked == False,
|
||||
RefreshToken.expires_at > datetime.now(timezone.utc)
|
||||
)
|
||||
)
|
||||
stored_token = result.scalar_one_or_none()
|
||||
|
||||
if not stored_token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired refresh token"
|
||||
)
|
||||
|
||||
# Get user
|
||||
user_result = await db.execute(
|
||||
select(User).where(User.id == user_id)
|
||||
)
|
||||
user = user_result.scalar_one_or_none()
|
||||
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found or inactive"
|
||||
)
|
||||
|
||||
# Create new access token
|
||||
access_token_data = {
|
||||
"user_id": str(user.id),
|
||||
"email": user.email,
|
||||
"full_name": user.full_name,
|
||||
"is_verified": user.is_verified,
|
||||
"is_active": user.is_active,
|
||||
"role": user.role,
|
||||
"type": "access"
|
||||
}
|
||||
|
||||
new_access_token = SecurityManager.create_access_token(user_data=access_token_data)
|
||||
|
||||
return {
|
||||
"access_token": new_access_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": 1800
|
||||
}
|
||||
|
||||
async with self.database_manager.get_session() as session:
|
||||
user_repo = UserRepository(session)
|
||||
token_repo = TokenRepository(session)
|
||||
|
||||
# Validate refresh token using repository
|
||||
is_valid = await token_repo.validate_refresh_token(refresh_token, user_id)
|
||||
if not is_valid:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired refresh token"
|
||||
)
|
||||
|
||||
# Get user using repository
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found or inactive"
|
||||
)
|
||||
|
||||
# Create new access token
|
||||
access_token_data = {
|
||||
"user_id": str(user.id),
|
||||
"email": user.email,
|
||||
"full_name": user.full_name,
|
||||
"is_verified": user.is_verified,
|
||||
"is_active": user.is_active,
|
||||
"role": user.role,
|
||||
"type": "access"
|
||||
}
|
||||
|
||||
new_access_token = SecurityManager.create_access_token(user_data=access_token_data)
|
||||
|
||||
logger.debug("Access token refreshed successfully using repository pattern",
|
||||
user_id=user_id)
|
||||
|
||||
return {
|
||||
"access_token": new_access_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": 1800
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Token refresh failed: {e}")
|
||||
logger.error("Token refresh failed using repository pattern", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token refresh failed"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def verify_user_token(token: str) -> Dict[str, Any]:
|
||||
"""Verify access token and return user info (UNCHANGED)"""
|
||||
|
||||
async def verify_user_token(self, token: str) -> Dict[str, Any]:
|
||||
"""Verify access token and return user info"""
|
||||
try:
|
||||
payload = SecurityManager.verify_token(token)
|
||||
if not payload:
|
||||
@@ -387,8 +354,173 @@ class AuthService:
|
||||
return payload
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Token verification error: {e}")
|
||||
logger.error("Token verification error using repository pattern", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token"
|
||||
)
|
||||
)
|
||||
|
||||
async def get_user_profile(self, user_id: str) -> Optional[UserResponse]:
|
||||
"""Get user profile using repository pattern"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as session:
|
||||
user_repo = UserRepository(session)
|
||||
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
return UserResponse(
|
||||
id=str(user.id),
|
||||
email=user.email,
|
||||
full_name=user.full_name,
|
||||
is_active=user.is_active,
|
||||
is_verified=user.is_verified,
|
||||
created_at=user.created_at,
|
||||
role=user.role
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get user profile using repository pattern",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
return None
|
||||
|
||||
async def update_user_profile(
|
||||
self,
|
||||
user_id: str,
|
||||
update_data: Dict[str, Any]
|
||||
) -> Optional[UserResponse]:
|
||||
"""Update user profile using repository pattern"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as session:
|
||||
user_repo = UserRepository(session)
|
||||
|
||||
updated_user = await user_repo.update(user_id, update_data)
|
||||
if not updated_user:
|
||||
return None
|
||||
|
||||
logger.info("User profile updated using repository pattern",
|
||||
user_id=user_id,
|
||||
updated_fields=list(update_data.keys()))
|
||||
|
||||
return UserResponse(
|
||||
id=str(updated_user.id),
|
||||
email=updated_user.email,
|
||||
full_name=updated_user.full_name,
|
||||
is_active=updated_user.is_active,
|
||||
is_verified=updated_user.is_verified,
|
||||
created_at=updated_user.created_at,
|
||||
role=updated_user.role
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to update user profile using repository pattern",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to update profile: {str(e)}")
|
||||
|
||||
async def change_password(
|
||||
self,
|
||||
user_id: str,
|
||||
old_password: str,
|
||||
new_password: str
|
||||
) -> bool:
|
||||
"""Change user password using repository pattern"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as session:
|
||||
user_repo = UserRepository(session)
|
||||
token_repo = TokenRepository(session)
|
||||
|
||||
# Get user and verify old password
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
if not SecurityManager.verify_password(old_password, user.hashed_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid old password"
|
||||
)
|
||||
|
||||
# Hash new password and update
|
||||
new_hashed_password = SecurityManager.hash_password(new_password)
|
||||
await user_repo.update(user_id, {"hashed_password": new_hashed_password})
|
||||
|
||||
# Revoke all existing tokens for security
|
||||
await token_repo.revoke_all_user_tokens(user_id)
|
||||
|
||||
logger.info("Password changed successfully using repository pattern",
|
||||
user_id=user_id)
|
||||
|
||||
return True
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to change password using repository pattern",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to change password: {str(e)}")
|
||||
|
||||
async def verify_user_email(self, user_id: str, verification_token: str) -> bool:
|
||||
"""Verify user email using repository pattern"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as session:
|
||||
user_repo = UserRepository(session)
|
||||
|
||||
# In a real implementation, you'd verify the verification_token
|
||||
# For now, just mark user as verified
|
||||
updated_user = await user_repo.update(user_id, {"is_verified": True})
|
||||
|
||||
if updated_user:
|
||||
logger.info("User email verified using repository pattern",
|
||||
user_id=user_id)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to verify email using repository pattern",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
return False
|
||||
|
||||
async def deactivate_user(self, user_id: str, admin_user_id: str) -> bool:
|
||||
"""Deactivate user account using repository pattern"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as session:
|
||||
user_repo = UserRepository(session)
|
||||
token_repo = TokenRepository(session)
|
||||
|
||||
# Update user status
|
||||
updated_user = await user_repo.update(user_id, {"is_active": False})
|
||||
if not updated_user:
|
||||
return False
|
||||
|
||||
# Revoke all tokens
|
||||
await token_repo.revoke_all_user_tokens(user_id)
|
||||
|
||||
logger.info("User deactivated using repository pattern",
|
||||
user_id=user_id,
|
||||
admin_user_id=admin_user_id)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to deactivate user using repository pattern",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
return False
|
||||
|
||||
|
||||
# Legacy compatibility - alias EnhancedAuthService as AuthService
|
||||
AuthService = EnhancedAuthService
|
||||
|
||||
|
||||
class EnhancedUserService(EnhancedAuthService):
|
||||
"""User service alias for backward compatibility"""
|
||||
pass
|
||||
@@ -36,3 +36,11 @@ async def publish_user_login(user_data: dict) -> bool:
|
||||
async def publish_user_logout(user_data: dict) -> bool:
|
||||
"""Publish user logout event"""
|
||||
return await auth_publisher.publish_user_event("logout", user_data)
|
||||
|
||||
async def publish_user_updated(user_data: dict) -> bool:
|
||||
"""Publish user updated event"""
|
||||
return await auth_publisher.publish_user_event("updated", user_data)
|
||||
|
||||
async def publish_user_deactivated(user_data: dict) -> bool:
|
||||
"""Publish user deactivated event"""
|
||||
return await auth_publisher.publish_user_event("deactivated", user_data)
|
||||
|
||||
@@ -1,153 +1,484 @@
|
||||
"""
|
||||
User service for managing user operations
|
||||
Enhanced User Service
|
||||
Updated to use repository pattern with dependency injection and improved error handling
|
||||
"""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, update, delete
|
||||
from fastapi import HTTPException, status
|
||||
from passlib.context import CryptContext
|
||||
import structlog
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Any, List, Optional
|
||||
from fastapi import HTTPException, status
|
||||
import structlog
|
||||
|
||||
from app.models.users import User
|
||||
from app.core.config import settings
|
||||
from app.repositories import UserRepository, TokenRepository
|
||||
from app.schemas.auth import UserResponse, UserUpdate
|
||||
from app.core.security import SecurityManager
|
||||
from shared.database.unit_of_work import UnitOfWork
|
||||
from shared.database.transactions import transactional
|
||||
from shared.database.exceptions import DatabaseError, ValidationError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Password hashing
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
class UserService:
|
||||
"""Service for user management operations"""
|
||||
class EnhancedUserService:
|
||||
"""Enhanced user management service using repository pattern"""
|
||||
|
||||
@staticmethod
|
||||
async def get_user_by_id(user_id: str, db: AsyncSession) -> User:
|
||||
"""Get user by ID"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(User).where(User.id == user_id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user by ID {user_id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get user"
|
||||
)
|
||||
def __init__(self, database_manager):
|
||||
"""Initialize service with database manager"""
|
||||
self.database_manager = database_manager
|
||||
|
||||
@staticmethod
|
||||
async def update_user(user_id: str, user_data: dict, db: AsyncSession) -> User:
|
||||
"""Update user information"""
|
||||
async def get_user_by_id(self, user_id: str) -> Optional[UserResponse]:
|
||||
"""Get user by ID using repository pattern"""
|
||||
try:
|
||||
# Get current user
|
||||
user = await UserService.get_user_by_id(user_id, db)
|
||||
|
||||
# Update fields
|
||||
update_data = {}
|
||||
allowed_fields = ['full_name', 'phone', 'language', 'timezone']
|
||||
|
||||
for field in allowed_fields:
|
||||
if field in user_data:
|
||||
update_data[field] = user_data[field]
|
||||
|
||||
if update_data:
|
||||
update_data["updated_at"] = datetime.now(timezone.utc)
|
||||
await db.execute(
|
||||
update(User)
|
||||
.where(User.id == user_id)
|
||||
.values(**update_data)
|
||||
)
|
||||
await db.commit()
|
||||
async with self.database_manager.get_session() as session:
|
||||
user_repo = UserRepository(session)
|
||||
|
||||
# Refresh user object
|
||||
await db.refresh(user)
|
||||
|
||||
return user
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating user {user_id}: {e}")
|
||||
await db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update user"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def change_password(
|
||||
user_id: str,
|
||||
current_password: str,
|
||||
new_password: str,
|
||||
db: AsyncSession
|
||||
):
|
||||
"""Change user password"""
|
||||
try:
|
||||
# Get current user
|
||||
user = await UserService.get_user_by_id(user_id, db)
|
||||
|
||||
# Verify current password
|
||||
if not pwd_context.verify(current_password, user.hashed_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Current password is incorrect"
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
return UserResponse(
|
||||
id=str(user.id),
|
||||
email=user.email,
|
||||
full_name=user.full_name,
|
||||
is_active=user.is_active,
|
||||
is_verified=user.is_verified,
|
||||
created_at=user.created_at,
|
||||
role=user.role,
|
||||
phone=getattr(user, 'phone', None),
|
||||
language=getattr(user, 'language', None),
|
||||
timezone=getattr(user, 'timezone', None)
|
||||
)
|
||||
|
||||
# Hash new password
|
||||
new_hashed_password = pwd_context.hash(new_password)
|
||||
|
||||
# Update password
|
||||
await db.execute(
|
||||
update(User)
|
||||
.where(User.id == user_id)
|
||||
.values(hashed_password=new_hashed_password, updated_at=datetime.now(timezone.utc))
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"Password changed for user {user_id}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error changing password for user {user_id}: {e}")
|
||||
await db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to change password"
|
||||
)
|
||||
logger.error("Failed to get user by ID using repository pattern",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get user: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
async def delete_user(user_id: str, db: AsyncSession):
|
||||
"""Delete user account"""
|
||||
async def get_user_by_email(self, email: str) -> Optional[UserResponse]:
|
||||
"""Get user by email using repository pattern"""
|
||||
try:
|
||||
# Get current user first
|
||||
user = await UserService.get_user_by_id(user_id, db)
|
||||
|
||||
# Soft delete by deactivating
|
||||
await db.execute(
|
||||
update(User)
|
||||
.where(User.id == user_id)
|
||||
.values(is_active=False)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"User {user_id} deactivated (soft delete)")
|
||||
|
||||
async with self.database_manager.get_session() as session:
|
||||
user_repo = UserRepository(session)
|
||||
|
||||
user = await user_repo.get_by_email(email)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
return UserResponse(
|
||||
id=str(user.id),
|
||||
email=user.email,
|
||||
full_name=user.full_name,
|
||||
is_active=user.is_active,
|
||||
is_verified=user.is_verified,
|
||||
created_at=user.created_at,
|
||||
role=user.role,
|
||||
phone=getattr(user, 'phone', None),
|
||||
language=getattr(user, 'language', None),
|
||||
timezone=getattr(user, 'timezone', None)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get user by email using repository pattern",
|
||||
email=email,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get user: {str(e)}")
|
||||
|
||||
async def get_users_list(
|
||||
self,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
active_only: bool = True,
|
||||
role: str = None
|
||||
) -> List[UserResponse]:
|
||||
"""Get paginated list of users using repository pattern"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as session:
|
||||
user_repo = UserRepository(session)
|
||||
|
||||
filters = {}
|
||||
if active_only:
|
||||
filters["is_active"] = True
|
||||
if role:
|
||||
filters["role"] = role
|
||||
|
||||
users = await user_repo.get_multi(
|
||||
filters=filters,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
|
||||
return [
|
||||
UserResponse(
|
||||
id=str(user.id),
|
||||
email=user.email,
|
||||
full_name=user.full_name,
|
||||
is_active=user.is_active,
|
||||
is_verified=user.is_verified,
|
||||
created_at=user.created_at,
|
||||
role=user.role,
|
||||
phone=getattr(user, 'phone', None),
|
||||
language=getattr(user, 'language', None),
|
||||
timezone=getattr(user, 'timezone', None)
|
||||
)
|
||||
for user in users
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get users list using repository pattern", error=str(e))
|
||||
return []
|
||||
|
||||
@transactional
|
||||
async def update_user(
|
||||
self,
|
||||
user_id: str,
|
||||
user_data: UserUpdate,
|
||||
session=None
|
||||
) -> Optional[UserResponse]:
|
||||
"""Update user information using repository pattern"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
user_repo = UserRepository(db_session)
|
||||
|
||||
# Validate user exists
|
||||
existing_user = await user_repo.get_by_id(user_id)
|
||||
if not existing_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
# Prepare update data
|
||||
update_data = {}
|
||||
if user_data.full_name is not None:
|
||||
update_data["full_name"] = user_data.full_name
|
||||
if user_data.phone is not None:
|
||||
update_data["phone"] = user_data.phone
|
||||
if user_data.language is not None:
|
||||
update_data["language"] = user_data.language
|
||||
if user_data.timezone is not None:
|
||||
update_data["timezone"] = user_data.timezone
|
||||
|
||||
if not update_data:
|
||||
# No updates to apply
|
||||
return UserResponse(
|
||||
id=str(existing_user.id),
|
||||
email=existing_user.email,
|
||||
full_name=existing_user.full_name,
|
||||
is_active=existing_user.is_active,
|
||||
is_verified=existing_user.is_verified,
|
||||
created_at=existing_user.created_at,
|
||||
role=existing_user.role
|
||||
)
|
||||
|
||||
# Update user using repository
|
||||
updated_user = await user_repo.update(user_id, update_data)
|
||||
if not updated_user:
|
||||
raise DatabaseError("Failed to update user")
|
||||
|
||||
logger.info("User updated successfully using repository pattern",
|
||||
user_id=user_id,
|
||||
updated_fields=list(update_data.keys()))
|
||||
|
||||
return UserResponse(
|
||||
id=str(updated_user.id),
|
||||
email=updated_user.email,
|
||||
full_name=updated_user.full_name,
|
||||
is_active=updated_user.is_active,
|
||||
is_verified=updated_user.is_verified,
|
||||
created_at=updated_user.created_at,
|
||||
role=updated_user.role,
|
||||
phone=getattr(updated_user, 'phone', None),
|
||||
language=getattr(updated_user, 'language', None),
|
||||
timezone=getattr(updated_user, 'timezone', None)
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting user {user_id}: {e}")
|
||||
await db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to delete user"
|
||||
)
|
||||
logger.error("Failed to update user using repository pattern",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to update user: {str(e)}")
|
||||
|
||||
@transactional
|
||||
async def change_password(
|
||||
self,
|
||||
user_id: str,
|
||||
current_password: str,
|
||||
new_password: str,
|
||||
session=None
|
||||
) -> bool:
|
||||
"""Change user password using repository pattern"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
async with UnitOfWork(db_session) as uow:
|
||||
# Register repositories
|
||||
user_repo = uow.register_repository("users", UserRepository)
|
||||
token_repo = uow.register_repository("tokens", TokenRepository)
|
||||
|
||||
# Get user and verify current password
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
if not SecurityManager.verify_password(current_password, user.hashed_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Current password is incorrect"
|
||||
)
|
||||
|
||||
# Hash new password and update
|
||||
new_hashed_password = SecurityManager.hash_password(new_password)
|
||||
await user_repo.update(user_id, {"hashed_password": new_hashed_password})
|
||||
|
||||
# Revoke all existing tokens for security
|
||||
await token_repo.revoke_user_tokens(user_id)
|
||||
|
||||
# Commit transaction
|
||||
await uow.commit()
|
||||
|
||||
logger.info("Password changed successfully using repository pattern",
|
||||
user_id=user_id)
|
||||
|
||||
return True
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to change password using repository pattern",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to change password: {str(e)}")
|
||||
|
||||
@transactional
|
||||
async def deactivate_user(self, user_id: str, admin_user_id: str, session=None) -> bool:
|
||||
"""Deactivate user account using repository pattern"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
async with UnitOfWork(db_session) as uow:
|
||||
# Register repositories
|
||||
user_repo = uow.register_repository("users", UserRepository)
|
||||
token_repo = uow.register_repository("tokens", TokenRepository)
|
||||
|
||||
# Verify user exists
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
return False
|
||||
|
||||
# Update user status (soft delete)
|
||||
updated_user = await user_repo.update(user_id, {"is_active": False})
|
||||
if not updated_user:
|
||||
return False
|
||||
|
||||
# Revoke all tokens
|
||||
await token_repo.revoke_user_tokens(user_id)
|
||||
|
||||
# Commit transaction
|
||||
await uow.commit()
|
||||
|
||||
logger.info("User deactivated successfully using repository pattern",
|
||||
user_id=user_id,
|
||||
admin_user_id=admin_user_id)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to deactivate user using repository pattern",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
return False
|
||||
|
||||
@transactional
|
||||
async def activate_user(self, user_id: str, admin_user_id: str, session=None) -> bool:
|
||||
"""Activate user account using repository pattern"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
user_repo = UserRepository(db_session)
|
||||
|
||||
# Update user status
|
||||
updated_user = await user_repo.update(user_id, {"is_active": True})
|
||||
if not updated_user:
|
||||
return False
|
||||
|
||||
logger.info("User activated successfully using repository pattern",
|
||||
user_id=user_id,
|
||||
admin_user_id=admin_user_id)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to activate user using repository pattern",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
return False
|
||||
|
||||
async def verify_user_email(self, user_id: str, verification_token: str) -> bool:
|
||||
"""Verify user email using repository pattern"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as session:
|
||||
user_repo = UserRepository(session)
|
||||
|
||||
# In a real implementation, you'd verify the verification_token
|
||||
# For now, just mark user as verified
|
||||
updated_user = await user_repo.update(user_id, {"is_verified": True})
|
||||
|
||||
if updated_user:
|
||||
logger.info("User email verified using repository pattern",
|
||||
user_id=user_id)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to verify email using repository pattern",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
return False
|
||||
|
||||
async def get_user_statistics(self) -> Dict[str, Any]:
|
||||
"""Get user statistics using repository pattern"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as session:
|
||||
user_repo = UserRepository(session)
|
||||
|
||||
# Get basic user statistics
|
||||
statistics = await user_repo.get_user_statistics()
|
||||
|
||||
return statistics
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get user statistics using repository pattern", error=str(e))
|
||||
return {
|
||||
"total_users": 0,
|
||||
"active_users": 0,
|
||||
"verified_users": 0,
|
||||
"users_by_role": {},
|
||||
"recent_registrations_7d": 0
|
||||
}
|
||||
|
||||
async def search_users(
|
||||
self,
|
||||
search_term: str,
|
||||
role: str = None,
|
||||
active_only: bool = True,
|
||||
skip: int = 0,
|
||||
limit: int = 50
|
||||
) -> List[UserResponse]:
|
||||
"""Search users by email or name using repository pattern"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as session:
|
||||
user_repo = UserRepository(session)
|
||||
|
||||
users = await user_repo.search_users(
|
||||
search_term, role, active_only, skip, limit
|
||||
)
|
||||
|
||||
return [
|
||||
UserResponse(
|
||||
id=str(user.id),
|
||||
email=user.email,
|
||||
full_name=user.full_name,
|
||||
is_active=user.is_active,
|
||||
is_verified=user.is_verified,
|
||||
created_at=user.created_at,
|
||||
role=user.role,
|
||||
phone=getattr(user, 'phone', None),
|
||||
language=getattr(user, 'language', None),
|
||||
timezone=getattr(user, 'timezone', None)
|
||||
)
|
||||
for user in users
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to search users using repository pattern",
|
||||
search_term=search_term,
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def update_user_role(
|
||||
self,
|
||||
user_id: str,
|
||||
new_role: str,
|
||||
admin_user_id: str
|
||||
) -> Optional[UserResponse]:
|
||||
"""Update user role using repository pattern"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as session:
|
||||
user_repo = UserRepository(session)
|
||||
|
||||
# Validate role
|
||||
valid_roles = ["user", "admin", "super_admin"]
|
||||
if new_role not in valid_roles:
|
||||
raise ValidationError(f"Invalid role. Must be one of: {valid_roles}")
|
||||
|
||||
# Update user role
|
||||
updated_user = await user_repo.update(user_id, {"role": new_role})
|
||||
if not updated_user:
|
||||
return None
|
||||
|
||||
logger.info("User role updated using repository pattern",
|
||||
user_id=user_id,
|
||||
new_role=new_role,
|
||||
admin_user_id=admin_user_id)
|
||||
|
||||
return UserResponse(
|
||||
id=str(updated_user.id),
|
||||
email=updated_user.email,
|
||||
full_name=updated_user.full_name,
|
||||
is_active=updated_user.is_active,
|
||||
is_verified=updated_user.is_verified,
|
||||
created_at=updated_user.created_at,
|
||||
role=updated_user.role,
|
||||
phone=getattr(updated_user, 'phone', None),
|
||||
language=getattr(updated_user, 'language', None),
|
||||
timezone=getattr(updated_user, 'timezone', None)
|
||||
)
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to update user role using repository pattern",
|
||||
user_id=user_id,
|
||||
new_role=new_role,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to update role: {str(e)}")
|
||||
|
||||
async def get_user_activity(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Get user activity information using repository pattern"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as session:
|
||||
user_repo = UserRepository(session)
|
||||
token_repo = TokenRepository(session)
|
||||
|
||||
# Get user
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
return {"error": "User not found"}
|
||||
|
||||
# Get token activity
|
||||
active_tokens = await token_repo.get_user_active_tokens(user_id)
|
||||
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"last_login": user.last_login.isoformat() if user.last_login else None,
|
||||
"account_created": user.created_at.isoformat(),
|
||||
"is_active": user.is_active,
|
||||
"is_verified": user.is_verified,
|
||||
"active_sessions": len(active_tokens),
|
||||
"last_activity": max([token.created_at for token in active_tokens]).isoformat() if active_tokens else None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get user activity using repository pattern",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
# Legacy compatibility - alias EnhancedUserService as UserService
|
||||
UserService = EnhancedUserService
|
||||
Reference in New Issue
Block a user