Improve gateway service

This commit is contained in:
Urtzi Alfaro
2025-07-20 07:24:04 +02:00
parent 1c730c3c81
commit 8cd433c0cd
4 changed files with 816 additions and 373 deletions

View File

@@ -244,14 +244,41 @@ class AuthMiddleware(BaseHTTPMiddleware):
await self.redis_client.setex(cache_key, ttl, json.dumps(user_context)) await self.redis_client.setex(cache_key, ttl, json.dumps(user_context))
def _inject_auth_headers(self, request: Request, user_context: Dict[str, Any], tenant_id: Optional[str]): def _inject_auth_headers(self, request: Request, user_context: Dict[str, Any], tenant_id: Optional[str]):
"""Inject authentication context into forwarded requests""" """
# Add user context headers for downstream services Inject authentication headers for downstream services
if hasattr(request, "headers"):
# Create mutable headers This allows services to work both:
headers = dict(request.headers) 1. Behind the gateway (using request.state)
headers["X-User-ID"] = user_context["user_id"] 2. Called directly (using headers) for development/testing
headers["X-User-Email"] = user_context["email"] """
# Remove any existing auth headers to prevent spoofing
headers_to_remove = [
"x-user-id", "x-user-email", "x-user-role",
"x-tenant-id", "x-user-permissions", "x-authenticated"
]
for header in headers_to_remove:
request.headers.__dict__["_list"] = [
(k, v) for k, v in request.headers.raw
if k.lower() != header.lower()
]
# Inject new headers
new_headers = [
(b"x-authenticated", b"true"),
(b"x-user-id", str(user_context.get("user_id", "")).encode()),
(b"x-user-email", str(user_context.get("email", "")).encode()),
(b"x-user-role", str(user_context.get("role", "user")).encode()),
]
if tenant_id: if tenant_id:
headers["X-Tenant-ID"] = tenant_id new_headers.append((b"x-tenant-id", tenant_id.encode()))
# Update request headers
request.scope["headers"] = [(k.lower().encode(), v.encode()) for k, v in headers.items()] permissions = user_context.get("permissions", [])
if permissions:
new_headers.append((b"x-user-permissions", ",".join(permissions).encode()))
# Add headers to request
request.headers.__dict__["_list"].extend(new_headers)
logger.debug(f"Injected auth headers for user {user_context.get('email')}")

View File

@@ -1,22 +1,18 @@
# ================================================================ # ================================================================
# services/data/app/api/sales.py - FIXED VERSION # services/data/app/api/sales.py - UPDATED WITH UNIFIED AUTH
# ================================================================ # ================================================================
"""Sales data API endpoints with improved error handling""" """Sales data API endpoints with unified authentication"""
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, Query, Response from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, Query, Response
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, Optional from typing import List, Optional, Dict, Any
import uuid import uuid
from datetime import datetime from datetime import datetime
import base64 import base64
import structlog import structlog
from app.core.database import get_db from app.core.database import get_db
from app.core.auth import get_current_user, AuthInfo
from app.services.sales_service import SalesService
from app.services.data_import_service import DataImportService
from app.services.messaging import publish_sales_created
from app.schemas.sales import ( from app.schemas.sales import (
SalesDataCreate, SalesDataCreate,
SalesDataResponse, SalesDataResponse,
@@ -26,75 +22,163 @@ from app.schemas.sales import (
SalesValidationResult, SalesValidationResult,
SalesExportRequest SalesExportRequest
) )
from app.services.sales_service import SalesService
from app.services.data_import_service import DataImportService
from app.services.messaging import (
publish_sales_created,
publish_data_imported,
publish_export_completed
)
router = APIRouter() # Import unified authentication from shared library
from shared.auth.decorators import (
get_current_user_dep,
get_current_tenant_id_dep
)
router = APIRouter(prefix="/sales", tags=["sales"])
logger = structlog.get_logger() logger = structlog.get_logger()
@router.post("/", response_model=SalesDataResponse) @router.post("/", response_model=SalesDataResponse)
async def create_sales_record( async def create_sales_record(
sales_data: SalesDataCreate, sales_data: SalesDataCreate,
db: AsyncSession = Depends(get_db), tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: AuthInfo = Depends(get_current_user) current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
): ):
"""Create a new sales record""" """Create a new sales record"""
try: try:
logger.debug("API: Creating sales record", product=sales_data.product_name, quantity=sales_data.quantity_sold) logger.debug("Creating sales record",
product=sales_data.product_name,
quantity=sales_data.quantity_sold,
tenant_id=tenant_id,
user_id=current_user["user_id"])
# Override tenant_id from token/header
sales_data.tenant_id = tenant_id
record = await SalesService.create_sales_record(sales_data, db) record = await SalesService.create_sales_record(sales_data, db)
# Publish event (with error handling) # Publish event (non-blocking)
try: try:
await publish_sales_created({ await publish_sales_created({
"tenant_id": str(sales_data.tenant_id), "tenant_id": tenant_id,
"product_name": sales_data.product_name, "product_name": sales_data.product_name,
"quantity_sold": sales_data.quantity_sold, "quantity_sold": sales_data.quantity_sold,
"revenue": sales_data.revenue, "revenue": sales_data.revenue,
"source": sales_data.source, "source": sales_data.source,
"created_by": current_user["user_id"],
"timestamp": datetime.utcnow().isoformat() "timestamp": datetime.utcnow().isoformat()
}) })
except Exception as pub_error: except Exception as pub_error:
logger.warning("Failed to publish sales created event", error=str(pub_error)) logger.warning("Failed to publish sales created event", error=str(pub_error))
# Continue processing - event publishing failure shouldn't break the API # Continue - event failure shouldn't break API
logger.debug("Successfully created sales record", record_id=record.id) logger.info("Successfully created sales record",
record_id=record.id,
tenant_id=tenant_id)
return record return record
except Exception as e: except Exception as e:
logger.error("Failed to create sales record", error=str(e)) logger.error("Failed to create sales record",
import traceback error=str(e),
logger.error("Sales creation traceback", traceback=traceback.format_exc()) tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to create sales record: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to create sales record: {str(e)}")
@router.post("/query", response_model=List[SalesDataResponse]) @router.post("/bulk", response_model=List[SalesDataResponse])
async def get_sales_data( async def create_bulk_sales(
query: SalesDataQuery, sales_data: List[SalesDataCreate],
db: AsyncSession = Depends(get_db), tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: AuthInfo = Depends(get_current_user) current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
): ):
"""Get sales data by query parameters""" """Create multiple sales records"""
try: try:
logger.debug("API: Querying sales data", tenant_id=query.tenant_id) logger.debug("Creating bulk sales records",
count=len(sales_data),
tenant_id=tenant_id)
records = await SalesService.get_sales_data(query, db) # Override tenant_id for all records
for record in sales_data:
record.tenant_id = tenant_id
logger.debug("Successfully retrieved sales data", count=len(records)) records = await SalesService.create_bulk_sales(sales_data, db)
# Publish event
try:
await publish_data_imported({
"tenant_id": tenant_id,
"type": "bulk_create",
"records_created": len(records),
"created_by": current_user["user_id"],
"timestamp": datetime.utcnow().isoformat()
})
except Exception as pub_error:
logger.warning("Failed to publish bulk import event", error=str(pub_error))
logger.info("Successfully created bulk sales records",
count=len(records),
tenant_id=tenant_id)
return records return records
except Exception as e: except Exception as e:
logger.error("Failed to query sales data", error=str(e)) logger.error("Failed to create bulk sales records",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to create bulk sales records: {str(e)}")
@router.get("/", response_model=List[SalesDataResponse])
async def get_sales_data(
start_date: Optional[datetime] = Query(None),
end_date: Optional[datetime] = Query(None),
product_name: Optional[str] = Query(None),
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""Get sales data with filters"""
try:
logger.debug("Querying sales data",
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
product_name=product_name)
query = SalesDataQuery(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
product_name=product_name
)
records = await SalesService.get_sales_data(query, db)
logger.debug("Successfully retrieved sales data",
count=len(records),
tenant_id=tenant_id)
return records
except Exception as e:
logger.error("Failed to query sales data",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to query sales data: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to query sales data: {str(e)}")
@router.post("/import", response_model=SalesImportResult) @router.post("/import", response_model=SalesImportResult)
async def import_sales_data( async def import_sales_data(
tenant_id: str = Form(...),
file_format: str = Form(...),
file: UploadFile = File(...), file: UploadFile = File(...),
db: AsyncSession = Depends(get_db), file_format: str = Form(...),
current_user: AuthInfo = Depends(get_current_user) tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
): ):
"""Import sales data from file""" """Import sales data from file"""
try: try:
logger.debug("API: Importing sales data", tenant_id=tenant_id, format=file_format, filename=file.filename) logger.info("Importing sales data",
tenant_id=tenant_id,
format=file_format,
filename=file.filename,
user_id=current_user["user_id"])
# Read file content # Read file content
content = await file.read() content = await file.read()
@@ -102,100 +186,78 @@ async def import_sales_data(
# Process import # Process import
result = await DataImportService.process_upload( result = await DataImportService.process_upload(
tenant_id, file_content, file_format, db tenant_id,
file_content,
file_format,
db,
user_id=current_user["user_id"]
) )
if result["success"]: if result["success"]:
# Publish event (with error handling) # Publish event
try: try:
await data_publisher.publish_data_imported({ await publish_data_imported({
"tenant_id": tenant_id, "tenant_id": tenant_id,
"type": "bulk_import", "type": "file_import",
"format": file_format, "format": file_format,
"filename": file.filename, "filename": file.filename,
"records_created": result["records_created"], "records_created": result["records_created"],
"imported_by": current_user["user_id"],
"timestamp": datetime.utcnow().isoformat() "timestamp": datetime.utcnow().isoformat()
}) })
except Exception as pub_error: except Exception as pub_error:
logger.warning("Failed to publish data imported event", error=str(pub_error)) logger.warning("Failed to publish import event", error=str(pub_error))
# Continue processing
logger.debug("Import completed", success=result["success"], records_created=result.get("records_created", 0)) logger.info("Import completed",
success=result["success"],
records_created=result.get("records_created", 0),
tenant_id=tenant_id)
return result return result
except Exception as e: except Exception as e:
logger.error("Failed to import sales data", error=str(e)) logger.error("Failed to import sales data",
import traceback error=str(e),
logger.error("Sales import traceback", traceback=traceback.format_exc()) tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to import sales data: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to import sales data: {str(e)}")
@router.post("/import/json", response_model=SalesImportResult)
async def import_sales_json(
import_data: SalesDataImport,
db: AsyncSession = Depends(get_db),
current_user: AuthInfo = Depends(get_current_user)
):
"""Import sales data from JSON"""
try:
logger.debug("API: Importing JSON sales data", tenant_id=import_data.tenant_id)
result = await DataImportService.process_upload(
str(import_data.tenant_id),
import_data.data,
import_data.data_format,
db
)
if result["success"]:
# Publish event (with error handling)
try:
await publish_data_imported({
"tenant_id": str(import_data.tenant_id),
"type": "json_import",
"records_created": result["records_created"],
"timestamp": datetime.utcnow().isoformat()
})
except Exception as pub_error:
logger.warning("Failed to publish JSON import event", error=str(pub_error))
# Continue processing
logger.debug("JSON import completed", success=result["success"], records_created=result.get("records_created", 0))
return result
except Exception as e:
logger.error("Failed to import JSON sales data", error=str(e))
import traceback
logger.error("JSON import traceback", traceback=traceback.format_exc())
raise HTTPException(status_code=500, detail=f"Failed to import JSON sales data: {str(e)}")
@router.post("/import/validate", response_model=SalesValidationResult) @router.post("/import/validate", response_model=SalesValidationResult)
async def validate_import_data( async def validate_import_data(
import_data: SalesDataImport, import_data: SalesDataImport,
current_user: AuthInfo = Depends(get_current_user) tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep)
): ):
"""Validate import data before processing""" """Validate import data before processing"""
try: try:
logger.debug("API: Validating import data", tenant_id=import_data.tenant_id) logger.debug("Validating import data", tenant_id=tenant_id)
# Override tenant_id
import_data.tenant_id = tenant_id
validation = await DataImportService.validate_import_data( validation = await DataImportService.validate_import_data(
import_data.model_dump() import_data.model_dump()
) )
logger.debug("Validation completed", is_valid=validation.get("is_valid", False)) logger.debug("Validation completed",
is_valid=validation.get("is_valid", False),
tenant_id=tenant_id)
return validation return validation
except Exception as e: except Exception as e:
logger.error("Failed to validate import data", error=str(e)) logger.error("Failed to validate import data",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to validate import data: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to validate import data: {str(e)}")
@router.get("/import/template/{format_type}") @router.get("/import/template/{format_type}")
async def get_import_template( async def get_import_template(
format_type: str, format_type: str,
current_user: AuthInfo = Depends(get_current_user) current_user: Dict[str, Any] = Depends(get_current_user_dep)
): ):
"""Get import template for specified format""" """Get import template for specified format"""
try: try:
logger.debug("API: Getting import template", format=format_type) logger.debug("Getting import template",
format=format_type,
user_id=current_user["user_id"])
template = await DataImportService.get_import_template(format_type) template = await DataImportService.get_import_template(format_type)
@@ -230,21 +292,22 @@ async def get_import_template(
raise raise
except Exception as e: except Exception as e:
logger.error("Failed to generate import template", error=str(e)) logger.error("Failed to generate import template", error=str(e))
import traceback
logger.error("Template generation traceback", traceback=traceback.format_exc())
raise HTTPException(status_code=500, detail=f"Failed to generate template: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to generate template: {str(e)}")
@router.get("/analytics/{tenant_id}") @router.get("/analytics")
async def get_sales_analytics( async def get_sales_analytics(
tenant_id: str,
start_date: Optional[datetime] = Query(None, description="Start date"), start_date: Optional[datetime] = Query(None, description="Start date"),
end_date: Optional[datetime] = Query(None, description="End date"), end_date: Optional[datetime] = Query(None, description="End date"),
db: AsyncSession = Depends(get_db), tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: AuthInfo = Depends(get_current_user) current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
): ):
"""Get sales analytics for tenant""" """Get sales analytics for tenant"""
try: try:
logger.debug("API: Getting sales analytics", tenant_id=tenant_id) logger.debug("Getting sales analytics",
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date)
analytics = await SalesService.get_sales_analytics( analytics = await SalesService.get_sales_analytics(
tenant_id, start_date, end_date, db tenant_id, start_date, end_date, db
@@ -254,22 +317,27 @@ async def get_sales_analytics(
return analytics return analytics
except Exception as e: except Exception as e:
logger.error("Failed to generate sales analytics", error=str(e)) logger.error("Failed to generate sales analytics",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to generate analytics: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to generate analytics: {str(e)}")
@router.post("/export/{tenant_id}") @router.post("/export")
async def export_sales_data( async def export_sales_data(
tenant_id: str,
export_format: str = Query("csv", description="Export format: csv, excel, json"), export_format: str = Query("csv", description="Export format: csv, excel, json"),
start_date: Optional[datetime] = Query(None, description="Start date"), start_date: Optional[datetime] = Query(None, description="Start date"),
end_date: Optional[datetime] = Query(None, description="End date"), end_date: Optional[datetime] = Query(None, description="End date"),
products: Optional[List[str]] = Query(None, description="Filter by products"), products: Optional[List[str]] = Query(None, description="Filter by products"),
db: AsyncSession = Depends(get_db), tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: AuthInfo = Depends(get_current_user) current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
): ):
"""Export sales data in specified format""" """Export sales data in specified format"""
try: try:
logger.debug("API: Exporting sales data", tenant_id=tenant_id, format=export_format) logger.info("Exporting sales data",
tenant_id=tenant_id,
format=export_format,
user_id=current_user["user_id"])
export_result = await SalesService.export_sales_data( export_result = await SalesService.export_sales_data(
tenant_id, export_format, start_date, end_date, products, db tenant_id, export_format, start_date, end_date, products, db
@@ -278,7 +346,21 @@ async def export_sales_data(
if not export_result: if not export_result:
raise HTTPException(status_code=404, detail="No data found for export") raise HTTPException(status_code=404, detail="No data found for export")
logger.debug("Export completed successfully", tenant_id=tenant_id, format=export_format) # Publish export event
try:
await publish_export_completed({
"tenant_id": tenant_id,
"format": export_format,
"exported_by": current_user["user_id"],
"record_count": export_result.get("record_count", 0),
"timestamp": datetime.utcnow().isoformat()
})
except Exception as pub_error:
logger.warning("Failed to publish export event", error=str(pub_error))
logger.info("Export completed successfully",
tenant_id=tenant_id,
format=export_format)
return StreamingResponse( return StreamingResponse(
iter([export_result["content"]]), iter([export_result["content"]]),
@@ -289,29 +371,91 @@ async def export_sales_data(
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
logger.error("Failed to export sales data", error=str(e)) logger.error("Failed to export sales data",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to export sales data: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to export sales data: {str(e)}")
@router.delete("/{record_id}") @router.delete("/{record_id}")
async def delete_sales_record( async def delete_sales_record(
record_id: str, record_id: str,
db: AsyncSession = Depends(get_db), tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: AuthInfo = Depends(get_current_user) current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
): ):
"""Delete a sales record""" """Delete a sales record"""
try: try:
logger.debug("API: Deleting sales record", record_id=record_id) logger.info("Deleting sales record",
record_id=record_id,
tenant_id=tenant_id,
user_id=current_user["user_id"])
# Verify record belongs to tenant before deletion
record = await SalesService.get_sales_record(record_id, db)
if not record or record.tenant_id != tenant_id:
raise HTTPException(status_code=404, detail="Sales record not found")
success = await SalesService.delete_sales_record(record_id, db) success = await SalesService.delete_sales_record(record_id, db)
if not success: if not success:
raise HTTPException(status_code=404, detail="Sales record not found") raise HTTPException(status_code=404, detail="Sales record not found")
logger.debug("Sales record deleted successfully", record_id=record_id) logger.info("Sales record deleted successfully",
record_id=record_id,
tenant_id=tenant_id)
return {"status": "success", "message": "Sales record deleted successfully"} return {"status": "success", "message": "Sales record deleted successfully"}
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
logger.error("Failed to delete sales record", error=str(e)) logger.error("Failed to delete sales record",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to delete sales record: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to delete sales record: {str(e)}")
@router.get("/summary")
async def get_sales_summary(
period: str = Query("daily", description="Summary period: daily, weekly, monthly"),
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""Get sales summary for specified period"""
try:
logger.debug("Getting sales summary",
tenant_id=tenant_id,
period=period)
summary = await SalesService.get_sales_summary(tenant_id, period, db)
logger.debug("Summary generated successfully", tenant_id=tenant_id)
return summary
except Exception as e:
logger.error("Failed to generate sales summary",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to generate summary: {str(e)}")
@router.get("/products")
async def get_products_list(
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""Get list of all products with sales data"""
try:
logger.debug("Getting products list", tenant_id=tenant_id)
products = await SalesService.get_products_list(tenant_id, db)
logger.debug("Products list retrieved",
count=len(products),
tenant_id=tenant_id)
return products
except Exception as e:
logger.error("Failed to get products list",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to get products list: {str(e)}")

View File

@@ -1,299 +1,472 @@
# services/training/app/api/training.py # ================================================================
""" # services/training/app/api/training.py - UPDATED WITH UNIFIED AUTH
Training API endpoints for the training service # ================================================================
""" """Training API endpoints with unified authentication"""
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, Query
from sqlalchemy.ext.asyncio import AsyncSession from typing import List, Optional, Dict, Any
from typing import Dict, List, Any, Optional
import logging
from datetime import datetime from datetime import datetime
import uuid import structlog
from app.core.database import get_db
from app.core.auth import get_current_tenant_id
from app.schemas.training import ( from app.schemas.training import (
TrainingJobRequest, TrainingJobRequest,
TrainingJobResponse, TrainingJobResponse,
TrainingStatusResponse, TrainingJobStatus,
SingleProductTrainingRequest SingleProductTrainingRequest,
TrainingJobProgress,
DataValidationRequest,
DataValidationResponse
) )
from app.services.training_service import TrainingService from app.services.training_service import TrainingService
from app.services.messaging import publish_job_started, publish_job_cancelled, publish_product_training_started from app.services.messaging import (
from shared.monitoring.metrics import MetricsCollector publish_job_started,
publish_job_completed,
publish_job_failed,
publish_job_progress,
publish_product_training_started,
publish_product_training_completed
)
logger = logging.getLogger(__name__) # Import unified authentication from shared library
router = APIRouter() from shared.auth.decorators import (
metrics = MetricsCollector("training-service") get_current_user_dep,
get_current_tenant_id_dep,
require_role
)
# Initialize training service logger = structlog.get_logger()
training_service = TrainingService() router = APIRouter(prefix="/training", tags=["training"])
@router.post("/jobs", response_model=TrainingJobResponse) @router.post("/jobs", response_model=TrainingJobResponse)
async def start_training_job( async def start_training_job(
request: TrainingJobRequest, request: TrainingJobRequest,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
tenant_id: str = Depends(get_current_tenant_id), tenant_id: str = Depends(get_current_tenant_id_dep),
db: AsyncSession = Depends(get_db) current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends()
): ):
""" """Start a new training job for all products"""
Start a new training job for all products of a tenant.
Replaces the old Celery-based training system.
"""
try: try:
logger.info(f"Starting training job for tenant {tenant_id}") logger.info("Starting training job",
metrics.increment_counter("training_jobs_started") tenant_id=tenant_id,
user_id=current_user["user_id"],
config=request.dict())
# Generate job ID # Create training job
job_id = f"training_{tenant_id}_{uuid.uuid4().hex[:8]}" job = await training_service.create_training_job(
# Create training job record
training_job = await training_service.create_training_job(
db=db,
tenant_id=tenant_id, tenant_id=tenant_id,
job_id=job_id, user_id=current_user["user_id"],
config=request.dict() config=request.dict()
) )
# Publish job started event
try:
await publish_job_started(
job_id=job.job_id,
tenant_id=tenant_id,
config=request.dict()
)
except Exception as e:
logger.warning("Failed to publish job started event", error=str(e))
# Start training in background # Start training in background
background_tasks.add_task( background_tasks.add_task(
training_service.execute_training_job, training_service.execute_training_job,
db, job.job_id
job_id,
tenant_id,
request
) )
# Publish training started event logger.info("Training job created",
await publish_job_started(job_id, tenant_id, request.dict()) job_id=job.job_id,
tenant_id=tenant_id)
return TrainingJobResponse( return job
job_id=job_id,
status="started",
message="Training job started successfully",
tenant_id=tenant_id,
created_at=training_job.start_time,
estimated_duration_minutes=request.estimated_duration or 15
)
except Exception as e: except Exception as e:
logger.error(f"Failed to start training job: {str(e)}") logger.error("Failed to start training job",
metrics.increment_counter("training_jobs_failed") error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to start training job: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to start training job: {str(e)}")
@router.get("/jobs/{job_id}/status", response_model=TrainingStatusResponse) @router.get("/jobs", response_model=List[TrainingJobResponse])
async def get_training_status( async def get_training_jobs(
job_id: str, status: Optional[TrainingJobStatus] = Query(None),
tenant_id: str = Depends(get_current_tenant_id), limit: int = Query(100, ge=1, le=1000),
db: AsyncSession = Depends(get_db) offset: int = Query(0, ge=0),
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends()
): ):
""" """Get training jobs for tenant"""
Get the status of a training job.
Provides real-time progress updates.
"""
try: try:
# Get job status from database logger.debug("Getting training jobs",
job_status = await training_service.get_job_status(db, job_id, tenant_id) tenant_id=tenant_id,
status=status,
limit=limit,
offset=offset)
if not job_status: jobs = await training_service.get_training_jobs(
raise HTTPException(status_code=404, detail="Training job not found") tenant_id=tenant_id,
status=status,
return TrainingStatusResponse( limit=limit,
job_id=job_id, offset=offset
status=job_status.status,
progress=job_status.progress,
current_step=job_status.current_step,
started_at=job_status.start_time,
completed_at=job_status.end_time,
results=job_status.results,
error_message=job_status.error_message
) )
logger.debug("Retrieved training jobs",
count=len(jobs),
tenant_id=tenant_id)
return jobs
except Exception as e:
logger.error("Failed to get training jobs",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to get training jobs: {str(e)}")
@router.get("/jobs/{job_id}", response_model=TrainingJobResponse)
async def get_training_job(
job_id: str,
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends()
):
"""Get specific training job details"""
try:
logger.debug("Getting training job",
job_id=job_id,
tenant_id=tenant_id)
job = await training_service.get_training_job(job_id)
# Verify tenant access
if job.tenant_id != tenant_id:
logger.warning("Unauthorized job access attempt",
job_id=job_id,
tenant_id=tenant_id,
job_tenant_id=job.tenant_id)
raise HTTPException(status_code=404, detail="Job not found")
return job
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
logger.error(f"Failed to get training status: {str(e)}") logger.error("Failed to get training job",
raise HTTPException(status_code=500, detail=f"Failed to get training status: {str(e)}") error=str(e),
job_id=job_id)
raise HTTPException(status_code=500, detail=f"Failed to get training job: {str(e)}")
@router.get("/jobs/{job_id}/progress", response_model=TrainingJobProgress)
async def get_training_progress(
job_id: str,
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends()
):
"""Get real-time training progress"""
try:
logger.debug("Getting training progress",
job_id=job_id,
tenant_id=tenant_id)
# Verify job belongs to tenant
job = await training_service.get_training_job(job_id)
if job.tenant_id != tenant_id:
raise HTTPException(status_code=404, detail="Job not found")
progress = await training_service.get_job_progress(job_id)
return progress
except HTTPException:
raise
except Exception as e:
logger.error("Failed to get training progress",
error=str(e),
job_id=job_id)
raise HTTPException(status_code=500, detail=f"Failed to get training progress: {str(e)}")
@router.post("/jobs/{job_id}/cancel")
async def cancel_training_job(
job_id: str,
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends()
):
"""Cancel a running training job"""
try:
logger.info("Cancelling training job",
job_id=job_id,
tenant_id=tenant_id,
user_id=current_user["user_id"])
job = await training_service.get_training_job(job_id)
# Verify tenant access
if job.tenant_id != tenant_id:
raise HTTPException(status_code=404, detail="Job not found")
await training_service.cancel_training_job(job_id)
# Publish cancellation event
try:
await publish_job_failed(
job_id=job_id,
tenant_id=tenant_id,
error="Job cancelled by user",
failed_at="cancellation"
)
except Exception as e:
logger.warning("Failed to publish cancellation event", error=str(e))
logger.info("Training job cancelled", job_id=job_id)
return {"message": "Job cancelled successfully", "job_id": job_id}
except HTTPException:
raise
except Exception as e:
logger.error("Failed to cancel training job",
error=str(e),
job_id=job_id)
raise HTTPException(status_code=500, detail=f"Failed to cancel training job: {str(e)}")
@router.post("/products/{product_name}", response_model=TrainingJobResponse) @router.post("/products/{product_name}", response_model=TrainingJobResponse)
async def train_single_product( async def train_single_product(
product_name: str, product_name: str,
request: SingleProductTrainingRequest, request: SingleProductTrainingRequest,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
tenant_id: str = Depends(get_current_tenant_id), tenant_id: str = Depends(get_current_tenant_id_dep),
db: AsyncSession = Depends(get_db) current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends()
): ):
""" """Train model for a single product"""
Train a model for a single product.
Useful for quick model updates or new products.
"""
try: try:
logger.info(f"Starting single product training for {product_name}, tenant {tenant_id}") logger.info("Training single product",
metrics.increment_counter("single_product_training_started") product_name=product_name,
tenant_id=tenant_id,
user_id=current_user["user_id"])
# Generate job ID # Create training job for single product
job_id = f"single_{tenant_id}_{product_name}_{uuid.uuid4().hex[:8]}" job = await training_service.create_single_product_job(
# Create training job record
training_job = await training_service.create_single_product_job(
db=db,
tenant_id=tenant_id, tenant_id=tenant_id,
user_id=current_user["user_id"],
product_name=product_name, product_name=product_name,
job_id=job_id,
config=request.dict() config=request.dict()
) )
# Publish event
try:
await publish_product_training_started(
job_id=job.job_id,
tenant_id=tenant_id,
product_name=product_name
)
except Exception as e:
logger.warning("Failed to publish product training event", error=str(e))
# Start training in background # Start training in background
background_tasks.add_task( background_tasks.add_task(
training_service.execute_single_product_training, training_service.execute_single_product_training,
db, job.job_id,
job_id, product_name
tenant_id, )
product_name,
request logger.info("Single product training started",
job_id=job.job_id,
product_name=product_name)
return job
except Exception as e:
logger.error("Failed to train single product",
error=str(e),
product_name=product_name,
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to train product: {str(e)}")
@router.post("/validate", response_model=DataValidationResponse)
async def validate_training_data(
request: DataValidationRequest,
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends()
):
"""Validate data before training"""
try:
logger.debug("Validating training data",
tenant_id=tenant_id,
products=request.products)
validation_result = await training_service.validate_training_data(
tenant_id=tenant_id,
products=request.products,
min_data_points=request.min_data_points
)
logger.debug("Data validation completed",
is_valid=validation_result.is_valid,
tenant_id=tenant_id)
return validation_result
except Exception as e:
logger.error("Failed to validate training data",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to validate data: {str(e)}")
@router.get("/models")
async def get_trained_models(
product_name: Optional[str] = Query(None),
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends()
):
"""Get list of trained models"""
try:
logger.debug("Getting trained models",
tenant_id=tenant_id,
product_name=product_name)
models = await training_service.get_trained_models(
tenant_id=tenant_id,
product_name=product_name
)
logger.debug("Retrieved trained models",
count=len(models),
tenant_id=tenant_id)
return models
except Exception as e:
logger.error("Failed to get trained models",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to get models: {str(e)}")
@router.delete("/models/{model_id}")
@require_role("admin") # Only admins can delete models
async def delete_model(
model_id: str,
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends()
):
"""Delete a trained model (admin only)"""
try:
logger.info("Deleting model",
model_id=model_id,
tenant_id=tenant_id,
admin_id=current_user["user_id"])
# Verify model belongs to tenant
model = await training_service.get_model(model_id)
if model.tenant_id != tenant_id:
raise HTTPException(status_code=404, detail="Model not found")
success = await training_service.delete_model(model_id)
if not success:
raise HTTPException(status_code=404, detail="Model not found")
logger.info("Model deleted successfully", model_id=model_id)
return {"message": "Model deleted successfully", "model_id": model_id}
except HTTPException:
raise
except Exception as e:
logger.error("Failed to delete model",
error=str(e),
model_id=model_id)
raise HTTPException(status_code=500, detail=f"Failed to delete model: {str(e)}")
@router.get("/stats")
async def get_training_stats(
start_date: Optional[datetime] = Query(None),
end_date: Optional[datetime] = Query(None),
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends()
):
"""Get training statistics for tenant"""
try:
logger.debug("Getting training stats",
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date)
stats = await training_service.get_training_stats(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date
)
logger.debug("Training stats retrieved", tenant_id=tenant_id)
return stats
except Exception as e:
logger.error("Failed to get training stats",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to get stats: {str(e)}")
@router.post("/retrain/all")
async def retrain_all_products(
request: TrainingJobRequest,
background_tasks: BackgroundTasks,
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends()
):
"""Retrain all products with existing models"""
try:
logger.info("Retraining all products",
tenant_id=tenant_id,
user_id=current_user["user_id"])
# Check if models exist
existing_models = await training_service.get_trained_models(tenant_id)
if not existing_models:
raise HTTPException(
status_code=400,
detail="No existing models found. Please run initial training first."
)
# Create retraining job
job = await training_service.create_training_job(
tenant_id=tenant_id,
user_id=current_user["user_id"],
config={**request.dict(), "is_retrain": True}
) )
# Publish event # Publish event
await publish_product_training_started(job_id, tenant_id, product_name) try:
await publish_job_started(
return TrainingJobResponse(
job_id=job_id,
status="started",
message=f"Single product training started for {product_name}",
tenant_id=tenant_id,
created_at=training_job.start_time,
estimated_duration_minutes=5
)
except Exception as e:
logger.error(f"Failed to start single product training: {str(e)}")
metrics.increment_counter("single_product_training_failed")
raise HTTPException(status_code=500, detail=f"Failed to start training: {str(e)}")
@router.get("/jobs", response_model=List[TrainingStatusResponse])
async def list_training_jobs(
limit: int = 10,
status: Optional[str] = None,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
"""
List training jobs for a tenant.
"""
try:
jobs = await training_service.list_training_jobs(
db=db,
tenant_id=tenant_id,
limit=limit,
status_filter=status
)
return [
TrainingStatusResponse(
job_id=job.job_id, job_id=job.job_id,
status=job.status, tenant_id=tenant_id,
progress=job.progress, config={**request.dict(), "is_retrain": True}
current_step=job.current_step,
started_at=job.start_time,
completed_at=job.end_time,
results=job.results,
error_message=job.error_message
) )
for job in jobs except Exception as e:
] logger.warning("Failed to publish retrain event", error=str(e))
except Exception as e: # Start retraining in background
logger.error(f"Failed to list training jobs: {str(e)}") background_tasks.add_task(
raise HTTPException(status_code=500, detail=f"Failed to list training jobs: {str(e)}") training_service.execute_training_job,
job.job_id
@router.post("/jobs/{job_id}/cancel")
async def cancel_training_job(
job_id: str,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
"""
Cancel a running training job.
"""
try:
logger.info(f"Cancelling training job {job_id} for tenant {tenant_id}")
# Update job status to cancelled
success = await training_service.cancel_training_job(db, job_id, tenant_id)
if not success:
raise HTTPException(status_code=404, detail="Training job not found or cannot be cancelled")
# Publish cancellation event
await publish_job_cancelled(job_id, tenant_id)
return {"message": "Training job cancelled successfully"}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to cancel training job: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to cancel training job: {str(e)}")
@router.get("/jobs/{job_id}/logs")
async def get_training_logs(
job_id: str,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
"""
Get detailed logs for a training job.
"""
try:
logs = await training_service.get_training_logs(db, job_id, tenant_id)
if not logs:
raise HTTPException(status_code=404, detail="Training job not found")
return {"job_id": job_id, "logs": logs}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to get training logs: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to get training logs: {str(e)}")
@router.post("/validate")
async def validate_training_data(
request: TrainingJobRequest,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
"""
Validate training data before starting a job.
Provides early feedback on data quality issues.
"""
try:
logger.info(f"Validating training data for tenant {tenant_id}")
# Perform data validation
validation_result = await training_service.validate_training_data(
db=db,
tenant_id=tenant_id,
config=request.dict()
) )
return { logger.info("Retraining job created", job_id=job.job_id)
"is_valid": validation_result["is_valid"],
"issues": validation_result.get("issues", []),
"recommendations": validation_result.get("recommendations", []),
"estimated_training_time": validation_result.get("estimated_time_minutes", 15)
}
return job
except HTTPException:
raise
except Exception as e: except Exception as e:
logger.error(f"Failed to validate training data: {str(e)}") logger.error("Failed to start retraining",
raise HTTPException(status_code=500, detail=f"Failed to validate training data: {str(e)}") error=str(e),
tenant_id=tenant_id)
@router.get("/health") raise HTTPException(status_code=500, detail=f"Failed to start retraining: {str(e)}")
async def health_check():
"""Health check for the training service"""
return {
"status": "healthy",
"service": "training-service",
"timestamp": datetime.now().isoformat()
}

View File

@@ -1,14 +1,24 @@
# shared/auth/decorators.py - NEW FILE
""" """
Authentication decorators for microservices Unified authentication decorators for microservices
Designed to work with gateway authentication middleware
""" """
from functools import wraps from functools import wraps
from fastapi import HTTPException, status, Request from fastapi import HTTPException, status, Request, Depends
from typing import Callable, Optional from fastapi.security import HTTPBearer
from typing import Callable, Optional, Dict, Any
import structlog
logger = structlog.get_logger()
# Bearer token scheme for services that need it
security = HTTPBearer(auto_error=False)
def require_authentication(func: Callable) -> Callable: def require_authentication(func: Callable) -> Callable:
"""Decorator to require authentication - assumes gateway has validated token""" """
Decorator to require authentication - trusts gateway validation
Services behind the gateway should use this decorator
"""
@wraps(func) @wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
@@ -20,7 +30,6 @@ def require_authentication(func: Callable) -> Callable:
break break
if not request: if not request:
# Check kwargs
request = kwargs.get('request') request = kwargs.get('request')
if not request: if not request:
@@ -31,10 +40,14 @@ def require_authentication(func: Callable) -> Callable:
# Check if user context exists (set by gateway) # Check if user context exists (set by gateway)
if not hasattr(request.state, 'user') or not request.state.user: if not hasattr(request.state, 'user') or not request.state.user:
raise HTTPException( # Check headers as fallback (for direct service calls in dev)
status_code=status.HTTP_401_UNAUTHORIZED, user_info = extract_user_from_headers(request)
detail="Authentication required" if not user_info:
) raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required"
)
request.state.user = user_info
return await func(*args, **kwargs) return await func(*args, **kwargs)
@@ -45,32 +58,118 @@ def require_tenant_access(func: Callable) -> Callable:
@wraps(func) @wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
# Find request object
request = None request = None
for arg in args: for arg in args:
if isinstance(arg, Request): if isinstance(arg, Request):
request = arg request = arg
break break
if not request:
request = kwargs.get('request')
if not request or not hasattr(request.state, 'tenant_id'): if not request or not hasattr(request.state, 'tenant_id'):
raise HTTPException( # Try to extract from headers
status_code=status.HTTP_403_FORBIDDEN, tenant_id = extract_tenant_from_headers(request)
detail="Tenant access required" if not tenant_id:
) raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Tenant access required"
)
request.state.tenant_id = tenant_id
return await func(*args, **kwargs) return await func(*args, **kwargs)
return wrapper return wrapper
def get_current_user(request: Request) -> dict: def require_role(role: str):
"""Get current user from request state""" """Decorator to require specific role"""
if not hasattr(request.state, 'user'):
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(*args, **kwargs):
request = None
for arg in args:
if isinstance(arg, Request):
request = arg
break
if not request:
request = kwargs.get('request')
user = get_current_user(request)
user_role = user.get('role', '').lower()
if user_role != role.lower() and user_role != 'admin':
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"{role} role required"
)
return await func(*args, **kwargs)
return wrapper
return decorator
def get_current_user(request: Request) -> Dict[str, Any]:
"""Get current user from request state or headers"""
if hasattr(request.state, 'user') and request.state.user:
return request.state.user
# Fallback to headers (for dev/testing)
user_info = extract_user_from_headers(request)
if not user_info:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not authenticated" detail="User not authenticated"
) )
return request.state.user
return user_info
def get_current_tenant_id(request: Request) -> Optional[str]: def get_current_tenant_id(request: Request) -> Optional[str]:
"""Get current tenant ID from request state""" """Get current tenant ID from request state or headers"""
return getattr(request.state, 'tenant_id', None) if hasattr(request.state, 'tenant_id'):
return request.state.tenant_id
# Fallback to headers
return extract_tenant_from_headers(request)
def extract_user_from_headers(request: Request) -> Optional[Dict[str, Any]]:
"""Extract user information from forwarded headers (gateway sets these)"""
user_id = request.headers.get("X-User-ID")
if not user_id:
return None
return {
"user_id": user_id,
"email": request.headers.get("X-User-Email", ""),
"role": request.headers.get("X-User-Role", "user"),
"tenant_id": request.headers.get("X-Tenant-ID"),
"permissions": request.headers.get("X-User-Permissions", "").split(",") if request.headers.get("X-User-Permissions") else []
}
def extract_tenant_from_headers(request: Request) -> Optional[str]:
"""Extract tenant ID from headers"""
return request.headers.get("X-Tenant-ID")
# FastAPI Dependencies for injection
async def get_current_user_dep(request: Request) -> Dict[str, Any]:
"""FastAPI dependency to get current user"""
return get_current_user(request)
async def get_current_tenant_id_dep(request: Request) -> Optional[str]:
"""FastAPI dependency to get current tenant ID"""
return get_current_tenant_id(request)
# Export all decorators and functions
__all__ = [
'require_authentication',
'require_tenant_access',
'require_role',
'get_current_user',
'get_current_tenant_id',
'get_current_user_dep',
'get_current_tenant_id_dep',
'extract_user_from_headers',
'extract_tenant_from_headers'
]