Improve training test 3

This commit is contained in:
Urtzi Alfaro
2025-07-29 12:45:39 +02:00
parent ef62f05031
commit cd6fd875f7
4 changed files with 154 additions and 60 deletions

View File

@@ -173,7 +173,14 @@ async def _proxy_request(request: Request, target_path: str, service_url: str):
# Add query parameters
params = dict(request.query_params)
async with httpx.AsyncClient(timeout=30.0) as client:
timeout_config = httpx.Timeout(
connect=30.0, # Connection timeout
read=600.0, # Read timeout: 10 minutes (was 30s)
write=30.0, # Write timeout
pool=30.0 # Pool timeout
)
async with httpx.AsyncClient(timeout=timeout_config) as client:
response = await client.request(
method=request.method,
url=url,

View File

@@ -134,29 +134,53 @@ async def get_sales_data(
start_date: Optional[datetime] = Query(None, description="Start date filter"),
end_date: Optional[datetime] = Query(None, description="End date filter"),
product_name: Optional[str] = Query(None, description="Product name filter"),
# ✅ FIX: Add missing pagination parameters
limit: Optional[int] = Query(1000, le=5000, description="Maximum number of records to return"),
offset: Optional[int] = Query(0, ge=0, description="Number of records to skip"),
# ✅ FIX: Add additional filtering parameters
product_names: Optional[List[str]] = Query(None, description="Multiple product name filters"),
location_ids: Optional[List[str]] = Query(None, description="Location ID filters"),
sources: Optional[List[str]] = Query(None, description="Source filters"),
min_quantity: Optional[int] = Query(None, description="Minimum quantity filter"),
max_quantity: Optional[int] = Query(None, description="Maximum quantity filter"),
min_revenue: Optional[float] = Query(None, description="Minimum revenue filter"),
max_revenue: Optional[float] = Query(None, description="Maximum revenue filter"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""Get sales data for tenant with filters"""
"""Get sales data for tenant with filters and pagination"""
try:
logger.debug("Querying sales data",
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
product_name=product_name)
product_name=product_name,
limit=limit,
offset=offset)
# ✅ FIX: Create complete SalesDataQuery with all parameters
query = SalesDataQuery(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
product_name=product_name
product_names=[product_name] if product_name else product_names,
location_ids=location_ids,
sources=sources,
min_quantity=min_quantity,
max_quantity=max_quantity,
min_revenue=min_revenue,
max_revenue=max_revenue,
limit=limit, # ✅ Now properly passed from query params
offset=offset # ✅ Now properly passed from query params
)
records = await SalesService.get_sales_data(query, db)
logger.debug("Successfully retrieved sales data",
count=len(records),
tenant_id=tenant_id)
tenant_id=tenant_id,
limit=limit,
offset=offset)
return records
except Exception as e:

View File

@@ -22,6 +22,7 @@ import warnings
warnings.filterwarnings('ignore')
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
from app.models.training import TrainedModel
from app.core.database import get_db_session
@@ -565,19 +566,27 @@ class BakeryProphetManager:
"""Deactivate previous models for the same product"""
if self.db_session:
try:
# Update previous models to inactive
query = """
# ✅ FIX: Wrap SQL string with text() for SQLAlchemy 2.0
query = text("""
UPDATE trained_models
SET is_active = false, is_production = false
WHERE tenant_id = :tenant_id AND product_name = :product_name
"""
""")
await self.db_session.execute(query, {
"tenant_id": tenant_id,
"product_name": product_name
})
# ✅ ADD: Commit the transaction
await self.db_session.commit()
logger.info(f"Successfully deactivated previous models for {product_name}")
except Exception as e:
logger.error(f"Failed to deactivate previous models: {str(e)}")
# ✅ ADD: Rollback on error
await self.db_session.rollback()
# Keep all existing methods unchanged
async def generate_forecast(self,

View File

@@ -12,17 +12,11 @@ class DataServiceClient:
def __init__(self):
self.base_url = settings.API_GATEWAY_URL
self.timeout = 2000.0
async def fetch_sales_data(
self,
tenant_id: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
product_name: Optional[str] = None
) -> List[Dict[str, Any]]:
async def fetch_sales_data(self, tenant_id: str) -> List[Dict[str, Any]]:
"""
Fetch sales data for training via API Gateway
✅ Uses proper service authentication
Fetch all sales data for training (no pagination limits)
FIXED: Retrieves ALL records instead of being limited to 1000
"""
try:
# Get service token
@@ -32,51 +26,104 @@ class DataServiceClient:
headers = service_auth.get_request_headers(tenant_id)
headers["Authorization"] = f"Bearer {token}"
# Prepare query parameters
params = {}
if start_date:
params["start_date"] = start_date
if end_date:
params["end_date"] = end_date
if product_name:
params["product_name"] = product_name
all_records = []
page = 0
page_size = 5000 # Use maximum allowed by API
# Make request via gateway
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(
f"{self.base_url}/api/v1/tenants/{tenant_id}/sales",
headers=headers,
params=params
)
while True:
# Prepare query parameters for pagination
params = {
"limit": page_size,
"offset": page * page_size
}
logger.info(f"Sales data request: {response.status_code}",
tenant_id=tenant_id,
url=response.url)
logger.info(f"Fetching sales data page {page + 1} (offset: {page * page_size})",
tenant_id=tenant_id)
if response.status_code == 200:
data = response.json()
logger.info(f"Successfully fetched {len(data)} sales records via gateway",
tenant_id=tenant_id)
return data
# Make GET request via gateway with pagination
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(
f"{self.base_url}/api/v1/tenants/{tenant_id}/sales",
headers=headers,
params=params
)
elif response.status_code == 401:
logger.error("Authentication failed with gateway",
tenant_id=tenant_id,
response_text=response.text)
return []
elif response.status_code == 404:
logger.warning("Sales data endpoint not found",
tenant_id=tenant_id,
url=response.url)
return []
else:
logger.error(f"Gateway request failed: HTTP {response.status_code}",
tenant_id=tenant_id,
response_text=response.text)
return []
if response.status_code == 200:
page_data = response.json()
# Handle different response formats
if isinstance(page_data, list):
# Direct list response (no pagination metadata)
records = page_data
logger.info(f"Retrieved {len(records)} records from page {page + 1} (direct list)")
# For direct list responses, we need to check if we got the max possible
# If we got less than page_size, we're done
if len(records) == 0:
logger.info("No records in response, pagination complete")
break
elif len(records) < page_size:
# Got fewer than requested, this is the last page
all_records.extend(records)
logger.info(f"Final page: retrieved {len(records)} records, total: {len(all_records)}")
break
else:
# Got full page, there might be more
all_records.extend(records)
logger.info(f"Full page retrieved: {len(records)} records, continuing to next page")
elif isinstance(page_data, dict):
# Paginated response format
records = page_data.get('records', page_data.get('data', []))
total_available = page_data.get('total', 0)
logger.info(f"Retrieved {len(records)} records from page {page + 1} (paginated response)")
if not records:
logger.info("No more records found in paginated response")
break
all_records.extend(records)
# Check if we've got all available records
if len(all_records) >= total_available:
logger.info(f"Retrieved all available records: {len(all_records)}/{total_available}")
break
else:
logger.warning(f"Unexpected response format: {type(page_data)}")
records = []
break
page += 1
# Safety break to prevent infinite loops
if page > 100: # Max 500,000 records (100 * 5000)
logger.warning("Reached maximum page limit, stopping pagination")
break
elif response.status_code == 401:
logger.error("Authentication failed with gateway",
tenant_id=tenant_id,
response_text=response.text)
return []
elif response.status_code == 404:
logger.warning("Sales data endpoint not found",
tenant_id=tenant_id,
url=response.url)
return []
else:
logger.error(f"Gateway request failed: HTTP {response.status_code}",
tenant_id=tenant_id,
response_text=response.text)
return []
logger.info(f"Successfully fetched {len(all_records)} total sales records via gateway",
tenant_id=tenant_id)
return all_records
except httpx.TimeoutException:
logger.error("Timeout when fetching sales data via gateway",
tenant_id=tenant_id)
@@ -86,7 +133,7 @@ class DataServiceClient:
logger.error(f"Error fetching sales data via gateway: {e}",
tenant_id=tenant_id)
return []
async def fetch_weather_data(
self,
@@ -195,8 +242,15 @@ class DataServiceClient:
logger.info(f"Traffic request payload: {payload}", tenant_id=tenant_id)
# Madrid traffic data can take 5-10 minutes to download and process
timeout_config = httpx.Timeout(
connect=30.0, # Connection timeout
read=600.0, # Read timeout: 10 minutes (was 30s)
write=30.0, # Write timeout
pool=30.0 # Pool timeout
)
# Make POST request via gateway
timeout_config = httpx.Timeout(connect=30.0, read=self.timeout, write=30.0, pool=30.0)
async with httpx.AsyncClient(timeout=timeout_config) as client:
response = await client.post(
f"{self.base_url}/api/v1/tenants/{tenant_id}/traffic/historical",