Fix training start

This commit is contained in:
Urtzi Alfaro
2025-07-25 17:20:39 +02:00
parent 86bf95eb89
commit ebb39aa8c9
2 changed files with 82 additions and 102 deletions

View File

@@ -7,6 +7,7 @@ from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, Query
from typing import List, Optional, Dict, Any
from datetime import datetime
import structlog
import uuid
from app.schemas.training import (
TrainingJobRequest,
@@ -27,6 +28,9 @@ from app.services.messaging import (
publish_product_training_completed
)
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_db
# Import unified authentication from shared library
from shared.auth.decorators import (
get_current_user_dep,
@@ -43,46 +47,54 @@ async def start_training_job(
background_tasks: BackgroundTasks,
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends()
training_service: TrainingService = Depends(),
db: AsyncSession = Depends(get_db) # Ensure db is available
):
"""Start a new training job for all products"""
try:
logger.info("Starting training job",
tenant_id=tenant_id,
user_id=current_user["user_id"],
config=request.dict())
new_job_id = str(uuid.uuid4())
logger.info("Starting training job",
tenant_id=tenant_id,
job_id=uuid.uuid4(),
config=request.dict())
# Create training job
job = await training_service.create_training_job(
db, # Pass db here
tenant_id=tenant_id,
user_id=current_user["user_id"],
job_id=new_job_id,
config=request.dict()
)
# Publish job started event
try:
await publish_job_started(
job_id=job.job_id,
job_id=new_job_id,
tenant_id=tenant_id,
config=request.dict()
)
except Exception as e:
logger.warning("Failed to publish job started event", error=str(e))
# Start training in background
background_tasks.add_task(
training_service.execute_training_job,
job.job_id
db, # Pass the database session
job.job_id,
job.tenant_id,
request # Pass the request object
)
logger.info("Training job created",
logger.info("Training job created",
job_id=job.job_id,
tenant_id=tenant_id)
return job
except Exception as e:
logger.error("Failed to start training job",
logger.error("Failed to start training job",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to start training job: {str(e)}")