diff --git a/gateway/app/middleware/auth.py b/gateway/app/middleware/auth.py index ffaae85c..c767a81b 100644 --- a/gateway/app/middleware/auth.py +++ b/gateway/app/middleware/auth.py @@ -244,14 +244,41 @@ class AuthMiddleware(BaseHTTPMiddleware): await self.redis_client.setex(cache_key, ttl, json.dumps(user_context)) def _inject_auth_headers(self, request: Request, user_context: Dict[str, Any], tenant_id: Optional[str]): - """Inject authentication context into forwarded requests""" - # Add user context headers for downstream services - if hasattr(request, "headers"): - # Create mutable headers - headers = dict(request.headers) - headers["X-User-ID"] = user_context["user_id"] - headers["X-User-Email"] = user_context["email"] + """ + Inject authentication headers for downstream services + + This allows services to work both: + 1. Behind the gateway (using request.state) + 2. Called directly (using headers) for development/testing + """ + # Remove any existing auth headers to prevent spoofing + headers_to_remove = [ + "x-user-id", "x-user-email", "x-user-role", + "x-tenant-id", "x-user-permissions", "x-authenticated" + ] + + for header in headers_to_remove: + request.headers.__dict__["_list"] = [ + (k, v) for k, v in request.headers.raw + if k.lower() != header.lower() + ] + + # Inject new headers + new_headers = [ + (b"x-authenticated", b"true"), + (b"x-user-id", str(user_context.get("user_id", "")).encode()), + (b"x-user-email", str(user_context.get("email", "")).encode()), + (b"x-user-role", str(user_context.get("role", "user")).encode()), + ] + if tenant_id: - headers["X-Tenant-ID"] = tenant_id - # Update request headers - request.scope["headers"] = [(k.lower().encode(), v.encode()) for k, v in headers.items()] + new_headers.append((b"x-tenant-id", tenant_id.encode())) + + permissions = user_context.get("permissions", []) + if permissions: + new_headers.append((b"x-user-permissions", ",".join(permissions).encode())) + + # Add headers to request + request.headers.__dict__["_list"].extend(new_headers) + + logger.debug(f"Injected auth headers for user {user_context.get('email')}") diff --git a/services/data/app/api/sales.py b/services/data/app/api/sales.py index 38a5fd3f..261b8df2 100644 --- a/services/data/app/api/sales.py +++ b/services/data/app/api/sales.py @@ -1,22 +1,18 @@ # ================================================================ -# services/data/app/api/sales.py - FIXED VERSION +# services/data/app/api/sales.py - UPDATED WITH UNIFIED AUTH # ================================================================ -"""Sales data API endpoints with improved error handling""" +"""Sales data API endpoints with unified authentication""" from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, Query, Response from fastapi.responses import StreamingResponse from sqlalchemy.ext.asyncio import AsyncSession -from typing import List, Optional +from typing import List, Optional, Dict, Any import uuid from datetime import datetime import base64 import structlog from app.core.database import get_db -from app.core.auth import get_current_user, AuthInfo -from app.services.sales_service import SalesService -from app.services.data_import_service import DataImportService -from app.services.messaging import publish_sales_created from app.schemas.sales import ( SalesDataCreate, SalesDataResponse, @@ -26,75 +22,163 @@ from app.schemas.sales import ( SalesValidationResult, SalesExportRequest ) +from app.services.sales_service import SalesService +from app.services.data_import_service import DataImportService +from app.services.messaging import ( + publish_sales_created, + publish_data_imported, + publish_export_completed +) -router = APIRouter() +# Import unified authentication from shared library +from shared.auth.decorators import ( + get_current_user_dep, + get_current_tenant_id_dep +) + +router = APIRouter(prefix="/sales", tags=["sales"]) logger = structlog.get_logger() @router.post("/", response_model=SalesDataResponse) async def create_sales_record( sales_data: SalesDataCreate, - db: AsyncSession = Depends(get_db), - current_user: AuthInfo = Depends(get_current_user) + tenant_id: str = Depends(get_current_tenant_id_dep), + current_user: Dict[str, Any] = Depends(get_current_user_dep), + db: AsyncSession = Depends(get_db) ): """Create a new sales record""" try: - logger.debug("API: Creating sales record", product=sales_data.product_name, quantity=sales_data.quantity_sold) + logger.debug("Creating sales record", + product=sales_data.product_name, + quantity=sales_data.quantity_sold, + tenant_id=tenant_id, + user_id=current_user["user_id"]) + + # Override tenant_id from token/header + sales_data.tenant_id = tenant_id record = await SalesService.create_sales_record(sales_data, db) - # Publish event (with error handling) + # Publish event (non-blocking) try: await publish_sales_created({ - "tenant_id": str(sales_data.tenant_id), + "tenant_id": tenant_id, "product_name": sales_data.product_name, "quantity_sold": sales_data.quantity_sold, "revenue": sales_data.revenue, "source": sales_data.source, + "created_by": current_user["user_id"], "timestamp": datetime.utcnow().isoformat() }) except Exception as pub_error: logger.warning("Failed to publish sales created event", error=str(pub_error)) - # Continue processing - event publishing failure shouldn't break the API + # Continue - event failure shouldn't break API - logger.debug("Successfully created sales record", record_id=record.id) + logger.info("Successfully created sales record", + record_id=record.id, + tenant_id=tenant_id) return record except Exception as e: - logger.error("Failed to create sales record", error=str(e)) - import traceback - logger.error("Sales creation traceback", traceback=traceback.format_exc()) + logger.error("Failed to create sales record", + error=str(e), + tenant_id=tenant_id) raise HTTPException(status_code=500, detail=f"Failed to create sales record: {str(e)}") -@router.post("/query", response_model=List[SalesDataResponse]) -async def get_sales_data( - query: SalesDataQuery, - db: AsyncSession = Depends(get_db), - current_user: AuthInfo = Depends(get_current_user) +@router.post("/bulk", response_model=List[SalesDataResponse]) +async def create_bulk_sales( + sales_data: List[SalesDataCreate], + tenant_id: str = Depends(get_current_tenant_id_dep), + current_user: Dict[str, Any] = Depends(get_current_user_dep), + db: AsyncSession = Depends(get_db) ): - """Get sales data by query parameters""" + """Create multiple sales records""" try: - logger.debug("API: Querying sales data", tenant_id=query.tenant_id) + logger.debug("Creating bulk sales records", + count=len(sales_data), + tenant_id=tenant_id) - records = await SalesService.get_sales_data(query, db) + # Override tenant_id for all records + for record in sales_data: + record.tenant_id = tenant_id - logger.debug("Successfully retrieved sales data", count=len(records)) + records = await SalesService.create_bulk_sales(sales_data, db) + + # Publish event + try: + await publish_data_imported({ + "tenant_id": tenant_id, + "type": "bulk_create", + "records_created": len(records), + "created_by": current_user["user_id"], + "timestamp": datetime.utcnow().isoformat() + }) + except Exception as pub_error: + logger.warning("Failed to publish bulk import event", error=str(pub_error)) + + logger.info("Successfully created bulk sales records", + count=len(records), + tenant_id=tenant_id) return records except Exception as e: - logger.error("Failed to query sales data", error=str(e)) + logger.error("Failed to create bulk sales records", + error=str(e), + tenant_id=tenant_id) + raise HTTPException(status_code=500, detail=f"Failed to create bulk sales records: {str(e)}") + +@router.get("/", response_model=List[SalesDataResponse]) +async def get_sales_data( + start_date: Optional[datetime] = Query(None), + end_date: Optional[datetime] = Query(None), + product_name: Optional[str] = Query(None), + tenant_id: str = Depends(get_current_tenant_id_dep), + current_user: Dict[str, Any] = Depends(get_current_user_dep), + db: AsyncSession = Depends(get_db) +): + """Get sales data with filters""" + try: + logger.debug("Querying sales data", + tenant_id=tenant_id, + start_date=start_date, + end_date=end_date, + product_name=product_name) + + query = SalesDataQuery( + tenant_id=tenant_id, + start_date=start_date, + end_date=end_date, + product_name=product_name + ) + + records = await SalesService.get_sales_data(query, db) + + logger.debug("Successfully retrieved sales data", + count=len(records), + tenant_id=tenant_id) + return records + + except Exception as e: + logger.error("Failed to query sales data", + error=str(e), + tenant_id=tenant_id) raise HTTPException(status_code=500, detail=f"Failed to query sales data: {str(e)}") @router.post("/import", response_model=SalesImportResult) async def import_sales_data( - tenant_id: str = Form(...), - file_format: str = Form(...), file: UploadFile = File(...), - db: AsyncSession = Depends(get_db), - current_user: AuthInfo = Depends(get_current_user) + file_format: str = Form(...), + tenant_id: str = Depends(get_current_tenant_id_dep), + current_user: Dict[str, Any] = Depends(get_current_user_dep), + db: AsyncSession = Depends(get_db) ): """Import sales data from file""" try: - logger.debug("API: Importing sales data", tenant_id=tenant_id, format=file_format, filename=file.filename) + logger.info("Importing sales data", + tenant_id=tenant_id, + format=file_format, + filename=file.filename, + user_id=current_user["user_id"]) # Read file content content = await file.read() @@ -102,100 +186,78 @@ async def import_sales_data( # Process import result = await DataImportService.process_upload( - tenant_id, file_content, file_format, db + tenant_id, + file_content, + file_format, + db, + user_id=current_user["user_id"] ) if result["success"]: - # Publish event (with error handling) + # Publish event try: - await data_publisher.publish_data_imported({ + await publish_data_imported({ "tenant_id": tenant_id, - "type": "bulk_import", + "type": "file_import", "format": file_format, "filename": file.filename, "records_created": result["records_created"], + "imported_by": current_user["user_id"], "timestamp": datetime.utcnow().isoformat() }) except Exception as pub_error: - logger.warning("Failed to publish data imported event", error=str(pub_error)) - # Continue processing + logger.warning("Failed to publish import event", error=str(pub_error)) - logger.debug("Import completed", success=result["success"], records_created=result.get("records_created", 0)) + logger.info("Import completed", + success=result["success"], + records_created=result.get("records_created", 0), + tenant_id=tenant_id) return result except Exception as e: - logger.error("Failed to import sales data", error=str(e)) - import traceback - logger.error("Sales import traceback", traceback=traceback.format_exc()) + logger.error("Failed to import sales data", + error=str(e), + tenant_id=tenant_id) raise HTTPException(status_code=500, detail=f"Failed to import sales data: {str(e)}") -@router.post("/import/json", response_model=SalesImportResult) -async def import_sales_json( - import_data: SalesDataImport, - db: AsyncSession = Depends(get_db), - current_user: AuthInfo = Depends(get_current_user) -): - """Import sales data from JSON""" - try: - logger.debug("API: Importing JSON sales data", tenant_id=import_data.tenant_id) - - result = await DataImportService.process_upload( - str(import_data.tenant_id), - import_data.data, - import_data.data_format, - db - ) - - if result["success"]: - # Publish event (with error handling) - try: - await publish_data_imported({ - "tenant_id": str(import_data.tenant_id), - "type": "json_import", - "records_created": result["records_created"], - "timestamp": datetime.utcnow().isoformat() - }) - except Exception as pub_error: - logger.warning("Failed to publish JSON import event", error=str(pub_error)) - # Continue processing - - logger.debug("JSON import completed", success=result["success"], records_created=result.get("records_created", 0)) - return result - - except Exception as e: - logger.error("Failed to import JSON sales data", error=str(e)) - import traceback - logger.error("JSON import traceback", traceback=traceback.format_exc()) - raise HTTPException(status_code=500, detail=f"Failed to import JSON sales data: {str(e)}") - @router.post("/import/validate", response_model=SalesValidationResult) async def validate_import_data( import_data: SalesDataImport, - current_user: AuthInfo = Depends(get_current_user) + tenant_id: str = Depends(get_current_tenant_id_dep), + current_user: Dict[str, Any] = Depends(get_current_user_dep) ): """Validate import data before processing""" try: - logger.debug("API: Validating import data", tenant_id=import_data.tenant_id) + logger.debug("Validating import data", tenant_id=tenant_id) + + # Override tenant_id + import_data.tenant_id = tenant_id validation = await DataImportService.validate_import_data( import_data.model_dump() ) - logger.debug("Validation completed", is_valid=validation.get("is_valid", False)) + logger.debug("Validation completed", + is_valid=validation.get("is_valid", False), + tenant_id=tenant_id) return validation except Exception as e: - logger.error("Failed to validate import data", error=str(e)) + logger.error("Failed to validate import data", + error=str(e), + tenant_id=tenant_id) raise HTTPException(status_code=500, detail=f"Failed to validate import data: {str(e)}") @router.get("/import/template/{format_type}") async def get_import_template( format_type: str, - current_user: AuthInfo = Depends(get_current_user) + current_user: Dict[str, Any] = Depends(get_current_user_dep) ): """Get import template for specified format""" try: - logger.debug("API: Getting import template", format=format_type) + logger.debug("Getting import template", + format=format_type, + user_id=current_user["user_id"]) template = await DataImportService.get_import_template(format_type) @@ -230,21 +292,22 @@ async def get_import_template( raise except Exception as e: logger.error("Failed to generate import template", error=str(e)) - import traceback - logger.error("Template generation traceback", traceback=traceback.format_exc()) raise HTTPException(status_code=500, detail=f"Failed to generate template: {str(e)}") -@router.get("/analytics/{tenant_id}") +@router.get("/analytics") async def get_sales_analytics( - tenant_id: str, start_date: Optional[datetime] = Query(None, description="Start date"), end_date: Optional[datetime] = Query(None, description="End date"), - db: AsyncSession = Depends(get_db), - current_user: AuthInfo = Depends(get_current_user) + tenant_id: str = Depends(get_current_tenant_id_dep), + current_user: Dict[str, Any] = Depends(get_current_user_dep), + db: AsyncSession = Depends(get_db) ): """Get sales analytics for tenant""" try: - logger.debug("API: Getting sales analytics", tenant_id=tenant_id) + logger.debug("Getting sales analytics", + tenant_id=tenant_id, + start_date=start_date, + end_date=end_date) analytics = await SalesService.get_sales_analytics( tenant_id, start_date, end_date, db @@ -254,22 +317,27 @@ async def get_sales_analytics( return analytics except Exception as e: - logger.error("Failed to generate sales analytics", error=str(e)) + logger.error("Failed to generate sales analytics", + error=str(e), + tenant_id=tenant_id) raise HTTPException(status_code=500, detail=f"Failed to generate analytics: {str(e)}") -@router.post("/export/{tenant_id}") +@router.post("/export") async def export_sales_data( - tenant_id: str, export_format: str = Query("csv", description="Export format: csv, excel, json"), start_date: Optional[datetime] = Query(None, description="Start date"), end_date: Optional[datetime] = Query(None, description="End date"), products: Optional[List[str]] = Query(None, description="Filter by products"), - db: AsyncSession = Depends(get_db), - current_user: AuthInfo = Depends(get_current_user) + tenant_id: str = Depends(get_current_tenant_id_dep), + current_user: Dict[str, Any] = Depends(get_current_user_dep), + db: AsyncSession = Depends(get_db) ): """Export sales data in specified format""" try: - logger.debug("API: Exporting sales data", tenant_id=tenant_id, format=export_format) + logger.info("Exporting sales data", + tenant_id=tenant_id, + format=export_format, + user_id=current_user["user_id"]) export_result = await SalesService.export_sales_data( tenant_id, export_format, start_date, end_date, products, db @@ -278,7 +346,21 @@ async def export_sales_data( if not export_result: raise HTTPException(status_code=404, detail="No data found for export") - logger.debug("Export completed successfully", tenant_id=tenant_id, format=export_format) + # Publish export event + try: + await publish_export_completed({ + "tenant_id": tenant_id, + "format": export_format, + "exported_by": current_user["user_id"], + "record_count": export_result.get("record_count", 0), + "timestamp": datetime.utcnow().isoformat() + }) + except Exception as pub_error: + logger.warning("Failed to publish export event", error=str(pub_error)) + + logger.info("Export completed successfully", + tenant_id=tenant_id, + format=export_format) return StreamingResponse( iter([export_result["content"]]), @@ -289,29 +371,91 @@ async def export_sales_data( except HTTPException: raise except Exception as e: - logger.error("Failed to export sales data", error=str(e)) + logger.error("Failed to export sales data", + error=str(e), + tenant_id=tenant_id) raise HTTPException(status_code=500, detail=f"Failed to export sales data: {str(e)}") @router.delete("/{record_id}") async def delete_sales_record( record_id: str, - db: AsyncSession = Depends(get_db), - current_user: AuthInfo = Depends(get_current_user) + tenant_id: str = Depends(get_current_tenant_id_dep), + current_user: Dict[str, Any] = Depends(get_current_user_dep), + db: AsyncSession = Depends(get_db) ): """Delete a sales record""" try: - logger.debug("API: Deleting sales record", record_id=record_id) + logger.info("Deleting sales record", + record_id=record_id, + tenant_id=tenant_id, + user_id=current_user["user_id"]) + + # Verify record belongs to tenant before deletion + record = await SalesService.get_sales_record(record_id, db) + if not record or record.tenant_id != tenant_id: + raise HTTPException(status_code=404, detail="Sales record not found") success = await SalesService.delete_sales_record(record_id, db) if not success: raise HTTPException(status_code=404, detail="Sales record not found") - logger.debug("Sales record deleted successfully", record_id=record_id) + logger.info("Sales record deleted successfully", + record_id=record_id, + tenant_id=tenant_id) return {"status": "success", "message": "Sales record deleted successfully"} except HTTPException: raise except Exception as e: - logger.error("Failed to delete sales record", error=str(e)) - raise HTTPException(status_code=500, detail=f"Failed to delete sales record: {str(e)}") \ No newline at end of file + logger.error("Failed to delete sales record", + error=str(e), + tenant_id=tenant_id) + raise HTTPException(status_code=500, detail=f"Failed to delete sales record: {str(e)}") + +@router.get("/summary") +async def get_sales_summary( + period: str = Query("daily", description="Summary period: daily, weekly, monthly"), + tenant_id: str = Depends(get_current_tenant_id_dep), + current_user: Dict[str, Any] = Depends(get_current_user_dep), + db: AsyncSession = Depends(get_db) +): + """Get sales summary for specified period""" + try: + logger.debug("Getting sales summary", + tenant_id=tenant_id, + period=period) + + summary = await SalesService.get_sales_summary(tenant_id, period, db) + + logger.debug("Summary generated successfully", tenant_id=tenant_id) + return summary + + except Exception as e: + logger.error("Failed to generate sales summary", + error=str(e), + tenant_id=tenant_id) + raise HTTPException(status_code=500, detail=f"Failed to generate summary: {str(e)}") + +@router.get("/products") +async def get_products_list( + tenant_id: str = Depends(get_current_tenant_id_dep), + current_user: Dict[str, Any] = Depends(get_current_user_dep), + db: AsyncSession = Depends(get_db) +): + """Get list of all products with sales data""" + try: + logger.debug("Getting products list", tenant_id=tenant_id) + + products = await SalesService.get_products_list(tenant_id, db) + + logger.debug("Products list retrieved", + count=len(products), + tenant_id=tenant_id) + return products + + except Exception as e: + logger.error("Failed to get products list", + error=str(e), + tenant_id=tenant_id) + raise HTTPException(status_code=500, detail=f"Failed to get products list: {str(e)}") \ No newline at end of file diff --git a/services/training/app/api/training.py b/services/training/app/api/training.py index 8c3632ee..218564c6 100644 --- a/services/training/app/api/training.py +++ b/services/training/app/api/training.py @@ -1,299 +1,472 @@ -# services/training/app/api/training.py -""" -Training API endpoints for the training service -""" +# ================================================================ +# services/training/app/api/training.py - UPDATED WITH UNIFIED AUTH +# ================================================================ +"""Training API endpoints with unified authentication""" -from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks -from sqlalchemy.ext.asyncio import AsyncSession -from typing import Dict, List, Any, Optional -import logging +from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, Query +from typing import List, Optional, Dict, Any from datetime import datetime -import uuid +import structlog -from app.core.database import get_db -from app.core.auth import get_current_tenant_id from app.schemas.training import ( - TrainingJobRequest, + TrainingJobRequest, TrainingJobResponse, - TrainingStatusResponse, - SingleProductTrainingRequest + TrainingJobStatus, + SingleProductTrainingRequest, + TrainingJobProgress, + DataValidationRequest, + DataValidationResponse ) from app.services.training_service import TrainingService -from app.services.messaging import publish_job_started, publish_job_cancelled, publish_product_training_started -from shared.monitoring.metrics import MetricsCollector +from app.services.messaging import ( + publish_job_started, + publish_job_completed, + publish_job_failed, + publish_job_progress, + publish_product_training_started, + publish_product_training_completed +) -logger = logging.getLogger(__name__) -router = APIRouter() -metrics = MetricsCollector("training-service") +# Import unified authentication from shared library +from shared.auth.decorators import ( + get_current_user_dep, + get_current_tenant_id_dep, + require_role +) -# Initialize training service -training_service = TrainingService() +logger = structlog.get_logger() +router = APIRouter(prefix="/training", tags=["training"]) @router.post("/jobs", response_model=TrainingJobResponse) async def start_training_job( request: TrainingJobRequest, background_tasks: BackgroundTasks, - tenant_id: str = Depends(get_current_tenant_id), - db: AsyncSession = Depends(get_db) + tenant_id: str = Depends(get_current_tenant_id_dep), + current_user: Dict[str, Any] = Depends(get_current_user_dep), + training_service: TrainingService = Depends() ): - """ - Start a new training job for all products of a tenant. - Replaces the old Celery-based training system. - """ + """Start a new training job for all products""" try: - logger.info(f"Starting training job for tenant {tenant_id}") - metrics.increment_counter("training_jobs_started") + logger.info("Starting training job", + tenant_id=tenant_id, + user_id=current_user["user_id"], + config=request.dict()) - # Generate job ID - job_id = f"training_{tenant_id}_{uuid.uuid4().hex[:8]}" - - # Create training job record - training_job = await training_service.create_training_job( - db=db, + # Create training job + job = await training_service.create_training_job( tenant_id=tenant_id, - job_id=job_id, + user_id=current_user["user_id"], config=request.dict() ) + # Publish job started event + try: + await publish_job_started( + job_id=job.job_id, + tenant_id=tenant_id, + config=request.dict() + ) + except Exception as e: + logger.warning("Failed to publish job started event", error=str(e)) + # Start training in background background_tasks.add_task( training_service.execute_training_job, - db, - job_id, - tenant_id, - request + job.job_id ) - # Publish training started event - await publish_job_started(job_id, tenant_id, request.dict()) + logger.info("Training job created", + job_id=job.job_id, + tenant_id=tenant_id) - return TrainingJobResponse( - job_id=job_id, - status="started", - message="Training job started successfully", - tenant_id=tenant_id, - created_at=training_job.start_time, - estimated_duration_minutes=request.estimated_duration or 15 - ) + return job except Exception as e: - logger.error(f"Failed to start training job: {str(e)}") - metrics.increment_counter("training_jobs_failed") + logger.error("Failed to start training job", + error=str(e), + tenant_id=tenant_id) raise HTTPException(status_code=500, detail=f"Failed to start training job: {str(e)}") -@router.get("/jobs/{job_id}/status", response_model=TrainingStatusResponse) -async def get_training_status( - job_id: str, - tenant_id: str = Depends(get_current_tenant_id), - db: AsyncSession = Depends(get_db) +@router.get("/jobs", response_model=List[TrainingJobResponse]) +async def get_training_jobs( + status: Optional[TrainingJobStatus] = Query(None), + limit: int = Query(100, ge=1, le=1000), + offset: int = Query(0, ge=0), + tenant_id: str = Depends(get_current_tenant_id_dep), + current_user: Dict[str, Any] = Depends(get_current_user_dep), + training_service: TrainingService = Depends() ): - """ - Get the status of a training job. - Provides real-time progress updates. - """ + """Get training jobs for tenant""" try: - # Get job status from database - job_status = await training_service.get_job_status(db, job_id, tenant_id) + logger.debug("Getting training jobs", + tenant_id=tenant_id, + status=status, + limit=limit, + offset=offset) - if not job_status: - raise HTTPException(status_code=404, detail="Training job not found") - - return TrainingStatusResponse( - job_id=job_id, - status=job_status.status, - progress=job_status.progress, - current_step=job_status.current_step, - started_at=job_status.start_time, - completed_at=job_status.end_time, - results=job_status.results, - error_message=job_status.error_message + jobs = await training_service.get_training_jobs( + tenant_id=tenant_id, + status=status, + limit=limit, + offset=offset ) + logger.debug("Retrieved training jobs", + count=len(jobs), + tenant_id=tenant_id) + + return jobs + + except Exception as e: + logger.error("Failed to get training jobs", + error=str(e), + tenant_id=tenant_id) + raise HTTPException(status_code=500, detail=f"Failed to get training jobs: {str(e)}") + +@router.get("/jobs/{job_id}", response_model=TrainingJobResponse) +async def get_training_job( + job_id: str, + tenant_id: str = Depends(get_current_tenant_id_dep), + current_user: Dict[str, Any] = Depends(get_current_user_dep), + training_service: TrainingService = Depends() +): + """Get specific training job details""" + try: + logger.debug("Getting training job", + job_id=job_id, + tenant_id=tenant_id) + + job = await training_service.get_training_job(job_id) + + # Verify tenant access + if job.tenant_id != tenant_id: + logger.warning("Unauthorized job access attempt", + job_id=job_id, + tenant_id=tenant_id, + job_tenant_id=job.tenant_id) + raise HTTPException(status_code=404, detail="Job not found") + + return job + except HTTPException: raise except Exception as e: - logger.error(f"Failed to get training status: {str(e)}") - raise HTTPException(status_code=500, detail=f"Failed to get training status: {str(e)}") + logger.error("Failed to get training job", + error=str(e), + job_id=job_id) + raise HTTPException(status_code=500, detail=f"Failed to get training job: {str(e)}") + +@router.get("/jobs/{job_id}/progress", response_model=TrainingJobProgress) +async def get_training_progress( + job_id: str, + tenant_id: str = Depends(get_current_tenant_id_dep), + current_user: Dict[str, Any] = Depends(get_current_user_dep), + training_service: TrainingService = Depends() +): + """Get real-time training progress""" + try: + logger.debug("Getting training progress", + job_id=job_id, + tenant_id=tenant_id) + + # Verify job belongs to tenant + job = await training_service.get_training_job(job_id) + if job.tenant_id != tenant_id: + raise HTTPException(status_code=404, detail="Job not found") + + progress = await training_service.get_job_progress(job_id) + + return progress + + except HTTPException: + raise + except Exception as e: + logger.error("Failed to get training progress", + error=str(e), + job_id=job_id) + raise HTTPException(status_code=500, detail=f"Failed to get training progress: {str(e)}") + +@router.post("/jobs/{job_id}/cancel") +async def cancel_training_job( + job_id: str, + tenant_id: str = Depends(get_current_tenant_id_dep), + current_user: Dict[str, Any] = Depends(get_current_user_dep), + training_service: TrainingService = Depends() +): + """Cancel a running training job""" + try: + logger.info("Cancelling training job", + job_id=job_id, + tenant_id=tenant_id, + user_id=current_user["user_id"]) + + job = await training_service.get_training_job(job_id) + + # Verify tenant access + if job.tenant_id != tenant_id: + raise HTTPException(status_code=404, detail="Job not found") + + await training_service.cancel_training_job(job_id) + + # Publish cancellation event + try: + await publish_job_failed( + job_id=job_id, + tenant_id=tenant_id, + error="Job cancelled by user", + failed_at="cancellation" + ) + except Exception as e: + logger.warning("Failed to publish cancellation event", error=str(e)) + + logger.info("Training job cancelled", job_id=job_id) + + return {"message": "Job cancelled successfully", "job_id": job_id} + + except HTTPException: + raise + except Exception as e: + logger.error("Failed to cancel training job", + error=str(e), + job_id=job_id) + raise HTTPException(status_code=500, detail=f"Failed to cancel training job: {str(e)}") @router.post("/products/{product_name}", response_model=TrainingJobResponse) async def train_single_product( product_name: str, request: SingleProductTrainingRequest, background_tasks: BackgroundTasks, - tenant_id: str = Depends(get_current_tenant_id), - db: AsyncSession = Depends(get_db) + tenant_id: str = Depends(get_current_tenant_id_dep), + current_user: Dict[str, Any] = Depends(get_current_user_dep), + training_service: TrainingService = Depends() ): - """ - Train a model for a single product. - Useful for quick model updates or new products. - """ + """Train model for a single product""" try: - logger.info(f"Starting single product training for {product_name}, tenant {tenant_id}") - metrics.increment_counter("single_product_training_started") + logger.info("Training single product", + product_name=product_name, + tenant_id=tenant_id, + user_id=current_user["user_id"]) - # Generate job ID - job_id = f"single_{tenant_id}_{product_name}_{uuid.uuid4().hex[:8]}" - - # Create training job record - training_job = await training_service.create_single_product_job( - db=db, + # Create training job for single product + job = await training_service.create_single_product_job( tenant_id=tenant_id, + user_id=current_user["user_id"], product_name=product_name, - job_id=job_id, config=request.dict() ) + # Publish event + try: + await publish_product_training_started( + job_id=job.job_id, + tenant_id=tenant_id, + product_name=product_name + ) + except Exception as e: + logger.warning("Failed to publish product training event", error=str(e)) + # Start training in background background_tasks.add_task( training_service.execute_single_product_training, - db, - job_id, - tenant_id, - product_name, - request + job.job_id, + product_name + ) + + logger.info("Single product training started", + job_id=job.job_id, + product_name=product_name) + + return job + + except Exception as e: + logger.error("Failed to train single product", + error=str(e), + product_name=product_name, + tenant_id=tenant_id) + raise HTTPException(status_code=500, detail=f"Failed to train product: {str(e)}") + +@router.post("/validate", response_model=DataValidationResponse) +async def validate_training_data( + request: DataValidationRequest, + tenant_id: str = Depends(get_current_tenant_id_dep), + current_user: Dict[str, Any] = Depends(get_current_user_dep), + training_service: TrainingService = Depends() +): + """Validate data before training""" + try: + logger.debug("Validating training data", + tenant_id=tenant_id, + products=request.products) + + validation_result = await training_service.validate_training_data( + tenant_id=tenant_id, + products=request.products, + min_data_points=request.min_data_points + ) + + logger.debug("Data validation completed", + is_valid=validation_result.is_valid, + tenant_id=tenant_id) + + return validation_result + + except Exception as e: + logger.error("Failed to validate training data", + error=str(e), + tenant_id=tenant_id) + raise HTTPException(status_code=500, detail=f"Failed to validate data: {str(e)}") + +@router.get("/models") +async def get_trained_models( + product_name: Optional[str] = Query(None), + tenant_id: str = Depends(get_current_tenant_id_dep), + current_user: Dict[str, Any] = Depends(get_current_user_dep), + training_service: TrainingService = Depends() +): + """Get list of trained models""" + try: + logger.debug("Getting trained models", + tenant_id=tenant_id, + product_name=product_name) + + models = await training_service.get_trained_models( + tenant_id=tenant_id, + product_name=product_name + ) + + logger.debug("Retrieved trained models", + count=len(models), + tenant_id=tenant_id) + + return models + + except Exception as e: + logger.error("Failed to get trained models", + error=str(e), + tenant_id=tenant_id) + raise HTTPException(status_code=500, detail=f"Failed to get models: {str(e)}") + +@router.delete("/models/{model_id}") +@require_role("admin") # Only admins can delete models +async def delete_model( + model_id: str, + tenant_id: str = Depends(get_current_tenant_id_dep), + current_user: Dict[str, Any] = Depends(get_current_user_dep), + training_service: TrainingService = Depends() +): + """Delete a trained model (admin only)""" + try: + logger.info("Deleting model", + model_id=model_id, + tenant_id=tenant_id, + admin_id=current_user["user_id"]) + + # Verify model belongs to tenant + model = await training_service.get_model(model_id) + if model.tenant_id != tenant_id: + raise HTTPException(status_code=404, detail="Model not found") + + success = await training_service.delete_model(model_id) + + if not success: + raise HTTPException(status_code=404, detail="Model not found") + + logger.info("Model deleted successfully", model_id=model_id) + + return {"message": "Model deleted successfully", "model_id": model_id} + + except HTTPException: + raise + except Exception as e: + logger.error("Failed to delete model", + error=str(e), + model_id=model_id) + raise HTTPException(status_code=500, detail=f"Failed to delete model: {str(e)}") + +@router.get("/stats") +async def get_training_stats( + start_date: Optional[datetime] = Query(None), + end_date: Optional[datetime] = Query(None), + tenant_id: str = Depends(get_current_tenant_id_dep), + current_user: Dict[str, Any] = Depends(get_current_user_dep), + training_service: TrainingService = Depends() +): + """Get training statistics for tenant""" + try: + logger.debug("Getting training stats", + tenant_id=tenant_id, + start_date=start_date, + end_date=end_date) + + stats = await training_service.get_training_stats( + tenant_id=tenant_id, + start_date=start_date, + end_date=end_date + ) + + logger.debug("Training stats retrieved", tenant_id=tenant_id) + + return stats + + except Exception as e: + logger.error("Failed to get training stats", + error=str(e), + tenant_id=tenant_id) + raise HTTPException(status_code=500, detail=f"Failed to get stats: {str(e)}") + +@router.post("/retrain/all") +async def retrain_all_products( + request: TrainingJobRequest, + background_tasks: BackgroundTasks, + tenant_id: str = Depends(get_current_tenant_id_dep), + current_user: Dict[str, Any] = Depends(get_current_user_dep), + training_service: TrainingService = Depends() +): + """Retrain all products with existing models""" + try: + logger.info("Retraining all products", + tenant_id=tenant_id, + user_id=current_user["user_id"]) + + # Check if models exist + existing_models = await training_service.get_trained_models(tenant_id) + if not existing_models: + raise HTTPException( + status_code=400, + detail="No existing models found. Please run initial training first." + ) + + # Create retraining job + job = await training_service.create_training_job( + tenant_id=tenant_id, + user_id=current_user["user_id"], + config={**request.dict(), "is_retrain": True} ) # Publish event - await publish_product_training_started(job_id, tenant_id, product_name) - - return TrainingJobResponse( - job_id=job_id, - status="started", - message=f"Single product training started for {product_name}", - tenant_id=tenant_id, - created_at=training_job.start_time, - estimated_duration_minutes=5 - ) - - except Exception as e: - logger.error(f"Failed to start single product training: {str(e)}") - metrics.increment_counter("single_product_training_failed") - raise HTTPException(status_code=500, detail=f"Failed to start training: {str(e)}") - -@router.get("/jobs", response_model=List[TrainingStatusResponse]) -async def list_training_jobs( - limit: int = 10, - status: Optional[str] = None, - tenant_id: str = Depends(get_current_tenant_id), - db: AsyncSession = Depends(get_db) -): - """ - List training jobs for a tenant. - """ - try: - jobs = await training_service.list_training_jobs( - db=db, - tenant_id=tenant_id, - limit=limit, - status_filter=status - ) - - return [ - TrainingStatusResponse( + try: + await publish_job_started( job_id=job.job_id, - status=job.status, - progress=job.progress, - current_step=job.current_step, - started_at=job.start_time, - completed_at=job.end_time, - results=job.results, - error_message=job.error_message + tenant_id=tenant_id, + config={**request.dict(), "is_retrain": True} ) - for job in jobs - ] + except Exception as e: + logger.warning("Failed to publish retrain event", error=str(e)) - except Exception as e: - logger.error(f"Failed to list training jobs: {str(e)}") - raise HTTPException(status_code=500, detail=f"Failed to list training jobs: {str(e)}") - -@router.post("/jobs/{job_id}/cancel") -async def cancel_training_job( - job_id: str, - tenant_id: str = Depends(get_current_tenant_id), - db: AsyncSession = Depends(get_db) -): - """ - Cancel a running training job. - """ - try: - logger.info(f"Cancelling training job {job_id} for tenant {tenant_id}") - - # Update job status to cancelled - success = await training_service.cancel_training_job(db, job_id, tenant_id) - - if not success: - raise HTTPException(status_code=404, detail="Training job not found or cannot be cancelled") - - # Publish cancellation event - await publish_job_cancelled(job_id, tenant_id) - - return {"message": "Training job cancelled successfully"} - - except HTTPException: - raise - except Exception as e: - logger.error(f"Failed to cancel training job: {str(e)}") - raise HTTPException(status_code=500, detail=f"Failed to cancel training job: {str(e)}") - -@router.get("/jobs/{job_id}/logs") -async def get_training_logs( - job_id: str, - tenant_id: str = Depends(get_current_tenant_id), - db: AsyncSession = Depends(get_db) -): - """ - Get detailed logs for a training job. - """ - try: - logs = await training_service.get_training_logs(db, job_id, tenant_id) - - if not logs: - raise HTTPException(status_code=404, detail="Training job not found") - - return {"job_id": job_id, "logs": logs} - - except HTTPException: - raise - except Exception as e: - logger.error(f"Failed to get training logs: {str(e)}") - raise HTTPException(status_code=500, detail=f"Failed to get training logs: {str(e)}") - -@router.post("/validate") -async def validate_training_data( - request: TrainingJobRequest, - tenant_id: str = Depends(get_current_tenant_id), - db: AsyncSession = Depends(get_db) -): - """ - Validate training data before starting a job. - Provides early feedback on data quality issues. - """ - try: - logger.info(f"Validating training data for tenant {tenant_id}") - - # Perform data validation - validation_result = await training_service.validate_training_data( - db=db, - tenant_id=tenant_id, - config=request.dict() + # Start retraining in background + background_tasks.add_task( + training_service.execute_training_job, + job.job_id ) - return { - "is_valid": validation_result["is_valid"], - "issues": validation_result.get("issues", []), - "recommendations": validation_result.get("recommendations", []), - "estimated_training_time": validation_result.get("estimated_time_minutes", 15) - } + logger.info("Retraining job created", job_id=job.job_id) + return job + + except HTTPException: + raise except Exception as e: - logger.error(f"Failed to validate training data: {str(e)}") - raise HTTPException(status_code=500, detail=f"Failed to validate training data: {str(e)}") - -@router.get("/health") -async def health_check(): - """Health check for the training service""" - return { - "status": "healthy", - "service": "training-service", - "timestamp": datetime.now().isoformat() - } \ No newline at end of file + logger.error("Failed to start retraining", + error=str(e), + tenant_id=tenant_id) + raise HTTPException(status_code=500, detail=f"Failed to start retraining: {str(e)}") \ No newline at end of file diff --git a/shared/auth/decorators.py b/shared/auth/decorators.py index 5080531e..213bc88c 100644 --- a/shared/auth/decorators.py +++ b/shared/auth/decorators.py @@ -1,14 +1,24 @@ -# shared/auth/decorators.py - NEW FILE """ -Authentication decorators for microservices +Unified authentication decorators for microservices +Designed to work with gateway authentication middleware """ from functools import wraps -from fastapi import HTTPException, status, Request -from typing import Callable, Optional +from fastapi import HTTPException, status, Request, Depends +from fastapi.security import HTTPBearer +from typing import Callable, Optional, Dict, Any +import structlog + +logger = structlog.get_logger() + +# Bearer token scheme for services that need it +security = HTTPBearer(auto_error=False) def require_authentication(func: Callable) -> Callable: - """Decorator to require authentication - assumes gateway has validated token""" + """ + Decorator to require authentication - trusts gateway validation + Services behind the gateway should use this decorator + """ @wraps(func) async def wrapper(*args, **kwargs): @@ -20,7 +30,6 @@ def require_authentication(func: Callable) -> Callable: break if not request: - # Check kwargs request = kwargs.get('request') if not request: @@ -31,10 +40,14 @@ def require_authentication(func: Callable) -> Callable: # Check if user context exists (set by gateway) if not hasattr(request.state, 'user') or not request.state.user: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Authentication required" - ) + # Check headers as fallback (for direct service calls in dev) + user_info = extract_user_from_headers(request) + if not user_info: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Authentication required" + ) + request.state.user = user_info return await func(*args, **kwargs) @@ -45,32 +58,118 @@ def require_tenant_access(func: Callable) -> Callable: @wraps(func) async def wrapper(*args, **kwargs): - # Find request object request = None for arg in args: if isinstance(arg, Request): request = arg break + if not request: + request = kwargs.get('request') + if not request or not hasattr(request.state, 'tenant_id'): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Tenant access required" - ) + # Try to extract from headers + tenant_id = extract_tenant_from_headers(request) + if not tenant_id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Tenant access required" + ) + request.state.tenant_id = tenant_id return await func(*args, **kwargs) return wrapper -def get_current_user(request: Request) -> dict: - """Get current user from request state""" - if not hasattr(request.state, 'user'): +def require_role(role: str): + """Decorator to require specific role""" + + def decorator(func: Callable) -> Callable: + @wraps(func) + async def wrapper(*args, **kwargs): + request = None + for arg in args: + if isinstance(arg, Request): + request = arg + break + + if not request: + request = kwargs.get('request') + + user = get_current_user(request) + user_role = user.get('role', '').lower() + + if user_role != role.lower() and user_role != 'admin': + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"{role} role required" + ) + + return await func(*args, **kwargs) + + return wrapper + + return decorator + +def get_current_user(request: Request) -> Dict[str, Any]: + """Get current user from request state or headers""" + if hasattr(request.state, 'user') and request.state.user: + return request.state.user + + # Fallback to headers (for dev/testing) + user_info = extract_user_from_headers(request) + if not user_info: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="User not authenticated" ) - return request.state.user + + return user_info def get_current_tenant_id(request: Request) -> Optional[str]: - """Get current tenant ID from request state""" - return getattr(request.state, 'tenant_id', None) + """Get current tenant ID from request state or headers""" + if hasattr(request.state, 'tenant_id'): + return request.state.tenant_id + + # Fallback to headers + return extract_tenant_from_headers(request) + +def extract_user_from_headers(request: Request) -> Optional[Dict[str, Any]]: + """Extract user information from forwarded headers (gateway sets these)""" + user_id = request.headers.get("X-User-ID") + if not user_id: + return None + + return { + "user_id": user_id, + "email": request.headers.get("X-User-Email", ""), + "role": request.headers.get("X-User-Role", "user"), + "tenant_id": request.headers.get("X-Tenant-ID"), + "permissions": request.headers.get("X-User-Permissions", "").split(",") if request.headers.get("X-User-Permissions") else [] + } + +def extract_tenant_from_headers(request: Request) -> Optional[str]: + """Extract tenant ID from headers""" + return request.headers.get("X-Tenant-ID") + +# FastAPI Dependencies for injection +async def get_current_user_dep(request: Request) -> Dict[str, Any]: + """FastAPI dependency to get current user""" + return get_current_user(request) + +async def get_current_tenant_id_dep(request: Request) -> Optional[str]: + """FastAPI dependency to get current tenant ID""" + return get_current_tenant_id(request) + +# Export all decorators and functions +__all__ = [ + 'require_authentication', + 'require_tenant_access', + 'require_role', + 'get_current_user', + 'get_current_tenant_id', + 'get_current_user_dep', + 'get_current_tenant_id_dep', + 'extract_user_from_headers', + 'extract_tenant_from_headers' +] \ No newline at end of file