Fix training start
This commit is contained in:
@@ -7,6 +7,7 @@ from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, Query
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
import structlog
|
||||
import uuid
|
||||
|
||||
from app.schemas.training import (
|
||||
TrainingJobRequest,
|
||||
@@ -27,6 +28,9 @@ from app.services.messaging import (
|
||||
publish_product_training_completed
|
||||
)
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.core.database import get_db
|
||||
|
||||
# Import unified authentication from shared library
|
||||
from shared.auth.decorators import (
|
||||
get_current_user_dep,
|
||||
@@ -43,46 +47,54 @@ async def start_training_job(
|
||||
background_tasks: BackgroundTasks,
|
||||
tenant_id: str = Depends(get_current_tenant_id_dep),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
training_service: TrainingService = Depends()
|
||||
training_service: TrainingService = Depends(),
|
||||
db: AsyncSession = Depends(get_db) # Ensure db is available
|
||||
):
|
||||
"""Start a new training job for all products"""
|
||||
try:
|
||||
logger.info("Starting training job",
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user["user_id"],
|
||||
config=request.dict())
|
||||
|
||||
new_job_id = str(uuid.uuid4())
|
||||
|
||||
logger.info("Starting training job",
|
||||
tenant_id=tenant_id,
|
||||
job_id=uuid.uuid4(),
|
||||
config=request.dict())
|
||||
|
||||
# Create training job
|
||||
job = await training_service.create_training_job(
|
||||
db, # Pass db here
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user["user_id"],
|
||||
job_id=new_job_id,
|
||||
config=request.dict()
|
||||
)
|
||||
|
||||
|
||||
# Publish job started event
|
||||
try:
|
||||
await publish_job_started(
|
||||
job_id=job.job_id,
|
||||
job_id=new_job_id,
|
||||
tenant_id=tenant_id,
|
||||
config=request.dict()
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish job started event", error=str(e))
|
||||
|
||||
|
||||
# Start training in background
|
||||
background_tasks.add_task(
|
||||
training_service.execute_training_job,
|
||||
job.job_id
|
||||
db, # Pass the database session
|
||||
job.job_id,
|
||||
job.tenant_id,
|
||||
request # Pass the request object
|
||||
)
|
||||
|
||||
logger.info("Training job created",
|
||||
|
||||
logger.info("Training job created",
|
||||
job_id=job.job_id,
|
||||
tenant_id=tenant_id)
|
||||
|
||||
|
||||
return job
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to start training job",
|
||||
logger.error("Failed to start training job",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to start training job: {str(e)}")
|
||||
|
||||
Reference in New Issue
Block a user