Fix new services implementation 3

This commit is contained in:
Urtzi Alfaro
2025-08-14 16:47:34 +02:00
parent 0951547e92
commit 03737430ee
51 changed files with 657 additions and 982 deletions

View File

@@ -34,21 +34,21 @@ class ForecastingBaseRepository(BaseRepository):
)
return await self.get_multi(skip=skip, limit=limit)
async def get_by_product_name(
async def get_by_inventory_product_id(
self,
tenant_id: str,
product_name: str,
inventory_product_id: str,
skip: int = 0,
limit: int = 100
) -> List:
"""Get records by tenant and product"""
if hasattr(self.model, 'product_name'):
"""Get records by tenant and inventory product"""
if hasattr(self.model, 'inventory_product_id'):
return await self.get_multi(
skip=skip,
limit=limit,
filters={
"tenant_id": tenant_id,
"product_name": product_name
"inventory_product_id": inventory_product_id
},
order_by="created_at",
order_desc=True
@@ -163,17 +163,17 @@ class ForecastingBaseRepository(BaseRepository):
# Get records by product if applicable
product_stats = {}
if hasattr(self.model, 'product_name'):
if hasattr(self.model, 'inventory_product_id'):
product_query = text(f"""
SELECT product_name, COUNT(*) as count
SELECT inventory_product_id, COUNT(*) as count
FROM {table_name}
WHERE tenant_id = :tenant_id
GROUP BY product_name
GROUP BY inventory_product_id
ORDER BY count DESC
""")
result = await self.session.execute(product_query, {"tenant_id": tenant_id})
product_stats = {row.product_name: row.count for row in result.fetchall()}
product_stats = {row.inventory_product_id: row.count for row in result.fetchall()}
return {
"total_records": total_records,
@@ -206,11 +206,11 @@ class ForecastingBaseRepository(BaseRepository):
if not isinstance(tenant_id, str) or len(tenant_id) < 1:
errors.append("Invalid tenant_id format")
# Validate product_name if present
if "product_name" in data and data["product_name"]:
product_name = data["product_name"]
if not isinstance(product_name, str) or len(product_name) < 1:
errors.append("Invalid product_name format")
# Validate inventory_product_id if present
if "inventory_product_id" in data and data["inventory_product_id"]:
inventory_product_id = data["inventory_product_id"]
if not isinstance(inventory_product_id, str) or len(inventory_product_id) < 1:
errors.append("Invalid inventory_product_id format")
# Validate dates if present - accept datetime objects, date objects, and date strings
date_fields = ["forecast_date", "created_at", "evaluation_date", "expires_at"]

View File

@@ -29,7 +29,7 @@ class ForecastRepository(ForecastingBaseRepository):
# Validate forecast data
validation_result = self._validate_forecast_data(
forecast_data,
["tenant_id", "product_name", "location", "forecast_date",
["tenant_id", "inventory_product_id", "location", "forecast_date",
"predicted_demand", "confidence_lower", "confidence_upper", "model_id"]
)
@@ -50,7 +50,7 @@ class ForecastRepository(ForecastingBaseRepository):
logger.info("Forecast created successfully",
forecast_id=forecast.id,
tenant_id=forecast.tenant_id,
product_name=forecast.product_name,
inventory_product_id=forecast.inventory_product_id,
forecast_date=forecast.forecast_date.isoformat())
return forecast
@@ -60,7 +60,7 @@ class ForecastRepository(ForecastingBaseRepository):
except Exception as e:
logger.error("Failed to create forecast",
tenant_id=forecast_data.get("tenant_id"),
product_name=forecast_data.get("product_name"),
inventory_product_id=forecast_data.get("inventory_product_id"),
error=str(e))
raise DatabaseError(f"Failed to create forecast: {str(e)}")
@@ -69,15 +69,15 @@ class ForecastRepository(ForecastingBaseRepository):
tenant_id: str,
start_date: date,
end_date: date,
product_name: str = None,
inventory_product_id: str = None,
location: str = None
) -> List[Forecast]:
"""Get forecasts within a date range"""
try:
filters = {"tenant_id": tenant_id}
if product_name:
filters["product_name"] = product_name
if inventory_product_id:
filters["inventory_product_id"] = inventory_product_id
if location:
filters["location"] = location
@@ -100,14 +100,14 @@ class ForecastRepository(ForecastingBaseRepository):
async def get_latest_forecast_for_product(
self,
tenant_id: str,
product_name: str,
inventory_product_id: str,
location: str = None
) -> Optional[Forecast]:
"""Get the most recent forecast for a product"""
try:
filters = {
"tenant_id": tenant_id,
"product_name": product_name
"inventory_product_id": inventory_product_id
}
if location:
filters["location"] = location
@@ -124,7 +124,7 @@ class ForecastRepository(ForecastingBaseRepository):
except Exception as e:
logger.error("Failed to get latest forecast for product",
tenant_id=tenant_id,
product_name=product_name,
inventory_product_id=inventory_product_id,
error=str(e))
raise DatabaseError(f"Failed to get latest forecast: {str(e)}")
@@ -132,7 +132,7 @@ class ForecastRepository(ForecastingBaseRepository):
self,
tenant_id: str,
forecast_date: date,
product_name: str = None
inventory_product_id: str = None
) -> List[Forecast]:
"""Get all forecasts for a specific date"""
try:
@@ -154,7 +154,7 @@ class ForecastRepository(ForecastingBaseRepository):
async def get_forecast_accuracy_metrics(
self,
tenant_id: str,
product_name: str = None,
inventory_product_id: str = None,
days_back: int = 30
) -> Dict[str, Any]:
"""Get forecast accuracy metrics"""
@@ -168,9 +168,9 @@ class ForecastRepository(ForecastingBaseRepository):
"cutoff_date": cutoff_date
}
if product_name:
conditions.append("product_name = :product_name")
params["product_name"] = product_name
if inventory_product_id:
conditions.append("inventory_product_id = :inventory_product_id")
params["inventory_product_id"] = inventory_product_id
query_text = f"""
SELECT
@@ -180,7 +180,7 @@ class ForecastRepository(ForecastingBaseRepository):
MAX(predicted_demand) as max_predicted_demand,
AVG(confidence_upper - confidence_lower) as avg_confidence_interval,
AVG(processing_time_ms) as avg_processing_time_ms,
COUNT(DISTINCT product_name) as unique_products,
COUNT(DISTINCT inventory_product_id) as unique_products,
COUNT(DISTINCT model_id) as unique_models
FROM forecasts
WHERE {' AND '.join(conditions)}
@@ -233,7 +233,7 @@ class ForecastRepository(ForecastingBaseRepository):
async def get_demand_trends(
self,
tenant_id: str,
product_name: str,
inventory_product_id: str,
days_back: int = 30
) -> Dict[str, Any]:
"""Get demand trends for a product"""
@@ -249,7 +249,7 @@ class ForecastRepository(ForecastingBaseRepository):
COUNT(*) as forecast_count
FROM forecasts
WHERE tenant_id = :tenant_id
AND product_name = :product_name
AND inventory_product_id = :inventory_product_id
AND forecast_date >= :cutoff_date
GROUP BY DATE(forecast_date)
ORDER BY date DESC
@@ -257,7 +257,7 @@ class ForecastRepository(ForecastingBaseRepository):
result = await self.session.execute(text(query_text), {
"tenant_id": tenant_id,
"product_name": product_name,
"inventory_product_id": inventory_product_id,
"cutoff_date": cutoff_date
})
@@ -280,7 +280,7 @@ class ForecastRepository(ForecastingBaseRepository):
trend_direction = "stable"
return {
"product_name": product_name,
"inventory_product_id": inventory_product_id,
"period_days": days_back,
"trends": trends,
"trend_direction": trend_direction,
@@ -290,10 +290,10 @@ class ForecastRepository(ForecastingBaseRepository):
except Exception as e:
logger.error("Failed to get demand trends",
tenant_id=tenant_id,
product_name=product_name,
inventory_product_id=inventory_product_id,
error=str(e))
return {
"product_name": product_name,
"inventory_product_id": inventory_product_id,
"period_days": days_back,
"trends": [],
"trend_direction": "unknown",
@@ -311,7 +311,7 @@ class ForecastRepository(ForecastingBaseRepository):
COUNT(*) as usage_count,
AVG(predicted_demand) as avg_prediction,
MAX(forecast_date) as last_used,
COUNT(DISTINCT product_name) as products_covered
COUNT(DISTINCT inventory_product_id) as products_covered
FROM forecasts
WHERE tenant_id = :tenant_id
GROUP BY model_id, algorithm
@@ -403,7 +403,7 @@ class ForecastRepository(ForecastingBaseRepository):
# Validate each forecast
validation_result = self._validate_forecast_data(
forecast_data,
["tenant_id", "product_name", "location", "forecast_date",
["tenant_id", "inventory_product_id", "location", "forecast_date",
"predicted_demand", "confidence_lower", "confidence_upper", "model_id"]
)

View File

@@ -29,7 +29,7 @@ class PerformanceMetricRepository(ForecastingBaseRepository):
# Validate metric data
validation_result = self._validate_forecast_data(
metric_data,
["model_id", "tenant_id", "product_name", "evaluation_date"]
["model_id", "tenant_id", "inventory_product_id", "evaluation_date"]
)
if not validation_result["is_valid"]:
@@ -41,7 +41,7 @@ class PerformanceMetricRepository(ForecastingBaseRepository):
metric_id=metric.id,
model_id=metric.model_id,
tenant_id=metric.tenant_id,
product_name=metric.product_name)
inventory_product_id=metric.inventory_product_id)
return metric
@@ -93,7 +93,7 @@ class PerformanceMetricRepository(ForecastingBaseRepository):
async def get_performance_trends(
self,
tenant_id: str,
product_name: str = None,
inventory_product_id: str = None,
days: int = 30
) -> Dict[str, Any]:
"""Get performance trends over time"""
@@ -109,14 +109,14 @@ class PerformanceMetricRepository(ForecastingBaseRepository):
"start_date": start_date
}
if product_name:
conditions.append("product_name = :product_name")
params["product_name"] = product_name
if inventory_product_id:
conditions.append("inventory_product_id = :inventory_product_id")
params["inventory_product_id"] = inventory_product_id
query_text = f"""
SELECT
DATE(evaluation_date) as date,
product_name,
inventory_product_id,
AVG(mae) as avg_mae,
AVG(mape) as avg_mape,
AVG(rmse) as avg_rmse,
@@ -124,8 +124,8 @@ class PerformanceMetricRepository(ForecastingBaseRepository):
COUNT(*) as measurement_count
FROM model_performance_metrics
WHERE {' AND '.join(conditions)}
GROUP BY DATE(evaluation_date), product_name
ORDER BY date DESC, product_name
GROUP BY DATE(evaluation_date), inventory_product_id
ORDER BY date DESC, inventory_product_id
"""
result = await self.session.execute(text(query_text), params)
@@ -134,7 +134,7 @@ class PerformanceMetricRepository(ForecastingBaseRepository):
for row in result.fetchall():
trends.append({
"date": row.date.isoformat() if row.date else None,
"product_name": row.product_name,
"inventory_product_id": row.inventory_product_id,
"metrics": {
"avg_mae": float(row.avg_mae) if row.avg_mae else None,
"avg_mape": float(row.avg_mape) if row.avg_mape else None,
@@ -146,7 +146,7 @@ class PerformanceMetricRepository(ForecastingBaseRepository):
return {
"tenant_id": tenant_id,
"product_name": product_name,
"inventory_product_id": inventory_product_id,
"period_days": days,
"trends": trends,
"total_measurements": len(trends)
@@ -155,11 +155,11 @@ class PerformanceMetricRepository(ForecastingBaseRepository):
except Exception as e:
logger.error("Failed to get performance trends",
tenant_id=tenant_id,
product_name=product_name,
inventory_product_id=inventory_product_id,
error=str(e))
return {
"tenant_id": tenant_id,
"product_name": product_name,
"inventory_product_id": inventory_product_id,
"period_days": days,
"trends": [],
"total_measurements": 0

View File

@@ -27,18 +27,18 @@ class PredictionCacheRepository(ForecastingBaseRepository):
def _generate_cache_key(
self,
tenant_id: str,
product_name: str,
inventory_product_id: str,
location: str,
forecast_date: datetime
) -> str:
"""Generate cache key for prediction"""
key_data = f"{tenant_id}:{product_name}:{location}:{forecast_date.isoformat()}"
key_data = f"{tenant_id}:{inventory_product_id}:{location}:{forecast_date.isoformat()}"
return hashlib.md5(key_data.encode()).hexdigest()
async def cache_prediction(
self,
tenant_id: str,
product_name: str,
inventory_product_id: str,
location: str,
forecast_date: datetime,
predicted_demand: float,
@@ -49,13 +49,13 @@ class PredictionCacheRepository(ForecastingBaseRepository):
) -> PredictionCache:
"""Cache a prediction result"""
try:
cache_key = self._generate_cache_key(tenant_id, product_name, location, forecast_date)
cache_key = self._generate_cache_key(tenant_id, inventory_product_id, location, forecast_date)
expires_at = datetime.utcnow() + timedelta(hours=expires_in_hours)
cache_data = {
"cache_key": cache_key,
"tenant_id": tenant_id,
"product_name": product_name,
"inventory_product_id": inventory_product_id,
"location": location,
"forecast_date": forecast_date,
"predicted_demand": predicted_demand,
@@ -80,20 +80,20 @@ class PredictionCacheRepository(ForecastingBaseRepository):
except Exception as e:
logger.error("Failed to cache prediction",
tenant_id=tenant_id,
product_name=product_name,
inventory_product_id=inventory_product_id,
error=str(e))
raise DatabaseError(f"Failed to cache prediction: {str(e)}")
async def get_cached_prediction(
self,
tenant_id: str,
product_name: str,
inventory_product_id: str,
location: str,
forecast_date: datetime
) -> Optional[PredictionCache]:
"""Get cached prediction if valid"""
try:
cache_key = self._generate_cache_key(tenant_id, product_name, location, forecast_date)
cache_key = self._generate_cache_key(tenant_id, inventory_product_id, location, forecast_date)
cache_entry = await self.get_by_field("cache_key", cache_key)
@@ -119,14 +119,14 @@ class PredictionCacheRepository(ForecastingBaseRepository):
except Exception as e:
logger.error("Failed to get cached prediction",
tenant_id=tenant_id,
product_name=product_name,
inventory_product_id=inventory_product_id,
error=str(e))
return None
async def invalidate_cache(
self,
tenant_id: str,
product_name: str = None,
inventory_product_id: str = None,
location: str = None
) -> int:
"""Invalidate cache entries"""
@@ -134,9 +134,9 @@ class PredictionCacheRepository(ForecastingBaseRepository):
conditions = ["tenant_id = :tenant_id"]
params = {"tenant_id": tenant_id}
if product_name:
conditions.append("product_name = :product_name")
params["product_name"] = product_name
if inventory_product_id:
conditions.append("inventory_product_id = :inventory_product_id")
params["inventory_product_id"] = inventory_product_id
if location:
conditions.append("location = :location")
@@ -152,7 +152,7 @@ class PredictionCacheRepository(ForecastingBaseRepository):
logger.info("Cache invalidated",
tenant_id=tenant_id,
product_name=product_name,
inventory_product_id=inventory_product_id,
location=location,
invalidated_count=invalidated_count)
@@ -204,7 +204,7 @@ class PredictionCacheRepository(ForecastingBaseRepository):
SUM(hit_count) as total_hits,
AVG(hit_count) as avg_hits_per_entry,
MAX(hit_count) as max_hits,
COUNT(DISTINCT product_name) as unique_products
COUNT(DISTINCT inventory_product_id) as unique_products
FROM prediction_cache
{base_filter}
""")
@@ -268,7 +268,7 @@ class PredictionCacheRepository(ForecastingBaseRepository):
query_text = f"""
SELECT
product_name,
inventory_product_id,
location,
hit_count,
predicted_demand,
@@ -285,7 +285,7 @@ class PredictionCacheRepository(ForecastingBaseRepository):
popular_predictions = []
for row in result.fetchall():
popular_predictions.append({
"product_name": row.product_name,
"inventory_product_id": row.inventory_product_id,
"location": row.location,
"hit_count": int(row.hit_count),
"predicted_demand": float(row.predicted_demand),