Checking onboardin flow - fix 4

This commit is contained in:
Urtzi Alfaro
2025-07-27 16:29:53 +02:00
parent 0b14cf9eb2
commit e63a99b818
8 changed files with 497 additions and 8 deletions

View File

@@ -0,0 +1,82 @@
import time
import structlog
from typing import Dict, Any
from shared.auth.jwt_handler import JWTHandler
from app.core.config import settings
logger = structlog.get_logger()
class ServiceAuthenticator:
"""Handles service-to-service authentication via gateway"""
def __init__(self):
self.jwt_handler = JWTHandler(settings.JWT_SECRET_KEY)
self._cached_token = None
self._token_expires_at = 0
async def get_service_token(self) -> str:
"""
Get a valid service token, using cache when possible
Creates JWT tokens that the gateway will accept
"""
current_time = int(time.time())
# Return cached token if still valid (with 5 min buffer)
if (self._cached_token and
self._token_expires_at > current_time + 300):
return self._cached_token
# Create new service token
token_expires_at = current_time + 3600 # 1 hour
service_payload = {
# ✅ Required fields for gateway middleware
"sub": "training-service",
"user_id": "training-service",
"email": "training-service@internal",
"type": "access", # ✅ Must be "access" for gateway
# ✅ Expiration and timing
"exp": token_expires_at,
"iat": current_time,
"iss": "training-service",
# ✅ Service identification
"service": "training",
"full_name": "Training Service",
"is_verified": True,
"is_active": True,
# ✅ Optional tenant context (can be overridden per request)
"tenant_id": None
}
try:
token = self.jwt_handler.create_access_token_from_payload(service_payload)
# Cache the token
self._cached_token = token
self._token_expires_at = token_expires_at
logger.debug("Created new service token", expires_at=token_expires_at)
return token
except Exception as e:
logger.error(f"Failed to create service token: {e}")
raise ValueError(f"Service token creation failed: {e}")
def get_request_headers(self, tenant_id: str = None) -> Dict[str, str]:
"""Get standard headers for service requests"""
headers = {
"Content-Type": "application/json",
"X-Service": "training-service",
"User-Agent": "training-service/1.0.0"
}
if tenant_id:
headers["X-Tenant-ID"] = str(tenant_id)
return headers
# Global authenticator instance
service_auth = ServiceAuthenticator()

View File

@@ -18,6 +18,7 @@ import joblib
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import json
from pathlib import Path
import math
from app.core.config import settings
@@ -177,8 +178,16 @@ class BakeryProphetManager:
"""Prepare data for Prophet training"""
prophet_data = df.copy()
# Ensure ds column is datetime
prophet_data['ds'] = pd.to_datetime(prophet_data['ds'])
# Prophet column mapping
if 'date' in prophet_data.columns:
prophet_data['ds'] = prophet_data['date']
if 'quantity' in prophet_data.columns:
prophet_data['y'] = prophet_data['quantity']
# ✅ CRITICAL FIX: Remove timezone from ds column
if 'ds' in prophet_data.columns:
prophet_data['ds'] = pd.to_datetime(prophet_data['ds']).dt.tz_localize(None)
logger.info(f"Removed timezone from ds column")
# Handle missing values in target
if prophet_data['y'].isna().any():
@@ -345,7 +354,14 @@ class BakeryProphetManager:
rmse = np.sqrt(mse)
# MAPE (Mean Absolute Percentage Error)
mape = np.mean(np.abs((y_true - y_pred) / y_true)) * 100
non_zero_mask = y_true != 0
if np.sum(non_zero_mask) == 0:
mape = 0.0 # Return 0 instead of Infinity
else:
mape_values = np.abs((y_true[non_zero_mask] - y_pred[non_zero_mask]) / y_true[non_zero_mask])
mape = np.mean(mape_values) * 100
if math.isinf(mape) or math.isnan(mape):
mape = 0.0
# R-squared
r2 = r2_score(y_true, y_pred)

View File

@@ -244,6 +244,10 @@ class BakeryMLTrainer:
if sales_df.empty:
raise ValueError(f"No sales data provided for tenant {tenant_id}")
if 'quantity_sold' in sales_df.columns and 'quantity' not in sales_df.columns:
sales_df['quantity'] = sales_df['quantity_sold']
logger.info("Mapped 'quantity_sold' to 'quantity' column")
required_columns = ['date', 'product_name', 'quantity']
missing_columns = [col for col in required_columns if col not in sales_df.columns]
if missing_columns:

View File

@@ -0,0 +1,140 @@
import httpx
import structlog
from typing import List, Dict, Any, Optional
from app.core.config import settings
from app.core.service_auth import service_auth
logger = structlog.get_logger()
class DataServiceClient:
"""Client for fetching data through the API Gateway"""
def __init__(self):
self.base_url = settings.API_GATEWAY_URL
self.timeout = 30.0
async def fetch_sales_data(
self,
tenant_id: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
product_name: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
Fetch sales data for training via API Gateway
✅ Uses proper service authentication
"""
try:
# Get service token
token = await service_auth.get_service_token()
# Prepare headers
headers = service_auth.get_request_headers(tenant_id)
headers["Authorization"] = f"Bearer {token}"
# Prepare query parameters
params = {}
if start_date:
params["start_date"] = start_date
if end_date:
params["end_date"] = end_date
if product_name:
params["product_name"] = product_name
# Make request via gateway
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(
f"{self.base_url}/api/v1/tenants/{tenant_id}/sales",
headers=headers,
params=params
)
logger.info(f"Sales data request: {response.status_code}",
tenant_id=tenant_id,
url=response.url)
if response.status_code == 200:
data = response.json()
logger.info(f"Successfully fetched {len(data)} sales records via gateway",
tenant_id=tenant_id)
return data
elif response.status_code == 401:
logger.error("Authentication failed with gateway",
tenant_id=tenant_id,
response_text=response.text)
return []
elif response.status_code == 404:
logger.warning("Sales data endpoint not found",
tenant_id=tenant_id,
url=response.url)
return []
else:
logger.error(f"Gateway request failed: HTTP {response.status_code}",
tenant_id=tenant_id,
response_text=response.text)
return []
except httpx.TimeoutException:
logger.error("Timeout when fetching sales data via gateway",
tenant_id=tenant_id)
return []
except Exception as e:
logger.error(f"Error fetching sales data via gateway: {e}",
tenant_id=tenant_id)
return []
async def fetch_weather_data(self, tenant_id: str) -> List[Dict[str, Any]]:
"""Fetch weather data via API Gateway"""
try:
token = await service_auth.get_service_token()
headers = service_auth.get_request_headers(tenant_id)
headers["Authorization"] = f"Bearer {token}"
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(
f"{self.base_url}/api/v1/tenants/{tenant_id}/weather/history",
headers=headers
)
if response.status_code == 200:
data = response.json()
logger.info(f"Fetched {len(data)} weather records", tenant_id=tenant_id)
return data
else:
logger.warning(f"Weather data fetch failed: {response.status_code}",
tenant_id=tenant_id)
return []
except Exception as e:
logger.warning(f"Error fetching weather data: {e}", tenant_id=tenant_id)
return []
async def fetch_traffic_data(self, tenant_id: str) -> List[Dict[str, Any]]:
"""Fetch traffic data via API Gateway"""
try:
token = await service_auth.get_service_token()
headers = service_auth.get_request_headers(tenant_id)
headers["Authorization"] = f"Bearer {token}"
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(
f"{self.base_url}/api/v1/tenants/{tenant_id}/traffic/historical",
headers=headers
)
if response.status_code == 200:
data = response.json()
logger.info(f"Fetched {len(data)} traffic records", tenant_id=tenant_id)
return data
else:
logger.warning(f"Traffic data fetch failed: {response.status_code}",
tenant_id=tenant_id)
return []
except Exception as e:
logger.warning(f"Error fetching traffic data: {e}", tenant_id=tenant_id)
return []

View File

@@ -19,6 +19,7 @@ from app.schemas.training import TrainingJobRequest, SingleProductTrainingReques
from app.services.messaging import publish_job_completed, publish_job_failed
from app.core.config import settings
from shared.monitoring.metrics import MetricsCollector
from app.services.data_client import DataServiceClient
logger = logging.getLogger(__name__)
metrics = MetricsCollector("training-service")
@@ -31,6 +32,7 @@ class TrainingService:
def __init__(self):
self.ml_trainer = BakeryMLTrainer()
self.data_client = DataServiceClient()
async def execute_training_job_simple(self, job_id: str, tenant_id_str: str, request: TrainingJobRequest):
"""Simple wrapper that creates its own database session"""
@@ -136,7 +138,7 @@ class TrainingService:
await self._update_job_status(db, job_id, "running", 5, "Fetching training data")
# Fetch sales data from data service
sales_data = await self._fetch_sales_data(tenant_id, request)
sales_data = await self.data_client.fetch_sales_data(tenant_id)
# Fetch external data if requested
weather_data = []
@@ -144,11 +146,11 @@ class TrainingService:
if request.include_weather:
await self._update_job_status(db, job_id, "running", 15, "Fetching weather data")
weather_data = await self._fetch_weather_data(tenant_id, request)
weather_data = await self.data_client.fetch_weather_data(tenant_id)
if request.include_traffic:
await self._update_job_status(db, job_id, "running", 25, "Fetching traffic data")
traffic_data = await self._fetch_traffic_data(tenant_id, request)
traffic_data = await self.data_client.fetch_traffic_data(tenant_id)
# Execute ML training
await self._update_job_status(db, job_id, "running", 35, "Processing training data")