Fix new services implementation 3
This commit is contained in:
@@ -56,7 +56,7 @@ async def create_enhanced_single_forecast(
|
||||
|
||||
logger.info("Generating enhanced single forecast",
|
||||
tenant_id=tenant_id,
|
||||
product_name=request.product_name,
|
||||
inventory_product_id=request.inventory_product_id,
|
||||
forecast_date=request.forecast_date.isoformat())
|
||||
|
||||
# Record metrics
|
||||
@@ -124,13 +124,13 @@ async def create_enhanced_batch_forecast(
|
||||
|
||||
logger.info("Generating enhanced batch forecasts",
|
||||
tenant_id=tenant_id,
|
||||
products_count=len(request.products),
|
||||
forecast_dates_count=len(request.forecast_dates))
|
||||
products_count=len(request.inventory_product_ids),
|
||||
forecast_dates_count=request.forecast_days)
|
||||
|
||||
# Record metrics
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_batch_forecasts_total")
|
||||
metrics.histogram("enhanced_batch_forecast_products_count", len(request.products))
|
||||
metrics.histogram("enhanced_batch_forecast_products_count", len(request.inventory_product_ids))
|
||||
|
||||
# Generate batch forecasts using enhanced service
|
||||
batch_result = await enhanced_forecasting_service.generate_batch_forecasts(
|
||||
@@ -174,7 +174,7 @@ async def create_enhanced_batch_forecast(
|
||||
@track_execution_time("enhanced_get_forecasts_duration_seconds", "forecasting-service")
|
||||
async def get_enhanced_tenant_forecasts(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
product_name: Optional[str] = Query(None, description="Filter by product name"),
|
||||
inventory_product_id: Optional[str] = Query(None, description="Filter by inventory product ID"),
|
||||
start_date: Optional[date] = Query(None, description="Start date filter"),
|
||||
end_date: Optional[date] = Query(None, description="End date filter"),
|
||||
skip: int = Query(0, description="Number of records to skip"),
|
||||
@@ -203,7 +203,7 @@ async def get_enhanced_tenant_forecasts(
|
||||
# Get forecasts using enhanced service
|
||||
forecasts = await enhanced_forecasting_service.get_tenant_forecasts(
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
inventory_product_id=inventory_product_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
skip=skip,
|
||||
@@ -218,7 +218,7 @@ async def get_enhanced_tenant_forecasts(
|
||||
"forecasts": forecasts,
|
||||
"total_returned": len(forecasts),
|
||||
"filters": {
|
||||
"product_name": product_name,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"start_date": start_date.isoformat() if start_date else None,
|
||||
"end_date": end_date.isoformat() if end_date else None
|
||||
},
|
||||
|
||||
@@ -59,14 +59,14 @@ async def generate_enhanced_realtime_prediction(
|
||||
|
||||
logger.info("Generating enhanced real-time prediction",
|
||||
tenant_id=tenant_id,
|
||||
product_name=prediction_request.get("product_name"))
|
||||
inventory_product_id=prediction_request.get("inventory_product_id"))
|
||||
|
||||
# Record metrics
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_realtime_predictions_total")
|
||||
|
||||
# Validate required fields
|
||||
required_fields = ["product_name", "model_id", "features"]
|
||||
required_fields = ["inventory_product_id", "model_id", "features"]
|
||||
missing_fields = [field for field in required_fields if field not in prediction_request]
|
||||
if missing_fields:
|
||||
raise HTTPException(
|
||||
@@ -91,7 +91,7 @@ async def generate_enhanced_realtime_prediction(
|
||||
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": prediction_request["product_name"],
|
||||
"inventory_product_id": prediction_request["inventory_product_id"],
|
||||
"model_id": prediction_request["model_id"],
|
||||
"prediction": prediction_result,
|
||||
"generated_at": datetime.now().isoformat(),
|
||||
@@ -205,7 +205,7 @@ async def generate_enhanced_batch_predictions(
|
||||
@track_execution_time("enhanced_get_prediction_cache_duration_seconds", "forecasting-service")
|
||||
async def get_enhanced_prediction_cache(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
product_name: Optional[str] = Query(None, description="Filter by product name"),
|
||||
inventory_product_id: Optional[str] = Query(None, description="Filter by inventory product ID"),
|
||||
skip: int = Query(0, description="Number of records to skip"),
|
||||
limit: int = Query(100, description="Number of records to return"),
|
||||
request_obj: Request = None,
|
||||
@@ -232,7 +232,7 @@ async def get_enhanced_prediction_cache(
|
||||
# Get cached predictions using enhanced service
|
||||
cached_predictions = await enhanced_forecasting_service.get_cached_predictions(
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
inventory_product_id=inventory_product_id,
|
||||
skip=skip,
|
||||
limit=limit
|
||||
)
|
||||
@@ -245,7 +245,7 @@ async def get_enhanced_prediction_cache(
|
||||
"cached_predictions": cached_predictions,
|
||||
"total_returned": len(cached_predictions),
|
||||
"filters": {
|
||||
"product_name": product_name
|
||||
"inventory_product_id": inventory_product_id
|
||||
},
|
||||
"pagination": {
|
||||
"skip": skip,
|
||||
@@ -271,7 +271,7 @@ async def get_enhanced_prediction_cache(
|
||||
@track_execution_time("enhanced_clear_prediction_cache_duration_seconds", "forecasting-service")
|
||||
async def clear_enhanced_prediction_cache(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
product_name: Optional[str] = Query(None, description="Clear cache for specific product"),
|
||||
inventory_product_id: Optional[str] = Query(None, description="Clear cache for specific inventory product ID"),
|
||||
request_obj: Request = None,
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
||||
@@ -296,7 +296,7 @@ async def clear_enhanced_prediction_cache(
|
||||
# Clear cache using enhanced service
|
||||
cleared_count = await enhanced_forecasting_service.clear_prediction_cache(
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name
|
||||
inventory_product_id=inventory_product_id
|
||||
)
|
||||
|
||||
if metrics:
|
||||
@@ -305,13 +305,13 @@ async def clear_enhanced_prediction_cache(
|
||||
|
||||
logger.info("Enhanced prediction cache cleared",
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
inventory_product_id=inventory_product_id,
|
||||
cleared_count=cleared_count)
|
||||
|
||||
return {
|
||||
"message": "Prediction cache cleared successfully",
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"cleared_count": cleared_count,
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True
|
||||
|
||||
@@ -40,7 +40,7 @@ class BakeryForecaster:
|
||||
self.database_manager = database_manager or create_database_manager(settings.DATABASE_URL, "forecasting-service")
|
||||
self.predictor = BakeryPredictor(database_manager)
|
||||
|
||||
async def generate_forecast_with_repository(self, tenant_id: str, product_name: str,
|
||||
async def generate_forecast_with_repository(self, tenant_id: str, inventory_product_id: str,
|
||||
forecast_date: date, model_id: str = None) -> Dict[str, Any]:
|
||||
"""Generate forecast with repository integration"""
|
||||
try:
|
||||
@@ -48,7 +48,7 @@ class BakeryForecaster:
|
||||
# Implementation would be added here
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"forecast_date": forecast_date.isoformat(),
|
||||
"prediction": 0.0,
|
||||
"confidence_interval": {"lower": 0.0, "upper": 0.0},
|
||||
|
||||
@@ -18,7 +18,7 @@ class Forecast(Base):
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
product_name = Column(String(255), nullable=False, index=True)
|
||||
inventory_product_id = Column(UUID(as_uuid=True), nullable=False, index=True) # Reference to inventory service
|
||||
location = Column(String(255), nullable=False, index=True)
|
||||
|
||||
# Forecast period
|
||||
@@ -53,7 +53,7 @@ class Forecast(Base):
|
||||
features_used = Column(JSON)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Forecast(id={self.id}, product={self.product_name}, date={self.forecast_date})>"
|
||||
return f"<Forecast(id={self.id}, inventory_product_id={self.inventory_product_id}, date={self.forecast_date})>"
|
||||
|
||||
class PredictionBatch(Base):
|
||||
"""Batch prediction requests"""
|
||||
|
||||
@@ -19,7 +19,7 @@ class ModelPerformanceMetric(Base):
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
model_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
product_name = Column(String(255), nullable=False)
|
||||
inventory_product_id = Column(UUID(as_uuid=True), nullable=False) # Reference to inventory service
|
||||
|
||||
# Performance metrics
|
||||
mae = Column(Float) # Mean Absolute Error
|
||||
@@ -48,7 +48,7 @@ class PredictionCache(Base):
|
||||
|
||||
# Cached data
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
product_name = Column(String(255), nullable=False)
|
||||
inventory_product_id = Column(UUID(as_uuid=True), nullable=False) # Reference to inventory service
|
||||
location = Column(String(255), nullable=False)
|
||||
forecast_date = Column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
@@ -64,4 +64,4 @@ class PredictionCache(Base):
|
||||
hit_count = Column(Integer, default=0)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<PredictionCache(key={self.cache_key}, product={self.product_name})>"
|
||||
return f"<PredictionCache(key={self.cache_key}, inventory_product_id={self.inventory_product_id})>"
|
||||
|
||||
@@ -34,21 +34,21 @@ class ForecastingBaseRepository(BaseRepository):
|
||||
)
|
||||
return await self.get_multi(skip=skip, limit=limit)
|
||||
|
||||
async def get_by_product_name(
|
||||
async def get_by_inventory_product_id(
|
||||
self,
|
||||
tenant_id: str,
|
||||
product_name: str,
|
||||
inventory_product_id: str,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List:
|
||||
"""Get records by tenant and product"""
|
||||
if hasattr(self.model, 'product_name'):
|
||||
"""Get records by tenant and inventory product"""
|
||||
if hasattr(self.model, 'inventory_product_id'):
|
||||
return await self.get_multi(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
filters={
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name
|
||||
"inventory_product_id": inventory_product_id
|
||||
},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
@@ -163,17 +163,17 @@ class ForecastingBaseRepository(BaseRepository):
|
||||
|
||||
# Get records by product if applicable
|
||||
product_stats = {}
|
||||
if hasattr(self.model, 'product_name'):
|
||||
if hasattr(self.model, 'inventory_product_id'):
|
||||
product_query = text(f"""
|
||||
SELECT product_name, COUNT(*) as count
|
||||
SELECT inventory_product_id, COUNT(*) as count
|
||||
FROM {table_name}
|
||||
WHERE tenant_id = :tenant_id
|
||||
GROUP BY product_name
|
||||
GROUP BY inventory_product_id
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(product_query, {"tenant_id": tenant_id})
|
||||
product_stats = {row.product_name: row.count for row in result.fetchall()}
|
||||
product_stats = {row.inventory_product_id: row.count for row in result.fetchall()}
|
||||
|
||||
return {
|
||||
"total_records": total_records,
|
||||
@@ -206,11 +206,11 @@ class ForecastingBaseRepository(BaseRepository):
|
||||
if not isinstance(tenant_id, str) or len(tenant_id) < 1:
|
||||
errors.append("Invalid tenant_id format")
|
||||
|
||||
# Validate product_name if present
|
||||
if "product_name" in data and data["product_name"]:
|
||||
product_name = data["product_name"]
|
||||
if not isinstance(product_name, str) or len(product_name) < 1:
|
||||
errors.append("Invalid product_name format")
|
||||
# Validate inventory_product_id if present
|
||||
if "inventory_product_id" in data and data["inventory_product_id"]:
|
||||
inventory_product_id = data["inventory_product_id"]
|
||||
if not isinstance(inventory_product_id, str) or len(inventory_product_id) < 1:
|
||||
errors.append("Invalid inventory_product_id format")
|
||||
|
||||
# Validate dates if present - accept datetime objects, date objects, and date strings
|
||||
date_fields = ["forecast_date", "created_at", "evaluation_date", "expires_at"]
|
||||
|
||||
@@ -29,7 +29,7 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
# Validate forecast data
|
||||
validation_result = self._validate_forecast_data(
|
||||
forecast_data,
|
||||
["tenant_id", "product_name", "location", "forecast_date",
|
||||
["tenant_id", "inventory_product_id", "location", "forecast_date",
|
||||
"predicted_demand", "confidence_lower", "confidence_upper", "model_id"]
|
||||
)
|
||||
|
||||
@@ -50,7 +50,7 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
logger.info("Forecast created successfully",
|
||||
forecast_id=forecast.id,
|
||||
tenant_id=forecast.tenant_id,
|
||||
product_name=forecast.product_name,
|
||||
inventory_product_id=forecast.inventory_product_id,
|
||||
forecast_date=forecast.forecast_date.isoformat())
|
||||
|
||||
return forecast
|
||||
@@ -60,7 +60,7 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
except Exception as e:
|
||||
logger.error("Failed to create forecast",
|
||||
tenant_id=forecast_data.get("tenant_id"),
|
||||
product_name=forecast_data.get("product_name"),
|
||||
inventory_product_id=forecast_data.get("inventory_product_id"),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create forecast: {str(e)}")
|
||||
|
||||
@@ -69,15 +69,15 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
tenant_id: str,
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
product_name: str = None,
|
||||
inventory_product_id: str = None,
|
||||
location: str = None
|
||||
) -> List[Forecast]:
|
||||
"""Get forecasts within a date range"""
|
||||
try:
|
||||
filters = {"tenant_id": tenant_id}
|
||||
|
||||
if product_name:
|
||||
filters["product_name"] = product_name
|
||||
if inventory_product_id:
|
||||
filters["inventory_product_id"] = inventory_product_id
|
||||
if location:
|
||||
filters["location"] = location
|
||||
|
||||
@@ -100,14 +100,14 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
async def get_latest_forecast_for_product(
|
||||
self,
|
||||
tenant_id: str,
|
||||
product_name: str,
|
||||
inventory_product_id: str,
|
||||
location: str = None
|
||||
) -> Optional[Forecast]:
|
||||
"""Get the most recent forecast for a product"""
|
||||
try:
|
||||
filters = {
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name
|
||||
"inventory_product_id": inventory_product_id
|
||||
}
|
||||
if location:
|
||||
filters["location"] = location
|
||||
@@ -124,7 +124,7 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
except Exception as e:
|
||||
logger.error("Failed to get latest forecast for product",
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get latest forecast: {str(e)}")
|
||||
|
||||
@@ -132,7 +132,7 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
self,
|
||||
tenant_id: str,
|
||||
forecast_date: date,
|
||||
product_name: str = None
|
||||
inventory_product_id: str = None
|
||||
) -> List[Forecast]:
|
||||
"""Get all forecasts for a specific date"""
|
||||
try:
|
||||
@@ -154,7 +154,7 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
async def get_forecast_accuracy_metrics(
|
||||
self,
|
||||
tenant_id: str,
|
||||
product_name: str = None,
|
||||
inventory_product_id: str = None,
|
||||
days_back: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""Get forecast accuracy metrics"""
|
||||
@@ -168,9 +168,9 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
"cutoff_date": cutoff_date
|
||||
}
|
||||
|
||||
if product_name:
|
||||
conditions.append("product_name = :product_name")
|
||||
params["product_name"] = product_name
|
||||
if inventory_product_id:
|
||||
conditions.append("inventory_product_id = :inventory_product_id")
|
||||
params["inventory_product_id"] = inventory_product_id
|
||||
|
||||
query_text = f"""
|
||||
SELECT
|
||||
@@ -180,7 +180,7 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
MAX(predicted_demand) as max_predicted_demand,
|
||||
AVG(confidence_upper - confidence_lower) as avg_confidence_interval,
|
||||
AVG(processing_time_ms) as avg_processing_time_ms,
|
||||
COUNT(DISTINCT product_name) as unique_products,
|
||||
COUNT(DISTINCT inventory_product_id) as unique_products,
|
||||
COUNT(DISTINCT model_id) as unique_models
|
||||
FROM forecasts
|
||||
WHERE {' AND '.join(conditions)}
|
||||
@@ -233,7 +233,7 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
async def get_demand_trends(
|
||||
self,
|
||||
tenant_id: str,
|
||||
product_name: str,
|
||||
inventory_product_id: str,
|
||||
days_back: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""Get demand trends for a product"""
|
||||
@@ -249,7 +249,7 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
COUNT(*) as forecast_count
|
||||
FROM forecasts
|
||||
WHERE tenant_id = :tenant_id
|
||||
AND product_name = :product_name
|
||||
AND inventory_product_id = :inventory_product_id
|
||||
AND forecast_date >= :cutoff_date
|
||||
GROUP BY DATE(forecast_date)
|
||||
ORDER BY date DESC
|
||||
@@ -257,7 +257,7 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
|
||||
result = await self.session.execute(text(query_text), {
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"cutoff_date": cutoff_date
|
||||
})
|
||||
|
||||
@@ -280,7 +280,7 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
trend_direction = "stable"
|
||||
|
||||
return {
|
||||
"product_name": product_name,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"period_days": days_back,
|
||||
"trends": trends,
|
||||
"trend_direction": trend_direction,
|
||||
@@ -290,10 +290,10 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
except Exception as e:
|
||||
logger.error("Failed to get demand trends",
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"product_name": product_name,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"period_days": days_back,
|
||||
"trends": [],
|
||||
"trend_direction": "unknown",
|
||||
@@ -311,7 +311,7 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
COUNT(*) as usage_count,
|
||||
AVG(predicted_demand) as avg_prediction,
|
||||
MAX(forecast_date) as last_used,
|
||||
COUNT(DISTINCT product_name) as products_covered
|
||||
COUNT(DISTINCT inventory_product_id) as products_covered
|
||||
FROM forecasts
|
||||
WHERE tenant_id = :tenant_id
|
||||
GROUP BY model_id, algorithm
|
||||
@@ -403,7 +403,7 @@ class ForecastRepository(ForecastingBaseRepository):
|
||||
# Validate each forecast
|
||||
validation_result = self._validate_forecast_data(
|
||||
forecast_data,
|
||||
["tenant_id", "product_name", "location", "forecast_date",
|
||||
["tenant_id", "inventory_product_id", "location", "forecast_date",
|
||||
"predicted_demand", "confidence_lower", "confidence_upper", "model_id"]
|
||||
)
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ class PerformanceMetricRepository(ForecastingBaseRepository):
|
||||
# Validate metric data
|
||||
validation_result = self._validate_forecast_data(
|
||||
metric_data,
|
||||
["model_id", "tenant_id", "product_name", "evaluation_date"]
|
||||
["model_id", "tenant_id", "inventory_product_id", "evaluation_date"]
|
||||
)
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
@@ -41,7 +41,7 @@ class PerformanceMetricRepository(ForecastingBaseRepository):
|
||||
metric_id=metric.id,
|
||||
model_id=metric.model_id,
|
||||
tenant_id=metric.tenant_id,
|
||||
product_name=metric.product_name)
|
||||
inventory_product_id=metric.inventory_product_id)
|
||||
|
||||
return metric
|
||||
|
||||
@@ -93,7 +93,7 @@ class PerformanceMetricRepository(ForecastingBaseRepository):
|
||||
async def get_performance_trends(
|
||||
self,
|
||||
tenant_id: str,
|
||||
product_name: str = None,
|
||||
inventory_product_id: str = None,
|
||||
days: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""Get performance trends over time"""
|
||||
@@ -109,14 +109,14 @@ class PerformanceMetricRepository(ForecastingBaseRepository):
|
||||
"start_date": start_date
|
||||
}
|
||||
|
||||
if product_name:
|
||||
conditions.append("product_name = :product_name")
|
||||
params["product_name"] = product_name
|
||||
if inventory_product_id:
|
||||
conditions.append("inventory_product_id = :inventory_product_id")
|
||||
params["inventory_product_id"] = inventory_product_id
|
||||
|
||||
query_text = f"""
|
||||
SELECT
|
||||
DATE(evaluation_date) as date,
|
||||
product_name,
|
||||
inventory_product_id,
|
||||
AVG(mae) as avg_mae,
|
||||
AVG(mape) as avg_mape,
|
||||
AVG(rmse) as avg_rmse,
|
||||
@@ -124,8 +124,8 @@ class PerformanceMetricRepository(ForecastingBaseRepository):
|
||||
COUNT(*) as measurement_count
|
||||
FROM model_performance_metrics
|
||||
WHERE {' AND '.join(conditions)}
|
||||
GROUP BY DATE(evaluation_date), product_name
|
||||
ORDER BY date DESC, product_name
|
||||
GROUP BY DATE(evaluation_date), inventory_product_id
|
||||
ORDER BY date DESC, inventory_product_id
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), params)
|
||||
@@ -134,7 +134,7 @@ class PerformanceMetricRepository(ForecastingBaseRepository):
|
||||
for row in result.fetchall():
|
||||
trends.append({
|
||||
"date": row.date.isoformat() if row.date else None,
|
||||
"product_name": row.product_name,
|
||||
"inventory_product_id": row.inventory_product_id,
|
||||
"metrics": {
|
||||
"avg_mae": float(row.avg_mae) if row.avg_mae else None,
|
||||
"avg_mape": float(row.avg_mape) if row.avg_mape else None,
|
||||
@@ -146,7 +146,7 @@ class PerformanceMetricRepository(ForecastingBaseRepository):
|
||||
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"period_days": days,
|
||||
"trends": trends,
|
||||
"total_measurements": len(trends)
|
||||
@@ -155,11 +155,11 @@ class PerformanceMetricRepository(ForecastingBaseRepository):
|
||||
except Exception as e:
|
||||
logger.error("Failed to get performance trends",
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"period_days": days,
|
||||
"trends": [],
|
||||
"total_measurements": 0
|
||||
|
||||
@@ -27,18 +27,18 @@ class PredictionCacheRepository(ForecastingBaseRepository):
|
||||
def _generate_cache_key(
|
||||
self,
|
||||
tenant_id: str,
|
||||
product_name: str,
|
||||
inventory_product_id: str,
|
||||
location: str,
|
||||
forecast_date: datetime
|
||||
) -> str:
|
||||
"""Generate cache key for prediction"""
|
||||
key_data = f"{tenant_id}:{product_name}:{location}:{forecast_date.isoformat()}"
|
||||
key_data = f"{tenant_id}:{inventory_product_id}:{location}:{forecast_date.isoformat()}"
|
||||
return hashlib.md5(key_data.encode()).hexdigest()
|
||||
|
||||
async def cache_prediction(
|
||||
self,
|
||||
tenant_id: str,
|
||||
product_name: str,
|
||||
inventory_product_id: str,
|
||||
location: str,
|
||||
forecast_date: datetime,
|
||||
predicted_demand: float,
|
||||
@@ -49,13 +49,13 @@ class PredictionCacheRepository(ForecastingBaseRepository):
|
||||
) -> PredictionCache:
|
||||
"""Cache a prediction result"""
|
||||
try:
|
||||
cache_key = self._generate_cache_key(tenant_id, product_name, location, forecast_date)
|
||||
cache_key = self._generate_cache_key(tenant_id, inventory_product_id, location, forecast_date)
|
||||
expires_at = datetime.utcnow() + timedelta(hours=expires_in_hours)
|
||||
|
||||
cache_data = {
|
||||
"cache_key": cache_key,
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"location": location,
|
||||
"forecast_date": forecast_date,
|
||||
"predicted_demand": predicted_demand,
|
||||
@@ -80,20 +80,20 @@ class PredictionCacheRepository(ForecastingBaseRepository):
|
||||
except Exception as e:
|
||||
logger.error("Failed to cache prediction",
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to cache prediction: {str(e)}")
|
||||
|
||||
async def get_cached_prediction(
|
||||
self,
|
||||
tenant_id: str,
|
||||
product_name: str,
|
||||
inventory_product_id: str,
|
||||
location: str,
|
||||
forecast_date: datetime
|
||||
) -> Optional[PredictionCache]:
|
||||
"""Get cached prediction if valid"""
|
||||
try:
|
||||
cache_key = self._generate_cache_key(tenant_id, product_name, location, forecast_date)
|
||||
cache_key = self._generate_cache_key(tenant_id, inventory_product_id, location, forecast_date)
|
||||
|
||||
cache_entry = await self.get_by_field("cache_key", cache_key)
|
||||
|
||||
@@ -119,14 +119,14 @@ class PredictionCacheRepository(ForecastingBaseRepository):
|
||||
except Exception as e:
|
||||
logger.error("Failed to get cached prediction",
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e))
|
||||
return None
|
||||
|
||||
async def invalidate_cache(
|
||||
self,
|
||||
tenant_id: str,
|
||||
product_name: str = None,
|
||||
inventory_product_id: str = None,
|
||||
location: str = None
|
||||
) -> int:
|
||||
"""Invalidate cache entries"""
|
||||
@@ -134,9 +134,9 @@ class PredictionCacheRepository(ForecastingBaseRepository):
|
||||
conditions = ["tenant_id = :tenant_id"]
|
||||
params = {"tenant_id": tenant_id}
|
||||
|
||||
if product_name:
|
||||
conditions.append("product_name = :product_name")
|
||||
params["product_name"] = product_name
|
||||
if inventory_product_id:
|
||||
conditions.append("inventory_product_id = :inventory_product_id")
|
||||
params["inventory_product_id"] = inventory_product_id
|
||||
|
||||
if location:
|
||||
conditions.append("location = :location")
|
||||
@@ -152,7 +152,7 @@ class PredictionCacheRepository(ForecastingBaseRepository):
|
||||
|
||||
logger.info("Cache invalidated",
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
inventory_product_id=inventory_product_id,
|
||||
location=location,
|
||||
invalidated_count=invalidated_count)
|
||||
|
||||
@@ -204,7 +204,7 @@ class PredictionCacheRepository(ForecastingBaseRepository):
|
||||
SUM(hit_count) as total_hits,
|
||||
AVG(hit_count) as avg_hits_per_entry,
|
||||
MAX(hit_count) as max_hits,
|
||||
COUNT(DISTINCT product_name) as unique_products
|
||||
COUNT(DISTINCT inventory_product_id) as unique_products
|
||||
FROM prediction_cache
|
||||
{base_filter}
|
||||
""")
|
||||
@@ -268,7 +268,7 @@ class PredictionCacheRepository(ForecastingBaseRepository):
|
||||
|
||||
query_text = f"""
|
||||
SELECT
|
||||
product_name,
|
||||
inventory_product_id,
|
||||
location,
|
||||
hit_count,
|
||||
predicted_demand,
|
||||
@@ -285,7 +285,7 @@ class PredictionCacheRepository(ForecastingBaseRepository):
|
||||
popular_predictions = []
|
||||
for row in result.fetchall():
|
||||
popular_predictions.append({
|
||||
"product_name": row.product_name,
|
||||
"inventory_product_id": row.inventory_product_id,
|
||||
"location": row.location,
|
||||
"hit_count": int(row.hit_count),
|
||||
"predicted_demand": float(row.predicted_demand),
|
||||
|
||||
@@ -22,7 +22,8 @@ class AlertType(str, Enum):
|
||||
|
||||
class ForecastRequest(BaseModel):
|
||||
"""Request schema for generating forecasts"""
|
||||
product_name: str = Field(..., description="Product name")
|
||||
inventory_product_id: str = Field(..., description="Inventory product UUID reference")
|
||||
# product_name: str = Field(..., description="Product name") # DEPRECATED - use inventory_product_id
|
||||
forecast_date: date = Field(..., description="Starting date for forecast")
|
||||
forecast_days: int = Field(1, ge=1, le=30, description="Number of days to forecast")
|
||||
location: str = Field(..., description="Location identifier")
|
||||
@@ -40,14 +41,15 @@ class BatchForecastRequest(BaseModel):
|
||||
"""Request schema for batch forecasting"""
|
||||
tenant_id: str = Field(..., description="Tenant ID")
|
||||
batch_name: str = Field(..., description="Batch name for tracking")
|
||||
products: List[str] = Field(..., description="List of product names")
|
||||
inventory_product_ids: List[str] = Field(..., description="List of inventory product IDs")
|
||||
forecast_days: int = Field(7, ge=1, le=30, description="Number of days to forecast")
|
||||
|
||||
class ForecastResponse(BaseModel):
|
||||
"""Response schema for forecast results"""
|
||||
id: str
|
||||
tenant_id: str
|
||||
product_name: str
|
||||
inventory_product_id: str # Reference to inventory service
|
||||
# product_name: str # Can be fetched from inventory service if needed for display
|
||||
location: str
|
||||
forecast_date: datetime
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ class EnhancedForecastingService:
|
||||
logger.error("Batch forecast generation failed", error=str(e))
|
||||
raise
|
||||
|
||||
async def get_tenant_forecasts(self, tenant_id: str, product_name: str = None,
|
||||
async def get_tenant_forecasts(self, tenant_id: str, inventory_product_id: str = None,
|
||||
start_date: date = None, end_date: date = None,
|
||||
skip: int = 0, limit: int = 100) -> List[Dict]:
|
||||
"""Get tenant forecasts with filtering"""
|
||||
@@ -149,7 +149,7 @@ class EnhancedForecastingService:
|
||||
logger.error("Batch predictions failed", error=str(e))
|
||||
raise
|
||||
|
||||
async def get_cached_predictions(self, tenant_id: str, product_name: str = None,
|
||||
async def get_cached_predictions(self, tenant_id: str, inventory_product_id: str = None,
|
||||
skip: int = 0, limit: int = 100) -> List[Dict]:
|
||||
"""Get cached predictions"""
|
||||
try:
|
||||
@@ -159,7 +159,7 @@ class EnhancedForecastingService:
|
||||
logger.error("Failed to get cached predictions", error=str(e))
|
||||
raise
|
||||
|
||||
async def clear_prediction_cache(self, tenant_id: str, product_name: str = None) -> int:
|
||||
async def clear_prediction_cache(self, tenant_id: str, inventory_product_id: str = None) -> int:
|
||||
"""Clear prediction cache"""
|
||||
try:
|
||||
# Implementation would use repository pattern
|
||||
@@ -195,7 +195,7 @@ class EnhancedForecastingService:
|
||||
try:
|
||||
logger.info("Generating enhanced forecast",
|
||||
tenant_id=tenant_id,
|
||||
product=request.product_name,
|
||||
inventory_product_id=request.inventory_product_id,
|
||||
date=request.forecast_date.isoformat())
|
||||
|
||||
# Get session and initialize repositories
|
||||
@@ -204,20 +204,20 @@ class EnhancedForecastingService:
|
||||
|
||||
# Step 1: Check cache first
|
||||
cached_prediction = await repos['cache'].get_cached_prediction(
|
||||
tenant_id, request.product_name, request.location, request.forecast_date
|
||||
tenant_id, request.inventory_product_id, request.location, request.forecast_date
|
||||
)
|
||||
|
||||
if cached_prediction:
|
||||
logger.debug("Using cached prediction",
|
||||
tenant_id=tenant_id,
|
||||
product=request.product_name)
|
||||
inventory_product_id=request.inventory_product_id)
|
||||
return self._create_forecast_response_from_cache(cached_prediction)
|
||||
|
||||
# Step 2: Get model with validation
|
||||
model_data = await self._get_latest_model_with_fallback(tenant_id, request.product_name)
|
||||
model_data = await self._get_latest_model_with_fallback(tenant_id, request.inventory_product_id)
|
||||
|
||||
if not model_data:
|
||||
raise ValueError(f"No valid model available for product: {request.product_name}")
|
||||
raise ValueError(f"No valid model available for product: {request.inventory_product_id}")
|
||||
|
||||
# Step 3: Prepare features with fallbacks
|
||||
features = await self._prepare_forecast_features_with_fallbacks(tenant_id, request)
|
||||
@@ -244,7 +244,7 @@ class EnhancedForecastingService:
|
||||
|
||||
forecast_data = {
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": request.product_name,
|
||||
"inventory_product_id": request.inventory_product_id,
|
||||
"location": request.location,
|
||||
"forecast_date": forecast_datetime,
|
||||
"predicted_demand": adjusted_prediction['prediction'],
|
||||
@@ -271,7 +271,7 @@ class EnhancedForecastingService:
|
||||
# Step 7: Cache the prediction
|
||||
await repos['cache'].cache_prediction(
|
||||
tenant_id=tenant_id,
|
||||
product_name=request.product_name,
|
||||
inventory_product_id=request.inventory_product_id,
|
||||
location=request.location,
|
||||
forecast_date=forecast_datetime,
|
||||
predicted_demand=adjusted_prediction['prediction'],
|
||||
@@ -296,14 +296,14 @@ class EnhancedForecastingService:
|
||||
logger.error("Error generating enhanced forecast",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id,
|
||||
product=request.product_name,
|
||||
inventory_product_id=request.inventory_product_id,
|
||||
processing_time=processing_time)
|
||||
raise
|
||||
|
||||
async def get_forecast_history(
|
||||
self,
|
||||
tenant_id: str,
|
||||
product_name: Optional[str] = None,
|
||||
inventory_product_id: Optional[str] = None,
|
||||
start_date: Optional[date] = None,
|
||||
end_date: Optional[date] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
@@ -314,7 +314,7 @@ class EnhancedForecastingService:
|
||||
|
||||
if start_date and end_date:
|
||||
forecasts = await repos['forecast'].get_forecasts_by_date_range(
|
||||
tenant_id, start_date, end_date, product_name
|
||||
tenant_id, start_date, end_date, inventory_product_id
|
||||
)
|
||||
else:
|
||||
# Get recent forecasts (last 30 days)
|
||||
@@ -374,7 +374,7 @@ class EnhancedForecastingService:
|
||||
self,
|
||||
tenant_id: str,
|
||||
batch_name: str,
|
||||
products: List[str],
|
||||
inventory_product_ids: List[str],
|
||||
forecast_days: int = 7
|
||||
) -> Dict[str, Any]:
|
||||
"""Create batch prediction job using repository"""
|
||||
@@ -386,7 +386,7 @@ class EnhancedForecastingService:
|
||||
batch_data = {
|
||||
"tenant_id": tenant_id,
|
||||
"batch_name": batch_name,
|
||||
"total_products": len(products),
|
||||
"total_products": len(inventory_product_ids),
|
||||
"forecast_days": forecast_days,
|
||||
"status": "pending"
|
||||
}
|
||||
@@ -396,12 +396,12 @@ class EnhancedForecastingService:
|
||||
logger.info("Batch prediction created",
|
||||
batch_id=batch.id,
|
||||
tenant_id=tenant_id,
|
||||
total_products=len(products))
|
||||
total_products=len(inventory_product_ids))
|
||||
|
||||
return {
|
||||
"batch_id": str(batch.id),
|
||||
"status": batch.status,
|
||||
"total_products": len(products),
|
||||
"total_products": len(inventory_product_ids),
|
||||
"created_at": batch.requested_at.isoformat()
|
||||
}
|
||||
|
||||
@@ -423,7 +423,7 @@ class EnhancedForecastingService:
|
||||
"forecast_id": forecast.id,
|
||||
"alert_type": "high_demand",
|
||||
"severity": "high" if prediction['prediction'] > 200 else "medium",
|
||||
"message": f"High demand predicted for {forecast.product_name}: {prediction['prediction']:.1f} units"
|
||||
"message": f"High demand predicted for inventory product {forecast.inventory_product_id}: {prediction['prediction']:.1f} units"
|
||||
})
|
||||
|
||||
# Check for low demand alert
|
||||
@@ -433,7 +433,7 @@ class EnhancedForecastingService:
|
||||
"forecast_id": forecast.id,
|
||||
"alert_type": "low_demand",
|
||||
"severity": "low",
|
||||
"message": f"Low demand predicted for {forecast.product_name}: {prediction['prediction']:.1f} units"
|
||||
"message": f"Low demand predicted for inventory product {forecast.inventory_product_id}: {prediction['prediction']:.1f} units"
|
||||
})
|
||||
|
||||
# Check for stockout risk (very low prediction with narrow confidence interval)
|
||||
@@ -444,7 +444,7 @@ class EnhancedForecastingService:
|
||||
"forecast_id": forecast.id,
|
||||
"alert_type": "stockout_risk",
|
||||
"severity": "critical",
|
||||
"message": f"Stockout risk for {forecast.product_name}: predicted {prediction['prediction']:.1f} units with high confidence"
|
||||
"message": f"Stockout risk for inventory product {forecast.inventory_product_id}: predicted {prediction['prediction']:.1f} units with high confidence"
|
||||
})
|
||||
|
||||
# Create alerts
|
||||
@@ -462,7 +462,7 @@ class EnhancedForecastingService:
|
||||
return ForecastResponse(
|
||||
id=str(cache_entry.id),
|
||||
tenant_id=str(cache_entry.tenant_id),
|
||||
product_name=cache_entry.product_name,
|
||||
inventory_product_id=cache_entry.inventory_product_id,
|
||||
location=cache_entry.location,
|
||||
forecast_date=cache_entry.forecast_date,
|
||||
predicted_demand=cache_entry.predicted_demand,
|
||||
@@ -486,7 +486,7 @@ class EnhancedForecastingService:
|
||||
return ForecastResponse(
|
||||
id=str(forecast.id),
|
||||
tenant_id=str(forecast.tenant_id),
|
||||
product_name=forecast.product_name,
|
||||
inventory_product_id=forecast.inventory_product_id,
|
||||
location=forecast.location,
|
||||
forecast_date=forecast.forecast_date,
|
||||
predicted_demand=forecast.predicted_demand,
|
||||
@@ -514,7 +514,7 @@ class EnhancedForecastingService:
|
||||
return {
|
||||
"id": str(forecast.id),
|
||||
"tenant_id": str(forecast.tenant_id),
|
||||
"product_name": forecast.product_name,
|
||||
"inventory_product_id": forecast.inventory_product_id,
|
||||
"location": forecast.location,
|
||||
"forecast_date": forecast.forecast_date.isoformat(),
|
||||
"predicted_demand": forecast.predicted_demand,
|
||||
@@ -527,17 +527,17 @@ class EnhancedForecastingService:
|
||||
}
|
||||
|
||||
# Additional helper methods from original service
|
||||
async def _get_latest_model_with_fallback(self, tenant_id: str, product_name: str) -> Optional[Dict[str, Any]]:
|
||||
async def _get_latest_model_with_fallback(self, tenant_id: str, inventory_product_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get the latest trained model with fallback strategies"""
|
||||
try:
|
||||
model_data = await self.model_client.get_best_model_for_forecasting(
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name
|
||||
inventory_product_id=inventory_product_id
|
||||
)
|
||||
|
||||
if model_data:
|
||||
logger.info("Found specific model for product",
|
||||
product=product_name,
|
||||
inventory_product_id=inventory_product_id,
|
||||
model_id=model_data.get('model_id'))
|
||||
return model_data
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ class ModelClient:
|
||||
async def get_best_model_for_forecasting(
|
||||
self,
|
||||
tenant_id: str,
|
||||
product_name: Optional[str] = None
|
||||
inventory_product_id: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get the best model for forecasting based on performance metrics
|
||||
@@ -71,7 +71,7 @@ class ModelClient:
|
||||
# Get latest model
|
||||
latest_model = await self.clients.training.get_active_model_for_product(
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name
|
||||
inventory_product_id=inventory_product_id
|
||||
)
|
||||
|
||||
if not latest_model:
|
||||
@@ -137,7 +137,7 @@ class ModelClient:
|
||||
logger.info("Found fallback model for tenant",
|
||||
tenant_id=tenant_id,
|
||||
model_id=best_model.get('id', 'unknown'),
|
||||
product=best_model.get('product_name', 'unknown'))
|
||||
inventory_product_id=best_model.get('inventory_product_id', 'unknown'))
|
||||
return best_model
|
||||
|
||||
logger.warning("No fallback models available for tenant", tenant_id=tenant_id)
|
||||
|
||||
@@ -38,7 +38,7 @@ class PredictionService:
|
||||
async def validate_prediction_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate prediction request"""
|
||||
try:
|
||||
required_fields = ["product_name", "model_id", "features"]
|
||||
required_fields = ["inventory_product_id", "model_id", "features"]
|
||||
missing_fields = [field for field in required_fields if field not in request]
|
||||
|
||||
if missing_fields:
|
||||
|
||||
Reference in New Issue
Block a user