From 30ac9450587d9b354339f9323c370b443132df41 Mon Sep 17 00:00:00 2001 From: Urtzi Alfaro Date: Sun, 27 Jul 2025 10:30:42 +0200 Subject: [PATCH] Checking onboardin flow - fix 2 --- services/training/app/api/training.py | 36 ++++++++----- services/training/app/schemas/training.py | 54 ++++++++++++++++++- .../training/app/services/training_service.py | 6 +-- 3 files changed, 78 insertions(+), 18 deletions(-) diff --git a/services/training/app/api/training.py b/services/training/app/api/training.py index 158411ac..52715bc8 100644 --- a/services/training/app/api/training.py +++ b/services/training/app/api/training.py @@ -57,17 +57,18 @@ async def start_training_job( """Start a new training job for all products""" try: + tenant_id_str = str(tenant_id) new_job_id = str(uuid4()) logger.info("Starting training job", - tenant_id=tenant_id, + tenant_id=tenant_id_str, job_id=uuid4(), config=request.dict()) # Create training job job = await training_service.create_training_job( db, # Pass db here - tenant_id=tenant_id, + tenant_id=tenant_id_str, job_id=new_job_id, config=request.dict() ) @@ -76,7 +77,7 @@ async def start_training_job( try: await publish_job_started( job_id=new_job_id, - tenant_id=tenant_id, + tenant_id=tenant_id_str, config=request.dict() ) except Exception as e: @@ -99,7 +100,7 @@ async def start_training_job( job_id=job.job_id, status=TrainingStatus.PENDING, message="Training job created successfully", - tenant_id=tenant_id, + tenant_id=tenant_id_str, created_at=job.created_at, estimated_duration_minutes=30 ) @@ -107,7 +108,7 @@ async def start_training_job( except Exception as e: logger.error("Failed to start training job", error=str(e), - tenant_id=tenant_id) + tenant_id=str(tenant_id)) raise HTTPException(status_code=500, detail=f"Failed to start training job: {str(e)}") @router.get("/tenants/{tenant_id}/training/jobs", response_model=List[TrainingJobResponse]) @@ -121,14 +122,17 @@ async def get_training_jobs( ): """Get training jobs for tenant""" try: + + tenant_id_str = str(tenant_id) + logger.debug("Getting training jobs", - tenant_id=tenant_id, + tenant_id=tenant_id_str, status=status, limit=limit, offset=offset) jobs = await training_service.get_training_jobs( - tenant_id=tenant_id, + tenant_id=tenant_id_str, status=status, limit=limit, offset=offset @@ -136,14 +140,14 @@ async def get_training_jobs( logger.debug("Retrieved training jobs", count=len(jobs), - tenant_id=tenant_id) + tenant_id=tenant_id_str) return jobs except Exception as e: logger.error("Failed to get training jobs", error=str(e), - tenant_id=tenant_id) + tenant_id=str(tenant_id)) raise HTTPException(status_code=500, detail=f"Failed to get training jobs: {str(e)}") @router.get("/tenants/{tenant_id}/training/jobs/{job_id}", response_model=TrainingJobResponse) @@ -155,9 +159,12 @@ async def get_training_job( ): """Get specific training job details""" try: + + tenant_id_str = str(tenant_id) + logger.debug("Getting training job", job_id=job_id, - tenant_id=tenant_id) + tenant_id=tenant_id_str) job = await training_service.get_training_job(job_id) @@ -165,7 +172,7 @@ async def get_training_job( if job.tenant_id != tenant_id: logger.warning("Unauthorized job access attempt", job_id=job_id, - tenant_id=tenant_id, + tenant_id=str(tenant_id), job_tenant_id=job.tenant_id) raise HTTPException(status_code=404, detail="Job not found") @@ -188,13 +195,16 @@ async def get_training_progress( ): """Get real-time training progress""" try: + + tenant_id_str = str(tenant_id) + logger.debug("Getting training progress", job_id=job_id, - tenant_id=tenant_id) + tenant_id=tenant_id_str) # Verify job belongs to tenant job = await training_service.get_training_job(job_id) - if job.tenant_id != tenant_id: + if job.tenant_id != tenant_id_str: raise HTTPException(status_code=404, detail="Job not found") progress = await training_service.get_job_progress(job_id) diff --git a/services/training/app/schemas/training.py b/services/training/app/schemas/training.py index a9c898a6..379a403c 100644 --- a/services/training/app/schemas/training.py +++ b/services/training/app/schemas/training.py @@ -8,6 +8,7 @@ from pydantic import BaseModel, Field, validator from typing import List, Optional, Dict, Any, Union from datetime import datetime from enum import Enum +from uuid import UUID class TrainingStatus(str, Enum): @@ -58,7 +59,6 @@ class SingleProductTrainingRequest(BaseModel): weekly_seasonality: bool = Field(True, description="Enable weekly seasonality") yearly_seasonality: bool = Field(True, description="Enable yearly seasonality") - class TrainingJobResponse(BaseModel): """Response schema for training job creation""" job_id: str = Field(..., description="Unique training job identifier") @@ -67,7 +67,17 @@ class TrainingJobResponse(BaseModel): tenant_id: str = Field(..., description="Tenant identifier") created_at: datetime = Field(..., description="Job creation timestamp") estimated_duration_minutes: int = Field(..., description="Estimated completion time in minutes") - + + # ✅ FIX: Add custom validator to convert UUID to string + @validator('tenant_id', 'job_id', pre=True) + def convert_uuid_to_string(cls, v): + """Convert UUID objects to strings for JSON serialization""" + if isinstance(v, UUID): + return str(v) + return v + + class Config: + from_attributes = True class TrainingJobStatus(BaseModel): """Response schema for training job status checks""" @@ -81,6 +91,16 @@ class TrainingJobStatus(BaseModel): products_completed: int = Field(0, description="Number of products completed") products_failed: int = Field(0, description="Number of products that failed") error_message: Optional[str] = Field(None, description="Error message if failed") + + @validator('job_id', pre=True) + def convert_uuid_to_string(cls, v): + """Convert UUID objects to strings for JSON serialization""" + if isinstance(v, UUID): + return str(v) + return v + + class Config: + from_attributes = True class TrainingJobProgress(BaseModel): @@ -94,6 +114,16 @@ class TrainingJobProgress(BaseModel): products_total: int = Field(0, description="Total number of products") estimated_time_remaining_minutes: Optional[int] = Field(None, description="Estimated time remaining") timestamp: datetime = Field(default_factory=datetime.now, description="Progress update timestamp") + + @validator('job_id', pre=True) + def convert_uuid_to_string(cls, v): + """Convert UUID objects to strings for JSON serialization""" + if isinstance(v, UUID): + return str(v) + return v + + class Config: + from_attributes = True class DataValidationRequest(BaseModel): @@ -157,6 +187,16 @@ class TrainingResultsResponse(BaseModel): training_results: Dict[str, ProductTrainingResult] = Field(..., description="Per-product results") summary: Dict[str, Any] = Field(..., description="Training summary statistics") completed_at: datetime = Field(..., description="Job completion timestamp") + + @validator('tenant_id', 'job_id', pre=True) + def convert_uuid_to_string(cls, v): + """Convert UUID objects to strings for JSON serialization""" + if isinstance(v, UUID): + return str(v) + return v + + class Config: + from_attributes = True class TrainingValidationResult(BaseModel): @@ -232,6 +272,16 @@ class TrainedModelResponse(BaseModel): created_at: datetime = Field(..., description="Model creation timestamp") data_period_start: Optional[datetime] = Field(None, description="Training data start date") data_period_end: Optional[datetime] = Field(None, description="Training data end date") + + @validator('tenant_id', 'model_id', pre=True) + def convert_uuid_to_string(cls, v): + """Convert UUID objects to strings for JSON serialization""" + if isinstance(v, UUID): + return str(v) + return v + + class Config: + from_attributes = True class ModelTrainingStats(BaseModel): diff --git a/services/training/app/services/training_service.py b/services/training/app/services/training_service.py index 004ffe65..9c3faa39 100644 --- a/services/training/app/services/training_service.py +++ b/services/training/app/services/training_service.py @@ -425,7 +425,7 @@ class TrainingService: params["limit"] = limit response = await client.get( - f"{settings.DATA_SERVICE_URL}/api/v1/sales/", + f"{settings.DATA_SERVICE_URL}/api/v1/tenants/{tenant_id}/sales/", params=params, headers=headers, timeout=30.0 @@ -479,7 +479,7 @@ class TrainingService: """Fetch weather data from data service""" try: async with httpx.AsyncClient() as client: - params = {"tenant_id": tenant_id} + params = { } if hasattr(request, 'start_date') and request.start_date: params["start_date"] = request.start_date.isoformat() @@ -507,7 +507,7 @@ class TrainingService: """Fetch traffic data from data service""" try: async with httpx.AsyncClient() as client: - params = {"tenant_id": tenant_id} + params = { } if hasattr(request, 'start_date') and request.start_date: params["start_date"] = request.start_date.isoformat()