Checking onboardin flow - fix 2

This commit is contained in:
Urtzi Alfaro
2025-07-27 10:30:42 +02:00
parent cb3ae4d78b
commit 30ac945058
3 changed files with 78 additions and 18 deletions

View File

@@ -57,17 +57,18 @@ async def start_training_job(
"""Start a new training job for all products"""
try:
tenant_id_str = str(tenant_id)
new_job_id = str(uuid4())
logger.info("Starting training job",
tenant_id=tenant_id,
tenant_id=tenant_id_str,
job_id=uuid4(),
config=request.dict())
# Create training job
job = await training_service.create_training_job(
db, # Pass db here
tenant_id=tenant_id,
tenant_id=tenant_id_str,
job_id=new_job_id,
config=request.dict()
)
@@ -76,7 +77,7 @@ async def start_training_job(
try:
await publish_job_started(
job_id=new_job_id,
tenant_id=tenant_id,
tenant_id=tenant_id_str,
config=request.dict()
)
except Exception as e:
@@ -99,7 +100,7 @@ async def start_training_job(
job_id=job.job_id,
status=TrainingStatus.PENDING,
message="Training job created successfully",
tenant_id=tenant_id,
tenant_id=tenant_id_str,
created_at=job.created_at,
estimated_duration_minutes=30
)
@@ -107,7 +108,7 @@ async def start_training_job(
except Exception as e:
logger.error("Failed to start training job",
error=str(e),
tenant_id=tenant_id)
tenant_id=str(tenant_id))
raise HTTPException(status_code=500, detail=f"Failed to start training job: {str(e)}")
@router.get("/tenants/{tenant_id}/training/jobs", response_model=List[TrainingJobResponse])
@@ -121,14 +122,17 @@ async def get_training_jobs(
):
"""Get training jobs for tenant"""
try:
tenant_id_str = str(tenant_id)
logger.debug("Getting training jobs",
tenant_id=tenant_id,
tenant_id=tenant_id_str,
status=status,
limit=limit,
offset=offset)
jobs = await training_service.get_training_jobs(
tenant_id=tenant_id,
tenant_id=tenant_id_str,
status=status,
limit=limit,
offset=offset
@@ -136,14 +140,14 @@ async def get_training_jobs(
logger.debug("Retrieved training jobs",
count=len(jobs),
tenant_id=tenant_id)
tenant_id=tenant_id_str)
return jobs
except Exception as e:
logger.error("Failed to get training jobs",
error=str(e),
tenant_id=tenant_id)
tenant_id=str(tenant_id))
raise HTTPException(status_code=500, detail=f"Failed to get training jobs: {str(e)}")
@router.get("/tenants/{tenant_id}/training/jobs/{job_id}", response_model=TrainingJobResponse)
@@ -155,9 +159,12 @@ async def get_training_job(
):
"""Get specific training job details"""
try:
tenant_id_str = str(tenant_id)
logger.debug("Getting training job",
job_id=job_id,
tenant_id=tenant_id)
tenant_id=tenant_id_str)
job = await training_service.get_training_job(job_id)
@@ -165,7 +172,7 @@ async def get_training_job(
if job.tenant_id != tenant_id:
logger.warning("Unauthorized job access attempt",
job_id=job_id,
tenant_id=tenant_id,
tenant_id=str(tenant_id),
job_tenant_id=job.tenant_id)
raise HTTPException(status_code=404, detail="Job not found")
@@ -188,13 +195,16 @@ async def get_training_progress(
):
"""Get real-time training progress"""
try:
tenant_id_str = str(tenant_id)
logger.debug("Getting training progress",
job_id=job_id,
tenant_id=tenant_id)
tenant_id=tenant_id_str)
# Verify job belongs to tenant
job = await training_service.get_training_job(job_id)
if job.tenant_id != tenant_id:
if job.tenant_id != tenant_id_str:
raise HTTPException(status_code=404, detail="Job not found")
progress = await training_service.get_job_progress(job_id)

View File

@@ -8,6 +8,7 @@ from pydantic import BaseModel, Field, validator
from typing import List, Optional, Dict, Any, Union
from datetime import datetime
from enum import Enum
from uuid import UUID
class TrainingStatus(str, Enum):
@@ -58,7 +59,6 @@ class SingleProductTrainingRequest(BaseModel):
weekly_seasonality: bool = Field(True, description="Enable weekly seasonality")
yearly_seasonality: bool = Field(True, description="Enable yearly seasonality")
class TrainingJobResponse(BaseModel):
"""Response schema for training job creation"""
job_id: str = Field(..., description="Unique training job identifier")
@@ -67,7 +67,17 @@ class TrainingJobResponse(BaseModel):
tenant_id: str = Field(..., description="Tenant identifier")
created_at: datetime = Field(..., description="Job creation timestamp")
estimated_duration_minutes: int = Field(..., description="Estimated completion time in minutes")
# ✅ FIX: Add custom validator to convert UUID to string
@validator('tenant_id', 'job_id', pre=True)
def convert_uuid_to_string(cls, v):
"""Convert UUID objects to strings for JSON serialization"""
if isinstance(v, UUID):
return str(v)
return v
class Config:
from_attributes = True
class TrainingJobStatus(BaseModel):
"""Response schema for training job status checks"""
@@ -81,6 +91,16 @@ class TrainingJobStatus(BaseModel):
products_completed: int = Field(0, description="Number of products completed")
products_failed: int = Field(0, description="Number of products that failed")
error_message: Optional[str] = Field(None, description="Error message if failed")
@validator('job_id', pre=True)
def convert_uuid_to_string(cls, v):
"""Convert UUID objects to strings for JSON serialization"""
if isinstance(v, UUID):
return str(v)
return v
class Config:
from_attributes = True
class TrainingJobProgress(BaseModel):
@@ -94,6 +114,16 @@ class TrainingJobProgress(BaseModel):
products_total: int = Field(0, description="Total number of products")
estimated_time_remaining_minutes: Optional[int] = Field(None, description="Estimated time remaining")
timestamp: datetime = Field(default_factory=datetime.now, description="Progress update timestamp")
@validator('job_id', pre=True)
def convert_uuid_to_string(cls, v):
"""Convert UUID objects to strings for JSON serialization"""
if isinstance(v, UUID):
return str(v)
return v
class Config:
from_attributes = True
class DataValidationRequest(BaseModel):
@@ -157,6 +187,16 @@ class TrainingResultsResponse(BaseModel):
training_results: Dict[str, ProductTrainingResult] = Field(..., description="Per-product results")
summary: Dict[str, Any] = Field(..., description="Training summary statistics")
completed_at: datetime = Field(..., description="Job completion timestamp")
@validator('tenant_id', 'job_id', pre=True)
def convert_uuid_to_string(cls, v):
"""Convert UUID objects to strings for JSON serialization"""
if isinstance(v, UUID):
return str(v)
return v
class Config:
from_attributes = True
class TrainingValidationResult(BaseModel):
@@ -232,6 +272,16 @@ class TrainedModelResponse(BaseModel):
created_at: datetime = Field(..., description="Model creation timestamp")
data_period_start: Optional[datetime] = Field(None, description="Training data start date")
data_period_end: Optional[datetime] = Field(None, description="Training data end date")
@validator('tenant_id', 'model_id', pre=True)
def convert_uuid_to_string(cls, v):
"""Convert UUID objects to strings for JSON serialization"""
if isinstance(v, UUID):
return str(v)
return v
class Config:
from_attributes = True
class ModelTrainingStats(BaseModel):

View File

@@ -425,7 +425,7 @@ class TrainingService:
params["limit"] = limit
response = await client.get(
f"{settings.DATA_SERVICE_URL}/api/v1/sales/",
f"{settings.DATA_SERVICE_URL}/api/v1/tenants/{tenant_id}/sales/",
params=params,
headers=headers,
timeout=30.0
@@ -479,7 +479,7 @@ class TrainingService:
"""Fetch weather data from data service"""
try:
async with httpx.AsyncClient() as client:
params = {"tenant_id": tenant_id}
params = { }
if hasattr(request, 'start_date') and request.start_date:
params["start_date"] = request.start_date.isoformat()
@@ -507,7 +507,7 @@ class TrainingService:
"""Fetch traffic data from data service"""
try:
async with httpx.AsyncClient() as client:
params = {"tenant_id": tenant_id}
params = { }
if hasattr(request, 'start_date') and request.start_date:
params["start_date"] = request.start_date.isoformat()