Checking onboardin flow - fix 2
This commit is contained in:
@@ -57,17 +57,18 @@ async def start_training_job(
|
|||||||
"""Start a new training job for all products"""
|
"""Start a new training job for all products"""
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
tenant_id_str = str(tenant_id)
|
||||||
new_job_id = str(uuid4())
|
new_job_id = str(uuid4())
|
||||||
|
|
||||||
logger.info("Starting training job",
|
logger.info("Starting training job",
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id_str,
|
||||||
job_id=uuid4(),
|
job_id=uuid4(),
|
||||||
config=request.dict())
|
config=request.dict())
|
||||||
|
|
||||||
# Create training job
|
# Create training job
|
||||||
job = await training_service.create_training_job(
|
job = await training_service.create_training_job(
|
||||||
db, # Pass db here
|
db, # Pass db here
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id_str,
|
||||||
job_id=new_job_id,
|
job_id=new_job_id,
|
||||||
config=request.dict()
|
config=request.dict()
|
||||||
)
|
)
|
||||||
@@ -76,7 +77,7 @@ async def start_training_job(
|
|||||||
try:
|
try:
|
||||||
await publish_job_started(
|
await publish_job_started(
|
||||||
job_id=new_job_id,
|
job_id=new_job_id,
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id_str,
|
||||||
config=request.dict()
|
config=request.dict()
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -99,7 +100,7 @@ async def start_training_job(
|
|||||||
job_id=job.job_id,
|
job_id=job.job_id,
|
||||||
status=TrainingStatus.PENDING,
|
status=TrainingStatus.PENDING,
|
||||||
message="Training job created successfully",
|
message="Training job created successfully",
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id_str,
|
||||||
created_at=job.created_at,
|
created_at=job.created_at,
|
||||||
estimated_duration_minutes=30
|
estimated_duration_minutes=30
|
||||||
)
|
)
|
||||||
@@ -107,7 +108,7 @@ async def start_training_job(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to start training job",
|
logger.error("Failed to start training job",
|
||||||
error=str(e),
|
error=str(e),
|
||||||
tenant_id=tenant_id)
|
tenant_id=str(tenant_id))
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to start training job: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Failed to start training job: {str(e)}")
|
||||||
|
|
||||||
@router.get("/tenants/{tenant_id}/training/jobs", response_model=List[TrainingJobResponse])
|
@router.get("/tenants/{tenant_id}/training/jobs", response_model=List[TrainingJobResponse])
|
||||||
@@ -121,14 +122,17 @@ async def get_training_jobs(
|
|||||||
):
|
):
|
||||||
"""Get training jobs for tenant"""
|
"""Get training jobs for tenant"""
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
tenant_id_str = str(tenant_id)
|
||||||
|
|
||||||
logger.debug("Getting training jobs",
|
logger.debug("Getting training jobs",
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id_str,
|
||||||
status=status,
|
status=status,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
offset=offset)
|
offset=offset)
|
||||||
|
|
||||||
jobs = await training_service.get_training_jobs(
|
jobs = await training_service.get_training_jobs(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id_str,
|
||||||
status=status,
|
status=status,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
offset=offset
|
offset=offset
|
||||||
@@ -136,14 +140,14 @@ async def get_training_jobs(
|
|||||||
|
|
||||||
logger.debug("Retrieved training jobs",
|
logger.debug("Retrieved training jobs",
|
||||||
count=len(jobs),
|
count=len(jobs),
|
||||||
tenant_id=tenant_id)
|
tenant_id=tenant_id_str)
|
||||||
|
|
||||||
return jobs
|
return jobs
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to get training jobs",
|
logger.error("Failed to get training jobs",
|
||||||
error=str(e),
|
error=str(e),
|
||||||
tenant_id=tenant_id)
|
tenant_id=str(tenant_id))
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to get training jobs: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Failed to get training jobs: {str(e)}")
|
||||||
|
|
||||||
@router.get("/tenants/{tenant_id}/training/jobs/{job_id}", response_model=TrainingJobResponse)
|
@router.get("/tenants/{tenant_id}/training/jobs/{job_id}", response_model=TrainingJobResponse)
|
||||||
@@ -155,9 +159,12 @@ async def get_training_job(
|
|||||||
):
|
):
|
||||||
"""Get specific training job details"""
|
"""Get specific training job details"""
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
tenant_id_str = str(tenant_id)
|
||||||
|
|
||||||
logger.debug("Getting training job",
|
logger.debug("Getting training job",
|
||||||
job_id=job_id,
|
job_id=job_id,
|
||||||
tenant_id=tenant_id)
|
tenant_id=tenant_id_str)
|
||||||
|
|
||||||
job = await training_service.get_training_job(job_id)
|
job = await training_service.get_training_job(job_id)
|
||||||
|
|
||||||
@@ -165,7 +172,7 @@ async def get_training_job(
|
|||||||
if job.tenant_id != tenant_id:
|
if job.tenant_id != tenant_id:
|
||||||
logger.warning("Unauthorized job access attempt",
|
logger.warning("Unauthorized job access attempt",
|
||||||
job_id=job_id,
|
job_id=job_id,
|
||||||
tenant_id=tenant_id,
|
tenant_id=str(tenant_id),
|
||||||
job_tenant_id=job.tenant_id)
|
job_tenant_id=job.tenant_id)
|
||||||
raise HTTPException(status_code=404, detail="Job not found")
|
raise HTTPException(status_code=404, detail="Job not found")
|
||||||
|
|
||||||
@@ -188,13 +195,16 @@ async def get_training_progress(
|
|||||||
):
|
):
|
||||||
"""Get real-time training progress"""
|
"""Get real-time training progress"""
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
tenant_id_str = str(tenant_id)
|
||||||
|
|
||||||
logger.debug("Getting training progress",
|
logger.debug("Getting training progress",
|
||||||
job_id=job_id,
|
job_id=job_id,
|
||||||
tenant_id=tenant_id)
|
tenant_id=tenant_id_str)
|
||||||
|
|
||||||
# Verify job belongs to tenant
|
# Verify job belongs to tenant
|
||||||
job = await training_service.get_training_job(job_id)
|
job = await training_service.get_training_job(job_id)
|
||||||
if job.tenant_id != tenant_id:
|
if job.tenant_id != tenant_id_str:
|
||||||
raise HTTPException(status_code=404, detail="Job not found")
|
raise HTTPException(status_code=404, detail="Job not found")
|
||||||
|
|
||||||
progress = await training_service.get_job_progress(job_id)
|
progress = await training_service.get_job_progress(job_id)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from pydantic import BaseModel, Field, validator
|
|||||||
from typing import List, Optional, Dict, Any, Union
|
from typing import List, Optional, Dict, Any, Union
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
|
||||||
class TrainingStatus(str, Enum):
|
class TrainingStatus(str, Enum):
|
||||||
@@ -58,7 +59,6 @@ class SingleProductTrainingRequest(BaseModel):
|
|||||||
weekly_seasonality: bool = Field(True, description="Enable weekly seasonality")
|
weekly_seasonality: bool = Field(True, description="Enable weekly seasonality")
|
||||||
yearly_seasonality: bool = Field(True, description="Enable yearly seasonality")
|
yearly_seasonality: bool = Field(True, description="Enable yearly seasonality")
|
||||||
|
|
||||||
|
|
||||||
class TrainingJobResponse(BaseModel):
|
class TrainingJobResponse(BaseModel):
|
||||||
"""Response schema for training job creation"""
|
"""Response schema for training job creation"""
|
||||||
job_id: str = Field(..., description="Unique training job identifier")
|
job_id: str = Field(..., description="Unique training job identifier")
|
||||||
@@ -68,6 +68,16 @@ class TrainingJobResponse(BaseModel):
|
|||||||
created_at: datetime = Field(..., description="Job creation timestamp")
|
created_at: datetime = Field(..., description="Job creation timestamp")
|
||||||
estimated_duration_minutes: int = Field(..., description="Estimated completion time in minutes")
|
estimated_duration_minutes: int = Field(..., description="Estimated completion time in minutes")
|
||||||
|
|
||||||
|
# ✅ FIX: Add custom validator to convert UUID to string
|
||||||
|
@validator('tenant_id', 'job_id', pre=True)
|
||||||
|
def convert_uuid_to_string(cls, v):
|
||||||
|
"""Convert UUID objects to strings for JSON serialization"""
|
||||||
|
if isinstance(v, UUID):
|
||||||
|
return str(v)
|
||||||
|
return v
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
class TrainingJobStatus(BaseModel):
|
class TrainingJobStatus(BaseModel):
|
||||||
"""Response schema for training job status checks"""
|
"""Response schema for training job status checks"""
|
||||||
@@ -82,6 +92,16 @@ class TrainingJobStatus(BaseModel):
|
|||||||
products_failed: int = Field(0, description="Number of products that failed")
|
products_failed: int = Field(0, description="Number of products that failed")
|
||||||
error_message: Optional[str] = Field(None, description="Error message if failed")
|
error_message: Optional[str] = Field(None, description="Error message if failed")
|
||||||
|
|
||||||
|
@validator('job_id', pre=True)
|
||||||
|
def convert_uuid_to_string(cls, v):
|
||||||
|
"""Convert UUID objects to strings for JSON serialization"""
|
||||||
|
if isinstance(v, UUID):
|
||||||
|
return str(v)
|
||||||
|
return v
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
class TrainingJobProgress(BaseModel):
|
class TrainingJobProgress(BaseModel):
|
||||||
"""Schema for real-time training job progress updates"""
|
"""Schema for real-time training job progress updates"""
|
||||||
@@ -95,6 +115,16 @@ class TrainingJobProgress(BaseModel):
|
|||||||
estimated_time_remaining_minutes: Optional[int] = Field(None, description="Estimated time remaining")
|
estimated_time_remaining_minutes: Optional[int] = Field(None, description="Estimated time remaining")
|
||||||
timestamp: datetime = Field(default_factory=datetime.now, description="Progress update timestamp")
|
timestamp: datetime = Field(default_factory=datetime.now, description="Progress update timestamp")
|
||||||
|
|
||||||
|
@validator('job_id', pre=True)
|
||||||
|
def convert_uuid_to_string(cls, v):
|
||||||
|
"""Convert UUID objects to strings for JSON serialization"""
|
||||||
|
if isinstance(v, UUID):
|
||||||
|
return str(v)
|
||||||
|
return v
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
class DataValidationRequest(BaseModel):
|
class DataValidationRequest(BaseModel):
|
||||||
"""Request schema for validating training data"""
|
"""Request schema for validating training data"""
|
||||||
@@ -158,6 +188,16 @@ class TrainingResultsResponse(BaseModel):
|
|||||||
summary: Dict[str, Any] = Field(..., description="Training summary statistics")
|
summary: Dict[str, Any] = Field(..., description="Training summary statistics")
|
||||||
completed_at: datetime = Field(..., description="Job completion timestamp")
|
completed_at: datetime = Field(..., description="Job completion timestamp")
|
||||||
|
|
||||||
|
@validator('tenant_id', 'job_id', pre=True)
|
||||||
|
def convert_uuid_to_string(cls, v):
|
||||||
|
"""Convert UUID objects to strings for JSON serialization"""
|
||||||
|
if isinstance(v, UUID):
|
||||||
|
return str(v)
|
||||||
|
return v
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
class TrainingValidationResult(BaseModel):
|
class TrainingValidationResult(BaseModel):
|
||||||
"""Schema for training data validation results"""
|
"""Schema for training data validation results"""
|
||||||
@@ -233,6 +273,16 @@ class TrainedModelResponse(BaseModel):
|
|||||||
data_period_start: Optional[datetime] = Field(None, description="Training data start date")
|
data_period_start: Optional[datetime] = Field(None, description="Training data start date")
|
||||||
data_period_end: Optional[datetime] = Field(None, description="Training data end date")
|
data_period_end: Optional[datetime] = Field(None, description="Training data end date")
|
||||||
|
|
||||||
|
@validator('tenant_id', 'model_id', pre=True)
|
||||||
|
def convert_uuid_to_string(cls, v):
|
||||||
|
"""Convert UUID objects to strings for JSON serialization"""
|
||||||
|
if isinstance(v, UUID):
|
||||||
|
return str(v)
|
||||||
|
return v
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
class ModelTrainingStats(BaseModel):
|
class ModelTrainingStats(BaseModel):
|
||||||
"""Schema for model training statistics"""
|
"""Schema for model training statistics"""
|
||||||
|
|||||||
@@ -425,7 +425,7 @@ class TrainingService:
|
|||||||
params["limit"] = limit
|
params["limit"] = limit
|
||||||
|
|
||||||
response = await client.get(
|
response = await client.get(
|
||||||
f"{settings.DATA_SERVICE_URL}/api/v1/sales/",
|
f"{settings.DATA_SERVICE_URL}/api/v1/tenants/{tenant_id}/sales/",
|
||||||
params=params,
|
params=params,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
timeout=30.0
|
timeout=30.0
|
||||||
@@ -479,7 +479,7 @@ class TrainingService:
|
|||||||
"""Fetch weather data from data service"""
|
"""Fetch weather data from data service"""
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
params = {"tenant_id": tenant_id}
|
params = { }
|
||||||
|
|
||||||
if hasattr(request, 'start_date') and request.start_date:
|
if hasattr(request, 'start_date') and request.start_date:
|
||||||
params["start_date"] = request.start_date.isoformat()
|
params["start_date"] = request.start_date.isoformat()
|
||||||
@@ -507,7 +507,7 @@ class TrainingService:
|
|||||||
"""Fetch traffic data from data service"""
|
"""Fetch traffic data from data service"""
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
params = {"tenant_id": tenant_id}
|
params = { }
|
||||||
|
|
||||||
if hasattr(request, 'start_date') and request.start_date:
|
if hasattr(request, 'start_date') and request.start_date:
|
||||||
params["start_date"] = request.start_date.isoformat()
|
params["start_date"] = request.start_date.isoformat()
|
||||||
|
|||||||
Reference in New Issue
Block a user