Checking onboardin flow - fix 2

This commit is contained in:
Urtzi Alfaro
2025-07-27 10:30:42 +02:00
parent cb3ae4d78b
commit 30ac945058
3 changed files with 78 additions and 18 deletions

View File

@@ -57,17 +57,18 @@ async def start_training_job(
"""Start a new training job for all products""" """Start a new training job for all products"""
try: try:
tenant_id_str = str(tenant_id)
new_job_id = str(uuid4()) new_job_id = str(uuid4())
logger.info("Starting training job", logger.info("Starting training job",
tenant_id=tenant_id, tenant_id=tenant_id_str,
job_id=uuid4(), job_id=uuid4(),
config=request.dict()) config=request.dict())
# Create training job # Create training job
job = await training_service.create_training_job( job = await training_service.create_training_job(
db, # Pass db here db, # Pass db here
tenant_id=tenant_id, tenant_id=tenant_id_str,
job_id=new_job_id, job_id=new_job_id,
config=request.dict() config=request.dict()
) )
@@ -76,7 +77,7 @@ async def start_training_job(
try: try:
await publish_job_started( await publish_job_started(
job_id=new_job_id, job_id=new_job_id,
tenant_id=tenant_id, tenant_id=tenant_id_str,
config=request.dict() config=request.dict()
) )
except Exception as e: except Exception as e:
@@ -99,7 +100,7 @@ async def start_training_job(
job_id=job.job_id, job_id=job.job_id,
status=TrainingStatus.PENDING, status=TrainingStatus.PENDING,
message="Training job created successfully", message="Training job created successfully",
tenant_id=tenant_id, tenant_id=tenant_id_str,
created_at=job.created_at, created_at=job.created_at,
estimated_duration_minutes=30 estimated_duration_minutes=30
) )
@@ -107,7 +108,7 @@ async def start_training_job(
except Exception as e: except Exception as e:
logger.error("Failed to start training job", logger.error("Failed to start training job",
error=str(e), error=str(e),
tenant_id=tenant_id) tenant_id=str(tenant_id))
raise HTTPException(status_code=500, detail=f"Failed to start training job: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to start training job: {str(e)}")
@router.get("/tenants/{tenant_id}/training/jobs", response_model=List[TrainingJobResponse]) @router.get("/tenants/{tenant_id}/training/jobs", response_model=List[TrainingJobResponse])
@@ -121,14 +122,17 @@ async def get_training_jobs(
): ):
"""Get training jobs for tenant""" """Get training jobs for tenant"""
try: try:
tenant_id_str = str(tenant_id)
logger.debug("Getting training jobs", logger.debug("Getting training jobs",
tenant_id=tenant_id, tenant_id=tenant_id_str,
status=status, status=status,
limit=limit, limit=limit,
offset=offset) offset=offset)
jobs = await training_service.get_training_jobs( jobs = await training_service.get_training_jobs(
tenant_id=tenant_id, tenant_id=tenant_id_str,
status=status, status=status,
limit=limit, limit=limit,
offset=offset offset=offset
@@ -136,14 +140,14 @@ async def get_training_jobs(
logger.debug("Retrieved training jobs", logger.debug("Retrieved training jobs",
count=len(jobs), count=len(jobs),
tenant_id=tenant_id) tenant_id=tenant_id_str)
return jobs return jobs
except Exception as e: except Exception as e:
logger.error("Failed to get training jobs", logger.error("Failed to get training jobs",
error=str(e), error=str(e),
tenant_id=tenant_id) tenant_id=str(tenant_id))
raise HTTPException(status_code=500, detail=f"Failed to get training jobs: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to get training jobs: {str(e)}")
@router.get("/tenants/{tenant_id}/training/jobs/{job_id}", response_model=TrainingJobResponse) @router.get("/tenants/{tenant_id}/training/jobs/{job_id}", response_model=TrainingJobResponse)
@@ -155,9 +159,12 @@ async def get_training_job(
): ):
"""Get specific training job details""" """Get specific training job details"""
try: try:
tenant_id_str = str(tenant_id)
logger.debug("Getting training job", logger.debug("Getting training job",
job_id=job_id, job_id=job_id,
tenant_id=tenant_id) tenant_id=tenant_id_str)
job = await training_service.get_training_job(job_id) job = await training_service.get_training_job(job_id)
@@ -165,7 +172,7 @@ async def get_training_job(
if job.tenant_id != tenant_id: if job.tenant_id != tenant_id:
logger.warning("Unauthorized job access attempt", logger.warning("Unauthorized job access attempt",
job_id=job_id, job_id=job_id,
tenant_id=tenant_id, tenant_id=str(tenant_id),
job_tenant_id=job.tenant_id) job_tenant_id=job.tenant_id)
raise HTTPException(status_code=404, detail="Job not found") raise HTTPException(status_code=404, detail="Job not found")
@@ -188,13 +195,16 @@ async def get_training_progress(
): ):
"""Get real-time training progress""" """Get real-time training progress"""
try: try:
tenant_id_str = str(tenant_id)
logger.debug("Getting training progress", logger.debug("Getting training progress",
job_id=job_id, job_id=job_id,
tenant_id=tenant_id) tenant_id=tenant_id_str)
# Verify job belongs to tenant # Verify job belongs to tenant
job = await training_service.get_training_job(job_id) job = await training_service.get_training_job(job_id)
if job.tenant_id != tenant_id: if job.tenant_id != tenant_id_str:
raise HTTPException(status_code=404, detail="Job not found") raise HTTPException(status_code=404, detail="Job not found")
progress = await training_service.get_job_progress(job_id) progress = await training_service.get_job_progress(job_id)

View File

@@ -8,6 +8,7 @@ from pydantic import BaseModel, Field, validator
from typing import List, Optional, Dict, Any, Union from typing import List, Optional, Dict, Any, Union
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from uuid import UUID
class TrainingStatus(str, Enum): class TrainingStatus(str, Enum):
@@ -58,7 +59,6 @@ class SingleProductTrainingRequest(BaseModel):
weekly_seasonality: bool = Field(True, description="Enable weekly seasonality") weekly_seasonality: bool = Field(True, description="Enable weekly seasonality")
yearly_seasonality: bool = Field(True, description="Enable yearly seasonality") yearly_seasonality: bool = Field(True, description="Enable yearly seasonality")
class TrainingJobResponse(BaseModel): class TrainingJobResponse(BaseModel):
"""Response schema for training job creation""" """Response schema for training job creation"""
job_id: str = Field(..., description="Unique training job identifier") job_id: str = Field(..., description="Unique training job identifier")
@@ -68,6 +68,16 @@ class TrainingJobResponse(BaseModel):
created_at: datetime = Field(..., description="Job creation timestamp") created_at: datetime = Field(..., description="Job creation timestamp")
estimated_duration_minutes: int = Field(..., description="Estimated completion time in minutes") estimated_duration_minutes: int = Field(..., description="Estimated completion time in minutes")
# ✅ FIX: Add custom validator to convert UUID to string
@validator('tenant_id', 'job_id', pre=True)
def convert_uuid_to_string(cls, v):
"""Convert UUID objects to strings for JSON serialization"""
if isinstance(v, UUID):
return str(v)
return v
class Config:
from_attributes = True
class TrainingJobStatus(BaseModel): class TrainingJobStatus(BaseModel):
"""Response schema for training job status checks""" """Response schema for training job status checks"""
@@ -82,6 +92,16 @@ class TrainingJobStatus(BaseModel):
products_failed: int = Field(0, description="Number of products that failed") products_failed: int = Field(0, description="Number of products that failed")
error_message: Optional[str] = Field(None, description="Error message if failed") error_message: Optional[str] = Field(None, description="Error message if failed")
@validator('job_id', pre=True)
def convert_uuid_to_string(cls, v):
"""Convert UUID objects to strings for JSON serialization"""
if isinstance(v, UUID):
return str(v)
return v
class Config:
from_attributes = True
class TrainingJobProgress(BaseModel): class TrainingJobProgress(BaseModel):
"""Schema for real-time training job progress updates""" """Schema for real-time training job progress updates"""
@@ -95,6 +115,16 @@ class TrainingJobProgress(BaseModel):
estimated_time_remaining_minutes: Optional[int] = Field(None, description="Estimated time remaining") estimated_time_remaining_minutes: Optional[int] = Field(None, description="Estimated time remaining")
timestamp: datetime = Field(default_factory=datetime.now, description="Progress update timestamp") timestamp: datetime = Field(default_factory=datetime.now, description="Progress update timestamp")
@validator('job_id', pre=True)
def convert_uuid_to_string(cls, v):
"""Convert UUID objects to strings for JSON serialization"""
if isinstance(v, UUID):
return str(v)
return v
class Config:
from_attributes = True
class DataValidationRequest(BaseModel): class DataValidationRequest(BaseModel):
"""Request schema for validating training data""" """Request schema for validating training data"""
@@ -158,6 +188,16 @@ class TrainingResultsResponse(BaseModel):
summary: Dict[str, Any] = Field(..., description="Training summary statistics") summary: Dict[str, Any] = Field(..., description="Training summary statistics")
completed_at: datetime = Field(..., description="Job completion timestamp") completed_at: datetime = Field(..., description="Job completion timestamp")
@validator('tenant_id', 'job_id', pre=True)
def convert_uuid_to_string(cls, v):
"""Convert UUID objects to strings for JSON serialization"""
if isinstance(v, UUID):
return str(v)
return v
class Config:
from_attributes = True
class TrainingValidationResult(BaseModel): class TrainingValidationResult(BaseModel):
"""Schema for training data validation results""" """Schema for training data validation results"""
@@ -233,6 +273,16 @@ class TrainedModelResponse(BaseModel):
data_period_start: Optional[datetime] = Field(None, description="Training data start date") data_period_start: Optional[datetime] = Field(None, description="Training data start date")
data_period_end: Optional[datetime] = Field(None, description="Training data end date") data_period_end: Optional[datetime] = Field(None, description="Training data end date")
@validator('tenant_id', 'model_id', pre=True)
def convert_uuid_to_string(cls, v):
"""Convert UUID objects to strings for JSON serialization"""
if isinstance(v, UUID):
return str(v)
return v
class Config:
from_attributes = True
class ModelTrainingStats(BaseModel): class ModelTrainingStats(BaseModel):
"""Schema for model training statistics""" """Schema for model training statistics"""

View File

@@ -425,7 +425,7 @@ class TrainingService:
params["limit"] = limit params["limit"] = limit
response = await client.get( response = await client.get(
f"{settings.DATA_SERVICE_URL}/api/v1/sales/", f"{settings.DATA_SERVICE_URL}/api/v1/tenants/{tenant_id}/sales/",
params=params, params=params,
headers=headers, headers=headers,
timeout=30.0 timeout=30.0
@@ -479,7 +479,7 @@ class TrainingService:
"""Fetch weather data from data service""" """Fetch weather data from data service"""
try: try:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
params = {"tenant_id": tenant_id} params = { }
if hasattr(request, 'start_date') and request.start_date: if hasattr(request, 'start_date') and request.start_date:
params["start_date"] = request.start_date.isoformat() params["start_date"] = request.start_date.isoformat()
@@ -507,7 +507,7 @@ class TrainingService:
"""Fetch traffic data from data service""" """Fetch traffic data from data service"""
try: try:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
params = {"tenant_id": tenant_id} params = { }
if hasattr(request, 'start_date') and request.start_date: if hasattr(request, 'start_date') and request.start_date:
params["start_date"] = request.start_date.isoformat() params["start_date"] = request.start_date.isoformat()