Files
bakery-ia/services/training/app/api/training.py
Urtzi Alfaro 4073222888 Fix imports
2025-07-18 14:41:39 +02:00

77 lines
2.5 KiB
Python

"""
Training API endpoints
"""
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, Optional
import structlog
from app.core.database import get_db
from app.core.auth import verify_token
from app.schemas.training import TrainingRequest, TrainingJobResponse, TrainedModelResponse
from app.services.training_service import TrainingService
logger = structlog.get_logger()
router = APIRouter()
training_service = TrainingService()
@router.post("/train", response_model=TrainingJobResponse)
async def start_training(
request: TrainingRequest,
user_data: dict = Depends(verify_token),
db: AsyncSession = Depends(get_db)
):
"""Start training job"""
try:
return await training_service.start_training(request, user_data, db)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e:
logger.error(f"Training start error: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to start training"
)
@router.get("/status/{job_id}", response_model=TrainingJobResponse)
async def get_training_status(
job_id: str,
user_data: dict = Depends(verify_token),
db: AsyncSession = Depends(get_db)
):
"""Get training job status"""
try:
return await training_service.get_training_status(job_id, user_data, db)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e)
)
except Exception as e:
logger.error(f"Get training status error: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get training status"
)
@router.get("/jobs", response_model=List[TrainingJobResponse])
async def get_training_jobs(
limit: int = Query(10, ge=1, le=100),
offset: int = Query(0, ge=0),
user_data: dict = Depends(verify_token),
db: AsyncSession = Depends(get_db)
):
"""Get training jobs"""
try:
return await training_service.get_training_jobs(user_data, limit, offset, db)
except Exception as e:
logger.error(f"Get training jobs error: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get training jobs"
)