Improve training test 3
This commit is contained in:
@@ -173,7 +173,14 @@ async def _proxy_request(request: Request, target_path: str, service_url: str):
|
|||||||
# Add query parameters
|
# Add query parameters
|
||||||
params = dict(request.query_params)
|
params = dict(request.query_params)
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
timeout_config = httpx.Timeout(
|
||||||
|
connect=30.0, # Connection timeout
|
||||||
|
read=600.0, # Read timeout: 10 minutes (was 30s)
|
||||||
|
write=30.0, # Write timeout
|
||||||
|
pool=30.0 # Pool timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=timeout_config) as client:
|
||||||
response = await client.request(
|
response = await client.request(
|
||||||
method=request.method,
|
method=request.method,
|
||||||
url=url,
|
url=url,
|
||||||
|
|||||||
@@ -134,29 +134,53 @@ async def get_sales_data(
|
|||||||
start_date: Optional[datetime] = Query(None, description="Start date filter"),
|
start_date: Optional[datetime] = Query(None, description="Start date filter"),
|
||||||
end_date: Optional[datetime] = Query(None, description="End date filter"),
|
end_date: Optional[datetime] = Query(None, description="End date filter"),
|
||||||
product_name: Optional[str] = Query(None, description="Product name filter"),
|
product_name: Optional[str] = Query(None, description="Product name filter"),
|
||||||
|
# ✅ FIX: Add missing pagination parameters
|
||||||
|
limit: Optional[int] = Query(1000, le=5000, description="Maximum number of records to return"),
|
||||||
|
offset: Optional[int] = Query(0, ge=0, description="Number of records to skip"),
|
||||||
|
# ✅ FIX: Add additional filtering parameters
|
||||||
|
product_names: Optional[List[str]] = Query(None, description="Multiple product name filters"),
|
||||||
|
location_ids: Optional[List[str]] = Query(None, description="Location ID filters"),
|
||||||
|
sources: Optional[List[str]] = Query(None, description="Source filters"),
|
||||||
|
min_quantity: Optional[int] = Query(None, description="Minimum quantity filter"),
|
||||||
|
max_quantity: Optional[int] = Query(None, description="Maximum quantity filter"),
|
||||||
|
min_revenue: Optional[float] = Query(None, description="Minimum revenue filter"),
|
||||||
|
max_revenue: Optional[float] = Query(None, description="Maximum revenue filter"),
|
||||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""Get sales data for tenant with filters"""
|
"""Get sales data for tenant with filters and pagination"""
|
||||||
try:
|
try:
|
||||||
logger.debug("Querying sales data",
|
logger.debug("Querying sales data",
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
start_date=start_date,
|
start_date=start_date,
|
||||||
end_date=end_date,
|
end_date=end_date,
|
||||||
product_name=product_name)
|
product_name=product_name,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset)
|
||||||
|
|
||||||
|
# ✅ FIX: Create complete SalesDataQuery with all parameters
|
||||||
query = SalesDataQuery(
|
query = SalesDataQuery(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
start_date=start_date,
|
start_date=start_date,
|
||||||
end_date=end_date,
|
end_date=end_date,
|
||||||
product_name=product_name
|
product_names=[product_name] if product_name else product_names,
|
||||||
|
location_ids=location_ids,
|
||||||
|
sources=sources,
|
||||||
|
min_quantity=min_quantity,
|
||||||
|
max_quantity=max_quantity,
|
||||||
|
min_revenue=min_revenue,
|
||||||
|
max_revenue=max_revenue,
|
||||||
|
limit=limit, # ✅ Now properly passed from query params
|
||||||
|
offset=offset # ✅ Now properly passed from query params
|
||||||
)
|
)
|
||||||
|
|
||||||
records = await SalesService.get_sales_data(query, db)
|
records = await SalesService.get_sales_data(query, db)
|
||||||
|
|
||||||
logger.debug("Successfully retrieved sales data",
|
logger.debug("Successfully retrieved sales data",
|
||||||
count=len(records),
|
count=len(records),
|
||||||
tenant_id=tenant_id)
|
tenant_id=tenant_id,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset)
|
||||||
return records
|
return records
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import warnings
|
|||||||
warnings.filterwarnings('ignore')
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy import text
|
||||||
from app.models.training import TrainedModel
|
from app.models.training import TrainedModel
|
||||||
from app.core.database import get_db_session
|
from app.core.database import get_db_session
|
||||||
|
|
||||||
@@ -565,19 +566,27 @@ class BakeryProphetManager:
|
|||||||
"""Deactivate previous models for the same product"""
|
"""Deactivate previous models for the same product"""
|
||||||
if self.db_session:
|
if self.db_session:
|
||||||
try:
|
try:
|
||||||
# Update previous models to inactive
|
# ✅ FIX: Wrap SQL string with text() for SQLAlchemy 2.0
|
||||||
query = """
|
query = text("""
|
||||||
UPDATE trained_models
|
UPDATE trained_models
|
||||||
SET is_active = false, is_production = false
|
SET is_active = false, is_production = false
|
||||||
WHERE tenant_id = :tenant_id AND product_name = :product_name
|
WHERE tenant_id = :tenant_id AND product_name = :product_name
|
||||||
"""
|
""")
|
||||||
|
|
||||||
await self.db_session.execute(query, {
|
await self.db_session.execute(query, {
|
||||||
"tenant_id": tenant_id,
|
"tenant_id": tenant_id,
|
||||||
"product_name": product_name
|
"product_name": product_name
|
||||||
})
|
})
|
||||||
|
|
||||||
|
# ✅ ADD: Commit the transaction
|
||||||
|
await self.db_session.commit()
|
||||||
|
|
||||||
|
logger.info(f"Successfully deactivated previous models for {product_name}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to deactivate previous models: {str(e)}")
|
logger.error(f"Failed to deactivate previous models: {str(e)}")
|
||||||
|
# ✅ ADD: Rollback on error
|
||||||
|
await self.db_session.rollback()
|
||||||
|
|
||||||
# Keep all existing methods unchanged
|
# Keep all existing methods unchanged
|
||||||
async def generate_forecast(self,
|
async def generate_forecast(self,
|
||||||
|
|||||||
@@ -13,16 +13,10 @@ class DataServiceClient:
|
|||||||
self.base_url = settings.API_GATEWAY_URL
|
self.base_url = settings.API_GATEWAY_URL
|
||||||
self.timeout = 2000.0
|
self.timeout = 2000.0
|
||||||
|
|
||||||
async def fetch_sales_data(
|
async def fetch_sales_data(self, tenant_id: str) -> List[Dict[str, Any]]:
|
||||||
self,
|
|
||||||
tenant_id: str,
|
|
||||||
start_date: Optional[str] = None,
|
|
||||||
end_date: Optional[str] = None,
|
|
||||||
product_name: Optional[str] = None
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
"""
|
||||||
Fetch sales data for training via API Gateway
|
Fetch all sales data for training (no pagination limits)
|
||||||
✅ Uses proper service authentication
|
FIXED: Retrieves ALL records instead of being limited to 1000
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Get service token
|
# Get service token
|
||||||
@@ -32,50 +26,103 @@ class DataServiceClient:
|
|||||||
headers = service_auth.get_request_headers(tenant_id)
|
headers = service_auth.get_request_headers(tenant_id)
|
||||||
headers["Authorization"] = f"Bearer {token}"
|
headers["Authorization"] = f"Bearer {token}"
|
||||||
|
|
||||||
# Prepare query parameters
|
all_records = []
|
||||||
params = {}
|
page = 0
|
||||||
if start_date:
|
page_size = 5000 # Use maximum allowed by API
|
||||||
params["start_date"] = start_date
|
|
||||||
if end_date:
|
|
||||||
params["end_date"] = end_date
|
|
||||||
if product_name:
|
|
||||||
params["product_name"] = product_name
|
|
||||||
|
|
||||||
# Make request via gateway
|
while True:
|
||||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
# Prepare query parameters for pagination
|
||||||
response = await client.get(
|
params = {
|
||||||
f"{self.base_url}/api/v1/tenants/{tenant_id}/sales",
|
"limit": page_size,
|
||||||
headers=headers,
|
"offset": page * page_size
|
||||||
params=params
|
}
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Sales data request: {response.status_code}",
|
logger.info(f"Fetching sales data page {page + 1} (offset: {page * page_size})",
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id)
|
||||||
url=response.url)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
# Make GET request via gateway with pagination
|
||||||
data = response.json()
|
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||||
logger.info(f"Successfully fetched {len(data)} sales records via gateway",
|
response = await client.get(
|
||||||
tenant_id=tenant_id)
|
f"{self.base_url}/api/v1/tenants/{tenant_id}/sales",
|
||||||
return data
|
headers=headers,
|
||||||
|
params=params
|
||||||
|
)
|
||||||
|
|
||||||
elif response.status_code == 401:
|
if response.status_code == 200:
|
||||||
logger.error("Authentication failed with gateway",
|
page_data = response.json()
|
||||||
tenant_id=tenant_id,
|
|
||||||
response_text=response.text)
|
|
||||||
return []
|
|
||||||
|
|
||||||
elif response.status_code == 404:
|
# Handle different response formats
|
||||||
logger.warning("Sales data endpoint not found",
|
if isinstance(page_data, list):
|
||||||
tenant_id=tenant_id,
|
# Direct list response (no pagination metadata)
|
||||||
url=response.url)
|
records = page_data
|
||||||
return []
|
logger.info(f"Retrieved {len(records)} records from page {page + 1} (direct list)")
|
||||||
|
|
||||||
else:
|
# For direct list responses, we need to check if we got the max possible
|
||||||
logger.error(f"Gateway request failed: HTTP {response.status_code}",
|
# If we got less than page_size, we're done
|
||||||
tenant_id=tenant_id,
|
if len(records) == 0:
|
||||||
response_text=response.text)
|
logger.info("No records in response, pagination complete")
|
||||||
return []
|
break
|
||||||
|
elif len(records) < page_size:
|
||||||
|
# Got fewer than requested, this is the last page
|
||||||
|
all_records.extend(records)
|
||||||
|
logger.info(f"Final page: retrieved {len(records)} records, total: {len(all_records)}")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Got full page, there might be more
|
||||||
|
all_records.extend(records)
|
||||||
|
logger.info(f"Full page retrieved: {len(records)} records, continuing to next page")
|
||||||
|
|
||||||
|
elif isinstance(page_data, dict):
|
||||||
|
# Paginated response format
|
||||||
|
records = page_data.get('records', page_data.get('data', []))
|
||||||
|
total_available = page_data.get('total', 0)
|
||||||
|
|
||||||
|
logger.info(f"Retrieved {len(records)} records from page {page + 1} (paginated response)")
|
||||||
|
|
||||||
|
if not records:
|
||||||
|
logger.info("No more records found in paginated response")
|
||||||
|
break
|
||||||
|
|
||||||
|
all_records.extend(records)
|
||||||
|
|
||||||
|
# Check if we've got all available records
|
||||||
|
if len(all_records) >= total_available:
|
||||||
|
logger.info(f"Retrieved all available records: {len(all_records)}/{total_available}")
|
||||||
|
break
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unexpected response format: {type(page_data)}")
|
||||||
|
records = []
|
||||||
|
break
|
||||||
|
|
||||||
|
page += 1
|
||||||
|
|
||||||
|
# Safety break to prevent infinite loops
|
||||||
|
if page > 100: # Max 500,000 records (100 * 5000)
|
||||||
|
logger.warning("Reached maximum page limit, stopping pagination")
|
||||||
|
break
|
||||||
|
|
||||||
|
elif response.status_code == 401:
|
||||||
|
logger.error("Authentication failed with gateway",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
response_text=response.text)
|
||||||
|
return []
|
||||||
|
|
||||||
|
elif response.status_code == 404:
|
||||||
|
logger.warning("Sales data endpoint not found",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
url=response.url)
|
||||||
|
return []
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.error(f"Gateway request failed: HTTP {response.status_code}",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
response_text=response.text)
|
||||||
|
return []
|
||||||
|
|
||||||
|
logger.info(f"Successfully fetched {len(all_records)} total sales records via gateway",
|
||||||
|
tenant_id=tenant_id)
|
||||||
|
return all_records
|
||||||
|
|
||||||
except httpx.TimeoutException:
|
except httpx.TimeoutException:
|
||||||
logger.error("Timeout when fetching sales data via gateway",
|
logger.error("Timeout when fetching sales data via gateway",
|
||||||
@@ -195,8 +242,15 @@ class DataServiceClient:
|
|||||||
|
|
||||||
logger.info(f"Traffic request payload: {payload}", tenant_id=tenant_id)
|
logger.info(f"Traffic request payload: {payload}", tenant_id=tenant_id)
|
||||||
|
|
||||||
|
# Madrid traffic data can take 5-10 minutes to download and process
|
||||||
|
timeout_config = httpx.Timeout(
|
||||||
|
connect=30.0, # Connection timeout
|
||||||
|
read=600.0, # Read timeout: 10 minutes (was 30s)
|
||||||
|
write=30.0, # Write timeout
|
||||||
|
pool=30.0 # Pool timeout
|
||||||
|
)
|
||||||
|
|
||||||
# Make POST request via gateway
|
# Make POST request via gateway
|
||||||
timeout_config = httpx.Timeout(connect=30.0, read=self.timeout, write=30.0, pool=30.0)
|
|
||||||
async with httpx.AsyncClient(timeout=timeout_config) as client:
|
async with httpx.AsyncClient(timeout=timeout_config) as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{self.base_url}/api/v1/tenants/{tenant_id}/traffic/historical",
|
f"{self.base_url}/api/v1/tenants/{tenant_id}/traffic/historical",
|
||||||
|
|||||||
Reference in New Issue
Block a user