Fix training start
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
"""
|
||||
Training routes for gateway
|
||||
Training routes for gateway - FIXED VERSION
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Request, HTTPException, Query
|
||||
from fastapi import APIRouter, Request, HTTPException, Query, Response
|
||||
from fastapi.responses import JSONResponse
|
||||
import httpx
|
||||
import logging
|
||||
@@ -13,16 +13,51 @@ from app.core.config import settings
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/status/{training_job_id}")
|
||||
async def get_training_status(training_job_id: str, request: Request):
|
||||
"""Get training job status"""
|
||||
try:
|
||||
auth_header = request.headers.get("Authorization")
|
||||
async def _proxy_training_request(request: Request, target_path: str, method: str = None):
|
||||
"""Proxy request to training service with user context"""
|
||||
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(
|
||||
f"{settings.TRAINING_SERVICE_URL}/status/{training_job_id}",
|
||||
headers={"Authorization": auth_header}
|
||||
# Handle OPTIONS requests directly for CORS
|
||||
if request.method == "OPTIONS":
|
||||
return Response(
|
||||
status_code=200,
|
||||
headers={
|
||||
"Access-Control-Allow-Origin": settings.CORS_ORIGINS_LIST,
|
||||
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type, Authorization, X-Tenant-ID",
|
||||
"Access-Control-Allow-Credentials": "true",
|
||||
"Access-Control-Max-Age": "86400" # Cache preflight for 24 hours
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
url = f"{settings.TRAINING_SERVICE_URL}{target_path}"
|
||||
|
||||
# Forward headers AND add user context from gateway auth
|
||||
headers = dict(request.headers)
|
||||
headers.pop("host", None) # Remove host header
|
||||
|
||||
# ✅ ADD USER CONTEXT FROM GATEWAY AUTHENTICATION
|
||||
# Gateway middleware already verified the token and added user to request.state
|
||||
if hasattr(request.state, 'user'):
|
||||
headers["X-User-ID"] = str(request.state.user.get("user_id"))
|
||||
headers["X-User-Email"] = request.state.user.get("email", "")
|
||||
headers["X-Tenant-ID"] = str(request.state.user.get("tenant_id"))
|
||||
headers["X-User-Roles"] = ",".join(request.state.user.get("roles", []))
|
||||
headers["X-User-Permissions"] = ",".join(request.state.user.get("permissions", []))
|
||||
|
||||
# Get request body if present
|
||||
body = None
|
||||
request_method = method or request.method
|
||||
if request_method in ["POST", "PUT", "PATCH"]:
|
||||
body = await request.body()
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.request(
|
||||
method=request_method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
content=body,
|
||||
params=dict(request.query_params)
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
@@ -37,35 +72,18 @@ async def get_training_status(training_job_id: str, request: Request):
|
||||
detail="Training service unavailable"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Training status error: {e}")
|
||||
logger.error(f"Training service error: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.get("/status/{training_job_id}")
|
||||
async def get_training_status(training_job_id: str, request: Request):
|
||||
"""Get training job status"""
|
||||
return await _proxy_training_request(request, f"/training/status/{training_job_id}", "GET")
|
||||
|
||||
@router.get("/models")
|
||||
async def get_trained_models(request: Request):
|
||||
"""Get trained models"""
|
||||
try:
|
||||
auth_header = request.headers.get("Authorization")
|
||||
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(
|
||||
f"{settings.TRAINING_SERVICE_URL}/models",
|
||||
headers={"Authorization": auth_header}
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=response.status_code,
|
||||
content=response.json()
|
||||
)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Training service unavailable: {e}")
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Training service unavailable"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Get models error: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
return await _proxy_training_request(request, "/training/models", "GET")
|
||||
|
||||
@router.get("/jobs")
|
||||
async def get_training_jobs(
|
||||
@@ -74,59 +92,9 @@ async def get_training_jobs(
|
||||
offset: Optional[int] = Query(0, ge=0)
|
||||
):
|
||||
"""Get training jobs"""
|
||||
try:
|
||||
auth_header = request.headers.get("Authorization")
|
||||
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(
|
||||
f"{settings.TRAINING_SERVICE_URL}/jobs",
|
||||
params={"limit": limit, "offset": offset},
|
||||
headers={"Authorization": auth_header}
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=response.status_code,
|
||||
content=response.json()
|
||||
)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Training service unavailable: {e}")
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Training service unavailable"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Get training jobs error: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
return await _proxy_training_request(request, f"/training/jobs?limit={limit}&offset={offset}", "GET")
|
||||
|
||||
@router.post("/jobs")
|
||||
async def start_training_job(request: Request):
|
||||
"""Start a new training job - Proxy to training service"""
|
||||
try:
|
||||
body = await request.body()
|
||||
auth_header = request.headers.get("Authorization")
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(
|
||||
f"{settings.TRAINING_SERVICE_URL}/training/jobs", # Correct path
|
||||
content=body,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": auth_header
|
||||
}
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=response.status_code,
|
||||
content=response.json()
|
||||
)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Training service unavailable: {e}")
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Training service unavailable"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Start training job error: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
return await _proxy_training_request(request, "/training/jobs", "POST")
|
||||
@@ -7,6 +7,7 @@ from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, Query
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
import structlog
|
||||
import uuid
|
||||
|
||||
from app.schemas.training import (
|
||||
TrainingJobRequest,
|
||||
@@ -27,6 +28,9 @@ from app.services.messaging import (
|
||||
publish_product_training_completed
|
||||
)
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.core.database import get_db
|
||||
|
||||
# Import unified authentication from shared library
|
||||
from shared.auth.decorators import (
|
||||
get_current_user_dep,
|
||||
@@ -43,26 +47,31 @@ async def start_training_job(
|
||||
background_tasks: BackgroundTasks,
|
||||
tenant_id: str = Depends(get_current_tenant_id_dep),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
training_service: TrainingService = Depends()
|
||||
training_service: TrainingService = Depends(),
|
||||
db: AsyncSession = Depends(get_db) # Ensure db is available
|
||||
):
|
||||
"""Start a new training job for all products"""
|
||||
try:
|
||||
|
||||
new_job_id = str(uuid.uuid4())
|
||||
|
||||
logger.info("Starting training job",
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user["user_id"],
|
||||
job_id=uuid.uuid4(),
|
||||
config=request.dict())
|
||||
|
||||
# Create training job
|
||||
job = await training_service.create_training_job(
|
||||
db, # Pass db here
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user["user_id"],
|
||||
job_id=new_job_id,
|
||||
config=request.dict()
|
||||
)
|
||||
|
||||
# Publish job started event
|
||||
try:
|
||||
await publish_job_started(
|
||||
job_id=job.job_id,
|
||||
job_id=new_job_id,
|
||||
tenant_id=tenant_id,
|
||||
config=request.dict()
|
||||
)
|
||||
@@ -72,7 +81,10 @@ async def start_training_job(
|
||||
# Start training in background
|
||||
background_tasks.add_task(
|
||||
training_service.execute_training_job,
|
||||
job.job_id
|
||||
db, # Pass the database session
|
||||
job.job_id,
|
||||
job.tenant_id,
|
||||
request # Pass the request object
|
||||
)
|
||||
|
||||
logger.info("Training job created",
|
||||
|
||||
Reference in New Issue
Block a user