162 lines
6.0 KiB
Python
162 lines
6.0 KiB
Python
|
|
# shared/clients/training_client.py
|
||
|
|
"""
|
||
|
|
Training Service Client
|
||
|
|
Handles all API calls to the training service
|
||
|
|
"""
|
||
|
|
|
||
|
|
from typing import Dict, Any, Optional, List
|
||
|
|
from .base_service_client import BaseServiceClient
|
||
|
|
from shared.config.base import BaseServiceSettings
|
||
|
|
|
||
|
|
|
||
|
|
class TrainingServiceClient(BaseServiceClient):
|
||
|
|
"""Client for communicating with the training service"""
|
||
|
|
|
||
|
|
def __init__(self, config: BaseServiceSettings, calling_service_name: str = "unknown"):
|
||
|
|
super().__init__(calling_service_name, config)
|
||
|
|
|
||
|
|
def get_service_base_path(self) -> str:
|
||
|
|
return "/api/v1"
|
||
|
|
|
||
|
|
# ================================================================
|
||
|
|
# TRAINING JOBS
|
||
|
|
# ================================================================
|
||
|
|
|
||
|
|
async def create_training_job(
|
||
|
|
self,
|
||
|
|
tenant_id: str,
|
||
|
|
include_weather: bool = True,
|
||
|
|
include_traffic: bool = False,
|
||
|
|
min_data_points: int = 30,
|
||
|
|
**kwargs
|
||
|
|
) -> Optional[Dict[str, Any]]:
|
||
|
|
"""Create a new training job"""
|
||
|
|
data = {
|
||
|
|
"include_weather": include_weather,
|
||
|
|
"include_traffic": include_traffic,
|
||
|
|
"min_data_points": min_data_points,
|
||
|
|
**kwargs
|
||
|
|
}
|
||
|
|
return await self.post("training/jobs", data=data, tenant_id=tenant_id)
|
||
|
|
|
||
|
|
async def get_training_job(self, tenant_id: str, job_id: str) -> Optional[Dict[str, Any]]:
|
||
|
|
"""Get training job details"""
|
||
|
|
return await self.get(f"training/jobs/{job_id}/status", tenant_id=tenant_id)
|
||
|
|
|
||
|
|
async def list_training_jobs(
|
||
|
|
self,
|
||
|
|
tenant_id: str,
|
||
|
|
status: Optional[str] = None,
|
||
|
|
limit: int = 50
|
||
|
|
) -> Optional[List[Dict[str, Any]]]:
|
||
|
|
"""List training jobs for a tenant"""
|
||
|
|
params = {"limit": limit}
|
||
|
|
if status:
|
||
|
|
params["status"] = status
|
||
|
|
|
||
|
|
result = await self.get("training/jobs", tenant_id=tenant_id, params=params)
|
||
|
|
return result.get("jobs", []) if result else None
|
||
|
|
|
||
|
|
async def cancel_training_job(self, tenant_id: str, job_id: str) -> Optional[Dict[str, Any]]:
|
||
|
|
"""Cancel a training job"""
|
||
|
|
return await self.delete(f"training/jobs/{job_id}", tenant_id=tenant_id)
|
||
|
|
|
||
|
|
# ================================================================
|
||
|
|
# MODELS
|
||
|
|
# ================================================================
|
||
|
|
|
||
|
|
async def get_model(self, tenant_id: str, model_id: str) -> Optional[Dict[str, Any]]:
|
||
|
|
"""Get model details"""
|
||
|
|
return await self.get(f"training/models/{model_id}", tenant_id=tenant_id)
|
||
|
|
|
||
|
|
async def list_models(
|
||
|
|
self,
|
||
|
|
tenant_id: str,
|
||
|
|
status: Optional[str] = None,
|
||
|
|
model_type: Optional[str] = None,
|
||
|
|
limit: int = 50
|
||
|
|
) -> Optional[List[Dict[str, Any]]]:
|
||
|
|
"""List models for a tenant"""
|
||
|
|
params = {"limit": limit}
|
||
|
|
if status:
|
||
|
|
params["status"] = status
|
||
|
|
if model_type:
|
||
|
|
params["model_type"] = model_type
|
||
|
|
|
||
|
|
result = await self.get("training/models", tenant_id=tenant_id, params=params)
|
||
|
|
return result.get("models", []) if result else None
|
||
|
|
|
||
|
|
async def get_active_model_for_product(
|
||
|
|
self,
|
||
|
|
tenant_id: str,
|
||
|
|
inventory_product_id: str
|
||
|
|
) -> Optional[Dict[str, Any]]:
|
||
|
|
"""
|
||
|
|
Get the active model for a specific product by inventory product ID
|
||
|
|
This is the preferred method since models are stored per product.
|
||
|
|
"""
|
||
|
|
result = await self.get(f"training/models/{inventory_product_id}/active", tenant_id=tenant_id)
|
||
|
|
return result
|
||
|
|
|
||
|
|
async def deploy_model(self, tenant_id: str, model_id: str) -> Optional[Dict[str, Any]]:
|
||
|
|
"""Deploy a trained model"""
|
||
|
|
return await self.post(f"training/models/{model_id}/deploy", data={}, tenant_id=tenant_id)
|
||
|
|
|
||
|
|
async def delete_model(self, tenant_id: str, model_id: str) -> Optional[Dict[str, Any]]:
|
||
|
|
"""Delete a model"""
|
||
|
|
return await self.delete(f"training/models/{model_id}", tenant_id=tenant_id)
|
||
|
|
|
||
|
|
# ================================================================
|
||
|
|
# MODEL METRICS & PERFORMANCE
|
||
|
|
# ================================================================
|
||
|
|
|
||
|
|
async def get_model_metrics(self, tenant_id: str, model_id: str) -> Optional[Dict[str, Any]]:
|
||
|
|
"""Get model performance metrics"""
|
||
|
|
return await self.get(f"training/models/{model_id}/metrics", tenant_id=tenant_id)
|
||
|
|
|
||
|
|
async def get_model_predictions(
|
||
|
|
self,
|
||
|
|
tenant_id: str,
|
||
|
|
model_id: str,
|
||
|
|
start_date: Optional[str] = None,
|
||
|
|
end_date: Optional[str] = None
|
||
|
|
) -> Optional[List[Dict[str, Any]]]:
|
||
|
|
"""Get model predictions for evaluation"""
|
||
|
|
params = {}
|
||
|
|
if start_date:
|
||
|
|
params["start_date"] = start_date
|
||
|
|
if end_date:
|
||
|
|
params["end_date"] = end_date
|
||
|
|
|
||
|
|
result = await self.get(f"training/models/{model_id}/predictions", tenant_id=tenant_id, params=params)
|
||
|
|
return result.get("predictions", []) if result else None
|
||
|
|
|
||
|
|
async def trigger_retrain(
|
||
|
|
self,
|
||
|
|
tenant_id: str,
|
||
|
|
inventory_product_id: str,
|
||
|
|
reason: str = 'manual',
|
||
|
|
metadata: Optional[Dict[str, Any]] = None
|
||
|
|
) -> Optional[Dict[str, Any]]:
|
||
|
|
"""
|
||
|
|
Trigger model retraining for a specific product.
|
||
|
|
Used by orchestrator when forecast accuracy degrades.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
tenant_id: Tenant UUID
|
||
|
|
inventory_product_id: Product UUID to retrain model for
|
||
|
|
reason: Reason for retraining (accuracy_degradation, manual, scheduled, etc.)
|
||
|
|
metadata: Optional metadata (e.g., previous_mape, validation_date, etc.)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Training job details or None if failed
|
||
|
|
"""
|
||
|
|
data = {
|
||
|
|
"inventory_product_id": inventory_product_id,
|
||
|
|
"reason": reason,
|
||
|
|
"metadata": metadata or {},
|
||
|
|
"include_weather": True,
|
||
|
|
"include_traffic": False,
|
||
|
|
"min_data_points": 30
|
||
|
|
}
|
||
|
|
return await self.post("training/models/retrain", data=data, tenant_id=tenant_id)
|