diff --git a/gateway/app/routes/training.py b/gateway/app/routes/training.py index 0bea905f..a3bcab11 100644 --- a/gateway/app/routes/training.py +++ b/gateway/app/routes/training.py @@ -1,8 +1,8 @@ """ -Training routes for gateway +Training routes for gateway - FIXED VERSION """ -from fastapi import APIRouter, Request, HTTPException, Query +from fastapi import APIRouter, Request, HTTPException, Query, Response from fastapi.responses import JSONResponse import httpx import logging @@ -13,16 +13,51 @@ from app.core.config import settings logger = logging.getLogger(__name__) router = APIRouter() -@router.get("/status/{training_job_id}") -async def get_training_status(training_job_id: str, request: Request): - """Get training job status""" - try: - auth_header = request.headers.get("Authorization") +async def _proxy_training_request(request: Request, target_path: str, method: str = None): + """Proxy request to training service with user context""" + + # Handle OPTIONS requests directly for CORS + if request.method == "OPTIONS": + return Response( + status_code=200, + headers={ + "Access-Control-Allow-Origin": settings.CORS_ORIGINS_LIST, + "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type, Authorization, X-Tenant-ID", + "Access-Control-Allow-Credentials": "true", + "Access-Control-Max-Age": "86400" # Cache preflight for 24 hours + } + ) - async with httpx.AsyncClient(timeout=10.0) as client: - response = await client.get( - f"{settings.TRAINING_SERVICE_URL}/status/{training_job_id}", - headers={"Authorization": auth_header} + try: + url = f"{settings.TRAINING_SERVICE_URL}{target_path}" + + # Forward headers AND add user context from gateway auth + headers = dict(request.headers) + headers.pop("host", None) # Remove host header + + # ✅ ADD USER CONTEXT FROM GATEWAY AUTHENTICATION + # Gateway middleware already verified the token and added user to request.state + if hasattr(request.state, 'user'): + headers["X-User-ID"] = str(request.state.user.get("user_id")) + headers["X-User-Email"] = request.state.user.get("email", "") + headers["X-Tenant-ID"] = str(request.state.user.get("tenant_id")) + headers["X-User-Roles"] = ",".join(request.state.user.get("roles", [])) + headers["X-User-Permissions"] = ",".join(request.state.user.get("permissions", [])) + + # Get request body if present + body = None + request_method = method or request.method + if request_method in ["POST", "PUT", "PATCH"]: + body = await request.body() + + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.request( + method=request_method, + url=url, + headers=headers, + content=body, + params=dict(request.query_params) ) return JSONResponse( @@ -37,35 +72,18 @@ async def get_training_status(training_job_id: str, request: Request): detail="Training service unavailable" ) except Exception as e: - logger.error(f"Training status error: {e}") + logger.error(f"Training service error: {e}") raise HTTPException(status_code=500, detail="Internal server error") +@router.get("/status/{training_job_id}") +async def get_training_status(training_job_id: str, request: Request): + """Get training job status""" + return await _proxy_training_request(request, f"/training/status/{training_job_id}", "GET") + @router.get("/models") async def get_trained_models(request: Request): """Get trained models""" - try: - auth_header = request.headers.get("Authorization") - - async with httpx.AsyncClient(timeout=10.0) as client: - response = await client.get( - f"{settings.TRAINING_SERVICE_URL}/models", - headers={"Authorization": auth_header} - ) - - return JSONResponse( - status_code=response.status_code, - content=response.json() - ) - - except httpx.RequestError as e: - logger.error(f"Training service unavailable: {e}") - raise HTTPException( - status_code=503, - detail="Training service unavailable" - ) - except Exception as e: - logger.error(f"Get models error: {e}") - raise HTTPException(status_code=500, detail="Internal server error") + return await _proxy_training_request(request, "/training/models", "GET") @router.get("/jobs") async def get_training_jobs( @@ -74,59 +92,9 @@ async def get_training_jobs( offset: Optional[int] = Query(0, ge=0) ): """Get training jobs""" - try: - auth_header = request.headers.get("Authorization") - - async with httpx.AsyncClient(timeout=10.0) as client: - response = await client.get( - f"{settings.TRAINING_SERVICE_URL}/jobs", - params={"limit": limit, "offset": offset}, - headers={"Authorization": auth_header} - ) - - return JSONResponse( - status_code=response.status_code, - content=response.json() - ) - - except httpx.RequestError as e: - logger.error(f"Training service unavailable: {e}") - raise HTTPException( - status_code=503, - detail="Training service unavailable" - ) - except Exception as e: - logger.error(f"Get training jobs error: {e}") - raise HTTPException(status_code=500, detail="Internal server error") + return await _proxy_training_request(request, f"/training/jobs?limit={limit}&offset={offset}", "GET") @router.post("/jobs") async def start_training_job(request: Request): """Start a new training job - Proxy to training service""" - try: - body = await request.body() - auth_header = request.headers.get("Authorization") - - async with httpx.AsyncClient(timeout=30.0) as client: - response = await client.post( - f"{settings.TRAINING_SERVICE_URL}/training/jobs", # Correct path - content=body, - headers={ - "Content-Type": "application/json", - "Authorization": auth_header - } - ) - - return JSONResponse( - status_code=response.status_code, - content=response.json() - ) - - except httpx.RequestError as e: - logger.error(f"Training service unavailable: {e}") - raise HTTPException( - status_code=503, - detail="Training service unavailable" - ) - except Exception as e: - logger.error(f"Start training job error: {e}") - raise HTTPException(status_code=500, detail="Internal server error") \ No newline at end of file + return await _proxy_training_request(request, "/training/jobs", "POST") \ No newline at end of file diff --git a/services/training/app/api/training.py b/services/training/app/api/training.py index f0b89110..6a92779d 100644 --- a/services/training/app/api/training.py +++ b/services/training/app/api/training.py @@ -7,6 +7,7 @@ from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, Query from typing import List, Optional, Dict, Any from datetime import datetime import structlog +import uuid from app.schemas.training import ( TrainingJobRequest, @@ -27,6 +28,9 @@ from app.services.messaging import ( publish_product_training_completed ) +from sqlalchemy.ext.asyncio import AsyncSession +from app.core.database import get_db + # Import unified authentication from shared library from shared.auth.decorators import ( get_current_user_dep, @@ -43,46 +47,54 @@ async def start_training_job( background_tasks: BackgroundTasks, tenant_id: str = Depends(get_current_tenant_id_dep), current_user: Dict[str, Any] = Depends(get_current_user_dep), - training_service: TrainingService = Depends() + training_service: TrainingService = Depends(), + db: AsyncSession = Depends(get_db) # Ensure db is available ): """Start a new training job for all products""" try: - logger.info("Starting training job", - tenant_id=tenant_id, - user_id=current_user["user_id"], - config=request.dict()) + new_job_id = str(uuid.uuid4()) + + logger.info("Starting training job", + tenant_id=tenant_id, + job_id=uuid.uuid4(), + config=request.dict()) + # Create training job job = await training_service.create_training_job( + db, # Pass db here tenant_id=tenant_id, - user_id=current_user["user_id"], + job_id=new_job_id, config=request.dict() ) - + # Publish job started event try: await publish_job_started( - job_id=job.job_id, + job_id=new_job_id, tenant_id=tenant_id, config=request.dict() ) except Exception as e: logger.warning("Failed to publish job started event", error=str(e)) - + # Start training in background background_tasks.add_task( training_service.execute_training_job, - job.job_id + db, # Pass the database session + job.job_id, + job.tenant_id, + request # Pass the request object ) - - logger.info("Training job created", + + logger.info("Training job created", job_id=job.job_id, tenant_id=tenant_id) - + return job - + except Exception as e: - logger.error("Failed to start training job", + logger.error("Failed to start training job", error=str(e), tenant_id=tenant_id) raise HTTPException(status_code=500, detail=f"Failed to start training job: {str(e)}")