From cd6fd875f7d5c4b984c84da66bebb2aea1736e04 Mon Sep 17 00:00:00 2001 From: Urtzi Alfaro Date: Tue, 29 Jul 2025 12:45:39 +0200 Subject: [PATCH] Improve training test 3 --- gateway/app/routes/tenant.py | 9 +- services/data/app/api/sales.py | 32 +++- services/training/app/ml/prophet_manager.py | 15 +- services/training/app/services/data_client.py | 158 ++++++++++++------ 4 files changed, 154 insertions(+), 60 deletions(-) diff --git a/gateway/app/routes/tenant.py b/gateway/app/routes/tenant.py index a6ba4e63..815624b9 100644 --- a/gateway/app/routes/tenant.py +++ b/gateway/app/routes/tenant.py @@ -173,7 +173,14 @@ async def _proxy_request(request: Request, target_path: str, service_url: str): # Add query parameters params = dict(request.query_params) - async with httpx.AsyncClient(timeout=30.0) as client: + timeout_config = httpx.Timeout( + connect=30.0, # Connection timeout + read=600.0, # Read timeout: 10 minutes (was 30s) + write=30.0, # Write timeout + pool=30.0 # Pool timeout + ) + + async with httpx.AsyncClient(timeout=timeout_config) as client: response = await client.request( method=request.method, url=url, diff --git a/services/data/app/api/sales.py b/services/data/app/api/sales.py index ae188428..c57c4c64 100644 --- a/services/data/app/api/sales.py +++ b/services/data/app/api/sales.py @@ -134,29 +134,53 @@ async def get_sales_data( start_date: Optional[datetime] = Query(None, description="Start date filter"), end_date: Optional[datetime] = Query(None, description="End date filter"), product_name: Optional[str] = Query(None, description="Product name filter"), + # ✅ FIX: Add missing pagination parameters + limit: Optional[int] = Query(1000, le=5000, description="Maximum number of records to return"), + offset: Optional[int] = Query(0, ge=0, description="Number of records to skip"), + # ✅ FIX: Add additional filtering parameters + product_names: Optional[List[str]] = Query(None, description="Multiple product name filters"), + location_ids: Optional[List[str]] = Query(None, description="Location ID filters"), + sources: Optional[List[str]] = Query(None, description="Source filters"), + min_quantity: Optional[int] = Query(None, description="Minimum quantity filter"), + max_quantity: Optional[int] = Query(None, description="Maximum quantity filter"), + min_revenue: Optional[float] = Query(None, description="Minimum revenue filter"), + max_revenue: Optional[float] = Query(None, description="Maximum revenue filter"), current_user: Dict[str, Any] = Depends(get_current_user_dep), db: AsyncSession = Depends(get_db) ): - """Get sales data for tenant with filters""" + """Get sales data for tenant with filters and pagination""" try: logger.debug("Querying sales data", tenant_id=tenant_id, start_date=start_date, end_date=end_date, - product_name=product_name) + product_name=product_name, + limit=limit, + offset=offset) + # ✅ FIX: Create complete SalesDataQuery with all parameters query = SalesDataQuery( tenant_id=tenant_id, start_date=start_date, end_date=end_date, - product_name=product_name + product_names=[product_name] if product_name else product_names, + location_ids=location_ids, + sources=sources, + min_quantity=min_quantity, + max_quantity=max_quantity, + min_revenue=min_revenue, + max_revenue=max_revenue, + limit=limit, # ✅ Now properly passed from query params + offset=offset # ✅ Now properly passed from query params ) records = await SalesService.get_sales_data(query, db) logger.debug("Successfully retrieved sales data", count=len(records), - tenant_id=tenant_id) + tenant_id=tenant_id, + limit=limit, + offset=offset) return records except Exception as e: diff --git a/services/training/app/ml/prophet_manager.py b/services/training/app/ml/prophet_manager.py index eff47a2e..df8ba2e4 100644 --- a/services/training/app/ml/prophet_manager.py +++ b/services/training/app/ml/prophet_manager.py @@ -22,6 +22,7 @@ import warnings warnings.filterwarnings('ignore') from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import text from app.models.training import TrainedModel from app.core.database import get_db_session @@ -565,19 +566,27 @@ class BakeryProphetManager: """Deactivate previous models for the same product""" if self.db_session: try: - # Update previous models to inactive - query = """ + # ✅ FIX: Wrap SQL string with text() for SQLAlchemy 2.0 + query = text(""" UPDATE trained_models SET is_active = false, is_production = false WHERE tenant_id = :tenant_id AND product_name = :product_name - """ + """) + await self.db_session.execute(query, { "tenant_id": tenant_id, "product_name": product_name }) + # ✅ ADD: Commit the transaction + await self.db_session.commit() + + logger.info(f"Successfully deactivated previous models for {product_name}") + except Exception as e: logger.error(f"Failed to deactivate previous models: {str(e)}") + # ✅ ADD: Rollback on error + await self.db_session.rollback() # Keep all existing methods unchanged async def generate_forecast(self, diff --git a/services/training/app/services/data_client.py b/services/training/app/services/data_client.py index a82b8b67..1bf29abd 100644 --- a/services/training/app/services/data_client.py +++ b/services/training/app/services/data_client.py @@ -12,17 +12,11 @@ class DataServiceClient: def __init__(self): self.base_url = settings.API_GATEWAY_URL self.timeout = 2000.0 - - async def fetch_sales_data( - self, - tenant_id: str, - start_date: Optional[str] = None, - end_date: Optional[str] = None, - product_name: Optional[str] = None - ) -> List[Dict[str, Any]]: + + async def fetch_sales_data(self, tenant_id: str) -> List[Dict[str, Any]]: """ - Fetch sales data for training via API Gateway - ✅ Uses proper service authentication + Fetch all sales data for training (no pagination limits) + FIXED: Retrieves ALL records instead of being limited to 1000 """ try: # Get service token @@ -32,51 +26,104 @@ class DataServiceClient: headers = service_auth.get_request_headers(tenant_id) headers["Authorization"] = f"Bearer {token}" - # Prepare query parameters - params = {} - if start_date: - params["start_date"] = start_date - if end_date: - params["end_date"] = end_date - if product_name: - params["product_name"] = product_name + all_records = [] + page = 0 + page_size = 5000 # Use maximum allowed by API - # Make request via gateway - async with httpx.AsyncClient(timeout=self.timeout) as client: - response = await client.get( - f"{self.base_url}/api/v1/tenants/{tenant_id}/sales", - headers=headers, - params=params - ) + while True: + # Prepare query parameters for pagination + params = { + "limit": page_size, + "offset": page * page_size + } - logger.info(f"Sales data request: {response.status_code}", - tenant_id=tenant_id, - url=response.url) + logger.info(f"Fetching sales data page {page + 1} (offset: {page * page_size})", + tenant_id=tenant_id) - if response.status_code == 200: - data = response.json() - logger.info(f"Successfully fetched {len(data)} sales records via gateway", - tenant_id=tenant_id) - return data + # Make GET request via gateway with pagination + async with httpx.AsyncClient(timeout=self.timeout) as client: + response = await client.get( + f"{self.base_url}/api/v1/tenants/{tenant_id}/sales", + headers=headers, + params=params + ) - elif response.status_code == 401: - logger.error("Authentication failed with gateway", - tenant_id=tenant_id, - response_text=response.text) - return [] - - elif response.status_code == 404: - logger.warning("Sales data endpoint not found", - tenant_id=tenant_id, - url=response.url) - return [] - - else: - logger.error(f"Gateway request failed: HTTP {response.status_code}", - tenant_id=tenant_id, - response_text=response.text) - return [] + if response.status_code == 200: + page_data = response.json() + + # Handle different response formats + if isinstance(page_data, list): + # Direct list response (no pagination metadata) + records = page_data + logger.info(f"Retrieved {len(records)} records from page {page + 1} (direct list)") + + # For direct list responses, we need to check if we got the max possible + # If we got less than page_size, we're done + if len(records) == 0: + logger.info("No records in response, pagination complete") + break + elif len(records) < page_size: + # Got fewer than requested, this is the last page + all_records.extend(records) + logger.info(f"Final page: retrieved {len(records)} records, total: {len(all_records)}") + break + else: + # Got full page, there might be more + all_records.extend(records) + logger.info(f"Full page retrieved: {len(records)} records, continuing to next page") + + elif isinstance(page_data, dict): + # Paginated response format + records = page_data.get('records', page_data.get('data', [])) + total_available = page_data.get('total', 0) + + logger.info(f"Retrieved {len(records)} records from page {page + 1} (paginated response)") + + if not records: + logger.info("No more records found in paginated response") + break + + all_records.extend(records) + + # Check if we've got all available records + if len(all_records) >= total_available: + logger.info(f"Retrieved all available records: {len(all_records)}/{total_available}") + break + + else: + logger.warning(f"Unexpected response format: {type(page_data)}") + records = [] + break + + page += 1 + + # Safety break to prevent infinite loops + if page > 100: # Max 500,000 records (100 * 5000) + logger.warning("Reached maximum page limit, stopping pagination") + break + elif response.status_code == 401: + logger.error("Authentication failed with gateway", + tenant_id=tenant_id, + response_text=response.text) + return [] + + elif response.status_code == 404: + logger.warning("Sales data endpoint not found", + tenant_id=tenant_id, + url=response.url) + return [] + + else: + logger.error(f"Gateway request failed: HTTP {response.status_code}", + tenant_id=tenant_id, + response_text=response.text) + return [] + + logger.info(f"Successfully fetched {len(all_records)} total sales records via gateway", + tenant_id=tenant_id) + return all_records + except httpx.TimeoutException: logger.error("Timeout when fetching sales data via gateway", tenant_id=tenant_id) @@ -86,7 +133,7 @@ class DataServiceClient: logger.error(f"Error fetching sales data via gateway: {e}", tenant_id=tenant_id) return [] - + async def fetch_weather_data( self, @@ -195,8 +242,15 @@ class DataServiceClient: logger.info(f"Traffic request payload: {payload}", tenant_id=tenant_id) + # Madrid traffic data can take 5-10 minutes to download and process + timeout_config = httpx.Timeout( + connect=30.0, # Connection timeout + read=600.0, # Read timeout: 10 minutes (was 30s) + write=30.0, # Write timeout + pool=30.0 # Pool timeout + ) + # Make POST request via gateway - timeout_config = httpx.Timeout(connect=30.0, read=self.timeout, write=30.0, pool=30.0) async with httpx.AsyncClient(timeout=timeout_config) as client: response = await client.post( f"{self.base_url}/api/v1/tenants/{tenant_id}/traffic/historical",