Start fixing forecast service API 3

This commit is contained in:
Urtzi Alfaro
2025-07-29 15:08:55 +02:00
parent dfb619a7b5
commit 84ed4a7a2e
14 changed files with 1607 additions and 447 deletions

106
shared/clients/__init__.py Normal file
View File

@@ -0,0 +1,106 @@
# shared/clients/__init__.py
"""
Service Client Factory and Convenient Imports
Provides easy access to all service clients
"""
from .base_service_client import BaseServiceClient, ServiceAuthenticator
from .training_client import TrainingServiceClient
from .data_client import DataServiceClient
from .forecast_client import ForecastServiceClient
# Import config
from shared.config.base import BaseServiceSettings
# Cache clients to avoid recreating them
_client_cache = {}
def get_training_client(config: BaseServiceSettings = None, service_name: str = "unknown") -> TrainingServiceClient:
"""Get or create a training service client"""
if config is None:
from app.core.config import settings as config
cache_key = f"training_{service_name}"
if cache_key not in _client_cache:
_client_cache[cache_key] = TrainingServiceClient(config, service_name)
return _client_cache[cache_key]
def get_data_client(config: BaseServiceSettings = None, service_name: str = "unknown") -> DataServiceClient:
"""Get or create a data service client"""
if config is None:
from app.core.config import settings as config
cache_key = f"data_{service_name}"
if cache_key not in _client_cache:
_client_cache[cache_key] = DataServiceClient(config, service_name)
return _client_cache[cache_key]
def get_forecast_client(config: BaseServiceSettings = None, service_name: str = "unknown") -> ForecastServiceClient:
"""Get or create a forecast service client"""
if config is None:
from app.core.config import settings as config
cache_key = f"forecast_{service_name}"
if cache_key not in _client_cache:
_client_cache[cache_key] = ForecastServiceClient(config, service_name)
return _client_cache[cache_key]
class ServiceClients:
"""Convenient wrapper for all service clients"""
def __init__(self, config: BaseServiceSettings = None, service_name: str = "unknown"):
self.service_name = service_name
self.config = config or self._get_default_config()
# Initialize clients lazily
self._training_client = None
self._data_client = None
self._forecast_client = None
def _get_default_config(self):
"""Get default config from app settings"""
try:
from app.core.config import settings
return settings
except ImportError:
raise ImportError("Could not import app config. Please provide config explicitly.")
@property
def training(self) -> TrainingServiceClient:
"""Get training service client"""
if self._training_client is None:
self._training_client = get_training_client(self.config, self.service_name)
return self._training_client
@property
def data(self) -> DataServiceClient:
"""Get data service client"""
if self._data_client is None:
self._data_client = get_data_client(self.config, self.service_name)
return self._data_client
@property
def forecast(self) -> ForecastServiceClient:
"""Get forecast service client"""
if self._forecast_client is None:
self._forecast_client = get_forecast_client(self.config, self.service_name)
return self._forecast_client
# Convenience function to get all clients
def get_service_clients(config: BaseServiceSettings = None, service_name: str = "unknown") -> ServiceClients:
"""Get a wrapper with all service clients"""
return ServiceClients(config, service_name)
# Export all classes for direct import
__all__ = [
'BaseServiceClient',
'ServiceAuthenticator',
'TrainingServiceClient',
'DataServiceClient',
'ForecastServiceClient',
'ServiceClients',
'get_training_client',
'get_data_client',
'get_forecast_client',
'get_service_clients'
]

View File

@@ -0,0 +1,363 @@
# shared/clients/base_service_client.py
"""
Base Service Client for Inter-Service Communication
Provides a reusable foundation for all service-to-service API calls
"""
import time
import asyncio
import httpx
import structlog
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, List, Union
from urllib.parse import urljoin
from shared.auth.jwt_handler import JWTHandler
from shared.config.base import BaseServiceSettings
logger = structlog.get_logger()
class ServiceAuthenticator:
"""Handles service-to-service authentication via gateway"""
def __init__(self, service_name: str, config: BaseServiceSettings):
self.service_name = service_name
self.config = config
self.jwt_handler = JWTHandler(config.JWT_SECRET_KEY)
self._cached_token = None
self._token_expires_at = 0
async def get_service_token(self) -> str:
"""Get a valid service token, using cache when possible"""
current_time = int(time.time())
# Return cached token if still valid (with 5 min buffer)
if (self._cached_token and
self._token_expires_at > current_time + 300):
return self._cached_token
# Create new service token
token_expires_at = current_time + 3600 # 1 hour
service_payload = {
"sub": f"{self.service_name}-service",
"user_id": f"{self.service_name}-service",
"email": f"{self.service_name}-service@internal",
"type": "access",
"exp": token_expires_at,
"iat": current_time,
"iss": f"{self.service_name}-service",
"service": self.service_name,
"full_name": f"{self.service_name.title()} Service",
"is_verified": True,
"is_active": True,
"tenant_id": None
}
try:
token = self.jwt_handler.create_access_token_from_payload(service_payload)
self._cached_token = token
self._token_expires_at = token_expires_at
logger.debug("Created new service token", service=self.service_name, expires_at=token_expires_at)
return token
except Exception as e:
logger.error(f"Failed to create service token: {e}", service=self.service_name)
raise ValueError(f"Service token creation failed: {e}")
def get_request_headers(self, tenant_id: Optional[str] = None) -> Dict[str, str]:
"""Get standard headers for service requests"""
headers = {
"X-Service": f"{self.service_name}-service",
"User-Agent": f"{self.service_name}-service/1.0.0"
}
if tenant_id:
headers["X-Tenant-ID"] = str(tenant_id)
return headers
class BaseServiceClient(ABC):
"""
Base class for all inter-service communication clients
Provides common functionality for API calls through the gateway
"""
def __init__(self, service_name: str, config: BaseServiceSettings):
self.service_name = service_name
self.config = config
self.gateway_url = config.GATEWAY_URL
self.authenticator = ServiceAuthenticator(service_name, config)
# HTTP client configuration
self.timeout = config.HTTP_TIMEOUT
self.retries = config.HTTP_RETRIES
self.retry_delay = config.HTTP_RETRY_DELAY
@abstractmethod
def get_service_base_path(self) -> str:
"""Return the base path for this service's APIs"""
pass
async def _make_request(
self,
method: str,
endpoint: str,
tenant_id: Optional[str] = None,
data: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
timeout: Optional[Union[int, httpx.Timeout]] = None
) -> Optional[Union[Dict[str, Any], List[Dict[str, Any]]]]:
"""
Make an authenticated request to another service via gateway
Args:
method: HTTP method (GET, POST, PUT, DELETE)
endpoint: API endpoint (will be prefixed with service base path)
tenant_id: Optional tenant ID for tenant-scoped requests
data: Request body data (for POST/PUT)
params: Query parameters
headers: Additional headers
timeout: Request timeout override
Returns:
Response data or None if request failed
"""
try:
# Get service token
token = await self.authenticator.get_service_token()
# Build headers
request_headers = self.authenticator.get_request_headers(tenant_id)
request_headers["Authorization"] = f"Bearer {token}"
request_headers["Content-Type"] = "application/json"
if headers:
request_headers.update(headers)
# Build URL
base_path = self.get_service_base_path()
if tenant_id:
# For tenant-scoped endpoints
full_endpoint = f"{base_path}/tenants/{tenant_id}/{endpoint.lstrip('/')}"
else:
# For non-tenant endpoints
full_endpoint = f"{base_path}/{endpoint.lstrip('/')}"
url = urljoin(self.gateway_url, full_endpoint)
# Make request with retries
for attempt in range(self.retries + 1):
try:
# Handle different timeout configurations
if isinstance(timeout, httpx.Timeout):
client_timeout = timeout
else:
client_timeout = timeout or self.timeout
async with httpx.AsyncClient(timeout=client_timeout) as client:
response = await client.request(
method=method,
url=url,
json=data,
params=params,
headers=request_headers
)
if response.status_code == 200:
return response.json()
elif response.status_code == 201:
return response.json()
elif response.status_code == 204:
return {} # No content success
elif response.status_code == 401:
# Token might be expired, clear cache and retry once
if attempt == 0:
self.authenticator._cached_token = None
logger.warning("Token expired, retrying with new token")
continue
else:
logger.error("Authentication failed after retry")
return None
elif response.status_code == 404:
logger.warning(f"Endpoint not found: {url}")
return None
else:
error_detail = "Unknown error"
try:
error_json = response.json()
error_detail = error_json.get('detail', f"HTTP {response.status_code}")
except:
error_detail = f"HTTP {response.status_code}: {response.text}"
logger.error(f"Request failed: {error_detail}",
url=url, status_code=response.status_code)
return None
except httpx.TimeoutException:
if attempt < self.retries:
logger.warning(f"Request timeout, retrying ({attempt + 1}/{self.retries})")
import asyncio
await asyncio.sleep(self.retry_delay * (2 ** attempt)) # Exponential backoff
continue
else:
logger.error(f"Request timeout after {self.retries} retries", url=url)
return None
except Exception as e:
if attempt < self.retries:
logger.warning(f"Request failed, retrying ({attempt + 1}/{self.retries}): {e}")
import asyncio
await asyncio.sleep(self.retry_delay * (2 ** attempt))
continue
else:
logger.error(f"Request failed after {self.retries} retries: {e}", url=url)
return None
except Exception as e:
logger.error(f"Unexpected error in _make_request: {e}")
return None
async def _make_paginated_request(
self,
endpoint: str,
tenant_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
page_size: int = 5000,
max_pages: int = 100,
timeout: Optional[Union[int, httpx.Timeout]] = None
) -> List[Dict[str, Any]]:
"""
Make paginated GET requests to fetch all records
Handles both direct list and paginated object responses
Args:
endpoint: API endpoint
tenant_id: Optional tenant ID
params: Base query parameters
page_size: Records per page (default 5000)
max_pages: Maximum pages to fetch (safety limit)
timeout: Request timeout override
Returns:
List of all records from all pages
"""
all_records = []
page = 0
base_params = params or {}
logger.info(f"Starting paginated request to {endpoint}",
tenant_id=tenant_id, page_size=page_size)
while page < max_pages:
# Prepare pagination parameters
page_params = base_params.copy()
page_params.update({
"limit": page_size,
"offset": page * page_size
})
logger.debug(f"Fetching page {page + 1} (offset: {page * page_size})",
tenant_id=tenant_id)
# Make request for this page
result = await self._make_request(
"GET",
endpoint,
tenant_id=tenant_id,
params=page_params,
timeout=timeout
)
if result is None:
logger.error(f"Failed to fetch page {page + 1}", tenant_id=tenant_id)
break
# Handle different response formats
if isinstance(result, list):
# Direct list response (no pagination metadata)
records = result
logger.debug(f"Retrieved {len(records)} records from page {page + 1} (direct list)")
if len(records) == 0:
logger.info("No records in response, pagination complete")
break
elif len(records) < page_size:
# Got fewer than requested, this is the last page
all_records.extend(records)
logger.info(f"Final page: retrieved {len(records)} records, total: {len(all_records)}")
break
else:
# Got full page, there might be more
all_records.extend(records)
logger.debug(f"Full page retrieved: {len(records)} records, continuing to next page")
elif isinstance(result, dict):
# Paginated response format
records = result.get('records', result.get('data', []))
total_available = result.get('total', 0)
logger.debug(f"Retrieved {len(records)} records from page {page + 1} (paginated response)")
if not records:
logger.info("No more records found in paginated response")
break
all_records.extend(records)
# Check if we've got all available records
if len(all_records) >= total_available:
logger.info(f"Retrieved all available records: {len(all_records)}/{total_available}")
break
else:
logger.warning(f"Unexpected response format: {type(result)}")
break
page += 1
if page >= max_pages:
logger.warning(f"Reached maximum page limit ({max_pages}), stopping pagination")
logger.info(f"Pagination complete: fetched {len(all_records)} total records",
tenant_id=tenant_id, pages_fetched=page)
return all_records
async def get(self, endpoint: str, tenant_id: Optional[str] = None, params: Optional[Dict[str, Any]] = None) -> Optional[Dict[str, Any]]:
"""Make a GET request"""
return await self._make_request("GET", endpoint, tenant_id=tenant_id, params=params)
async def get_paginated(
self,
endpoint: str,
tenant_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
page_size: int = 5000,
max_pages: int = 100,
timeout: Optional[Union[int, httpx.Timeout]] = None
) -> List[Dict[str, Any]]:
"""Make a paginated GET request to fetch all records"""
return await self._make_paginated_request(
endpoint,
tenant_id=tenant_id,
params=params,
page_size=page_size,
max_pages=max_pages,
timeout=timeout
)
async def post(self, endpoint: str, data: Dict[str, Any], tenant_id: Optional[str] = None) -> Optional[Dict[str, Any]]:
"""Make a POST request"""
return await self._make_request("POST", endpoint, tenant_id=tenant_id, data=data)
async def put(self, endpoint: str, data: Dict[str, Any], tenant_id: Optional[str] = None) -> Optional[Dict[str, Any]]:
"""Make a PUT request"""
return await self._make_request("PUT", endpoint, tenant_id=tenant_id, data=data)
async def delete(self, endpoint: str, tenant_id: Optional[str] = None) -> Optional[Dict[str, Any]]:
"""Make a DELETE request"""
return await self._make_request("DELETE", endpoint, tenant_id=tenant_id)

View File

@@ -0,0 +1,399 @@
# shared/clients/data_client.py
"""
Data Service Client
Handles all API calls to the data service
"""
import httpx
import structlog
from typing import Dict, Any, Optional, List, Union
from .base_service_client import BaseServiceClient
from shared.config.base import BaseServiceSettings
logger = structlog.get_logger()
class DataServiceClient(BaseServiceClient):
"""Client for communicating with the data service"""
def __init__(self, config: BaseServiceSettings, calling_service_name: str = "unknown"):
super().__init__(calling_service_name, config)
def get_service_base_path(self) -> str:
return "/api/v1"
# ================================================================
# SALES DATA (with advanced pagination support)
# ================================================================
async def get_sales_data(
self,
tenant_id: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
product_id: Optional[str] = None,
aggregation: str = "daily"
) -> Optional[List[Dict[str, Any]]]:
"""Get sales data for a date range"""
params = {"aggregation": aggregation}
if start_date:
params["start_date"] = start_date
if end_date:
params["end_date"] = end_date
if product_id:
params["product_id"] = product_id
result = await self.get("sales", tenant_id=tenant_id, params=params)
return result.get("sales", []) if result else None
async def get_all_sales_data(
self,
tenant_id: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
product_id: Optional[str] = None,
aggregation: str = "daily",
page_size: int = 5000,
max_pages: int = 100
) -> List[Dict[str, Any]]:
"""
Get ALL sales data using pagination (equivalent to original fetch_sales_data)
Retrieves all records without pagination limits
"""
params = {"aggregation": aggregation}
if start_date:
params["start_date"] = start_date
if end_date:
params["end_date"] = end_date
if product_id:
params["product_id"] = product_id
# Use the inherited paginated request method
try:
all_records = await self.get_paginated(
"sales",
tenant_id=tenant_id,
params=params,
page_size=page_size,
max_pages=max_pages,
timeout=2000.0 # Match original timeout
)
logger.info(f"Successfully fetched {len(all_records)} total sales records via gateway",
tenant_id=tenant_id)
return all_records
except AttributeError as e:
# Fallback: implement pagination directly if inheritance isn't working
logger.warning(f"Using fallback pagination due to: {e}")
return await self._fallback_paginated_sales(tenant_id, params, page_size, max_pages)
async def _fallback_paginated_sales(
self,
tenant_id: str,
base_params: Dict[str, Any],
page_size: int = 5000,
max_pages: int = 100
) -> List[Dict[str, Any]]:
"""
Fallback pagination implementation for sales data
This replicates your original pagination logic directly
"""
all_records = []
page = 0
logger.info(f"Starting fallback paginated request for sales data",
tenant_id=tenant_id, page_size=page_size)
while page < max_pages:
# Prepare pagination parameters
params = base_params.copy()
params.update({
"limit": page_size,
"offset": page * page_size
})
logger.info(f"Fetching sales data page {page + 1} (offset: {page * page_size})",
tenant_id=tenant_id)
# Make request using the base client's _make_request method
result = await self._make_request(
"GET",
"sales",
tenant_id=tenant_id,
params=params,
timeout=2000.0
)
if result is None:
logger.error(f"Failed to fetch page {page + 1}", tenant_id=tenant_id)
break
# Handle different response formats (from your original code)
if isinstance(result, list):
# Direct list response (no pagination metadata)
records = result
logger.info(f"Retrieved {len(records)} records from page {page + 1} (direct list)")
if len(records) == 0:
logger.info("No records in response, pagination complete")
break
elif len(records) < page_size:
# Got fewer than requested, this is the last page
all_records.extend(records)
logger.info(f"Final page: retrieved {len(records)} records, total: {len(all_records)}")
break
else:
# Got full page, there might be more
all_records.extend(records)
logger.info(f"Full page retrieved: {len(records)} records, continuing to next page")
elif isinstance(result, dict):
# Paginated response format
records = result.get('records', result.get('data', []))
total_available = result.get('total', 0)
logger.info(f"Retrieved {len(records)} records from page {page + 1} (paginated response)")
if not records:
logger.info("No more records found in paginated response")
break
all_records.extend(records)
# Check if we've got all available records
if len(all_records) >= total_available:
logger.info(f"Retrieved all available records: {len(all_records)}/{total_available}")
break
else:
logger.warning(f"Unexpected response format: {type(result)}")
break
page += 1
logger.info(f"Fallback pagination complete: fetched {len(all_records)} total records",
tenant_id=tenant_id, pages_fetched=page)
return all_records
async def upload_sales_data(
self,
tenant_id: str,
sales_data: List[Dict[str, Any]]
) -> Optional[Dict[str, Any]]:
"""Upload sales data"""
data = {"sales": sales_data}
return await self.post("sales", data=data, tenant_id=tenant_id)
# ================================================================
# WEATHER DATA
# ================================================================
async def get_weather_historical(
self,
tenant_id: str,
start_date: str,
end_date: str,
latitude: Optional[float] = None,
longitude: Optional[float] = None
) -> Optional[List[Dict[str, Any]]]:
"""
Get weather data for a date range and location
Uses POST request as per original implementation
"""
# Prepare request payload with proper date handling
payload = {
"start_date": start_date, # Already in ISO format from calling code
"end_date": end_date, # Already in ISO format from calling code
"latitude": latitude or 40.4168, # Default Madrid coordinates
"longitude": longitude or -3.7038
}
logger.info(f"Weather request payload: {payload}", tenant_id=tenant_id)
# Use POST request with extended timeout
result = await self._make_request(
"POST",
"weather/historical",
tenant_id=tenant_id,
data=payload,
timeout=2000.0 # Match original timeout
)
if result:
logger.info(f"Successfully fetched {len(result)} weather records")
return result
else:
logger.error("Failed to fetch weather data")
return []
async def get_weather_forecast(
self,
tenant_id: str,
days: str,
latitude: Optional[float] = None,
longitude: Optional[float] = None
) -> Optional[List[Dict[str, Any]]]:
"""
Get weather data for a date range and location
Uses POST request as per original implementation
"""
# Prepare request payload with proper date handling
payload = {
"days": days, # Already in ISO format from calling code
"latitude": latitude or 40.4168, # Default Madrid coordinates
"longitude": longitude or -3.7038
}
logger.info(f"Weather request payload: {payload}", tenant_id=tenant_id)
# Use POST request with extended timeout
result = await self._make_request(
"POST",
"weather/historical",
tenant_id=tenant_id,
data=payload,
timeout=2000.0 # Match original timeout
)
if result:
logger.info(f"Successfully fetched {len(result)} weather forecast for {days}")
return result
else:
logger.error("Failed to fetch weather data")
return []
# ================================================================
# TRAFFIC DATA
# ================================================================
async def get_traffic_data(
self,
tenant_id: str,
start_date: str,
end_date: str,
latitude: Optional[float] = None,
longitude: Optional[float] = None
) -> Optional[List[Dict[str, Any]]]:
"""
Get traffic data for a date range and location
Uses POST request with extended timeout for Madrid traffic data processing
"""
# Prepare request payload
payload = {
"start_date": start_date, # Already in ISO format from calling code
"end_date": end_date, # Already in ISO format from calling code
"latitude": latitude or 40.4168, # Default Madrid coordinates
"longitude": longitude or -3.7038
}
logger.info(f"Traffic request payload: {payload}", tenant_id=tenant_id)
# Madrid traffic data can take 5-10 minutes to download and process
traffic_timeout = httpx.Timeout(
connect=30.0, # Connection timeout
read=600.0, # Read timeout: 10 minutes (was 30s)
write=30.0, # Write timeout
pool=30.0 # Pool timeout
)
# Use POST request with extended timeout
result = await self._make_request(
"POST",
"traffic/historical",
tenant_id=tenant_id,
data=payload,
timeout=traffic_timeout
)
if result:
logger.info(f"Successfully fetched {len(result)} traffic records")
return result
else:
logger.error("Failed to fetch traffic data")
return []
# ================================================================
# PRODUCTS
# ================================================================
async def get_products(self, tenant_id: str) -> Optional[List[Dict[str, Any]]]:
"""Get all products for a tenant"""
result = await self.get("products", tenant_id=tenant_id)
return result.get("products", []) if result else None
async def get_product(self, tenant_id: str, product_id: str) -> Optional[Dict[str, Any]]:
"""Get a specific product"""
return await self.get(f"products/{product_id}", tenant_id=tenant_id)
async def create_product(
self,
tenant_id: str,
name: str,
category: str,
price: float,
**kwargs
) -> Optional[Dict[str, Any]]:
"""Create a new product"""
data = {
"name": name,
"category": category,
"price": price,
**kwargs
}
return await self.post("products", data=data, tenant_id=tenant_id)
async def update_product(
self,
tenant_id: str,
product_id: str,
**updates
) -> Optional[Dict[str, Any]]:
"""Update a product"""
return await self.put(f"products/{product_id}", data=updates, tenant_id=tenant_id)
# ================================================================
# STORES & LOCATIONS
# ================================================================
async def get_stores(self, tenant_id: str) -> Optional[List[Dict[str, Any]]]:
"""Get all stores for a tenant"""
result = await self.get("stores", tenant_id=tenant_id)
return result.get("stores", []) if result else None
async def get_store(self, tenant_id: str, store_id: str) -> Optional[Dict[str, Any]]:
"""Get a specific store"""
return await self.get(f"stores/{store_id}", tenant_id=tenant_id)
# ================================================================
# DATA VALIDATION & HEALTH
# ================================================================
async def validate_data_quality(
self,
tenant_id: str,
start_date: str,
end_date: str
) -> Optional[Dict[str, Any]]:
"""Validate data quality for a date range"""
params = {
"start_date": start_date,
"end_date": end_date
}
return await self.get("validation", tenant_id=tenant_id, params=params)
async def get_data_statistics(
self,
tenant_id: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""Get data statistics for a tenant"""
params = {}
if start_date:
params["start_date"] = start_date
if end_date:
params["end_date"] = end_date
return await self.get("statistics", tenant_id=tenant_id, params=params)

View File

@@ -0,0 +1,175 @@
# shared/clients/forecast_client.py
"""
Forecast Service Client
Handles all API calls to the forecasting service
"""
from typing import Dict, Any, Optional, List
from .base_service_client import BaseServiceClient
from shared.config.base import BaseServiceSettings
class ForecastServiceClient(BaseServiceClient):
"""Client for communicating with the forecasting service"""
def __init__(self, config: BaseServiceSettings, calling_service_name: str = "unknown"):
super().__init__(calling_service_name, config)
def get_service_base_path(self) -> str:
return "/api/v1"
# ================================================================
# FORECASTS
# ================================================================
async def create_forecast(
self,
tenant_id: str,
model_id: str,
start_date: str,
end_date: str,
product_ids: Optional[List[str]] = None,
include_confidence_intervals: bool = True,
**kwargs
) -> Optional[Dict[str, Any]]:
"""Create a new forecast"""
data = {
"model_id": model_id,
"start_date": start_date,
"end_date": end_date,
"include_confidence_intervals": include_confidence_intervals,
**kwargs
}
if product_ids:
data["product_ids"] = product_ids
return await self.post("forecasts", data=data, tenant_id=tenant_id)
async def get_forecast(self, tenant_id: str, forecast_id: str) -> Optional[Dict[str, Any]]:
"""Get forecast details"""
return await self.get(f"forecasts/{forecast_id}", tenant_id=tenant_id)
async def list_forecasts(
self,
tenant_id: str,
status: Optional[str] = None,
model_id: Optional[str] = None,
limit: int = 50
) -> Optional[List[Dict[str, Any]]]:
"""List forecasts for a tenant"""
params = {"limit": limit}
if status:
params["status"] = status
if model_id:
params["model_id"] = model_id
result = await self.get("forecasts", tenant_id=tenant_id, params=params)
return result.get("forecasts", []) if result else None
async def delete_forecast(self, tenant_id: str, forecast_id: str) -> Optional[Dict[str, Any]]:
"""Delete a forecast"""
return await self.delete(f"forecasts/{forecast_id}", tenant_id=tenant_id)
# ================================================================
# PREDICTIONS
# ================================================================
async def get_predictions(
self,
tenant_id: str,
forecast_id: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
product_id: Optional[str] = None
) -> Optional[List[Dict[str, Any]]]:
"""Get predictions from a forecast"""
params = {}
if start_date:
params["start_date"] = start_date
if end_date:
params["end_date"] = end_date
if product_id:
params["product_id"] = product_id
result = await self.get(f"forecasts/{forecast_id}/predictions", tenant_id=tenant_id, params=params)
return result.get("predictions", []) if result else None
async def create_realtime_prediction(
self,
tenant_id: str,
model_id: str,
target_date: str,
features: Dict[str, Any],
**kwargs
) -> Optional[Dict[str, Any]]:
"""Create a real-time prediction"""
data = {
"model_id": model_id,
"target_date": target_date,
"features": features,
**kwargs
}
return await self.post("predictions", data=data, tenant_id=tenant_id)
# ================================================================
# FORECAST VALIDATION & METRICS
# ================================================================
async def get_forecast_accuracy(
self,
tenant_id: str,
forecast_id: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""Get forecast accuracy metrics"""
params = {}
if start_date:
params["start_date"] = start_date
if end_date:
params["end_date"] = end_date
return await self.get(f"forecasts/{forecast_id}/accuracy", tenant_id=tenant_id, params=params)
async def compare_forecasts(
self,
tenant_id: str,
forecast_ids: List[str],
metric: str = "mape"
) -> Optional[Dict[str, Any]]:
"""Compare multiple forecasts"""
data = {
"forecast_ids": forecast_ids,
"metric": metric
}
return await self.post("forecasts/compare", data=data, tenant_id=tenant_id)
# ================================================================
# FORECAST SCENARIOS
# ================================================================
async def create_scenario_forecast(
self,
tenant_id: str,
model_id: str,
scenario_name: str,
scenario_data: Dict[str, Any],
start_date: str,
end_date: str,
**kwargs
) -> Optional[Dict[str, Any]]:
"""Create a scenario-based forecast"""
data = {
"model_id": model_id,
"scenario_name": scenario_name,
"scenario_data": scenario_data,
"start_date": start_date,
"end_date": end_date,
**kwargs
}
return await self.post("scenarios", data=data, tenant_id=tenant_id)
async def list_scenarios(self, tenant_id: str) -> Optional[List[Dict[str, Any]]]:
"""List forecast scenarios for a tenant"""
result = await self.get("scenarios", tenant_id=tenant_id)
return result.get("scenarios", []) if result else None

View File

@@ -0,0 +1,134 @@
# shared/clients/training_client.py
"""
Training Service Client
Handles all API calls to the training service
"""
from typing import Dict, Any, Optional, List
from .base_service_client import BaseServiceClient
from shared.config.base import BaseServiceSettings
class TrainingServiceClient(BaseServiceClient):
"""Client for communicating with the training service"""
def __init__(self, config: BaseServiceSettings, calling_service_name: str = "unknown"):
super().__init__(calling_service_name, config)
def get_service_base_path(self) -> str:
return "/api/v1"
# ================================================================
# TRAINING JOBS
# ================================================================
async def create_training_job(
self,
tenant_id: str,
include_weather: bool = True,
include_traffic: bool = False,
min_data_points: int = 30,
**kwargs
) -> Optional[Dict[str, Any]]:
"""Create a new training job"""
data = {
"include_weather": include_weather,
"include_traffic": include_traffic,
"min_data_points": min_data_points,
**kwargs
}
return await self.post("jobs", data=data, tenant_id=tenant_id)
async def get_training_job(self, tenant_id: str, job_id: str) -> Optional[Dict[str, Any]]:
"""Get training job details"""
return await self.get(f"jobs/{job_id}", tenant_id=tenant_id)
async def list_training_jobs(
self,
tenant_id: str,
status: Optional[str] = None,
limit: int = 50
) -> Optional[List[Dict[str, Any]]]:
"""List training jobs for a tenant"""
params = {"limit": limit}
if status:
params["status"] = status
result = await self.get("jobs", tenant_id=tenant_id, params=params)
return result.get("jobs", []) if result else None
async def cancel_training_job(self, tenant_id: str, job_id: str) -> Optional[Dict[str, Any]]:
"""Cancel a training job"""
return await self.delete(f"jobs/{job_id}", tenant_id=tenant_id)
# ================================================================
# MODELS
# ================================================================
async def get_model(self, tenant_id: str, model_id: str) -> Optional[Dict[str, Any]]:
"""Get model details"""
return await self.get(f"models/{model_id}", tenant_id=tenant_id)
async def list_models(
self,
tenant_id: str,
status: Optional[str] = None,
model_type: Optional[str] = None,
limit: int = 50
) -> Optional[List[Dict[str, Any]]]:
"""List models for a tenant"""
params = {"limit": limit}
if status:
params["status"] = status
if model_type:
params["model_type"] = model_type
result = await self.get("models", tenant_id=tenant_id, params=params)
return result.get("models", []) if result else None
async def get_latest_model(
self,
tenant_id: str,
model_type: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""Get the latest trained model for a tenant"""
params = {"latest": "true"}
if model_type:
params["model_type"] = model_type
result = await self.get("models", tenant_id=tenant_id, params=params)
models = result.get("models", []) if result else []
return models[0] if models else None
async def deploy_model(self, tenant_id: str, model_id: str) -> Optional[Dict[str, Any]]:
"""Deploy a trained model"""
return await self.post(f"models/{model_id}/deploy", data={}, tenant_id=tenant_id)
async def delete_model(self, tenant_id: str, model_id: str) -> Optional[Dict[str, Any]]:
"""Delete a model"""
return await self.delete(f"models/{model_id}", tenant_id=tenant_id)
# ================================================================
# MODEL METRICS & PERFORMANCE
# ================================================================
async def get_model_metrics(self, tenant_id: str, model_id: str) -> Optional[Dict[str, Any]]:
"""Get model performance metrics"""
return await self.get(f"models/{model_id}/metrics", tenant_id=tenant_id)
async def get_model_predictions(
self,
tenant_id: str,
model_id: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None
) -> Optional[List[Dict[str, Any]]]:
"""Get model predictions for evaluation"""
params = {}
if start_date:
params["start_date"] = start_date
if end_date:
params["end_date"] = end_date
result = await self.get(f"models/{model_id}/predictions", tenant_id=tenant_id, params=params)
return result.get("predictions", []) if result else None