Improve training test 3
This commit is contained in:
@@ -22,6 +22,7 @@ import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import text
|
||||
from app.models.training import TrainedModel
|
||||
from app.core.database import get_db_session
|
||||
|
||||
@@ -565,19 +566,27 @@ class BakeryProphetManager:
|
||||
"""Deactivate previous models for the same product"""
|
||||
if self.db_session:
|
||||
try:
|
||||
# Update previous models to inactive
|
||||
query = """
|
||||
# ✅ FIX: Wrap SQL string with text() for SQLAlchemy 2.0
|
||||
query = text("""
|
||||
UPDATE trained_models
|
||||
SET is_active = false, is_production = false
|
||||
WHERE tenant_id = :tenant_id AND product_name = :product_name
|
||||
"""
|
||||
""")
|
||||
|
||||
await self.db_session.execute(query, {
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name
|
||||
})
|
||||
|
||||
# ✅ ADD: Commit the transaction
|
||||
await self.db_session.commit()
|
||||
|
||||
logger.info(f"Successfully deactivated previous models for {product_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to deactivate previous models: {str(e)}")
|
||||
# ✅ ADD: Rollback on error
|
||||
await self.db_session.rollback()
|
||||
|
||||
# Keep all existing methods unchanged
|
||||
async def generate_forecast(self,
|
||||
|
||||
@@ -12,17 +12,11 @@ class DataServiceClient:
|
||||
def __init__(self):
|
||||
self.base_url = settings.API_GATEWAY_URL
|
||||
self.timeout = 2000.0
|
||||
|
||||
async def fetch_sales_data(
|
||||
self,
|
||||
tenant_id: str,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
product_name: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
|
||||
async def fetch_sales_data(self, tenant_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch sales data for training via API Gateway
|
||||
✅ Uses proper service authentication
|
||||
Fetch all sales data for training (no pagination limits)
|
||||
FIXED: Retrieves ALL records instead of being limited to 1000
|
||||
"""
|
||||
try:
|
||||
# Get service token
|
||||
@@ -32,51 +26,104 @@ class DataServiceClient:
|
||||
headers = service_auth.get_request_headers(tenant_id)
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
|
||||
# Prepare query parameters
|
||||
params = {}
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
if end_date:
|
||||
params["end_date"] = end_date
|
||||
if product_name:
|
||||
params["product_name"] = product_name
|
||||
all_records = []
|
||||
page = 0
|
||||
page_size = 5000 # Use maximum allowed by API
|
||||
|
||||
# Make request via gateway
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/api/v1/tenants/{tenant_id}/sales",
|
||||
headers=headers,
|
||||
params=params
|
||||
)
|
||||
while True:
|
||||
# Prepare query parameters for pagination
|
||||
params = {
|
||||
"limit": page_size,
|
||||
"offset": page * page_size
|
||||
}
|
||||
|
||||
logger.info(f"Sales data request: {response.status_code}",
|
||||
tenant_id=tenant_id,
|
||||
url=response.url)
|
||||
logger.info(f"Fetching sales data page {page + 1} (offset: {page * page_size})",
|
||||
tenant_id=tenant_id)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
logger.info(f"Successfully fetched {len(data)} sales records via gateway",
|
||||
tenant_id=tenant_id)
|
||||
return data
|
||||
# Make GET request via gateway with pagination
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/api/v1/tenants/{tenant_id}/sales",
|
||||
headers=headers,
|
||||
params=params
|
||||
)
|
||||
|
||||
elif response.status_code == 401:
|
||||
logger.error("Authentication failed with gateway",
|
||||
tenant_id=tenant_id,
|
||||
response_text=response.text)
|
||||
return []
|
||||
|
||||
elif response.status_code == 404:
|
||||
logger.warning("Sales data endpoint not found",
|
||||
tenant_id=tenant_id,
|
||||
url=response.url)
|
||||
return []
|
||||
|
||||
else:
|
||||
logger.error(f"Gateway request failed: HTTP {response.status_code}",
|
||||
tenant_id=tenant_id,
|
||||
response_text=response.text)
|
||||
return []
|
||||
if response.status_code == 200:
|
||||
page_data = response.json()
|
||||
|
||||
# Handle different response formats
|
||||
if isinstance(page_data, list):
|
||||
# Direct list response (no pagination metadata)
|
||||
records = page_data
|
||||
logger.info(f"Retrieved {len(records)} records from page {page + 1} (direct list)")
|
||||
|
||||
# For direct list responses, we need to check if we got the max possible
|
||||
# If we got less than page_size, we're done
|
||||
if len(records) == 0:
|
||||
logger.info("No records in response, pagination complete")
|
||||
break
|
||||
elif len(records) < page_size:
|
||||
# Got fewer than requested, this is the last page
|
||||
all_records.extend(records)
|
||||
logger.info(f"Final page: retrieved {len(records)} records, total: {len(all_records)}")
|
||||
break
|
||||
else:
|
||||
# Got full page, there might be more
|
||||
all_records.extend(records)
|
||||
logger.info(f"Full page retrieved: {len(records)} records, continuing to next page")
|
||||
|
||||
elif isinstance(page_data, dict):
|
||||
# Paginated response format
|
||||
records = page_data.get('records', page_data.get('data', []))
|
||||
total_available = page_data.get('total', 0)
|
||||
|
||||
logger.info(f"Retrieved {len(records)} records from page {page + 1} (paginated response)")
|
||||
|
||||
if not records:
|
||||
logger.info("No more records found in paginated response")
|
||||
break
|
||||
|
||||
all_records.extend(records)
|
||||
|
||||
# Check if we've got all available records
|
||||
if len(all_records) >= total_available:
|
||||
logger.info(f"Retrieved all available records: {len(all_records)}/{total_available}")
|
||||
break
|
||||
|
||||
else:
|
||||
logger.warning(f"Unexpected response format: {type(page_data)}")
|
||||
records = []
|
||||
break
|
||||
|
||||
page += 1
|
||||
|
||||
# Safety break to prevent infinite loops
|
||||
if page > 100: # Max 500,000 records (100 * 5000)
|
||||
logger.warning("Reached maximum page limit, stopping pagination")
|
||||
break
|
||||
|
||||
elif response.status_code == 401:
|
||||
logger.error("Authentication failed with gateway",
|
||||
tenant_id=tenant_id,
|
||||
response_text=response.text)
|
||||
return []
|
||||
|
||||
elif response.status_code == 404:
|
||||
logger.warning("Sales data endpoint not found",
|
||||
tenant_id=tenant_id,
|
||||
url=response.url)
|
||||
return []
|
||||
|
||||
else:
|
||||
logger.error(f"Gateway request failed: HTTP {response.status_code}",
|
||||
tenant_id=tenant_id,
|
||||
response_text=response.text)
|
||||
return []
|
||||
|
||||
logger.info(f"Successfully fetched {len(all_records)} total sales records via gateway",
|
||||
tenant_id=tenant_id)
|
||||
return all_records
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.error("Timeout when fetching sales data via gateway",
|
||||
tenant_id=tenant_id)
|
||||
@@ -86,7 +133,7 @@ class DataServiceClient:
|
||||
logger.error(f"Error fetching sales data via gateway: {e}",
|
||||
tenant_id=tenant_id)
|
||||
return []
|
||||
|
||||
|
||||
|
||||
async def fetch_weather_data(
|
||||
self,
|
||||
@@ -195,8 +242,15 @@ class DataServiceClient:
|
||||
|
||||
logger.info(f"Traffic request payload: {payload}", tenant_id=tenant_id)
|
||||
|
||||
# Madrid traffic data can take 5-10 minutes to download and process
|
||||
timeout_config = httpx.Timeout(
|
||||
connect=30.0, # Connection timeout
|
||||
read=600.0, # Read timeout: 10 minutes (was 30s)
|
||||
write=30.0, # Write timeout
|
||||
pool=30.0 # Pool timeout
|
||||
)
|
||||
|
||||
# Make POST request via gateway
|
||||
timeout_config = httpx.Timeout(connect=30.0, read=self.timeout, write=30.0, pool=30.0)
|
||||
async with httpx.AsyncClient(timeout=timeout_config) as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/api/v1/tenants/{tenant_id}/traffic/historical",
|
||||
|
||||
Reference in New Issue
Block a user