Fix new services implementation 3

This commit is contained in:
Urtzi Alfaro
2025-08-14 16:47:34 +02:00
parent 0951547e92
commit 03737430ee
51 changed files with 657 additions and 982 deletions

View File

@@ -56,7 +56,7 @@ async def create_enhanced_single_forecast(
logger.info("Generating enhanced single forecast",
tenant_id=tenant_id,
product_name=request.product_name,
inventory_product_id=request.inventory_product_id,
forecast_date=request.forecast_date.isoformat())
# Record metrics
@@ -124,13 +124,13 @@ async def create_enhanced_batch_forecast(
logger.info("Generating enhanced batch forecasts",
tenant_id=tenant_id,
products_count=len(request.products),
forecast_dates_count=len(request.forecast_dates))
products_count=len(request.inventory_product_ids),
forecast_dates_count=request.forecast_days)
# Record metrics
if metrics:
metrics.increment_counter("enhanced_batch_forecasts_total")
metrics.histogram("enhanced_batch_forecast_products_count", len(request.products))
metrics.histogram("enhanced_batch_forecast_products_count", len(request.inventory_product_ids))
# Generate batch forecasts using enhanced service
batch_result = await enhanced_forecasting_service.generate_batch_forecasts(
@@ -174,7 +174,7 @@ async def create_enhanced_batch_forecast(
@track_execution_time("enhanced_get_forecasts_duration_seconds", "forecasting-service")
async def get_enhanced_tenant_forecasts(
tenant_id: str = Path(..., description="Tenant ID"),
product_name: Optional[str] = Query(None, description="Filter by product name"),
inventory_product_id: Optional[str] = Query(None, description="Filter by inventory product ID"),
start_date: Optional[date] = Query(None, description="Start date filter"),
end_date: Optional[date] = Query(None, description="End date filter"),
skip: int = Query(0, description="Number of records to skip"),
@@ -203,7 +203,7 @@ async def get_enhanced_tenant_forecasts(
# Get forecasts using enhanced service
forecasts = await enhanced_forecasting_service.get_tenant_forecasts(
tenant_id=tenant_id,
product_name=product_name,
inventory_product_id=inventory_product_id,
start_date=start_date,
end_date=end_date,
skip=skip,
@@ -218,7 +218,7 @@ async def get_enhanced_tenant_forecasts(
"forecasts": forecasts,
"total_returned": len(forecasts),
"filters": {
"product_name": product_name,
"inventory_product_id": inventory_product_id,
"start_date": start_date.isoformat() if start_date else None,
"end_date": end_date.isoformat() if end_date else None
},

View File

@@ -59,14 +59,14 @@ async def generate_enhanced_realtime_prediction(
logger.info("Generating enhanced real-time prediction",
tenant_id=tenant_id,
product_name=prediction_request.get("product_name"))
inventory_product_id=prediction_request.get("inventory_product_id"))
# Record metrics
if metrics:
metrics.increment_counter("enhanced_realtime_predictions_total")
# Validate required fields
required_fields = ["product_name", "model_id", "features"]
required_fields = ["inventory_product_id", "model_id", "features"]
missing_fields = [field for field in required_fields if field not in prediction_request]
if missing_fields:
raise HTTPException(
@@ -91,7 +91,7 @@ async def generate_enhanced_realtime_prediction(
return {
"tenant_id": tenant_id,
"product_name": prediction_request["product_name"],
"inventory_product_id": prediction_request["inventory_product_id"],
"model_id": prediction_request["model_id"],
"prediction": prediction_result,
"generated_at": datetime.now().isoformat(),
@@ -205,7 +205,7 @@ async def generate_enhanced_batch_predictions(
@track_execution_time("enhanced_get_prediction_cache_duration_seconds", "forecasting-service")
async def get_enhanced_prediction_cache(
tenant_id: str = Path(..., description="Tenant ID"),
product_name: Optional[str] = Query(None, description="Filter by product name"),
inventory_product_id: Optional[str] = Query(None, description="Filter by inventory product ID"),
skip: int = Query(0, description="Number of records to skip"),
limit: int = Query(100, description="Number of records to return"),
request_obj: Request = None,
@@ -232,7 +232,7 @@ async def get_enhanced_prediction_cache(
# Get cached predictions using enhanced service
cached_predictions = await enhanced_forecasting_service.get_cached_predictions(
tenant_id=tenant_id,
product_name=product_name,
inventory_product_id=inventory_product_id,
skip=skip,
limit=limit
)
@@ -245,7 +245,7 @@ async def get_enhanced_prediction_cache(
"cached_predictions": cached_predictions,
"total_returned": len(cached_predictions),
"filters": {
"product_name": product_name
"inventory_product_id": inventory_product_id
},
"pagination": {
"skip": skip,
@@ -271,7 +271,7 @@ async def get_enhanced_prediction_cache(
@track_execution_time("enhanced_clear_prediction_cache_duration_seconds", "forecasting-service")
async def clear_enhanced_prediction_cache(
tenant_id: str = Path(..., description="Tenant ID"),
product_name: Optional[str] = Query(None, description="Clear cache for specific product"),
inventory_product_id: Optional[str] = Query(None, description="Clear cache for specific inventory product ID"),
request_obj: Request = None,
current_tenant: str = Depends(get_current_tenant_id_dep),
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
@@ -296,7 +296,7 @@ async def clear_enhanced_prediction_cache(
# Clear cache using enhanced service
cleared_count = await enhanced_forecasting_service.clear_prediction_cache(
tenant_id=tenant_id,
product_name=product_name
inventory_product_id=inventory_product_id
)
if metrics:
@@ -305,13 +305,13 @@ async def clear_enhanced_prediction_cache(
logger.info("Enhanced prediction cache cleared",
tenant_id=tenant_id,
product_name=product_name,
inventory_product_id=inventory_product_id,
cleared_count=cleared_count)
return {
"message": "Prediction cache cleared successfully",
"tenant_id": tenant_id,
"product_name": product_name,
"inventory_product_id": inventory_product_id,
"cleared_count": cleared_count,
"enhanced_features": True,
"repository_integration": True

View File

@@ -40,7 +40,7 @@ class BakeryForecaster:
self.database_manager = database_manager or create_database_manager(settings.DATABASE_URL, "forecasting-service")
self.predictor = BakeryPredictor(database_manager)
async def generate_forecast_with_repository(self, tenant_id: str, product_name: str,
async def generate_forecast_with_repository(self, tenant_id: str, inventory_product_id: str,
forecast_date: date, model_id: str = None) -> Dict[str, Any]:
"""Generate forecast with repository integration"""
try:
@@ -48,7 +48,7 @@ class BakeryForecaster:
# Implementation would be added here
return {
"tenant_id": tenant_id,
"product_name": product_name,
"inventory_product_id": inventory_product_id,
"forecast_date": forecast_date.isoformat(),
"prediction": 0.0,
"confidence_interval": {"lower": 0.0, "upper": 0.0},

View File

@@ -18,7 +18,7 @@ class Forecast(Base):
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
product_name = Column(String(255), nullable=False, index=True)
inventory_product_id = Column(UUID(as_uuid=True), nullable=False, index=True) # Reference to inventory service
location = Column(String(255), nullable=False, index=True)
# Forecast period
@@ -53,7 +53,7 @@ class Forecast(Base):
features_used = Column(JSON)
def __repr__(self):
return f"<Forecast(id={self.id}, product={self.product_name}, date={self.forecast_date})>"
return f"<Forecast(id={self.id}, inventory_product_id={self.inventory_product_id}, date={self.forecast_date})>"
class PredictionBatch(Base):
"""Batch prediction requests"""

View File

@@ -19,7 +19,7 @@ class ModelPerformanceMetric(Base):
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
model_id = Column(UUID(as_uuid=True), nullable=False, index=True)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
product_name = Column(String(255), nullable=False)
inventory_product_id = Column(UUID(as_uuid=True), nullable=False) # Reference to inventory service
# Performance metrics
mae = Column(Float) # Mean Absolute Error
@@ -48,7 +48,7 @@ class PredictionCache(Base):
# Cached data
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
product_name = Column(String(255), nullable=False)
inventory_product_id = Column(UUID(as_uuid=True), nullable=False) # Reference to inventory service
location = Column(String(255), nullable=False)
forecast_date = Column(DateTime(timezone=True), nullable=False)
@@ -64,4 +64,4 @@ class PredictionCache(Base):
hit_count = Column(Integer, default=0)
def __repr__(self):
return f"<PredictionCache(key={self.cache_key}, product={self.product_name})>"
return f"<PredictionCache(key={self.cache_key}, inventory_product_id={self.inventory_product_id})>"

View File

@@ -34,21 +34,21 @@ class ForecastingBaseRepository(BaseRepository):
)
return await self.get_multi(skip=skip, limit=limit)
async def get_by_product_name(
async def get_by_inventory_product_id(
self,
tenant_id: str,
product_name: str,
inventory_product_id: str,
skip: int = 0,
limit: int = 100
) -> List:
"""Get records by tenant and product"""
if hasattr(self.model, 'product_name'):
"""Get records by tenant and inventory product"""
if hasattr(self.model, 'inventory_product_id'):
return await self.get_multi(
skip=skip,
limit=limit,
filters={
"tenant_id": tenant_id,
"product_name": product_name
"inventory_product_id": inventory_product_id
},
order_by="created_at",
order_desc=True
@@ -163,17 +163,17 @@ class ForecastingBaseRepository(BaseRepository):
# Get records by product if applicable
product_stats = {}
if hasattr(self.model, 'product_name'):
if hasattr(self.model, 'inventory_product_id'):
product_query = text(f"""
SELECT product_name, COUNT(*) as count
SELECT inventory_product_id, COUNT(*) as count
FROM {table_name}
WHERE tenant_id = :tenant_id
GROUP BY product_name
GROUP BY inventory_product_id
ORDER BY count DESC
""")
result = await self.session.execute(product_query, {"tenant_id": tenant_id})
product_stats = {row.product_name: row.count for row in result.fetchall()}
product_stats = {row.inventory_product_id: row.count for row in result.fetchall()}
return {
"total_records": total_records,
@@ -206,11 +206,11 @@ class ForecastingBaseRepository(BaseRepository):
if not isinstance(tenant_id, str) or len(tenant_id) < 1:
errors.append("Invalid tenant_id format")
# Validate product_name if present
if "product_name" in data and data["product_name"]:
product_name = data["product_name"]
if not isinstance(product_name, str) or len(product_name) < 1:
errors.append("Invalid product_name format")
# Validate inventory_product_id if present
if "inventory_product_id" in data and data["inventory_product_id"]:
inventory_product_id = data["inventory_product_id"]
if not isinstance(inventory_product_id, str) or len(inventory_product_id) < 1:
errors.append("Invalid inventory_product_id format")
# Validate dates if present - accept datetime objects, date objects, and date strings
date_fields = ["forecast_date", "created_at", "evaluation_date", "expires_at"]

View File

@@ -29,7 +29,7 @@ class ForecastRepository(ForecastingBaseRepository):
# Validate forecast data
validation_result = self._validate_forecast_data(
forecast_data,
["tenant_id", "product_name", "location", "forecast_date",
["tenant_id", "inventory_product_id", "location", "forecast_date",
"predicted_demand", "confidence_lower", "confidence_upper", "model_id"]
)
@@ -50,7 +50,7 @@ class ForecastRepository(ForecastingBaseRepository):
logger.info("Forecast created successfully",
forecast_id=forecast.id,
tenant_id=forecast.tenant_id,
product_name=forecast.product_name,
inventory_product_id=forecast.inventory_product_id,
forecast_date=forecast.forecast_date.isoformat())
return forecast
@@ -60,7 +60,7 @@ class ForecastRepository(ForecastingBaseRepository):
except Exception as e:
logger.error("Failed to create forecast",
tenant_id=forecast_data.get("tenant_id"),
product_name=forecast_data.get("product_name"),
inventory_product_id=forecast_data.get("inventory_product_id"),
error=str(e))
raise DatabaseError(f"Failed to create forecast: {str(e)}")
@@ -69,15 +69,15 @@ class ForecastRepository(ForecastingBaseRepository):
tenant_id: str,
start_date: date,
end_date: date,
product_name: str = None,
inventory_product_id: str = None,
location: str = None
) -> List[Forecast]:
"""Get forecasts within a date range"""
try:
filters = {"tenant_id": tenant_id}
if product_name:
filters["product_name"] = product_name
if inventory_product_id:
filters["inventory_product_id"] = inventory_product_id
if location:
filters["location"] = location
@@ -100,14 +100,14 @@ class ForecastRepository(ForecastingBaseRepository):
async def get_latest_forecast_for_product(
self,
tenant_id: str,
product_name: str,
inventory_product_id: str,
location: str = None
) -> Optional[Forecast]:
"""Get the most recent forecast for a product"""
try:
filters = {
"tenant_id": tenant_id,
"product_name": product_name
"inventory_product_id": inventory_product_id
}
if location:
filters["location"] = location
@@ -124,7 +124,7 @@ class ForecastRepository(ForecastingBaseRepository):
except Exception as e:
logger.error("Failed to get latest forecast for product",
tenant_id=tenant_id,
product_name=product_name,
inventory_product_id=inventory_product_id,
error=str(e))
raise DatabaseError(f"Failed to get latest forecast: {str(e)}")
@@ -132,7 +132,7 @@ class ForecastRepository(ForecastingBaseRepository):
self,
tenant_id: str,
forecast_date: date,
product_name: str = None
inventory_product_id: str = None
) -> List[Forecast]:
"""Get all forecasts for a specific date"""
try:
@@ -154,7 +154,7 @@ class ForecastRepository(ForecastingBaseRepository):
async def get_forecast_accuracy_metrics(
self,
tenant_id: str,
product_name: str = None,
inventory_product_id: str = None,
days_back: int = 30
) -> Dict[str, Any]:
"""Get forecast accuracy metrics"""
@@ -168,9 +168,9 @@ class ForecastRepository(ForecastingBaseRepository):
"cutoff_date": cutoff_date
}
if product_name:
conditions.append("product_name = :product_name")
params["product_name"] = product_name
if inventory_product_id:
conditions.append("inventory_product_id = :inventory_product_id")
params["inventory_product_id"] = inventory_product_id
query_text = f"""
SELECT
@@ -180,7 +180,7 @@ class ForecastRepository(ForecastingBaseRepository):
MAX(predicted_demand) as max_predicted_demand,
AVG(confidence_upper - confidence_lower) as avg_confidence_interval,
AVG(processing_time_ms) as avg_processing_time_ms,
COUNT(DISTINCT product_name) as unique_products,
COUNT(DISTINCT inventory_product_id) as unique_products,
COUNT(DISTINCT model_id) as unique_models
FROM forecasts
WHERE {' AND '.join(conditions)}
@@ -233,7 +233,7 @@ class ForecastRepository(ForecastingBaseRepository):
async def get_demand_trends(
self,
tenant_id: str,
product_name: str,
inventory_product_id: str,
days_back: int = 30
) -> Dict[str, Any]:
"""Get demand trends for a product"""
@@ -249,7 +249,7 @@ class ForecastRepository(ForecastingBaseRepository):
COUNT(*) as forecast_count
FROM forecasts
WHERE tenant_id = :tenant_id
AND product_name = :product_name
AND inventory_product_id = :inventory_product_id
AND forecast_date >= :cutoff_date
GROUP BY DATE(forecast_date)
ORDER BY date DESC
@@ -257,7 +257,7 @@ class ForecastRepository(ForecastingBaseRepository):
result = await self.session.execute(text(query_text), {
"tenant_id": tenant_id,
"product_name": product_name,
"inventory_product_id": inventory_product_id,
"cutoff_date": cutoff_date
})
@@ -280,7 +280,7 @@ class ForecastRepository(ForecastingBaseRepository):
trend_direction = "stable"
return {
"product_name": product_name,
"inventory_product_id": inventory_product_id,
"period_days": days_back,
"trends": trends,
"trend_direction": trend_direction,
@@ -290,10 +290,10 @@ class ForecastRepository(ForecastingBaseRepository):
except Exception as e:
logger.error("Failed to get demand trends",
tenant_id=tenant_id,
product_name=product_name,
inventory_product_id=inventory_product_id,
error=str(e))
return {
"product_name": product_name,
"inventory_product_id": inventory_product_id,
"period_days": days_back,
"trends": [],
"trend_direction": "unknown",
@@ -311,7 +311,7 @@ class ForecastRepository(ForecastingBaseRepository):
COUNT(*) as usage_count,
AVG(predicted_demand) as avg_prediction,
MAX(forecast_date) as last_used,
COUNT(DISTINCT product_name) as products_covered
COUNT(DISTINCT inventory_product_id) as products_covered
FROM forecasts
WHERE tenant_id = :tenant_id
GROUP BY model_id, algorithm
@@ -403,7 +403,7 @@ class ForecastRepository(ForecastingBaseRepository):
# Validate each forecast
validation_result = self._validate_forecast_data(
forecast_data,
["tenant_id", "product_name", "location", "forecast_date",
["tenant_id", "inventory_product_id", "location", "forecast_date",
"predicted_demand", "confidence_lower", "confidence_upper", "model_id"]
)

View File

@@ -29,7 +29,7 @@ class PerformanceMetricRepository(ForecastingBaseRepository):
# Validate metric data
validation_result = self._validate_forecast_data(
metric_data,
["model_id", "tenant_id", "product_name", "evaluation_date"]
["model_id", "tenant_id", "inventory_product_id", "evaluation_date"]
)
if not validation_result["is_valid"]:
@@ -41,7 +41,7 @@ class PerformanceMetricRepository(ForecastingBaseRepository):
metric_id=metric.id,
model_id=metric.model_id,
tenant_id=metric.tenant_id,
product_name=metric.product_name)
inventory_product_id=metric.inventory_product_id)
return metric
@@ -93,7 +93,7 @@ class PerformanceMetricRepository(ForecastingBaseRepository):
async def get_performance_trends(
self,
tenant_id: str,
product_name: str = None,
inventory_product_id: str = None,
days: int = 30
) -> Dict[str, Any]:
"""Get performance trends over time"""
@@ -109,14 +109,14 @@ class PerformanceMetricRepository(ForecastingBaseRepository):
"start_date": start_date
}
if product_name:
conditions.append("product_name = :product_name")
params["product_name"] = product_name
if inventory_product_id:
conditions.append("inventory_product_id = :inventory_product_id")
params["inventory_product_id"] = inventory_product_id
query_text = f"""
SELECT
DATE(evaluation_date) as date,
product_name,
inventory_product_id,
AVG(mae) as avg_mae,
AVG(mape) as avg_mape,
AVG(rmse) as avg_rmse,
@@ -124,8 +124,8 @@ class PerformanceMetricRepository(ForecastingBaseRepository):
COUNT(*) as measurement_count
FROM model_performance_metrics
WHERE {' AND '.join(conditions)}
GROUP BY DATE(evaluation_date), product_name
ORDER BY date DESC, product_name
GROUP BY DATE(evaluation_date), inventory_product_id
ORDER BY date DESC, inventory_product_id
"""
result = await self.session.execute(text(query_text), params)
@@ -134,7 +134,7 @@ class PerformanceMetricRepository(ForecastingBaseRepository):
for row in result.fetchall():
trends.append({
"date": row.date.isoformat() if row.date else None,
"product_name": row.product_name,
"inventory_product_id": row.inventory_product_id,
"metrics": {
"avg_mae": float(row.avg_mae) if row.avg_mae else None,
"avg_mape": float(row.avg_mape) if row.avg_mape else None,
@@ -146,7 +146,7 @@ class PerformanceMetricRepository(ForecastingBaseRepository):
return {
"tenant_id": tenant_id,
"product_name": product_name,
"inventory_product_id": inventory_product_id,
"period_days": days,
"trends": trends,
"total_measurements": len(trends)
@@ -155,11 +155,11 @@ class PerformanceMetricRepository(ForecastingBaseRepository):
except Exception as e:
logger.error("Failed to get performance trends",
tenant_id=tenant_id,
product_name=product_name,
inventory_product_id=inventory_product_id,
error=str(e))
return {
"tenant_id": tenant_id,
"product_name": product_name,
"inventory_product_id": inventory_product_id,
"period_days": days,
"trends": [],
"total_measurements": 0

View File

@@ -27,18 +27,18 @@ class PredictionCacheRepository(ForecastingBaseRepository):
def _generate_cache_key(
self,
tenant_id: str,
product_name: str,
inventory_product_id: str,
location: str,
forecast_date: datetime
) -> str:
"""Generate cache key for prediction"""
key_data = f"{tenant_id}:{product_name}:{location}:{forecast_date.isoformat()}"
key_data = f"{tenant_id}:{inventory_product_id}:{location}:{forecast_date.isoformat()}"
return hashlib.md5(key_data.encode()).hexdigest()
async def cache_prediction(
self,
tenant_id: str,
product_name: str,
inventory_product_id: str,
location: str,
forecast_date: datetime,
predicted_demand: float,
@@ -49,13 +49,13 @@ class PredictionCacheRepository(ForecastingBaseRepository):
) -> PredictionCache:
"""Cache a prediction result"""
try:
cache_key = self._generate_cache_key(tenant_id, product_name, location, forecast_date)
cache_key = self._generate_cache_key(tenant_id, inventory_product_id, location, forecast_date)
expires_at = datetime.utcnow() + timedelta(hours=expires_in_hours)
cache_data = {
"cache_key": cache_key,
"tenant_id": tenant_id,
"product_name": product_name,
"inventory_product_id": inventory_product_id,
"location": location,
"forecast_date": forecast_date,
"predicted_demand": predicted_demand,
@@ -80,20 +80,20 @@ class PredictionCacheRepository(ForecastingBaseRepository):
except Exception as e:
logger.error("Failed to cache prediction",
tenant_id=tenant_id,
product_name=product_name,
inventory_product_id=inventory_product_id,
error=str(e))
raise DatabaseError(f"Failed to cache prediction: {str(e)}")
async def get_cached_prediction(
self,
tenant_id: str,
product_name: str,
inventory_product_id: str,
location: str,
forecast_date: datetime
) -> Optional[PredictionCache]:
"""Get cached prediction if valid"""
try:
cache_key = self._generate_cache_key(tenant_id, product_name, location, forecast_date)
cache_key = self._generate_cache_key(tenant_id, inventory_product_id, location, forecast_date)
cache_entry = await self.get_by_field("cache_key", cache_key)
@@ -119,14 +119,14 @@ class PredictionCacheRepository(ForecastingBaseRepository):
except Exception as e:
logger.error("Failed to get cached prediction",
tenant_id=tenant_id,
product_name=product_name,
inventory_product_id=inventory_product_id,
error=str(e))
return None
async def invalidate_cache(
self,
tenant_id: str,
product_name: str = None,
inventory_product_id: str = None,
location: str = None
) -> int:
"""Invalidate cache entries"""
@@ -134,9 +134,9 @@ class PredictionCacheRepository(ForecastingBaseRepository):
conditions = ["tenant_id = :tenant_id"]
params = {"tenant_id": tenant_id}
if product_name:
conditions.append("product_name = :product_name")
params["product_name"] = product_name
if inventory_product_id:
conditions.append("inventory_product_id = :inventory_product_id")
params["inventory_product_id"] = inventory_product_id
if location:
conditions.append("location = :location")
@@ -152,7 +152,7 @@ class PredictionCacheRepository(ForecastingBaseRepository):
logger.info("Cache invalidated",
tenant_id=tenant_id,
product_name=product_name,
inventory_product_id=inventory_product_id,
location=location,
invalidated_count=invalidated_count)
@@ -204,7 +204,7 @@ class PredictionCacheRepository(ForecastingBaseRepository):
SUM(hit_count) as total_hits,
AVG(hit_count) as avg_hits_per_entry,
MAX(hit_count) as max_hits,
COUNT(DISTINCT product_name) as unique_products
COUNT(DISTINCT inventory_product_id) as unique_products
FROM prediction_cache
{base_filter}
""")
@@ -268,7 +268,7 @@ class PredictionCacheRepository(ForecastingBaseRepository):
query_text = f"""
SELECT
product_name,
inventory_product_id,
location,
hit_count,
predicted_demand,
@@ -285,7 +285,7 @@ class PredictionCacheRepository(ForecastingBaseRepository):
popular_predictions = []
for row in result.fetchall():
popular_predictions.append({
"product_name": row.product_name,
"inventory_product_id": row.inventory_product_id,
"location": row.location,
"hit_count": int(row.hit_count),
"predicted_demand": float(row.predicted_demand),

View File

@@ -22,7 +22,8 @@ class AlertType(str, Enum):
class ForecastRequest(BaseModel):
"""Request schema for generating forecasts"""
product_name: str = Field(..., description="Product name")
inventory_product_id: str = Field(..., description="Inventory product UUID reference")
# product_name: str = Field(..., description="Product name") # DEPRECATED - use inventory_product_id
forecast_date: date = Field(..., description="Starting date for forecast")
forecast_days: int = Field(1, ge=1, le=30, description="Number of days to forecast")
location: str = Field(..., description="Location identifier")
@@ -40,14 +41,15 @@ class BatchForecastRequest(BaseModel):
"""Request schema for batch forecasting"""
tenant_id: str = Field(..., description="Tenant ID")
batch_name: str = Field(..., description="Batch name for tracking")
products: List[str] = Field(..., description="List of product names")
inventory_product_ids: List[str] = Field(..., description="List of inventory product IDs")
forecast_days: int = Field(7, ge=1, le=30, description="Number of days to forecast")
class ForecastResponse(BaseModel):
"""Response schema for forecast results"""
id: str
tenant_id: str
product_name: str
inventory_product_id: str # Reference to inventory service
# product_name: str # Can be fetched from inventory service if needed for display
location: str
forecast_date: datetime

View File

@@ -78,7 +78,7 @@ class EnhancedForecastingService:
logger.error("Batch forecast generation failed", error=str(e))
raise
async def get_tenant_forecasts(self, tenant_id: str, product_name: str = None,
async def get_tenant_forecasts(self, tenant_id: str, inventory_product_id: str = None,
start_date: date = None, end_date: date = None,
skip: int = 0, limit: int = 100) -> List[Dict]:
"""Get tenant forecasts with filtering"""
@@ -149,7 +149,7 @@ class EnhancedForecastingService:
logger.error("Batch predictions failed", error=str(e))
raise
async def get_cached_predictions(self, tenant_id: str, product_name: str = None,
async def get_cached_predictions(self, tenant_id: str, inventory_product_id: str = None,
skip: int = 0, limit: int = 100) -> List[Dict]:
"""Get cached predictions"""
try:
@@ -159,7 +159,7 @@ class EnhancedForecastingService:
logger.error("Failed to get cached predictions", error=str(e))
raise
async def clear_prediction_cache(self, tenant_id: str, product_name: str = None) -> int:
async def clear_prediction_cache(self, tenant_id: str, inventory_product_id: str = None) -> int:
"""Clear prediction cache"""
try:
# Implementation would use repository pattern
@@ -195,7 +195,7 @@ class EnhancedForecastingService:
try:
logger.info("Generating enhanced forecast",
tenant_id=tenant_id,
product=request.product_name,
inventory_product_id=request.inventory_product_id,
date=request.forecast_date.isoformat())
# Get session and initialize repositories
@@ -204,20 +204,20 @@ class EnhancedForecastingService:
# Step 1: Check cache first
cached_prediction = await repos['cache'].get_cached_prediction(
tenant_id, request.product_name, request.location, request.forecast_date
tenant_id, request.inventory_product_id, request.location, request.forecast_date
)
if cached_prediction:
logger.debug("Using cached prediction",
tenant_id=tenant_id,
product=request.product_name)
inventory_product_id=request.inventory_product_id)
return self._create_forecast_response_from_cache(cached_prediction)
# Step 2: Get model with validation
model_data = await self._get_latest_model_with_fallback(tenant_id, request.product_name)
model_data = await self._get_latest_model_with_fallback(tenant_id, request.inventory_product_id)
if not model_data:
raise ValueError(f"No valid model available for product: {request.product_name}")
raise ValueError(f"No valid model available for product: {request.inventory_product_id}")
# Step 3: Prepare features with fallbacks
features = await self._prepare_forecast_features_with_fallbacks(tenant_id, request)
@@ -244,7 +244,7 @@ class EnhancedForecastingService:
forecast_data = {
"tenant_id": tenant_id,
"product_name": request.product_name,
"inventory_product_id": request.inventory_product_id,
"location": request.location,
"forecast_date": forecast_datetime,
"predicted_demand": adjusted_prediction['prediction'],
@@ -271,7 +271,7 @@ class EnhancedForecastingService:
# Step 7: Cache the prediction
await repos['cache'].cache_prediction(
tenant_id=tenant_id,
product_name=request.product_name,
inventory_product_id=request.inventory_product_id,
location=request.location,
forecast_date=forecast_datetime,
predicted_demand=adjusted_prediction['prediction'],
@@ -296,14 +296,14 @@ class EnhancedForecastingService:
logger.error("Error generating enhanced forecast",
error=str(e),
tenant_id=tenant_id,
product=request.product_name,
inventory_product_id=request.inventory_product_id,
processing_time=processing_time)
raise
async def get_forecast_history(
self,
tenant_id: str,
product_name: Optional[str] = None,
inventory_product_id: Optional[str] = None,
start_date: Optional[date] = None,
end_date: Optional[date] = None
) -> List[Dict[str, Any]]:
@@ -314,7 +314,7 @@ class EnhancedForecastingService:
if start_date and end_date:
forecasts = await repos['forecast'].get_forecasts_by_date_range(
tenant_id, start_date, end_date, product_name
tenant_id, start_date, end_date, inventory_product_id
)
else:
# Get recent forecasts (last 30 days)
@@ -374,7 +374,7 @@ class EnhancedForecastingService:
self,
tenant_id: str,
batch_name: str,
products: List[str],
inventory_product_ids: List[str],
forecast_days: int = 7
) -> Dict[str, Any]:
"""Create batch prediction job using repository"""
@@ -386,7 +386,7 @@ class EnhancedForecastingService:
batch_data = {
"tenant_id": tenant_id,
"batch_name": batch_name,
"total_products": len(products),
"total_products": len(inventory_product_ids),
"forecast_days": forecast_days,
"status": "pending"
}
@@ -396,12 +396,12 @@ class EnhancedForecastingService:
logger.info("Batch prediction created",
batch_id=batch.id,
tenant_id=tenant_id,
total_products=len(products))
total_products=len(inventory_product_ids))
return {
"batch_id": str(batch.id),
"status": batch.status,
"total_products": len(products),
"total_products": len(inventory_product_ids),
"created_at": batch.requested_at.isoformat()
}
@@ -423,7 +423,7 @@ class EnhancedForecastingService:
"forecast_id": forecast.id,
"alert_type": "high_demand",
"severity": "high" if prediction['prediction'] > 200 else "medium",
"message": f"High demand predicted for {forecast.product_name}: {prediction['prediction']:.1f} units"
"message": f"High demand predicted for inventory product {forecast.inventory_product_id}: {prediction['prediction']:.1f} units"
})
# Check for low demand alert
@@ -433,7 +433,7 @@ class EnhancedForecastingService:
"forecast_id": forecast.id,
"alert_type": "low_demand",
"severity": "low",
"message": f"Low demand predicted for {forecast.product_name}: {prediction['prediction']:.1f} units"
"message": f"Low demand predicted for inventory product {forecast.inventory_product_id}: {prediction['prediction']:.1f} units"
})
# Check for stockout risk (very low prediction with narrow confidence interval)
@@ -444,7 +444,7 @@ class EnhancedForecastingService:
"forecast_id": forecast.id,
"alert_type": "stockout_risk",
"severity": "critical",
"message": f"Stockout risk for {forecast.product_name}: predicted {prediction['prediction']:.1f} units with high confidence"
"message": f"Stockout risk for inventory product {forecast.inventory_product_id}: predicted {prediction['prediction']:.1f} units with high confidence"
})
# Create alerts
@@ -462,7 +462,7 @@ class EnhancedForecastingService:
return ForecastResponse(
id=str(cache_entry.id),
tenant_id=str(cache_entry.tenant_id),
product_name=cache_entry.product_name,
inventory_product_id=cache_entry.inventory_product_id,
location=cache_entry.location,
forecast_date=cache_entry.forecast_date,
predicted_demand=cache_entry.predicted_demand,
@@ -486,7 +486,7 @@ class EnhancedForecastingService:
return ForecastResponse(
id=str(forecast.id),
tenant_id=str(forecast.tenant_id),
product_name=forecast.product_name,
inventory_product_id=forecast.inventory_product_id,
location=forecast.location,
forecast_date=forecast.forecast_date,
predicted_demand=forecast.predicted_demand,
@@ -514,7 +514,7 @@ class EnhancedForecastingService:
return {
"id": str(forecast.id),
"tenant_id": str(forecast.tenant_id),
"product_name": forecast.product_name,
"inventory_product_id": forecast.inventory_product_id,
"location": forecast.location,
"forecast_date": forecast.forecast_date.isoformat(),
"predicted_demand": forecast.predicted_demand,
@@ -527,17 +527,17 @@ class EnhancedForecastingService:
}
# Additional helper methods from original service
async def _get_latest_model_with_fallback(self, tenant_id: str, product_name: str) -> Optional[Dict[str, Any]]:
async def _get_latest_model_with_fallback(self, tenant_id: str, inventory_product_id: str) -> Optional[Dict[str, Any]]:
"""Get the latest trained model with fallback strategies"""
try:
model_data = await self.model_client.get_best_model_for_forecasting(
tenant_id=tenant_id,
product_name=product_name
inventory_product_id=inventory_product_id
)
if model_data:
logger.info("Found specific model for product",
product=product_name,
inventory_product_id=inventory_product_id,
model_id=model_data.get('model_id'))
return model_data

View File

@@ -62,7 +62,7 @@ class ModelClient:
async def get_best_model_for_forecasting(
self,
tenant_id: str,
product_name: Optional[str] = None
inventory_product_id: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""
Get the best model for forecasting based on performance metrics
@@ -71,7 +71,7 @@ class ModelClient:
# Get latest model
latest_model = await self.clients.training.get_active_model_for_product(
tenant_id=tenant_id,
product_name=product_name
inventory_product_id=inventory_product_id
)
if not latest_model:
@@ -137,7 +137,7 @@ class ModelClient:
logger.info("Found fallback model for tenant",
tenant_id=tenant_id,
model_id=best_model.get('id', 'unknown'),
product=best_model.get('product_name', 'unknown'))
inventory_product_id=best_model.get('inventory_product_id', 'unknown'))
return best_model
logger.warning("No fallback models available for tenant", tenant_id=tenant_id)

View File

@@ -38,7 +38,7 @@ class PredictionService:
async def validate_prediction_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""Validate prediction request"""
try:
required_fields = ["product_name", "model_id", "features"]
required_fields = ["inventory_product_id", "model_id", "features"]
missing_fields = [field for field in required_fields if field not in request]
if missing_fields: