Start fixing forecast service API 3
This commit is contained in:
@@ -47,7 +47,7 @@ async def create_single_forecast(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Generate forecast
|
# Generate forecast
|
||||||
forecast = await forecasting_service.generate_forecast(request, db)
|
forecast = await forecasting_service.generate_forecast(tenant_id, request, db)
|
||||||
|
|
||||||
# Convert to response model
|
# Convert to response model
|
||||||
return ForecastResponse(
|
return ForecastResponse(
|
||||||
|
|||||||
@@ -24,14 +24,7 @@ class ForecastRequest(BaseModel):
|
|||||||
"""Request schema for generating forecasts"""
|
"""Request schema for generating forecasts"""
|
||||||
tenant_id: str = Field(..., description="Tenant ID")
|
tenant_id: str = Field(..., description="Tenant ID")
|
||||||
product_name: str = Field(..., description="Product name")
|
product_name: str = Field(..., description="Product name")
|
||||||
location: str = Field(..., description="Location identifier")
|
|
||||||
forecast_date: date = Field(..., description="Date for which to generate forecast")
|
forecast_date: date = Field(..., description="Date for which to generate forecast")
|
||||||
business_type: BusinessType = Field(BusinessType.INDIVIDUAL, description="Business model type")
|
|
||||||
|
|
||||||
# Optional context
|
|
||||||
include_weather: bool = Field(True, description="Include weather data in forecast")
|
|
||||||
include_traffic: bool = Field(True, description="Include traffic data in forecast")
|
|
||||||
confidence_level: float = Field(0.8, ge=0.5, le=0.95, description="Confidence level for intervals")
|
|
||||||
|
|
||||||
@validator('forecast_date')
|
@validator('forecast_date')
|
||||||
def validate_forecast_date(cls, v):
|
def validate_forecast_date(cls, v):
|
||||||
|
|||||||
64
services/forecasting/app/services/data_client.py
Normal file
64
services/forecasting/app/services/data_client.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
# services/training/app/services/data_client.py
|
||||||
|
"""
|
||||||
|
Training Service Data Client
|
||||||
|
Migrated to use shared service clients - much simpler now!
|
||||||
|
"""
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# Import the shared clients
|
||||||
|
from shared.clients import get_data_client, get_service_clients
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
|
class DataClient:
|
||||||
|
"""
|
||||||
|
Data client for training service
|
||||||
|
Now uses the shared data service client under the hood
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Get the shared data client configured for this service
|
||||||
|
self.data_client = get_data_client(settings, "forecasting")
|
||||||
|
|
||||||
|
# Or alternatively, get all clients at once:
|
||||||
|
# self.clients = get_service_clients(settings, "training")
|
||||||
|
# Then use: self.clients.data.get_sales_data(...)
|
||||||
|
|
||||||
|
|
||||||
|
async def fetch_weather_forecast(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
days: str,
|
||||||
|
latitude: Optional[float] = None,
|
||||||
|
longitude: Optional[float] = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Fetch weather data for forecats
|
||||||
|
All the error handling and retry logic is now in the base client!
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
weather_data = await self.data_client.get_weather_forecast(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
days=days,
|
||||||
|
latitude=latitude,
|
||||||
|
longitude=longitude
|
||||||
|
)
|
||||||
|
|
||||||
|
if weather_data:
|
||||||
|
logger.info(f"Fetched {len(weather_data)} weather records",
|
||||||
|
tenant_id=tenant_id)
|
||||||
|
return weather_data
|
||||||
|
else:
|
||||||
|
logger.warning("No weather data returned", tenant_id=tenant_id)
|
||||||
|
return []
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching weather data: {e}", tenant_id=tenant_id)
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Global instance - same as before, but much simpler implementation
|
||||||
|
data_client = DataClient()
|
||||||
@@ -21,6 +21,8 @@ from app.services.prediction_service import PredictionService
|
|||||||
from app.services.messaging import publish_forecast_completed, publish_alert_created
|
from app.services.messaging import publish_forecast_completed, publish_alert_created
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from shared.monitoring.metrics import MetricsCollector
|
from shared.monitoring.metrics import MetricsCollector
|
||||||
|
from app.services.model_client import ModelClient
|
||||||
|
from app.services.data_client import DataClient
|
||||||
|
|
||||||
logger = structlog.get_logger()
|
logger = structlog.get_logger()
|
||||||
metrics = MetricsCollector("forecasting-service")
|
metrics = MetricsCollector("forecasting-service")
|
||||||
@@ -33,6 +35,8 @@ class ForecastingService:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.prediction_service = PredictionService()
|
self.prediction_service = PredictionService()
|
||||||
|
self.model_client = ModelClient()
|
||||||
|
self.data_client = DataClient()
|
||||||
|
|
||||||
async def generate_forecast(self, request: ForecastRequest, db: AsyncSession) -> Forecast:
|
async def generate_forecast(self, request: ForecastRequest, db: AsyncSession) -> Forecast:
|
||||||
"""Generate a single forecast for a product"""
|
"""Generate a single forecast for a product"""
|
||||||
@@ -48,7 +52,6 @@ class ForecastingService:
|
|||||||
model_info = await self._get_latest_model(
|
model_info = await self._get_latest_model(
|
||||||
request.tenant_id,
|
request.tenant_id,
|
||||||
request.product_name,
|
request.product_name,
|
||||||
request.location
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not model_info:
|
if not model_info:
|
||||||
@@ -66,10 +69,9 @@ class ForecastingService:
|
|||||||
|
|
||||||
# Create forecast record
|
# Create forecast record
|
||||||
forecast = Forecast(
|
forecast = Forecast(
|
||||||
tenant_id=uuid.UUID(request.tenant_id),
|
tenant_id=uuid.UUID(tenant_id),
|
||||||
product_name=request.product_name,
|
product_name=product_name,
|
||||||
location=request.location,
|
forecast_date=datetime.combine(forecast_date, datetime.min.time()),
|
||||||
forecast_date=datetime.combine(request.forecast_date, datetime.min.time()),
|
|
||||||
|
|
||||||
# Prediction results
|
# Prediction results
|
||||||
predicted_demand=prediction_result["demand"],
|
predicted_demand=prediction_result["demand"],
|
||||||
@@ -243,27 +245,12 @@ class ForecastingService:
|
|||||||
logger.error("Error retrieving forecasts", error=str(e))
|
logger.error("Error retrieving forecasts", error=str(e))
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _get_latest_model(self, tenant_id: str, product_name: str, location: str) -> Optional[Dict[str, Any]]:
|
async def _get_latest_model(self, tenant_id: str, product_name: str) -> Optional[Dict[str, Any]]:
|
||||||
"""Get the latest trained model for a tenant/product combination"""
|
"""Get the latest trained model for a tenant/product combination"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Call training service to get model information
|
model_data = await self.data_client.get_best_model_for_forecasting(tenant_id, product_name)
|
||||||
async with httpx.AsyncClient() as client:
|
return model_data
|
||||||
response = await client.get(
|
|
||||||
f"{settings.TRAINING_SERVICE_URL}/tenants/{tenant_id}/models/{product_name}/active",
|
|
||||||
params={},
|
|
||||||
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
return response.json()
|
|
||||||
elif response.status_code == 404:
|
|
||||||
logger.warning("No model found",
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
product=product_name)
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error getting latest model", error=str(e))
|
logger.error("Error getting latest model", error=str(e))
|
||||||
@@ -275,22 +262,15 @@ class ForecastingService:
|
|||||||
features = {
|
features = {
|
||||||
"date": request.forecast_date.isoformat(),
|
"date": request.forecast_date.isoformat(),
|
||||||
"day_of_week": request.forecast_date.weekday(),
|
"day_of_week": request.forecast_date.weekday(),
|
||||||
"is_weekend": request.forecast_date.weekday() >= 5,
|
"is_weekend": request.forecast_date.weekday() >= 5
|
||||||
"business_type": request.business_type.value
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add Spanish holidays
|
# Add Spanish holidays
|
||||||
features["is_holiday"] = await self._is_spanish_holiday(request.forecast_date)
|
features["is_holiday"] = await self._is_spanish_holiday(request.forecast_date)
|
||||||
|
|
||||||
# Add weather data if requested
|
|
||||||
if request.include_weather:
|
|
||||||
weather_data = await self._get_weather_forecast(request.forecast_date)
|
|
||||||
features.update(weather_data)
|
|
||||||
|
|
||||||
# Add traffic data if requested
|
weather_data = await self._get_weather_forecast(request.tenant_id, 1)
|
||||||
if request.include_traffic:
|
features.update(weather_data)
|
||||||
traffic_data = await self._get_traffic_forecast(request.forecast_date, request.location)
|
|
||||||
features.update(traffic_data)
|
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
@@ -315,61 +295,16 @@ class ForecastingService:
|
|||||||
logger.warning("Error checking holiday status", error=str(e))
|
logger.warning("Error checking holiday status", error=str(e))
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def _get_weather_forecast(self, forecast_date: date) -> Dict[str, Any]:
|
async def _get_weather_forecast(self, tenant_id: str, days: str) -> Dict[str, Any]:
|
||||||
"""Get weather forecast for the date"""
|
"""Get weather forecast for the date"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Call data service for weather forecast
|
weather_data = await self.data_client.fetch_weather_forecast(tenant_id, days)
|
||||||
async with httpx.AsyncClient() as client:
|
return weather_data
|
||||||
response = await client.get(
|
|
||||||
f"{settings.DATA_SERVICE_URL}/api/v1/weather/forecast",
|
|
||||||
params={"date": forecast_date.isoformat()},
|
|
||||||
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
weather = response.json()
|
|
||||||
return {
|
|
||||||
"temperature": weather.get("temperature"),
|
|
||||||
"precipitation": weather.get("precipitation"),
|
|
||||||
"humidity": weather.get("humidity"),
|
|
||||||
"weather_description": weather.get("description")
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Error getting weather forecast", error=str(e))
|
logger.warning("Error getting weather forecast", error=str(e))
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def _get_traffic_forecast(self, forecast_date: date, location: str) -> Dict[str, Any]:
|
|
||||||
"""Get traffic forecast for the date and location"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Call data service for traffic forecast
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
response = await client.get(
|
|
||||||
f"{settings.DATA_SERVICE_URL}/api/v1/traffic/forecast",
|
|
||||||
params={
|
|
||||||
"date": forecast_date.isoformat(),
|
|
||||||
"location": location
|
|
||||||
},
|
|
||||||
headers={"X-Service-Auth": settings.SERVICE_AUTH_TOKEN}
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
traffic = response.json()
|
|
||||||
return {
|
|
||||||
"traffic_volume": traffic.get("volume"),
|
|
||||||
"pedestrian_count": traffic.get("pedestrian_count")
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning("Error getting traffic forecast", error=str(e))
|
|
||||||
return {}
|
|
||||||
|
|
||||||
async def _check_and_create_alerts(self, forecast: Forecast, db: AsyncSession):
|
async def _check_and_create_alerts(self, forecast: Forecast, db: AsyncSession):
|
||||||
"""Check forecast and create alerts if needed"""
|
"""Check forecast and create alerts if needed"""
|
||||||
|
|
||||||
|
|||||||
183
services/forecasting/app/services/model_client.py
Normal file
183
services/forecasting/app/services/model_client.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
# services/forecasting/app/services/model_client.py
|
||||||
|
"""
|
||||||
|
Forecast Service Model Client
|
||||||
|
Demonstrates calling training service to get models
|
||||||
|
"""
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
|
||||||
|
# Import shared clients - no more code duplication!
|
||||||
|
from shared.clients import get_service_clients, get_training_client, get_data_client
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
|
class ModelClient:
|
||||||
|
"""
|
||||||
|
Client for managing models in forecasting service
|
||||||
|
Shows how to call multiple services cleanly
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Option 1: Get all clients at once
|
||||||
|
self.clients = get_service_clients(settings, "forecasting")
|
||||||
|
|
||||||
|
# Option 2: Get specific clients
|
||||||
|
# self.training_client = get_training_client(settings, "forecasting")
|
||||||
|
# self.data_client = get_data_client(settings, "forecasting")
|
||||||
|
|
||||||
|
async def get_available_models(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
model_type: Optional[str] = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get available trained models from training service
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
models = await self.clients.training.list_models(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
status="deployed", # Only get deployed models
|
||||||
|
model_type=model_type
|
||||||
|
)
|
||||||
|
|
||||||
|
if models:
|
||||||
|
logger.info(f"Found {len(models)} available models",
|
||||||
|
tenant_id=tenant_id, model_type=model_type)
|
||||||
|
return models
|
||||||
|
else:
|
||||||
|
logger.warning("No available models found", tenant_id=tenant_id)
|
||||||
|
return []
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching available models: {e}", tenant_id=tenant_id)
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_best_model_for_forecasting(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
product_id: Optional[str] = None
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get the best model for forecasting based on performance metrics
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get latest model
|
||||||
|
latest_model = await self.clients.training.get_latest_model(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
model_type="forecasting"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not latest_model:
|
||||||
|
logger.warning("No trained models found", tenant_id=tenant_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get model metrics to validate quality
|
||||||
|
metrics = await self.clients.training.get_model_metrics(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
model_id=latest_model["id"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if metrics and metrics.get("accuracy", 0) > 0.7: # 70% accuracy threshold
|
||||||
|
logger.info(f"Selected model {latest_model['id']} with accuracy {metrics.get('accuracy')}",
|
||||||
|
tenant_id=tenant_id)
|
||||||
|
return latest_model
|
||||||
|
else:
|
||||||
|
logger.warning(f"Model accuracy too low: {metrics.get('accuracy', 'unknown')}",
|
||||||
|
tenant_id=tenant_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error selecting best model: {e}", tenant_id=tenant_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def validate_model_data_compatibility(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
model_id: str,
|
||||||
|
forecast_start_date: str,
|
||||||
|
forecast_end_date: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Validate that we have sufficient data for the model to make forecasts
|
||||||
|
Demonstrates calling both training and data services
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get model details from training service
|
||||||
|
model = await self.clients.training.get_model(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
model_id=model_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not model:
|
||||||
|
return {"is_valid": False, "error": "Model not found"}
|
||||||
|
|
||||||
|
# Get data statistics from data service
|
||||||
|
data_stats = await self.clients.data.get_data_statistics(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
start_date=forecast_start_date,
|
||||||
|
end_date=forecast_end_date
|
||||||
|
)
|
||||||
|
|
||||||
|
if not data_stats:
|
||||||
|
return {"is_valid": False, "error": "Could not retrieve data statistics"}
|
||||||
|
|
||||||
|
# Check if we have minimum required data points
|
||||||
|
min_required = model.get("metadata", {}).get("min_data_points", 30)
|
||||||
|
available_points = data_stats.get("total_records", 0)
|
||||||
|
|
||||||
|
is_valid = available_points >= min_required
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"is_valid": is_valid,
|
||||||
|
"model_id": model_id,
|
||||||
|
"required_points": min_required,
|
||||||
|
"available_points": available_points,
|
||||||
|
"data_coverage": data_stats.get("coverage_percentage", 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
if not is_valid:
|
||||||
|
result["error"] = f"Insufficient data: need {min_required}, have {available_points}"
|
||||||
|
|
||||||
|
logger.info("Model data compatibility check completed",
|
||||||
|
tenant_id=tenant_id, model_id=model_id, is_valid=is_valid)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error validating model compatibility: {e}",
|
||||||
|
tenant_id=tenant_id, model_id=model_id)
|
||||||
|
return {"is_valid": False, "error": str(e)}
|
||||||
|
|
||||||
|
async def trigger_model_retraining(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
include_weather: bool = True,
|
||||||
|
include_traffic: bool = False
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Trigger a new training job if current model is outdated
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Create training job through training service
|
||||||
|
job = await self.clients.training.create_training_job(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
include_weather=include_weather,
|
||||||
|
include_traffic=include_traffic,
|
||||||
|
min_data_points=50 # Higher threshold for forecasting
|
||||||
|
)
|
||||||
|
|
||||||
|
if job:
|
||||||
|
logger.info(f"Training job created: {job['job_id']}", tenant_id=tenant_id)
|
||||||
|
return job
|
||||||
|
else:
|
||||||
|
logger.error("Failed to create training job", tenant_id=tenant_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error triggering model retraining: {e}", tenant_id=tenant_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Global instance
|
||||||
|
model_client = ModelClient()
|
||||||
@@ -1,81 +0,0 @@
|
|||||||
import time
|
|
||||||
import structlog
|
|
||||||
from typing import Dict, Any
|
|
||||||
from shared.auth.jwt_handler import JWTHandler
|
|
||||||
from app.core.config import settings
|
|
||||||
|
|
||||||
logger = structlog.get_logger()
|
|
||||||
|
|
||||||
class ServiceAuthenticator:
|
|
||||||
"""Handles service-to-service authentication via gateway"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.jwt_handler = JWTHandler(settings.JWT_SECRET_KEY)
|
|
||||||
self._cached_token = None
|
|
||||||
self._token_expires_at = 0
|
|
||||||
|
|
||||||
async def get_service_token(self) -> str:
|
|
||||||
"""
|
|
||||||
Get a valid service token, using cache when possible
|
|
||||||
Creates JWT tokens that the gateway will accept
|
|
||||||
"""
|
|
||||||
current_time = int(time.time())
|
|
||||||
|
|
||||||
# Return cached token if still valid (with 5 min buffer)
|
|
||||||
if (self._cached_token and
|
|
||||||
self._token_expires_at > current_time + 300):
|
|
||||||
return self._cached_token
|
|
||||||
|
|
||||||
# Create new service token
|
|
||||||
token_expires_at = current_time + 3600 # 1 hour
|
|
||||||
|
|
||||||
service_payload = {
|
|
||||||
# ✅ Required fields for gateway middleware
|
|
||||||
"sub": "training-service",
|
|
||||||
"user_id": "training-service",
|
|
||||||
"email": "training-service@internal",
|
|
||||||
"type": "access", # ✅ Must be "access" for gateway
|
|
||||||
|
|
||||||
# ✅ Expiration and timing
|
|
||||||
"exp": token_expires_at,
|
|
||||||
"iat": current_time,
|
|
||||||
"iss": "training-service",
|
|
||||||
|
|
||||||
# ✅ Service identification
|
|
||||||
"service": "training",
|
|
||||||
"full_name": "Training Service",
|
|
||||||
"is_verified": True,
|
|
||||||
"is_active": True,
|
|
||||||
|
|
||||||
# ✅ Optional tenant context (can be overridden per request)
|
|
||||||
"tenant_id": None
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
token = self.jwt_handler.create_access_token_from_payload(service_payload)
|
|
||||||
|
|
||||||
# Cache the token
|
|
||||||
self._cached_token = token
|
|
||||||
self._token_expires_at = token_expires_at
|
|
||||||
|
|
||||||
logger.debug("Created new service token", expires_at=token_expires_at)
|
|
||||||
return token
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to create service token: {e}")
|
|
||||||
raise ValueError(f"Service token creation failed: {e}")
|
|
||||||
|
|
||||||
def get_request_headers(self, tenant_id: str = None) -> Dict[str, str]:
|
|
||||||
"""Get standard headers for service requests"""
|
|
||||||
headers = {
|
|
||||||
"X-Service": "training-service",
|
|
||||||
"User-Agent": "training-service/1.0.0"
|
|
||||||
}
|
|
||||||
|
|
||||||
if tenant_id:
|
|
||||||
headers["X-Tenant-ID"] = str(tenant_id)
|
|
||||||
|
|
||||||
return headers
|
|
||||||
|
|
||||||
# Global authenticator instance
|
|
||||||
service_auth = ServiceAuthenticator()
|
|
||||||
@@ -1,140 +1,87 @@
|
|||||||
import httpx
|
# services/training/app/services/data_client.py
|
||||||
|
"""
|
||||||
|
Training Service Data Client
|
||||||
|
Migrated to use shared service clients - much simpler now!
|
||||||
|
"""
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import Dict, Any, List, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# Import the shared clients
|
||||||
|
from shared.clients import get_data_client, get_service_clients
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.service_auth import service_auth
|
|
||||||
|
|
||||||
logger = structlog.get_logger()
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
class DataServiceClient:
|
class DataClient:
|
||||||
"""Client for fetching data through the API Gateway"""
|
"""
|
||||||
|
Data client for training service
|
||||||
|
Now uses the shared data service client under the hood
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.base_url = settings.API_GATEWAY_URL
|
# Get the shared data client configured for this service
|
||||||
self.timeout = 2000.0
|
self.data_client = get_data_client(settings, "training")
|
||||||
|
|
||||||
async def fetch_sales_data(self, tenant_id: str) -> List[Dict[str, Any]]:
|
# Or alternatively, get all clients at once:
|
||||||
|
# self.clients = get_service_clients(settings, "training")
|
||||||
|
# Then use: self.clients.data.get_sales_data(...)
|
||||||
|
|
||||||
|
async def fetch_sales_data(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
start_date: Optional[str] = None,
|
||||||
|
end_date: Optional[str] = None,
|
||||||
|
product_id: Optional[str] = None,
|
||||||
|
fetch_all: bool = True
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Fetch all sales data for training (no pagination limits)
|
Fetch sales data for training
|
||||||
FIXED: Retrieves ALL records instead of being limited to 1000
|
|
||||||
|
Args:
|
||||||
|
tenant_id: Tenant identifier
|
||||||
|
start_date: Start date in ISO format
|
||||||
|
end_date: End date in ISO format
|
||||||
|
product_id: Optional product filter
|
||||||
|
fetch_all: If True, fetches ALL records using pagination (original behavior)
|
||||||
|
If False, fetches limited records (standard API response)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Get service token
|
if fetch_all:
|
||||||
token = await service_auth.get_service_token()
|
# Use paginated method to get ALL records (original behavior)
|
||||||
|
sales_data = await self.data_client.get_all_sales_data(
|
||||||
# Prepare headers
|
tenant_id=tenant_id,
|
||||||
headers = service_auth.get_request_headers(tenant_id)
|
start_date=start_date,
|
||||||
headers["Authorization"] = f"Bearer {token}"
|
end_date=end_date,
|
||||||
|
product_id=product_id,
|
||||||
all_records = []
|
aggregation="daily",
|
||||||
page = 0
|
page_size=5000, # Match original page size
|
||||||
page_size = 5000 # Use maximum allowed by API
|
max_pages=100 # Safety limit (500k records max)
|
||||||
|
|
||||||
while True:
|
|
||||||
# Prepare query parameters for pagination
|
|
||||||
params = {
|
|
||||||
"limit": page_size,
|
|
||||||
"offset": page * page_size
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(f"Fetching sales data page {page + 1} (offset: {page * page_size})",
|
|
||||||
tenant_id=tenant_id)
|
|
||||||
|
|
||||||
# Make GET request via gateway with pagination
|
|
||||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|
||||||
response = await client.get(
|
|
||||||
f"{self.base_url}/api/v1/tenants/{tenant_id}/sales",
|
|
||||||
headers=headers,
|
|
||||||
params=params
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
page_data = response.json()
|
|
||||||
|
|
||||||
# Handle different response formats
|
|
||||||
if isinstance(page_data, list):
|
|
||||||
# Direct list response (no pagination metadata)
|
|
||||||
records = page_data
|
|
||||||
logger.info(f"Retrieved {len(records)} records from page {page + 1} (direct list)")
|
|
||||||
|
|
||||||
# For direct list responses, we need to check if we got the max possible
|
|
||||||
# If we got less than page_size, we're done
|
|
||||||
if len(records) == 0:
|
|
||||||
logger.info("No records in response, pagination complete")
|
|
||||||
break
|
|
||||||
elif len(records) < page_size:
|
|
||||||
# Got fewer than requested, this is the last page
|
|
||||||
all_records.extend(records)
|
|
||||||
logger.info(f"Final page: retrieved {len(records)} records, total: {len(all_records)}")
|
|
||||||
break
|
|
||||||
else:
|
else:
|
||||||
# Got full page, there might be more
|
# Use standard method for limited results
|
||||||
all_records.extend(records)
|
sales_data = await self.data_client.get_sales_data(
|
||||||
logger.info(f"Full page retrieved: {len(records)} records, continuing to next page")
|
tenant_id=tenant_id,
|
||||||
|
start_date=start_date,
|
||||||
elif isinstance(page_data, dict):
|
end_date=end_date,
|
||||||
# Paginated response format
|
product_id=product_id,
|
||||||
records = page_data.get('records', page_data.get('data', []))
|
aggregation="daily"
|
||||||
total_available = page_data.get('total', 0)
|
)
|
||||||
|
sales_data = sales_data or []
|
||||||
logger.info(f"Retrieved {len(records)} records from page {page + 1} (paginated response)")
|
|
||||||
|
|
||||||
if not records:
|
|
||||||
logger.info("No more records found in paginated response")
|
|
||||||
break
|
|
||||||
|
|
||||||
all_records.extend(records)
|
|
||||||
|
|
||||||
# Check if we've got all available records
|
|
||||||
if len(all_records) >= total_available:
|
|
||||||
logger.info(f"Retrieved all available records: {len(all_records)}/{total_available}")
|
|
||||||
break
|
|
||||||
|
|
||||||
|
if sales_data:
|
||||||
|
logger.info(f"Fetched {len(sales_data)} sales records",
|
||||||
|
tenant_id=tenant_id, product_id=product_id, fetch_all=fetch_all)
|
||||||
|
return sales_data
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Unexpected response format: {type(page_data)}")
|
logger.warning("No sales data returned", tenant_id=tenant_id)
|
||||||
records = []
|
|
||||||
break
|
|
||||||
|
|
||||||
page += 1
|
|
||||||
|
|
||||||
# Safety break to prevent infinite loops
|
|
||||||
if page > 100: # Max 500,000 records (100 * 5000)
|
|
||||||
logger.warning("Reached maximum page limit, stopping pagination")
|
|
||||||
break
|
|
||||||
|
|
||||||
elif response.status_code == 401:
|
|
||||||
logger.error("Authentication failed with gateway",
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
response_text=response.text)
|
|
||||||
return []
|
|
||||||
|
|
||||||
elif response.status_code == 404:
|
|
||||||
logger.warning("Sales data endpoint not found",
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
url=response.url)
|
|
||||||
return []
|
|
||||||
|
|
||||||
else:
|
|
||||||
logger.error(f"Gateway request failed: HTTP {response.status_code}",
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
response_text=response.text)
|
|
||||||
return []
|
|
||||||
|
|
||||||
logger.info(f"Successfully fetched {len(all_records)} total sales records via gateway",
|
|
||||||
tenant_id=tenant_id)
|
|
||||||
return all_records
|
|
||||||
|
|
||||||
except httpx.TimeoutException:
|
|
||||||
logger.error("Timeout when fetching sales data via gateway",
|
|
||||||
tenant_id=tenant_id)
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching sales data via gateway: {e}",
|
logger.error(f"Error fetching sales data: {e}", tenant_id=tenant_id)
|
||||||
tenant_id=tenant_id)
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
async def fetch_weather_data(
|
async def fetch_weather_data(
|
||||||
self,
|
self,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
@@ -144,72 +91,28 @@ class DataServiceClient:
|
|||||||
longitude: Optional[float] = None
|
longitude: Optional[float] = None
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Fetch historical weather data for training via API Gateway using POST
|
Fetch weather data for training
|
||||||
|
All the error handling and retry logic is now in the base client!
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Get service token
|
weather_data = await self.data_client.get_weather_historical(
|
||||||
token = await service_auth.get_service_token()
|
tenant_id=tenant_id,
|
||||||
|
start_date=start_date,
|
||||||
# Prepare headers
|
end_date=end_date,
|
||||||
headers = service_auth.get_request_headers(tenant_id)
|
latitude=latitude,
|
||||||
headers["Authorization"] = f"Bearer {token}"
|
longitude=longitude
|
||||||
headers["Content-Type"] = "application/json"
|
|
||||||
|
|
||||||
# Prepare request payload with proper date handling
|
|
||||||
payload = {
|
|
||||||
"start_date": start_date, # Already in ISO format from calling code
|
|
||||||
"end_date": end_date, # Already in ISO format from calling code
|
|
||||||
"latitude": latitude or 40.4168, # Default Madrid coordinates
|
|
||||||
"longitude": longitude or -3.7038
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(f"Weather request payload: {payload}", tenant_id=tenant_id)
|
|
||||||
|
|
||||||
# Make POST request via gateway
|
|
||||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|
||||||
response = await client.post(
|
|
||||||
f"{self.base_url}/api/v1/tenants/{tenant_id}/weather/historical",
|
|
||||||
headers=headers,
|
|
||||||
json=payload
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Weather data request: {response.status_code}",
|
if weather_data:
|
||||||
tenant_id=tenant_id,
|
logger.info(f"Fetched {len(weather_data)} weather records",
|
||||||
url=response.url)
|
tenant_id=tenant_id)
|
||||||
|
return weather_data
|
||||||
if response.status_code == 200:
|
|
||||||
data = response.json()
|
|
||||||
logger.info(f"Successfully fetched {len(data)} weather records")
|
|
||||||
return data
|
|
||||||
elif response.status_code == 400:
|
|
||||||
error_details = response.text
|
|
||||||
logger.error(f"Weather API validation error (400): {error_details}")
|
|
||||||
|
|
||||||
# Try to parse the error and provide helpful info
|
|
||||||
try:
|
|
||||||
error_json = response.json()
|
|
||||||
if 'detail' in error_json:
|
|
||||||
detail = error_json['detail']
|
|
||||||
if 'End date must be after start date' in str(detail):
|
|
||||||
logger.error(f"Date range issue: start={start_date}, end={end_date}")
|
|
||||||
elif 'Date range cannot exceed 90 days' in str(detail):
|
|
||||||
logger.error(f"Date range too large: {start_date} to {end_date}")
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return []
|
|
||||||
elif response.status_code == 401:
|
|
||||||
logger.error("Authentication failed for weather API")
|
|
||||||
return []
|
|
||||||
else:
|
else:
|
||||||
logger.error(f"Failed to fetch weather data: {response.status_code} - {response.text}")
|
logger.warning("No weather data returned", tenant_id=tenant_id)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
except httpx.TimeoutException:
|
|
||||||
logger.error("Timeout when fetching weather data")
|
|
||||||
return []
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching weather data: {str(e)}")
|
logger.error(f"Error fetching weather data: {e}", tenant_id=tenant_id)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def fetch_traffic_data(
|
async def fetch_traffic_data(
|
||||||
@@ -221,65 +124,57 @@ class DataServiceClient:
|
|||||||
longitude: Optional[float] = None
|
longitude: Optional[float] = None
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Fetch historical traffic data for training via API Gateway using POST
|
Fetch traffic data for training
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Get service token
|
traffic_data = await self.data_client.get_traffic_data(
|
||||||
token = await service_auth.get_service_token()
|
|
||||||
|
|
||||||
# Prepare headers
|
|
||||||
headers = service_auth.get_request_headers(tenant_id)
|
|
||||||
headers["Authorization"] = f"Bearer {token}"
|
|
||||||
headers["Content-Type"] = "application/json"
|
|
||||||
|
|
||||||
# Prepare request payload
|
|
||||||
payload = {
|
|
||||||
"start_date": start_date, # Already in ISO format from calling code
|
|
||||||
"end_date": end_date, # Already in ISO format from calling code
|
|
||||||
"latitude": latitude or 40.4168, # Default Madrid coordinates
|
|
||||||
"longitude": longitude or -3.7038
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(f"Traffic request payload: {payload}", tenant_id=tenant_id)
|
|
||||||
|
|
||||||
# Madrid traffic data can take 5-10 minutes to download and process
|
|
||||||
timeout_config = httpx.Timeout(
|
|
||||||
connect=30.0, # Connection timeout
|
|
||||||
read=600.0, # Read timeout: 10 minutes (was 30s)
|
|
||||||
write=30.0, # Write timeout
|
|
||||||
pool=30.0 # Pool timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
# Make POST request via gateway
|
|
||||||
async with httpx.AsyncClient(timeout=timeout_config) as client:
|
|
||||||
response = await client.post(
|
|
||||||
f"{self.base_url}/api/v1/tenants/{tenant_id}/traffic/historical",
|
|
||||||
headers=headers,
|
|
||||||
json=payload
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Traffic data request: {response.status_code}",
|
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
url=response.url)
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
latitude=latitude,
|
||||||
|
longitude=longitude
|
||||||
|
)
|
||||||
|
|
||||||
if response.status_code == 200:
|
if traffic_data:
|
||||||
data = response.json()
|
logger.info(f"Fetched {len(traffic_data)} traffic records",
|
||||||
logger.info(f"Successfully fetched {len(data)} traffic records")
|
tenant_id=tenant_id)
|
||||||
return data
|
return traffic_data
|
||||||
elif response.status_code == 400:
|
|
||||||
error_details = response.text
|
|
||||||
logger.error(f"Traffic API validation error (400): {error_details}")
|
|
||||||
return []
|
|
||||||
elif response.status_code == 401:
|
|
||||||
logger.error("Authentication failed for traffic API")
|
|
||||||
return []
|
|
||||||
else:
|
else:
|
||||||
logger.error(f"Failed to fetch traffic data: {response.status_code} - {response.text}")
|
logger.warning("No traffic data returned", tenant_id=tenant_id)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
except httpx.TimeoutException:
|
|
||||||
logger.error("Timeout when fetching traffic data")
|
|
||||||
return []
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching traffic data: {str(e)}")
|
logger.error(f"Error fetching traffic data: {e}", tenant_id=tenant_id)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
async def validate_data_quality(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
start_date: str,
|
||||||
|
end_date: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Validate data quality before training
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
validation_result = await self.data_client.validate_data_quality(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date
|
||||||
|
)
|
||||||
|
|
||||||
|
if validation_result:
|
||||||
|
logger.info("Data validation completed",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
is_valid=validation_result.get("is_valid", False))
|
||||||
|
return validation_result
|
||||||
|
else:
|
||||||
|
logger.warning("Data validation failed", tenant_id=tenant_id)
|
||||||
|
return {"is_valid": False, "errors": ["Validation service unavailable"]}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error validating data: {e}", tenant_id=tenant_id)
|
||||||
|
return {"is_valid": False, "errors": [str(e)]}
|
||||||
|
|
||||||
|
# Global instance - same as before, but much simpler implementation
|
||||||
|
data_client = DataClient()
|
||||||
@@ -13,7 +13,7 @@ from concurrent.futures import ThreadPoolExecutor
|
|||||||
from datetime import timezone
|
from datetime import timezone
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
from app.services.data_client import DataServiceClient
|
from app.services.data_client import DataClient
|
||||||
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType, AlignedDateRange
|
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType, AlignedDateRange
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -37,7 +37,7 @@ class TrainingDataOrchestrator:
|
|||||||
madrid_client=None,
|
madrid_client=None,
|
||||||
weather_client=None,
|
weather_client=None,
|
||||||
date_alignment_service: DateAlignmentService = None):
|
date_alignment_service: DateAlignmentService = None):
|
||||||
self.data_client = DataServiceClient()
|
self.data_client = DataClient()
|
||||||
self.date_alignment_service = date_alignment_service or DateAlignmentService()
|
self.date_alignment_service = date_alignment_service or DateAlignmentService()
|
||||||
self.max_concurrent_requests = 3
|
self.max_concurrent_requests = 3
|
||||||
|
|
||||||
|
|||||||
106
shared/clients/__init__.py
Normal file
106
shared/clients/__init__.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
# shared/clients/__init__.py
|
||||||
|
"""
|
||||||
|
Service Client Factory and Convenient Imports
|
||||||
|
Provides easy access to all service clients
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .base_service_client import BaseServiceClient, ServiceAuthenticator
|
||||||
|
from .training_client import TrainingServiceClient
|
||||||
|
from .data_client import DataServiceClient
|
||||||
|
from .forecast_client import ForecastServiceClient
|
||||||
|
|
||||||
|
# Import config
|
||||||
|
from shared.config.base import BaseServiceSettings
|
||||||
|
|
||||||
|
# Cache clients to avoid recreating them
|
||||||
|
_client_cache = {}
|
||||||
|
|
||||||
|
def get_training_client(config: BaseServiceSettings = None, service_name: str = "unknown") -> TrainingServiceClient:
|
||||||
|
"""Get or create a training service client"""
|
||||||
|
if config is None:
|
||||||
|
from app.core.config import settings as config
|
||||||
|
|
||||||
|
cache_key = f"training_{service_name}"
|
||||||
|
if cache_key not in _client_cache:
|
||||||
|
_client_cache[cache_key] = TrainingServiceClient(config, service_name)
|
||||||
|
return _client_cache[cache_key]
|
||||||
|
|
||||||
|
def get_data_client(config: BaseServiceSettings = None, service_name: str = "unknown") -> DataServiceClient:
|
||||||
|
"""Get or create a data service client"""
|
||||||
|
if config is None:
|
||||||
|
from app.core.config import settings as config
|
||||||
|
|
||||||
|
cache_key = f"data_{service_name}"
|
||||||
|
if cache_key not in _client_cache:
|
||||||
|
_client_cache[cache_key] = DataServiceClient(config, service_name)
|
||||||
|
return _client_cache[cache_key]
|
||||||
|
|
||||||
|
def get_forecast_client(config: BaseServiceSettings = None, service_name: str = "unknown") -> ForecastServiceClient:
|
||||||
|
"""Get or create a forecast service client"""
|
||||||
|
if config is None:
|
||||||
|
from app.core.config import settings as config
|
||||||
|
|
||||||
|
cache_key = f"forecast_{service_name}"
|
||||||
|
if cache_key not in _client_cache:
|
||||||
|
_client_cache[cache_key] = ForecastServiceClient(config, service_name)
|
||||||
|
return _client_cache[cache_key]
|
||||||
|
|
||||||
|
class ServiceClients:
|
||||||
|
"""Convenient wrapper for all service clients"""
|
||||||
|
|
||||||
|
def __init__(self, config: BaseServiceSettings = None, service_name: str = "unknown"):
|
||||||
|
self.service_name = service_name
|
||||||
|
self.config = config or self._get_default_config()
|
||||||
|
|
||||||
|
# Initialize clients lazily
|
||||||
|
self._training_client = None
|
||||||
|
self._data_client = None
|
||||||
|
self._forecast_client = None
|
||||||
|
|
||||||
|
def _get_default_config(self):
|
||||||
|
"""Get default config from app settings"""
|
||||||
|
try:
|
||||||
|
from app.core.config import settings
|
||||||
|
return settings
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Could not import app config. Please provide config explicitly.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def training(self) -> TrainingServiceClient:
|
||||||
|
"""Get training service client"""
|
||||||
|
if self._training_client is None:
|
||||||
|
self._training_client = get_training_client(self.config, self.service_name)
|
||||||
|
return self._training_client
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data(self) -> DataServiceClient:
|
||||||
|
"""Get data service client"""
|
||||||
|
if self._data_client is None:
|
||||||
|
self._data_client = get_data_client(self.config, self.service_name)
|
||||||
|
return self._data_client
|
||||||
|
|
||||||
|
@property
|
||||||
|
def forecast(self) -> ForecastServiceClient:
|
||||||
|
"""Get forecast service client"""
|
||||||
|
if self._forecast_client is None:
|
||||||
|
self._forecast_client = get_forecast_client(self.config, self.service_name)
|
||||||
|
return self._forecast_client
|
||||||
|
|
||||||
|
# Convenience function to get all clients
|
||||||
|
def get_service_clients(config: BaseServiceSettings = None, service_name: str = "unknown") -> ServiceClients:
|
||||||
|
"""Get a wrapper with all service clients"""
|
||||||
|
return ServiceClients(config, service_name)
|
||||||
|
|
||||||
|
# Export all classes for direct import
|
||||||
|
__all__ = [
|
||||||
|
'BaseServiceClient',
|
||||||
|
'ServiceAuthenticator',
|
||||||
|
'TrainingServiceClient',
|
||||||
|
'DataServiceClient',
|
||||||
|
'ForecastServiceClient',
|
||||||
|
'ServiceClients',
|
||||||
|
'get_training_client',
|
||||||
|
'get_data_client',
|
||||||
|
'get_forecast_client',
|
||||||
|
'get_service_clients'
|
||||||
|
]
|
||||||
363
shared/clients/base_service_client.py
Normal file
363
shared/clients/base_service_client.py
Normal file
@@ -0,0 +1,363 @@
|
|||||||
|
# shared/clients/base_service_client.py
|
||||||
|
"""
|
||||||
|
Base Service Client for Inter-Service Communication
|
||||||
|
Provides a reusable foundation for all service-to-service API calls
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
import httpx
|
||||||
|
import structlog
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, Any, Optional, List, Union
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
|
from shared.auth.jwt_handler import JWTHandler
|
||||||
|
from shared.config.base import BaseServiceSettings
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
|
class ServiceAuthenticator:
|
||||||
|
"""Handles service-to-service authentication via gateway"""
|
||||||
|
|
||||||
|
def __init__(self, service_name: str, config: BaseServiceSettings):
|
||||||
|
self.service_name = service_name
|
||||||
|
self.config = config
|
||||||
|
self.jwt_handler = JWTHandler(config.JWT_SECRET_KEY)
|
||||||
|
self._cached_token = None
|
||||||
|
self._token_expires_at = 0
|
||||||
|
|
||||||
|
async def get_service_token(self) -> str:
|
||||||
|
"""Get a valid service token, using cache when possible"""
|
||||||
|
current_time = int(time.time())
|
||||||
|
|
||||||
|
# Return cached token if still valid (with 5 min buffer)
|
||||||
|
if (self._cached_token and
|
||||||
|
self._token_expires_at > current_time + 300):
|
||||||
|
return self._cached_token
|
||||||
|
|
||||||
|
# Create new service token
|
||||||
|
token_expires_at = current_time + 3600 # 1 hour
|
||||||
|
|
||||||
|
service_payload = {
|
||||||
|
"sub": f"{self.service_name}-service",
|
||||||
|
"user_id": f"{self.service_name}-service",
|
||||||
|
"email": f"{self.service_name}-service@internal",
|
||||||
|
"type": "access",
|
||||||
|
"exp": token_expires_at,
|
||||||
|
"iat": current_time,
|
||||||
|
"iss": f"{self.service_name}-service",
|
||||||
|
"service": self.service_name,
|
||||||
|
"full_name": f"{self.service_name.title()} Service",
|
||||||
|
"is_verified": True,
|
||||||
|
"is_active": True,
|
||||||
|
"tenant_id": None
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
token = self.jwt_handler.create_access_token_from_payload(service_payload)
|
||||||
|
self._cached_token = token
|
||||||
|
self._token_expires_at = token_expires_at
|
||||||
|
|
||||||
|
logger.debug("Created new service token", service=self.service_name, expires_at=token_expires_at)
|
||||||
|
return token
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create service token: {e}", service=self.service_name)
|
||||||
|
raise ValueError(f"Service token creation failed: {e}")
|
||||||
|
|
||||||
|
def get_request_headers(self, tenant_id: Optional[str] = None) -> Dict[str, str]:
|
||||||
|
"""Get standard headers for service requests"""
|
||||||
|
headers = {
|
||||||
|
"X-Service": f"{self.service_name}-service",
|
||||||
|
"User-Agent": f"{self.service_name}-service/1.0.0"
|
||||||
|
}
|
||||||
|
|
||||||
|
if tenant_id:
|
||||||
|
headers["X-Tenant-ID"] = str(tenant_id)
|
||||||
|
|
||||||
|
return headers
|
||||||
|
|
||||||
|
|
||||||
|
class BaseServiceClient(ABC):
|
||||||
|
"""
|
||||||
|
Base class for all inter-service communication clients
|
||||||
|
Provides common functionality for API calls through the gateway
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, service_name: str, config: BaseServiceSettings):
|
||||||
|
self.service_name = service_name
|
||||||
|
self.config = config
|
||||||
|
self.gateway_url = config.GATEWAY_URL
|
||||||
|
self.authenticator = ServiceAuthenticator(service_name, config)
|
||||||
|
|
||||||
|
# HTTP client configuration
|
||||||
|
self.timeout = config.HTTP_TIMEOUT
|
||||||
|
self.retries = config.HTTP_RETRIES
|
||||||
|
self.retry_delay = config.HTTP_RETRY_DELAY
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_service_base_path(self) -> str:
|
||||||
|
"""Return the base path for this service's APIs"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _make_request(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
endpoint: str,
|
||||||
|
tenant_id: Optional[str] = None,
|
||||||
|
data: Optional[Dict[str, Any]] = None,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
headers: Optional[Dict[str, str]] = None,
|
||||||
|
timeout: Optional[Union[int, httpx.Timeout]] = None
|
||||||
|
) -> Optional[Union[Dict[str, Any], List[Dict[str, Any]]]]:
|
||||||
|
"""
|
||||||
|
Make an authenticated request to another service via gateway
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method: HTTP method (GET, POST, PUT, DELETE)
|
||||||
|
endpoint: API endpoint (will be prefixed with service base path)
|
||||||
|
tenant_id: Optional tenant ID for tenant-scoped requests
|
||||||
|
data: Request body data (for POST/PUT)
|
||||||
|
params: Query parameters
|
||||||
|
headers: Additional headers
|
||||||
|
timeout: Request timeout override
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response data or None if request failed
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get service token
|
||||||
|
token = await self.authenticator.get_service_token()
|
||||||
|
|
||||||
|
# Build headers
|
||||||
|
request_headers = self.authenticator.get_request_headers(tenant_id)
|
||||||
|
request_headers["Authorization"] = f"Bearer {token}"
|
||||||
|
request_headers["Content-Type"] = "application/json"
|
||||||
|
|
||||||
|
if headers:
|
||||||
|
request_headers.update(headers)
|
||||||
|
|
||||||
|
# Build URL
|
||||||
|
base_path = self.get_service_base_path()
|
||||||
|
if tenant_id:
|
||||||
|
# For tenant-scoped endpoints
|
||||||
|
full_endpoint = f"{base_path}/tenants/{tenant_id}/{endpoint.lstrip('/')}"
|
||||||
|
else:
|
||||||
|
# For non-tenant endpoints
|
||||||
|
full_endpoint = f"{base_path}/{endpoint.lstrip('/')}"
|
||||||
|
|
||||||
|
url = urljoin(self.gateway_url, full_endpoint)
|
||||||
|
|
||||||
|
# Make request with retries
|
||||||
|
for attempt in range(self.retries + 1):
|
||||||
|
try:
|
||||||
|
# Handle different timeout configurations
|
||||||
|
if isinstance(timeout, httpx.Timeout):
|
||||||
|
client_timeout = timeout
|
||||||
|
else:
|
||||||
|
client_timeout = timeout or self.timeout
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=client_timeout) as client:
|
||||||
|
response = await client.request(
|
||||||
|
method=method,
|
||||||
|
url=url,
|
||||||
|
json=data,
|
||||||
|
params=params,
|
||||||
|
headers=request_headers
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
return response.json()
|
||||||
|
elif response.status_code == 201:
|
||||||
|
return response.json()
|
||||||
|
elif response.status_code == 204:
|
||||||
|
return {} # No content success
|
||||||
|
elif response.status_code == 401:
|
||||||
|
# Token might be expired, clear cache and retry once
|
||||||
|
if attempt == 0:
|
||||||
|
self.authenticator._cached_token = None
|
||||||
|
logger.warning("Token expired, retrying with new token")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
logger.error("Authentication failed after retry")
|
||||||
|
return None
|
||||||
|
elif response.status_code == 404:
|
||||||
|
logger.warning(f"Endpoint not found: {url}")
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
error_detail = "Unknown error"
|
||||||
|
try:
|
||||||
|
error_json = response.json()
|
||||||
|
error_detail = error_json.get('detail', f"HTTP {response.status_code}")
|
||||||
|
except:
|
||||||
|
error_detail = f"HTTP {response.status_code}: {response.text}"
|
||||||
|
|
||||||
|
logger.error(f"Request failed: {error_detail}",
|
||||||
|
url=url, status_code=response.status_code)
|
||||||
|
return None
|
||||||
|
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
if attempt < self.retries:
|
||||||
|
logger.warning(f"Request timeout, retrying ({attempt + 1}/{self.retries})")
|
||||||
|
import asyncio
|
||||||
|
await asyncio.sleep(self.retry_delay * (2 ** attempt)) # Exponential backoff
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
logger.error(f"Request timeout after {self.retries} retries", url=url)
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if attempt < self.retries:
|
||||||
|
logger.warning(f"Request failed, retrying ({attempt + 1}/{self.retries}): {e}")
|
||||||
|
import asyncio
|
||||||
|
await asyncio.sleep(self.retry_delay * (2 ** attempt))
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
logger.error(f"Request failed after {self.retries} retries: {e}", url=url)
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error in _make_request: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _make_paginated_request(
|
||||||
|
self,
|
||||||
|
endpoint: str,
|
||||||
|
tenant_id: Optional[str] = None,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
page_size: int = 5000,
|
||||||
|
max_pages: int = 100,
|
||||||
|
timeout: Optional[Union[int, httpx.Timeout]] = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Make paginated GET requests to fetch all records
|
||||||
|
Handles both direct list and paginated object responses
|
||||||
|
|
||||||
|
Args:
|
||||||
|
endpoint: API endpoint
|
||||||
|
tenant_id: Optional tenant ID
|
||||||
|
params: Base query parameters
|
||||||
|
page_size: Records per page (default 5000)
|
||||||
|
max_pages: Maximum pages to fetch (safety limit)
|
||||||
|
timeout: Request timeout override
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of all records from all pages
|
||||||
|
"""
|
||||||
|
all_records = []
|
||||||
|
page = 0
|
||||||
|
base_params = params or {}
|
||||||
|
|
||||||
|
logger.info(f"Starting paginated request to {endpoint}",
|
||||||
|
tenant_id=tenant_id, page_size=page_size)
|
||||||
|
|
||||||
|
while page < max_pages:
|
||||||
|
# Prepare pagination parameters
|
||||||
|
page_params = base_params.copy()
|
||||||
|
page_params.update({
|
||||||
|
"limit": page_size,
|
||||||
|
"offset": page * page_size
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.debug(f"Fetching page {page + 1} (offset: {page * page_size})",
|
||||||
|
tenant_id=tenant_id)
|
||||||
|
|
||||||
|
# Make request for this page
|
||||||
|
result = await self._make_request(
|
||||||
|
"GET",
|
||||||
|
endpoint,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
params=page_params,
|
||||||
|
timeout=timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
logger.error(f"Failed to fetch page {page + 1}", tenant_id=tenant_id)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Handle different response formats
|
||||||
|
if isinstance(result, list):
|
||||||
|
# Direct list response (no pagination metadata)
|
||||||
|
records = result
|
||||||
|
logger.debug(f"Retrieved {len(records)} records from page {page + 1} (direct list)")
|
||||||
|
|
||||||
|
if len(records) == 0:
|
||||||
|
logger.info("No records in response, pagination complete")
|
||||||
|
break
|
||||||
|
elif len(records) < page_size:
|
||||||
|
# Got fewer than requested, this is the last page
|
||||||
|
all_records.extend(records)
|
||||||
|
logger.info(f"Final page: retrieved {len(records)} records, total: {len(all_records)}")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Got full page, there might be more
|
||||||
|
all_records.extend(records)
|
||||||
|
logger.debug(f"Full page retrieved: {len(records)} records, continuing to next page")
|
||||||
|
|
||||||
|
elif isinstance(result, dict):
|
||||||
|
# Paginated response format
|
||||||
|
records = result.get('records', result.get('data', []))
|
||||||
|
total_available = result.get('total', 0)
|
||||||
|
|
||||||
|
logger.debug(f"Retrieved {len(records)} records from page {page + 1} (paginated response)")
|
||||||
|
|
||||||
|
if not records:
|
||||||
|
logger.info("No more records found in paginated response")
|
||||||
|
break
|
||||||
|
|
||||||
|
all_records.extend(records)
|
||||||
|
|
||||||
|
# Check if we've got all available records
|
||||||
|
if len(all_records) >= total_available:
|
||||||
|
logger.info(f"Retrieved all available records: {len(all_records)}/{total_available}")
|
||||||
|
break
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unexpected response format: {type(result)}")
|
||||||
|
break
|
||||||
|
|
||||||
|
page += 1
|
||||||
|
|
||||||
|
if page >= max_pages:
|
||||||
|
logger.warning(f"Reached maximum page limit ({max_pages}), stopping pagination")
|
||||||
|
|
||||||
|
logger.info(f"Pagination complete: fetched {len(all_records)} total records",
|
||||||
|
tenant_id=tenant_id, pages_fetched=page)
|
||||||
|
|
||||||
|
return all_records
|
||||||
|
|
||||||
|
async def get(self, endpoint: str, tenant_id: Optional[str] = None, params: Optional[Dict[str, Any]] = None) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Make a GET request"""
|
||||||
|
return await self._make_request("GET", endpoint, tenant_id=tenant_id, params=params)
|
||||||
|
|
||||||
|
async def get_paginated(
|
||||||
|
self,
|
||||||
|
endpoint: str,
|
||||||
|
tenant_id: Optional[str] = None,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
page_size: int = 5000,
|
||||||
|
max_pages: int = 100,
|
||||||
|
timeout: Optional[Union[int, httpx.Timeout]] = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Make a paginated GET request to fetch all records"""
|
||||||
|
return await self._make_paginated_request(
|
||||||
|
endpoint,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
params=params,
|
||||||
|
page_size=page_size,
|
||||||
|
max_pages=max_pages,
|
||||||
|
timeout=timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
async def post(self, endpoint: str, data: Dict[str, Any], tenant_id: Optional[str] = None) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Make a POST request"""
|
||||||
|
return await self._make_request("POST", endpoint, tenant_id=tenant_id, data=data)
|
||||||
|
|
||||||
|
async def put(self, endpoint: str, data: Dict[str, Any], tenant_id: Optional[str] = None) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Make a PUT request"""
|
||||||
|
return await self._make_request("PUT", endpoint, tenant_id=tenant_id, data=data)
|
||||||
|
|
||||||
|
async def delete(self, endpoint: str, tenant_id: Optional[str] = None) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Make a DELETE request"""
|
||||||
|
return await self._make_request("DELETE", endpoint, tenant_id=tenant_id)
|
||||||
399
shared/clients/data_client.py
Normal file
399
shared/clients/data_client.py
Normal file
@@ -0,0 +1,399 @@
|
|||||||
|
# shared/clients/data_client.py
|
||||||
|
"""
|
||||||
|
Data Service Client
|
||||||
|
Handles all API calls to the data service
|
||||||
|
"""
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import structlog
|
||||||
|
from typing import Dict, Any, Optional, List, Union
|
||||||
|
from .base_service_client import BaseServiceClient
|
||||||
|
from shared.config.base import BaseServiceSettings
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class DataServiceClient(BaseServiceClient):
|
||||||
|
"""Client for communicating with the data service"""
|
||||||
|
|
||||||
|
def __init__(self, config: BaseServiceSettings, calling_service_name: str = "unknown"):
|
||||||
|
super().__init__(calling_service_name, config)
|
||||||
|
|
||||||
|
def get_service_base_path(self) -> str:
|
||||||
|
return "/api/v1"
|
||||||
|
|
||||||
|
# ================================================================
|
||||||
|
# SALES DATA (with advanced pagination support)
|
||||||
|
# ================================================================
|
||||||
|
|
||||||
|
async def get_sales_data(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
start_date: Optional[str] = None,
|
||||||
|
end_date: Optional[str] = None,
|
||||||
|
product_id: Optional[str] = None,
|
||||||
|
aggregation: str = "daily"
|
||||||
|
) -> Optional[List[Dict[str, Any]]]:
|
||||||
|
"""Get sales data for a date range"""
|
||||||
|
params = {"aggregation": aggregation}
|
||||||
|
if start_date:
|
||||||
|
params["start_date"] = start_date
|
||||||
|
if end_date:
|
||||||
|
params["end_date"] = end_date
|
||||||
|
if product_id:
|
||||||
|
params["product_id"] = product_id
|
||||||
|
|
||||||
|
result = await self.get("sales", tenant_id=tenant_id, params=params)
|
||||||
|
return result.get("sales", []) if result else None
|
||||||
|
|
||||||
|
async def get_all_sales_data(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
start_date: Optional[str] = None,
|
||||||
|
end_date: Optional[str] = None,
|
||||||
|
product_id: Optional[str] = None,
|
||||||
|
aggregation: str = "daily",
|
||||||
|
page_size: int = 5000,
|
||||||
|
max_pages: int = 100
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get ALL sales data using pagination (equivalent to original fetch_sales_data)
|
||||||
|
Retrieves all records without pagination limits
|
||||||
|
"""
|
||||||
|
params = {"aggregation": aggregation}
|
||||||
|
if start_date:
|
||||||
|
params["start_date"] = start_date
|
||||||
|
if end_date:
|
||||||
|
params["end_date"] = end_date
|
||||||
|
if product_id:
|
||||||
|
params["product_id"] = product_id
|
||||||
|
|
||||||
|
# Use the inherited paginated request method
|
||||||
|
try:
|
||||||
|
all_records = await self.get_paginated(
|
||||||
|
"sales",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
params=params,
|
||||||
|
page_size=page_size,
|
||||||
|
max_pages=max_pages,
|
||||||
|
timeout=2000.0 # Match original timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Successfully fetched {len(all_records)} total sales records via gateway",
|
||||||
|
tenant_id=tenant_id)
|
||||||
|
return all_records
|
||||||
|
|
||||||
|
except AttributeError as e:
|
||||||
|
# Fallback: implement pagination directly if inheritance isn't working
|
||||||
|
logger.warning(f"Using fallback pagination due to: {e}")
|
||||||
|
return await self._fallback_paginated_sales(tenant_id, params, page_size, max_pages)
|
||||||
|
|
||||||
|
async def _fallback_paginated_sales(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
base_params: Dict[str, Any],
|
||||||
|
page_size: int = 5000,
|
||||||
|
max_pages: int = 100
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Fallback pagination implementation for sales data
|
||||||
|
This replicates your original pagination logic directly
|
||||||
|
"""
|
||||||
|
all_records = []
|
||||||
|
page = 0
|
||||||
|
|
||||||
|
logger.info(f"Starting fallback paginated request for sales data",
|
||||||
|
tenant_id=tenant_id, page_size=page_size)
|
||||||
|
|
||||||
|
while page < max_pages:
|
||||||
|
# Prepare pagination parameters
|
||||||
|
params = base_params.copy()
|
||||||
|
params.update({
|
||||||
|
"limit": page_size,
|
||||||
|
"offset": page * page_size
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(f"Fetching sales data page {page + 1} (offset: {page * page_size})",
|
||||||
|
tenant_id=tenant_id)
|
||||||
|
|
||||||
|
# Make request using the base client's _make_request method
|
||||||
|
result = await self._make_request(
|
||||||
|
"GET",
|
||||||
|
"sales",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
params=params,
|
||||||
|
timeout=2000.0
|
||||||
|
)
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
logger.error(f"Failed to fetch page {page + 1}", tenant_id=tenant_id)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Handle different response formats (from your original code)
|
||||||
|
if isinstance(result, list):
|
||||||
|
# Direct list response (no pagination metadata)
|
||||||
|
records = result
|
||||||
|
logger.info(f"Retrieved {len(records)} records from page {page + 1} (direct list)")
|
||||||
|
|
||||||
|
if len(records) == 0:
|
||||||
|
logger.info("No records in response, pagination complete")
|
||||||
|
break
|
||||||
|
elif len(records) < page_size:
|
||||||
|
# Got fewer than requested, this is the last page
|
||||||
|
all_records.extend(records)
|
||||||
|
logger.info(f"Final page: retrieved {len(records)} records, total: {len(all_records)}")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Got full page, there might be more
|
||||||
|
all_records.extend(records)
|
||||||
|
logger.info(f"Full page retrieved: {len(records)} records, continuing to next page")
|
||||||
|
|
||||||
|
elif isinstance(result, dict):
|
||||||
|
# Paginated response format
|
||||||
|
records = result.get('records', result.get('data', []))
|
||||||
|
total_available = result.get('total', 0)
|
||||||
|
|
||||||
|
logger.info(f"Retrieved {len(records)} records from page {page + 1} (paginated response)")
|
||||||
|
|
||||||
|
if not records:
|
||||||
|
logger.info("No more records found in paginated response")
|
||||||
|
break
|
||||||
|
|
||||||
|
all_records.extend(records)
|
||||||
|
|
||||||
|
# Check if we've got all available records
|
||||||
|
if len(all_records) >= total_available:
|
||||||
|
logger.info(f"Retrieved all available records: {len(all_records)}/{total_available}")
|
||||||
|
break
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unexpected response format: {type(result)}")
|
||||||
|
break
|
||||||
|
|
||||||
|
page += 1
|
||||||
|
|
||||||
|
logger.info(f"Fallback pagination complete: fetched {len(all_records)} total records",
|
||||||
|
tenant_id=tenant_id, pages_fetched=page)
|
||||||
|
|
||||||
|
return all_records
|
||||||
|
|
||||||
|
async def upload_sales_data(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
sales_data: List[Dict[str, Any]]
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Upload sales data"""
|
||||||
|
data = {"sales": sales_data}
|
||||||
|
return await self.post("sales", data=data, tenant_id=tenant_id)
|
||||||
|
|
||||||
|
# ================================================================
|
||||||
|
# WEATHER DATA
|
||||||
|
# ================================================================
|
||||||
|
|
||||||
|
async def get_weather_historical(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
latitude: Optional[float] = None,
|
||||||
|
longitude: Optional[float] = None
|
||||||
|
) -> Optional[List[Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
Get weather data for a date range and location
|
||||||
|
Uses POST request as per original implementation
|
||||||
|
"""
|
||||||
|
# Prepare request payload with proper date handling
|
||||||
|
payload = {
|
||||||
|
"start_date": start_date, # Already in ISO format from calling code
|
||||||
|
"end_date": end_date, # Already in ISO format from calling code
|
||||||
|
"latitude": latitude or 40.4168, # Default Madrid coordinates
|
||||||
|
"longitude": longitude or -3.7038
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Weather request payload: {payload}", tenant_id=tenant_id)
|
||||||
|
|
||||||
|
# Use POST request with extended timeout
|
||||||
|
result = await self._make_request(
|
||||||
|
"POST",
|
||||||
|
"weather/historical",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
data=payload,
|
||||||
|
timeout=2000.0 # Match original timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
logger.info(f"Successfully fetched {len(result)} weather records")
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
logger.error("Failed to fetch weather data")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_weather_forecast(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
days: str,
|
||||||
|
latitude: Optional[float] = None,
|
||||||
|
longitude: Optional[float] = None
|
||||||
|
) -> Optional[List[Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
Get weather data for a date range and location
|
||||||
|
Uses POST request as per original implementation
|
||||||
|
"""
|
||||||
|
# Prepare request payload with proper date handling
|
||||||
|
payload = {
|
||||||
|
"days": days, # Already in ISO format from calling code
|
||||||
|
"latitude": latitude or 40.4168, # Default Madrid coordinates
|
||||||
|
"longitude": longitude or -3.7038
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Weather request payload: {payload}", tenant_id=tenant_id)
|
||||||
|
|
||||||
|
# Use POST request with extended timeout
|
||||||
|
result = await self._make_request(
|
||||||
|
"POST",
|
||||||
|
"weather/historical",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
data=payload,
|
||||||
|
timeout=2000.0 # Match original timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
logger.info(f"Successfully fetched {len(result)} weather forecast for {days}")
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
logger.error("Failed to fetch weather data")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# ================================================================
|
||||||
|
# TRAFFIC DATA
|
||||||
|
# ================================================================
|
||||||
|
|
||||||
|
async def get_traffic_data(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
latitude: Optional[float] = None,
|
||||||
|
longitude: Optional[float] = None
|
||||||
|
) -> Optional[List[Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
Get traffic data for a date range and location
|
||||||
|
Uses POST request with extended timeout for Madrid traffic data processing
|
||||||
|
"""
|
||||||
|
# Prepare request payload
|
||||||
|
payload = {
|
||||||
|
"start_date": start_date, # Already in ISO format from calling code
|
||||||
|
"end_date": end_date, # Already in ISO format from calling code
|
||||||
|
"latitude": latitude or 40.4168, # Default Madrid coordinates
|
||||||
|
"longitude": longitude or -3.7038
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Traffic request payload: {payload}", tenant_id=tenant_id)
|
||||||
|
|
||||||
|
# Madrid traffic data can take 5-10 minutes to download and process
|
||||||
|
traffic_timeout = httpx.Timeout(
|
||||||
|
connect=30.0, # Connection timeout
|
||||||
|
read=600.0, # Read timeout: 10 minutes (was 30s)
|
||||||
|
write=30.0, # Write timeout
|
||||||
|
pool=30.0 # Pool timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use POST request with extended timeout
|
||||||
|
result = await self._make_request(
|
||||||
|
"POST",
|
||||||
|
"traffic/historical",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
data=payload,
|
||||||
|
timeout=traffic_timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
logger.info(f"Successfully fetched {len(result)} traffic records")
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
logger.error("Failed to fetch traffic data")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# ================================================================
|
||||||
|
# PRODUCTS
|
||||||
|
# ================================================================
|
||||||
|
|
||||||
|
async def get_products(self, tenant_id: str) -> Optional[List[Dict[str, Any]]]:
|
||||||
|
"""Get all products for a tenant"""
|
||||||
|
result = await self.get("products", tenant_id=tenant_id)
|
||||||
|
return result.get("products", []) if result else None
|
||||||
|
|
||||||
|
async def get_product(self, tenant_id: str, product_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get a specific product"""
|
||||||
|
return await self.get(f"products/{product_id}", tenant_id=tenant_id)
|
||||||
|
|
||||||
|
async def create_product(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
name: str,
|
||||||
|
category: str,
|
||||||
|
price: float,
|
||||||
|
**kwargs
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Create a new product"""
|
||||||
|
data = {
|
||||||
|
"name": name,
|
||||||
|
"category": category,
|
||||||
|
"price": price,
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
return await self.post("products", data=data, tenant_id=tenant_id)
|
||||||
|
|
||||||
|
async def update_product(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
product_id: str,
|
||||||
|
**updates
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Update a product"""
|
||||||
|
return await self.put(f"products/{product_id}", data=updates, tenant_id=tenant_id)
|
||||||
|
|
||||||
|
# ================================================================
|
||||||
|
# STORES & LOCATIONS
|
||||||
|
# ================================================================
|
||||||
|
|
||||||
|
async def get_stores(self, tenant_id: str) -> Optional[List[Dict[str, Any]]]:
|
||||||
|
"""Get all stores for a tenant"""
|
||||||
|
result = await self.get("stores", tenant_id=tenant_id)
|
||||||
|
return result.get("stores", []) if result else None
|
||||||
|
|
||||||
|
async def get_store(self, tenant_id: str, store_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get a specific store"""
|
||||||
|
return await self.get(f"stores/{store_id}", tenant_id=tenant_id)
|
||||||
|
|
||||||
|
# ================================================================
|
||||||
|
# DATA VALIDATION & HEALTH
|
||||||
|
# ================================================================
|
||||||
|
|
||||||
|
async def validate_data_quality(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
start_date: str,
|
||||||
|
end_date: str
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Validate data quality for a date range"""
|
||||||
|
params = {
|
||||||
|
"start_date": start_date,
|
||||||
|
"end_date": end_date
|
||||||
|
}
|
||||||
|
return await self.get("validation", tenant_id=tenant_id, params=params)
|
||||||
|
|
||||||
|
async def get_data_statistics(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
start_date: Optional[str] = None,
|
||||||
|
end_date: Optional[str] = None
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get data statistics for a tenant"""
|
||||||
|
params = {}
|
||||||
|
if start_date:
|
||||||
|
params["start_date"] = start_date
|
||||||
|
if end_date:
|
||||||
|
params["end_date"] = end_date
|
||||||
|
|
||||||
|
return await self.get("statistics", tenant_id=tenant_id, params=params)
|
||||||
175
shared/clients/forecast_client.py
Normal file
175
shared/clients/forecast_client.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
# shared/clients/forecast_client.py
|
||||||
|
"""
|
||||||
|
Forecast Service Client
|
||||||
|
Handles all API calls to the forecasting service
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Any, Optional, List
|
||||||
|
from .base_service_client import BaseServiceClient
|
||||||
|
from shared.config.base import BaseServiceSettings
|
||||||
|
|
||||||
|
|
||||||
|
class ForecastServiceClient(BaseServiceClient):
|
||||||
|
"""Client for communicating with the forecasting service"""
|
||||||
|
|
||||||
|
def __init__(self, config: BaseServiceSettings, calling_service_name: str = "unknown"):
|
||||||
|
super().__init__(calling_service_name, config)
|
||||||
|
|
||||||
|
def get_service_base_path(self) -> str:
|
||||||
|
return "/api/v1"
|
||||||
|
|
||||||
|
# ================================================================
|
||||||
|
# FORECASTS
|
||||||
|
# ================================================================
|
||||||
|
|
||||||
|
async def create_forecast(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
model_id: str,
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
product_ids: Optional[List[str]] = None,
|
||||||
|
include_confidence_intervals: bool = True,
|
||||||
|
**kwargs
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Create a new forecast"""
|
||||||
|
data = {
|
||||||
|
"model_id": model_id,
|
||||||
|
"start_date": start_date,
|
||||||
|
"end_date": end_date,
|
||||||
|
"include_confidence_intervals": include_confidence_intervals,
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
if product_ids:
|
||||||
|
data["product_ids"] = product_ids
|
||||||
|
|
||||||
|
return await self.post("forecasts", data=data, tenant_id=tenant_id)
|
||||||
|
|
||||||
|
async def get_forecast(self, tenant_id: str, forecast_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get forecast details"""
|
||||||
|
return await self.get(f"forecasts/{forecast_id}", tenant_id=tenant_id)
|
||||||
|
|
||||||
|
async def list_forecasts(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
status: Optional[str] = None,
|
||||||
|
model_id: Optional[str] = None,
|
||||||
|
limit: int = 50
|
||||||
|
) -> Optional[List[Dict[str, Any]]]:
|
||||||
|
"""List forecasts for a tenant"""
|
||||||
|
params = {"limit": limit}
|
||||||
|
if status:
|
||||||
|
params["status"] = status
|
||||||
|
if model_id:
|
||||||
|
params["model_id"] = model_id
|
||||||
|
|
||||||
|
result = await self.get("forecasts", tenant_id=tenant_id, params=params)
|
||||||
|
return result.get("forecasts", []) if result else None
|
||||||
|
|
||||||
|
async def delete_forecast(self, tenant_id: str, forecast_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Delete a forecast"""
|
||||||
|
return await self.delete(f"forecasts/{forecast_id}", tenant_id=tenant_id)
|
||||||
|
|
||||||
|
# ================================================================
|
||||||
|
# PREDICTIONS
|
||||||
|
# ================================================================
|
||||||
|
|
||||||
|
async def get_predictions(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
forecast_id: str,
|
||||||
|
start_date: Optional[str] = None,
|
||||||
|
end_date: Optional[str] = None,
|
||||||
|
product_id: Optional[str] = None
|
||||||
|
) -> Optional[List[Dict[str, Any]]]:
|
||||||
|
"""Get predictions from a forecast"""
|
||||||
|
params = {}
|
||||||
|
if start_date:
|
||||||
|
params["start_date"] = start_date
|
||||||
|
if end_date:
|
||||||
|
params["end_date"] = end_date
|
||||||
|
if product_id:
|
||||||
|
params["product_id"] = product_id
|
||||||
|
|
||||||
|
result = await self.get(f"forecasts/{forecast_id}/predictions", tenant_id=tenant_id, params=params)
|
||||||
|
return result.get("predictions", []) if result else None
|
||||||
|
|
||||||
|
async def create_realtime_prediction(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
model_id: str,
|
||||||
|
target_date: str,
|
||||||
|
features: Dict[str, Any],
|
||||||
|
**kwargs
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Create a real-time prediction"""
|
||||||
|
data = {
|
||||||
|
"model_id": model_id,
|
||||||
|
"target_date": target_date,
|
||||||
|
"features": features,
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
return await self.post("predictions", data=data, tenant_id=tenant_id)
|
||||||
|
|
||||||
|
# ================================================================
|
||||||
|
# FORECAST VALIDATION & METRICS
|
||||||
|
# ================================================================
|
||||||
|
|
||||||
|
async def get_forecast_accuracy(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
forecast_id: str,
|
||||||
|
start_date: Optional[str] = None,
|
||||||
|
end_date: Optional[str] = None
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get forecast accuracy metrics"""
|
||||||
|
params = {}
|
||||||
|
if start_date:
|
||||||
|
params["start_date"] = start_date
|
||||||
|
if end_date:
|
||||||
|
params["end_date"] = end_date
|
||||||
|
|
||||||
|
return await self.get(f"forecasts/{forecast_id}/accuracy", tenant_id=tenant_id, params=params)
|
||||||
|
|
||||||
|
async def compare_forecasts(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
forecast_ids: List[str],
|
||||||
|
metric: str = "mape"
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Compare multiple forecasts"""
|
||||||
|
data = {
|
||||||
|
"forecast_ids": forecast_ids,
|
||||||
|
"metric": metric
|
||||||
|
}
|
||||||
|
return await self.post("forecasts/compare", data=data, tenant_id=tenant_id)
|
||||||
|
|
||||||
|
# ================================================================
|
||||||
|
# FORECAST SCENARIOS
|
||||||
|
# ================================================================
|
||||||
|
|
||||||
|
async def create_scenario_forecast(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
model_id: str,
|
||||||
|
scenario_name: str,
|
||||||
|
scenario_data: Dict[str, Any],
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
**kwargs
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Create a scenario-based forecast"""
|
||||||
|
data = {
|
||||||
|
"model_id": model_id,
|
||||||
|
"scenario_name": scenario_name,
|
||||||
|
"scenario_data": scenario_data,
|
||||||
|
"start_date": start_date,
|
||||||
|
"end_date": end_date,
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
return await self.post("scenarios", data=data, tenant_id=tenant_id)
|
||||||
|
|
||||||
|
async def list_scenarios(self, tenant_id: str) -> Optional[List[Dict[str, Any]]]:
|
||||||
|
"""List forecast scenarios for a tenant"""
|
||||||
|
result = await self.get("scenarios", tenant_id=tenant_id)
|
||||||
|
return result.get("scenarios", []) if result else None
|
||||||
134
shared/clients/training_client.py
Normal file
134
shared/clients/training_client.py
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
# shared/clients/training_client.py
|
||||||
|
"""
|
||||||
|
Training Service Client
|
||||||
|
Handles all API calls to the training service
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Any, Optional, List
|
||||||
|
from .base_service_client import BaseServiceClient
|
||||||
|
from shared.config.base import BaseServiceSettings
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingServiceClient(BaseServiceClient):
|
||||||
|
"""Client for communicating with the training service"""
|
||||||
|
|
||||||
|
def __init__(self, config: BaseServiceSettings, calling_service_name: str = "unknown"):
|
||||||
|
super().__init__(calling_service_name, config)
|
||||||
|
|
||||||
|
def get_service_base_path(self) -> str:
|
||||||
|
return "/api/v1"
|
||||||
|
|
||||||
|
# ================================================================
|
||||||
|
# TRAINING JOBS
|
||||||
|
# ================================================================
|
||||||
|
|
||||||
|
async def create_training_job(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
include_weather: bool = True,
|
||||||
|
include_traffic: bool = False,
|
||||||
|
min_data_points: int = 30,
|
||||||
|
**kwargs
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Create a new training job"""
|
||||||
|
data = {
|
||||||
|
"include_weather": include_weather,
|
||||||
|
"include_traffic": include_traffic,
|
||||||
|
"min_data_points": min_data_points,
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
return await self.post("jobs", data=data, tenant_id=tenant_id)
|
||||||
|
|
||||||
|
async def get_training_job(self, tenant_id: str, job_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get training job details"""
|
||||||
|
return await self.get(f"jobs/{job_id}", tenant_id=tenant_id)
|
||||||
|
|
||||||
|
async def list_training_jobs(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
status: Optional[str] = None,
|
||||||
|
limit: int = 50
|
||||||
|
) -> Optional[List[Dict[str, Any]]]:
|
||||||
|
"""List training jobs for a tenant"""
|
||||||
|
params = {"limit": limit}
|
||||||
|
if status:
|
||||||
|
params["status"] = status
|
||||||
|
|
||||||
|
result = await self.get("jobs", tenant_id=tenant_id, params=params)
|
||||||
|
return result.get("jobs", []) if result else None
|
||||||
|
|
||||||
|
async def cancel_training_job(self, tenant_id: str, job_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Cancel a training job"""
|
||||||
|
return await self.delete(f"jobs/{job_id}", tenant_id=tenant_id)
|
||||||
|
|
||||||
|
# ================================================================
|
||||||
|
# MODELS
|
||||||
|
# ================================================================
|
||||||
|
|
||||||
|
async def get_model(self, tenant_id: str, model_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get model details"""
|
||||||
|
return await self.get(f"models/{model_id}", tenant_id=tenant_id)
|
||||||
|
|
||||||
|
async def list_models(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
status: Optional[str] = None,
|
||||||
|
model_type: Optional[str] = None,
|
||||||
|
limit: int = 50
|
||||||
|
) -> Optional[List[Dict[str, Any]]]:
|
||||||
|
"""List models for a tenant"""
|
||||||
|
params = {"limit": limit}
|
||||||
|
if status:
|
||||||
|
params["status"] = status
|
||||||
|
if model_type:
|
||||||
|
params["model_type"] = model_type
|
||||||
|
|
||||||
|
result = await self.get("models", tenant_id=tenant_id, params=params)
|
||||||
|
return result.get("models", []) if result else None
|
||||||
|
|
||||||
|
async def get_latest_model(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
model_type: Optional[str] = None
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get the latest trained model for a tenant"""
|
||||||
|
params = {"latest": "true"}
|
||||||
|
if model_type:
|
||||||
|
params["model_type"] = model_type
|
||||||
|
|
||||||
|
result = await self.get("models", tenant_id=tenant_id, params=params)
|
||||||
|
models = result.get("models", []) if result else []
|
||||||
|
return models[0] if models else None
|
||||||
|
|
||||||
|
async def deploy_model(self, tenant_id: str, model_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Deploy a trained model"""
|
||||||
|
return await self.post(f"models/{model_id}/deploy", data={}, tenant_id=tenant_id)
|
||||||
|
|
||||||
|
async def delete_model(self, tenant_id: str, model_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Delete a model"""
|
||||||
|
return await self.delete(f"models/{model_id}", tenant_id=tenant_id)
|
||||||
|
|
||||||
|
# ================================================================
|
||||||
|
# MODEL METRICS & PERFORMANCE
|
||||||
|
# ================================================================
|
||||||
|
|
||||||
|
async def get_model_metrics(self, tenant_id: str, model_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get model performance metrics"""
|
||||||
|
return await self.get(f"models/{model_id}/metrics", tenant_id=tenant_id)
|
||||||
|
|
||||||
|
async def get_model_predictions(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
model_id: str,
|
||||||
|
start_date: Optional[str] = None,
|
||||||
|
end_date: Optional[str] = None
|
||||||
|
) -> Optional[List[Dict[str, Any]]]:
|
||||||
|
"""Get model predictions for evaluation"""
|
||||||
|
params = {}
|
||||||
|
if start_date:
|
||||||
|
params["start_date"] = start_date
|
||||||
|
if end_date:
|
||||||
|
params["end_date"] = end_date
|
||||||
|
|
||||||
|
result = await self.get(f"models/{model_id}/predictions", tenant_id=tenant_id, params=params)
|
||||||
|
return result.get("predictions", []) if result else None
|
||||||
@@ -569,28 +569,22 @@ echo ""
|
|||||||
|
|
||||||
log_step "5.1. Testing basic dashboard functionality"
|
log_step "5.1. Testing basic dashboard functionality"
|
||||||
|
|
||||||
# Test basic forecasting capability (if training completed)
|
# Use a real product name from our CSV for forecasting
|
||||||
if [ -n "$TRAINING_TASK_ID" ]; then
|
FIRST_PRODUCT=$(echo "$REAL_PRODUCTS" | sed 's/"//g' | cut -d',' -f1)
|
||||||
# Use a real product name from our CSV for forecasting
|
|
||||||
FIRST_PRODUCT=$(echo "$REAL_PRODUCTS" | sed 's/"//g' | cut -d',' -f1)
|
|
||||||
|
|
||||||
FORECAST_RESPONSE=$(curl -s -X POST "$API_BASE/api/v1/forecasting/predict" \
|
FORECAST_RESPONSE=$(curl -s -X POST "$API_BASE/api/v1/tenants/$TENANT_ID/forecast/single" \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-H "Authorization: Bearer $ACCESS_TOKEN" \
|
-H "Authorization: Bearer $ACCESS_TOKEN" \
|
||||||
-H "X-Tenant-ID: $TENANT_ID" \
|
|
||||||
-d "{
|
-d "{
|
||||||
\"products\": [\"$FIRST_PRODUCT\"],
|
\"products\": [\"$FIRST_PRODUCT\"],
|
||||||
\"forecast_days\": 7,
|
\"forecast_days\": 7,
|
||||||
\"date\": \"2024-01-15\"
|
\"date\": \"2025-09-15\"
|
||||||
}")
|
}")
|
||||||
|
|
||||||
if echo "$FORECAST_RESPONSE" | grep -q '"predictions"\|"forecast"'; then
|
if echo "$FORECAST_RESPONSE" | grep -q '"predictions"\|"forecast"'; then
|
||||||
log_success "Forecasting service is accessible"
|
log_success "Forecasting service is accessible"
|
||||||
else
|
|
||||||
log_warning "Forecasting may not be ready yet (model training required)"
|
|
||||||
fi
|
|
||||||
else
|
else
|
||||||
log_warning "Skipping forecast test - no training task ID available"
|
log_warning "Forecasting may not be ready yet (model training required)"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
|
|||||||
Reference in New Issue
Block a user