REFACTOR API gateway fix 8

This commit is contained in:
Urtzi Alfaro
2025-07-26 23:29:57 +02:00
parent 1291d05183
commit 97ae58fb06
8 changed files with 997 additions and 375 deletions

View File

@@ -56,14 +56,8 @@ async def register(
logger.debug(f"Input validation passed for {user_data.email}")
# ✅ DEBUG: Call auth service with enhanced error tracking
result = await AuthService.register_user_with_tokens(
email=user_data.email.strip().lower(), # Normalize email
password=user_data.password,
full_name=user_data.full_name.strip(),
db=db
)
result = await AuthService.register_user(user_data, db)
logger.info(f"Registration successful for {user_data.email}")
# Record successful registration
@@ -132,11 +126,7 @@ async def login(
)
# Attempt login through AuthService
result = await AuthService.login(
email=login_data.email.strip().lower(), # Normalize email
password=login_data.password,
db=db
)
result = await AuthService.login_user(login_data, db)
# Record successful login
if metrics:

View File

@@ -30,20 +30,6 @@ redis_client = redis.from_url(settings.REDIS_URL)
class SecurityManager:
"""Security utilities for authentication - FIXED VERSION"""
@staticmethod
def hash_password(password: str) -> str:
"""Hash password using passlib bcrypt - FIXED"""
return pwd_context.hash(password)
@staticmethod
def verify_password(password: str, hashed_password: str) -> bool:
"""Verify password against hash using passlib - FIXED"""
try:
return pwd_context.verify(password, hashed_password)
except Exception as e:
logger.error(f"Password verification error: {e}")
return False
@staticmethod
def validate_password(password: str) -> bool:
"""Validate password strength"""
@@ -65,48 +51,59 @@ class SecurityManager:
return True
@staticmethod
def create_access_token(user_data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
"""Create JWT access token with PROPER validation"""
def hash_password(password: str) -> str:
"""Hash password using passlib bcrypt - FIXED"""
return pwd_context.hash(password)
@staticmethod
def verify_password(password: str, hashed_password: str) -> bool:
"""Verify password against hash using passlib - FIXED"""
try:
return pwd_context.verify(password, hashed_password)
except Exception as e:
logger.error(f"Password verification error: {e}")
return False
@staticmethod
def create_access_token(user_data: Dict[str, Any]) -> str:
"""
Create JWT ACCESS token with proper payload structure
✅ FIXED: Only creates access tokens
"""
# ✅ FIX 1: Validate required fields BEFORE token creation
required_fields = ["user_id", "email"]
missing_fields = [field for field in required_fields if field not in user_data]
# Validate required fields for access token
if "user_id" not in user_data:
raise ValueError("user_id required for access token creation")
if missing_fields:
error_msg = f"Missing required fields for token creation: {missing_fields}"
logger.error(f"Token creation failed: {error_msg}")
raise ValueError(error_msg)
# ✅ FIX 2: Validate that required fields are not None/empty
if not user_data.get("user_id"):
raise ValueError("user_id cannot be empty")
if not user_data.get("email"):
raise ValueError("email cannot be empty")
if "email" not in user_data:
raise ValueError("email required for access token creation")
try:
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES)
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES)
# ✅ FIX 3: Build payload with SAFE access to user_data
# ✅ FIX 1: ACCESS TOKEN payload structure
payload = {
"sub": user_data["user_id"],
"user_id": user_data["user_id"],
"email": user_data["email"], # ✅ Guaranteed to exist now
"type": "access",
"full_name": user_data.get("full_name", ""), # Safe access with default
"is_verified": user_data.get("is_verified", False), # Safe access with default
"is_active": user_data.get("is_active", True), # Safe access with default
"email": user_data["email"],
"type": "access", # ✅ EXPLICITLY set as access token
"exp": expire,
"iat": datetime.now(timezone.utc),
"iss": "bakery-auth" # Token issuer
"iss": "bakery-auth"
}
# Add optional fields for access tokens
if "full_name" in user_data:
payload["full_name"] = user_data["full_name"]
if "is_verified" in user_data:
payload["is_verified"] = user_data["is_verified"]
if "is_active" in user_data:
payload["is_active"] = user_data["is_active"]
logger.debug(f"Creating access token with payload keys: {list(payload.keys())}")
# ✅ FIX 4: Use jwt_handler with proper error handling
token = jwt_handler.create_access_token(payload)
# ✅ FIX 2: Use JWT handler to create access token
token = jwt_handler.create_access_token_from_payload(payload)
logger.debug(f"Access token created successfully for user {user_data['email']}")
return token
@@ -116,13 +113,14 @@ class SecurityManager:
@staticmethod
def create_refresh_token(user_data: Dict[str, Any]) -> str:
"""Create JWT refresh token with FLEXIBLE validation"""
"""
Create JWT REFRESH token with minimal payload structure
✅ FIXED: Only creates refresh tokens, different from access tokens
"""
# ✅ FIX 1: Validate only essential fields for refresh token
# Validate required fields for refresh token
if "user_id" not in user_data:
error_msg = "user_id required for refresh token creation"
logger.error(f"Refresh token creation failed: {error_msg}")
raise ValueError(error_msg)
raise ValueError("user_id required for refresh token creation")
if not user_data.get("user_id"):
raise ValueError("user_id cannot be empty")
@@ -130,24 +128,31 @@ class SecurityManager:
try:
expire = datetime.now(timezone.utc) + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
# ✅ FIX 2: Minimal payload for refresh token (email is optional)
# ✅ FIX 3: REFRESH TOKEN payload structure (minimal, different from access)
payload = {
"sub": user_data["user_id"],
"user_id": user_data["user_id"],
"type": "refresh",
"type": "refresh", # ✅ EXPLICITLY set as refresh token
"exp": expire,
"iat": datetime.now(timezone.utc),
"iss": "bakery-auth"
}
# ✅ FIX 3: Include email only if available (no longer required)
# Add unique JTI for refresh tokens to prevent duplicates
if "jti" in user_data:
payload["jti"] = user_data["jti"]
else:
import uuid
payload["jti"] = str(uuid.uuid4())
# Include email only if available (optional for refresh tokens)
if "email" in user_data and user_data["email"]:
payload["email"] = user_data["email"]
logger.debug(f"Creating refresh token with payload keys: {list(payload.keys())}")
# Use the same JWT handler method (it handles both access and refresh)
token = jwt_handler.create_access_token(payload)
# ✅ FIX 4: Use JWT handler to create REFRESH token (not access token!)
token = jwt_handler.create_refresh_token_from_payload(payload)
logger.debug(f"Refresh token created successfully for user {user_data['user_id']}")
return token
@@ -167,6 +172,55 @@ class SecurityManager:
logger.warning(f"Token verification failed: {e}")
return None
@staticmethod
def decode_token(token: str) -> Dict[str, Any]:
"""Decode JWT token without verification (for refresh token handling)"""
try:
payload = jwt_handler.decode_token_no_verify(token)
return payload
except Exception as e:
logger.error(f"Token decoding failed: {e}")
raise ValueError("Invalid token format")
@staticmethod
def generate_secure_hash(data: str) -> str:
"""Generate secure hash for token storage"""
return hashlib.sha256(data.encode()).hexdigest()
@staticmethod
async def track_login_attempt(email: str, ip_address: str, success: bool) -> None:
"""Track login attempts for security monitoring"""
try:
# This would use Redis for production
# For now, just log the attempt
logger.info(f"Login attempt tracked: email={email}, ip={ip_address}, success={success}")
except Exception as e:
logger.warning(f"Failed to track login attempt: {e}")
@staticmethod
def is_token_expired(token: str) -> bool:
"""Check if token is expired"""
try:
payload = SecurityManager.decode_token(token)
exp_timestamp = payload.get("exp")
if exp_timestamp:
exp_datetime = datetime.fromtimestamp(exp_timestamp, tz=timezone.utc)
return datetime.now(timezone.utc) > exp_datetime
return True
except Exception:
return True
@staticmethod
def verify_token(token: str) -> Optional[Dict[str, Any]]:
"""Verify JWT token with enhanced error handling"""
try:
payload = jwt_handler.verify_token(token)
if payload:
logger.debug(f"Token verified successfully for user: {payload.get('email', 'unknown')}")
return payload
except Exception as e:
logger.warning(f"Token verification failed: {e}")
return None
@staticmethod
async def track_login_attempt(email: str, ip_address: str, success: bool) -> None:
"""Track login attempts for security monitoring"""

View File

@@ -5,34 +5,75 @@
Token models for authentication service
"""
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from datetime import datetime
import hashlib
import uuid
from datetime import datetime, timezone
from sqlalchemy import Column, String, Boolean, DateTime, Text, Index
from sqlalchemy.dialects.postgresql import UUID
from shared.database.base import Base
class RefreshToken(Base):
"""Refresh token model"""
"""
Refresh token model - FIXED to prevent duplicate constraint violations
"""
__tablename__ = "refresh_tokens"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
token_hash = Column(String(255), nullable=False, unique=True)
is_active = Column(Boolean, default=True)
expires_at = Column(DateTime, nullable=False)
user_id = Column(UUID(as_uuid=True), nullable=False, index=True)
# Session metadata
ip_address = Column(String(45))
user_agent = Column(Text)
device_info = Column(Text)
# ✅ FIX 1: Use TEXT instead of VARCHAR to handle longer tokens
token = Column(Text, nullable=False)
created_at = Column(DateTime, default=datetime.utcnow)
revoked_at = Column(DateTime)
# ✅ FIX 2: Add token hash for uniqueness instead of full token
token_hash = Column(String(255), nullable=True, unique=True)
expires_at = Column(DateTime(timezone=True), nullable=False)
is_revoked = Column(Boolean, default=False, nullable=False)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
revoked_at = Column(DateTime(timezone=True), nullable=True)
# ✅ FIX 3: Add indexes for better performance
__table_args__ = (
Index('ix_refresh_tokens_user_id_active', 'user_id', 'is_revoked'),
Index('ix_refresh_tokens_expires_at', 'expires_at'),
Index('ix_refresh_tokens_token_hash', 'token_hash'),
)
def __init__(self, **kwargs):
"""Initialize refresh token with automatic hash generation"""
super().__init__(**kwargs)
if self.token and not self.token_hash:
self.token_hash = self._generate_token_hash(self.token)
@staticmethod
def _generate_token_hash(token: str) -> str:
"""Generate a hash of the token for uniqueness checking"""
return hashlib.sha256(token.encode()).hexdigest()
def update_token(self, new_token: str):
"""Update token and regenerate hash"""
self.token = new_token
self.token_hash = self._generate_token_hash(new_token)
@classmethod
async def create_refresh_token(cls, user_id: uuid.UUID, token: str, expires_at: datetime):
"""
Create a new refresh token with proper hash generation
"""
return cls(
id=uuid.uuid4(),
user_id=user_id,
token=token,
token_hash=cls._generate_token_hash(token),
expires_at=expires_at,
is_revoked=False,
created_at=datetime.now(timezone.utc)
)
def __repr__(self):
return f"<RefreshToken(id={self.id}, user_id={self.user_id})>"
return f"<RefreshToken(id={self.id}, user_id={self.user_id}, expires_at={self.expires_at})>"
class LoginAttempt(Base):
"""Login attempt tracking model"""
@@ -45,7 +86,7 @@ class LoginAttempt(Base):
success = Column(Boolean, default=False)
failure_reason = Column(String(255))
created_at = Column(DateTime, default=datetime.utcnow)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
def __repr__(self):
return f"<LoginAttempt(id={self.id}, email={self.email}, success={self.success})>"

View File

@@ -3,16 +3,21 @@
Authentication Service - Updated to support registration with direct token issuance
"""
from datetime import datetime, timezone, timedelta
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from fastapi import HTTPException, status
import hashlib
import uuid
from datetime import datetime, timedelta, timezone
from typing import Dict, Any, Optional
from fastapi import HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update
from sqlalchemy.exc import IntegrityError
import structlog
from app.models.users import User, RefreshToken
from app.schemas.auth import UserRegistration, UserLogin
from app.core.security import SecurityManager
from app.services.messaging import publish_user_registered, publish_user_login, publish_user_logout
from app.services.messaging import publish_user_registered, publish_user_login
logger = structlog.get_logger()
@@ -20,91 +25,71 @@ class AuthService:
"""Enhanced Authentication service with unified token response"""
@staticmethod
async def register_user_with_tokens(
email: str,
password: str,
full_name: str,
db: AsyncSession
) -> Dict[str, Any]:
"""Register new user and return tokens directly - COMPLETELY FIXED"""
async def register_user(user_data: UserRegistration, db: AsyncSession) -> Dict[str, Any]:
"""Register a new user with FIXED token generation"""
try:
# Check if user already exists
result = await db.execute(select(User).where(User.email == email))
existing_user = result.scalar_one_or_none()
if existing_user:
existing_user = await db.execute(
select(User).where(User.email == user_data.email)
)
if existing_user.scalar_one_or_none():
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
status_code=status.HTTP_400_BAD_REQUEST,
detail="User with this email already exists"
)
# Create new user
hashed_password = SecurityManager.hash_password(password)
hashed_password = SecurityManager.hash_password(user_data.password)
new_user = User(
email=email,
id=uuid.uuid4(),
email=user_data.email,
full_name=user_data.full_name,
hashed_password=hashed_password,
full_name=full_name,
is_active=True,
is_verified=False,
created_at=datetime.now(timezone.utc),
language='es', # Default language from logs
timezone='Europe/Madrid' # Default timezone from logs
updated_at=datetime.now(timezone.utc)
)
db.add(new_user)
await db.flush() # Get user ID without committing
# ✅ FIX 1: Create COMPLETE and CONSISTENT user_data for token generation
token_user_data = {
# ✅ FIX 1: Create SEPARATE access and refresh tokens with different payloads
access_token_data = {
"user_id": str(new_user.id),
"email": new_user.email, # ✅ Ensure email is included
"email": new_user.email,
"full_name": new_user.full_name,
"is_verified": new_user.is_verified,
"is_active": new_user.is_active
"is_active": new_user.is_active,
"type": "access" # ✅ Explicitly mark as access token
}
logger.debug(f"Creating tokens for user: {email} with data: {token_user_data}")
# ✅ FIX 2: Generate tokens with VALIDATED user data
try:
access_token = SecurityManager.create_access_token(user_data=token_user_data)
logger.debug(f"Access token created successfully for {email}")
except Exception as token_error:
logger.error(f"Access token creation failed for {email}: {token_error}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Token creation failed: {token_error}"
)
# ✅ FIX 3: Create refresh token with minimal but complete data
refresh_token_data = {
"user_id": str(new_user.id),
"email": new_user.email # Include email for consistency
"email": new_user.email,
"type": "refresh" # ✅ Explicitly mark as refresh token
}
try:
refresh_token_value = SecurityManager.create_refresh_token(user_data=refresh_token_data)
logger.debug(f"Refresh token created successfully for {email}")
except Exception as refresh_error:
logger.error(f"Refresh token creation failed for {email}: {refresh_error}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Refresh token creation failed: {refresh_error}"
)
logger.debug(f"Creating tokens for registration: {user_data.email}")
# Store refresh token in database
# ✅ FIX 2: Generate tokens with different payloads
access_token = SecurityManager.create_access_token(user_data=access_token_data)
refresh_token_value = SecurityManager.create_refresh_token(user_data=refresh_token_data)
logger.debug(f"Tokens created successfully for {user_data.email}")
# ✅ FIX 3: Store ONLY the refresh token in database (not access token)
refresh_token = RefreshToken(
id=uuid.uuid4(),
user_id=new_user.id,
token=refresh_token_value,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
is_revoked=False
token=refresh_token_value, # Store the actual refresh token
expires_at=datetime.now(timezone.utc) + timedelta(days=30),
is_revoked=False,
created_at=datetime.now(timezone.utc)
)
db.add(refresh_token)
# ✅ FIX 4: Only commit after ALL token creation succeeds
await db.commit()
await db.refresh(new_user)
# Publish registration event (non-blocking)
try:
@@ -112,14 +97,13 @@ class AuthService:
"user_id": str(new_user.id),
"email": new_user.email,
"full_name": new_user.full_name,
"registered_at": new_user.created_at.isoformat()
"registered_at": datetime.now(timezone.utc).isoformat()
})
except Exception as e:
logger.warning(f"Failed to publish registration event: {e}")
logger.info(f"User registered successfully with tokens: {email}")
logger.info(f"User registered successfully: {user_data.email}")
# Return unified token response format
return {
"access_token": access_token,
"refresh_token": refresh_token_value,
@@ -138,111 +122,101 @@ class AuthService:
except HTTPException:
await db.rollback()
raise
except Exception as e:
except IntegrityError as e:
await db.rollback()
logger.error(f"Registration with tokens failed for {email}: {e}")
# ✅ FIX 5: Provide more specific error information
logger.error(f"Registration failed for {user_data.email}: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Registration failed: {str(e)}"
detail="Registration failed"
)
except Exception as e:
await db.rollback()
logger.error(f"Registration failed for {user_data.email}: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Registration failed"
)
@staticmethod
async def create_user(
email: str,
password: str,
full_name: str,
db: AsyncSession
) -> User:
"""
Create user without tokens (LEGACY METHOD - kept for compatibility)
Use register_user_with_tokens() for new implementations
"""
async def login_user(login_data: UserLogin, db: AsyncSession) -> Dict[str, Any]:
"""Login user with FIXED token generation and SQLAlchemy syntax"""
try:
# Check if user already exists
result = await db.execute(select(User).where(User.email == email))
existing_user = result.scalar_one_or_none()
if existing_user:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="User with this email already exists"
)
# Create new user
hashed_password = SecurityManager.hash_password(password)
new_user = User(
email=email,
hashed_password=hashed_password,
full_name=full_name,
is_active=True,
is_verified=False,
created_at=datetime.now(timezone.utc)
# Find user
result = await db.execute(
select(User).where(User.email == login_data.email)
)
db.add(new_user)
await db.commit()
await db.refresh(new_user)
logger.info(f"User created (legacy): {email}")
return new_user
except HTTPException:
await db.rollback()
raise
except Exception as e:
await db.rollback()
logger.error(f"User creation failed for {email}: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="User creation failed"
)
@staticmethod
async def login(email: str, password: str, db: AsyncSession) -> Dict[str, Any]:
"""Login user and return tokens - FIXED VERSION"""
try:
# Get user
result = await db.execute(select(User).where(User.email == email))
user = result.scalar_one_or_none()
if not user or not SecurityManager.verify_password(password, user.hashed_password):
if not user or not SecurityManager.verify_password(login_data.password, user.hashed_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid credentials"
detail="Invalid email or password"
)
# ✅ FIX 1: Create COMPLETE user data for access token
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Account is deactivated"
)
# ✅ FIX 4: Revoke existing refresh tokens using proper SQLAlchemy ORM syntax
logger.debug(f"Revoking existing refresh tokens for user: {user.id}")
# Using SQLAlchemy ORM update (more reliable than raw SQL)
stmt = update(RefreshToken).where(
RefreshToken.user_id == user.id,
RefreshToken.is_revoked == False
).values(
is_revoked=True,
revoked_at=datetime.now(timezone.utc)
)
result = await db.execute(stmt)
revoked_count = result.rowcount
logger.debug(f"Revoked {revoked_count} existing refresh tokens for user: {user.id}")
# ✅ FIX 5: Create DIFFERENT token payloads
access_token_data = {
"user_id": str(user.id),
"email": user.email, # ✅ Include email
"email": user.email,
"full_name": user.full_name,
"is_verified": user.is_verified,
"is_active": user.is_active
"is_active": user.is_active,
"type": "access" # ✅ Explicitly mark as access token
}
# ✅ FIX 2: Create COMPLETE user data for refresh token
refresh_token_data = {
"user_id": str(user.id),
"email": user.email # ✅ Include email for consistency
"email": user.email,
"type": "refresh", # ✅ Explicitly mark as refresh token
"jti": str(uuid.uuid4()) # ✅ Add unique identifier for each refresh token
}
logger.debug(f"Creating access token for login with data: {list(access_token_data.keys())}")
logger.debug(f"Creating refresh token for login with data: {list(refresh_token_data.keys())}")
# Create tokens with complete data
# ✅ FIX 6: Generate tokens with different payloads and expiration
access_token = SecurityManager.create_access_token(user_data=access_token_data)
refresh_token_value = SecurityManager.create_refresh_token(user_data=refresh_token_data)
# Store refresh token in database
logger.debug(f"Access token created successfully for user {login_data.email}")
logger.debug(f"Refresh token created successfully for user {str(user.id)}")
# ✅ FIX 7: Store ONLY refresh token in database with unique constraint handling
refresh_token = RefreshToken(
id=uuid.uuid4(),
user_id=user.id,
token=refresh_token_value,
token=refresh_token_value, # This should be the refresh token, not access token
expires_at=datetime.now(timezone.utc) + timedelta(days=30),
is_revoked=False
is_revoked=False,
created_at=datetime.now(timezone.utc)
)
db.add(refresh_token)
# Update last login
user.last_login = datetime.now(timezone.utc)
await db.commit()
# Publish login event (non-blocking)
@@ -255,13 +229,13 @@ class AuthService:
except Exception as e:
logger.warning(f"Failed to publish login event: {e}")
logger.info(f"User logged in successfully: {email}")
logger.info(f"User logged in successfully: {login_data.email}")
return {
"access_token": access_token,
"refresh_token": refresh_token_value,
"token_type": "bearer",
"expires_in": 3600, # 1 hour
"expires_in": 1800, # 30 minutes
"user": {
"id": str(user.id),
"email": user.email,
@@ -275,108 +249,119 @@ class AuthService:
except HTTPException:
await db.rollback()
raise
except IntegrityError as e:
await db.rollback()
logger.error(f"Login failed for {login_data.email}: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Login failed"
)
except Exception as e:
await db.rollback()
logger.error(f"Login failed for {email}: {e}")
logger.error(f"Login failed for {login_data.email}: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Login failed"
)
@staticmethod
async def logout_user(user_id: str, refresh_token: str, db: AsyncSession) -> bool:
"""Logout user by revoking refresh token"""
try:
# Revoke the specific refresh token using ORM
stmt = update(RefreshToken).where(
RefreshToken.user_id == user_id,
RefreshToken.token == refresh_token,
RefreshToken.is_revoked == False
).values(
is_revoked=True,
revoked_at=datetime.now(timezone.utc)
)
result = await db.execute(stmt)
if result.rowcount > 0:
await db.commit()
logger.info(f"User logged out successfully: {user_id}")
return True
return False
except Exception as e:
await db.rollback()
logger.error(f"Logout failed for user {user_id}: {e}")
return False
@staticmethod
async def refresh_access_token(refresh_token: str, db: AsyncSession) -> Dict[str, Any]:
"""Refresh access token using refresh token (UNCHANGED)"""
"""Refresh access token using refresh token"""
try:
# Verify refresh token
payload = SecurityManager.verify_token(refresh_token)
if not payload:
payload = SecurityManager.decode_token(refresh_token)
user_id = payload.get("user_id")
if not user_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token"
)
user_id = payload.get("user_id")
if not user_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token payload"
)
# Check if refresh token exists and is not revoked
# Check if refresh token exists and is valid using ORM
result = await db.execute(
select(RefreshToken).where(
RefreshToken.user_id == user_id,
RefreshToken.token == refresh_token,
RefreshToken.is_revoked == False,
RefreshToken.expires_at > datetime.now(timezone.utc)
)
)
stored_token = result.scalar_one_or_none()
if not stored_token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired refresh token"
)
# Get user info
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
# Get user
user_result = await db.execute(
select(User).where(User.id == user_id)
)
user = user_result.scalar_one_or_none()
if not user:
if not user or not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found"
detail="User not found or inactive"
)
# Create new access token
access_token = SecurityManager.create_access_token(
user_data={
"user_id": str(user.id),
"email": user.email,
"full_name": user.full_name,
"is_verified": user.is_verified
}
)
access_token_data = {
"user_id": str(user.id),
"email": user.email,
"full_name": user.full_name,
"is_verified": user.is_verified,
"is_active": user.is_active,
"type": "access"
}
logger.info(f"Token refreshed successfully for user {user_id}")
new_access_token = SecurityManager.create_access_token(user_data=access_token_data)
return {
"access_token": access_token,
"access_token": new_access_token,
"token_type": "bearer",
"expires_in": 3600
"expires_in": 1800
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Token refresh error: {e}")
logger.error(f"Token refresh failed: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token refresh failed"
)
@staticmethod
async def logout(refresh_token: str, db: AsyncSession) -> bool:
"""Logout user by revoking refresh token (UNCHANGED)"""
try:
# Revoke refresh token
result = await db.execute(
select(RefreshToken).where(RefreshToken.token == refresh_token)
)
token = result.scalar_one_or_none()
if token:
token.is_revoked = True
token.revoked_at = datetime.now(timezone.utc)
await db.commit()
return True
except Exception as e:
logger.error(f"Logout error: {e}")
await db.rollback()
return False
@staticmethod
async def verify_user_token(token: str) -> Dict[str, Any]:
"""Verify access token and return user info (UNCHANGED)"""

View File

@@ -20,7 +20,7 @@ router = APIRouter()
training_service = TrainingService()
@router.get("/", response_model=List[TrainedModelResponse])
@router.get("/tenants/{tenant_id}/", response_model=List[TrainedModelResponse])
async def get_trained_models(
tenant_id: str = Depends(get_current_tenant_id_dep),
db: AsyncSession = Depends(get_db)

View File

@@ -45,7 +45,7 @@ def get_training_service() -> TrainingService:
"""Factory function for TrainingService dependency"""
return TrainingService()
@router.post("/jobs", response_model=TrainingJobResponse)
@router.post("/tenants/{tenant_id}/jobs", response_model=TrainingJobResponse)
async def start_training_job(
request: TrainingJobRequest,
background_tasks: BackgroundTasks,
@@ -110,7 +110,7 @@ async def start_training_job(
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to start training job: {str(e)}")
@router.get("/jobs", response_model=List[TrainingJobResponse])
@router.get("/tenants/{tenant_id}/jobs", response_model=List[TrainingJobResponse])
async def get_training_jobs(
status: Optional[TrainingStatus] = Query(None, description="Filter jobs by status"),
limit: int = Query(100, ge=1, le=1000),
@@ -146,7 +146,7 @@ async def get_training_jobs(
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to get training jobs: {str(e)}")
@router.get("/jobs/{job_id}", response_model=TrainingJobResponse)
@router.get("/tenants/{tenant_id}/jobs/{job_id}", response_model=TrainingJobResponse)
async def get_training_job(
job_id: str,
tenant_id: UUID = Path(..., description="Tenant ID"),
@@ -179,7 +179,7 @@ async def get_training_job(
job_id=job_id)
raise HTTPException(status_code=500, detail=f"Failed to get training job: {str(e)}")
@router.get("/jobs/{job_id}/progress", response_model=TrainingJobProgress)
@router.get("/tenants/{tenant_id}/jobs/{job_id}/progress", response_model=TrainingJobProgress)
async def get_training_progress(
job_id: str,
tenant_id: UUID = Path(..., description="Tenant ID"),
@@ -209,7 +209,7 @@ async def get_training_progress(
job_id=job_id)
raise HTTPException(status_code=500, detail=f"Failed to get training progress: {str(e)}")
@router.post("/jobs/{job_id}/cancel")
@router.post("/tenants/{tenant_id}/jobs/{job_id}/cancel")
async def cancel_training_job(
job_id: str,
tenant_id: UUID = Path(..., description="Tenant ID"),
@@ -254,7 +254,7 @@ async def cancel_training_job(
job_id=job_id)
raise HTTPException(status_code=500, detail=f"Failed to cancel training job: {str(e)}")
@router.post("/products/{product_name}", response_model=TrainingJobResponse)
@router.post("/tenants/{tenant_id}/products/{product_name}", response_model=TrainingJobResponse)
async def train_single_product(
product_name: str,
request: SingleProductTrainingRequest,
@@ -309,7 +309,7 @@ async def train_single_product(
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to train product: {str(e)}")
@router.post("/validate", response_model=DataValidationResponse)
@router.post("/tenants/{tenant_id}/validate", response_model=DataValidationResponse)
async def validate_training_data(
request: DataValidationRequest,
tenant_id: UUID = Path(..., description="Tenant ID"),
@@ -340,7 +340,7 @@ async def validate_training_data(
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to validate data: {str(e)}")
@router.get("/models")
@router.get("/tenants/{tenant_id}/models")
async def get_trained_models(
product_name: Optional[str] = Query(None),
tenant_id: UUID = Path(..., description="Tenant ID"),
@@ -370,7 +370,7 @@ async def get_trained_models(
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to get models: {str(e)}")
@router.delete("/models/{model_id}")
@router.delete("/tenants/{tenant_id}/models/{model_id}")
@require_role("admin") # Only admins can delete models
async def delete_model(
model_id: str,
@@ -407,7 +407,7 @@ async def delete_model(
model_id=model_id)
raise HTTPException(status_code=500, detail=f"Failed to delete model: {str(e)}")
@router.get("/stats")
@router.get("/tenants/{tenant_id}/stats")
async def get_training_stats(
start_date: Optional[datetime] = Query(None),
end_date: Optional[datetime] = Query(None),
@@ -438,7 +438,7 @@ async def get_training_stats(
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to get stats: {str(e)}")
@router.post("/retrain/all")
@router.post("/tenants/{tenant_id}/retrain/all")
async def retrain_all_products(
request: TrainingJobRequest,
background_tasks: BackgroundTasks,