Files
bakery-ia/shared/clients/training_client.py

162 lines
6.0 KiB
Python
Raw Normal View History

2025-07-29 15:08:55 +02:00
# shared/clients/training_client.py
"""
Training Service Client
Handles all API calls to the training service
"""
from typing import Dict, Any, Optional, List
from .base_service_client import BaseServiceClient
from shared.config.base import BaseServiceSettings
class TrainingServiceClient(BaseServiceClient):
"""Client for communicating with the training service"""
def __init__(self, config: BaseServiceSettings, calling_service_name: str = "unknown"):
super().__init__(calling_service_name, config)
def get_service_base_path(self) -> str:
return "/api/v1"
# ================================================================
# TRAINING JOBS
# ================================================================
async def create_training_job(
self,
tenant_id: str,
include_weather: bool = True,
include_traffic: bool = False,
min_data_points: int = 30,
**kwargs
) -> Optional[Dict[str, Any]]:
"""Create a new training job"""
data = {
"include_weather": include_weather,
"include_traffic": include_traffic,
"min_data_points": min_data_points,
**kwargs
}
2025-10-06 15:27:01 +02:00
return await self.post("training/jobs", data=data, tenant_id=tenant_id)
2025-07-29 15:08:55 +02:00
async def get_training_job(self, tenant_id: str, job_id: str) -> Optional[Dict[str, Any]]:
"""Get training job details"""
2025-10-06 15:27:01 +02:00
return await self.get(f"training/jobs/{job_id}/status", tenant_id=tenant_id)
2025-07-29 15:08:55 +02:00
async def list_training_jobs(
self,
tenant_id: str,
status: Optional[str] = None,
limit: int = 50
) -> Optional[List[Dict[str, Any]]]:
"""List training jobs for a tenant"""
params = {"limit": limit}
if status:
params["status"] = status
2025-10-06 15:27:01 +02:00
result = await self.get("training/jobs", tenant_id=tenant_id, params=params)
2025-07-29 15:08:55 +02:00
return result.get("jobs", []) if result else None
2025-10-06 15:27:01 +02:00
2025-07-29 15:08:55 +02:00
async def cancel_training_job(self, tenant_id: str, job_id: str) -> Optional[Dict[str, Any]]:
"""Cancel a training job"""
2025-10-06 15:27:01 +02:00
return await self.delete(f"training/jobs/{job_id}", tenant_id=tenant_id)
2025-07-29 15:08:55 +02:00
# ================================================================
# MODELS
# ================================================================
async def get_model(self, tenant_id: str, model_id: str) -> Optional[Dict[str, Any]]:
"""Get model details"""
2025-10-06 15:27:01 +02:00
return await self.get(f"training/models/{model_id}", tenant_id=tenant_id)
2025-07-29 15:08:55 +02:00
async def list_models(
self,
tenant_id: str,
status: Optional[str] = None,
model_type: Optional[str] = None,
limit: int = 50
) -> Optional[List[Dict[str, Any]]]:
"""List models for a tenant"""
params = {"limit": limit}
if status:
params["status"] = status
if model_type:
params["model_type"] = model_type
2025-10-06 15:27:01 +02:00
result = await self.get("training/models", tenant_id=tenant_id, params=params)
2025-07-29 15:08:55 +02:00
return result.get("models", []) if result else None
2025-07-29 18:12:06 +02:00
async def get_active_model_for_product(
2025-07-29 15:08:55 +02:00
self,
tenant_id: str,
2025-08-15 17:53:59 +02:00
inventory_product_id: str
2025-07-29 15:08:55 +02:00
) -> Optional[Dict[str, Any]]:
2025-07-29 18:12:06 +02:00
"""
2025-08-15 17:53:59 +02:00
Get the active model for a specific product by inventory product ID
2025-07-29 18:12:06 +02:00
This is the preferred method since models are stored per product.
"""
2025-10-06 15:27:01 +02:00
result = await self.get(f"training/models/{inventory_product_id}/active", tenant_id=tenant_id)
2025-07-29 18:12:06 +02:00
return result
2025-07-29 15:08:55 +02:00
async def deploy_model(self, tenant_id: str, model_id: str) -> Optional[Dict[str, Any]]:
"""Deploy a trained model"""
2025-10-06 15:27:01 +02:00
return await self.post(f"training/models/{model_id}/deploy", data={}, tenant_id=tenant_id)
2025-07-29 15:08:55 +02:00
async def delete_model(self, tenant_id: str, model_id: str) -> Optional[Dict[str, Any]]:
"""Delete a model"""
2025-10-06 15:27:01 +02:00
return await self.delete(f"training/models/{model_id}", tenant_id=tenant_id)
2025-07-29 15:08:55 +02:00
# ================================================================
# MODEL METRICS & PERFORMANCE
# ================================================================
async def get_model_metrics(self, tenant_id: str, model_id: str) -> Optional[Dict[str, Any]]:
"""Get model performance metrics"""
2025-10-06 15:27:01 +02:00
return await self.get(f"training/models/{model_id}/metrics", tenant_id=tenant_id)
2025-07-29 15:08:55 +02:00
async def get_model_predictions(
self,
tenant_id: str,
model_id: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None
) -> Optional[List[Dict[str, Any]]]:
"""Get model predictions for evaluation"""
params = {}
if start_date:
params["start_date"] = start_date
if end_date:
params["end_date"] = end_date
2025-11-05 13:34:56 +01:00
2025-10-06 15:27:01 +02:00
result = await self.get(f"training/models/{model_id}/predictions", tenant_id=tenant_id, params=params)
2025-11-05 13:34:56 +01:00
return result.get("predictions", []) if result else None
async def trigger_retrain(
self,
tenant_id: str,
inventory_product_id: str,
reason: str = 'manual',
metadata: Optional[Dict[str, Any]] = None
) -> Optional[Dict[str, Any]]:
"""
Trigger model retraining for a specific product.
Used by orchestrator when forecast accuracy degrades.
Args:
tenant_id: Tenant UUID
inventory_product_id: Product UUID to retrain model for
reason: Reason for retraining (accuracy_degradation, manual, scheduled, etc.)
metadata: Optional metadata (e.g., previous_mape, validation_date, etc.)
Returns:
Training job details or None if failed
"""
data = {
"inventory_product_id": inventory_product_id,
"reason": reason,
"metadata": metadata or {},
"include_weather": True,
"include_traffic": False,
"min_data_points": 30
}
return await self.post("training/models/retrain", data=data, tenant_id=tenant_id)