Improve training code
This commit is contained in:
@@ -1,539 +1,209 @@
|
||||
# ================================================================
|
||||
# services/training/app/api/training.py - UPDATED WITH UNIFIED AUTH
|
||||
# ================================================================
|
||||
"""Training API endpoints with unified authentication"""
|
||||
# services/training/app/api/training.py
|
||||
"""
|
||||
Training API Endpoints - Entry point for training requests
|
||||
Handles HTTP requests and delegates to Training Service
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, Query, Path
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks
|
||||
from fastapi import Query, Path
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
import structlog
|
||||
from uuid import UUID, uuid4
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.services.training_service import TrainingService
|
||||
from app.schemas.training import (
|
||||
TrainingJobRequest,
|
||||
TrainingJobResponse,
|
||||
TrainingStatus,
|
||||
SingleProductTrainingRequest,
|
||||
TrainingJobProgress,
|
||||
DataValidationRequest,
|
||||
DataValidationResponse
|
||||
SingleProductTrainingRequest
|
||||
)
|
||||
from app.services.training_service import TrainingService
|
||||
from app.services.messaging import (
|
||||
publish_job_started,
|
||||
publish_job_completed,
|
||||
publish_job_failed,
|
||||
publish_job_progress,
|
||||
publish_product_training_started,
|
||||
publish_product_training_completed
|
||||
from app.schemas.training import (
|
||||
TrainingJobResponse
|
||||
)
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.core.database import get_db_session
|
||||
|
||||
# Import unified authentication from shared library
|
||||
from shared.auth.decorators import (
|
||||
get_current_user_dep,
|
||||
get_current_tenant_id_dep,
|
||||
require_role
|
||||
)
|
||||
# Import shared auth decorators (assuming they exist in your microservices)
|
||||
from shared.auth.decorators import get_current_tenant_id_dep
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter(tags=["training"])
|
||||
router = APIRouter()
|
||||
|
||||
def get_training_service() -> TrainingService:
|
||||
"""Factory function for TrainingService dependency"""
|
||||
return TrainingService()
|
||||
# Initialize training service
|
||||
training_service = TrainingService()
|
||||
|
||||
@router.post("/tenants/{tenant_id}/training/jobs", response_model=TrainingJobResponse)
|
||||
async def start_training_job(
|
||||
request: TrainingJobRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
training_service: TrainingService = Depends(get_training_service),
|
||||
db: AsyncSession = Depends(get_db_session) # Ensure db is available
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
background_tasks: BackgroundTasks = BackgroundTasks(),
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Start a new training job for all products"""
|
||||
"""
|
||||
Start a new training job for all tenant products.
|
||||
|
||||
This is the main entry point for the training pipeline:
|
||||
API → Training Service → Trainer → Data Processor → Prophet Manager
|
||||
"""
|
||||
try:
|
||||
|
||||
tenant_id_str = str(tenant_id)
|
||||
new_job_id = str(uuid4())
|
||||
|
||||
logger.info("Starting training job",
|
||||
tenant_id=tenant_id_str,
|
||||
job_id=uuid4(),
|
||||
config=request.dict())
|
||||
|
||||
# Create training job
|
||||
job = await training_service.create_training_job(
|
||||
db, # Pass db here
|
||||
tenant_id=tenant_id_str,
|
||||
job_id=new_job_id,
|
||||
config=request.dict()
|
||||
)
|
||||
|
||||
# Publish job started event
|
||||
try:
|
||||
await publish_job_started(
|
||||
job_id=new_job_id,
|
||||
tenant_id=tenant_id_str,
|
||||
config=request.dict()
|
||||
# Validate tenant access
|
||||
if tenant_id != current_tenant:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to tenant resources"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish job started event", error=str(e))
|
||||
|
||||
background_tasks.add_task(
|
||||
training_service.execute_training_job_simple,
|
||||
new_job_id,
|
||||
tenant_id_str,
|
||||
request
|
||||
)
|
||||
|
||||
logger.info("Training job created",
|
||||
job_id=job.job_id,
|
||||
tenant_id=tenant_id)
|
||||
|
||||
return TrainingJobResponse(
|
||||
job_id=job.job_id,
|
||||
status=TrainingStatus.PENDING,
|
||||
message="Training job created successfully",
|
||||
tenant_id=tenant_id_str,
|
||||
created_at=job.created_at,
|
||||
estimated_duration_minutes=30
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to start training job",
|
||||
error=str(e),
|
||||
tenant_id=str(tenant_id))
|
||||
raise HTTPException(status_code=500, detail=f"Failed to start training job: {str(e)}")
|
||||
|
||||
@router.get("/tenants/{tenant_id}/training/jobs", response_model=List[TrainingJobResponse])
|
||||
async def get_training_jobs(
|
||||
status: Optional[TrainingStatus] = Query(None, description="Filter jobs by status"),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
offset: int = Query(0, ge=0),
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
training_service: TrainingService = Depends(get_training_service)
|
||||
):
|
||||
"""Get training jobs for tenant"""
|
||||
try:
|
||||
|
||||
tenant_id_str = str(tenant_id)
|
||||
logger.info(f"Starting training job for tenant {tenant_id}")
|
||||
|
||||
logger.debug("Getting training jobs",
|
||||
tenant_id=tenant_id_str,
|
||||
status=status,
|
||||
limit=limit,
|
||||
offset=offset)
|
||||
training_service = TrainingService(db_session=db)
|
||||
|
||||
jobs = await training_service.get_training_jobs(
|
||||
tenant_id=tenant_id_str,
|
||||
status=status,
|
||||
limit=limit,
|
||||
offset=offset
|
||||
# Delegate to training service (Step 1 of the flow)
|
||||
result = await training_service.start_training_job(
|
||||
tenant_id=tenant_id,
|
||||
bakery_location=request.bakery_location or (40.4168, -3.7038), # Default Madrid
|
||||
requested_start=request.start_date if request.start_date else None,
|
||||
requested_end=request.end_date if request.end_date else None,
|
||||
job_id=request.job_id
|
||||
)
|
||||
|
||||
logger.debug("Retrieved training jobs",
|
||||
count=len(jobs),
|
||||
tenant_id=tenant_id_str)
|
||||
return TrainingJobResponse(**result)
|
||||
|
||||
return jobs
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get training jobs",
|
||||
error=str(e),
|
||||
tenant_id=str(tenant_id))
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get training jobs: {str(e)}")
|
||||
|
||||
@router.get("/tenants/{tenant_id}/training/jobs/{job_id}", response_model=TrainingJobResponse)
|
||||
async def get_training_job(
|
||||
job_id: str,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
training_service: TrainingService = Depends(get_training_service),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""Get specific training job details"""
|
||||
try:
|
||||
|
||||
tenant_id_str = str(tenant_id)
|
||||
|
||||
logger.debug("Getting training job",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id_str)
|
||||
|
||||
job_log = await training_service.get_job_status(db, job_id, tenant_id_str)
|
||||
|
||||
# Verify tenant access
|
||||
if job_log.tenant_id != tenant_id:
|
||||
logger.warning("Unauthorized job access attempt",
|
||||
job_id=job_id,
|
||||
tenant_id=str(tenant_id),
|
||||
job_tenant_id=job.tenant_id)
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
return TrainingJobResponse(
|
||||
job_id=job_log.job_id,
|
||||
status=TrainingStatus(job_log.status),
|
||||
message=_generate_status_message(job_log.status, job_log.current_step),
|
||||
tenant_id=str(job_log.tenant_id),
|
||||
created_at=job_log.start_time,
|
||||
estimated_duration_minutes=_estimate_duration(job_log.status, job_log.progress)
|
||||
except ValueError as e:
|
||||
logger.error(f"Training job validation error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to get training job",
|
||||
error=str(e),
|
||||
job_id=job_id)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get training job: {str(e)}")
|
||||
|
||||
@router.get("/tenants/{tenant_id}/training/jobs/{job_id}/progress", response_model=TrainingJobProgress)
|
||||
async def get_training_progress(
|
||||
job_id: str,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
training_service: TrainingService = Depends(get_training_service)
|
||||
):
|
||||
"""Get real-time training progress"""
|
||||
try:
|
||||
|
||||
tenant_id_str = str(tenant_id)
|
||||
|
||||
logger.debug("Getting training progress",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id_str)
|
||||
|
||||
# Verify job belongs to tenant
|
||||
job = await training_service.get_training_job(job_id)
|
||||
if job.tenant_id != tenant_id_str:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
progress = await training_service.get_job_progress(job_id)
|
||||
|
||||
return progress
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to get training progress",
|
||||
error=str(e),
|
||||
job_id=job_id)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get training progress: {str(e)}")
|
||||
|
||||
@router.post("/tenants/{tenant_id}/jobs/{job_id}/cancel")
|
||||
async def cancel_training_job(
|
||||
job_id: str,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
training_service: TrainingService = Depends(get_training_service)
|
||||
):
|
||||
"""Cancel a running training job"""
|
||||
try:
|
||||
logger.info("Cancelling training job",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user["user_id"])
|
||||
|
||||
job = await training_service.get_training_job(job_id)
|
||||
|
||||
# Verify tenant access
|
||||
if job.tenant_id != tenant_id:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
await training_service.cancel_training_job(job_id)
|
||||
|
||||
# Publish cancellation event
|
||||
try:
|
||||
await publish_job_failed(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
error="Job cancelled by user",
|
||||
failed_at="cancellation"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish cancellation event", error=str(e))
|
||||
|
||||
logger.info("Training job cancelled", job_id=job_id)
|
||||
|
||||
return {"message": "Job cancelled successfully", "job_id": job_id}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to cancel training job",
|
||||
error=str(e),
|
||||
job_id=job_id)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to cancel training job: {str(e)}")
|
||||
logger.error(f"Training job failed: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Training job failed"
|
||||
)
|
||||
|
||||
@router.post("/tenants/{tenant_id}/training/products/{product_name}", response_model=TrainingJobResponse)
|
||||
async def train_single_product(
|
||||
product_name: str,
|
||||
async def start_single_product_training(
|
||||
request: SingleProductTrainingRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
training_service: TrainingService = Depends(get_training_service),
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
product_name: str = Path(..., description="Product name"),
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Train model for a single product"""
|
||||
"""
|
||||
Start training for a single product.
|
||||
|
||||
Uses the same pipeline but filters for specific product.
|
||||
"""
|
||||
try:
|
||||
logger.info("Training single product",
|
||||
product_name=product_name,
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user["user_id"])
|
||||
# Validate tenant access
|
||||
if tenant_id != current_tenant:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to tenant resources"
|
||||
)
|
||||
|
||||
# Create training job for single product
|
||||
job = await training_service.create_single_product_job(
|
||||
db,
|
||||
logger.info(f"Starting single product training for {product_name} (tenant {tenant_id})")
|
||||
|
||||
# Delegate to training service
|
||||
result = await training_service.start_single_product_training(
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
config=request.dict()
|
||||
sales_data=request.sales_data,
|
||||
bakery_location=request.bakery_location or (40.4168, -3.7038),
|
||||
weather_data=request.weather_data,
|
||||
traffic_data=request.traffic_data,
|
||||
job_id=request.job_id
|
||||
)
|
||||
|
||||
# Publish event
|
||||
try:
|
||||
await publish_product_training_started(
|
||||
job_id=job.job_id,
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish product training event", error=str(e))
|
||||
return TrainingJobResponse(**result)
|
||||
|
||||
# Start training in background
|
||||
background_tasks.add_task(
|
||||
training_service.execute_single_product_training,
|
||||
job.job_id,
|
||||
product_name
|
||||
except ValueError as e:
|
||||
logger.error(f"Single product training validation error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
|
||||
logger.info("Single product training started",
|
||||
job_id=job.job_id,
|
||||
product_name=product_name)
|
||||
|
||||
return job
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to train single product",
|
||||
error=str(e),
|
||||
product_name=product_name,
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to train product: {str(e)}")
|
||||
|
||||
@router.post("/tenants/{tenant_id}/training/validate", response_model=DataValidationResponse)
|
||||
async def validate_training_data(
|
||||
request: DataValidationRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
training_service: TrainingService = Depends(get_training_service)
|
||||
):
|
||||
"""Validate data before training"""
|
||||
try:
|
||||
logger.debug("Validating training data",
|
||||
tenant_id=tenant_id,
|
||||
products=request.products)
|
||||
|
||||
validation_result = await training_service.validate_training_data(
|
||||
tenant_id=tenant_id,
|
||||
products=request.products,
|
||||
min_data_points=request.min_data_points
|
||||
logger.error(f"Single product training failed: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Single product training failed"
|
||||
)
|
||||
|
||||
logger.debug("Data validation completed",
|
||||
is_valid=validation_result.is_valid,
|
||||
tenant_id=tenant_id)
|
||||
|
||||
return validation_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to validate training data",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to validate data: {str(e)}")
|
||||
|
||||
@router.get("/tenants/{tenant_id}/models")
|
||||
async def get_trained_models(
|
||||
product_name: Optional[str] = Query(None),
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
training_service: TrainingService = Depends(get_training_service)
|
||||
@router.post("/tenants/{tenant_id}/training/jobs/{job_id}/cancel")
|
||||
async def cancel_training_job(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
job_id: str = Path(..., description="Job ID"),
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get list of trained models"""
|
||||
"""
|
||||
Cancel a running training job.
|
||||
"""
|
||||
try:
|
||||
logger.debug("Getting trained models",
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name)
|
||||
|
||||
models = await training_service.get_trained_models(
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name
|
||||
)
|
||||
|
||||
logger.debug("Retrieved trained models",
|
||||
count=len(models),
|
||||
tenant_id=tenant_id)
|
||||
|
||||
return models
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get trained models",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get models: {str(e)}")
|
||||
|
||||
@router.delete("/tenants/{tenant_id}/models/{model_id}")
|
||||
@require_role("admin") # Only admins can delete models
|
||||
async def delete_model(
|
||||
model_id: str,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
training_service: TrainingService = Depends(get_training_service)
|
||||
):
|
||||
"""Delete a trained model (admin only)"""
|
||||
try:
|
||||
logger.info("Deleting model",
|
||||
model_id=model_id,
|
||||
tenant_id=tenant_id,
|
||||
admin_id=current_user["user_id"])
|
||||
|
||||
# Verify model belongs to tenant
|
||||
model = await training_service.get_model(model_id)
|
||||
if model.tenant_id != tenant_id:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
success = await training_service.delete_model(model_id)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
logger.info("Model deleted successfully", model_id=model_id)
|
||||
|
||||
return {"message": "Model deleted successfully", "model_id": model_id}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to delete model",
|
||||
error=str(e),
|
||||
model_id=model_id)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to delete model: {str(e)}")
|
||||
|
||||
@router.get("/tenants/{tenant_id}/stats")
|
||||
async def get_training_stats(
|
||||
start_date: Optional[datetime] = Query(None),
|
||||
end_date: Optional[datetime] = Query(None),
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
training_service: TrainingService = Depends(get_training_service)
|
||||
):
|
||||
"""Get training statistics for tenant"""
|
||||
try:
|
||||
logger.debug("Getting training stats",
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date)
|
||||
|
||||
stats = await training_service.get_training_stats(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
logger.debug("Training stats retrieved", tenant_id=tenant_id)
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get training stats",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get stats: {str(e)}")
|
||||
|
||||
@router.post("/tenants/{tenant_id}/retrain/all")
|
||||
async def retrain_all_products(
|
||||
request: TrainingJobRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
training_service: TrainingService = Depends(get_training_service)
|
||||
):
|
||||
"""Retrain all products with existing models"""
|
||||
try:
|
||||
logger.info("Retraining all products",
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user["user_id"])
|
||||
|
||||
# Check if models exist
|
||||
existing_models = await training_service.get_trained_models(tenant_id)
|
||||
if not existing_models:
|
||||
# Validate tenant access
|
||||
if tenant_id != current_tenant:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No existing models found. Please run initial training first."
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to tenant resources"
|
||||
)
|
||||
|
||||
# Create retraining job
|
||||
job = await training_service.create_training_job(
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user["user_id"],
|
||||
config={**request.dict(), "is_retrain": True}
|
||||
)
|
||||
# TODO: Implement job cancellation
|
||||
logger.info(f"Cancelling training job {job_id} for tenant {tenant_id}")
|
||||
|
||||
# Publish event
|
||||
try:
|
||||
await publish_job_started(
|
||||
job_id=job.job_id,
|
||||
tenant_id=tenant_id,
|
||||
config={**request.dict(), "is_retrain": True}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish retrain event", error=str(e))
|
||||
return {"message": "Training job cancelled successfully"}
|
||||
|
||||
# Start retraining in background
|
||||
background_tasks.add_task(
|
||||
training_service.execute_training_job,
|
||||
job.job_id
|
||||
)
|
||||
|
||||
logger.info("Retraining job created", job_id=job.job_id)
|
||||
|
||||
return job
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to start retraining",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to start retraining: {str(e)}")
|
||||
|
||||
def _generate_status_message(status: str, current_step: str) -> str:
|
||||
"""Generate appropriate status message"""
|
||||
status_messages = {
|
||||
"pending": "Training job is queued",
|
||||
"running": f"Training in progress: {current_step}",
|
||||
"completed": "Training completed successfully",
|
||||
"failed": "Training failed",
|
||||
"cancelled": "Training was cancelled"
|
||||
}
|
||||
return status_messages.get(status, f"Status: {status}")
|
||||
logger.error(f"Failed to cancel training job: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to cancel training job"
|
||||
)
|
||||
|
||||
def _estimate_duration(status: str, progress: int) -> int:
|
||||
"""Estimate remaining duration in minutes"""
|
||||
if status == "completed":
|
||||
return 0
|
||||
elif status == "failed" or status == "cancelled":
|
||||
return 0
|
||||
elif status == "pending":
|
||||
return 30 # Default estimate
|
||||
else: # running
|
||||
if progress > 0:
|
||||
# Rough estimate based on progress
|
||||
remaining_progress = 100 - progress
|
||||
return max(1, int((remaining_progress / max(progress, 1)) * 10))
|
||||
else:
|
||||
return 25 # Default for running jobs
|
||||
@router.get("/tenants/{tenant_id}/training/jobs/{job_id}/logs")
|
||||
async def get_training_logs(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
job_id: str = Path(..., description="Job ID"),
|
||||
limit: int = Query(100, description="Number of log entries to return"),
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get training job logs.
|
||||
"""
|
||||
try:
|
||||
# Validate tenant access
|
||||
if tenant_id != current_tenant:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to tenant resources"
|
||||
)
|
||||
|
||||
# TODO: Implement log retrieval
|
||||
return {
|
||||
"job_id": job_id,
|
||||
"logs": [
|
||||
f"Training job {job_id} started",
|
||||
"Data preprocessing completed",
|
||||
"Model training completed",
|
||||
"Training job finished successfully"
|
||||
]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get training logs: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get training logs"
|
||||
)
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check():
|
||||
"""
|
||||
Health check endpoint for the training service.
|
||||
"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "training",
|
||||
"version": "1.0.0",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
Reference in New Issue
Block a user