Files
bakery-ia/services/training/app/api/training.py

209 lines
7.1 KiB
Python
Raw Normal View History

2025-07-28 19:28:39 +02:00
# services/training/app/api/training.py
"""
Training API Endpoints - Entry point for training requests
Handles HTTP requests and delegates to Training Service
"""
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks
from fastapi import Query, Path
from sqlalchemy.ext.asyncio import AsyncSession
2025-07-20 07:24:04 +02:00
from typing import List, Optional, Dict, Any
import structlog
2025-07-28 19:28:39 +02:00
from datetime import datetime
2025-07-28 19:28:39 +02:00
from app.core.database import get_db
from app.services.training_service import TrainingService
2025-07-19 16:59:37 +02:00
from app.schemas.training import (
2025-07-20 07:24:04 +02:00
TrainingJobRequest,
2025-07-28 19:28:39 +02:00
SingleProductTrainingRequest
2025-07-19 16:59:37 +02:00
)
2025-07-28 19:28:39 +02:00
from app.schemas.training import (
TrainingJobResponse
2025-07-20 07:24:04 +02:00
)
2025-07-28 19:28:39 +02:00
# Import shared auth decorators (assuming they exist in your microservices)
from shared.auth.decorators import get_current_tenant_id_dep
2025-07-20 07:24:04 +02:00
logger = structlog.get_logger()
2025-07-28 19:28:39 +02:00
router = APIRouter()
2025-07-28 19:28:39 +02:00
# Initialize training service
training_service = TrainingService()
2025-07-25 19:40:49 +02:00
2025-07-26 23:34:46 +02:00
@router.post("/tenants/{tenant_id}/training/jobs", response_model=TrainingJobResponse)
2025-07-19 16:59:37 +02:00
async def start_training_job(
request: TrainingJobRequest,
2025-07-28 19:28:39 +02:00
tenant_id: str = Path(..., description="Tenant ID"),
background_tasks: BackgroundTasks = BackgroundTasks(),
current_tenant: str = Depends(get_current_tenant_id_dep),
db: AsyncSession = Depends(get_db)
):
2025-07-28 19:28:39 +02:00
"""
Start a new training job for all tenant products.
This is the main entry point for the training pipeline:
API Training Service Trainer Data Processor Prophet Manager
"""
try:
2025-07-28 19:28:39 +02:00
# Validate tenant access
if tenant_id != current_tenant:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
2025-07-20 07:24:04 +02:00
)
2025-07-27 10:30:42 +02:00
2025-07-28 19:28:39 +02:00
logger.info(f"Starting training job for tenant {tenant_id}")
2025-07-27 10:30:42 +02:00
2025-07-28 19:28:39 +02:00
training_service = TrainingService(db_session=db)
2025-07-19 16:59:37 +02:00
2025-07-28 19:28:39 +02:00
# Delegate to training service (Step 1 of the flow)
result = await training_service.start_training_job(
tenant_id=tenant_id,
2025-07-28 20:20:54 +02:00
bakery_location=(40.4168, -3.7038), # Default Madrid coordinates
2025-07-28 19:28:39 +02:00
requested_start=request.start_date if request.start_date else None,
requested_end=request.end_date if request.end_date else None,
2025-07-28 20:20:54 +02:00
job_id=None # Let the service generate it
)
2025-07-19 16:59:37 +02:00
2025-07-28 19:28:39 +02:00
return TrainingJobResponse(**result)
2025-07-20 07:24:04 +02:00
2025-07-28 19:28:39 +02:00
except ValueError as e:
logger.error(f"Training job validation error: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
2025-07-27 22:58:18 +02:00
)
except Exception as e:
2025-07-28 19:28:39 +02:00
logger.error(f"Training job failed: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Training job failed"
)
2025-07-19 16:59:37 +02:00
2025-07-26 23:34:46 +02:00
@router.post("/tenants/{tenant_id}/training/products/{product_name}", response_model=TrainingJobResponse)
2025-07-28 19:28:39 +02:00
async def start_single_product_training(
2025-07-19 16:59:37 +02:00
request: SingleProductTrainingRequest,
2025-07-28 19:28:39 +02:00
tenant_id: str = Path(..., description="Tenant ID"),
product_name: str = Path(..., description="Product name"),
current_tenant: str = Depends(get_current_tenant_id_dep),
db: AsyncSession = Depends(get_db)
2025-07-19 16:59:37 +02:00
):
2025-07-28 19:28:39 +02:00
"""
Start training for a single product.
Uses the same pipeline but filters for specific product.
"""
2025-07-19 16:59:37 +02:00
try:
2025-07-28 19:28:39 +02:00
# Validate tenant access
if tenant_id != current_tenant:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
2025-07-20 07:24:04 +02:00
)
2025-07-28 19:28:39 +02:00
logger.info(f"Starting single product training for {product_name} (tenant {tenant_id})")
2025-07-19 16:59:37 +02:00
2025-07-28 19:28:39 +02:00
# Delegate to training service
result = await training_service.start_single_product_training(
2025-07-19 16:59:37 +02:00
tenant_id=tenant_id,
2025-07-28 19:28:39 +02:00
product_name=product_name,
sales_data=request.sales_data,
bakery_location=request.bakery_location or (40.4168, -3.7038),
weather_data=request.weather_data,
traffic_data=request.traffic_data,
job_id=request.job_id
2025-07-19 16:59:37 +02:00
)
2025-07-28 19:28:39 +02:00
return TrainingJobResponse(**result)
2025-07-20 07:24:04 +02:00
2025-07-28 19:28:39 +02:00
except ValueError as e:
logger.error(f"Single product training validation error: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
2025-07-19 16:59:37 +02:00
except Exception as e:
2025-07-28 19:28:39 +02:00
logger.error(f"Single product training failed: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Single product training failed"
)
2025-07-19 16:59:37 +02:00
2025-07-28 19:28:39 +02:00
@router.post("/tenants/{tenant_id}/training/jobs/{job_id}/cancel")
async def cancel_training_job(
tenant_id: str = Path(..., description="Tenant ID"),
job_id: str = Path(..., description="Job ID"),
current_tenant: str = Depends(get_current_tenant_id_dep),
db: AsyncSession = Depends(get_db)
2025-07-19 16:59:37 +02:00
):
2025-07-28 19:28:39 +02:00
"""
Cancel a running training job.
"""
2025-07-19 16:59:37 +02:00
try:
2025-07-28 19:28:39 +02:00
# Validate tenant access
if tenant_id != current_tenant:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
2025-07-19 16:59:37 +02:00
2025-07-28 19:28:39 +02:00
# TODO: Implement job cancellation
logger.info(f"Cancelling training job {job_id} for tenant {tenant_id}")
2025-07-19 16:59:37 +02:00
2025-07-28 19:28:39 +02:00
return {"message": "Training job cancelled successfully"}
2025-07-19 16:59:37 +02:00
except Exception as e:
2025-07-28 19:28:39 +02:00
logger.error(f"Failed to cancel training job: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to cancel training job"
2025-07-20 07:24:04 +02:00
)
2025-07-19 16:59:37 +02:00
2025-07-28 19:28:39 +02:00
@router.get("/tenants/{tenant_id}/training/jobs/{job_id}/logs")
async def get_training_logs(
tenant_id: str = Path(..., description="Tenant ID"),
job_id: str = Path(..., description="Job ID"),
limit: int = Query(100, description="Number of log entries to return"),
current_tenant: str = Depends(get_current_tenant_id_dep),
db: AsyncSession = Depends(get_db)
2025-07-19 16:59:37 +02:00
):
2025-07-28 19:28:39 +02:00
"""
Get training job logs.
"""
2025-07-19 16:59:37 +02:00
try:
2025-07-28 19:28:39 +02:00
# Validate tenant access
if tenant_id != current_tenant:
2025-07-20 07:24:04 +02:00
raise HTTPException(
2025-07-28 19:28:39 +02:00
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
2025-07-20 07:24:04 +02:00
)
2025-07-28 19:28:39 +02:00
# TODO: Implement log retrieval
return {
"job_id": job_id,
"logs": [
f"Training job {job_id} started",
"Data preprocessing completed",
"Model training completed",
"Training job finished successfully"
]
}
2025-07-20 07:24:04 +02:00
2025-07-19 16:59:37 +02:00
except Exception as e:
2025-07-28 19:28:39 +02:00
logger.error(f"Failed to get training logs: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get training logs"
)
2025-07-27 22:58:18 +02:00
2025-07-28 19:28:39 +02:00
@router.get("/health")
async def health_check():
"""
Health check endpoint for the training service.
"""
return {
"status": "healthy",
"service": "training",
"version": "1.0.0",
"timestamp": datetime.now().isoformat()
}