Fix new services implementation 3
This commit is contained in:
@@ -34,21 +34,21 @@ class ForecastingBaseRepository(BaseRepository):
|
||||
)
|
||||
return await self.get_multi(skip=skip, limit=limit)
|
||||
|
||||
async def get_by_product_name(
|
||||
async def get_by_inventory_product_id(
|
||||
self,
|
||||
tenant_id: str,
|
||||
product_name: str,
|
||||
inventory_product_id: str,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List:
|
||||
"""Get records by tenant and product"""
|
||||
if hasattr(self.model, 'product_name'):
|
||||
"""Get records by tenant and inventory product"""
|
||||
if hasattr(self.model, 'inventory_product_id'):
|
||||
return await self.get_multi(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
filters={
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name
|
||||
"inventory_product_id": inventory_product_id
|
||||
},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
@@ -163,17 +163,17 @@ class ForecastingBaseRepository(BaseRepository):
|
||||
|
||||
# Get records by product if applicable
|
||||
product_stats = {}
|
||||
if hasattr(self.model, 'product_name'):
|
||||
if hasattr(self.model, 'inventory_product_id'):
|
||||
product_query = text(f"""
|
||||
SELECT product_name, COUNT(*) as count
|
||||
SELECT inventory_product_id, COUNT(*) as count
|
||||
FROM {table_name}
|
||||
WHERE tenant_id = :tenant_id
|
||||
GROUP BY product_name
|
||||
GROUP BY inventory_product_id
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(product_query, {"tenant_id": tenant_id})
|
||||
product_stats = {row.product_name: row.count for row in result.fetchall()}
|
||||
product_stats = {row.inventory_product_id: row.count for row in result.fetchall()}
|
||||
|
||||
return {
|
||||
"total_records": total_records,
|
||||
@@ -206,11 +206,11 @@ class ForecastingBaseRepository(BaseRepository):
|
||||
if not isinstance(tenant_id, str) or len(tenant_id) < 1:
|
||||
errors.append("Invalid tenant_id format")
|
||||
|
||||
# Validate product_name if present
|
||||
if "product_name" in data and data["product_name"]:
|
||||
product_name = data["product_name"]
|
||||
if not isinstance(product_name, str) or len(product_name) < 1:
|
||||
errors.append("Invalid product_name format")
|
||||
# Validate inventory_product_id if present
|
||||
if "inventory_product_id" in data and data["inventory_product_id"]:
|
||||
inventory_product_id = data["inventory_product_id"]
|
||||
if not isinstance(inventory_product_id, str) or len(inventory_product_id) < 1:
|
||||
errors.append("Invalid inventory_product_id format")
|
||||
|
||||
# Validate dates if present - accept datetime objects, date objects, and date strings
|
||||
date_fields = ["forecast_date", "created_at", "evaluation_date", "expires_at"]
|
||||
|
||||
@@ -29,7 +29,7 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
# Validate forecast data
|
||||
validation_result = self._validate_forecast_data(
|
||||
forecast_data,
|
||||
["tenant_id", "product_name", "location", "forecast_date",
|
||||
["tenant_id", "inventory_product_id", "location", "forecast_date",
|
||||
"predicted_demand", "confidence_lower", "confidence_upper", "model_id"]
|
||||
)
|
||||
|
||||
@@ -50,7 +50,7 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
logger.info("Forecast created successfully",
|
||||
forecast_id=forecast.id,
|
||||
tenant_id=forecast.tenant_id,
|
||||
product_name=forecast.product_name,
|
||||
inventory_product_id=forecast.inventory_product_id,
|
||||
forecast_date=forecast.forecast_date.isoformat())
|
||||
|
||||
return forecast
|
||||
@@ -60,7 +60,7 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
except Exception as e:
|
||||
logger.error("Failed to create forecast",
|
||||
tenant_id=forecast_data.get("tenant_id"),
|
||||
product_name=forecast_data.get("product_name"),
|
||||
inventory_product_id=forecast_data.get("inventory_product_id"),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create forecast: {str(e)}")
|
||||
|
||||
@@ -69,15 +69,15 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
tenant_id: str,
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
product_name: str = None,
|
||||
inventory_product_id: str = None,
|
||||
location: str = None
|
||||
) -> List[Forecast]:
|
||||
"""Get forecasts within a date range"""
|
||||
try:
|
||||
filters = {"tenant_id": tenant_id}
|
||||
|
||||
if product_name:
|
||||
filters["product_name"] = product_name
|
||||
if inventory_product_id:
|
||||
filters["inventory_product_id"] = inventory_product_id
|
||||
if location:
|
||||
filters["location"] = location
|
||||
|
||||
@@ -100,14 +100,14 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
async def get_latest_forecast_for_product(
|
||||
self,
|
||||
tenant_id: str,
|
||||
product_name: str,
|
||||
inventory_product_id: str,
|
||||
location: str = None
|
||||
) -> Optional[Forecast]:
|
||||
"""Get the most recent forecast for a product"""
|
||||
try:
|
||||
filters = {
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name
|
||||
"inventory_product_id": inventory_product_id
|
||||
}
|
||||
if location:
|
||||
filters["location"] = location
|
||||
@@ -124,7 +124,7 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
except Exception as e:
|
||||
logger.error("Failed to get latest forecast for product",
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get latest forecast: {str(e)}")
|
||||
|
||||
@@ -132,7 +132,7 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
self,
|
||||
tenant_id: str,
|
||||
forecast_date: date,
|
||||
product_name: str = None
|
||||
inventory_product_id: str = None
|
||||
) -> List[Forecast]:
|
||||
"""Get all forecasts for a specific date"""
|
||||
try:
|
||||
@@ -154,7 +154,7 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
async def get_forecast_accuracy_metrics(
|
||||
self,
|
||||
tenant_id: str,
|
||||
product_name: str = None,
|
||||
inventory_product_id: str = None,
|
||||
days_back: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""Get forecast accuracy metrics"""
|
||||
@@ -168,9 +168,9 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
"cutoff_date": cutoff_date
|
||||
}
|
||||
|
||||
if product_name:
|
||||
conditions.append("product_name = :product_name")
|
||||
params["product_name"] = product_name
|
||||
if inventory_product_id:
|
||||
conditions.append("inventory_product_id = :inventory_product_id")
|
||||
params["inventory_product_id"] = inventory_product_id
|
||||
|
||||
query_text = f"""
|
||||
SELECT
|
||||
@@ -180,7 +180,7 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
MAX(predicted_demand) as max_predicted_demand,
|
||||
AVG(confidence_upper - confidence_lower) as avg_confidence_interval,
|
||||
AVG(processing_time_ms) as avg_processing_time_ms,
|
||||
COUNT(DISTINCT product_name) as unique_products,
|
||||
COUNT(DISTINCT inventory_product_id) as unique_products,
|
||||
COUNT(DISTINCT model_id) as unique_models
|
||||
FROM forecasts
|
||||
WHERE {' AND '.join(conditions)}
|
||||
@@ -233,7 +233,7 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
async def get_demand_trends(
|
||||
self,
|
||||
tenant_id: str,
|
||||
product_name: str,
|
||||
inventory_product_id: str,
|
||||
days_back: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""Get demand trends for a product"""
|
||||
@@ -249,7 +249,7 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
COUNT(*) as forecast_count
|
||||
FROM forecasts
|
||||
WHERE tenant_id = :tenant_id
|
||||
AND product_name = :product_name
|
||||
AND inventory_product_id = :inventory_product_id
|
||||
AND forecast_date >= :cutoff_date
|
||||
GROUP BY DATE(forecast_date)
|
||||
ORDER BY date DESC
|
||||
@@ -257,7 +257,7 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
|
||||
result = await self.session.execute(text(query_text), {
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"cutoff_date": cutoff_date
|
||||
})
|
||||
|
||||
@@ -280,7 +280,7 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
trend_direction = "stable"
|
||||
|
||||
return {
|
||||
"product_name": product_name,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"period_days": days_back,
|
||||
"trends": trends,
|
||||
"trend_direction": trend_direction,
|
||||
@@ -290,10 +290,10 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
except Exception as e:
|
||||
logger.error("Failed to get demand trends",
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"product_name": product_name,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"period_days": days_back,
|
||||
"trends": [],
|
||||
"trend_direction": "unknown",
|
||||
@@ -311,7 +311,7 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
COUNT(*) as usage_count,
|
||||
AVG(predicted_demand) as avg_prediction,
|
||||
MAX(forecast_date) as last_used,
|
||||
COUNT(DISTINCT product_name) as products_covered
|
||||
COUNT(DISTINCT inventory_product_id) as products_covered
|
||||
FROM forecasts
|
||||
WHERE tenant_id = :tenant_id
|
||||
GROUP BY model_id, algorithm
|
||||
@@ -403,7 +403,7 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
# Validate each forecast
|
||||
validation_result = self._validate_forecast_data(
|
||||
forecast_data,
|
||||
["tenant_id", "product_name", "location", "forecast_date",
|
||||
["tenant_id", "inventory_product_id", "location", "forecast_date",
|
||||
"predicted_demand", "confidence_lower", "confidence_upper", "model_id"]
|
||||
)
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ class PerformanceMetricRepository(ForecastingBaseRepository):
|
||||
# Validate metric data
|
||||
validation_result = self._validate_forecast_data(
|
||||
metric_data,
|
||||
["model_id", "tenant_id", "product_name", "evaluation_date"]
|
||||
["model_id", "tenant_id", "inventory_product_id", "evaluation_date"]
|
||||
)
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
@@ -41,7 +41,7 @@ class PerformanceMetricRepository(ForecastingBaseRepository):
|
||||
metric_id=metric.id,
|
||||
model_id=metric.model_id,
|
||||
tenant_id=metric.tenant_id,
|
||||
product_name=metric.product_name)
|
||||
inventory_product_id=metric.inventory_product_id)
|
||||
|
||||
return metric
|
||||
|
||||
@@ -93,7 +93,7 @@ class PerformanceMetricRepository(ForecastingBaseRepository):
|
||||
async def get_performance_trends(
|
||||
self,
|
||||
tenant_id: str,
|
||||
product_name: str = None,
|
||||
inventory_product_id: str = None,
|
||||
days: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""Get performance trends over time"""
|
||||
@@ -109,14 +109,14 @@ class PerformanceMetricRepository(ForecastingBaseRepository):
|
||||
"start_date": start_date
|
||||
}
|
||||
|
||||
if product_name:
|
||||
conditions.append("product_name = :product_name")
|
||||
params["product_name"] = product_name
|
||||
if inventory_product_id:
|
||||
conditions.append("inventory_product_id = :inventory_product_id")
|
||||
params["inventory_product_id"] = inventory_product_id
|
||||
|
||||
query_text = f"""
|
||||
SELECT
|
||||
DATE(evaluation_date) as date,
|
||||
product_name,
|
||||
inventory_product_id,
|
||||
AVG(mae) as avg_mae,
|
||||
AVG(mape) as avg_mape,
|
||||
AVG(rmse) as avg_rmse,
|
||||
@@ -124,8 +124,8 @@ class PerformanceMetricRepository(ForecastingBaseRepository):
|
||||
COUNT(*) as measurement_count
|
||||
FROM model_performance_metrics
|
||||
WHERE {' AND '.join(conditions)}
|
||||
GROUP BY DATE(evaluation_date), product_name
|
||||
ORDER BY date DESC, product_name
|
||||
GROUP BY DATE(evaluation_date), inventory_product_id
|
||||
ORDER BY date DESC, inventory_product_id
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), params)
|
||||
@@ -134,7 +134,7 @@ class PerformanceMetricRepository(ForecastingBaseRepository):
|
||||
for row in result.fetchall():
|
||||
trends.append({
|
||||
"date": row.date.isoformat() if row.date else None,
|
||||
"product_name": row.product_name,
|
||||
"inventory_product_id": row.inventory_product_id,
|
||||
"metrics": {
|
||||
"avg_mae": float(row.avg_mae) if row.avg_mae else None,
|
||||
"avg_mape": float(row.avg_mape) if row.avg_mape else None,
|
||||
@@ -146,7 +146,7 @@ class PerformanceMetricRepository(ForecastingBaseRepository):
|
||||
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"period_days": days,
|
||||
"trends": trends,
|
||||
"total_measurements": len(trends)
|
||||
@@ -155,11 +155,11 @@ class PerformanceMetricRepository(ForecastingBaseRepository):
|
||||
except Exception as e:
|
||||
logger.error("Failed to get performance trends",
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"period_days": days,
|
||||
"trends": [],
|
||||
"total_measurements": 0
|
||||
|
||||
@@ -27,18 +27,18 @@ class PredictionCacheRepository(ForecastingBaseRepository):
|
||||
def _generate_cache_key(
|
||||
self,
|
||||
tenant_id: str,
|
||||
product_name: str,
|
||||
inventory_product_id: str,
|
||||
location: str,
|
||||
forecast_date: datetime
|
||||
) -> str:
|
||||
"""Generate cache key for prediction"""
|
||||
key_data = f"{tenant_id}:{product_name}:{location}:{forecast_date.isoformat()}"
|
||||
key_data = f"{tenant_id}:{inventory_product_id}:{location}:{forecast_date.isoformat()}"
|
||||
return hashlib.md5(key_data.encode()).hexdigest()
|
||||
|
||||
async def cache_prediction(
|
||||
self,
|
||||
tenant_id: str,
|
||||
product_name: str,
|
||||
inventory_product_id: str,
|
||||
location: str,
|
||||
forecast_date: datetime,
|
||||
predicted_demand: float,
|
||||
@@ -49,13 +49,13 @@ class PredictionCacheRepository(ForecastingBaseRepository):
|
||||
) -> PredictionCache:
|
||||
"""Cache a prediction result"""
|
||||
try:
|
||||
cache_key = self._generate_cache_key(tenant_id, product_name, location, forecast_date)
|
||||
cache_key = self._generate_cache_key(tenant_id, inventory_product_id, location, forecast_date)
|
||||
expires_at = datetime.utcnow() + timedelta(hours=expires_in_hours)
|
||||
|
||||
cache_data = {
|
||||
"cache_key": cache_key,
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"location": location,
|
||||
"forecast_date": forecast_date,
|
||||
"predicted_demand": predicted_demand,
|
||||
@@ -80,20 +80,20 @@ class PredictionCacheRepository(ForecastingBaseRepository):
|
||||
except Exception as e:
|
||||
logger.error("Failed to cache prediction",
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to cache prediction: {str(e)}")
|
||||
|
||||
async def get_cached_prediction(
|
||||
self,
|
||||
tenant_id: str,
|
||||
product_name: str,
|
||||
inventory_product_id: str,
|
||||
location: str,
|
||||
forecast_date: datetime
|
||||
) -> Optional[PredictionCache]:
|
||||
"""Get cached prediction if valid"""
|
||||
try:
|
||||
cache_key = self._generate_cache_key(tenant_id, product_name, location, forecast_date)
|
||||
cache_key = self._generate_cache_key(tenant_id, inventory_product_id, location, forecast_date)
|
||||
|
||||
cache_entry = await self.get_by_field("cache_key", cache_key)
|
||||
|
||||
@@ -119,14 +119,14 @@ class PredictionCacheRepository(ForecastingBaseRepository):
|
||||
except Exception as e:
|
||||
logger.error("Failed to get cached prediction",
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e))
|
||||
return None
|
||||
|
||||
async def invalidate_cache(
|
||||
self,
|
||||
tenant_id: str,
|
||||
product_name: str = None,
|
||||
inventory_product_id: str = None,
|
||||
location: str = None
|
||||
) -> int:
|
||||
"""Invalidate cache entries"""
|
||||
@@ -134,9 +134,9 @@ class PredictionCacheRepository(ForecastingBaseRepository):
|
||||
conditions = ["tenant_id = :tenant_id"]
|
||||
params = {"tenant_id": tenant_id}
|
||||
|
||||
if product_name:
|
||||
conditions.append("product_name = :product_name")
|
||||
params["product_name"] = product_name
|
||||
if inventory_product_id:
|
||||
conditions.append("inventory_product_id = :inventory_product_id")
|
||||
params["inventory_product_id"] = inventory_product_id
|
||||
|
||||
if location:
|
||||
conditions.append("location = :location")
|
||||
@@ -152,7 +152,7 @@ class PredictionCacheRepository(ForecastingBaseRepository):
|
||||
|
||||
logger.info("Cache invalidated",
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
inventory_product_id=inventory_product_id,
|
||||
location=location,
|
||||
invalidated_count=invalidated_count)
|
||||
|
||||
@@ -204,7 +204,7 @@ class PredictionCacheRepository(ForecastingBaseRepository):
|
||||
SUM(hit_count) as total_hits,
|
||||
AVG(hit_count) as avg_hits_per_entry,
|
||||
MAX(hit_count) as max_hits,
|
||||
COUNT(DISTINCT product_name) as unique_products
|
||||
COUNT(DISTINCT inventory_product_id) as unique_products
|
||||
FROM prediction_cache
|
||||
{base_filter}
|
||||
""")
|
||||
@@ -268,7 +268,7 @@ class PredictionCacheRepository(ForecastingBaseRepository):
|
||||
|
||||
query_text = f"""
|
||||
SELECT
|
||||
product_name,
|
||||
inventory_product_id,
|
||||
location,
|
||||
hit_count,
|
||||
predicted_demand,
|
||||
@@ -285,7 +285,7 @@ class PredictionCacheRepository(ForecastingBaseRepository):
|
||||
popular_predictions = []
|
||||
for row in result.fetchall():
|
||||
popular_predictions.append({
|
||||
"product_name": row.product_name,
|
||||
"inventory_product_id": row.inventory_product_id,
|
||||
"location": row.location,
|
||||
"hit_count": int(row.hit_count),
|
||||
"predicted_demand": float(row.predicted_demand),
|
||||
|
||||
Reference in New Issue
Block a user