Fix training start

This commit is contained in:
Urtzi Alfaro
2025-07-25 17:20:39 +02:00
parent 86bf95eb89
commit ebb39aa8c9
2 changed files with 82 additions and 102 deletions

View File

@@ -1,8 +1,8 @@
""" """
Training routes for gateway Training routes for gateway - FIXED VERSION
""" """
from fastapi import APIRouter, Request, HTTPException, Query from fastapi import APIRouter, Request, HTTPException, Query, Response
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
import httpx import httpx
import logging import logging
@@ -13,16 +13,51 @@ from app.core.config import settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
@router.get("/status/{training_job_id}") async def _proxy_training_request(request: Request, target_path: str, method: str = None):
async def get_training_status(training_job_id: str, request: Request): """Proxy request to training service with user context"""
"""Get training job status"""
try:
auth_header = request.headers.get("Authorization")
async with httpx.AsyncClient(timeout=10.0) as client: # Handle OPTIONS requests directly for CORS
response = await client.get( if request.method == "OPTIONS":
f"{settings.TRAINING_SERVICE_URL}/status/{training_job_id}", return Response(
headers={"Authorization": auth_header} status_code=200,
headers={
"Access-Control-Allow-Origin": settings.CORS_ORIGINS_LIST,
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization, X-Tenant-ID",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Max-Age": "86400" # Cache preflight for 24 hours
}
)
try:
url = f"{settings.TRAINING_SERVICE_URL}{target_path}"
# Forward headers AND add user context from gateway auth
headers = dict(request.headers)
headers.pop("host", None) # Remove host header
# ✅ ADD USER CONTEXT FROM GATEWAY AUTHENTICATION
# Gateway middleware already verified the token and added user to request.state
if hasattr(request.state, 'user'):
headers["X-User-ID"] = str(request.state.user.get("user_id"))
headers["X-User-Email"] = request.state.user.get("email", "")
headers["X-Tenant-ID"] = str(request.state.user.get("tenant_id"))
headers["X-User-Roles"] = ",".join(request.state.user.get("roles", []))
headers["X-User-Permissions"] = ",".join(request.state.user.get("permissions", []))
# Get request body if present
body = None
request_method = method or request.method
if request_method in ["POST", "PUT", "PATCH"]:
body = await request.body()
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.request(
method=request_method,
url=url,
headers=headers,
content=body,
params=dict(request.query_params)
) )
return JSONResponse( return JSONResponse(
@@ -37,35 +72,18 @@ async def get_training_status(training_job_id: str, request: Request):
detail="Training service unavailable" detail="Training service unavailable"
) )
except Exception as e: except Exception as e:
logger.error(f"Training status error: {e}") logger.error(f"Training service error: {e}")
raise HTTPException(status_code=500, detail="Internal server error") raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/status/{training_job_id}")
async def get_training_status(training_job_id: str, request: Request):
"""Get training job status"""
return await _proxy_training_request(request, f"/training/status/{training_job_id}", "GET")
@router.get("/models") @router.get("/models")
async def get_trained_models(request: Request): async def get_trained_models(request: Request):
"""Get trained models""" """Get trained models"""
try: return await _proxy_training_request(request, "/training/models", "GET")
auth_header = request.headers.get("Authorization")
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(
f"{settings.TRAINING_SERVICE_URL}/models",
headers={"Authorization": auth_header}
)
return JSONResponse(
status_code=response.status_code,
content=response.json()
)
except httpx.RequestError as e:
logger.error(f"Training service unavailable: {e}")
raise HTTPException(
status_code=503,
detail="Training service unavailable"
)
except Exception as e:
logger.error(f"Get models error: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/jobs") @router.get("/jobs")
async def get_training_jobs( async def get_training_jobs(
@@ -74,59 +92,9 @@ async def get_training_jobs(
offset: Optional[int] = Query(0, ge=0) offset: Optional[int] = Query(0, ge=0)
): ):
"""Get training jobs""" """Get training jobs"""
try: return await _proxy_training_request(request, f"/training/jobs?limit={limit}&offset={offset}", "GET")
auth_header = request.headers.get("Authorization")
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(
f"{settings.TRAINING_SERVICE_URL}/jobs",
params={"limit": limit, "offset": offset},
headers={"Authorization": auth_header}
)
return JSONResponse(
status_code=response.status_code,
content=response.json()
)
except httpx.RequestError as e:
logger.error(f"Training service unavailable: {e}")
raise HTTPException(
status_code=503,
detail="Training service unavailable"
)
except Exception as e:
logger.error(f"Get training jobs error: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.post("/jobs") @router.post("/jobs")
async def start_training_job(request: Request): async def start_training_job(request: Request):
"""Start a new training job - Proxy to training service""" """Start a new training job - Proxy to training service"""
try: return await _proxy_training_request(request, "/training/jobs", "POST")
body = await request.body()
auth_header = request.headers.get("Authorization")
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
f"{settings.TRAINING_SERVICE_URL}/training/jobs", # Correct path
content=body,
headers={
"Content-Type": "application/json",
"Authorization": auth_header
}
)
return JSONResponse(
status_code=response.status_code,
content=response.json()
)
except httpx.RequestError as e:
logger.error(f"Training service unavailable: {e}")
raise HTTPException(
status_code=503,
detail="Training service unavailable"
)
except Exception as e:
logger.error(f"Start training job error: {e}")
raise HTTPException(status_code=500, detail="Internal server error")

View File

@@ -7,6 +7,7 @@ from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, Query
from typing import List, Optional, Dict, Any from typing import List, Optional, Dict, Any
from datetime import datetime from datetime import datetime
import structlog import structlog
import uuid
from app.schemas.training import ( from app.schemas.training import (
TrainingJobRequest, TrainingJobRequest,
@@ -27,6 +28,9 @@ from app.services.messaging import (
publish_product_training_completed publish_product_training_completed
) )
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_db
# Import unified authentication from shared library # Import unified authentication from shared library
from shared.auth.decorators import ( from shared.auth.decorators import (
get_current_user_dep, get_current_user_dep,
@@ -43,26 +47,31 @@ async def start_training_job(
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
tenant_id: str = Depends(get_current_tenant_id_dep), tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep), current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends() training_service: TrainingService = Depends(),
db: AsyncSession = Depends(get_db) # Ensure db is available
): ):
"""Start a new training job for all products""" """Start a new training job for all products"""
try: try:
new_job_id = str(uuid.uuid4())
logger.info("Starting training job", logger.info("Starting training job",
tenant_id=tenant_id, tenant_id=tenant_id,
user_id=current_user["user_id"], job_id=uuid.uuid4(),
config=request.dict()) config=request.dict())
# Create training job # Create training job
job = await training_service.create_training_job( job = await training_service.create_training_job(
db, # Pass db here
tenant_id=tenant_id, tenant_id=tenant_id,
user_id=current_user["user_id"], job_id=new_job_id,
config=request.dict() config=request.dict()
) )
# Publish job started event # Publish job started event
try: try:
await publish_job_started( await publish_job_started(
job_id=job.job_id, job_id=new_job_id,
tenant_id=tenant_id, tenant_id=tenant_id,
config=request.dict() config=request.dict()
) )
@@ -72,7 +81,10 @@ async def start_training_job(
# Start training in background # Start training in background
background_tasks.add_task( background_tasks.add_task(
training_service.execute_training_job, training_service.execute_training_job,
job.job_id db, # Pass the database session
job.job_id,
job.tenant_id,
request # Pass the request object
) )
logger.info("Training job created", logger.info("Training job created",