Improve auth models
This commit is contained in:
@@ -9,6 +9,7 @@ import structlog
|
|||||||
|
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
from app.schemas.auth import UserResponse, PasswordChangeRequest
|
from app.schemas.auth import UserResponse, PasswordChangeRequest
|
||||||
|
from app.schemas.users import UserUpdate
|
||||||
from app.services.user_service import UserService
|
from app.services.user_service import UserService
|
||||||
from app.core.auth import get_current_user
|
from app.core.auth import get_current_user
|
||||||
from app.models.users import User
|
from app.models.users import User
|
||||||
@@ -29,8 +30,6 @@ async def get_current_user_info(
|
|||||||
full_name=current_user.full_name,
|
full_name=current_user.full_name,
|
||||||
is_active=current_user.is_active,
|
is_active=current_user.is_active,
|
||||||
is_verified=current_user.is_verified,
|
is_verified=current_user.is_verified,
|
||||||
tenant_id=str(current_user.tenant_id) if current_user.tenant_id else None,
|
|
||||||
role=current_user.role,
|
|
||||||
phone=current_user.phone,
|
phone=current_user.phone,
|
||||||
language=current_user.language,
|
language=current_user.language,
|
||||||
timezone=current_user.timezone,
|
timezone=current_user.timezone,
|
||||||
@@ -46,7 +45,7 @@ async def get_current_user_info(
|
|||||||
|
|
||||||
@router.put("/me", response_model=UserResponse)
|
@router.put("/me", response_model=UserResponse)
|
||||||
async def update_current_user(
|
async def update_current_user(
|
||||||
user_update: dict,
|
user_update: UserUpdate,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
@@ -59,8 +58,6 @@ async def update_current_user(
|
|||||||
full_name=updated_user.full_name,
|
full_name=updated_user.full_name,
|
||||||
is_active=updated_user.is_active,
|
is_active=updated_user.is_active,
|
||||||
is_verified=updated_user.is_verified,
|
is_verified=updated_user.is_verified,
|
||||||
tenant_id=str(updated_user.tenant_id) if updated_user.tenant_id else None,
|
|
||||||
role=updated_user.role,
|
|
||||||
phone=updated_user.phone,
|
phone=updated_user.phone,
|
||||||
language=updated_user.language,
|
language=updated_user.language,
|
||||||
timezone=updated_user.timezone,
|
timezone=updated_user.timezone,
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
# ================================================================
|
|
||||||
# services/auth/app/models/users.py - FIXED VERSION
|
# services/auth/app/models/users.py - FIXED VERSION
|
||||||
# ================================================================
|
# ================================================================
|
||||||
"""
|
"""
|
||||||
@@ -7,6 +6,7 @@ User models for authentication service - FIXED
|
|||||||
|
|
||||||
from sqlalchemy import Column, String, Boolean, DateTime, Text
|
from sqlalchemy import Column, String, Boolean, DateTime, Text
|
||||||
from sqlalchemy.dialects.postgresql import UUID
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
from sqlalchemy.orm import relationship # Import relationship
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
@@ -22,8 +22,7 @@ class User(Base):
|
|||||||
full_name = Column(String(255), nullable=False)
|
full_name = Column(String(255), nullable=False)
|
||||||
is_active = Column(Boolean, default=True)
|
is_active = Column(Boolean, default=True)
|
||||||
is_verified = Column(Boolean, default=False)
|
is_verified = Column(Boolean, default=False)
|
||||||
tenant_id = Column(UUID(as_uuid=True), nullable=True)
|
# Removed tenant_id and role from User model
|
||||||
role = Column(String(50), default="user") # user, admin, super_admin
|
|
||||||
|
|
||||||
# FIXED: Use timezone-aware datetime for all datetime fields
|
# FIXED: Use timezone-aware datetime for all datetime fields
|
||||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||||
@@ -35,6 +34,11 @@ class User(Base):
|
|||||||
language = Column(String(10), default="es")
|
language = Column(String(10), default="es")
|
||||||
timezone = Column(String(50), default="Europe/Madrid")
|
timezone = Column(String(50), default="Europe/Madrid")
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
# Define the many-to-many relationship through TenantMember
|
||||||
|
tenant_memberships = relationship("TenantMember", back_populates="user", cascade="all, delete-orphan") # Changed back_populates to avoid conflict
|
||||||
|
tenants = relationship("Tenant", secondary="tenant_members", back_populates="users")
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<User(id={self.id}, email={self.email})>"
|
return f"<User(id={self.id}, email={self.email})>"
|
||||||
|
|
||||||
@@ -46,8 +50,7 @@ class User(Base):
|
|||||||
"full_name": self.full_name,
|
"full_name": self.full_name,
|
||||||
"is_active": self.is_active,
|
"is_active": self.is_active,
|
||||||
"is_verified": self.is_verified,
|
"is_verified": self.is_verified,
|
||||||
"tenant_id": str(self.tenant_id) if self.tenant_id else None,
|
# Removed tenant_id and role from to_dict
|
||||||
"role": self.role,
|
|
||||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||||
"last_login": self.last_login.isoformat() if self.last_login else None,
|
"last_login": self.last_login.isoformat() if self.last_login else None,
|
||||||
"phone": self.phone,
|
"phone": self.phone,
|
||||||
|
|||||||
@@ -35,8 +35,6 @@ class UserProfile(BaseModel):
|
|||||||
timezone: str
|
timezone: str
|
||||||
is_active: bool
|
is_active: bool
|
||||||
is_verified: bool
|
is_verified: bool
|
||||||
tenant_id: Optional[str]
|
|
||||||
role: str
|
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
last_login: Optional[datetime]
|
last_login: Optional[datetime]
|
||||||
|
|
||||||
|
|||||||
@@ -39,9 +39,6 @@ class AuthService:
|
|||||||
detail="Email already registered"
|
detail="Email already registered"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate tenant_id if not provided
|
|
||||||
tenant_id = user_data.tenant_id if hasattr(user_data, 'tenant_id') and user_data.tenant_id else str(uuid.uuid4())
|
|
||||||
|
|
||||||
# Hash password
|
# Hash password
|
||||||
hashed_password = security_manager.hash_password(user_data.password)
|
hashed_password = security_manager.hash_password(user_data.password)
|
||||||
|
|
||||||
@@ -49,7 +46,6 @@ class AuthService:
|
|||||||
user = User(
|
user = User(
|
||||||
email=user_data.email,
|
email=user_data.email,
|
||||||
hashed_password=hashed_password,
|
hashed_password=hashed_password,
|
||||||
tenant_id=tenant_id,
|
|
||||||
full_name=user_data.full_name,
|
full_name=user_data.full_name,
|
||||||
phone=user_data.phone,
|
phone=user_data.phone,
|
||||||
language=user_data.language,
|
language=user_data.language,
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from sqlalchemy import select, update, delete
|
|||||||
from fastapi import HTTPException, status
|
from fastapi import HTTPException, status
|
||||||
from passlib.context import CryptContext
|
from passlib.context import CryptContext
|
||||||
import structlog
|
import structlog
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from app.models.users import User
|
from app.models.users import User
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
@@ -59,6 +60,7 @@ class UserService:
|
|||||||
update_data[field] = user_data[field]
|
update_data[field] = user_data[field]
|
||||||
|
|
||||||
if update_data:
|
if update_data:
|
||||||
|
update_data["updated_at"] = datetime.now(timezone.utc)
|
||||||
await db.execute(
|
await db.execute(
|
||||||
update(User)
|
update(User)
|
||||||
.where(User.id == user_id)
|
.where(User.id == user_id)
|
||||||
@@ -107,7 +109,7 @@ class UserService:
|
|||||||
await db.execute(
|
await db.execute(
|
||||||
update(User)
|
update(User)
|
||||||
.where(User.id == user_id)
|
.where(User.id == user_id)
|
||||||
.values(hashed_password=new_hashed_password)
|
.values(hashed_password=new_hashed_password, updated_at=datetime.now(timezone.utc))
|
||||||
)
|
)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
|
|||||||
@@ -39,25 +39,30 @@ class Tenant(Base):
|
|||||||
model_trained = Column(Boolean, default=False)
|
model_trained = Column(Boolean, default=False)
|
||||||
last_training_date = Column(DateTime)
|
last_training_date = Column(DateTime)
|
||||||
|
|
||||||
# Ownership
|
# Ownership (The user who created the tenant, still a direct link)
|
||||||
owner_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
owner_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||||
|
|
||||||
# Timestamps
|
# Timestamps
|
||||||
created_at = Column(DateTime, default=datetime.utcnow)
|
created_at = Column(DateTime, default=datetime.utcnow)
|
||||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
# Define the many-to-many relationship through TenantMember
|
||||||
|
members = relationship("TenantMember", back_populates="tenant", cascade="all, delete-orphan")
|
||||||
|
users = relationship("User", secondary="tenant_members", back_populates="tenants")
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<Tenant(id={self.id}, name={self.name})>"
|
return f"<Tenant(id={self.id}, name={self.name})>"
|
||||||
|
|
||||||
class TenantMember(Base):
|
class TenantMember(Base):
|
||||||
"""Tenant membership model for team access"""
|
"""Tenant membership model for team access - Association Table"""
|
||||||
__tablename__ = "tenant_members"
|
__tablename__ = "tenant_members"
|
||||||
|
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id", ondelete="CASCADE"), nullable=False)
|
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id", ondelete="CASCADE"), nullable=False)
|
||||||
user_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True) # Added ForeignKey to users.id
|
||||||
|
|
||||||
# Role and permissions
|
# Role and permissions specific to this tenant
|
||||||
role = Column(String(50), default="member") # owner, admin, member, viewer
|
role = Column(String(50), default="member") # owner, admin, member, viewer
|
||||||
permissions = Column(Text) # JSON string of permissions
|
permissions = Column(Text) # JSON string of permissions
|
||||||
|
|
||||||
@@ -69,5 +74,9 @@ class TenantMember(Base):
|
|||||||
|
|
||||||
created_at = Column(DateTime, default=datetime.utcnow)
|
created_at = Column(DateTime, default=datetime.utcnow)
|
||||||
|
|
||||||
|
# Relationships to access the associated Tenant and User objects
|
||||||
|
tenant = relationship("Tenant", back_populates="members")
|
||||||
|
user = relationship("User", back_populates="tenant_memberships") # Changed back_populates to avoid conflict
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<TenantMember(tenant_id={self.tenant_id}, user_id={self.user_id}, role={self.role})>"
|
return f"<TenantMember(tenant_id={self.tenant_id}, user_id={self.user_id}, role={self.role})>"
|
||||||
@@ -0,0 +1,295 @@
|
|||||||
|
# services/tenant/app/services/tenant_service.py
|
||||||
|
"""
|
||||||
|
Tenant service business logic
|
||||||
|
"""
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy import select, update, and_
|
||||||
|
from fastapi import HTTPException, status
|
||||||
|
import uuid
|
||||||
|
import json
|
||||||
|
|
||||||
|
from app.models.tenants import Tenant, TenantMember
|
||||||
|
from app.schemas.tenants import BakeryRegistration, TenantResponse, TenantAccessResponse, TenantUpdate
|
||||||
|
from app.services.messaging import publish_tenant_created, publish_member_added
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
|
class TenantService:
|
||||||
|
"""Tenant management business logic"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def create_bakery(bakery_data: BakeryRegistration, owner_id: str, db: AsyncSession) -> TenantResponse:
|
||||||
|
"""Create a new bakery/tenant"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Generate subdomain if not provided
|
||||||
|
subdomain = bakery_data.name.lower().replace(' ', '-').replace('á', 'a').replace('é', 'e').replace('í', 'i').replace('ó', 'o').replace('ú', 'u')
|
||||||
|
subdomain = ''.join(c for c in subdomain if c.isalnum() or c == '-')
|
||||||
|
|
||||||
|
# Check if subdomain already exists
|
||||||
|
result = await db.execute(
|
||||||
|
select(Tenant).where(Tenant.subdomain == subdomain)
|
||||||
|
)
|
||||||
|
if result.scalar_one_or_none():
|
||||||
|
subdomain = f"{subdomain}-{uuid.uuid4().hex[:6]}"
|
||||||
|
|
||||||
|
# Create tenant
|
||||||
|
tenant = Tenant(
|
||||||
|
name=bakery_data.name,
|
||||||
|
subdomain=subdomain,
|
||||||
|
business_type=bakery_data.business_type,
|
||||||
|
address=bakery_data.address,
|
||||||
|
city=bakery_data.city,
|
||||||
|
postal_code=bakery_data.postal_code,
|
||||||
|
phone=bakery_data.phone,
|
||||||
|
owner_id=owner_id,
|
||||||
|
is_active=True
|
||||||
|
)
|
||||||
|
|
||||||
|
db.add(tenant)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(tenant)
|
||||||
|
|
||||||
|
# Create owner membership
|
||||||
|
owner_membership = TenantMember(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
user_id=owner_id,
|
||||||
|
role="owner",
|
||||||
|
permissions=json.dumps(["read", "write", "admin", "delete"]),
|
||||||
|
is_active=True,
|
||||||
|
joined_at=datetime.now(timezone.utc)
|
||||||
|
)
|
||||||
|
|
||||||
|
db.add(owner_membership)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# Publish event
|
||||||
|
await publish_tenant_created(str(tenant.id), owner_id, bakery_data.name)
|
||||||
|
|
||||||
|
logger.info(f"Bakery created: {bakery_data.name} (ID: {tenant.id})")
|
||||||
|
|
||||||
|
return TenantResponse.from_orm(tenant)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Error creating bakery: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Failed to create bakery"
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def verify_user_access(user_id: str, tenant_id: str, db: AsyncSession) -> TenantAccessResponse:
|
||||||
|
"""Verify if user has access to tenant"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check if user is tenant member
|
||||||
|
result = await db.execute(
|
||||||
|
select(TenantMember).where(
|
||||||
|
and_(
|
||||||
|
TenantMember.user_id == user_id,
|
||||||
|
TenantMember.tenant_id == tenant_id,
|
||||||
|
TenantMember.is_active == True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
membership = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not membership:
|
||||||
|
return TenantAccessResponse(
|
||||||
|
has_access=False,
|
||||||
|
role="none",
|
||||||
|
permissions=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse permissions
|
||||||
|
permissions = []
|
||||||
|
if membership.permissions:
|
||||||
|
try:
|
||||||
|
permissions = json.loads(membership.permissions)
|
||||||
|
except:
|
||||||
|
permissions = []
|
||||||
|
|
||||||
|
return TenantAccessResponse(
|
||||||
|
has_access=True,
|
||||||
|
role=membership.role,
|
||||||
|
permissions=permissions
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error verifying user access: {e}")
|
||||||
|
return TenantAccessResponse(
|
||||||
|
has_access=False,
|
||||||
|
role="none",
|
||||||
|
permissions=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_user_tenants(user_id: str, db: AsyncSession) -> List[TenantResponse]:
|
||||||
|
"""Get all tenants accessible by user"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get user's tenant memberships
|
||||||
|
result = await db.execute(
|
||||||
|
select(Tenant)
|
||||||
|
.join(TenantMember, Tenant.id == TenantMember.tenant_id)
|
||||||
|
.where(
|
||||||
|
and_(
|
||||||
|
TenantMember.user_id == user_id,
|
||||||
|
TenantMember.is_active == True,
|
||||||
|
Tenant.is_active == True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.order_by(Tenant.name)
|
||||||
|
)
|
||||||
|
|
||||||
|
tenants = result.scalars().all()
|
||||||
|
return [TenantResponse.from_orm(tenant) for tenant in tenants]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting user tenants: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_tenant_by_id(tenant_id: str, db: AsyncSession) -> Optional[TenantResponse]:
|
||||||
|
"""Get tenant by ID"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await db.execute(
|
||||||
|
select(Tenant).where(Tenant.id == tenant_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
tenant = result.scalar_one_or_none()
|
||||||
|
if tenant:
|
||||||
|
return TenantResponse.from_orm(tenant)
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting tenant: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def update_tenant(tenant_id: str, update_data: TenantUpdate, user_id: str, db: AsyncSession) -> TenantResponse:
|
||||||
|
"""Update tenant information"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Verify user has admin access
|
||||||
|
access = await TenantService.verify_user_access(user_id, tenant_id, db)
|
||||||
|
if not access.has_access or access.role not in ["owner", "admin"]:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Insufficient permissions to update tenant"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update tenant
|
||||||
|
update_values = update_data.dict(exclude_unset=True)
|
||||||
|
if update_values:
|
||||||
|
update_values["updated_at"] = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
await db.execute(
|
||||||
|
update(Tenant)
|
||||||
|
.where(Tenant.id == tenant_id)
|
||||||
|
.values(**update_values)
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# Get updated tenant
|
||||||
|
result = await db.execute(
|
||||||
|
select(Tenant).where(Tenant.id == tenant_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
tenant = result.scalar_one()
|
||||||
|
logger.info(f"Tenant updated: {tenant.name} (ID: {tenant_id})")
|
||||||
|
|
||||||
|
return TenantResponse.from_orm(tenant)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Error updating tenant: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Failed to update tenant"
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def add_team_member(tenant_id: str, user_id: str, role: str, invited_by: str, db: AsyncSession) -> TenantMemberResponse:
|
||||||
|
"""Add a team member to tenant"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Verify inviter has admin access
|
||||||
|
access = await TenantService.verify_user_access(invited_by, tenant_id, db)
|
||||||
|
if not access.has_access or access.role not in ["owner", "admin"]:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Insufficient permissions to add team members"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if user is already a member
|
||||||
|
result = await db.execute(
|
||||||
|
select(TenantMember).where(
|
||||||
|
and_(
|
||||||
|
TenantMember.tenant_id == tenant_id,
|
||||||
|
TenantMember.user_id == user_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
existing_member = result.scalar_one_or_none()
|
||||||
|
if existing_member:
|
||||||
|
if existing_member.is_active:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="User is already a member of this tenant"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Reactivate existing membership
|
||||||
|
existing_member.is_active = True
|
||||||
|
existing_member.role = role
|
||||||
|
existing_member.joined_at = datetime.now(timezone.utc)
|
||||||
|
await db.commit()
|
||||||
|
return TenantMemberResponse.from_orm(existing_member)
|
||||||
|
|
||||||
|
# Create new membership
|
||||||
|
permissions = ["read"]
|
||||||
|
if role in ["admin", "owner"]:
|
||||||
|
permissions.extend(["write", "admin"])
|
||||||
|
if role == "owner":
|
||||||
|
permissions.append("delete")
|
||||||
|
|
||||||
|
member = TenantMember(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
role=role,
|
||||||
|
permissions=json.dumps(permissions),
|
||||||
|
invited_by=invited_by,
|
||||||
|
is_active=True,
|
||||||
|
joined_at=datetime.now(timezone.utc)
|
||||||
|
)
|
||||||
|
|
||||||
|
db.add(member)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(member)
|
||||||
|
|
||||||
|
# Publish event
|
||||||
|
await publish_member_added(tenant_id, user_id, role)
|
||||||
|
|
||||||
|
logger.info(f"Team member added: {user_id} to tenant {tenant_id} as {role}")
|
||||||
|
|
||||||
|
return TenantMemberResponse.from_orm(member)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Error adding team member: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Failed to add team member"
|
||||||
|
)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ Training service configuration
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
|
from typing import List
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
"""Application settings"""
|
"""Application settings"""
|
||||||
@@ -38,6 +39,14 @@ class Settings(BaseSettings):
|
|||||||
PROPHET_WEEKLY_SEASONALITY: bool = os.getenv("PROPHET_WEEKLY_SEASONALITY", "true").lower() == "true"
|
PROPHET_WEEKLY_SEASONALITY: bool = os.getenv("PROPHET_WEEKLY_SEASONALITY", "true").lower() == "true"
|
||||||
PROPHET_YEARLY_SEASONALITY: bool = os.getenv("PROPHET_YEARLY_SEASONALITY", "true").lower() == "true"
|
PROPHET_YEARLY_SEASONALITY: bool = os.getenv("PROPHET_YEARLY_SEASONALITY", "true").lower() == "true"
|
||||||
|
|
||||||
|
# CORS
|
||||||
|
CORS_ORIGINS: str = os.getenv("CORS_ORIGINS", "http://localhost:3000,http://localhost:3001")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def CORS_ORIGINS_LIST(self) -> List[str]:
|
||||||
|
"""Get CORS origins as list"""
|
||||||
|
return [origin.strip() for origin in self.CORS_ORIGINS.split(",")]
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
env_file = ".env"
|
env_file = ".env"
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
# services/training/app/main.py
|
# ================================================================
|
||||||
|
# services/training/app/main.py - FIXED VERSION
|
||||||
|
# ================================================================
|
||||||
"""
|
"""
|
||||||
Training Service Main Application
|
Training Service Main Application
|
||||||
Enhanced with proper error handling, monitoring, and lifecycle management
|
Enhanced with proper error handling, monitoring, and lifecycle management
|
||||||
@@ -19,7 +21,7 @@ from app.api import training, models
|
|||||||
from app.services.messaging import setup_messaging, cleanup_messaging
|
from app.services.messaging import setup_messaging, cleanup_messaging
|
||||||
from shared.monitoring.logging import setup_logging
|
from shared.monitoring.logging import setup_logging
|
||||||
from shared.monitoring.metrics import MetricsCollector
|
from shared.monitoring.metrics import MetricsCollector
|
||||||
from shared.auth.decorators import require_auth
|
# REMOVED: from shared.auth.decorators import require_auth
|
||||||
|
|
||||||
# Setup structured logging
|
# Setup structured logging
|
||||||
setup_logging("training-service", settings.LOG_LEVEL)
|
setup_logging("training-service", settings.LOG_LEVEL)
|
||||||
@@ -52,6 +54,9 @@ async def lifespan(app: FastAPI):
|
|||||||
metrics_collector.start_metrics_server(8080)
|
metrics_collector.start_metrics_server(8080)
|
||||||
logger.info("Metrics server started on port 8080")
|
logger.info("Metrics server started on port 8080")
|
||||||
|
|
||||||
|
# Store metrics collector in app state
|
||||||
|
app.state.metrics_collector = metrics_collector
|
||||||
|
|
||||||
# Mark service as ready
|
# Mark service as ready
|
||||||
app.state.ready = True
|
app.state.ready = True
|
||||||
logger.info("Training Service startup completed successfully")
|
logger.info("Training Service startup completed successfully")
|
||||||
@@ -67,74 +72,57 @@ async def lifespan(app: FastAPI):
|
|||||||
logger.info("Shutting down Training Service")
|
logger.info("Shutting down Training Service")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Stop metrics server
|
||||||
|
if hasattr(app.state, 'metrics_collector'):
|
||||||
|
await app.state.metrics_collector.shutdown()
|
||||||
|
|
||||||
# Cleanup messaging
|
# Cleanup messaging
|
||||||
logger.info("Cleaning up messaging")
|
|
||||||
await cleanup_messaging()
|
await cleanup_messaging()
|
||||||
|
logger.info("Messaging cleanup completed")
|
||||||
|
|
||||||
# Close database connections
|
# Close database connections
|
||||||
logger.info("Closing database connections")
|
|
||||||
await database_manager.close_connections()
|
await database_manager.close_connections()
|
||||||
|
logger.info("Database connections closed")
|
||||||
logger.info("Training Service shutdown completed")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error during shutdown", error=str(e))
|
logger.error("Error during shutdown", error=str(e))
|
||||||
|
|
||||||
# Create FastAPI app with lifespan
|
logger.info("Training Service shutdown completed")
|
||||||
|
|
||||||
|
# Create FastAPI application with lifespan
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="Training Service",
|
title="Bakery Training Service",
|
||||||
description="ML model training service for bakery demand forecasting",
|
description="ML training service for bakery demand forecasting",
|
||||||
version="1.0.0",
|
version="1.0.0",
|
||||||
docs_url="/docs" if settings.DEBUG else None,
|
docs_url="/docs",
|
||||||
redoc_url="/redoc" if settings.DEBUG else None,
|
redoc_url="/redoc",
|
||||||
lifespan=lifespan
|
lifespan=lifespan
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize app state
|
# Add middleware
|
||||||
app.state.ready = False
|
|
||||||
|
|
||||||
# Security middleware
|
|
||||||
if not settings.DEBUG:
|
|
||||||
app.add_middleware(
|
|
||||||
TrustedHostMiddleware,
|
|
||||||
allowed_hosts=["localhost", "127.0.0.1", "training-service", "*.bakery-forecast.local"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# CORS middleware
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=["*"] if settings.DEBUG else [
|
allow_origins=settings.CORS_ORIGINS_LIST,
|
||||||
"http://localhost:3000",
|
|
||||||
"http://localhost:8000",
|
|
||||||
"https://dashboard.bakery-forecast.es"
|
|
||||||
],
|
|
||||||
allow_credentials=True,
|
allow_credentials=True,
|
||||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Request logging middleware
|
app.add_middleware(
|
||||||
|
TrustedHostMiddleware,
|
||||||
|
allowed_hosts=settings.ALLOWED_HOSTS
|
||||||
|
)
|
||||||
|
|
||||||
|
# Request middleware for logging and metrics
|
||||||
@app.middleware("http")
|
@app.middleware("http")
|
||||||
async def log_requests(request: Request, call_next):
|
async def process_request(request: Request, call_next):
|
||||||
"""Log all incoming requests with timing"""
|
"""Process requests with logging and metrics"""
|
||||||
start_time = asyncio.get_event_loop().time()
|
start_time = asyncio.get_event_loop().time()
|
||||||
|
|
||||||
# Log request
|
|
||||||
logger.info(
|
|
||||||
"Request started",
|
|
||||||
method=request.method,
|
|
||||||
path=request.url.path,
|
|
||||||
client_ip=request.client.host if request.client else "unknown"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Process request
|
|
||||||
try:
|
try:
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
|
|
||||||
# Calculate duration
|
|
||||||
duration = asyncio.get_event_loop().time() - start_time
|
duration = asyncio.get_event_loop().time() - start_time
|
||||||
|
|
||||||
# Log response
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Request completed",
|
"Request completed",
|
||||||
method=request.method,
|
method=request.method,
|
||||||
@@ -189,19 +177,18 @@ async def global_exception_handler(request: Request, exc: Exception):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Include API routers
|
# Include API routers - NO AUTH DEPENDENCIES HERE
|
||||||
|
# Authentication is handled by API Gateway
|
||||||
app.include_router(
|
app.include_router(
|
||||||
training.router,
|
training.router,
|
||||||
prefix="/training",
|
prefix="/training",
|
||||||
tags=["training"],
|
tags=["training"]
|
||||||
dependencies=[require_auth] if not settings.DEBUG else []
|
|
||||||
)
|
)
|
||||||
|
|
||||||
app.include_router(
|
app.include_router(
|
||||||
models.router,
|
models.router,
|
||||||
prefix="/models",
|
prefix="/models",
|
||||||
tags=["models"],
|
tags=["models"]
|
||||||
dependencies=[require_auth] if not settings.DEBUG else []
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Health check endpoints
|
# Health check endpoints
|
||||||
@@ -217,66 +204,32 @@ async def health_check():
|
|||||||
|
|
||||||
@app.get("/health/ready")
|
@app.get("/health/ready")
|
||||||
async def readiness_check():
|
async def readiness_check():
|
||||||
"""Kubernetes readiness probe"""
|
"""Kubernetes readiness probe endpoint"""
|
||||||
if not app.state.ready:
|
checks = {
|
||||||
|
"database": await get_db_health(),
|
||||||
|
"application": getattr(app.state, 'ready', False)
|
||||||
|
}
|
||||||
|
|
||||||
|
if all(checks.values()):
|
||||||
|
return {"status": "ready", "checks": checks}
|
||||||
|
else:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=503,
|
status_code=503,
|
||||||
content={"status": "not_ready", "message": "Service is starting up"}
|
content={"status": "not ready", "checks": checks}
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"status": "ready", "service": "training-service"}
|
|
||||||
|
|
||||||
@app.get("/health/live")
|
|
||||||
async def liveness_check():
|
|
||||||
"""Kubernetes liveness probe"""
|
|
||||||
# Check database connectivity
|
|
||||||
try:
|
|
||||||
db_healthy = await get_db_health()
|
|
||||||
if not db_healthy:
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=503,
|
|
||||||
content={"status": "unhealthy", "reason": "database_unavailable"}
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Database health check failed", error=str(e))
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=503,
|
|
||||||
content={"status": "unhealthy", "reason": "database_error"}
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"status": "alive", "service": "training-service"}
|
|
||||||
|
|
||||||
@app.get("/metrics")
|
@app.get("/metrics")
|
||||||
async def get_metrics():
|
async def get_metrics():
|
||||||
"""Expose service metrics"""
|
"""Prometheus metrics endpoint"""
|
||||||
return {
|
if hasattr(app.state, 'metrics_collector'):
|
||||||
"training_jobs_active": metrics_collector.get_gauge("training_jobs_active", 0),
|
return app.state.metrics_collector.get_metrics()
|
||||||
"training_jobs_completed": metrics_collector.get_counter("training_jobs_completed", 0),
|
return {"status": "metrics not available"}
|
||||||
"training_jobs_failed": metrics_collector.get_counter("training_jobs_failed", 0),
|
|
||||||
"models_trained_total": metrics_collector.get_counter("models_trained_total", 0),
|
|
||||||
"uptime_seconds": metrics_collector.get_gauge("uptime_seconds", 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
@app.get("/")
|
|
||||||
async def root():
|
|
||||||
"""Root endpoint with service information"""
|
|
||||||
return {
|
|
||||||
"service": "training-service",
|
|
||||||
"version": "1.0.0",
|
|
||||||
"description": "ML model training service for bakery demand forecasting",
|
|
||||||
"docs": "/docs" if settings.DEBUG else "Documentation disabled in production",
|
|
||||||
"health": "/health"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Development server configuration
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
"app.main:app",
|
"app.main:app",
|
||||||
host="0.0.0.0",
|
host="0.0.0.0",
|
||||||
port=8000,
|
port=settings.PORT,
|
||||||
reload=settings.DEBUG,
|
reload=settings.DEBUG,
|
||||||
log_level=settings.LOG_LEVEL.lower(),
|
log_level=settings.LOG_LEVEL.lower()
|
||||||
access_log=settings.DEBUG,
|
|
||||||
server_header=False,
|
|
||||||
date_header=False
|
|
||||||
)
|
)
|
||||||
Reference in New Issue
Block a user