Start fixing forecast service API 3
This commit is contained in:
@@ -1,81 +0,0 @@
|
||||
import time
|
||||
import structlog
|
||||
from typing import Dict, Any
|
||||
from shared.auth.jwt_handler import JWTHandler
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
class ServiceAuthenticator:
|
||||
"""Handles service-to-service authentication via gateway"""
|
||||
|
||||
def __init__(self):
|
||||
self.jwt_handler = JWTHandler(settings.JWT_SECRET_KEY)
|
||||
self._cached_token = None
|
||||
self._token_expires_at = 0
|
||||
|
||||
async def get_service_token(self) -> str:
|
||||
"""
|
||||
Get a valid service token, using cache when possible
|
||||
Creates JWT tokens that the gateway will accept
|
||||
"""
|
||||
current_time = int(time.time())
|
||||
|
||||
# Return cached token if still valid (with 5 min buffer)
|
||||
if (self._cached_token and
|
||||
self._token_expires_at > current_time + 300):
|
||||
return self._cached_token
|
||||
|
||||
# Create new service token
|
||||
token_expires_at = current_time + 3600 # 1 hour
|
||||
|
||||
service_payload = {
|
||||
# ✅ Required fields for gateway middleware
|
||||
"sub": "training-service",
|
||||
"user_id": "training-service",
|
||||
"email": "training-service@internal",
|
||||
"type": "access", # ✅ Must be "access" for gateway
|
||||
|
||||
# ✅ Expiration and timing
|
||||
"exp": token_expires_at,
|
||||
"iat": current_time,
|
||||
"iss": "training-service",
|
||||
|
||||
# ✅ Service identification
|
||||
"service": "training",
|
||||
"full_name": "Training Service",
|
||||
"is_verified": True,
|
||||
"is_active": True,
|
||||
|
||||
# ✅ Optional tenant context (can be overridden per request)
|
||||
"tenant_id": None
|
||||
}
|
||||
|
||||
try:
|
||||
token = self.jwt_handler.create_access_token_from_payload(service_payload)
|
||||
|
||||
# Cache the token
|
||||
self._cached_token = token
|
||||
self._token_expires_at = token_expires_at
|
||||
|
||||
logger.debug("Created new service token", expires_at=token_expires_at)
|
||||
return token
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create service token: {e}")
|
||||
raise ValueError(f"Service token creation failed: {e}")
|
||||
|
||||
def get_request_headers(self, tenant_id: str = None) -> Dict[str, str]:
|
||||
"""Get standard headers for service requests"""
|
||||
headers = {
|
||||
"X-Service": "training-service",
|
||||
"User-Agent": "training-service/1.0.0"
|
||||
}
|
||||
|
||||
if tenant_id:
|
||||
headers["X-Tenant-ID"] = str(tenant_id)
|
||||
|
||||
return headers
|
||||
|
||||
# Global authenticator instance
|
||||
service_auth = ServiceAuthenticator()
|
||||
@@ -1,219 +1,89 @@
|
||||
import httpx
|
||||
# services/training/app/services/data_client.py
|
||||
"""
|
||||
Training Service Data Client
|
||||
Migrated to use shared service clients - much simpler now!
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
# Import the shared clients
|
||||
from shared.clients import get_data_client, get_service_clients
|
||||
from app.core.config import settings
|
||||
from app.core.service_auth import service_auth
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
class DataServiceClient:
|
||||
"""Client for fetching data through the API Gateway"""
|
||||
class DataClient:
|
||||
"""
|
||||
Data client for training service
|
||||
Now uses the shared data service client under the hood
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.base_url = settings.API_GATEWAY_URL
|
||||
self.timeout = 2000.0
|
||||
|
||||
async def fetch_sales_data(self, tenant_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch all sales data for training (no pagination limits)
|
||||
FIXED: Retrieves ALL records instead of being limited to 1000
|
||||
"""
|
||||
try:
|
||||
# Get service token
|
||||
token = await service_auth.get_service_token()
|
||||
|
||||
# Prepare headers
|
||||
headers = service_auth.get_request_headers(tenant_id)
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
|
||||
all_records = []
|
||||
page = 0
|
||||
page_size = 5000 # Use maximum allowed by API
|
||||
|
||||
while True:
|
||||
# Prepare query parameters for pagination
|
||||
params = {
|
||||
"limit": page_size,
|
||||
"offset": page * page_size
|
||||
}
|
||||
|
||||
logger.info(f"Fetching sales data page {page + 1} (offset: {page * page_size})",
|
||||
tenant_id=tenant_id)
|
||||
|
||||
# Make GET request via gateway with pagination
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/api/v1/tenants/{tenant_id}/sales",
|
||||
headers=headers,
|
||||
params=params
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
page_data = response.json()
|
||||
|
||||
# Handle different response formats
|
||||
if isinstance(page_data, list):
|
||||
# Direct list response (no pagination metadata)
|
||||
records = page_data
|
||||
logger.info(f"Retrieved {len(records)} records from page {page + 1} (direct list)")
|
||||
|
||||
# For direct list responses, we need to check if we got the max possible
|
||||
# If we got less than page_size, we're done
|
||||
if len(records) == 0:
|
||||
logger.info("No records in response, pagination complete")
|
||||
break
|
||||
elif len(records) < page_size:
|
||||
# Got fewer than requested, this is the last page
|
||||
all_records.extend(records)
|
||||
logger.info(f"Final page: retrieved {len(records)} records, total: {len(all_records)}")
|
||||
break
|
||||
else:
|
||||
# Got full page, there might be more
|
||||
all_records.extend(records)
|
||||
logger.info(f"Full page retrieved: {len(records)} records, continuing to next page")
|
||||
|
||||
elif isinstance(page_data, dict):
|
||||
# Paginated response format
|
||||
records = page_data.get('records', page_data.get('data', []))
|
||||
total_available = page_data.get('total', 0)
|
||||
|
||||
logger.info(f"Retrieved {len(records)} records from page {page + 1} (paginated response)")
|
||||
|
||||
if not records:
|
||||
logger.info("No more records found in paginated response")
|
||||
break
|
||||
|
||||
all_records.extend(records)
|
||||
|
||||
# Check if we've got all available records
|
||||
if len(all_records) >= total_available:
|
||||
logger.info(f"Retrieved all available records: {len(all_records)}/{total_available}")
|
||||
break
|
||||
|
||||
else:
|
||||
logger.warning(f"Unexpected response format: {type(page_data)}")
|
||||
records = []
|
||||
break
|
||||
|
||||
page += 1
|
||||
|
||||
# Safety break to prevent infinite loops
|
||||
if page > 100: # Max 500,000 records (100 * 5000)
|
||||
logger.warning("Reached maximum page limit, stopping pagination")
|
||||
break
|
||||
|
||||
elif response.status_code == 401:
|
||||
logger.error("Authentication failed with gateway",
|
||||
tenant_id=tenant_id,
|
||||
response_text=response.text)
|
||||
return []
|
||||
|
||||
elif response.status_code == 404:
|
||||
logger.warning("Sales data endpoint not found",
|
||||
tenant_id=tenant_id,
|
||||
url=response.url)
|
||||
return []
|
||||
|
||||
else:
|
||||
logger.error(f"Gateway request failed: HTTP {response.status_code}",
|
||||
tenant_id=tenant_id,
|
||||
response_text=response.text)
|
||||
return []
|
||||
|
||||
logger.info(f"Successfully fetched {len(all_records)} total sales records via gateway",
|
||||
tenant_id=tenant_id)
|
||||
return all_records
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.error("Timeout when fetching sales data via gateway",
|
||||
tenant_id=tenant_id)
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching sales data via gateway: {e}",
|
||||
tenant_id=tenant_id)
|
||||
return []
|
||||
# Get the shared data client configured for this service
|
||||
self.data_client = get_data_client(settings, "training")
|
||||
|
||||
|
||||
async def fetch_weather_data(
|
||||
self,
|
||||
# Or alternatively, get all clients at once:
|
||||
# self.clients = get_service_clients(settings, "training")
|
||||
# Then use: self.clients.data.get_sales_data(...)
|
||||
|
||||
async def fetch_sales_data(
|
||||
self,
|
||||
tenant_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
latitude: Optional[float] = None,
|
||||
longitude: Optional[float] = None
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
product_id: Optional[str] = None,
|
||||
fetch_all: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch historical weather data for training via API Gateway using POST
|
||||
Fetch sales data for training
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
start_date: Start date in ISO format
|
||||
end_date: End date in ISO format
|
||||
product_id: Optional product filter
|
||||
fetch_all: If True, fetches ALL records using pagination (original behavior)
|
||||
If False, fetches limited records (standard API response)
|
||||
"""
|
||||
try:
|
||||
# Get service token
|
||||
token = await service_auth.get_service_token()
|
||||
|
||||
# Prepare headers
|
||||
headers = service_auth.get_request_headers(tenant_id)
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
# Prepare request payload with proper date handling
|
||||
payload = {
|
||||
"start_date": start_date, # Already in ISO format from calling code
|
||||
"end_date": end_date, # Already in ISO format from calling code
|
||||
"latitude": latitude or 40.4168, # Default Madrid coordinates
|
||||
"longitude": longitude or -3.7038
|
||||
}
|
||||
|
||||
logger.info(f"Weather request payload: {payload}", tenant_id=tenant_id)
|
||||
|
||||
# Make POST request via gateway
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/api/v1/tenants/{tenant_id}/weather/historical",
|
||||
headers=headers,
|
||||
json=payload
|
||||
if fetch_all:
|
||||
# Use paginated method to get ALL records (original behavior)
|
||||
sales_data = await self.data_client.get_all_sales_data(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
product_id=product_id,
|
||||
aggregation="daily",
|
||||
page_size=5000, # Match original page size
|
||||
max_pages=100 # Safety limit (500k records max)
|
||||
)
|
||||
else:
|
||||
# Use standard method for limited results
|
||||
sales_data = await self.data_client.get_sales_data(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
product_id=product_id,
|
||||
aggregation="daily"
|
||||
)
|
||||
sales_data = sales_data or []
|
||||
|
||||
if sales_data:
|
||||
logger.info(f"Fetched {len(sales_data)} sales records",
|
||||
tenant_id=tenant_id, product_id=product_id, fetch_all=fetch_all)
|
||||
return sales_data
|
||||
else:
|
||||
logger.warning("No sales data returned", tenant_id=tenant_id)
|
||||
return []
|
||||
|
||||
logger.info(f"Weather data request: {response.status_code}",
|
||||
tenant_id=tenant_id,
|
||||
url=response.url)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
logger.info(f"Successfully fetched {len(data)} weather records")
|
||||
return data
|
||||
elif response.status_code == 400:
|
||||
error_details = response.text
|
||||
logger.error(f"Weather API validation error (400): {error_details}")
|
||||
|
||||
# Try to parse the error and provide helpful info
|
||||
try:
|
||||
error_json = response.json()
|
||||
if 'detail' in error_json:
|
||||
detail = error_json['detail']
|
||||
if 'End date must be after start date' in str(detail):
|
||||
logger.error(f"Date range issue: start={start_date}, end={end_date}")
|
||||
elif 'Date range cannot exceed 90 days' in str(detail):
|
||||
logger.error(f"Date range too large: {start_date} to {end_date}")
|
||||
except:
|
||||
pass
|
||||
|
||||
return []
|
||||
elif response.status_code == 401:
|
||||
logger.error("Authentication failed for weather API")
|
||||
return []
|
||||
else:
|
||||
logger.error(f"Failed to fetch weather data: {response.status_code} - {response.text}")
|
||||
return []
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.error("Timeout when fetching weather data")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching weather data: {str(e)}")
|
||||
logger.error(f"Error fetching sales data: {e}", tenant_id=tenant_id)
|
||||
return []
|
||||
|
||||
async def fetch_traffic_data(
|
||||
self,
|
||||
|
||||
async def fetch_weather_data(
|
||||
self,
|
||||
tenant_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
@@ -221,65 +91,90 @@ class DataServiceClient:
|
||||
longitude: Optional[float] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch historical traffic data for training via API Gateway using POST
|
||||
Fetch weather data for training
|
||||
All the error handling and retry logic is now in the base client!
|
||||
"""
|
||||
try:
|
||||
# Get service token
|
||||
token = await service_auth.get_service_token()
|
||||
|
||||
# Prepare headers
|
||||
headers = service_auth.get_request_headers(tenant_id)
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
# Prepare request payload
|
||||
payload = {
|
||||
"start_date": start_date, # Already in ISO format from calling code
|
||||
"end_date": end_date, # Already in ISO format from calling code
|
||||
"latitude": latitude or 40.4168, # Default Madrid coordinates
|
||||
"longitude": longitude or -3.7038
|
||||
}
|
||||
|
||||
logger.info(f"Traffic request payload: {payload}", tenant_id=tenant_id)
|
||||
|
||||
# Madrid traffic data can take 5-10 minutes to download and process
|
||||
timeout_config = httpx.Timeout(
|
||||
connect=30.0, # Connection timeout
|
||||
read=600.0, # Read timeout: 10 minutes (was 30s)
|
||||
write=30.0, # Write timeout
|
||||
pool=30.0 # Pool timeout
|
||||
weather_data = await self.data_client.get_weather_historical(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
latitude=latitude,
|
||||
longitude=longitude
|
||||
)
|
||||
|
||||
# Make POST request via gateway
|
||||
async with httpx.AsyncClient(timeout=timeout_config) as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/api/v1/tenants/{tenant_id}/traffic/historical",
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
if weather_data:
|
||||
logger.info(f"Fetched {len(weather_data)} weather records",
|
||||
tenant_id=tenant_id)
|
||||
return weather_data
|
||||
else:
|
||||
logger.warning("No weather data returned", tenant_id=tenant_id)
|
||||
return []
|
||||
|
||||
logger.info(f"Traffic data request: {response.status_code}",
|
||||
tenant_id=tenant_id,
|
||||
url=response.url)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
logger.info(f"Successfully fetched {len(data)} traffic records")
|
||||
return data
|
||||
elif response.status_code == 400:
|
||||
error_details = response.text
|
||||
logger.error(f"Traffic API validation error (400): {error_details}")
|
||||
return []
|
||||
elif response.status_code == 401:
|
||||
logger.error("Authentication failed for traffic API")
|
||||
return []
|
||||
else:
|
||||
logger.error(f"Failed to fetch traffic data: {response.status_code} - {response.text}")
|
||||
return []
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.error("Timeout when fetching traffic data")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching traffic data: {str(e)}")
|
||||
return []
|
||||
logger.error(f"Error fetching weather data: {e}", tenant_id=tenant_id)
|
||||
return []
|
||||
|
||||
async def fetch_traffic_data(
|
||||
self,
|
||||
tenant_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
latitude: Optional[float] = None,
|
||||
longitude: Optional[float] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch traffic data for training
|
||||
"""
|
||||
try:
|
||||
traffic_data = await self.data_client.get_traffic_data(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
latitude=latitude,
|
||||
longitude=longitude
|
||||
)
|
||||
|
||||
if traffic_data:
|
||||
logger.info(f"Fetched {len(traffic_data)} traffic records",
|
||||
tenant_id=tenant_id)
|
||||
return traffic_data
|
||||
else:
|
||||
logger.warning("No traffic data returned", tenant_id=tenant_id)
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching traffic data: {e}", tenant_id=tenant_id)
|
||||
return []
|
||||
|
||||
async def validate_data_quality(
|
||||
self,
|
||||
tenant_id: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate data quality before training
|
||||
"""
|
||||
try:
|
||||
validation_result = await self.data_client.validate_data_quality(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
if validation_result:
|
||||
logger.info("Data validation completed",
|
||||
tenant_id=tenant_id,
|
||||
is_valid=validation_result.get("is_valid", False))
|
||||
return validation_result
|
||||
else:
|
||||
logger.warning("Data validation failed", tenant_id=tenant_id)
|
||||
return {"is_valid": False, "errors": ["Validation service unavailable"]}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating data: {e}", tenant_id=tenant_id)
|
||||
return {"is_valid": False, "errors": [str(e)]}
|
||||
|
||||
# Global instance - same as before, but much simpler implementation
|
||||
data_client = DataClient()
|
||||
@@ -13,7 +13,7 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import timezone
|
||||
import pandas as pd
|
||||
|
||||
from app.services.data_client import DataServiceClient
|
||||
from app.services.data_client import DataClient
|
||||
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType, AlignedDateRange
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -37,7 +37,7 @@ class TrainingDataOrchestrator:
|
||||
madrid_client=None,
|
||||
weather_client=None,
|
||||
date_alignment_service: DateAlignmentService = None):
|
||||
self.data_client = DataServiceClient()
|
||||
self.data_client = DataClient()
|
||||
self.date_alignment_service = date_alignment_service or DateAlignmentService()
|
||||
self.max_concurrent_requests = 3
|
||||
|
||||
|
||||
Reference in New Issue
Block a user