Few fixes
This commit is contained in:
@@ -15,7 +15,7 @@ COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy shared libraries
|
||||
COPY --from=shared /shared /app/shared
|
||||
COPY shared/ /app/shared/
|
||||
|
||||
# Copy application code
|
||||
COPY . .
|
||||
|
||||
@@ -3,7 +3,7 @@ Authentication service configuration
|
||||
"""
|
||||
|
||||
import os
|
||||
from pydantic import BaseSettings
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings"""
|
||||
|
||||
0
services/auth/app/models/__init__.py
Normal file
0
services/auth/app/models/__init__.py
Normal file
74
services/auth/app/models/users.py
Normal file
74
services/auth/app/models/users.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""
|
||||
User models for authentication service
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
from shared.database.base import Base
|
||||
|
||||
class User(Base):
|
||||
"""User model"""
|
||||
__tablename__ = "users"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
email = Column(String(255), unique=True, index=True, nullable=False)
|
||||
hashed_password = Column(String(255), nullable=False)
|
||||
full_name = Column(String(255), nullable=False)
|
||||
is_active = Column(Boolean, default=True)
|
||||
is_verified = Column(Boolean, default=False)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=True)
|
||||
role = Column(String(50), default="user") # user, admin, super_admin
|
||||
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
last_login = Column(DateTime)
|
||||
|
||||
# Profile fields
|
||||
phone = Column(String(20))
|
||||
language = Column(String(10), default="es")
|
||||
timezone = Column(String(50), default="Europe/Madrid")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<User(id={self.id}, email={self.email})>"
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert user to dictionary"""
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"email": self.email,
|
||||
"full_name": self.full_name,
|
||||
"is_active": self.is_active,
|
||||
"is_verified": self.is_verified,
|
||||
"tenant_id": str(self.tenant_id) if self.tenant_id else None,
|
||||
"role": self.role,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"last_login": self.last_login.isoformat() if self.last_login else None,
|
||||
"phone": self.phone,
|
||||
"language": self.language,
|
||||
"timezone": self.timezone
|
||||
}
|
||||
|
||||
|
||||
class UserSession(Base):
|
||||
"""User session model"""
|
||||
__tablename__ = "user_sessions"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
refresh_token_hash = Column(String(255), nullable=False)
|
||||
is_active = Column(Boolean, default=True)
|
||||
expires_at = Column(DateTime, nullable=False)
|
||||
|
||||
# Session metadata
|
||||
ip_address = Column(String(45))
|
||||
user_agent = Column(Text)
|
||||
device_info = Column(Text)
|
||||
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<UserSession(id={self.id}, user_id={self.user_id})>"
|
||||
@@ -15,7 +15,7 @@ class UserRegistration(BaseModel):
|
||||
password: str = Field(..., min_length=settings.PASSWORD_MIN_LENGTH)
|
||||
full_name: str = Field(..., min_length=2, max_length=100)
|
||||
phone: Optional[str] = None
|
||||
language: str = Field(default="es", regex="^(es|en)$")
|
||||
language: str = Field(default="es", pattern="^(es|en)$")
|
||||
|
||||
@validator('password')
|
||||
def validate_password(cls, v):
|
||||
@@ -97,7 +97,7 @@ class UserUpdate(BaseModel):
|
||||
"""User update schema"""
|
||||
full_name: Optional[str] = Field(None, min_length=2, max_length=100)
|
||||
phone: Optional[str] = None
|
||||
language: Optional[str] = Field(None, regex="^(es|en)$")
|
||||
language: Optional[str] = Field(None, pattern="^(es|en)$")
|
||||
timezone: Optional[str] = None
|
||||
|
||||
@validator('phone')
|
||||
|
||||
284
services/auth/app/services/auth_service.py
Normal file
284
services/auth/app/services/auth_service.py
Normal file
@@ -0,0 +1,284 @@
|
||||
"""
|
||||
Authentication service business logic
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, update
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from app.models.users import User, UserSession
|
||||
from app.schemas.auth import UserRegistration, UserLogin, TokenResponse, UserResponse
|
||||
from app.core.security import security_manager
|
||||
from app.services.messaging import message_publisher
|
||||
from shared.messaging.events import UserRegisteredEvent, UserLoginEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AuthService:
|
||||
"""Authentication service business logic"""
|
||||
|
||||
@staticmethod
|
||||
async def register_user(user_data: UserRegistration, db: AsyncSession) -> UserResponse:
|
||||
"""Register a new user"""
|
||||
|
||||
# Check if user already exists
|
||||
result = await db.execute(
|
||||
select(User).where(User.email == user_data.email)
|
||||
)
|
||||
existing_user = result.scalar_one_or_none()
|
||||
|
||||
if existing_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email already registered"
|
||||
)
|
||||
|
||||
# Hash password
|
||||
hashed_password = security_manager.hash_password(user_data.password)
|
||||
|
||||
# Create user
|
||||
user = User(
|
||||
email=user_data.email,
|
||||
hashed_password=hashed_password,
|
||||
full_name=user_data.full_name,
|
||||
phone=user_data.phone,
|
||||
language=user_data.language,
|
||||
is_active=True,
|
||||
is_verified=False # Email verification required
|
||||
)
|
||||
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
|
||||
# Publish user registered event
|
||||
await message_publisher.publish_event(
|
||||
"user_events",
|
||||
"user.registered",
|
||||
UserRegisteredEvent(
|
||||
event_id="",
|
||||
service_name="auth-service",
|
||||
timestamp=datetime.utcnow(),
|
||||
data={
|
||||
"user_id": str(user.id),
|
||||
"email": user.email,
|
||||
"full_name": user.full_name,
|
||||
"language": user.language
|
||||
}
|
||||
).__dict__
|
||||
)
|
||||
|
||||
logger.info(f"User registered: {user.email}")
|
||||
return UserResponse(**user.to_dict())
|
||||
|
||||
@staticmethod
|
||||
async def login_user(login_data: UserLogin, db: AsyncSession, ip_address: str, user_agent: str) -> TokenResponse:
|
||||
"""Authenticate user and create tokens"""
|
||||
|
||||
# Check login attempts
|
||||
if not await security_manager.check_login_attempts(login_data.email):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail="Too many login attempts. Please try again later."
|
||||
)
|
||||
|
||||
# Get user
|
||||
result = await db.execute(
|
||||
select(User).where(User.email == login_data.email)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user or not security_manager.verify_password(login_data.password, user.hashed_password):
|
||||
await security_manager.increment_login_attempts(login_data.email)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid email or password"
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Account is inactive"
|
||||
)
|
||||
|
||||
# Clear login attempts
|
||||
await security_manager.clear_login_attempts(login_data.email)
|
||||
|
||||
# Update last login
|
||||
await db.execute(
|
||||
update(User)
|
||||
.where(User.id == user.id)
|
||||
.values(last_login=datetime.utcnow())
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
# Create tokens
|
||||
token_data = {
|
||||
"user_id": str(user.id),
|
||||
"email": user.email,
|
||||
"tenant_id": str(user.tenant_id) if user.tenant_id else None,
|
||||
"role": user.role
|
||||
}
|
||||
|
||||
access_token = security_manager.create_access_token(token_data)
|
||||
refresh_token = security_manager.create_refresh_token(token_data)
|
||||
|
||||
# Store refresh token
|
||||
await security_manager.store_refresh_token(str(user.id), refresh_token)
|
||||
|
||||
# Create session record
|
||||
session = UserSession(
|
||||
user_id=user.id,
|
||||
refresh_token_hash=security_manager.hash_password(refresh_token),
|
||||
expires_at=datetime.utcnow() + timedelta(days=7),
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent
|
||||
)
|
||||
|
||||
db.add(session)
|
||||
await db.commit()
|
||||
|
||||
# Publish login event
|
||||
await message_publisher.publish_event(
|
||||
"user_events",
|
||||
"user.login",
|
||||
UserLoginEvent(
|
||||
event_id="",
|
||||
service_name="auth-service",
|
||||
timestamp=datetime.utcnow(),
|
||||
data={
|
||||
"user_id": str(user.id),
|
||||
"email": user.email,
|
||||
"ip_address": ip_address,
|
||||
"user_agent": user_agent
|
||||
}
|
||||
).__dict__
|
||||
)
|
||||
|
||||
logger.info(f"User logged in: {user.email}")
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
expires_in=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def refresh_token(refresh_token: str, db: AsyncSession) -> TokenResponse:
|
||||
"""Refresh access token"""
|
||||
|
||||
# Verify refresh token
|
||||
payload = security_manager.verify_token(refresh_token)
|
||||
if not payload or payload.get("type") != "refresh":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token"
|
||||
)
|
||||
|
||||
user_id = payload.get("user_id")
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token"
|
||||
)
|
||||
|
||||
# Verify refresh token is stored
|
||||
if not await security_manager.verify_refresh_token(user_id, refresh_token):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token"
|
||||
)
|
||||
|
||||
# Get user
|
||||
result = await db.execute(
|
||||
select(User).where(User.id == user_id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found or inactive"
|
||||
)
|
||||
|
||||
# Create new tokens
|
||||
token_data = {
|
||||
"user_id": str(user.id),
|
||||
"email": user.email,
|
||||
"tenant_id": str(user.tenant_id) if user.tenant_id else None,
|
||||
"role": user.role
|
||||
}
|
||||
|
||||
new_access_token = security_manager.create_access_token(token_data)
|
||||
new_refresh_token = security_manager.create_refresh_token(token_data)
|
||||
|
||||
# Update stored refresh token
|
||||
await security_manager.store_refresh_token(str(user.id), new_refresh_token)
|
||||
|
||||
logger.info(f"Token refreshed for user: {user.email}")
|
||||
|
||||
return TokenResponse(
|
||||
access_token=new_access_token,
|
||||
refresh_token=new_refresh_token,
|
||||
expires_in=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def verify_token(token: str, db: AsyncSession) -> Dict[str, Any]:
|
||||
"""Verify access token"""
|
||||
|
||||
# Verify token
|
||||
payload = security_manager.verify_token(token)
|
||||
if not payload or payload.get("type") != "access":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired token"
|
||||
)
|
||||
|
||||
user_id = payload.get("user_id")
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token"
|
||||
)
|
||||
|
||||
# Get user
|
||||
result = await db.execute(
|
||||
select(User).where(User.id == user_id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found or inactive"
|
||||
)
|
||||
|
||||
return {
|
||||
"user_id": str(user.id),
|
||||
"email": user.email,
|
||||
"tenant_id": str(user.tenant_id) if user.tenant_id else None,
|
||||
"role": user.role,
|
||||
"full_name": user.full_name,
|
||||
"is_verified": user.is_verified
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def logout_user(user_id: str, db: AsyncSession):
|
||||
"""Logout user and revoke tokens"""
|
||||
|
||||
# Revoke refresh token
|
||||
await security_manager.revoke_refresh_token(user_id)
|
||||
|
||||
# Deactivate user sessions
|
||||
await db.execute(
|
||||
update(UserSession)
|
||||
.where(UserSession.user_id == user_id)
|
||||
.values(is_active=False)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"User logged out: {user_id}")
|
||||
@@ -3,7 +3,7 @@ Shared JWT Authentication Handler
|
||||
Used across all microservices for consistent authentication
|
||||
"""
|
||||
|
||||
import jwt
|
||||
from jose import jwt
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any
|
||||
import logging
|
||||
@@ -53,6 +53,6 @@ class JWTHandler:
|
||||
except jwt.ExpiredSignatureError:
|
||||
logger.warning("Token has expired")
|
||||
return None
|
||||
except jwt.InvalidTokenError:
|
||||
except jwt.JWTError:
|
||||
logger.warning("Invalid token")
|
||||
return None
|
||||
@@ -3,7 +3,7 @@ uLudata service configuration
|
||||
"""
|
||||
|
||||
import os
|
||||
from pydantic import BaseSettings
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings"""
|
||||
|
||||
0
services/data/app/models/__init__.py
Normal file
0
services/data/app/models/__init__.py
Normal file
@@ -3,7 +3,7 @@ Shared JWT Authentication Handler
|
||||
Used across all microservices for consistent authentication
|
||||
"""
|
||||
|
||||
import jwt
|
||||
from jose import jwt
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any
|
||||
import logging
|
||||
@@ -53,6 +53,6 @@ class JWTHandler:
|
||||
except jwt.ExpiredSignatureError:
|
||||
logger.warning("Token has expired")
|
||||
return None
|
||||
except jwt.InvalidTokenError:
|
||||
except jwt.JWTError:
|
||||
logger.warning("Invalid token")
|
||||
return None
|
||||
@@ -3,7 +3,7 @@ uLuforecasting service configuration
|
||||
"""
|
||||
|
||||
import os
|
||||
from pydantic import BaseSettings
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings"""
|
||||
|
||||
0
services/forecasting/app/models/__init__.py
Normal file
0
services/forecasting/app/models/__init__.py
Normal file
@@ -3,7 +3,7 @@ Shared JWT Authentication Handler
|
||||
Used across all microservices for consistent authentication
|
||||
"""
|
||||
|
||||
import jwt
|
||||
from jose import jwt
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any
|
||||
import logging
|
||||
@@ -53,6 +53,6 @@ class JWTHandler:
|
||||
except jwt.ExpiredSignatureError:
|
||||
logger.warning("Token has expired")
|
||||
return None
|
||||
except jwt.InvalidTokenError:
|
||||
except jwt.JWTError:
|
||||
logger.warning("Invalid token")
|
||||
return None
|
||||
@@ -3,7 +3,7 @@ uLunotification service configuration
|
||||
"""
|
||||
|
||||
import os
|
||||
from pydantic import BaseSettings
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings"""
|
||||
|
||||
0
services/notification/app/models/__init__.py
Normal file
0
services/notification/app/models/__init__.py
Normal file
@@ -3,7 +3,7 @@ Shared JWT Authentication Handler
|
||||
Used across all microservices for consistent authentication
|
||||
"""
|
||||
|
||||
import jwt
|
||||
from jose import jwt
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any
|
||||
import logging
|
||||
@@ -53,6 +53,6 @@ class JWTHandler:
|
||||
except jwt.ExpiredSignatureError:
|
||||
logger.warning("Token has expired")
|
||||
return None
|
||||
except jwt.InvalidTokenError:
|
||||
except jwt.JWTError:
|
||||
logger.warning("Invalid token")
|
||||
return None
|
||||
@@ -3,7 +3,7 @@ uLutenant service configuration
|
||||
"""
|
||||
|
||||
import os
|
||||
from pydantic import BaseSettings
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings"""
|
||||
|
||||
0
services/tenant/app/models/__init__.py
Normal file
0
services/tenant/app/models/__init__.py
Normal file
@@ -3,7 +3,7 @@ Shared JWT Authentication Handler
|
||||
Used across all microservices for consistent authentication
|
||||
"""
|
||||
|
||||
import jwt
|
||||
from jose import jwt
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any
|
||||
import logging
|
||||
@@ -53,6 +53,6 @@ class JWTHandler:
|
||||
except jwt.ExpiredSignatureError:
|
||||
logger.warning("Token has expired")
|
||||
return None
|
||||
except jwt.InvalidTokenError:
|
||||
except jwt.JWTError:
|
||||
logger.warning("Invalid token")
|
||||
return None
|
||||
@@ -17,7 +17,7 @@ COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy shared libraries
|
||||
COPY --from=shared /shared /app/shared
|
||||
COPY shared/ /app/shared/
|
||||
|
||||
# Copy application code
|
||||
COPY . .
|
||||
|
||||
@@ -3,7 +3,7 @@ Training service configuration
|
||||
"""
|
||||
|
||||
import os
|
||||
from pydantic import BaseSettings
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings"""
|
||||
|
||||
0
services/training/app/models/__init__.py
Normal file
0
services/training/app/models/__init__.py
Normal file
101
services/training/app/models/training.py
Normal file
101
services/training/app/models/training.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Training models
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, Integer, DateTime, Text, JSON, Boolean, Float
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
from shared.database.base import Base
|
||||
|
||||
class TrainingJob(Base):
|
||||
"""Training job model"""
|
||||
__tablename__ = "training_jobs"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
status = Column(String(20), nullable=False, default="queued") # queued, running, completed, failed
|
||||
progress = Column(Integer, default=0)
|
||||
current_step = Column(String(200))
|
||||
requested_by = Column(UUID(as_uuid=True), nullable=False)
|
||||
|
||||
# Timing
|
||||
started_at = Column(DateTime, default=datetime.utcnow)
|
||||
completed_at = Column(DateTime)
|
||||
duration_seconds = Column(Integer)
|
||||
|
||||
# Results
|
||||
models_trained = Column(JSON)
|
||||
metrics = Column(JSON)
|
||||
error_message = Column(Text)
|
||||
|
||||
# Metadata
|
||||
training_data_from = Column(DateTime)
|
||||
training_data_to = Column(DateTime)
|
||||
total_data_points = Column(Integer)
|
||||
products_count = Column(Integer)
|
||||
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<TrainingJob(id={self.id}, tenant_id={self.tenant_id}, status={self.status})>"
|
||||
|
||||
class TrainedModel(Base):
|
||||
"""Trained model information"""
|
||||
__tablename__ = "trained_models"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
training_job_id = Column(UUID(as_uuid=True), nullable=False)
|
||||
|
||||
# Model details
|
||||
product_name = Column(String(100), nullable=False)
|
||||
model_type = Column(String(50), nullable=False, default="prophet")
|
||||
model_version = Column(String(20), nullable=False)
|
||||
model_path = Column(String(500)) # Path to saved model file
|
||||
|
||||
# Performance metrics
|
||||
mape = Column(Float) # Mean Absolute Percentage Error
|
||||
rmse = Column(Float) # Root Mean Square Error
|
||||
mae = Column(Float) # Mean Absolute Error
|
||||
r2_score = Column(Float) # R-squared score
|
||||
|
||||
# Training details
|
||||
training_samples = Column(Integer)
|
||||
validation_samples = Column(Integer)
|
||||
features_used = Column(JSON)
|
||||
hyperparameters = Column(JSON)
|
||||
|
||||
# Status
|
||||
is_active = Column(Boolean, default=True)
|
||||
last_used_at = Column(DateTime)
|
||||
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<TrainedModel(id={self.id}, product={self.product_name}, tenant={self.tenant_id})>"
|
||||
|
||||
class TrainingLog(Base):
|
||||
"""Training log entries"""
|
||||
__tablename__ = "training_logs"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
training_job_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
|
||||
level = Column(String(10), nullable=False) # DEBUG, INFO, WARNING, ERROR
|
||||
message = Column(Text, nullable=False)
|
||||
step = Column(String(100))
|
||||
progress = Column(Integer)
|
||||
|
||||
# Additional data
|
||||
execution_time = Column(Float) # Time taken for this step
|
||||
memory_usage = Column(Float) # Memory usage in MB
|
||||
metadata = Column(JSON) # Additional metadata
|
||||
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<TrainingLog(id={self.id}, level={self.level})>"
|
||||
@@ -4,47 +4,3 @@ Messaging service for training service
|
||||
|
||||
from shared.messaging.rabbitmq import RabbitMQClient
|
||||
from app.core.config import settings
|
||||
|
||||
# Global message publisher
|
||||
message_publisher = RabbitMQClient(settings.RABBITMQ_URL)
|
||||
|
||||
|
||||
# services/training/Dockerfile
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
g++ \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements
|
||||
COPY requirements.txt .
|
||||
|
||||
# Install Python dependencies
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy shared libraries
|
||||
COPY --from=shared /shared /app/shared
|
||||
|
||||
# Copy application code
|
||||
COPY . .
|
||||
|
||||
# Create model storage directory
|
||||
RUN mkdir -p /app/models
|
||||
|
||||
# Add shared libraries to Python path
|
||||
ENV PYTHONPATH="/app:/app/shared:$PYTHONPATH"
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD curl -f http://localhost:8000/health || exit 1
|
||||
|
||||
# Run application
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
384
services/training/app/services/training_service.py
Normal file
384
services/training/app/services/training_service.py
Normal file
@@ -0,0 +1,384 @@
|
||||
"""
|
||||
Training service business logic
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, update, and_
|
||||
import httpx
|
||||
import uuid
|
||||
import json
|
||||
|
||||
from app.core.config import settings
|
||||
from app.models.training import TrainingJob, TrainedModel, TrainingLog
|
||||
from app.schemas.training import TrainingRequest, TrainingJobResponse, TrainedModelResponse
|
||||
from app.ml.trainer import MLTrainer
|
||||
from app.services.messaging import message_publisher
|
||||
from shared.messaging.events import TrainingStartedEvent, TrainingCompletedEvent, TrainingFailedEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TrainingService:
|
||||
"""Training service business logic"""
|
||||
|
||||
def __init__(self):
|
||||
self.ml_trainer = MLTrainer()
|
||||
|
||||
async def start_training(self, request: TrainingRequest, user_data: dict, db: AsyncSession) -> TrainingJobResponse:
|
||||
"""Start a new training job"""
|
||||
|
||||
tenant_id = user_data.get("tenant_id")
|
||||
if not tenant_id:
|
||||
raise ValueError("User must be associated with a tenant")
|
||||
|
||||
# Check if there's already a running job for this tenant
|
||||
existing_job = await self._get_running_job(tenant_id, db)
|
||||
if existing_job:
|
||||
raise ValueError("Training job already running for this tenant")
|
||||
|
||||
# Create training job
|
||||
training_job = TrainingJob(
|
||||
tenant_id=tenant_id,
|
||||
status="queued",
|
||||
progress=0,
|
||||
current_step="Queued for training",
|
||||
requested_by=user_data.get("user_id"),
|
||||
training_data_from=datetime.utcnow() - timedelta(days=request.training_days),
|
||||
training_data_to=datetime.utcnow()
|
||||
)
|
||||
|
||||
db.add(training_job)
|
||||
await db.commit()
|
||||
await db.refresh(training_job)
|
||||
|
||||
# Start training in background
|
||||
asyncio.create_task(self._execute_training(training_job.id, request, db))
|
||||
|
||||
# Publish training started event
|
||||
await message_publisher.publish_event(
|
||||
"training_events",
|
||||
"training.started",
|
||||
TrainingStartedEvent(
|
||||
event_id=str(uuid.uuid4()),
|
||||
service_name="training-service",
|
||||
timestamp=datetime.utcnow(),
|
||||
data={
|
||||
"job_id": str(training_job.id),
|
||||
"tenant_id": tenant_id,
|
||||
"requested_by": user_data.get("user_id"),
|
||||
"training_days": request.training_days
|
||||
}
|
||||
).__dict__
|
||||
)
|
||||
|
||||
logger.info(f"Training job started: {training_job.id} for tenant: {tenant_id}")
|
||||
|
||||
return TrainingJobResponse(
|
||||
id=str(training_job.id),
|
||||
tenant_id=tenant_id,
|
||||
status=training_job.status,
|
||||
progress=training_job.progress,
|
||||
current_step=training_job.current_step,
|
||||
started_at=training_job.started_at,
|
||||
completed_at=training_job.completed_at,
|
||||
duration_seconds=training_job.duration_seconds,
|
||||
models_trained=training_job.models_trained,
|
||||
metrics=training_job.metrics,
|
||||
error_message=training_job.error_message
|
||||
)
|
||||
|
||||
async def get_training_status(self, job_id: str, user_data: dict, db: AsyncSession) -> TrainingJobResponse:
|
||||
"""Get training job status"""
|
||||
|
||||
tenant_id = user_data.get("tenant_id")
|
||||
|
||||
result = await db.execute(
|
||||
select(TrainingJob).where(
|
||||
and_(
|
||||
TrainingJob.id == job_id,
|
||||
TrainingJob.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
job = result.scalar_one_or_none()
|
||||
if not job:
|
||||
raise ValueError("Training job not found")
|
||||
|
||||
return TrainingJobResponse(
|
||||
id=str(job.id),
|
||||
tenant_id=str(job.tenant_id),
|
||||
status=job.status,
|
||||
progress=job.progress,
|
||||
current_step=job.current_step,
|
||||
started_at=job.started_at,
|
||||
completed_at=job.completed_at,
|
||||
duration_seconds=job.duration_seconds,
|
||||
models_trained=job.models_trained,
|
||||
metrics=job.metrics,
|
||||
error_message=job.error_message
|
||||
)
|
||||
|
||||
async def get_trained_models(self, user_data: dict, db: AsyncSession) -> List[TrainedModelResponse]:
|
||||
"""Get trained models for tenant"""
|
||||
|
||||
tenant_id = user_data.get("tenant_id")
|
||||
|
||||
result = await db.execute(
|
||||
select(TrainedModel).where(
|
||||
and_(
|
||||
TrainedModel.tenant_id == tenant_id,
|
||||
TrainedModel.is_active == True
|
||||
)
|
||||
).order_by(TrainedModel.created_at.desc())
|
||||
)
|
||||
|
||||
models = result.scalars().all()
|
||||
|
||||
return [
|
||||
TrainedModelResponse(
|
||||
id=str(model.id),
|
||||
product_name=model.product_name,
|
||||
model_type=model.model_type,
|
||||
model_version=model.model_version,
|
||||
mape=model.mape,
|
||||
rmse=model.rmse,
|
||||
mae=model.mae,
|
||||
r2_score=model.r2_score,
|
||||
training_samples=model.training_samples,
|
||||
features_used=model.features_used,
|
||||
is_active=model.is_active,
|
||||
created_at=model.created_at,
|
||||
last_used_at=model.last_used_at
|
||||
)
|
||||
for model in models
|
||||
]
|
||||
|
||||
async def get_training_jobs(self, user_data: dict, limit: int, offset: int, db: AsyncSession) -> List[TrainingJobResponse]:
|
||||
"""Get training jobs for tenant"""
|
||||
|
||||
tenant_id = user_data.get("tenant_id")
|
||||
|
||||
result = await db.execute(
|
||||
select(TrainingJob).where(
|
||||
TrainingJob.tenant_id == tenant_id
|
||||
).order_by(TrainingJob.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
|
||||
jobs = result.scalars().all()
|
||||
|
||||
return [
|
||||
TrainingJobResponse(
|
||||
id=str(job.id),
|
||||
tenant_id=str(job.tenant_id),
|
||||
status=job.status,
|
||||
progress=job.progress,
|
||||
current_step=job.current_step,
|
||||
started_at=job.started_at,
|
||||
completed_at=job.completed_at,
|
||||
duration_seconds=job.duration_seconds,
|
||||
models_trained=job.models_trained,
|
||||
metrics=job.metrics,
|
||||
error_message=job.error_message
|
||||
)
|
||||
for job in jobs
|
||||
]
|
||||
|
||||
async def _get_running_job(self, tenant_id: str, db: AsyncSession) -> Optional[TrainingJob]:
|
||||
"""Get running training job for tenant"""
|
||||
|
||||
result = await db.execute(
|
||||
select(TrainingJob).where(
|
||||
and_(
|
||||
TrainingJob.tenant_id == tenant_id,
|
||||
TrainingJob.status.in_(["queued", "running"])
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def _execute_training(self, job_id: str, request: TrainingRequest, db: AsyncSession):
|
||||
"""Execute training job"""
|
||||
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
# Update job status
|
||||
await self._update_job_status(job_id, "running", 0, "Starting training...", db)
|
||||
|
||||
# Get training data
|
||||
await self._update_job_status(job_id, "running", 10, "Fetching training data...", db)
|
||||
training_data = await self._get_training_data(job_id, request, db)
|
||||
|
||||
# Train models
|
||||
await self._update_job_status(job_id, "running", 30, "Training models...", db)
|
||||
models_result = await self.ml_trainer.train_models(training_data, job_id, db)
|
||||
|
||||
# Validate models
|
||||
await self._update_job_status(job_id, "running", 80, "Validating models...", db)
|
||||
validation_result = await self.ml_trainer.validate_models(models_result, db)
|
||||
|
||||
# Save models
|
||||
await self._update_job_status(job_id, "running", 90, "Saving models...", db)
|
||||
await self._save_trained_models(job_id, models_result, validation_result, db)
|
||||
|
||||
# Complete job
|
||||
duration = int((datetime.utcnow() - start_time).total_seconds())
|
||||
await self._complete_job(job_id, models_result, validation_result, duration, db)
|
||||
|
||||
# Publish completion event
|
||||
await message_publisher.publish_event(
|
||||
"training_events",
|
||||
"training.completed",
|
||||
TrainingCompletedEvent(
|
||||
event_id=str(uuid.uuid4()),
|
||||
service_name="training-service",
|
||||
timestamp=datetime.utcnow(),
|
||||
data={
|
||||
"job_id": str(job_id),
|
||||
"models_trained": len(models_result),
|
||||
"duration_seconds": duration
|
||||
}
|
||||
).__dict__
|
||||
)
|
||||
|
||||
logger.info(f"Training job completed: {job_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training job failed: {job_id} - {e}")
|
||||
|
||||
# Update job as failed
|
||||
await self._update_job_status(job_id, "failed", 0, f"Training failed: {str(e)}", db)
|
||||
|
||||
# Publish failure event
|
||||
await message_publisher.publish_event(
|
||||
"training_events",
|
||||
"training.failed",
|
||||
TrainingFailedEvent(
|
||||
event_id=str(uuid.uuid4()),
|
||||
service_name="training-service",
|
||||
timestamp=datetime.utcnow(),
|
||||
data={
|
||||
"job_id": str(job_id),
|
||||
"error": str(e)
|
||||
}
|
||||
).__dict__
|
||||
)
|
||||
|
||||
async def _update_job_status(self, job_id: str, status: str, progress: int, current_step: str, db: AsyncSession):
|
||||
"""Update training job status"""
|
||||
|
||||
await db.execute(
|
||||
update(TrainingJob)
|
||||
.where(TrainingJob.id == job_id)
|
||||
.values(
|
||||
status=status,
|
||||
progress=progress,
|
||||
current_step=current_step,
|
||||
updated_at=datetime.utcnow()
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
async def _get_training_data(self, job_id: str, request: TrainingRequest, db: AsyncSession) -> Dict[str, Any]:
|
||||
"""Get training data from data service"""
|
||||
|
||||
# Get job details
|
||||
result = await db.execute(
|
||||
select(TrainingJob).where(TrainingJob.id == job_id)
|
||||
)
|
||||
job = result.scalar_one()
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{settings.DATA_SERVICE_URL}/training-data/{job.tenant_id}",
|
||||
params={
|
||||
"from_date": job.training_data_from.isoformat(),
|
||||
"to_date": job.training_data_to.isoformat(),
|
||||
"products": request.products
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
raise Exception(f"Failed to get training data: {response.status_code}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting training data: {e}")
|
||||
raise
|
||||
|
||||
async def _save_trained_models(self, job_id: str, models_result: Dict[str, Any], validation_result: Dict[str, Any], db: AsyncSession):
|
||||
"""Save trained models to database"""
|
||||
|
||||
# Get job details
|
||||
result = await db.execute(
|
||||
select(TrainingJob).where(TrainingJob.id == job_id)
|
||||
)
|
||||
job = result.scalar_one()
|
||||
|
||||
# Deactivate old models
|
||||
await db.execute(
|
||||
update(TrainedModel)
|
||||
.where(TrainedModel.tenant_id == job.tenant_id)
|
||||
.values(is_active=False)
|
||||
)
|
||||
|
||||
# Save new models
|
||||
for product_name, model_data in models_result.items():
|
||||
validation_data = validation_result.get(product_name, {})
|
||||
|
||||
trained_model = TrainedModel(
|
||||
tenant_id=job.tenant_id,
|
||||
training_job_id=job_id,
|
||||
product_name=product_name,
|
||||
model_type=model_data.get("type", "prophet"),
|
||||
model_version="1.0",
|
||||
model_path=model_data.get("path"),
|
||||
mape=validation_data.get("mape"),
|
||||
rmse=validation_data.get("rmse"),
|
||||
mae=validation_data.get("mae"),
|
||||
r2_score=validation_data.get("r2_score"),
|
||||
training_samples=model_data.get("training_samples"),
|
||||
features_used=model_data.get("features", []),
|
||||
hyperparameters=model_data.get("hyperparameters", {}),
|
||||
is_active=True
|
||||
)
|
||||
|
||||
db.add(trained_model)
|
||||
|
||||
await db.commit()
|
||||
|
||||
async def _complete_job(self, job_id: str, models_result: Dict[str, Any], validation_result: Dict[str, Any], duration: int, db: AsyncSession):
|
||||
"""Complete training job"""
|
||||
|
||||
# Calculate metrics
|
||||
metrics = {
|
||||
"models_trained": len(models_result),
|
||||
"average_mape": sum(v.get("mape", 0) for v in validation_result.values()) / len(validation_result) if validation_result else 0,
|
||||
"training_duration": duration,
|
||||
"validation_results": validation_result
|
||||
}
|
||||
|
||||
await db.execute(
|
||||
update(TrainingJob)
|
||||
.where(TrainingJob.id == job_id)
|
||||
.values(
|
||||
status="completed",
|
||||
progress=100,
|
||||
current_step="Training completed successfully",
|
||||
completed_at=datetime.utcnow(),
|
||||
duration_seconds=duration,
|
||||
models_trained=models_result,
|
||||
metrics=metrics,
|
||||
products_count=len(models_result)
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
@@ -21,64 +21,4 @@ scipy==1.11.4
|
||||
|
||||
# Utilities
|
||||
pytz==2023.3
|
||||
python-dateutil==2.8.2# services/training/app/main.py
|
||||
"""
|
||||
Training Service
|
||||
Handles ML model training for bakery demand forecasting
|
||||
"""
|
||||
|
||||
import logging
|
||||
from fastapi import FastAPI, BackgroundTasks
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import database_manager
|
||||
from app.api import training, models
|
||||
from app.services.messaging import message_publisher
|
||||
from shared.monitoring.logging import setup_logging
|
||||
from shared.monitoring.metrics import MetricsCollector
|
||||
|
||||
# Setup logging
|
||||
setup_logging("training-service", settings.LOG_LEVEL)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title="Training Service",
|
||||
description="ML model training service for bakery demand forecasting",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# Initialize metrics collector
|
||||
metrics_collector = MetricsCollector("training-service")
|
||||
|
||||
# CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include routers
|
||||
app.include_router(training.router, prefix="/training", tags=["training"])
|
||||
app.include_router(models.router, prefix="/models", tags=["models"])
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Application startup"""
|
||||
logger.info("Starting Training Service")
|
||||
|
||||
# Create database tables
|
||||
await database_manager.create_tables()
|
||||
|
||||
# Initialize message publisher
|
||||
await message_publisher.connect()
|
||||
|
||||
# Start metrics server
|
||||
metrics_collector.start_metrics_server(8080)
|
||||
|
||||
logger.info("Training Service started successfully")
|
||||
|
||||
@
|
||||
python-dateutil==2.8.2
|
||||
|
||||
@@ -3,7 +3,7 @@ Shared JWT Authentication Handler
|
||||
Used across all microservices for consistent authentication
|
||||
"""
|
||||
|
||||
import jwt
|
||||
from jose import jwt
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any
|
||||
import logging
|
||||
@@ -53,6 +53,6 @@ class JWTHandler:
|
||||
except jwt.ExpiredSignatureError:
|
||||
logger.warning("Token has expired")
|
||||
return None
|
||||
except jwt.InvalidTokenError:
|
||||
except jwt.JWTError:
|
||||
logger.warning("Invalid token")
|
||||
return None
|
||||
Reference in New Issue
Block a user