2025-07-17 13:09:24 +02:00
|
|
|
"""
|
|
|
|
|
Training API endpoints
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
from typing import List, Optional
|
2025-07-18 14:41:39 +02:00
|
|
|
import structlog
|
2025-07-17 13:09:24 +02:00
|
|
|
|
|
|
|
|
from app.core.database import get_db
|
|
|
|
|
from app.core.auth import verify_token
|
|
|
|
|
from app.schemas.training import TrainingRequest, TrainingJobResponse, TrainedModelResponse
|
|
|
|
|
from app.services.training_service import TrainingService
|
|
|
|
|
|
2025-07-18 14:41:39 +02:00
|
|
|
logger = structlog.get_logger()
|
2025-07-17 13:09:24 +02:00
|
|
|
router = APIRouter()
|
|
|
|
|
|
|
|
|
|
training_service = TrainingService()
|
|
|
|
|
|
|
|
|
|
@router.post("/train", response_model=TrainingJobResponse)
|
|
|
|
|
async def start_training(
|
|
|
|
|
request: TrainingRequest,
|
|
|
|
|
user_data: dict = Depends(verify_token),
|
|
|
|
|
db: AsyncSession = Depends(get_db)
|
|
|
|
|
):
|
|
|
|
|
"""Start training job"""
|
|
|
|
|
try:
|
|
|
|
|
return await training_service.start_training(request, user_data, db)
|
|
|
|
|
except ValueError as e:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
|
|
detail=str(e)
|
|
|
|
|
)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Training start error: {e}")
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
|
|
|
detail="Failed to start training"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@router.get("/status/{job_id}", response_model=TrainingJobResponse)
|
|
|
|
|
async def get_training_status(
|
|
|
|
|
job_id: str,
|
|
|
|
|
user_data: dict = Depends(verify_token),
|
|
|
|
|
db: AsyncSession = Depends(get_db)
|
|
|
|
|
):
|
|
|
|
|
"""Get training job status"""
|
|
|
|
|
try:
|
|
|
|
|
return await training_service.get_training_status(job_id, user_data, db)
|
|
|
|
|
except ValueError as e:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
|
|
|
detail=str(e)
|
|
|
|
|
)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Get training status error: {e}")
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
|
|
|
detail="Failed to get training status"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@router.get("/jobs", response_model=List[TrainingJobResponse])
|
|
|
|
|
async def get_training_jobs(
|
|
|
|
|
limit: int = Query(10, ge=1, le=100),
|
|
|
|
|
offset: int = Query(0, ge=0),
|
|
|
|
|
user_data: dict = Depends(verify_token),
|
|
|
|
|
db: AsyncSession = Depends(get_db)
|
|
|
|
|
):
|
|
|
|
|
"""Get training jobs"""
|
|
|
|
|
try:
|
|
|
|
|
return await training_service.get_training_jobs(user_data, limit, offset, db)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Get training jobs error: {e}")
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
|
|
|
detail="Failed to get training jobs"
|
|
|
|
|
)
|