Initial commit - production deployment
This commit is contained in:
0
services/forecasting/app/__init__.py
Normal file
0
services/forecasting/app/__init__.py
Normal file
27
services/forecasting/app/api/__init__.py
Normal file
27
services/forecasting/app/api/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""
|
||||
Forecasting API Layer
|
||||
HTTP endpoints for demand forecasting and prediction operations
|
||||
"""
|
||||
|
||||
from .forecasts import router as forecasts_router
|
||||
from .forecasting_operations import router as forecasting_operations_router
|
||||
from .analytics import router as analytics_router
|
||||
from .validation import router as validation_router
|
||||
from .historical_validation import router as historical_validation_router
|
||||
from .webhooks import router as webhooks_router
|
||||
from .performance_monitoring import router as performance_monitoring_router
|
||||
from .retraining import router as retraining_router
|
||||
from .enterprise_forecasting import router as enterprise_forecasting_router
|
||||
|
||||
|
||||
__all__ = [
|
||||
"forecasts_router",
|
||||
"forecasting_operations_router",
|
||||
"analytics_router",
|
||||
"validation_router",
|
||||
"historical_validation_router",
|
||||
"webhooks_router",
|
||||
"performance_monitoring_router",
|
||||
"retraining_router",
|
||||
"enterprise_forecasting_router",
|
||||
]
|
||||
55
services/forecasting/app/api/analytics.py
Normal file
55
services/forecasting/app/api/analytics.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# services/forecasting/app/api/analytics.py
|
||||
"""
|
||||
Forecasting Analytics API - Reporting, statistics, and insights
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query, Path
|
||||
from datetime import date
|
||||
from typing import Optional
|
||||
|
||||
from app.services.prediction_service import PredictionService
|
||||
from shared.database.base import create_database_manager
|
||||
from app.core.config import settings
|
||||
from shared.routing import RouteBuilder
|
||||
from shared.auth.access_control import analytics_tier_required
|
||||
|
||||
route_builder = RouteBuilder('forecasting')
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter(tags=["forecasting-analytics"])
|
||||
|
||||
|
||||
def get_enhanced_prediction_service():
|
||||
"""Dependency injection for enhanced PredictionService"""
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "forecasting-service")
|
||||
return PredictionService(database_manager)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_analytics_route("predictions-performance")
|
||||
)
|
||||
@analytics_tier_required
|
||||
async def get_predictions_performance(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
start_date: Optional[date] = Query(None),
|
||||
end_date: Optional[date] = Query(None),
|
||||
prediction_service: PredictionService = Depends(get_enhanced_prediction_service)
|
||||
):
|
||||
"""Get predictions performance analytics (Professional+ tier required)"""
|
||||
try:
|
||||
logger.info("Getting predictions performance", tenant_id=tenant_id)
|
||||
|
||||
performance = await prediction_service.get_performance_metrics(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
return performance
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get predictions performance", error=str(e), tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve predictions performance"
|
||||
)
|
||||
237
services/forecasting/app/api/audit.py
Normal file
237
services/forecasting/app/api/audit.py
Normal file
@@ -0,0 +1,237 @@
|
||||
# services/forecasting/app/api/audit.py
|
||||
"""
|
||||
Audit Logs API - Retrieve audit trail for forecasting service
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Path, status
|
||||
from typing import Optional, Dict, Any
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
import structlog
|
||||
from sqlalchemy import select, func, and_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import AuditLog
|
||||
from shared.auth.decorators import get_current_user_dep
|
||||
from shared.auth.access_control import require_user_role
|
||||
from shared.routing import RouteBuilder
|
||||
from shared.models.audit_log_schemas import (
|
||||
AuditLogResponse,
|
||||
AuditLogListResponse,
|
||||
AuditLogStatsResponse
|
||||
)
|
||||
from app.core.database import database_manager
|
||||
|
||||
route_builder = RouteBuilder('forecasting')
|
||||
router = APIRouter(tags=["audit-logs"])
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
async def get_db():
|
||||
"""Database session dependency"""
|
||||
async with database_manager.get_session() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("audit-logs"),
|
||||
response_model=AuditLogListResponse
|
||||
)
|
||||
@require_user_role(['admin', 'owner'])
|
||||
async def get_audit_logs(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
start_date: Optional[datetime] = Query(None, description="Filter logs from this date"),
|
||||
end_date: Optional[datetime] = Query(None, description="Filter logs until this date"),
|
||||
user_id: Optional[UUID] = Query(None, description="Filter by user ID"),
|
||||
action: Optional[str] = Query(None, description="Filter by action type"),
|
||||
resource_type: Optional[str] = Query(None, description="Filter by resource type"),
|
||||
severity: Optional[str] = Query(None, description="Filter by severity level"),
|
||||
search: Optional[str] = Query(None, description="Search in description field"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="Number of records to return"),
|
||||
offset: int = Query(0, ge=0, description="Number of records to skip"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get audit logs for forecasting service.
|
||||
Requires admin or owner role.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Retrieving audit logs",
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user.get("user_id"),
|
||||
filters={
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"action": action,
|
||||
"resource_type": resource_type,
|
||||
"severity": severity
|
||||
}
|
||||
)
|
||||
|
||||
# Build query filters
|
||||
filters = [AuditLog.tenant_id == tenant_id]
|
||||
|
||||
if start_date:
|
||||
filters.append(AuditLog.created_at >= start_date)
|
||||
if end_date:
|
||||
filters.append(AuditLog.created_at <= end_date)
|
||||
if user_id:
|
||||
filters.append(AuditLog.user_id == user_id)
|
||||
if action:
|
||||
filters.append(AuditLog.action == action)
|
||||
if resource_type:
|
||||
filters.append(AuditLog.resource_type == resource_type)
|
||||
if severity:
|
||||
filters.append(AuditLog.severity == severity)
|
||||
if search:
|
||||
filters.append(AuditLog.description.ilike(f"%{search}%"))
|
||||
|
||||
# Count total matching records
|
||||
count_query = select(func.count()).select_from(AuditLog).where(and_(*filters))
|
||||
total_result = await db.execute(count_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
# Fetch paginated results
|
||||
query = (
|
||||
select(AuditLog)
|
||||
.where(and_(*filters))
|
||||
.order_by(AuditLog.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
|
||||
result = await db.execute(query)
|
||||
audit_logs = result.scalars().all()
|
||||
|
||||
# Convert to response models
|
||||
items = [AuditLogResponse.from_orm(log) for log in audit_logs]
|
||||
|
||||
logger.info(
|
||||
"Successfully retrieved audit logs",
|
||||
tenant_id=tenant_id,
|
||||
total=total,
|
||||
returned=len(items)
|
||||
)
|
||||
|
||||
return AuditLogListResponse(
|
||||
items=items,
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
has_more=(offset + len(items)) < total
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to retrieve audit logs",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to retrieve audit logs: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("audit-logs/stats"),
|
||||
response_model=AuditLogStatsResponse
|
||||
)
|
||||
@require_user_role(['admin', 'owner'])
|
||||
async def get_audit_log_stats(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
start_date: Optional[datetime] = Query(None, description="Filter logs from this date"),
|
||||
end_date: Optional[datetime] = Query(None, description="Filter logs until this date"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get audit log statistics for forecasting service.
|
||||
Requires admin or owner role.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Retrieving audit log statistics",
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
# Build base filters
|
||||
filters = [AuditLog.tenant_id == tenant_id]
|
||||
if start_date:
|
||||
filters.append(AuditLog.created_at >= start_date)
|
||||
if end_date:
|
||||
filters.append(AuditLog.created_at <= end_date)
|
||||
|
||||
# Total events
|
||||
count_query = select(func.count()).select_from(AuditLog).where(and_(*filters))
|
||||
total_result = await db.execute(count_query)
|
||||
total_events = total_result.scalar() or 0
|
||||
|
||||
# Events by action
|
||||
action_query = (
|
||||
select(AuditLog.action, func.count().label('count'))
|
||||
.where(and_(*filters))
|
||||
.group_by(AuditLog.action)
|
||||
)
|
||||
action_result = await db.execute(action_query)
|
||||
events_by_action = {row.action: row.count for row in action_result}
|
||||
|
||||
# Events by severity
|
||||
severity_query = (
|
||||
select(AuditLog.severity, func.count().label('count'))
|
||||
.where(and_(*filters))
|
||||
.group_by(AuditLog.severity)
|
||||
)
|
||||
severity_result = await db.execute(severity_query)
|
||||
events_by_severity = {row.severity: row.count for row in severity_result}
|
||||
|
||||
# Events by resource type
|
||||
resource_query = (
|
||||
select(AuditLog.resource_type, func.count().label('count'))
|
||||
.where(and_(*filters))
|
||||
.group_by(AuditLog.resource_type)
|
||||
)
|
||||
resource_result = await db.execute(resource_query)
|
||||
events_by_resource_type = {row.resource_type: row.count for row in resource_result}
|
||||
|
||||
# Date range
|
||||
date_range_query = (
|
||||
select(
|
||||
func.min(AuditLog.created_at).label('min_date'),
|
||||
func.max(AuditLog.created_at).label('max_date')
|
||||
)
|
||||
.where(and_(*filters))
|
||||
)
|
||||
date_result = await db.execute(date_range_query)
|
||||
date_row = date_result.one()
|
||||
|
||||
logger.info(
|
||||
"Successfully retrieved audit log statistics",
|
||||
tenant_id=tenant_id,
|
||||
total_events=total_events
|
||||
)
|
||||
|
||||
return AuditLogStatsResponse(
|
||||
total_events=total_events,
|
||||
events_by_action=events_by_action,
|
||||
events_by_severity=events_by_severity,
|
||||
events_by_resource_type=events_by_resource_type,
|
||||
date_range={
|
||||
"min": date_row.min_date,
|
||||
"max": date_row.max_date
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to retrieve audit log statistics",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to retrieve audit log statistics: {str(e)}"
|
||||
)
|
||||
108
services/forecasting/app/api/enterprise_forecasting.py
Normal file
108
services/forecasting/app/api/enterprise_forecasting.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
Enterprise forecasting API endpoints
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from typing import Optional
|
||||
from datetime import date
|
||||
import structlog
|
||||
|
||||
from app.services.enterprise_forecasting_service import EnterpriseForecastingService
|
||||
from shared.auth.tenant_access import verify_tenant_permission_dep
|
||||
from shared.clients import get_forecast_client, get_tenant_client
|
||||
import shared.redis_utils
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter()
|
||||
|
||||
# Global Redis client
|
||||
_redis_client = None
|
||||
|
||||
|
||||
async def get_forecasting_redis_client():
|
||||
"""Get or create Redis client"""
|
||||
global _redis_client
|
||||
try:
|
||||
if _redis_client is None:
|
||||
_redis_client = await shared.redis_utils.initialize_redis(settings.REDIS_URL)
|
||||
logger.info("Redis client initialized for enterprise forecasting")
|
||||
return _redis_client
|
||||
except Exception as e:
|
||||
logger.warning("Failed to initialize Redis client, enterprise forecasting will work with limited functionality", error=str(e))
|
||||
return None
|
||||
|
||||
|
||||
async def get_enterprise_forecasting_service(
|
||||
redis_client = Depends(get_forecasting_redis_client)
|
||||
) -> EnterpriseForecastingService:
|
||||
"""Dependency injection for EnterpriseForecastingService"""
|
||||
forecast_client = get_forecast_client(settings, "forecasting-service")
|
||||
tenant_client = get_tenant_client(settings, "forecasting-service")
|
||||
return EnterpriseForecastingService(
|
||||
forecast_client=forecast_client,
|
||||
tenant_client=tenant_client,
|
||||
redis_client=redis_client
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}/forecasting/enterprise/aggregated")
|
||||
async def get_aggregated_forecast(
|
||||
tenant_id: str,
|
||||
start_date: date = Query(..., description="Start date for forecast aggregation"),
|
||||
end_date: date = Query(..., description="End date for forecast aggregation"),
|
||||
product_id: Optional[str] = Query(None, description="Optional product ID to filter by"),
|
||||
enterprise_forecasting_service: EnterpriseForecastingService = Depends(get_enterprise_forecasting_service),
|
||||
verified_tenant: str = Depends(verify_tenant_permission_dep)
|
||||
):
|
||||
"""
|
||||
Get aggregated forecasts across parent and child tenants
|
||||
"""
|
||||
try:
|
||||
# Check if this tenant is a parent tenant
|
||||
tenant_info = await enterprise_forecasting_service.tenant_client.get_tenant(tenant_id)
|
||||
if tenant_info.get('tenant_type') != 'parent':
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Only parent tenants can access aggregated enterprise forecasts"
|
||||
)
|
||||
|
||||
result = await enterprise_forecasting_service.get_aggregated_forecast(
|
||||
parent_tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
product_id=product_id
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get aggregated forecast: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}/forecasting/enterprise/network-performance")
|
||||
async def get_network_performance_metrics(
|
||||
tenant_id: str,
|
||||
start_date: date = Query(..., description="Start date for metrics"),
|
||||
end_date: date = Query(..., description="End date for metrics"),
|
||||
enterprise_forecasting_service: EnterpriseForecastingService = Depends(get_enterprise_forecasting_service),
|
||||
verified_tenant: str = Depends(verify_tenant_permission_dep)
|
||||
):
|
||||
"""
|
||||
Get aggregated performance metrics across tenant network
|
||||
"""
|
||||
try:
|
||||
# Check if this tenant is a parent tenant
|
||||
tenant_info = await enterprise_forecasting_service.tenant_client.get_tenant(tenant_id)
|
||||
if tenant_info.get('tenant_type') != 'parent':
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Only parent tenants can access network performance metrics"
|
||||
)
|
||||
|
||||
result = await enterprise_forecasting_service.get_network_performance_metrics(
|
||||
parent_tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get network performance: {str(e)}")
|
||||
417
services/forecasting/app/api/forecast_feedback.py
Normal file
417
services/forecasting/app/api/forecast_feedback.py
Normal file
@@ -0,0 +1,417 @@
|
||||
# services/forecasting/app/api/forecast_feedback.py
|
||||
"""
|
||||
Forecast Feedback API - Endpoints for collecting and analyzing forecast feedback
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query, Path, Body
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import date, datetime
|
||||
import uuid
|
||||
import enum
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.services.forecast_feedback_service import ForecastFeedbackService
|
||||
from shared.database.base import create_database_manager
|
||||
from app.core.config import settings
|
||||
from shared.routing import RouteBuilder
|
||||
from shared.auth.tenant_access import verify_tenant_permission_dep
|
||||
|
||||
route_builder = RouteBuilder('forecasting')
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter(tags=["forecast-feedback"])
|
||||
|
||||
|
||||
# Enums for feedback types
|
||||
class FeedbackType(str, enum.Enum):
|
||||
"""Type of feedback on forecast accuracy"""
|
||||
TOO_HIGH = "too_high"
|
||||
TOO_LOW = "too_low"
|
||||
ACCURATE = "accurate"
|
||||
UNCERTAIN = "uncertain"
|
||||
|
||||
|
||||
class FeedbackConfidence(str, enum.Enum):
|
||||
"""Confidence level of the feedback provider"""
|
||||
LOW = "low"
|
||||
MEDIUM = "medium"
|
||||
HIGH = "high"
|
||||
|
||||
|
||||
# Pydantic models
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class ForecastFeedbackRequest(BaseModel):
|
||||
"""Request model for submitting forecast feedback"""
|
||||
feedback_type: FeedbackType = Field(..., description="Type of feedback on forecast accuracy")
|
||||
confidence: FeedbackConfidence = Field(..., description="Confidence level of the feedback provider")
|
||||
actual_value: Optional[float] = Field(None, description="Actual observed value")
|
||||
notes: Optional[str] = Field(None, description="Additional notes about the feedback")
|
||||
feedback_data: Optional[Dict[str, Any]] = Field(None, description="Additional feedback data")
|
||||
|
||||
|
||||
class ForecastFeedbackResponse(BaseModel):
|
||||
"""Response model for forecast feedback"""
|
||||
feedback_id: str = Field(..., description="Unique feedback ID")
|
||||
forecast_id: str = Field(..., description="Forecast ID this feedback relates to")
|
||||
tenant_id: str = Field(..., description="Tenant ID")
|
||||
feedback_type: FeedbackType = Field(..., description="Type of feedback")
|
||||
confidence: FeedbackConfidence = Field(..., description="Confidence level")
|
||||
actual_value: Optional[float] = Field(None, description="Actual value observed")
|
||||
notes: Optional[str] = Field(None, description="Feedback notes")
|
||||
feedback_data: Dict[str, Any] = Field(..., description="Additional feedback data")
|
||||
created_at: datetime = Field(..., description="When feedback was created")
|
||||
created_by: Optional[str] = Field(None, description="Who created the feedback")
|
||||
|
||||
|
||||
class ForecastAccuracyMetrics(BaseModel):
|
||||
"""Accuracy metrics for a forecast"""
|
||||
forecast_id: str = Field(..., description="Forecast ID")
|
||||
total_feedback_count: int = Field(..., description="Total feedback received")
|
||||
accuracy_score: float = Field(..., description="Calculated accuracy score (0-100)")
|
||||
feedback_distribution: Dict[str, int] = Field(..., description="Distribution of feedback types")
|
||||
average_confidence: float = Field(..., description="Average confidence score")
|
||||
last_feedback_date: Optional[datetime] = Field(None, description="Most recent feedback date")
|
||||
|
||||
|
||||
class ForecasterPerformanceMetrics(BaseModel):
|
||||
"""Performance metrics for the forecasting system"""
|
||||
overall_accuracy: float = Field(..., description="Overall system accuracy score")
|
||||
total_forecasts_with_feedback: int = Field(..., description="Total forecasts with feedback")
|
||||
accuracy_by_product: Dict[str, float] = Field(..., description="Accuracy by product type")
|
||||
accuracy_trend: str = Field(..., description="Trend direction: improving, declining, stable")
|
||||
improvement_suggestions: List[str] = Field(..., description="AI-generated improvement suggestions")
|
||||
|
||||
|
||||
def get_forecast_feedback_service():
|
||||
"""Dependency injection for ForecastFeedbackService"""
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "forecasting-service")
|
||||
return ForecastFeedbackService(database_manager)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_nested_resource_route("forecasts", "forecast_id", "feedback"),
|
||||
response_model=ForecastFeedbackResponse,
|
||||
status_code=status.HTTP_201_CREATED
|
||||
)
|
||||
async def submit_forecast_feedback(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
forecast_id: str = Path(..., description="Forecast ID"),
|
||||
feedback_request: ForecastFeedbackRequest = Body(...),
|
||||
forecast_feedback_service: ForecastFeedbackService = Depends(get_forecast_feedback_service),
|
||||
verified_tenant: str = Depends(verify_tenant_permission_dep)
|
||||
):
|
||||
"""
|
||||
Submit feedback on forecast accuracy
|
||||
|
||||
Allows users to provide feedback on whether forecasts were accurate, too high, or too low.
|
||||
This feedback is used to improve future forecast accuracy through continuous learning.
|
||||
"""
|
||||
try:
|
||||
logger.info("Submitting forecast feedback",
|
||||
tenant_id=tenant_id, forecast_id=forecast_id,
|
||||
feedback_type=feedback_request.feedback_type)
|
||||
|
||||
# Validate forecast exists
|
||||
forecast_exists = await forecast_feedback_service.forecast_exists(tenant_id, forecast_id)
|
||||
if not forecast_exists:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Forecast not found"
|
||||
)
|
||||
|
||||
# Submit feedback
|
||||
feedback = await forecast_feedback_service.submit_feedback(
|
||||
tenant_id=tenant_id,
|
||||
forecast_id=forecast_id,
|
||||
feedback_type=feedback_request.feedback_type,
|
||||
confidence=feedback_request.confidence,
|
||||
actual_value=feedback_request.actual_value,
|
||||
notes=feedback_request.notes,
|
||||
feedback_data=feedback_request.feedback_data
|
||||
)
|
||||
|
||||
return {
|
||||
'feedback_id': str(feedback.feedback_id),
|
||||
'forecast_id': str(feedback.forecast_id),
|
||||
'tenant_id': feedback.tenant_id,
|
||||
'feedback_type': feedback.feedback_type,
|
||||
'confidence': feedback.confidence,
|
||||
'actual_value': feedback.actual_value,
|
||||
'notes': feedback.notes,
|
||||
'feedback_data': feedback.feedback_data or {},
|
||||
'created_at': feedback.created_at,
|
||||
'created_by': feedback.created_by
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError as e:
|
||||
logger.error("Invalid forecast ID", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid forecast ID format"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to submit forecast feedback", error=str(e), tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to submit feedback"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_nested_resource_route("forecasts", "forecast_id", "feedback"),
|
||||
response_model=List[ForecastFeedbackResponse]
|
||||
)
|
||||
async def get_forecast_feedback(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
forecast_id: str = Path(..., description="Forecast ID"),
|
||||
limit: int = Query(50, ge=1, le=1000),
|
||||
offset: int = Query(0, ge=0),
|
||||
forecast_feedback_service: ForecastFeedbackService = Depends(get_forecast_feedback_service),
|
||||
verified_tenant: str = Depends(verify_tenant_permission_dep)
|
||||
):
|
||||
"""
|
||||
Get all feedback for a specific forecast
|
||||
|
||||
Retrieves historical feedback submissions for analysis and auditing.
|
||||
"""
|
||||
try:
|
||||
logger.info("Getting forecast feedback", tenant_id=tenant_id, forecast_id=forecast_id)
|
||||
|
||||
feedback_list = await forecast_feedback_service.get_feedback_for_forecast(
|
||||
tenant_id=tenant_id,
|
||||
forecast_id=forecast_id,
|
||||
limit=limit,
|
||||
offset=offset
|
||||
)
|
||||
|
||||
return [
|
||||
ForecastFeedbackResponse(
|
||||
feedback_id=str(f.feedback_id),
|
||||
forecast_id=str(f.forecast_id),
|
||||
tenant_id=f.tenant_id,
|
||||
feedback_type=f.feedback_type,
|
||||
confidence=f.confidence,
|
||||
actual_value=f.actual_value,
|
||||
notes=f.notes,
|
||||
feedback_data=f.feedback_data or {},
|
||||
created_at=f.created_at,
|
||||
created_by=f.created_by
|
||||
) for f in feedback_list
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get forecast feedback", error=str(e), tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve feedback"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_nested_resource_route("forecasts", "forecast_id", "accuracy"),
|
||||
response_model=ForecastAccuracyMetrics
|
||||
)
|
||||
async def get_forecast_accuracy_metrics(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
forecast_id: str = Path(..., description="Forecast ID"),
|
||||
forecast_feedback_service: ForecastFeedbackService = Depends(get_forecast_feedback_service),
|
||||
verified_tenant: str = Depends(verify_tenant_permission_dep)
|
||||
):
|
||||
"""
|
||||
Get accuracy metrics for a specific forecast
|
||||
|
||||
Calculates accuracy scores based on feedback and actual vs predicted values.
|
||||
"""
|
||||
try:
|
||||
logger.info("Getting forecast accuracy metrics", tenant_id=tenant_id, forecast_id=forecast_id)
|
||||
|
||||
metrics = await forecast_feedback_service.calculate_accuracy_metrics(
|
||||
tenant_id=tenant_id,
|
||||
forecast_id=forecast_id
|
||||
)
|
||||
|
||||
if not metrics:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="No accuracy metrics available for this forecast"
|
||||
)
|
||||
|
||||
return {
|
||||
'forecast_id': metrics.forecast_id,
|
||||
'total_feedback_count': metrics.total_feedback_count,
|
||||
'accuracy_score': metrics.accuracy_score,
|
||||
'feedback_distribution': metrics.feedback_distribution,
|
||||
'average_confidence': metrics.average_confidence,
|
||||
'last_feedback_date': metrics.last_feedback_date
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get forecast accuracy metrics", error=str(e), tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to calculate accuracy metrics"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("forecasts", "accuracy-summary"),
|
||||
response_model=ForecasterPerformanceMetrics
|
||||
)
|
||||
async def get_forecaster_performance_summary(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
start_date: Optional[date] = Query(None, description="Start date filter"),
|
||||
end_date: Optional[date] = Query(None, description="End date filter"),
|
||||
product_id: Optional[str] = Query(None, description="Filter by product ID"),
|
||||
forecast_feedback_service: ForecastFeedbackService = Depends(get_forecast_feedback_service),
|
||||
verified_tenant: str = Depends(verify_tenant_permission_dep)
|
||||
):
|
||||
"""
|
||||
Get overall forecaster performance summary
|
||||
|
||||
Aggregates accuracy metrics across all forecasts to assess overall system performance
|
||||
and identify areas for improvement.
|
||||
"""
|
||||
try:
|
||||
logger.info("Getting forecaster performance summary", tenant_id=tenant_id)
|
||||
|
||||
metrics = await forecast_feedback_service.calculate_performance_summary(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
product_id=product_id
|
||||
)
|
||||
|
||||
return {
|
||||
'overall_accuracy': metrics.overall_accuracy,
|
||||
'total_forecasts_with_feedback': metrics.total_forecasts_with_feedback,
|
||||
'accuracy_by_product': metrics.accuracy_by_product,
|
||||
'accuracy_trend': metrics.accuracy_trend,
|
||||
'improvement_suggestions': metrics.improvement_suggestions
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get forecaster performance summary", error=str(e), tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to calculate performance summary"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("forecasts", "feedback-trends")
|
||||
)
|
||||
async def get_feedback_trends(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
days: int = Query(30, ge=7, le=365, description="Number of days to analyze"),
|
||||
product_id: Optional[str] = Query(None, description="Filter by product ID"),
|
||||
forecast_feedback_service: ForecastFeedbackService = Depends(get_forecast_feedback_service),
|
||||
verified_tenant: str = Depends(verify_tenant_permission_dep)
|
||||
):
|
||||
"""
|
||||
Get feedback trends over time
|
||||
|
||||
Analyzes how forecast accuracy and feedback patterns change over time.
|
||||
"""
|
||||
try:
|
||||
logger.info("Getting feedback trends", tenant_id=tenant_id, days=days)
|
||||
|
||||
trends = await forecast_feedback_service.get_feedback_trends(
|
||||
tenant_id=tenant_id,
|
||||
days=days,
|
||||
product_id=product_id
|
||||
)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'trends': trends,
|
||||
'period': f'Last {days} days'
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get feedback trends", error=str(e), tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve feedback trends"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_resource_action_route("forecasts", "forecast_id", "retrain")
|
||||
)
|
||||
async def trigger_retraining_from_feedback(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
forecast_id: str = Path(..., description="Forecast ID"),
|
||||
forecast_feedback_service: ForecastFeedbackService = Depends(get_forecast_feedback_service),
|
||||
verified_tenant: str = Depends(verify_tenant_permission_dep)
|
||||
):
|
||||
"""
|
||||
Trigger model retraining based on feedback
|
||||
|
||||
Initiates a retraining job using recent feedback to improve forecast accuracy.
|
||||
"""
|
||||
try:
|
||||
logger.info("Triggering retraining from feedback", tenant_id=tenant_id, forecast_id=forecast_id)
|
||||
|
||||
result = await forecast_feedback_service.trigger_retraining_from_feedback(
|
||||
tenant_id=tenant_id,
|
||||
forecast_id=forecast_id
|
||||
)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'message': 'Retraining job initiated successfully',
|
||||
'job_id': result.job_id,
|
||||
'forecasts_included': result.forecasts_included,
|
||||
'feedback_samples_used': result.feedback_samples_used
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to trigger retraining", error=str(e), tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to initiate retraining"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_resource_action_route("forecasts", "forecast_id", "suggestions")
|
||||
)
|
||||
async def get_improvement_suggestions(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
forecast_id: str = Path(..., description="Forecast ID"),
|
||||
forecast_feedback_service: ForecastFeedbackService = Depends(get_forecast_feedback_service),
|
||||
verified_tenant: str = Depends(verify_tenant_permission_dep)
|
||||
):
|
||||
"""
|
||||
Get AI-generated improvement suggestions for a forecast
|
||||
|
||||
Analyzes feedback patterns and suggests specific improvements for forecast accuracy.
|
||||
"""
|
||||
try:
|
||||
logger.info("Getting improvement suggestions", tenant_id=tenant_id, forecast_id=forecast_id)
|
||||
|
||||
suggestions = await forecast_feedback_service.get_improvement_suggestions(
|
||||
tenant_id=tenant_id,
|
||||
forecast_id=forecast_id
|
||||
)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'forecast_id': forecast_id,
|
||||
'suggestions': suggestions,
|
||||
'confidence_scores': [s.get('confidence', 0.8) for s in suggestions]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get improvement suggestions", error=str(e), tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to generate suggestions"
|
||||
)
|
||||
|
||||
|
||||
# Import datetime at runtime to avoid circular imports
|
||||
from datetime import datetime, timedelta
|
||||
1038
services/forecasting/app/api/forecasting_operations.py
Normal file
1038
services/forecasting/app/api/forecasting_operations.py
Normal file
File diff suppressed because it is too large
Load Diff
145
services/forecasting/app/api/forecasts.py
Normal file
145
services/forecasting/app/api/forecasts.py
Normal file
@@ -0,0 +1,145 @@
|
||||
# services/forecasting/app/api/forecasts.py
|
||||
"""
|
||||
Forecasts API - Atomic CRUD operations on Forecast model
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query, Path
|
||||
from typing import List, Optional
|
||||
from datetime import date, datetime
|
||||
import uuid
|
||||
|
||||
from app.services.forecasting_service import EnhancedForecastingService
|
||||
from app.schemas.forecasts import ForecastResponse
|
||||
from shared.database.base import create_database_manager
|
||||
from app.core.config import settings
|
||||
from shared.routing import RouteBuilder
|
||||
|
||||
route_builder = RouteBuilder('forecasting')
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter(tags=["forecasts"])
|
||||
|
||||
|
||||
def get_enhanced_forecasting_service():
|
||||
"""Dependency injection for EnhancedForecastingService"""
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "forecasting-service")
|
||||
return EnhancedForecastingService(database_manager)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("forecasts"),
|
||||
response_model=List[ForecastResponse]
|
||||
)
|
||||
async def list_forecasts(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
inventory_product_id: Optional[str] = Query(None, description="Filter by product ID"),
|
||||
start_date: Optional[date] = Query(None, description="Start date filter"),
|
||||
end_date: Optional[date] = Query(None, description="End date filter"),
|
||||
limit: int = Query(50, ge=1, le=1000),
|
||||
offset: int = Query(0, ge=0),
|
||||
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
||||
):
|
||||
"""List forecasts with optional filters"""
|
||||
try:
|
||||
logger.info("Listing forecasts", tenant_id=tenant_id)
|
||||
|
||||
forecasts = await enhanced_forecasting_service.list_forecasts(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
offset=offset
|
||||
)
|
||||
|
||||
return forecasts
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to list forecasts", error=str(e), tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve forecasts"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_resource_detail_route("forecasts", "forecast_id"),
|
||||
response_model=ForecastResponse
|
||||
)
|
||||
async def get_forecast(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
forecast_id: str = Path(..., description="Forecast ID"),
|
||||
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
||||
):
|
||||
"""Get a specific forecast by ID"""
|
||||
try:
|
||||
logger.info("Getting forecast", tenant_id=tenant_id, forecast_id=forecast_id)
|
||||
|
||||
forecast = await enhanced_forecasting_service.get_forecast(
|
||||
tenant_id=tenant_id,
|
||||
forecast_id=uuid.UUID(forecast_id)
|
||||
)
|
||||
|
||||
if not forecast:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Forecast not found"
|
||||
)
|
||||
|
||||
return forecast
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError as e:
|
||||
logger.error("Invalid forecast ID", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid forecast ID format"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get forecast", error=str(e), tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve forecast"
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
route_builder.build_resource_detail_route("forecasts", "forecast_id")
|
||||
)
|
||||
async def delete_forecast(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
forecast_id: str = Path(..., description="Forecast ID"),
|
||||
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
||||
):
|
||||
"""Delete a specific forecast"""
|
||||
try:
|
||||
logger.info("Deleting forecast", tenant_id=tenant_id, forecast_id=forecast_id)
|
||||
|
||||
success = await enhanced_forecasting_service.delete_forecast(
|
||||
tenant_id=tenant_id,
|
||||
forecast_id=uuid.UUID(forecast_id)
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Forecast not found"
|
||||
)
|
||||
|
||||
return {"message": "Forecast deleted successfully"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError as e:
|
||||
logger.error("Invalid forecast ID", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid forecast ID format"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to delete forecast", error=str(e), tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to delete forecast"
|
||||
)
|
||||
304
services/forecasting/app/api/historical_validation.py
Normal file
304
services/forecasting/app/api/historical_validation.py
Normal file
@@ -0,0 +1,304 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/api/historical_validation.py
|
||||
# ================================================================
|
||||
"""
|
||||
Historical Validation API - Backfill validation for late-arriving sales data
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Query, status
|
||||
from typing import Dict, Any, List, Optional
|
||||
from uuid import UUID
|
||||
from datetime import date
|
||||
import structlog
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from app.services.historical_validation_service import HistoricalValidationService
|
||||
from shared.auth.decorators import get_current_user_dep
|
||||
from shared.auth.access_control import require_user_role
|
||||
from shared.routing import RouteBuilder
|
||||
from app.core.database import get_db
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
route_builder = RouteBuilder('forecasting')
|
||||
router = APIRouter(tags=["historical-validation"])
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Request/Response Schemas
|
||||
# ================================================================
|
||||
|
||||
class DetectGapsRequest(BaseModel):
|
||||
"""Request model for gap detection"""
|
||||
lookback_days: int = Field(default=90, ge=1, le=365, description="Days to look back")
|
||||
|
||||
|
||||
class BackfillRequest(BaseModel):
|
||||
"""Request model for manual backfill"""
|
||||
start_date: date = Field(..., description="Start date for backfill")
|
||||
end_date: date = Field(..., description="End date for backfill")
|
||||
|
||||
|
||||
class SalesDataUpdateRequest(BaseModel):
|
||||
"""Request model for registering sales data update"""
|
||||
start_date: date = Field(..., description="Start date of updated data")
|
||||
end_date: date = Field(..., description="End date of updated data")
|
||||
records_affected: int = Field(..., ge=0, description="Number of records affected")
|
||||
update_source: str = Field(default="import", description="Source of update")
|
||||
import_job_id: Optional[str] = Field(None, description="Import job ID if applicable")
|
||||
auto_trigger_validation: bool = Field(default=True, description="Auto-trigger validation")
|
||||
|
||||
|
||||
class AutoBackfillRequest(BaseModel):
|
||||
"""Request model for automatic backfill"""
|
||||
lookback_days: int = Field(default=90, ge=1, le=365, description="Days to look back")
|
||||
max_gaps_to_process: int = Field(default=10, ge=1, le=50, description="Max gaps to process")
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Endpoints
|
||||
# ================================================================
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("validation/detect-gaps"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def detect_validation_gaps(
|
||||
request: DetectGapsRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Detect date ranges where forecasts exist but haven't been validated yet
|
||||
|
||||
Returns list of gap periods that need validation backfill.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Detecting validation gaps",
|
||||
tenant_id=tenant_id,
|
||||
lookback_days=request.lookback_days,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = HistoricalValidationService(db)
|
||||
|
||||
gaps = await service.detect_validation_gaps(
|
||||
tenant_id=tenant_id,
|
||||
lookback_days=request.lookback_days
|
||||
)
|
||||
|
||||
return {
|
||||
"gaps_found": len(gaps),
|
||||
"lookback_days": request.lookback_days,
|
||||
"gaps": [
|
||||
{
|
||||
"start_date": gap["start_date"].isoformat(),
|
||||
"end_date": gap["end_date"].isoformat(),
|
||||
"days_count": gap["days_count"]
|
||||
}
|
||||
for gap in gaps
|
||||
]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to detect validation gaps",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to detect validation gaps: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("validation/backfill"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner'])
|
||||
async def backfill_validation(
|
||||
request: BackfillRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Manually trigger validation backfill for a specific date range
|
||||
|
||||
Validates forecasts against sales data for historical periods.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Manual validation backfill requested",
|
||||
tenant_id=tenant_id,
|
||||
start_date=request.start_date.isoformat(),
|
||||
end_date=request.end_date.isoformat(),
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = HistoricalValidationService(db)
|
||||
|
||||
result = await service.backfill_validation(
|
||||
tenant_id=tenant_id,
|
||||
start_date=request.start_date,
|
||||
end_date=request.end_date,
|
||||
triggered_by="manual"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to backfill validation",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to backfill validation: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("validation/auto-backfill"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner'])
|
||||
async def auto_backfill_validation_gaps(
|
||||
request: AutoBackfillRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Automatically detect and backfill validation gaps
|
||||
|
||||
Finds all date ranges with missing validations and processes them.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Auto backfill requested",
|
||||
tenant_id=tenant_id,
|
||||
lookback_days=request.lookback_days,
|
||||
max_gaps=request.max_gaps_to_process,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = HistoricalValidationService(db)
|
||||
|
||||
result = await service.auto_backfill_gaps(
|
||||
tenant_id=tenant_id,
|
||||
lookback_days=request.lookback_days,
|
||||
max_gaps_to_process=request.max_gaps_to_process
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to auto backfill",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to auto backfill: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("validation/register-sales-update"),
|
||||
status_code=status.HTTP_201_CREATED
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def register_sales_data_update(
|
||||
request: SalesDataUpdateRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Register a sales data update and optionally trigger validation
|
||||
|
||||
Call this endpoint after importing historical sales data to automatically
|
||||
trigger validation for the affected date range.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Registering sales data update",
|
||||
tenant_id=tenant_id,
|
||||
date_range=f"{request.start_date} to {request.end_date}",
|
||||
records_affected=request.records_affected,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = HistoricalValidationService(db)
|
||||
|
||||
result = await service.register_sales_data_update(
|
||||
tenant_id=tenant_id,
|
||||
start_date=request.start_date,
|
||||
end_date=request.end_date,
|
||||
records_affected=request.records_affected,
|
||||
update_source=request.update_source,
|
||||
import_job_id=request.import_job_id,
|
||||
auto_trigger_validation=request.auto_trigger_validation
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to register sales data update",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to register sales data update: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("validation/pending"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def get_pending_validations(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
limit: int = Query(50, ge=1, le=100, description="Number of records to return"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get pending sales data updates awaiting validation
|
||||
|
||||
Returns list of sales data updates that have been registered
|
||||
but not yet validated.
|
||||
"""
|
||||
try:
|
||||
service = HistoricalValidationService(db)
|
||||
|
||||
pending = await service.get_pending_validations(
|
||||
tenant_id=tenant_id,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
return {
|
||||
"pending_count": len(pending),
|
||||
"pending_validations": [record.to_dict() for record in pending]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get pending validations",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get pending validations: {str(e)}"
|
||||
)
|
||||
477
services/forecasting/app/api/internal_demo.py
Normal file
477
services/forecasting/app/api/internal_demo.py
Normal file
@@ -0,0 +1,477 @@
|
||||
"""
|
||||
Internal Demo Cloning API for Forecasting Service
|
||||
Service-to-service endpoint for cloning forecast data
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Header
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
import structlog
|
||||
import uuid
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Optional
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent))
|
||||
from shared.utils.demo_dates import adjust_date_for_demo, resolve_time_marker
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.models.forecasts import Forecast, PredictionBatch
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter(prefix="/internal/demo", tags=["internal"])
|
||||
|
||||
# Base demo tenant IDs
|
||||
DEMO_TENANT_PROFESSIONAL = "a1b2c3d4-e5f6-47a8-b9c0-d1e2f3a4b5c6"
|
||||
|
||||
|
||||
def parse_date_field(date_value, session_time: datetime, field_name: str = "date") -> Optional[datetime]:
|
||||
"""
|
||||
Parse date field, handling both ISO strings and BASE_TS markers.
|
||||
|
||||
Supports:
|
||||
- BASE_TS markers: "BASE_TS + 1h30m", "BASE_TS - 2d"
|
||||
- ISO 8601 strings: "2025-01-15T06:00:00Z"
|
||||
- None values (returns None)
|
||||
|
||||
Returns timezone-aware datetime or None.
|
||||
"""
|
||||
if not date_value:
|
||||
return None
|
||||
|
||||
# Check if it's a BASE_TS marker
|
||||
if isinstance(date_value, str) and date_value.startswith("BASE_TS"):
|
||||
try:
|
||||
return resolve_time_marker(date_value, session_time)
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
f"Invalid BASE_TS marker in {field_name}",
|
||||
marker=date_value,
|
||||
error=str(e)
|
||||
)
|
||||
return None
|
||||
|
||||
# Handle regular ISO date strings
|
||||
try:
|
||||
if isinstance(date_value, str):
|
||||
original_date = datetime.fromisoformat(date_value.replace('Z', '+00:00'))
|
||||
elif hasattr(date_value, 'isoformat'):
|
||||
original_date = date_value
|
||||
else:
|
||||
logger.warning(f"Unsupported date format in {field_name}", date_value=date_value)
|
||||
return None
|
||||
|
||||
return adjust_date_for_demo(original_date, session_time)
|
||||
except (ValueError, AttributeError) as e:
|
||||
logger.warning(
|
||||
f"Invalid date format in {field_name}",
|
||||
date_value=date_value,
|
||||
error=str(e)
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def align_to_week_start(target_date: datetime) -> datetime:
|
||||
"""Align forecast date to Monday (start of week)"""
|
||||
if target_date:
|
||||
days_since_monday = target_date.weekday()
|
||||
return target_date - timedelta(days=days_since_monday)
|
||||
return target_date
|
||||
|
||||
|
||||
@router.post("/clone")
|
||||
async def clone_demo_data(
|
||||
base_tenant_id: str,
|
||||
virtual_tenant_id: str,
|
||||
demo_account_type: str,
|
||||
session_id: Optional[str] = None,
|
||||
session_created_at: Optional[str] = None,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Clone forecasting service data for a virtual demo tenant
|
||||
|
||||
This endpoint creates fresh demo data by:
|
||||
1. Loading seed data from JSON files
|
||||
2. Applying XOR-based ID transformation
|
||||
3. Adjusting dates relative to session creation time
|
||||
4. Creating records in the virtual tenant
|
||||
|
||||
Args:
|
||||
base_tenant_id: Template tenant UUID (for reference)
|
||||
virtual_tenant_id: Target virtual tenant UUID
|
||||
demo_account_type: Type of demo account
|
||||
session_id: Originating session ID for tracing
|
||||
session_created_at: Session creation timestamp for date adjustment
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Dictionary with cloning results
|
||||
|
||||
Raises:
|
||||
HTTPException: On validation or cloning errors
|
||||
"""
|
||||
start_time = datetime.now(timezone.utc)
|
||||
|
||||
try:
|
||||
# Validate UUIDs
|
||||
virtual_uuid = uuid.UUID(virtual_tenant_id)
|
||||
|
||||
# Parse session creation time for date adjustment
|
||||
if session_created_at:
|
||||
try:
|
||||
session_time = datetime.fromisoformat(session_created_at.replace('Z', '+00:00'))
|
||||
except (ValueError, AttributeError):
|
||||
session_time = start_time
|
||||
else:
|
||||
session_time = start_time
|
||||
|
||||
logger.info(
|
||||
"Starting forecasting data cloning with date adjustment",
|
||||
base_tenant_id=base_tenant_id,
|
||||
virtual_tenant_id=str(virtual_uuid),
|
||||
demo_account_type=demo_account_type,
|
||||
session_id=session_id,
|
||||
session_time=session_time.isoformat()
|
||||
)
|
||||
|
||||
# Load seed data using shared utility
|
||||
try:
|
||||
from shared.utils.seed_data_paths import get_seed_data_path
|
||||
|
||||
if demo_account_type == "enterprise":
|
||||
profile = "enterprise"
|
||||
else:
|
||||
profile = "professional"
|
||||
|
||||
json_file = get_seed_data_path(profile, "10-forecasting.json")
|
||||
|
||||
except ImportError:
|
||||
# Fallback to original path
|
||||
seed_data_dir = Path(__file__).parent.parent.parent.parent / "shared" / "demo" / "fixtures"
|
||||
if demo_account_type == "enterprise":
|
||||
json_file = seed_data_dir / "enterprise" / "parent" / "10-forecasting.json"
|
||||
else:
|
||||
json_file = seed_data_dir / "professional" / "10-forecasting.json"
|
||||
|
||||
if not json_file.exists():
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Seed data file not found: {json_file}"
|
||||
)
|
||||
|
||||
# Load JSON data
|
||||
with open(json_file, 'r', encoding='utf-8') as f:
|
||||
seed_data = json.load(f)
|
||||
|
||||
# Check if data already exists for this virtual tenant (idempotency)
|
||||
existing_check = await db.execute(
|
||||
select(Forecast).where(Forecast.tenant_id == virtual_uuid).limit(1)
|
||||
)
|
||||
existing_forecast = existing_check.scalar_one_or_none()
|
||||
|
||||
if existing_forecast:
|
||||
logger.warning(
|
||||
"Demo data already exists, skipping clone",
|
||||
virtual_tenant_id=str(virtual_uuid)
|
||||
)
|
||||
return {
|
||||
"status": "skipped",
|
||||
"reason": "Data already exists",
|
||||
"records_cloned": 0
|
||||
}
|
||||
|
||||
# Track cloning statistics
|
||||
stats = {
|
||||
"forecasts": 0,
|
||||
"prediction_batches": 0
|
||||
}
|
||||
|
||||
# Transform and insert forecasts
|
||||
for forecast_data in seed_data.get('forecasts', []):
|
||||
# Transform ID using XOR
|
||||
from shared.utils.demo_id_transformer import transform_id
|
||||
try:
|
||||
forecast_uuid = uuid.UUID(forecast_data['id'])
|
||||
tenant_uuid = uuid.UUID(virtual_tenant_id)
|
||||
transformed_id = transform_id(forecast_data['id'], tenant_uuid)
|
||||
except ValueError as e:
|
||||
logger.error("Failed to parse UUIDs for ID transformation",
|
||||
forecast_id=forecast_data['id'],
|
||||
virtual_tenant_id=virtual_tenant_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid UUID format in forecast data: {str(e)}"
|
||||
)
|
||||
|
||||
# Transform dates using the proper parse_date_field function
|
||||
for date_field in ['forecast_date', 'created_at']:
|
||||
if date_field in forecast_data:
|
||||
try:
|
||||
parsed_date = parse_date_field(
|
||||
forecast_data[date_field],
|
||||
session_time,
|
||||
date_field
|
||||
)
|
||||
if parsed_date:
|
||||
forecast_data[date_field] = parsed_date
|
||||
else:
|
||||
# If parsing fails, use session_time as fallback
|
||||
forecast_data[date_field] = session_time
|
||||
logger.warning("Using fallback date for failed parsing",
|
||||
date_field=date_field,
|
||||
original_value=forecast_data[date_field])
|
||||
except Exception as e:
|
||||
logger.warning("Failed to parse date, using fallback",
|
||||
date_field=date_field,
|
||||
date_value=forecast_data[date_field],
|
||||
error=str(e))
|
||||
forecast_data[date_field] = session_time
|
||||
|
||||
# Create forecast
|
||||
# Map product_id to inventory_product_id if needed
|
||||
inventory_product_id_str = forecast_data.get('inventory_product_id') or forecast_data.get('product_id')
|
||||
# Convert to UUID if it's a string
|
||||
if isinstance(inventory_product_id_str, str):
|
||||
inventory_product_id = uuid.UUID(inventory_product_id_str)
|
||||
else:
|
||||
inventory_product_id = inventory_product_id_str
|
||||
|
||||
# Map predicted_quantity to predicted_demand if needed
|
||||
predicted_demand = forecast_data.get('predicted_demand') or forecast_data.get('predicted_quantity')
|
||||
|
||||
# Set default location if not provided in seed data
|
||||
location = forecast_data.get('location') or "Main Bakery"
|
||||
|
||||
# Get or calculate forecast date
|
||||
forecast_date = forecast_data.get('forecast_date')
|
||||
if not forecast_date:
|
||||
forecast_date = session_time
|
||||
|
||||
# Calculate day_of_week from forecast_date if not provided
|
||||
# day_of_week should be 0-6 (Monday=0, Sunday=6)
|
||||
day_of_week = forecast_data.get('day_of_week')
|
||||
if day_of_week is None and forecast_date:
|
||||
day_of_week = forecast_date.weekday()
|
||||
|
||||
# Calculate is_weekend from day_of_week if not provided
|
||||
is_weekend = forecast_data.get('is_weekend')
|
||||
if is_weekend is None and day_of_week is not None:
|
||||
is_weekend = day_of_week >= 5 # Saturday=5, Sunday=6
|
||||
else:
|
||||
is_weekend = False
|
||||
|
||||
new_forecast = Forecast(
|
||||
id=transformed_id,
|
||||
tenant_id=virtual_uuid,
|
||||
inventory_product_id=inventory_product_id,
|
||||
product_name=forecast_data.get('product_name'),
|
||||
location=location,
|
||||
forecast_date=forecast_date,
|
||||
created_at=forecast_data.get('created_at', session_time),
|
||||
predicted_demand=predicted_demand,
|
||||
confidence_lower=forecast_data.get('confidence_lower', max(0.0, float(predicted_demand or 0.0) * 0.8)),
|
||||
confidence_upper=forecast_data.get('confidence_upper', max(0.0, float(predicted_demand or 0.0) * 1.2)),
|
||||
confidence_level=forecast_data.get('confidence_level', 0.8),
|
||||
model_id=forecast_data.get('model_id') or 'default-fallback-model',
|
||||
model_version=forecast_data.get('model_version') or '1.0',
|
||||
algorithm=forecast_data.get('algorithm', 'prophet'),
|
||||
business_type=forecast_data.get('business_type', 'individual'),
|
||||
day_of_week=day_of_week,
|
||||
is_holiday=forecast_data.get('is_holiday', False),
|
||||
is_weekend=is_weekend,
|
||||
weather_temperature=forecast_data.get('weather_temperature'),
|
||||
weather_precipitation=forecast_data.get('weather_precipitation'),
|
||||
weather_description=forecast_data.get('weather_description'),
|
||||
traffic_volume=forecast_data.get('traffic_volume'),
|
||||
processing_time_ms=forecast_data.get('processing_time_ms'),
|
||||
features_used=forecast_data.get('features_used')
|
||||
)
|
||||
db.add(new_forecast)
|
||||
stats["forecasts"] += 1
|
||||
|
||||
# Transform and insert prediction batches
|
||||
for batch_data in seed_data.get('prediction_batches', []):
|
||||
# Transform ID using XOR
|
||||
from shared.utils.demo_id_transformer import transform_id
|
||||
try:
|
||||
batch_uuid = uuid.UUID(batch_data['id'])
|
||||
tenant_uuid = uuid.UUID(virtual_tenant_id)
|
||||
transformed_id = transform_id(batch_data['id'], tenant_uuid)
|
||||
except ValueError as e:
|
||||
logger.error("Failed to parse UUIDs for ID transformation",
|
||||
batch_id=batch_data['id'],
|
||||
virtual_tenant_id=virtual_tenant_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid UUID format in batch data: {str(e)}"
|
||||
)
|
||||
|
||||
# Create prediction batch
|
||||
# Handle field mapping: batch_id -> batch_name, total_forecasts -> total_products
|
||||
batch_name = batch_data.get('batch_name') or batch_data.get('batch_id') or f"Batch-{transformed_id}"
|
||||
total_products = batch_data.get('total_products') or batch_data.get('total_forecasts') or 0
|
||||
completed_products = batch_data.get('completed_products') or (total_products if batch_data.get('status') == 'COMPLETED' else 0)
|
||||
|
||||
# Parse dates (handle created_at or prediction_date for requested_at)
|
||||
requested_at_raw = batch_data.get('requested_at') or batch_data.get('created_at') or batch_data.get('prediction_date')
|
||||
requested_at = parse_date_field(requested_at_raw, session_time, 'requested_at') if requested_at_raw else session_time
|
||||
|
||||
completed_at_raw = batch_data.get('completed_at')
|
||||
completed_at = parse_date_field(completed_at_raw, session_time, 'completed_at') if completed_at_raw else None
|
||||
|
||||
new_batch = PredictionBatch(
|
||||
id=transformed_id,
|
||||
tenant_id=virtual_uuid,
|
||||
batch_name=batch_name,
|
||||
requested_at=requested_at,
|
||||
completed_at=completed_at,
|
||||
status=batch_data.get('status', 'completed'),
|
||||
total_products=total_products,
|
||||
completed_products=completed_products,
|
||||
failed_products=batch_data.get('failed_products', 0),
|
||||
forecast_days=batch_data.get('forecast_days', 7),
|
||||
business_type=batch_data.get('business_type', 'individual'),
|
||||
error_message=batch_data.get('error_message'),
|
||||
processing_time_ms=batch_data.get('processing_time_ms'),
|
||||
cancelled_by=batch_data.get('cancelled_by')
|
||||
)
|
||||
db.add(new_batch)
|
||||
stats["prediction_batches"] += 1
|
||||
|
||||
# Commit all changes
|
||||
await db.commit()
|
||||
|
||||
total_records = sum(stats.values())
|
||||
duration_ms = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
|
||||
|
||||
logger.info(
|
||||
"Forecasting data cloned successfully",
|
||||
virtual_tenant_id=str(virtual_uuid),
|
||||
records_cloned=total_records,
|
||||
duration_ms=duration_ms,
|
||||
forecasts_cloned=stats["forecasts"],
|
||||
batches_cloned=stats["prediction_batches"]
|
||||
)
|
||||
|
||||
return {
|
||||
"service": "forecasting",
|
||||
"status": "completed",
|
||||
"records_cloned": total_records,
|
||||
"duration_ms": duration_ms,
|
||||
"details": {
|
||||
"forecasts": stats["forecasts"],
|
||||
"prediction_batches": stats["prediction_batches"],
|
||||
"virtual_tenant_id": str(virtual_uuid)
|
||||
}
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
logger.error("Invalid UUID format", error=str(e), virtual_tenant_id=virtual_tenant_id)
|
||||
raise HTTPException(status_code=400, detail=f"Invalid UUID: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to clone forecasting data",
|
||||
error=str(e),
|
||||
virtual_tenant_id=virtual_tenant_id,
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
# Rollback on error
|
||||
await db.rollback()
|
||||
|
||||
return {
|
||||
"service": "forecasting",
|
||||
"status": "failed",
|
||||
"records_cloned": 0,
|
||||
"duration_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000),
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
@router.get("/clone/health")
|
||||
async def clone_health_check():
|
||||
"""
|
||||
Health check for internal cloning endpoint
|
||||
Used by orchestrator to verify service availability
|
||||
"""
|
||||
return {
|
||||
"service": "forecasting",
|
||||
"clone_endpoint": "available",
|
||||
"version": "2.0.0"
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/tenant/{virtual_tenant_id}")
|
||||
async def delete_demo_tenant_data(
|
||||
virtual_tenant_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Delete all demo data for a virtual tenant.
|
||||
This endpoint is idempotent - safe to call multiple times.
|
||||
"""
|
||||
from sqlalchemy import delete
|
||||
|
||||
start_time = datetime.now(timezone.utc)
|
||||
|
||||
records_deleted = {
|
||||
"forecasts": 0,
|
||||
"prediction_batches": 0,
|
||||
"total": 0
|
||||
}
|
||||
|
||||
try:
|
||||
# Delete in reverse dependency order
|
||||
|
||||
# 1. Delete prediction batches
|
||||
result = await db.execute(
|
||||
delete(PredictionBatch)
|
||||
.where(PredictionBatch.tenant_id == virtual_tenant_id)
|
||||
)
|
||||
records_deleted["prediction_batches"] = result.rowcount
|
||||
|
||||
# 2. Delete forecasts
|
||||
result = await db.execute(
|
||||
delete(Forecast)
|
||||
.where(Forecast.tenant_id == virtual_tenant_id)
|
||||
)
|
||||
records_deleted["forecasts"] = result.rowcount
|
||||
|
||||
records_deleted["total"] = sum(records_deleted.values())
|
||||
|
||||
await db.commit()
|
||||
|
||||
logger.info(
|
||||
"demo_data_deleted",
|
||||
service="forecasting",
|
||||
virtual_tenant_id=str(virtual_tenant_id),
|
||||
records_deleted=records_deleted
|
||||
)
|
||||
|
||||
return {
|
||||
"service": "forecasting",
|
||||
"status": "deleted",
|
||||
"virtual_tenant_id": str(virtual_tenant_id),
|
||||
"records_deleted": records_deleted,
|
||||
"duration_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
"demo_data_deletion_failed",
|
||||
service="forecasting",
|
||||
virtual_tenant_id=str(virtual_tenant_id),
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to delete demo data: {str(e)}"
|
||||
)
|
||||
959
services/forecasting/app/api/ml_insights.py
Normal file
959
services/forecasting/app/api/ml_insights.py
Normal file
@@ -0,0 +1,959 @@
|
||||
"""
|
||||
ML Insights API Endpoints for Forecasting Service
|
||||
|
||||
Provides endpoints to trigger ML insight generation for:
|
||||
- Dynamic business rules learning
|
||||
- Demand pattern analysis
|
||||
- Seasonal trend detection
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, Request
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, List
|
||||
from uuid import UUID
|
||||
from datetime import datetime, timedelta
|
||||
import structlog
|
||||
import pandas as pd
|
||||
|
||||
from app.core.database import get_db
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/v1/tenants/{tenant_id}/forecasting/ml/insights",
|
||||
tags=["ML Insights"]
|
||||
)
|
||||
|
||||
|
||||
# ================================================================
|
||||
# REQUEST/RESPONSE SCHEMAS
|
||||
# ================================================================
|
||||
|
||||
class RulesGenerationRequest(BaseModel):
|
||||
"""Request schema for rules generation"""
|
||||
product_ids: Optional[List[str]] = Field(
|
||||
None,
|
||||
description="Specific product IDs to analyze. If None, analyzes all products"
|
||||
)
|
||||
lookback_days: int = Field(
|
||||
90,
|
||||
description="Days of historical data to analyze",
|
||||
ge=30,
|
||||
le=365
|
||||
)
|
||||
min_samples: int = Field(
|
||||
10,
|
||||
description="Minimum samples required for rule learning",
|
||||
ge=5,
|
||||
le=100
|
||||
)
|
||||
|
||||
|
||||
class RulesGenerationResponse(BaseModel):
|
||||
"""Response schema for rules generation"""
|
||||
success: bool
|
||||
message: str
|
||||
tenant_id: str
|
||||
products_analyzed: int
|
||||
total_insights_generated: int
|
||||
total_insights_posted: int
|
||||
insights_by_product: dict
|
||||
errors: List[str] = []
|
||||
|
||||
|
||||
class DemandAnalysisRequest(BaseModel):
|
||||
"""Request schema for demand analysis"""
|
||||
product_ids: Optional[List[str]] = Field(
|
||||
None,
|
||||
description="Specific product IDs to analyze. If None, analyzes all products"
|
||||
)
|
||||
lookback_days: int = Field(
|
||||
90,
|
||||
description="Days of historical data to analyze",
|
||||
ge=30,
|
||||
le=365
|
||||
)
|
||||
forecast_horizon_days: int = Field(
|
||||
30,
|
||||
description="Days to forecast ahead",
|
||||
ge=7,
|
||||
le=90
|
||||
)
|
||||
|
||||
|
||||
class DemandAnalysisResponse(BaseModel):
|
||||
"""Response schema for demand analysis"""
|
||||
success: bool
|
||||
message: str
|
||||
tenant_id: str
|
||||
products_analyzed: int
|
||||
total_insights_generated: int
|
||||
total_insights_posted: int
|
||||
insights_by_product: dict
|
||||
errors: List[str] = []
|
||||
|
||||
|
||||
class BusinessRulesAnalysisRequest(BaseModel):
|
||||
"""Request schema for business rules analysis"""
|
||||
product_ids: Optional[List[str]] = Field(
|
||||
None,
|
||||
description="Specific product IDs to analyze. If None, analyzes all products"
|
||||
)
|
||||
lookback_days: int = Field(
|
||||
90,
|
||||
description="Days of historical data to analyze",
|
||||
ge=30,
|
||||
le=365
|
||||
)
|
||||
min_samples: int = Field(
|
||||
10,
|
||||
description="Minimum samples required for rule analysis",
|
||||
ge=5,
|
||||
le=100
|
||||
)
|
||||
|
||||
|
||||
class BusinessRulesAnalysisResponse(BaseModel):
|
||||
"""Response schema for business rules analysis"""
|
||||
success: bool
|
||||
message: str
|
||||
tenant_id: str
|
||||
products_analyzed: int
|
||||
total_insights_generated: int
|
||||
total_insights_posted: int
|
||||
insights_by_product: dict
|
||||
errors: List[str] = []
|
||||
|
||||
|
||||
# ================================================================
|
||||
# API ENDPOINTS
|
||||
# ================================================================
|
||||
|
||||
@router.post("/generate-rules", response_model=RulesGenerationResponse)
|
||||
async def trigger_rules_generation(
|
||||
tenant_id: str,
|
||||
request_data: RulesGenerationRequest,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Trigger dynamic business rules learning from historical sales data.
|
||||
|
||||
This endpoint:
|
||||
1. Fetches historical sales data for specified products
|
||||
2. Runs the RulesOrchestrator to learn patterns
|
||||
3. Generates insights about optimal business rules
|
||||
4. Posts insights to AI Insights Service
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID
|
||||
request_data: Rules generation parameters
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
RulesGenerationResponse with generation results
|
||||
"""
|
||||
logger.info(
|
||||
"ML insights rules generation requested",
|
||||
tenant_id=tenant_id,
|
||||
product_ids=request_data.product_ids,
|
||||
lookback_days=request_data.lookback_days
|
||||
)
|
||||
|
||||
try:
|
||||
# Import ML orchestrator and clients
|
||||
from app.ml.rules_orchestrator import RulesOrchestrator
|
||||
from shared.clients.sales_client import SalesServiceClient
|
||||
from shared.clients.inventory_client import InventoryServiceClient
|
||||
from app.core.config import settings
|
||||
|
||||
# Get event publisher from app state
|
||||
event_publisher = getattr(request.app.state, 'event_publisher', None)
|
||||
|
||||
# Initialize orchestrator and clients
|
||||
orchestrator = RulesOrchestrator(event_publisher=event_publisher)
|
||||
inventory_client = InventoryServiceClient(settings)
|
||||
|
||||
# Get products to analyze from inventory service via API
|
||||
if request_data.product_ids:
|
||||
# Fetch specific products
|
||||
products = []
|
||||
for product_id in request_data.product_ids:
|
||||
product = await inventory_client.get_ingredient_by_id(
|
||||
ingredient_id=UUID(product_id),
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
if product:
|
||||
products.append(product)
|
||||
else:
|
||||
# Fetch all products for tenant (limit to 10)
|
||||
all_products = await inventory_client.get_all_ingredients(tenant_id=tenant_id)
|
||||
products = all_products[:10] # Limit to prevent timeout
|
||||
|
||||
if not products:
|
||||
return RulesGenerationResponse(
|
||||
success=False,
|
||||
message="No products found for analysis",
|
||||
tenant_id=tenant_id,
|
||||
products_analyzed=0,
|
||||
total_insights_generated=0,
|
||||
total_insights_posted=0,
|
||||
insights_by_product={},
|
||||
errors=["No products found"]
|
||||
)
|
||||
|
||||
# Initialize sales client to fetch historical data
|
||||
sales_client = SalesServiceClient(config=settings, calling_service_name="forecasting")
|
||||
|
||||
# Calculate date range
|
||||
end_date = datetime.utcnow()
|
||||
start_date = end_date - timedelta(days=request_data.lookback_days)
|
||||
|
||||
# Process each product
|
||||
total_insights_generated = 0
|
||||
total_insights_posted = 0
|
||||
insights_by_product = {}
|
||||
errors = []
|
||||
|
||||
for product in products:
|
||||
try:
|
||||
product_id = str(product['id'])
|
||||
product_name = product.get('name', 'Unknown')
|
||||
logger.info(f"Analyzing product {product_name} ({product_id})")
|
||||
|
||||
# Fetch sales data for product
|
||||
sales_data = await sales_client.get_sales_data(
|
||||
tenant_id=tenant_id,
|
||||
product_id=product_id,
|
||||
start_date=start_date.strftime('%Y-%m-%d'),
|
||||
end_date=end_date.strftime('%Y-%m-%d')
|
||||
)
|
||||
|
||||
if not sales_data:
|
||||
logger.warning(f"No sales data for product {product_id}")
|
||||
continue
|
||||
|
||||
# Convert to DataFrame
|
||||
sales_df = pd.DataFrame(sales_data)
|
||||
|
||||
if len(sales_df) < request_data.min_samples:
|
||||
logger.warning(
|
||||
f"Insufficient data for product {product_id}: "
|
||||
f"{len(sales_df)} samples < {request_data.min_samples} required"
|
||||
)
|
||||
continue
|
||||
|
||||
# Check what columns are available and map to expected format
|
||||
logger.debug(f"Sales data columns for product {product_id}: {sales_df.columns.tolist()}")
|
||||
|
||||
# Map common field names to 'quantity' and 'date'
|
||||
if 'quantity' not in sales_df.columns:
|
||||
if 'total_quantity' in sales_df.columns:
|
||||
sales_df['quantity'] = sales_df['total_quantity']
|
||||
elif 'amount' in sales_df.columns:
|
||||
sales_df['quantity'] = sales_df['amount']
|
||||
else:
|
||||
logger.warning(f"No quantity field found for product {product_id}, skipping")
|
||||
continue
|
||||
|
||||
if 'date' not in sales_df.columns:
|
||||
if 'sale_date' in sales_df.columns:
|
||||
sales_df['date'] = sales_df['sale_date']
|
||||
else:
|
||||
logger.warning(f"No date field found for product {product_id}, skipping")
|
||||
continue
|
||||
|
||||
# Prepare sales data with required columns
|
||||
sales_df['date'] = pd.to_datetime(sales_df['date'])
|
||||
sales_df['quantity'] = sales_df['quantity'].astype(float)
|
||||
sales_df['day_of_week'] = sales_df['date'].dt.dayofweek
|
||||
|
||||
# NOTE: Holiday detection for historical data requires:
|
||||
# 1. Tenant location context (calendar_id)
|
||||
# 2. Bulk holiday check API (currently single-date only)
|
||||
# 3. Historical calendar data
|
||||
# For real-time forecasts, holiday detection IS implemented via data_client.py
|
||||
sales_df['is_holiday'] = False
|
||||
|
||||
# NOTE: Weather data for historical analysis requires:
|
||||
# 1. Historical weather API integration
|
||||
# 2. Tenant location coordinates
|
||||
# For real-time forecasts, weather data IS fetched via external service
|
||||
sales_df['weather'] = 'unknown'
|
||||
|
||||
# Run rules learning
|
||||
results = await orchestrator.learn_and_post_rules(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=product_id,
|
||||
sales_data=sales_df,
|
||||
external_data=None,
|
||||
min_samples=request_data.min_samples
|
||||
)
|
||||
|
||||
# Track results
|
||||
total_insights_generated += results['insights_generated']
|
||||
total_insights_posted += results['insights_posted']
|
||||
insights_by_product[product_id] = {
|
||||
'product_name': product_name,
|
||||
'insights_posted': results['insights_posted'],
|
||||
'rules_learned': len(results['rules'])
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Product {product_id} analysis complete",
|
||||
insights_posted=results['insights_posted']
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error analyzing product {product_id}: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
errors.append(error_msg)
|
||||
|
||||
# Close orchestrator
|
||||
await orchestrator.close()
|
||||
|
||||
# Build response
|
||||
response = RulesGenerationResponse(
|
||||
success=total_insights_posted > 0,
|
||||
message=f"Successfully generated {total_insights_posted} insights from {len(products)} products",
|
||||
tenant_id=tenant_id,
|
||||
products_analyzed=len(products),
|
||||
total_insights_generated=total_insights_generated,
|
||||
total_insights_posted=total_insights_posted,
|
||||
insights_by_product=insights_by_product,
|
||||
errors=errors
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"ML insights rules generation complete",
|
||||
tenant_id=tenant_id,
|
||||
total_insights=total_insights_posted
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"ML insights rules generation failed",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e),
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Rules generation failed: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/analyze-demand", response_model=DemandAnalysisResponse)
|
||||
async def trigger_demand_analysis(
|
||||
tenant_id: str,
|
||||
request_data: DemandAnalysisRequest,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Trigger demand pattern analysis from historical sales data.
|
||||
|
||||
This endpoint:
|
||||
1. Fetches historical sales data for specified products
|
||||
2. Runs the DemandInsightsOrchestrator to analyze patterns
|
||||
3. Generates insights about demand forecasting optimization
|
||||
4. Posts insights to AI Insights Service
|
||||
5. Publishes events to RabbitMQ
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID
|
||||
request_data: Demand analysis parameters
|
||||
request: FastAPI request object to access app state
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
DemandAnalysisResponse with analysis results
|
||||
"""
|
||||
logger.info(
|
||||
"ML insights demand analysis requested",
|
||||
tenant_id=tenant_id,
|
||||
product_ids=request_data.product_ids,
|
||||
lookback_days=request_data.lookback_days
|
||||
)
|
||||
|
||||
try:
|
||||
# Import ML orchestrator and clients
|
||||
from app.ml.demand_insights_orchestrator import DemandInsightsOrchestrator
|
||||
from shared.clients.sales_client import SalesServiceClient
|
||||
from shared.clients.inventory_client import InventoryServiceClient
|
||||
from app.core.config import settings
|
||||
|
||||
# Get event publisher from app state
|
||||
event_publisher = getattr(request.app.state, 'event_publisher', None)
|
||||
|
||||
# Initialize orchestrator and clients
|
||||
orchestrator = DemandInsightsOrchestrator(event_publisher=event_publisher)
|
||||
inventory_client = InventoryServiceClient(settings)
|
||||
|
||||
# Get products to analyze from inventory service via API
|
||||
if request_data.product_ids:
|
||||
# Fetch specific products
|
||||
products = []
|
||||
for product_id in request_data.product_ids:
|
||||
product = await inventory_client.get_ingredient_by_id(
|
||||
ingredient_id=UUID(product_id),
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
if product:
|
||||
products.append(product)
|
||||
else:
|
||||
# Fetch all products for tenant (limit to 10)
|
||||
all_products = await inventory_client.get_all_ingredients(tenant_id=tenant_id)
|
||||
products = all_products[:10] # Limit to prevent timeout
|
||||
|
||||
if not products:
|
||||
return DemandAnalysisResponse(
|
||||
success=False,
|
||||
message="No products found for analysis",
|
||||
tenant_id=tenant_id,
|
||||
products_analyzed=0,
|
||||
total_insights_generated=0,
|
||||
total_insights_posted=0,
|
||||
insights_by_product={},
|
||||
errors=["No products found"]
|
||||
)
|
||||
|
||||
# Initialize sales client to fetch historical data
|
||||
sales_client = SalesServiceClient(config=settings, calling_service_name="forecasting")
|
||||
|
||||
# Calculate date range
|
||||
end_date = datetime.utcnow()
|
||||
start_date = end_date - timedelta(days=request_data.lookback_days)
|
||||
|
||||
# Process each product
|
||||
total_insights_generated = 0
|
||||
total_insights_posted = 0
|
||||
insights_by_product = {}
|
||||
errors = []
|
||||
|
||||
for product in products:
|
||||
try:
|
||||
product_id = str(product['id'])
|
||||
product_name = product.get('name', 'Unknown')
|
||||
logger.info(f"Analyzing product {product_name} ({product_id})")
|
||||
|
||||
# Fetch sales data for product
|
||||
sales_data = await sales_client.get_sales_data(
|
||||
tenant_id=tenant_id,
|
||||
product_id=product_id,
|
||||
start_date=start_date.strftime('%Y-%m-%d'),
|
||||
end_date=end_date.strftime('%Y-%m-%d')
|
||||
)
|
||||
|
||||
if not sales_data:
|
||||
logger.warning(f"No sales data for product {product_id}")
|
||||
continue
|
||||
|
||||
# Convert to DataFrame
|
||||
sales_df = pd.DataFrame(sales_data)
|
||||
|
||||
if len(sales_df) < 30: # Minimum for demand analysis
|
||||
logger.warning(
|
||||
f"Insufficient data for product {product_id}: "
|
||||
f"{len(sales_df)} samples < 30 required"
|
||||
)
|
||||
continue
|
||||
|
||||
# Check what columns are available and map to expected format
|
||||
logger.debug(f"Sales data columns for product {product_id}: {sales_df.columns.tolist()}")
|
||||
|
||||
# Map common field names to 'quantity' and 'date'
|
||||
if 'quantity' not in sales_df.columns:
|
||||
if 'total_quantity' in sales_df.columns:
|
||||
sales_df['quantity'] = sales_df['total_quantity']
|
||||
elif 'amount' in sales_df.columns:
|
||||
sales_df['quantity'] = sales_df['amount']
|
||||
else:
|
||||
logger.warning(f"No quantity field found for product {product_id}, skipping")
|
||||
continue
|
||||
|
||||
if 'date' not in sales_df.columns:
|
||||
if 'sale_date' in sales_df.columns:
|
||||
sales_df['date'] = sales_df['sale_date']
|
||||
else:
|
||||
logger.warning(f"No date field found for product {product_id}, skipping")
|
||||
continue
|
||||
|
||||
# Prepare sales data with required columns
|
||||
sales_df['date'] = pd.to_datetime(sales_df['date'])
|
||||
sales_df['quantity'] = sales_df['quantity'].astype(float)
|
||||
sales_df['day_of_week'] = sales_df['date'].dt.dayofweek
|
||||
|
||||
# Run demand analysis
|
||||
results = await orchestrator.analyze_and_post_demand_insights(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=product_id,
|
||||
sales_data=sales_df,
|
||||
forecast_horizon_days=request_data.forecast_horizon_days,
|
||||
min_history_days=request_data.lookback_days
|
||||
)
|
||||
|
||||
# Track results
|
||||
total_insights_generated += results['insights_generated']
|
||||
total_insights_posted += results['insights_posted']
|
||||
insights_by_product[product_id] = {
|
||||
'product_name': product_name,
|
||||
'insights_posted': results['insights_posted'],
|
||||
'trend_analysis': results.get('trend_analysis', {})
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Product {product_id} demand analysis complete",
|
||||
insights_posted=results['insights_posted']
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error analyzing product {product_id}: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
errors.append(error_msg)
|
||||
|
||||
# Close orchestrator
|
||||
await orchestrator.close()
|
||||
|
||||
# Build response
|
||||
response = DemandAnalysisResponse(
|
||||
success=total_insights_posted > 0,
|
||||
message=f"Successfully generated {total_insights_posted} insights from {len(products)} products",
|
||||
tenant_id=tenant_id,
|
||||
products_analyzed=len(products),
|
||||
total_insights_generated=total_insights_generated,
|
||||
total_insights_posted=total_insights_posted,
|
||||
insights_by_product=insights_by_product,
|
||||
errors=errors
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"ML insights demand analysis complete",
|
||||
tenant_id=tenant_id,
|
||||
total_insights=total_insights_posted
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"ML insights demand analysis failed",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e),
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Demand analysis failed: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/analyze-business-rules", response_model=BusinessRulesAnalysisResponse)
|
||||
async def trigger_business_rules_analysis(
|
||||
tenant_id: str,
|
||||
request_data: BusinessRulesAnalysisRequest,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Trigger business rules optimization analysis from historical sales data.
|
||||
|
||||
This endpoint:
|
||||
1. Fetches historical sales data for specified products
|
||||
2. Runs the BusinessRulesInsightsOrchestrator to analyze rules
|
||||
3. Generates insights about business rule optimization
|
||||
4. Posts insights to AI Insights Service
|
||||
5. Publishes events to RabbitMQ
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID
|
||||
request_data: Business rules analysis parameters
|
||||
request: FastAPI request object to access app state
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
BusinessRulesAnalysisResponse with analysis results
|
||||
"""
|
||||
logger.info(
|
||||
"ML insights business rules analysis requested",
|
||||
tenant_id=tenant_id,
|
||||
product_ids=request_data.product_ids,
|
||||
lookback_days=request_data.lookback_days
|
||||
)
|
||||
|
||||
try:
|
||||
# Import ML orchestrator and clients
|
||||
from app.ml.business_rules_insights_orchestrator import BusinessRulesInsightsOrchestrator
|
||||
from shared.clients.sales_client import SalesServiceClient
|
||||
from shared.clients.inventory_client import InventoryServiceClient
|
||||
from app.core.config import settings
|
||||
|
||||
# Get event publisher from app state
|
||||
event_publisher = getattr(request.app.state, 'event_publisher', None)
|
||||
|
||||
# Initialize orchestrator and clients
|
||||
orchestrator = BusinessRulesInsightsOrchestrator(event_publisher=event_publisher)
|
||||
inventory_client = InventoryServiceClient(settings)
|
||||
|
||||
# Get products to analyze from inventory service via API
|
||||
if request_data.product_ids:
|
||||
# Fetch specific products
|
||||
products = []
|
||||
for product_id in request_data.product_ids:
|
||||
product = await inventory_client.get_ingredient_by_id(
|
||||
ingredient_id=UUID(product_id),
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
if product:
|
||||
products.append(product)
|
||||
else:
|
||||
# Fetch all products for tenant (limit to 10)
|
||||
all_products = await inventory_client.get_all_ingredients(tenant_id=tenant_id)
|
||||
products = all_products[:10] # Limit to prevent timeout
|
||||
|
||||
if not products:
|
||||
return BusinessRulesAnalysisResponse(
|
||||
success=False,
|
||||
message="No products found for analysis",
|
||||
tenant_id=tenant_id,
|
||||
products_analyzed=0,
|
||||
total_insights_generated=0,
|
||||
total_insights_posted=0,
|
||||
insights_by_product={},
|
||||
errors=["No products found"]
|
||||
)
|
||||
|
||||
# Initialize sales client to fetch historical data
|
||||
sales_client = SalesServiceClient(config=settings, calling_service_name="forecasting")
|
||||
|
||||
# Calculate date range
|
||||
end_date = datetime.utcnow()
|
||||
start_date = end_date - timedelta(days=request_data.lookback_days)
|
||||
|
||||
# Process each product
|
||||
total_insights_generated = 0
|
||||
total_insights_posted = 0
|
||||
insights_by_product = {}
|
||||
errors = []
|
||||
|
||||
for product in products:
|
||||
try:
|
||||
product_id = str(product['id'])
|
||||
product_name = product.get('name', 'Unknown')
|
||||
logger.info(f"Analyzing product {product_name} ({product_id})")
|
||||
|
||||
# Fetch sales data for product
|
||||
sales_data = await sales_client.get_sales_data(
|
||||
tenant_id=tenant_id,
|
||||
product_id=product_id,
|
||||
start_date=start_date.strftime('%Y-%m-%d'),
|
||||
end_date=end_date.strftime('%Y-%m-%d')
|
||||
)
|
||||
|
||||
if not sales_data:
|
||||
logger.warning(f"No sales data for product {product_id}")
|
||||
continue
|
||||
|
||||
# Convert to DataFrame
|
||||
sales_df = pd.DataFrame(sales_data)
|
||||
|
||||
if len(sales_df) < request_data.min_samples:
|
||||
logger.warning(
|
||||
f"Insufficient data for product {product_id}: "
|
||||
f"{len(sales_df)} samples < {request_data.min_samples} required"
|
||||
)
|
||||
continue
|
||||
|
||||
# Check what columns are available and map to expected format
|
||||
logger.debug(f"Sales data columns for product {product_id}: {sales_df.columns.tolist()}")
|
||||
|
||||
# Map common field names to 'quantity' and 'date'
|
||||
if 'quantity' not in sales_df.columns:
|
||||
if 'total_quantity' in sales_df.columns:
|
||||
sales_df['quantity'] = sales_df['total_quantity']
|
||||
elif 'amount' in sales_df.columns:
|
||||
sales_df['quantity'] = sales_df['amount']
|
||||
else:
|
||||
logger.warning(f"No quantity field found for product {product_id}, skipping")
|
||||
continue
|
||||
|
||||
if 'date' not in sales_df.columns:
|
||||
if 'sale_date' in sales_df.columns:
|
||||
sales_df['date'] = sales_df['sale_date']
|
||||
else:
|
||||
logger.warning(f"No date field found for product {product_id}, skipping")
|
||||
continue
|
||||
|
||||
# Prepare sales data with required columns
|
||||
sales_df['date'] = pd.to_datetime(sales_df['date'])
|
||||
sales_df['quantity'] = sales_df['quantity'].astype(float)
|
||||
sales_df['day_of_week'] = sales_df['date'].dt.dayofweek
|
||||
|
||||
# Run business rules analysis
|
||||
results = await orchestrator.analyze_and_post_business_rules_insights(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=product_id,
|
||||
sales_data=sales_df,
|
||||
min_samples=request_data.min_samples
|
||||
)
|
||||
|
||||
# Track results
|
||||
total_insights_generated += results['insights_generated']
|
||||
total_insights_posted += results['insights_posted']
|
||||
insights_by_product[product_id] = {
|
||||
'product_name': product_name,
|
||||
'insights_posted': results['insights_posted'],
|
||||
'rules_learned': len(results.get('rules', {}))
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Product {product_id} business rules analysis complete",
|
||||
insights_posted=results['insights_posted']
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error analyzing product {product_id}: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
errors.append(error_msg)
|
||||
|
||||
# Close orchestrator
|
||||
await orchestrator.close()
|
||||
|
||||
# Build response
|
||||
response = BusinessRulesAnalysisResponse(
|
||||
success=total_insights_posted > 0,
|
||||
message=f"Successfully generated {total_insights_posted} insights from {len(products)} products",
|
||||
tenant_id=tenant_id,
|
||||
products_analyzed=len(products),
|
||||
total_insights_generated=total_insights_generated,
|
||||
total_insights_posted=total_insights_posted,
|
||||
insights_by_product=insights_by_product,
|
||||
errors=errors
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"ML insights business rules analysis complete",
|
||||
tenant_id=tenant_id,
|
||||
total_insights=total_insights_posted
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"ML insights business rules analysis failed",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e),
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Business rules analysis failed: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def ml_insights_health():
|
||||
"""Health check for ML insights endpoints"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "forecasting-ml-insights",
|
||||
"endpoints": [
|
||||
"POST /ml/insights/generate-rules",
|
||||
"POST /ml/insights/analyze-demand",
|
||||
"POST /ml/insights/analyze-business-rules"
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
# ================================================================
|
||||
# INTERNAL ML INSIGHTS ENDPOINTS (for demo session service)
|
||||
# ================================================================
|
||||
|
||||
internal_router = APIRouter(tags=["Internal ML"])
|
||||
|
||||
|
||||
@internal_router.post("/api/v1/tenants/{tenant_id}/forecasting/internal/ml/generate-demand-insights")
|
||||
async def trigger_demand_insights_internal(
|
||||
tenant_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Internal endpoint to trigger demand forecasting insights for a tenant.
|
||||
|
||||
This endpoint is called by the demo-session service after cloning to generate
|
||||
AI insights from the seeded forecast data.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID
|
||||
request: FastAPI request object to access app state
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Dict with insights generation results
|
||||
"""
|
||||
logger.info(
|
||||
"Internal demand insights generation triggered",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
try:
|
||||
# Import ML orchestrator and clients
|
||||
from app.ml.demand_insights_orchestrator import DemandInsightsOrchestrator
|
||||
from shared.clients.sales_client import SalesServiceClient
|
||||
from shared.clients.inventory_client import InventoryServiceClient
|
||||
from app.core.config import settings
|
||||
|
||||
# Get event publisher from app state
|
||||
event_publisher = getattr(request.app.state, 'event_publisher', None)
|
||||
|
||||
# Initialize orchestrator and clients
|
||||
orchestrator = DemandInsightsOrchestrator(event_publisher=event_publisher)
|
||||
inventory_client = InventoryServiceClient(settings)
|
||||
|
||||
# Get all products for tenant (limit to 10 for performance)
|
||||
all_products = await inventory_client.get_all_ingredients(tenant_id=tenant_id)
|
||||
products = all_products[:10] if all_products else []
|
||||
|
||||
logger.info(
|
||||
"Retrieved products from inventory service",
|
||||
tenant_id=tenant_id,
|
||||
product_count=len(products)
|
||||
)
|
||||
|
||||
if not products:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "No products found for analysis",
|
||||
"tenant_id": tenant_id,
|
||||
"products_analyzed": 0,
|
||||
"insights_posted": 0
|
||||
}
|
||||
|
||||
# Initialize sales client
|
||||
sales_client = SalesServiceClient(config=settings, calling_service_name="forecasting")
|
||||
|
||||
# Calculate date range (90 days lookback)
|
||||
end_date = datetime.utcnow()
|
||||
start_date = end_date - timedelta(days=90)
|
||||
|
||||
# Process each product
|
||||
total_insights_generated = 0
|
||||
total_insights_posted = 0
|
||||
|
||||
for product in products:
|
||||
try:
|
||||
product_id = str(product['id'])
|
||||
product_name = product.get('name', 'Unknown Product')
|
||||
|
||||
logger.debug(
|
||||
"Analyzing demand for product",
|
||||
tenant_id=tenant_id,
|
||||
product_id=product_id,
|
||||
product_name=product_name
|
||||
)
|
||||
|
||||
# Fetch historical sales data
|
||||
sales_data_raw = await sales_client.get_sales_data(
|
||||
tenant_id=tenant_id,
|
||||
product_id=product_id,
|
||||
start_date=start_date.strftime('%Y-%m-%d'),
|
||||
end_date=end_date.strftime('%Y-%m-%d')
|
||||
)
|
||||
|
||||
if not sales_data_raw or len(sales_data_raw) < 10:
|
||||
logger.debug(
|
||||
"Insufficient sales data for product",
|
||||
product_id=product_id,
|
||||
sales_records=len(sales_data_raw) if sales_data_raw else 0
|
||||
)
|
||||
continue
|
||||
|
||||
# Convert to DataFrame
|
||||
sales_df = pd.DataFrame(sales_data_raw)
|
||||
|
||||
# Map field names to expected format
|
||||
if 'quantity' not in sales_df.columns:
|
||||
if 'total_quantity' in sales_df.columns:
|
||||
sales_df['quantity'] = sales_df['total_quantity']
|
||||
elif 'quantity_sold' in sales_df.columns:
|
||||
sales_df['quantity'] = sales_df['quantity_sold']
|
||||
else:
|
||||
logger.warning(
|
||||
"No quantity field found for product",
|
||||
product_id=product_id
|
||||
)
|
||||
continue
|
||||
|
||||
if 'date' not in sales_df.columns:
|
||||
if 'sale_date' in sales_df.columns:
|
||||
sales_df['date'] = sales_df['sale_date']
|
||||
else:
|
||||
logger.warning(
|
||||
"No date field found for product",
|
||||
product_id=product_id
|
||||
)
|
||||
continue
|
||||
|
||||
# Run demand insights orchestrator
|
||||
results = await orchestrator.analyze_and_post_demand_insights(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=product_id,
|
||||
sales_data=sales_df,
|
||||
forecast_horizon_days=30,
|
||||
min_history_days=90
|
||||
)
|
||||
|
||||
total_insights_generated += results['insights_generated']
|
||||
total_insights_posted += results['insights_posted']
|
||||
|
||||
logger.info(
|
||||
"Demand insights generated for product",
|
||||
tenant_id=tenant_id,
|
||||
product_id=product_id,
|
||||
insights_posted=results['insights_posted']
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to analyze product demand (non-fatal)",
|
||||
tenant_id=tenant_id,
|
||||
product_id=product_id,
|
||||
error=str(e)
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
"Internal demand insights generation complete",
|
||||
tenant_id=tenant_id,
|
||||
products_analyzed=len(products),
|
||||
insights_generated=total_insights_generated,
|
||||
insights_posted=total_insights_posted
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Generated {total_insights_posted} demand forecasting insights",
|
||||
"tenant_id": tenant_id,
|
||||
"products_analyzed": len(products),
|
||||
"insights_posted": total_insights_posted
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Internal demand insights generation failed",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e),
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"Demand insights generation failed: {str(e)}",
|
||||
"tenant_id": tenant_id,
|
||||
"products_analyzed": 0,
|
||||
"insights_posted": 0
|
||||
}
|
||||
287
services/forecasting/app/api/performance_monitoring.py
Normal file
287
services/forecasting/app/api/performance_monitoring.py
Normal file
@@ -0,0 +1,287 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/api/performance_monitoring.py
|
||||
# ================================================================
|
||||
"""
|
||||
Performance Monitoring API - Track and analyze forecast accuracy over time
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Query, status
|
||||
from typing import Dict, Any
|
||||
from uuid import UUID
|
||||
import structlog
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from app.services.performance_monitoring_service import PerformanceMonitoringService
|
||||
from shared.auth.decorators import get_current_user_dep
|
||||
from shared.auth.access_control import require_user_role
|
||||
from shared.routing import RouteBuilder
|
||||
from app.core.database import get_db
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
route_builder = RouteBuilder('forecasting')
|
||||
router = APIRouter(tags=["performance-monitoring"])
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Request/Response Schemas
|
||||
# ================================================================
|
||||
|
||||
class AccuracySummaryRequest(BaseModel):
|
||||
"""Request model for accuracy summary"""
|
||||
days: int = Field(default=30, ge=1, le=365, description="Analysis period in days")
|
||||
|
||||
|
||||
class DegradationAnalysisRequest(BaseModel):
|
||||
"""Request model for degradation analysis"""
|
||||
lookback_days: int = Field(default=30, ge=7, le=365, description="Days to analyze")
|
||||
|
||||
|
||||
class ModelAgeCheckRequest(BaseModel):
|
||||
"""Request model for model age check"""
|
||||
max_age_days: int = Field(default=30, ge=1, le=90, description="Max acceptable model age")
|
||||
|
||||
|
||||
class PerformanceReportRequest(BaseModel):
|
||||
"""Request model for comprehensive performance report"""
|
||||
days: int = Field(default=30, ge=1, le=365, description="Analysis period in days")
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Endpoints
|
||||
# ================================================================
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("monitoring/accuracy-summary"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def get_accuracy_summary(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
days: int = Query(30, ge=1, le=365, description="Analysis period in days"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get forecast accuracy summary for recent period
|
||||
|
||||
Returns overall metrics, validation coverage, and health status.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Getting accuracy summary",
|
||||
tenant_id=tenant_id,
|
||||
days=days,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = PerformanceMonitoringService(db)
|
||||
|
||||
summary = await service.get_accuracy_summary(
|
||||
tenant_id=tenant_id,
|
||||
days=days
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get accuracy summary",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get accuracy summary: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("monitoring/degradation-analysis"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def analyze_performance_degradation(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
lookback_days: int = Query(30, ge=7, le=365, description="Days to analyze"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Detect if forecast performance is degrading over time
|
||||
|
||||
Compares first half vs second half of period and identifies poor performers.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Analyzing performance degradation",
|
||||
tenant_id=tenant_id,
|
||||
lookback_days=lookback_days,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = PerformanceMonitoringService(db)
|
||||
|
||||
analysis = await service.detect_performance_degradation(
|
||||
tenant_id=tenant_id,
|
||||
lookback_days=lookback_days
|
||||
)
|
||||
|
||||
return analysis
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to analyze degradation",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to analyze degradation: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("monitoring/model-age"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def check_model_age(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
max_age_days: int = Query(30, ge=1, le=90, description="Max acceptable model age"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Check if models are outdated and need retraining
|
||||
|
||||
Returns models in use and identifies those needing updates.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Checking model age",
|
||||
tenant_id=tenant_id,
|
||||
max_age_days=max_age_days,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = PerformanceMonitoringService(db)
|
||||
|
||||
analysis = await service.check_model_age(
|
||||
tenant_id=tenant_id,
|
||||
max_age_days=max_age_days
|
||||
)
|
||||
|
||||
return analysis
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to check model age",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to check model age: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("monitoring/performance-report"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def generate_performance_report(
|
||||
request: PerformanceReportRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Generate comprehensive performance report
|
||||
|
||||
Combines accuracy summary, degradation analysis, and model age check
|
||||
with actionable recommendations.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Generating performance report",
|
||||
tenant_id=tenant_id,
|
||||
days=request.days,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = PerformanceMonitoringService(db)
|
||||
|
||||
report = await service.generate_performance_report(
|
||||
tenant_id=tenant_id,
|
||||
days=request.days
|
||||
)
|
||||
|
||||
return report
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to generate performance report",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to generate performance report: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("monitoring/health"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def get_health_status(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get quick health status for dashboards
|
||||
|
||||
Returns simplified health metrics for UI display.
|
||||
"""
|
||||
try:
|
||||
service = PerformanceMonitoringService(db)
|
||||
|
||||
# Get 7-day summary for quick health check
|
||||
summary = await service.get_accuracy_summary(
|
||||
tenant_id=tenant_id,
|
||||
days=7
|
||||
)
|
||||
|
||||
if summary.get("status") == "no_data":
|
||||
return {
|
||||
"status": "unknown",
|
||||
"message": "No recent validation data available",
|
||||
"health_status": "unknown"
|
||||
}
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"health_status": summary.get("health_status"),
|
||||
"current_mape": summary["average_metrics"].get("mape"),
|
||||
"accuracy_percentage": summary["average_metrics"].get("accuracy_percentage"),
|
||||
"validation_coverage": summary.get("coverage_percentage"),
|
||||
"last_7_days": {
|
||||
"validation_runs": summary.get("validation_runs"),
|
||||
"forecasts_evaluated": summary.get("total_forecasts_evaluated")
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get health status",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get health status: {str(e)}"
|
||||
)
|
||||
297
services/forecasting/app/api/retraining.py
Normal file
297
services/forecasting/app/api/retraining.py
Normal file
@@ -0,0 +1,297 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/api/retraining.py
|
||||
# ================================================================
|
||||
"""
|
||||
Retraining API - Trigger and manage model retraining based on performance
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Query, status
|
||||
from typing import Dict, Any, List
|
||||
from uuid import UUID
|
||||
import structlog
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from app.services.retraining_trigger_service import RetrainingTriggerService
|
||||
from shared.auth.decorators import get_current_user_dep
|
||||
from shared.auth.access_control import require_user_role
|
||||
from shared.routing import RouteBuilder
|
||||
from app.core.database import get_db
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
route_builder = RouteBuilder('forecasting')
|
||||
router = APIRouter(tags=["retraining"])
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Request/Response Schemas
|
||||
# ================================================================
|
||||
|
||||
class EvaluateRetrainingRequest(BaseModel):
|
||||
"""Request model for retraining evaluation"""
|
||||
auto_trigger: bool = Field(
|
||||
default=False,
|
||||
description="Automatically trigger retraining for poor performers"
|
||||
)
|
||||
|
||||
|
||||
class TriggerProductRetrainingRequest(BaseModel):
|
||||
"""Request model for single product retraining"""
|
||||
inventory_product_id: UUID = Field(..., description="Product to retrain")
|
||||
reason: str = Field(..., description="Reason for retraining")
|
||||
priority: str = Field(
|
||||
default="normal",
|
||||
description="Priority level: low, normal, high"
|
||||
)
|
||||
|
||||
|
||||
class TriggerBulkRetrainingRequest(BaseModel):
|
||||
"""Request model for bulk retraining"""
|
||||
product_ids: List[UUID] = Field(..., description="List of products to retrain")
|
||||
reason: str = Field(
|
||||
default="Bulk retraining requested",
|
||||
description="Reason for bulk retraining"
|
||||
)
|
||||
|
||||
|
||||
class ScheduledRetrainingCheckRequest(BaseModel):
|
||||
"""Request model for scheduled retraining check"""
|
||||
max_model_age_days: int = Field(
|
||||
default=30,
|
||||
ge=1,
|
||||
le=90,
|
||||
description="Maximum acceptable model age"
|
||||
)
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Endpoints
|
||||
# ================================================================
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("retraining/evaluate"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner'])
|
||||
async def evaluate_retraining_needs(
|
||||
request: EvaluateRetrainingRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Evaluate performance and optionally trigger retraining
|
||||
|
||||
Analyzes 30-day performance and identifies products needing retraining.
|
||||
If auto_trigger=true, automatically triggers retraining for poor performers.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Evaluating retraining needs",
|
||||
tenant_id=tenant_id,
|
||||
auto_trigger=request.auto_trigger,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = RetrainingTriggerService(db)
|
||||
|
||||
result = await service.evaluate_and_trigger_retraining(
|
||||
tenant_id=tenant_id,
|
||||
auto_trigger=request.auto_trigger
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to evaluate retraining needs",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to evaluate retraining: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("retraining/trigger-product"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner'])
|
||||
async def trigger_product_retraining(
|
||||
request: TriggerProductRetrainingRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Trigger retraining for a specific product
|
||||
|
||||
Manually trigger model retraining for a single product.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Triggering product retraining",
|
||||
tenant_id=tenant_id,
|
||||
product_id=request.inventory_product_id,
|
||||
reason=request.reason,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = RetrainingTriggerService(db)
|
||||
|
||||
result = await service._trigger_product_retraining(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=request.inventory_product_id,
|
||||
reason=request.reason,
|
||||
priority=request.priority
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to trigger product retraining",
|
||||
tenant_id=tenant_id,
|
||||
product_id=request.inventory_product_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to trigger retraining: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("retraining/trigger-bulk"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner'])
|
||||
async def trigger_bulk_retraining(
|
||||
request: TriggerBulkRetrainingRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Trigger retraining for multiple products
|
||||
|
||||
Bulk retraining operation for multiple products at once.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Triggering bulk retraining",
|
||||
tenant_id=tenant_id,
|
||||
product_count=len(request.product_ids),
|
||||
reason=request.reason,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = RetrainingTriggerService(db)
|
||||
|
||||
result = await service.trigger_bulk_retraining(
|
||||
tenant_id=tenant_id,
|
||||
product_ids=request.product_ids,
|
||||
reason=request.reason
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to trigger bulk retraining",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to trigger bulk retraining: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("retraining/recommendations"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def get_retraining_recommendations(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get retraining recommendations without triggering
|
||||
|
||||
Returns recommendations for manual review and decision-making.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Getting retraining recommendations",
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = RetrainingTriggerService(db)
|
||||
|
||||
recommendations = await service.get_retraining_recommendations(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
return recommendations
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get recommendations",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get recommendations: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("retraining/check-scheduled"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner'])
|
||||
async def check_scheduled_retraining(
|
||||
request: ScheduledRetrainingCheckRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Check for models needing scheduled retraining based on age
|
||||
|
||||
Identifies models that haven't been updated in max_model_age_days.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Checking scheduled retraining needs",
|
||||
tenant_id=tenant_id,
|
||||
max_model_age_days=request.max_model_age_days,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = RetrainingTriggerService(db)
|
||||
|
||||
result = await service.check_and_trigger_scheduled_retraining(
|
||||
tenant_id=tenant_id,
|
||||
max_model_age_days=request.max_model_age_days
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to check scheduled retraining",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to check scheduled retraining: {str(e)}"
|
||||
)
|
||||
455
services/forecasting/app/api/scenario_operations.py
Normal file
455
services/forecasting/app/api/scenario_operations.py
Normal file
@@ -0,0 +1,455 @@
|
||||
"""
|
||||
Scenario Simulation Operations API - PROFESSIONAL/ENTERPRISE ONLY
|
||||
Business operations for "what-if" scenario testing and strategic planning
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Path, Request
|
||||
from typing import List, Dict, Any
|
||||
from datetime import date, datetime, timedelta, timezone
|
||||
import uuid
|
||||
|
||||
from app.schemas.forecasts import (
|
||||
ScenarioSimulationRequest,
|
||||
ScenarioSimulationResponse,
|
||||
ScenarioComparisonRequest,
|
||||
ScenarioComparisonResponse,
|
||||
ScenarioType,
|
||||
ScenarioImpact,
|
||||
ForecastResponse,
|
||||
ForecastRequest
|
||||
)
|
||||
from app.services.forecasting_service import EnhancedForecastingService
|
||||
from shared.auth.decorators import get_current_user_dep
|
||||
from shared.database.base import create_database_manager
|
||||
from shared.monitoring.decorators import track_execution_time
|
||||
from shared.monitoring.metrics import get_metrics_collector
|
||||
from app.core.config import settings
|
||||
from shared.routing import RouteBuilder
|
||||
from shared.auth.access_control import require_user_role, analytics_tier_required
|
||||
from shared.clients.tenant_client import TenantServiceClient
|
||||
|
||||
route_builder = RouteBuilder('forecasting')
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter(tags=["scenario-simulation"])
|
||||
|
||||
|
||||
def get_enhanced_forecasting_service():
|
||||
"""Dependency injection for EnhancedForecastingService"""
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "forecasting-service")
|
||||
return EnhancedForecastingService(database_manager)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_analytics_route("scenario-simulation"),
|
||||
response_model=ScenarioSimulationResponse
|
||||
)
|
||||
@require_user_role(['admin', 'owner'])
|
||||
@analytics_tier_required
|
||||
@track_execution_time("scenario_simulation_duration_seconds", "forecasting-service")
|
||||
async def simulate_scenario(
|
||||
request: ScenarioSimulationRequest,
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
request_obj: Request = None,
|
||||
current_user: dict = Depends(get_current_user_dep),
|
||||
forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
||||
):
|
||||
"""
|
||||
Run a "what-if" scenario simulation on forecasts
|
||||
|
||||
This endpoint allows users to test how different scenarios might impact demand:
|
||||
- Weather events (heatwaves, cold snaps, rain)
|
||||
- Competition (new competitors opening nearby)
|
||||
- Events (festivals, concerts, sports events)
|
||||
- Pricing changes
|
||||
- Promotions
|
||||
- Supply disruptions
|
||||
|
||||
**ENTERPRISE TIER ONLY - Admin+ role required**
|
||||
"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
start_time = datetime.now(timezone.utc)
|
||||
|
||||
try:
|
||||
logger.info("Starting scenario simulation",
|
||||
tenant_id=tenant_id,
|
||||
scenario_name=request.scenario_name,
|
||||
scenario_type=request.scenario_type.value,
|
||||
products=len(request.inventory_product_ids))
|
||||
|
||||
if metrics:
|
||||
metrics.increment_counter(f"scenario_simulations_total")
|
||||
metrics.increment_counter(f"scenario_simulations_{request.scenario_type.value}_total")
|
||||
|
||||
# Generate simulation ID
|
||||
simulation_id = str(uuid.uuid4())
|
||||
end_date = request.start_date + timedelta(days=request.duration_days - 1)
|
||||
|
||||
# Step 1: Generate baseline forecasts
|
||||
baseline_forecasts = []
|
||||
if request.include_baseline:
|
||||
logger.info("Generating baseline forecasts", tenant_id=tenant_id)
|
||||
|
||||
# Get tenant location (city) from tenant service
|
||||
location = "default"
|
||||
try:
|
||||
tenant_client = TenantServiceClient(settings)
|
||||
tenant_info = await tenant_client.get_tenant(tenant_id)
|
||||
if tenant_info and tenant_info.get('city'):
|
||||
location = tenant_info['city']
|
||||
logger.info("Using tenant location for forecasts", tenant_id=tenant_id, location=location)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to get tenant location, using default", error=str(e), tenant_id=tenant_id)
|
||||
|
||||
for product_id in request.inventory_product_ids:
|
||||
forecast_request = ForecastRequest(
|
||||
inventory_product_id=product_id,
|
||||
forecast_date=request.start_date,
|
||||
forecast_days=request.duration_days,
|
||||
location=location
|
||||
)
|
||||
multi_day_result = await forecasting_service.generate_multi_day_forecast(
|
||||
tenant_id=tenant_id,
|
||||
request=forecast_request
|
||||
)
|
||||
# Convert forecast dictionaries to ForecastResponse objects
|
||||
forecast_dicts = multi_day_result.get("forecasts", [])
|
||||
for forecast_dict in forecast_dicts:
|
||||
if isinstance(forecast_dict, dict):
|
||||
baseline_forecasts.append(ForecastResponse(**forecast_dict))
|
||||
else:
|
||||
baseline_forecasts.append(forecast_dict)
|
||||
|
||||
# Step 2: Apply scenario adjustments to generate scenario forecasts
|
||||
scenario_forecasts = await _apply_scenario_adjustments(
|
||||
tenant_id=tenant_id,
|
||||
request=request,
|
||||
baseline_forecasts=baseline_forecasts if request.include_baseline else [],
|
||||
forecasting_service=forecasting_service
|
||||
)
|
||||
|
||||
# Step 3: Calculate impacts
|
||||
product_impacts = _calculate_product_impacts(
|
||||
baseline_forecasts,
|
||||
scenario_forecasts,
|
||||
request.inventory_product_ids
|
||||
)
|
||||
|
||||
# Step 4: Calculate totals
|
||||
total_baseline_demand = sum(f.predicted_demand for f in baseline_forecasts) if baseline_forecasts else 0
|
||||
total_scenario_demand = sum(f.predicted_demand for f in scenario_forecasts)
|
||||
overall_impact_percent = (
|
||||
((total_scenario_demand - total_baseline_demand) / total_baseline_demand * 100)
|
||||
if total_baseline_demand > 0 else 0
|
||||
)
|
||||
|
||||
# Step 5: Generate insights and recommendations
|
||||
insights, recommendations, risk_level = _generate_insights(
|
||||
request.scenario_type,
|
||||
request,
|
||||
product_impacts,
|
||||
overall_impact_percent
|
||||
)
|
||||
|
||||
# Calculate processing time
|
||||
processing_time_ms = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
|
||||
|
||||
if metrics:
|
||||
metrics.increment_counter("scenario_simulations_success_total")
|
||||
metrics.observe_histogram("scenario_simulation_processing_time_ms", processing_time_ms)
|
||||
|
||||
logger.info("Scenario simulation completed successfully",
|
||||
tenant_id=tenant_id,
|
||||
simulation_id=simulation_id,
|
||||
overall_impact=f"{overall_impact_percent:.2f}%",
|
||||
processing_time_ms=processing_time_ms)
|
||||
|
||||
return ScenarioSimulationResponse(
|
||||
id=simulation_id,
|
||||
tenant_id=tenant_id,
|
||||
scenario_name=request.scenario_name,
|
||||
scenario_type=request.scenario_type,
|
||||
start_date=request.start_date,
|
||||
end_date=end_date,
|
||||
duration_days=request.duration_days,
|
||||
baseline_forecasts=baseline_forecasts if request.include_baseline else None,
|
||||
scenario_forecasts=scenario_forecasts,
|
||||
total_baseline_demand=total_baseline_demand,
|
||||
total_scenario_demand=total_scenario_demand,
|
||||
overall_impact_percent=overall_impact_percent,
|
||||
product_impacts=product_impacts,
|
||||
insights=insights,
|
||||
recommendations=recommendations,
|
||||
risk_level=risk_level,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
processing_time_ms=processing_time_ms
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("scenario_simulation_validation_errors_total")
|
||||
logger.error("Scenario simulation validation error", error=str(e), tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
if metrics:
|
||||
metrics.increment_counter("scenario_simulations_errors_total")
|
||||
logger.error("Scenario simulation failed", error=str(e), tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Scenario simulation failed"
|
||||
)
|
||||
|
||||
|
||||
async def _apply_scenario_adjustments(
|
||||
tenant_id: str,
|
||||
request: ScenarioSimulationRequest,
|
||||
baseline_forecasts: List[ForecastResponse],
|
||||
forecasting_service: EnhancedForecastingService
|
||||
) -> List[ForecastResponse]:
|
||||
"""
|
||||
Apply scenario-specific adjustments to forecasts
|
||||
"""
|
||||
scenario_forecasts = []
|
||||
|
||||
# If no baseline, generate fresh forecasts
|
||||
if not baseline_forecasts:
|
||||
for product_id in request.inventory_product_ids:
|
||||
forecast_request = ForecastRequest(
|
||||
inventory_product_id=product_id,
|
||||
forecast_date=request.start_date,
|
||||
forecast_days=request.duration_days,
|
||||
location="default"
|
||||
)
|
||||
multi_day_result = await forecasting_service.generate_multi_day_forecast(
|
||||
tenant_id=tenant_id,
|
||||
request=forecast_request
|
||||
)
|
||||
baseline_forecasts = multi_day_result.get("forecasts", [])
|
||||
|
||||
# Apply multipliers based on scenario type
|
||||
for forecast in baseline_forecasts:
|
||||
adjusted_forecast = forecast.copy()
|
||||
multiplier = _get_scenario_multiplier(request)
|
||||
|
||||
# Adjust predicted demand
|
||||
adjusted_forecast.predicted_demand *= multiplier
|
||||
adjusted_forecast.confidence_lower *= multiplier
|
||||
adjusted_forecast.confidence_upper *= multiplier
|
||||
|
||||
scenario_forecasts.append(adjusted_forecast)
|
||||
|
||||
return scenario_forecasts
|
||||
|
||||
|
||||
def _get_scenario_multiplier(request: ScenarioSimulationRequest) -> float:
|
||||
"""
|
||||
Calculate demand multiplier based on scenario type and parameters
|
||||
"""
|
||||
if request.scenario_type == ScenarioType.WEATHER:
|
||||
if request.weather_params:
|
||||
# Heatwave increases demand for cold items, decreases for hot items
|
||||
if request.weather_params.temperature_change and request.weather_params.temperature_change > 10:
|
||||
return 1.25 # 25% increase during heatwave
|
||||
elif request.weather_params.temperature_change and request.weather_params.temperature_change < -10:
|
||||
return 0.85 # 15% decrease during cold snap
|
||||
elif request.weather_params.precipitation_change and request.weather_params.precipitation_change > 10:
|
||||
return 0.90 # 10% decrease during heavy rain
|
||||
return 1.0
|
||||
|
||||
elif request.scenario_type == ScenarioType.COMPETITION:
|
||||
if request.competition_params:
|
||||
# New competition reduces demand based on market share loss
|
||||
return 1.0 - request.competition_params.estimated_market_share_loss
|
||||
return 0.85 # Default 15% reduction
|
||||
|
||||
elif request.scenario_type == ScenarioType.EVENT:
|
||||
if request.event_params:
|
||||
# Events increase demand based on attendance and proximity
|
||||
if request.event_params.distance_km < 1.0:
|
||||
return 1.5 # 50% increase for very close events
|
||||
elif request.event_params.distance_km < 5.0:
|
||||
return 1.2 # 20% increase for nearby events
|
||||
return 1.15 # Default 15% increase
|
||||
|
||||
elif request.scenario_type == ScenarioType.PRICING:
|
||||
if request.pricing_params:
|
||||
# Price elasticity: typically -0.5 to -2.0
|
||||
# 10% price increase = 5-20% demand decrease
|
||||
elasticity = -1.0 # Average elasticity
|
||||
return 1.0 + (request.pricing_params.price_change_percent / 100) * elasticity
|
||||
return 1.0
|
||||
|
||||
elif request.scenario_type == ScenarioType.PROMOTION:
|
||||
if request.promotion_params:
|
||||
# Promotions increase traffic and conversion
|
||||
traffic_boost = 1.0 + request.promotion_params.expected_traffic_increase
|
||||
discount_boost = 1.0 + (request.promotion_params.discount_percent / 100) * 0.5
|
||||
return traffic_boost * discount_boost
|
||||
return 1.3 # Default 30% increase
|
||||
|
||||
elif request.scenario_type == ScenarioType.SUPPLY_DISRUPTION:
|
||||
return 0.6 # 40% reduction due to limited supply
|
||||
|
||||
elif request.scenario_type == ScenarioType.CUSTOM:
|
||||
if request.custom_multipliers and 'demand' in request.custom_multipliers:
|
||||
return request.custom_multipliers['demand']
|
||||
return 1.0
|
||||
|
||||
return 1.0
|
||||
|
||||
|
||||
def _calculate_product_impacts(
|
||||
baseline_forecasts: List[ForecastResponse],
|
||||
scenario_forecasts: List[ForecastResponse],
|
||||
product_ids: List[str]
|
||||
) -> List[ScenarioImpact]:
|
||||
"""
|
||||
Calculate per-product impact of the scenario
|
||||
"""
|
||||
impacts = []
|
||||
|
||||
for product_id in product_ids:
|
||||
baseline_total = sum(
|
||||
f.predicted_demand for f in baseline_forecasts
|
||||
if f.inventory_product_id == product_id
|
||||
)
|
||||
scenario_total = sum(
|
||||
f.predicted_demand for f in scenario_forecasts
|
||||
if f.inventory_product_id == product_id
|
||||
)
|
||||
|
||||
if baseline_total > 0:
|
||||
change_percent = ((scenario_total - baseline_total) / baseline_total) * 100
|
||||
else:
|
||||
change_percent = 0
|
||||
|
||||
# Get confidence ranges
|
||||
scenario_product_forecasts = [
|
||||
f for f in scenario_forecasts if f.inventory_product_id == product_id
|
||||
]
|
||||
avg_lower = sum(f.confidence_lower for f in scenario_product_forecasts) / len(scenario_product_forecasts) if scenario_product_forecasts else 0
|
||||
avg_upper = sum(f.confidence_upper for f in scenario_product_forecasts) / len(scenario_product_forecasts) if scenario_product_forecasts else 0
|
||||
|
||||
impacts.append(ScenarioImpact(
|
||||
inventory_product_id=product_id,
|
||||
baseline_demand=baseline_total,
|
||||
simulated_demand=scenario_total,
|
||||
demand_change_percent=change_percent,
|
||||
confidence_range=(avg_lower, avg_upper),
|
||||
impact_factors={"primary_driver": "scenario_adjustment"}
|
||||
))
|
||||
|
||||
return impacts
|
||||
|
||||
|
||||
def _generate_insights(
|
||||
scenario_type: ScenarioType,
|
||||
request: ScenarioSimulationRequest,
|
||||
impacts: List[ScenarioImpact],
|
||||
overall_impact: float
|
||||
) -> tuple[List[str], List[str], str]:
|
||||
"""
|
||||
Generate AI-powered insights and recommendations
|
||||
"""
|
||||
insights = []
|
||||
recommendations = []
|
||||
risk_level = "low"
|
||||
|
||||
# Determine risk level
|
||||
if abs(overall_impact) > 30:
|
||||
risk_level = "high"
|
||||
elif abs(overall_impact) > 15:
|
||||
risk_level = "medium"
|
||||
|
||||
# Generate scenario-specific insights
|
||||
if scenario_type == ScenarioType.WEATHER:
|
||||
if request.weather_params:
|
||||
if request.weather_params.temperature_change and request.weather_params.temperature_change > 10:
|
||||
insights.append(f"Heatwave of +{request.weather_params.temperature_change}°C expected to increase demand by {overall_impact:.1f}%")
|
||||
recommendations.append("Increase inventory of cold beverages and refrigerated items")
|
||||
recommendations.append("Extend operating hours to capture increased evening traffic")
|
||||
elif request.weather_params.temperature_change and request.weather_params.temperature_change < -10:
|
||||
insights.append(f"Cold snap of {request.weather_params.temperature_change}°C expected to decrease demand by {abs(overall_impact):.1f}%")
|
||||
recommendations.append("Increase production of warm comfort foods")
|
||||
recommendations.append("Reduce inventory of cold items")
|
||||
|
||||
elif scenario_type == ScenarioType.COMPETITION:
|
||||
insights.append(f"New competitor expected to reduce demand by {abs(overall_impact):.1f}%")
|
||||
recommendations.append("Consider launching loyalty program to retain customers")
|
||||
recommendations.append("Differentiate with unique product offerings")
|
||||
recommendations.append("Focus on customer service excellence")
|
||||
|
||||
elif scenario_type == ScenarioType.EVENT:
|
||||
insights.append(f"Local event expected to increase demand by {overall_impact:.1f}%")
|
||||
recommendations.append("Increase staffing for the event period")
|
||||
recommendations.append("Stock additional inventory of popular items")
|
||||
recommendations.append("Consider event-specific promotions")
|
||||
|
||||
elif scenario_type == ScenarioType.PRICING:
|
||||
if overall_impact < 0:
|
||||
insights.append(f"Price increase expected to reduce demand by {abs(overall_impact):.1f}%")
|
||||
recommendations.append("Consider smaller price increases")
|
||||
recommendations.append("Communicate value proposition to customers")
|
||||
else:
|
||||
insights.append(f"Price decrease expected to increase demand by {overall_impact:.1f}%")
|
||||
recommendations.append("Ensure adequate inventory to meet increased demand")
|
||||
|
||||
elif scenario_type == ScenarioType.PROMOTION:
|
||||
insights.append(f"Promotion expected to increase demand by {overall_impact:.1f}%")
|
||||
recommendations.append("Stock additional inventory before promotion starts")
|
||||
recommendations.append("Increase staffing during promotion period")
|
||||
recommendations.append("Prepare marketing materials and signage")
|
||||
|
||||
# Add product-specific insights
|
||||
high_impact_products = [
|
||||
impact for impact in impacts
|
||||
if abs(impact.demand_change_percent) > 20
|
||||
]
|
||||
if high_impact_products:
|
||||
insights.append(f"{len(high_impact_products)} products show significant impact (>20% change)")
|
||||
|
||||
# Add general recommendation
|
||||
if risk_level == "high":
|
||||
recommendations.append("⚠️ High-impact scenario - review and adjust operational plans immediately")
|
||||
elif risk_level == "medium":
|
||||
recommendations.append("Monitor situation closely and prepare contingency plans")
|
||||
|
||||
return insights, recommendations, risk_level
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_analytics_route("scenario-comparison"),
|
||||
response_model=ScenarioComparisonResponse
|
||||
)
|
||||
@require_user_role(['viewer', 'member', 'admin', 'owner'])
|
||||
@analytics_tier_required
|
||||
async def compare_scenarios(
|
||||
request: ScenarioComparisonRequest,
|
||||
tenant_id: str = Path(..., description="Tenant ID")
|
||||
):
|
||||
"""
|
||||
Compare multiple scenario simulations
|
||||
|
||||
**PROFESSIONAL/ENTERPRISE ONLY**
|
||||
|
||||
**STATUS**: Not yet implemented - requires scenario persistence layer
|
||||
|
||||
**Future implementation would**:
|
||||
1. Retrieve saved scenarios by ID from database
|
||||
2. Use ScenarioPlanner.compare_scenarios() to analyze them
|
||||
3. Return comparison matrix with best/worst case analysis
|
||||
|
||||
**Prerequisites**:
|
||||
- Scenario storage/retrieval database layer
|
||||
- Scenario CRUD endpoints
|
||||
- UI for scenario management
|
||||
"""
|
||||
# NOTE: HTTP 501 Not Implemented is the correct response for unimplemented optional features
|
||||
# The ML logic exists in scenario_planner.py but requires a persistence layer
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="Scenario comparison requires scenario persistence layer (future feature)"
|
||||
)
|
||||
346
services/forecasting/app/api/validation.py
Normal file
346
services/forecasting/app/api/validation.py
Normal file
@@ -0,0 +1,346 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/api/validation.py
|
||||
# ================================================================
|
||||
"""
|
||||
Validation API - Forecast validation endpoints
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Query, status
|
||||
from typing import Dict, Any, List, Optional
|
||||
from uuid import UUID
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import structlog
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from app.services.validation_service import ValidationService
|
||||
from shared.auth.decorators import get_current_user_dep
|
||||
from shared.auth.access_control import require_user_role
|
||||
from shared.routing import RouteBuilder
|
||||
from app.core.database import get_db
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
route_builder = RouteBuilder('forecasting')
|
||||
router = APIRouter(tags=["validation"])
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Request/Response Schemas
|
||||
# ================================================================
|
||||
|
||||
class ValidationRequest(BaseModel):
|
||||
"""Request model for validation"""
|
||||
start_date: datetime = Field(..., description="Start date for validation period")
|
||||
end_date: datetime = Field(..., description="End date for validation period")
|
||||
orchestration_run_id: Optional[UUID] = Field(None, description="Optional orchestration run ID")
|
||||
triggered_by: str = Field(default="manual", description="Trigger source")
|
||||
|
||||
|
||||
class ValidationResponse(BaseModel):
|
||||
"""Response model for validation results"""
|
||||
validation_run_id: str
|
||||
status: str
|
||||
forecasts_evaluated: int
|
||||
forecasts_with_actuals: int
|
||||
forecasts_without_actuals: int
|
||||
metrics_created: int
|
||||
overall_metrics: Optional[Dict[str, float]] = None
|
||||
total_predicted_demand: Optional[float] = None
|
||||
total_actual_demand: Optional[float] = None
|
||||
duration_seconds: Optional[float] = None
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
class ValidationRunResponse(BaseModel):
|
||||
"""Response model for validation run details"""
|
||||
id: str
|
||||
tenant_id: str
|
||||
orchestration_run_id: Optional[str]
|
||||
validation_start_date: str
|
||||
validation_end_date: str
|
||||
started_at: str
|
||||
completed_at: Optional[str]
|
||||
duration_seconds: Optional[float]
|
||||
status: str
|
||||
total_forecasts_evaluated: int
|
||||
forecasts_with_actuals: int
|
||||
forecasts_without_actuals: int
|
||||
overall_mae: Optional[float]
|
||||
overall_mape: Optional[float]
|
||||
overall_rmse: Optional[float]
|
||||
overall_r2_score: Optional[float]
|
||||
overall_accuracy_percentage: Optional[float]
|
||||
total_predicted_demand: float
|
||||
total_actual_demand: float
|
||||
metrics_by_product: Optional[Dict[str, Any]]
|
||||
metrics_by_location: Optional[Dict[str, Any]]
|
||||
metrics_records_created: int
|
||||
error_message: Optional[str]
|
||||
triggered_by: str
|
||||
execution_mode: str
|
||||
|
||||
|
||||
class AccuracyTrendResponse(BaseModel):
|
||||
"""Response model for accuracy trends"""
|
||||
period_days: int
|
||||
total_runs: int
|
||||
average_mape: Optional[float]
|
||||
average_accuracy: Optional[float]
|
||||
trends: List[Dict[str, Any]]
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Endpoints
|
||||
# ================================================================
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("validation/validate-date-range"),
|
||||
response_model=ValidationResponse,
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def validate_date_range(
|
||||
validation_request: ValidationRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Validate forecasts against actual sales for a date range
|
||||
|
||||
This endpoint:
|
||||
- Fetches forecasts for the specified date range
|
||||
- Retrieves corresponding actual sales data
|
||||
- Calculates accuracy metrics (MAE, MAPE, RMSE, R², accuracy %)
|
||||
- Stores performance metrics in the database
|
||||
- Returns validation summary
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Starting date range validation",
|
||||
tenant_id=tenant_id,
|
||||
start_date=validation_request.start_date.isoformat(),
|
||||
end_date=validation_request.end_date.isoformat(),
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
validation_service = ValidationService(db)
|
||||
|
||||
result = await validation_service.validate_date_range(
|
||||
tenant_id=tenant_id,
|
||||
start_date=validation_request.start_date,
|
||||
end_date=validation_request.end_date,
|
||||
orchestration_run_id=validation_request.orchestration_run_id,
|
||||
triggered_by=validation_request.triggered_by
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Date range validation completed",
|
||||
tenant_id=tenant_id,
|
||||
validation_run_id=result.get("validation_run_id"),
|
||||
forecasts_evaluated=result.get("forecasts_evaluated")
|
||||
)
|
||||
|
||||
return ValidationResponse(**result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to validate date range",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to validate forecasts: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("validation/validate-yesterday"),
|
||||
response_model=ValidationResponse,
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def validate_yesterday(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
orchestration_run_id: Optional[UUID] = Query(None, description="Optional orchestration run ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Validate yesterday's forecasts against actual sales
|
||||
|
||||
Convenience endpoint for validating the most recent day's forecasts.
|
||||
This is typically called by the orchestrator as part of the daily workflow.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Starting yesterday validation",
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
validation_service = ValidationService(db)
|
||||
|
||||
result = await validation_service.validate_yesterday(
|
||||
tenant_id=tenant_id,
|
||||
orchestration_run_id=orchestration_run_id,
|
||||
triggered_by="manual"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Yesterday validation completed",
|
||||
tenant_id=tenant_id,
|
||||
validation_run_id=result.get("validation_run_id"),
|
||||
forecasts_evaluated=result.get("forecasts_evaluated")
|
||||
)
|
||||
|
||||
return ValidationResponse(**result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to validate yesterday",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to validate yesterday's forecasts: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("validation/runs/{validation_run_id}"),
|
||||
response_model=ValidationRunResponse,
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def get_validation_run(
|
||||
validation_run_id: UUID = Path(..., description="Validation run ID"),
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get details of a specific validation run
|
||||
|
||||
Returns complete information about a validation execution including:
|
||||
- Summary statistics
|
||||
- Overall accuracy metrics
|
||||
- Breakdown by product and location
|
||||
- Execution metadata
|
||||
"""
|
||||
try:
|
||||
validation_service = ValidationService(db)
|
||||
|
||||
validation_run = await validation_service.get_validation_run(validation_run_id)
|
||||
|
||||
if not validation_run:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Validation run {validation_run_id} not found"
|
||||
)
|
||||
|
||||
if validation_run.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to this validation run"
|
||||
)
|
||||
|
||||
return ValidationRunResponse(**validation_run.to_dict())
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get validation run",
|
||||
validation_run_id=validation_run_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get validation run: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("validation/runs"),
|
||||
response_model=List[ValidationRunResponse],
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def get_validation_runs(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
limit: int = Query(50, ge=1, le=100, description="Number of records to return"),
|
||||
skip: int = Query(0, ge=0, description="Number of records to skip"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get validation runs for a tenant
|
||||
|
||||
Returns a list of validation executions with pagination support.
|
||||
"""
|
||||
try:
|
||||
validation_service = ValidationService(db)
|
||||
|
||||
runs = await validation_service.get_validation_runs_by_tenant(
|
||||
tenant_id=tenant_id,
|
||||
limit=limit,
|
||||
skip=skip
|
||||
)
|
||||
|
||||
return [ValidationRunResponse(**run.to_dict()) for run in runs]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get validation runs",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get validation runs: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("validation/trends"),
|
||||
response_model=AccuracyTrendResponse,
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def get_accuracy_trends(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
days: int = Query(30, ge=1, le=365, description="Number of days to analyze"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get accuracy trends over time
|
||||
|
||||
Returns validation accuracy metrics over the specified time period.
|
||||
Useful for monitoring model performance degradation and improvement.
|
||||
"""
|
||||
try:
|
||||
validation_service = ValidationService(db)
|
||||
|
||||
trends = await validation_service.get_accuracy_trends(
|
||||
tenant_id=tenant_id,
|
||||
days=days
|
||||
)
|
||||
|
||||
return AccuracyTrendResponse(**trends)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get accuracy trends",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get accuracy trends: {str(e)}"
|
||||
)
|
||||
174
services/forecasting/app/api/webhooks.py
Normal file
174
services/forecasting/app/api/webhooks.py
Normal file
@@ -0,0 +1,174 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/api/webhooks.py
|
||||
# ================================================================
|
||||
"""
|
||||
Webhooks API - Receive events from other services
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, status, Header
|
||||
from typing import Dict, Any, Optional
|
||||
from uuid import UUID
|
||||
from datetime import date
|
||||
import structlog
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from app.jobs.sales_data_listener import (
|
||||
handle_sales_import_completion,
|
||||
handle_pos_sync_completion
|
||||
)
|
||||
from shared.routing import RouteBuilder
|
||||
|
||||
route_builder = RouteBuilder('forecasting')
|
||||
router = APIRouter(tags=["webhooks"])
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Request Schemas
|
||||
# ================================================================
|
||||
|
||||
class SalesImportWebhook(BaseModel):
|
||||
"""Webhook payload for sales data import completion"""
|
||||
tenant_id: UUID = Field(..., description="Tenant ID")
|
||||
import_job_id: str = Field(..., description="Import job ID")
|
||||
start_date: date = Field(..., description="Start date of imported data")
|
||||
end_date: date = Field(..., description="End date of imported data")
|
||||
records_count: int = Field(..., ge=0, description="Number of records imported")
|
||||
import_source: str = Field(default="import", description="Source of import")
|
||||
|
||||
|
||||
class POSSyncWebhook(BaseModel):
|
||||
"""Webhook payload for POS sync completion"""
|
||||
tenant_id: UUID = Field(..., description="Tenant ID")
|
||||
sync_log_id: str = Field(..., description="POS sync log ID")
|
||||
sync_date: date = Field(..., description="Date of synced data")
|
||||
records_synced: int = Field(..., ge=0, description="Number of records synced")
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Endpoints
|
||||
# ================================================================
|
||||
|
||||
@router.post(
|
||||
"/webhooks/sales-import-completed",
|
||||
status_code=status.HTTP_202_ACCEPTED
|
||||
)
|
||||
async def sales_import_completed_webhook(
|
||||
payload: SalesImportWebhook,
|
||||
x_webhook_signature: Optional[str] = Header(None, description="Webhook signature for verification")
|
||||
):
|
||||
"""
|
||||
Webhook endpoint for sales data import completion
|
||||
|
||||
Called by the sales service when a data import completes.
|
||||
Triggers validation backfill for the imported date range.
|
||||
|
||||
Note: In production, this should verify the webhook signature
|
||||
to ensure the request comes from a trusted source.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Received sales import completion webhook",
|
||||
tenant_id=payload.tenant_id,
|
||||
import_job_id=payload.import_job_id,
|
||||
date_range=f"{payload.start_date} to {payload.end_date}"
|
||||
)
|
||||
|
||||
# In production, verify webhook signature here
|
||||
# if not verify_webhook_signature(x_webhook_signature, payload):
|
||||
# raise HTTPException(status_code=401, detail="Invalid webhook signature")
|
||||
|
||||
# Handle the import completion asynchronously
|
||||
result = await handle_sales_import_completion(
|
||||
tenant_id=payload.tenant_id,
|
||||
import_job_id=payload.import_job_id,
|
||||
start_date=payload.start_date,
|
||||
end_date=payload.end_date,
|
||||
records_count=payload.records_count,
|
||||
import_source=payload.import_source
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "accepted",
|
||||
"message": "Sales import completion event received and processing",
|
||||
"result": result
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to process sales import webhook",
|
||||
payload=payload.dict(),
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to process webhook: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/webhooks/pos-sync-completed",
|
||||
status_code=status.HTTP_202_ACCEPTED
|
||||
)
|
||||
async def pos_sync_completed_webhook(
|
||||
payload: POSSyncWebhook,
|
||||
x_webhook_signature: Optional[str] = Header(None, description="Webhook signature for verification")
|
||||
):
|
||||
"""
|
||||
Webhook endpoint for POS sync completion
|
||||
|
||||
Called by the POS service when data synchronization completes.
|
||||
Triggers validation for the synced date.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Received POS sync completion webhook",
|
||||
tenant_id=payload.tenant_id,
|
||||
sync_log_id=payload.sync_log_id,
|
||||
sync_date=payload.sync_date.isoformat()
|
||||
)
|
||||
|
||||
# In production, verify webhook signature here
|
||||
# if not verify_webhook_signature(x_webhook_signature, payload):
|
||||
# raise HTTPException(status_code=401, detail="Invalid webhook signature")
|
||||
|
||||
# Handle the sync completion
|
||||
result = await handle_pos_sync_completion(
|
||||
tenant_id=payload.tenant_id,
|
||||
sync_log_id=payload.sync_log_id,
|
||||
sync_date=payload.sync_date,
|
||||
records_synced=payload.records_synced
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "accepted",
|
||||
"message": "POS sync completion event received and processing",
|
||||
"result": result
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to process POS sync webhook",
|
||||
payload=payload.dict(),
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to process webhook: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/webhooks/health",
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
async def webhook_health_check():
|
||||
"""Health check endpoint for webhook receiver"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "forecasting-webhooks",
|
||||
"endpoints": [
|
||||
"/webhooks/sales-import-completed",
|
||||
"/webhooks/pos-sync-completed"
|
||||
]
|
||||
}
|
||||
253
services/forecasting/app/clients/ai_insights_client.py
Normal file
253
services/forecasting/app/clients/ai_insights_client.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""
|
||||
AI Insights Service HTTP Client
|
||||
Posts insights from forecasting service to AI Insights Service
|
||||
"""
|
||||
|
||||
import httpx
|
||||
from typing import Dict, List, Any, Optional
|
||||
from uuid import UUID
|
||||
import structlog
|
||||
from datetime import datetime
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class AIInsightsClient:
|
||||
"""
|
||||
HTTP client for AI Insights Service.
|
||||
Allows forecasting service to post detected patterns and insights.
|
||||
"""
|
||||
|
||||
def __init__(self, base_url: str, timeout: int = 30):
|
||||
"""
|
||||
Initialize AI Insights client.
|
||||
|
||||
Args:
|
||||
base_url: Base URL of AI Insights Service (e.g., http://ai-insights-service:8000)
|
||||
timeout: Request timeout in seconds
|
||||
"""
|
||||
self.base_url = base_url.rstrip('/')
|
||||
self.timeout = timeout
|
||||
self.client = httpx.AsyncClient(timeout=self.timeout)
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP client."""
|
||||
await self.client.aclose()
|
||||
|
||||
async def create_insight(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
insight_data: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Create a new insight in AI Insights Service.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID
|
||||
insight_data: Insight data dictionary
|
||||
|
||||
Returns:
|
||||
Created insight dict or None if failed
|
||||
"""
|
||||
url = f"{self.base_url}/api/v1/ai-insights/tenants/{tenant_id}/insights"
|
||||
|
||||
try:
|
||||
# Ensure tenant_id is in the data
|
||||
insight_data['tenant_id'] = str(tenant_id)
|
||||
|
||||
response = await self.client.post(url, json=insight_data)
|
||||
|
||||
if response.status_code == 201:
|
||||
logger.info(
|
||||
"Insight created successfully",
|
||||
tenant_id=str(tenant_id),
|
||||
insight_title=insight_data.get('title')
|
||||
)
|
||||
return response.json()
|
||||
else:
|
||||
logger.error(
|
||||
"Failed to create insight",
|
||||
status_code=response.status_code,
|
||||
response=response.text,
|
||||
insight_title=insight_data.get('title')
|
||||
)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error creating insight",
|
||||
error=str(e),
|
||||
tenant_id=str(tenant_id)
|
||||
)
|
||||
return None
|
||||
|
||||
async def create_insights_bulk(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
insights: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create multiple insights in bulk.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID
|
||||
insights: List of insight data dictionaries
|
||||
|
||||
Returns:
|
||||
Dictionary with success/failure counts
|
||||
"""
|
||||
results = {
|
||||
'total': len(insights),
|
||||
'successful': 0,
|
||||
'failed': 0,
|
||||
'created_insights': []
|
||||
}
|
||||
|
||||
for insight_data in insights:
|
||||
result = await self.create_insight(tenant_id, insight_data)
|
||||
if result:
|
||||
results['successful'] += 1
|
||||
results['created_insights'].append(result)
|
||||
else:
|
||||
results['failed'] += 1
|
||||
|
||||
logger.info(
|
||||
"Bulk insight creation complete",
|
||||
total=results['total'],
|
||||
successful=results['successful'],
|
||||
failed=results['failed']
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def get_insights(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
filters: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get insights for a tenant.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID
|
||||
filters: Optional filters (category, priority, etc.)
|
||||
|
||||
Returns:
|
||||
Paginated insights response or None if failed
|
||||
"""
|
||||
url = f"{self.base_url}/api/v1/ai-insights/tenants/{tenant_id}/insights"
|
||||
|
||||
try:
|
||||
response = await self.client.get(url, params=filters or {})
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
logger.error(
|
||||
"Failed to get insights",
|
||||
status_code=response.status_code
|
||||
)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting insights", error=str(e))
|
||||
return None
|
||||
|
||||
async def get_orchestration_ready_insights(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
target_date: datetime,
|
||||
min_confidence: int = 70
|
||||
) -> Optional[Dict[str, List[Dict[str, Any]]]]:
|
||||
"""
|
||||
Get insights ready for orchestration workflow.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID
|
||||
target_date: Target date for orchestration
|
||||
min_confidence: Minimum confidence threshold
|
||||
|
||||
Returns:
|
||||
Categorized insights or None if failed
|
||||
"""
|
||||
url = f"{self.base_url}/api/v1/ai-insights/tenants/{tenant_id}/insights/orchestration-ready"
|
||||
|
||||
params = {
|
||||
'target_date': target_date.isoformat(),
|
||||
'min_confidence': min_confidence
|
||||
}
|
||||
|
||||
try:
|
||||
response = await self.client.get(url, params=params)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
logger.error(
|
||||
"Failed to get orchestration insights",
|
||||
status_code=response.status_code
|
||||
)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting orchestration insights", error=str(e))
|
||||
return None
|
||||
|
||||
async def record_feedback(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
insight_id: UUID,
|
||||
feedback_data: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Record feedback for an applied insight.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID
|
||||
insight_id: Insight UUID
|
||||
feedback_data: Feedback data
|
||||
|
||||
Returns:
|
||||
Feedback response or None if failed
|
||||
"""
|
||||
url = f"{self.base_url}/api/v1/ai-insights/tenants/{tenant_id}/insights/{insight_id}/feedback"
|
||||
|
||||
try:
|
||||
feedback_data['insight_id'] = str(insight_id)
|
||||
|
||||
response = await self.client.post(url, json=feedback_data)
|
||||
|
||||
if response.status_code in [200, 201]:
|
||||
logger.info(
|
||||
"Feedback recorded",
|
||||
insight_id=str(insight_id),
|
||||
success=feedback_data.get('success')
|
||||
)
|
||||
return response.json()
|
||||
else:
|
||||
logger.error(
|
||||
"Failed to record feedback",
|
||||
status_code=response.status_code
|
||||
)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error recording feedback", error=str(e))
|
||||
return None
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""
|
||||
Check if AI Insights Service is healthy.
|
||||
|
||||
Returns:
|
||||
True if healthy, False otherwise
|
||||
"""
|
||||
url = f"{self.base_url}/health"
|
||||
|
||||
try:
|
||||
response = await self.client.get(url)
|
||||
return response.status_code == 200
|
||||
|
||||
except Exception as e:
|
||||
logger.error("AI Insights Service health check failed", error=str(e))
|
||||
return False
|
||||
187
services/forecasting/app/consumers/forecast_event_consumer.py
Normal file
187
services/forecasting/app/consumers/forecast_event_consumer.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""
|
||||
Forecast event consumer for the forecasting service
|
||||
Handles events that should trigger cache invalidation for aggregated forecasts
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
import json
|
||||
import redis.asyncio as redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ForecastEventConsumer:
|
||||
"""
|
||||
Consumer for forecast events that may trigger cache invalidation
|
||||
"""
|
||||
|
||||
def __init__(self, redis_client: redis.Redis):
|
||||
self.redis_client = redis_client
|
||||
|
||||
async def handle_forecast_updated(self, event_data: Dict[str, Any]):
|
||||
"""
|
||||
Handle forecast updated event
|
||||
Invalidate parent tenant's aggregated forecast cache if this tenant is a child
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Handling forecast updated event: {event_data}")
|
||||
|
||||
tenant_id = event_data.get('tenant_id')
|
||||
forecast_date = event_data.get('forecast_date')
|
||||
product_id = event_data.get('product_id')
|
||||
updated_at = event_data.get('updated_at', None)
|
||||
|
||||
if not tenant_id:
|
||||
logger.error("Missing tenant_id in forecast event")
|
||||
return
|
||||
|
||||
# Check if this tenant is a child tenant (has parent)
|
||||
# In a real implementation, this would call the tenant service to check hierarchy
|
||||
parent_tenant_id = await self._get_parent_tenant_id(tenant_id)
|
||||
|
||||
if parent_tenant_id:
|
||||
# Invalidate parent's aggregated forecast cache
|
||||
await self._invalidate_parent_aggregated_cache(
|
||||
parent_tenant_id=parent_tenant_id,
|
||||
child_tenant_id=tenant_id,
|
||||
forecast_date=forecast_date,
|
||||
product_id=product_id
|
||||
)
|
||||
|
||||
logger.info(f"Forecast updated event processed for tenant {tenant_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling forecast updated event: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def handle_forecast_created(self, event_data: Dict[str, Any]):
|
||||
"""
|
||||
Handle forecast created event
|
||||
Similar to update, may affect parent tenant's aggregated forecasts
|
||||
"""
|
||||
await self.handle_forecast_updated(event_data)
|
||||
|
||||
async def handle_forecast_deleted(self, event_data: Dict[str, Any]):
|
||||
"""
|
||||
Handle forecast deleted event
|
||||
Similar to update, may affect parent tenant's aggregated forecasts
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Handling forecast deleted event: {event_data}")
|
||||
|
||||
tenant_id = event_data.get('tenant_id')
|
||||
forecast_date = event_data.get('forecast_date')
|
||||
product_id = event_data.get('product_id')
|
||||
|
||||
if not tenant_id:
|
||||
logger.error("Missing tenant_id in forecast delete event")
|
||||
return
|
||||
|
||||
# Check if this tenant is a child tenant
|
||||
parent_tenant_id = await self._get_parent_tenant_id(tenant_id)
|
||||
|
||||
if parent_tenant_id:
|
||||
# Invalidate parent's aggregated forecast cache
|
||||
await self._invalidate_parent_aggregated_cache(
|
||||
parent_tenant_id=parent_tenant_id,
|
||||
child_tenant_id=tenant_id,
|
||||
forecast_date=forecast_date,
|
||||
product_id=product_id
|
||||
)
|
||||
|
||||
logger.info(f"Forecast deleted event processed for tenant {tenant_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling forecast deleted event: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def _get_parent_tenant_id(self, tenant_id: str) -> Optional[str]:
|
||||
"""
|
||||
Get parent tenant ID for a child tenant using the tenant service
|
||||
"""
|
||||
try:
|
||||
from shared.clients.tenant_client import TenantServiceClient
|
||||
from shared.config.base import get_settings
|
||||
|
||||
# Create tenant client
|
||||
config = get_settings()
|
||||
tenant_client = TenantServiceClient(config)
|
||||
|
||||
# Get parent tenant information
|
||||
parent_tenant = await tenant_client.get_parent_tenant(tenant_id)
|
||||
|
||||
if parent_tenant:
|
||||
parent_tenant_id = parent_tenant.get('id')
|
||||
logger.info(f"Found parent tenant {parent_tenant_id} for child tenant {tenant_id}")
|
||||
return parent_tenant_id
|
||||
else:
|
||||
logger.debug(f"No parent tenant found for tenant {tenant_id} (tenant may be standalone or parent)")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting parent tenant ID for {tenant_id}: {e}")
|
||||
return None
|
||||
|
||||
async def _invalidate_parent_aggregated_cache(
|
||||
self,
|
||||
parent_tenant_id: str,
|
||||
child_tenant_id: str,
|
||||
forecast_date: Optional[str] = None,
|
||||
product_id: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Invalidate parent tenant's aggregated forecast cache
|
||||
"""
|
||||
try:
|
||||
# Pattern to match all aggregated forecast cache keys for this parent
|
||||
# Format: agg_forecast:{parent_tenant_id}:{start_date}:{end_date}:{product_id}
|
||||
pattern = f"agg_forecast:{parent_tenant_id}:*:*:*"
|
||||
|
||||
# Find all matching keys and delete them
|
||||
keys_to_delete = []
|
||||
async for key in self.redis_client.scan_iter(match=pattern):
|
||||
if isinstance(key, bytes):
|
||||
key = key.decode('utf-8')
|
||||
keys_to_delete.append(key)
|
||||
|
||||
if keys_to_delete:
|
||||
await self.redis_client.delete(*keys_to_delete)
|
||||
logger.info(f"Invalidated {len(keys_to_delete)} aggregated forecast cache entries for parent tenant {parent_tenant_id}")
|
||||
else:
|
||||
logger.info(f"No aggregated forecast cache entries found to invalidate for parent tenant {parent_tenant_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error invalidating parent aggregated cache: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def handle_tenant_hierarchy_changed(self, event_data: Dict[str, Any]):
|
||||
"""
|
||||
Handle tenant hierarchy change event
|
||||
This could be when a tenant becomes a child of another, or when the hierarchy changes
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Handling tenant hierarchy change event: {event_data}")
|
||||
|
||||
tenant_id = event_data.get('tenant_id')
|
||||
parent_tenant_id = event_data.get('parent_tenant_id')
|
||||
action = event_data.get('action') # 'added', 'removed', 'changed'
|
||||
|
||||
# Invalidate any cached aggregated forecasts that might be affected
|
||||
if parent_tenant_id:
|
||||
# If this child tenant changed, invalidate parent's cache
|
||||
await self._invalidate_parent_aggregated_cache(
|
||||
parent_tenant_id=parent_tenant_id,
|
||||
child_tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# If this was a former parent tenant that's no longer a parent,
|
||||
# its aggregated cache might need to be invalidated differently
|
||||
if action == 'removed' and event_data.get('was_parent'):
|
||||
# Invalidate its own aggregated cache since it's no longer a parent
|
||||
# This would be handled by tenant service events
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling tenant hierarchy change event: {e}", exc_info=True)
|
||||
raise
|
||||
0
services/forecasting/app/core/__init__.py
Normal file
0
services/forecasting/app/core/__init__.py
Normal file
81
services/forecasting/app/core/config.py
Normal file
81
services/forecasting/app/core/config.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# ================================================================
|
||||
# FORECASTING SERVICE CONFIGURATION
|
||||
# services/forecasting/app/core/config.py
|
||||
# ================================================================
|
||||
|
||||
"""
|
||||
Forecasting service configuration
|
||||
Demand prediction and forecasting
|
||||
"""
|
||||
|
||||
from shared.config.base import BaseServiceSettings
|
||||
import os
|
||||
|
||||
class ForecastingSettings(BaseServiceSettings):
|
||||
"""Forecasting service specific settings"""
|
||||
|
||||
# Service Identity
|
||||
APP_NAME: str = "Forecasting Service"
|
||||
SERVICE_NAME: str = "forecasting-service"
|
||||
DESCRIPTION: str = "Demand prediction and forecasting service"
|
||||
|
||||
# Database configuration (secure approach - build from components)
|
||||
@property
|
||||
def DATABASE_URL(self) -> str:
|
||||
"""Build database URL from secure components"""
|
||||
# Try complete URL first (for backward compatibility)
|
||||
complete_url = os.getenv("FORECASTING_DATABASE_URL")
|
||||
if complete_url:
|
||||
return complete_url
|
||||
|
||||
# Build from components (secure approach)
|
||||
user = os.getenv("FORECASTING_DB_USER", "forecasting_user")
|
||||
password = os.getenv("FORECASTING_DB_PASSWORD", "forecasting_pass123")
|
||||
host = os.getenv("FORECASTING_DB_HOST", "localhost")
|
||||
port = os.getenv("FORECASTING_DB_PORT", "5432")
|
||||
name = os.getenv("FORECASTING_DB_NAME", "forecasting_db")
|
||||
|
||||
return f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{name}"
|
||||
|
||||
# Redis Database (dedicated for prediction cache)
|
||||
REDIS_DB: int = 2
|
||||
|
||||
# Prediction Configuration
|
||||
MAX_FORECAST_DAYS: int = int(os.getenv("MAX_FORECAST_DAYS", "30"))
|
||||
MIN_HISTORICAL_DAYS: int = int(os.getenv("MIN_HISTORICAL_DAYS", "60"))
|
||||
PREDICTION_CONFIDENCE_THRESHOLD: float = float(os.getenv("PREDICTION_CONFIDENCE_THRESHOLD", "0.8"))
|
||||
|
||||
# Caching Configuration
|
||||
PREDICTION_CACHE_TTL_HOURS: int = int(os.getenv("PREDICTION_CACHE_TTL_HOURS", "6"))
|
||||
FORECAST_BATCH_SIZE: int = int(os.getenv("FORECAST_BATCH_SIZE", "100"))
|
||||
|
||||
# MinIO Configuration
|
||||
MINIO_ENDPOINT: str = os.getenv("MINIO_ENDPOINT", "minio.bakery-ia.svc.cluster.local:9000")
|
||||
MINIO_ACCESS_KEY: str = os.getenv("FORECASTING_MINIO_ACCESS_KEY", "forecasting-service")
|
||||
MINIO_SECRET_KEY: str = os.getenv("FORECASTING_MINIO_SECRET_KEY", "forecasting-secret-key")
|
||||
MINIO_USE_SSL: bool = os.getenv("MINIO_USE_SSL", "true").lower() == "true"
|
||||
MINIO_MODEL_BUCKET: str = os.getenv("MINIO_MODEL_BUCKET", "training-models")
|
||||
MINIO_CONSOLE_PORT: str = os.getenv("MINIO_CONSOLE_PORT", "9001")
|
||||
MINIO_API_PORT: str = os.getenv("MINIO_API_PORT", "9000")
|
||||
MINIO_REGION: str = os.getenv("MINIO_REGION", "us-east-1")
|
||||
MINIO_MODEL_LIFECYCLE_DAYS: int = int(os.getenv("MINIO_MODEL_LIFECYCLE_DAYS", "90"))
|
||||
MINIO_CACHE_TTL_SECONDS: int = int(os.getenv("MINIO_CACHE_TTL_SECONDS", "3600"))
|
||||
|
||||
# Real-time Forecasting
|
||||
REALTIME_FORECASTING_ENABLED: bool = os.getenv("REALTIME_FORECASTING_ENABLED", "true").lower() == "true"
|
||||
FORECAST_UPDATE_INTERVAL_HOURS: int = int(os.getenv("FORECAST_UPDATE_INTERVAL_HOURS", "6"))
|
||||
|
||||
# Business Rules for Spanish Bakeries
|
||||
BUSINESS_HOUR_START: int = 7 # 7 AM
|
||||
BUSINESS_HOUR_END: int = 20 # 8 PM
|
||||
WEEKEND_ADJUSTMENT_FACTOR: float = float(os.getenv("WEEKEND_ADJUSTMENT_FACTOR", "0.8"))
|
||||
HOLIDAY_ADJUSTMENT_FACTOR: float = float(os.getenv("HOLIDAY_ADJUSTMENT_FACTOR", "0.5"))
|
||||
|
||||
# Weather Impact Modeling
|
||||
WEATHER_IMPACT_ENABLED: bool = os.getenv("WEATHER_IMPACT_ENABLED", "true").lower() == "true"
|
||||
TEMPERATURE_THRESHOLD_COLD: float = float(os.getenv("TEMPERATURE_THRESHOLD_COLD", "10.0"))
|
||||
TEMPERATURE_THRESHOLD_HOT: float = float(os.getenv("TEMPERATURE_THRESHOLD_HOT", "30.0"))
|
||||
RAIN_IMPACT_FACTOR: float = float(os.getenv("RAIN_IMPACT_FACTOR", "0.7"))
|
||||
|
||||
|
||||
settings = ForecastingSettings()
|
||||
121
services/forecasting/app/core/database.py
Normal file
121
services/forecasting/app/core/database.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/core/database.py
|
||||
# ================================================================
|
||||
"""
|
||||
Database configuration for forecasting service
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
from sqlalchemy import text
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from app.core.config import settings
|
||||
from shared.database.base import Base, DatabaseManager
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Create async engine
|
||||
async_engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
echo=settings.DEBUG,
|
||||
pool_size=10,
|
||||
max_overflow=20,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=3600
|
||||
)
|
||||
|
||||
# Create async session factory
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
bind=async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False
|
||||
)
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Get database session"""
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error("Database session error", error=str(e))
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
async def init_database():
|
||||
"""Initialize database tables"""
|
||||
try:
|
||||
async with async_engine.begin() as conn:
|
||||
# Import all models to ensure they are registered
|
||||
from app.models.forecast import ForecastBatch, Forecast
|
||||
from app.models.prediction import PredictionBatch, Prediction
|
||||
|
||||
# Create all tables
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
logger.info("Forecasting database initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize forecasting database", error=str(e))
|
||||
raise
|
||||
|
||||
async def get_db_health() -> bool:
|
||||
"""Check database health"""
|
||||
try:
|
||||
async with async_engine.begin() as conn:
|
||||
await conn.execute(text("SELECT 1"))
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Database health check failed", error=str(e))
|
||||
return False
|
||||
|
||||
|
||||
async def get_connection_pool_stats() -> dict:
|
||||
"""
|
||||
Get current connection pool statistics for monitoring.
|
||||
|
||||
Returns:
|
||||
Dictionary with pool statistics including usage and capacity
|
||||
"""
|
||||
try:
|
||||
pool = async_engine.pool
|
||||
|
||||
# Get pool stats
|
||||
stats = {
|
||||
"pool_size": pool.size(),
|
||||
"checked_in_connections": pool.checkedin(),
|
||||
"checked_out_connections": pool.checkedout(),
|
||||
"overflow_connections": pool.overflow(),
|
||||
"total_connections": pool.size() + pool.overflow(),
|
||||
"max_capacity": 10 + 20, # pool_size + max_overflow
|
||||
"usage_percentage": round(((pool.size() + pool.overflow()) / 30) * 100, 2)
|
||||
}
|
||||
|
||||
# Add health status
|
||||
if stats["usage_percentage"] > 90:
|
||||
stats["status"] = "critical"
|
||||
stats["message"] = "Connection pool near capacity"
|
||||
elif stats["usage_percentage"] > 80:
|
||||
stats["status"] = "warning"
|
||||
stats["message"] = "Connection pool usage high"
|
||||
else:
|
||||
stats["status"] = "healthy"
|
||||
stats["message"] = "Connection pool healthy"
|
||||
|
||||
return stats
|
||||
except Exception as e:
|
||||
logger.error("Failed to get connection pool stats", error=str(e))
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Failed to get pool stats: {str(e)}"
|
||||
}
|
||||
|
||||
# Database manager instance for service_base compatibility
|
||||
database_manager = DatabaseManager(
|
||||
database_url=settings.DATABASE_URL,
|
||||
service_name="forecasting-service",
|
||||
pool_size=10,
|
||||
max_overflow=20,
|
||||
echo=settings.DEBUG
|
||||
)
|
||||
29
services/forecasting/app/jobs/__init__.py
Normal file
29
services/forecasting/app/jobs/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
Forecasting Service Jobs Package
|
||||
Scheduled and background jobs for the forecasting service
|
||||
"""
|
||||
|
||||
from .daily_validation import daily_validation_job, validate_date_range_job
|
||||
from .sales_data_listener import (
|
||||
handle_sales_import_completion,
|
||||
handle_pos_sync_completion,
|
||||
process_pending_validations
|
||||
)
|
||||
from .auto_backfill_job import (
|
||||
auto_backfill_all_tenants,
|
||||
process_all_pending_validations,
|
||||
daily_validation_maintenance_job,
|
||||
run_validation_maintenance_for_tenant
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"daily_validation_job",
|
||||
"validate_date_range_job",
|
||||
"handle_sales_import_completion",
|
||||
"handle_pos_sync_completion",
|
||||
"process_pending_validations",
|
||||
"auto_backfill_all_tenants",
|
||||
"process_all_pending_validations",
|
||||
"daily_validation_maintenance_job",
|
||||
"run_validation_maintenance_for_tenant",
|
||||
]
|
||||
275
services/forecasting/app/jobs/auto_backfill_job.py
Normal file
275
services/forecasting/app/jobs/auto_backfill_job.py
Normal file
@@ -0,0 +1,275 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/jobs/auto_backfill_job.py
|
||||
# ================================================================
|
||||
"""
|
||||
Automated Backfill Job
|
||||
|
||||
Scheduled job to automatically detect and backfill validation gaps.
|
||||
Can be run daily or weekly to ensure all historical forecasts are validated.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List
|
||||
from datetime import datetime, timezone
|
||||
import structlog
|
||||
import uuid
|
||||
|
||||
from app.services.historical_validation_service import HistoricalValidationService
|
||||
from app.core.database import database_manager
|
||||
from app.jobs.sales_data_listener import process_pending_validations
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
async def auto_backfill_all_tenants(
|
||||
tenant_ids: List[uuid.UUID],
|
||||
lookback_days: int = 90,
|
||||
max_gaps_per_tenant: int = 5
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run auto backfill for multiple tenants
|
||||
|
||||
Args:
|
||||
tenant_ids: List of tenant IDs to process
|
||||
lookback_days: How far back to check for gaps
|
||||
max_gaps_per_tenant: Maximum number of gaps to process per tenant
|
||||
|
||||
Returns:
|
||||
Summary of backfill operations across all tenants
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Starting auto backfill for all tenants",
|
||||
tenant_count=len(tenant_ids),
|
||||
lookback_days=lookback_days
|
||||
)
|
||||
|
||||
results = []
|
||||
total_gaps_found = 0
|
||||
total_gaps_processed = 0
|
||||
total_successful = 0
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
try:
|
||||
async with database_manager.get_session() as db:
|
||||
service = HistoricalValidationService(db)
|
||||
|
||||
result = await service.auto_backfill_gaps(
|
||||
tenant_id=tenant_id,
|
||||
lookback_days=lookback_days,
|
||||
max_gaps_to_process=max_gaps_per_tenant
|
||||
)
|
||||
|
||||
results.append({
|
||||
"tenant_id": str(tenant_id),
|
||||
"status": "success",
|
||||
**result
|
||||
})
|
||||
|
||||
total_gaps_found += result.get("gaps_found", 0)
|
||||
total_gaps_processed += result.get("gaps_processed", 0)
|
||||
total_successful += result.get("validations_completed", 0)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to auto backfill for tenant",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
results.append({
|
||||
"tenant_id": str(tenant_id),
|
||||
"status": "failed",
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
logger.info(
|
||||
"Auto backfill completed for all tenants",
|
||||
tenant_count=len(tenant_ids),
|
||||
total_gaps_found=total_gaps_found,
|
||||
total_gaps_processed=total_gaps_processed,
|
||||
total_successful=total_successful
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"tenants_processed": len(tenant_ids),
|
||||
"total_gaps_found": total_gaps_found,
|
||||
"total_gaps_processed": total_gaps_processed,
|
||||
"total_validations_completed": total_successful,
|
||||
"results": results
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Auto backfill job failed",
|
||||
error=str(e)
|
||||
)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
async def process_all_pending_validations(
|
||||
tenant_ids: List[uuid.UUID],
|
||||
max_per_tenant: int = 10
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Process all pending validations for multiple tenants
|
||||
|
||||
Args:
|
||||
tenant_ids: List of tenant IDs to process
|
||||
max_per_tenant: Maximum pending validations to process per tenant
|
||||
|
||||
Returns:
|
||||
Summary of processing results
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Processing pending validations for all tenants",
|
||||
tenant_count=len(tenant_ids)
|
||||
)
|
||||
|
||||
results = []
|
||||
total_pending = 0
|
||||
total_processed = 0
|
||||
total_successful = 0
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
try:
|
||||
result = await process_pending_validations(
|
||||
tenant_id=tenant_id,
|
||||
max_to_process=max_per_tenant
|
||||
)
|
||||
|
||||
results.append({
|
||||
"tenant_id": str(tenant_id),
|
||||
**result
|
||||
})
|
||||
|
||||
total_pending += result.get("pending_count", 0)
|
||||
total_processed += result.get("processed", 0)
|
||||
total_successful += result.get("successful", 0)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to process pending validations for tenant",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
results.append({
|
||||
"tenant_id": str(tenant_id),
|
||||
"status": "failed",
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
logger.info(
|
||||
"Pending validations processed for all tenants",
|
||||
tenant_count=len(tenant_ids),
|
||||
total_pending=total_pending,
|
||||
total_processed=total_processed,
|
||||
total_successful=total_successful
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"tenants_processed": len(tenant_ids),
|
||||
"total_pending": total_pending,
|
||||
"total_processed": total_processed,
|
||||
"total_successful": total_successful,
|
||||
"results": results
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to process all pending validations",
|
||||
error=str(e)
|
||||
)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
async def daily_validation_maintenance_job(
|
||||
tenant_ids: List[uuid.UUID]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Daily validation maintenance job
|
||||
|
||||
Combines gap detection/backfill and pending validation processing.
|
||||
Recommended to run once daily (e.g., 6:00 AM after orchestrator completes).
|
||||
|
||||
Args:
|
||||
tenant_ids: List of tenant IDs to process
|
||||
|
||||
Returns:
|
||||
Summary of all maintenance operations
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Starting daily validation maintenance",
|
||||
tenant_count=len(tenant_ids),
|
||||
timestamp=datetime.now(timezone.utc).isoformat()
|
||||
)
|
||||
|
||||
# Step 1: Process pending validations (retry failures)
|
||||
pending_result = await process_all_pending_validations(
|
||||
tenant_ids=tenant_ids,
|
||||
max_per_tenant=10
|
||||
)
|
||||
|
||||
# Step 2: Auto backfill detected gaps
|
||||
backfill_result = await auto_backfill_all_tenants(
|
||||
tenant_ids=tenant_ids,
|
||||
lookback_days=90,
|
||||
max_gaps_per_tenant=5
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Daily validation maintenance completed",
|
||||
pending_validations_processed=pending_result.get("total_processed", 0),
|
||||
gaps_backfilled=backfill_result.get("total_validations_completed", 0)
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"tenants_processed": len(tenant_ids),
|
||||
"pending_validations": pending_result,
|
||||
"gap_backfill": backfill_result,
|
||||
"summary": {
|
||||
"total_pending_processed": pending_result.get("total_processed", 0),
|
||||
"total_gaps_backfilled": backfill_result.get("total_validations_completed", 0),
|
||||
"total_validations": (
|
||||
pending_result.get("total_processed", 0) +
|
||||
backfill_result.get("total_validations_completed", 0)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Daily validation maintenance failed",
|
||||
error=str(e)
|
||||
)
|
||||
return {
|
||||
"status": "failed",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
# Convenience function for single tenant
|
||||
async def run_validation_maintenance_for_tenant(
|
||||
tenant_id: uuid.UUID
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run validation maintenance for a single tenant
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
|
||||
Returns:
|
||||
Maintenance results
|
||||
"""
|
||||
return await daily_validation_maintenance_job([tenant_id])
|
||||
147
services/forecasting/app/jobs/daily_validation.py
Normal file
147
services/forecasting/app/jobs/daily_validation.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/jobs/daily_validation.py
|
||||
# ================================================================
|
||||
"""
|
||||
Daily Validation Job
|
||||
|
||||
Scheduled job to validate previous day's forecasts against actual sales.
|
||||
This job is called by the orchestrator as part of the daily workflow.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import structlog
|
||||
import uuid
|
||||
|
||||
from app.services.validation_service import ValidationService
|
||||
from app.core.database import database_manager
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
async def daily_validation_job(
|
||||
tenant_id: uuid.UUID,
|
||||
orchestration_run_id: Optional[uuid.UUID] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate yesterday's forecasts against actual sales
|
||||
|
||||
This function is designed to be called by the orchestrator as part of
|
||||
the daily workflow (Step 5: validate_previous_forecasts).
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
orchestration_run_id: Optional orchestration run ID for tracking
|
||||
|
||||
Returns:
|
||||
Dictionary with validation results
|
||||
"""
|
||||
async with database_manager.get_session() as db:
|
||||
try:
|
||||
logger.info(
|
||||
"Starting daily validation job",
|
||||
tenant_id=tenant_id,
|
||||
orchestration_run_id=orchestration_run_id
|
||||
)
|
||||
|
||||
validation_service = ValidationService(db)
|
||||
|
||||
# Validate yesterday's forecasts
|
||||
result = await validation_service.validate_yesterday(
|
||||
tenant_id=tenant_id,
|
||||
orchestration_run_id=orchestration_run_id,
|
||||
triggered_by="orchestrator"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Daily validation job completed",
|
||||
tenant_id=tenant_id,
|
||||
validation_run_id=result.get("validation_run_id"),
|
||||
forecasts_evaluated=result.get("forecasts_evaluated"),
|
||||
forecasts_with_actuals=result.get("forecasts_with_actuals"),
|
||||
overall_mape=result.get("overall_metrics", {}).get("mape")
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Daily validation job failed",
|
||||
tenant_id=tenant_id,
|
||||
orchestration_run_id=orchestration_run_id,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__
|
||||
)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"tenant_id": str(tenant_id),
|
||||
"orchestration_run_id": str(orchestration_run_id) if orchestration_run_id else None
|
||||
}
|
||||
|
||||
|
||||
async def validate_date_range_job(
|
||||
tenant_id: uuid.UUID,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
orchestration_run_id: Optional[uuid.UUID] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate forecasts for a specific date range
|
||||
|
||||
Useful for backfilling validation metrics when historical data is uploaded.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
start_date: Start of validation period
|
||||
end_date: End of validation period
|
||||
orchestration_run_id: Optional orchestration run ID for tracking
|
||||
|
||||
Returns:
|
||||
Dictionary with validation results
|
||||
"""
|
||||
async with database_manager.get_session() as db:
|
||||
try:
|
||||
logger.info(
|
||||
"Starting date range validation job",
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date.isoformat(),
|
||||
end_date=end_date.isoformat(),
|
||||
orchestration_run_id=orchestration_run_id
|
||||
)
|
||||
|
||||
validation_service = ValidationService(db)
|
||||
|
||||
result = await validation_service.validate_date_range(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
orchestration_run_id=orchestration_run_id,
|
||||
triggered_by="scheduled"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Date range validation job completed",
|
||||
tenant_id=tenant_id,
|
||||
validation_run_id=result.get("validation_run_id"),
|
||||
forecasts_evaluated=result.get("forecasts_evaluated"),
|
||||
forecasts_with_actuals=result.get("forecasts_with_actuals")
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Date range validation job failed",
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date.isoformat(),
|
||||
end_date=end_date.isoformat(),
|
||||
error=str(e),
|
||||
error_type=type(e).__name__
|
||||
)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"tenant_id": str(tenant_id),
|
||||
"orchestration_run_id": str(orchestration_run_id) if orchestration_run_id else None
|
||||
}
|
||||
276
services/forecasting/app/jobs/sales_data_listener.py
Normal file
276
services/forecasting/app/jobs/sales_data_listener.py
Normal file
@@ -0,0 +1,276 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/jobs/sales_data_listener.py
|
||||
# ================================================================
|
||||
"""
|
||||
Sales Data Listener
|
||||
|
||||
Listens for sales data import completions and triggers validation backfill.
|
||||
Can be called via webhook, message queue, or direct API call from sales service.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime, date
|
||||
import structlog
|
||||
import uuid
|
||||
|
||||
from app.services.historical_validation_service import HistoricalValidationService
|
||||
from app.core.database import database_manager
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
async def handle_sales_import_completion(
|
||||
tenant_id: uuid.UUID,
|
||||
import_job_id: str,
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
records_count: int,
|
||||
import_source: str = "import"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Handle sales data import completion event
|
||||
|
||||
This function is called when the sales service completes a data import.
|
||||
It registers the update and triggers validation for the imported date range.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
import_job_id: Sales import job ID
|
||||
start_date: Start date of imported data
|
||||
end_date: End date of imported data
|
||||
records_count: Number of records imported
|
||||
import_source: Source of import (csv, xlsx, api, pos_sync)
|
||||
|
||||
Returns:
|
||||
Dictionary with registration and validation results
|
||||
"""
|
||||
async with database_manager.get_session() as db:
|
||||
try:
|
||||
logger.info(
|
||||
"Handling sales import completion",
|
||||
tenant_id=tenant_id,
|
||||
import_job_id=import_job_id,
|
||||
date_range=f"{start_date} to {end_date}",
|
||||
records_count=records_count
|
||||
)
|
||||
|
||||
service = HistoricalValidationService(db)
|
||||
|
||||
# Register the sales data update and trigger validation
|
||||
result = await service.register_sales_data_update(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
records_affected=records_count,
|
||||
update_source=import_source,
|
||||
import_job_id=import_job_id,
|
||||
auto_trigger_validation=True
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Sales import completion handled",
|
||||
tenant_id=tenant_id,
|
||||
import_job_id=import_job_id,
|
||||
update_id=result.get("update_id"),
|
||||
validation_triggered=result.get("validation_triggered")
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"tenant_id": str(tenant_id),
|
||||
"import_job_id": import_job_id,
|
||||
**result
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to handle sales import completion",
|
||||
tenant_id=tenant_id,
|
||||
import_job_id=import_job_id,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__
|
||||
)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"tenant_id": str(tenant_id),
|
||||
"import_job_id": import_job_id
|
||||
}
|
||||
|
||||
|
||||
async def handle_pos_sync_completion(
|
||||
tenant_id: uuid.UUID,
|
||||
sync_log_id: str,
|
||||
sync_date: date,
|
||||
records_synced: int
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Handle POS sync completion event
|
||||
|
||||
Called when POS data is synchronized to the sales service.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
sync_log_id: POS sync log ID
|
||||
sync_date: Date of synced data
|
||||
records_synced: Number of records synced
|
||||
|
||||
Returns:
|
||||
Dictionary with registration and validation results
|
||||
"""
|
||||
async with database_manager.get_session() as db:
|
||||
try:
|
||||
logger.info(
|
||||
"Handling POS sync completion",
|
||||
tenant_id=tenant_id,
|
||||
sync_log_id=sync_log_id,
|
||||
sync_date=sync_date.isoformat(),
|
||||
records_synced=records_synced
|
||||
)
|
||||
|
||||
service = HistoricalValidationService(db)
|
||||
|
||||
# For POS syncs, we typically validate just the sync date
|
||||
result = await service.register_sales_data_update(
|
||||
tenant_id=tenant_id,
|
||||
start_date=sync_date,
|
||||
end_date=sync_date,
|
||||
records_affected=records_synced,
|
||||
update_source="pos_sync",
|
||||
import_job_id=sync_log_id,
|
||||
auto_trigger_validation=True
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"POS sync completion handled",
|
||||
tenant_id=tenant_id,
|
||||
sync_log_id=sync_log_id,
|
||||
update_id=result.get("update_id")
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"tenant_id": str(tenant_id),
|
||||
"sync_log_id": sync_log_id,
|
||||
**result
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to handle POS sync completion",
|
||||
tenant_id=tenant_id,
|
||||
sync_log_id=sync_log_id,
|
||||
error=str(e)
|
||||
)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"tenant_id": str(tenant_id),
|
||||
"sync_log_id": sync_log_id
|
||||
}
|
||||
|
||||
|
||||
async def process_pending_validations(
|
||||
tenant_id: Optional[uuid.UUID] = None,
|
||||
max_to_process: int = 10
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Process pending validation requests
|
||||
|
||||
Can be run as a scheduled job to handle any pending validations
|
||||
that failed to trigger automatically.
|
||||
|
||||
Args:
|
||||
tenant_id: Optional tenant ID to filter (process all tenants if None)
|
||||
max_to_process: Maximum number of pending validations to process
|
||||
|
||||
Returns:
|
||||
Summary of processing results
|
||||
"""
|
||||
async with database_manager.get_session() as db:
|
||||
try:
|
||||
logger.info(
|
||||
"Processing pending validations",
|
||||
tenant_id=tenant_id,
|
||||
max_to_process=max_to_process
|
||||
)
|
||||
|
||||
service = HistoricalValidationService(db)
|
||||
|
||||
if tenant_id:
|
||||
# Process specific tenant
|
||||
pending = await service.get_pending_validations(
|
||||
tenant_id=tenant_id,
|
||||
limit=max_to_process
|
||||
)
|
||||
else:
|
||||
# Would need to implement get_all_pending_validations for all tenants
|
||||
# For now, require tenant_id
|
||||
logger.warning("Processing all tenants not implemented, tenant_id required")
|
||||
return {
|
||||
"status": "skipped",
|
||||
"message": "tenant_id required"
|
||||
}
|
||||
|
||||
if not pending:
|
||||
logger.info("No pending validations found")
|
||||
return {
|
||||
"status": "success",
|
||||
"pending_count": 0,
|
||||
"processed": 0
|
||||
}
|
||||
|
||||
results = []
|
||||
for update_record in pending:
|
||||
try:
|
||||
result = await service.backfill_validation(
|
||||
tenant_id=update_record.tenant_id,
|
||||
start_date=update_record.update_date_start,
|
||||
end_date=update_record.update_date_end,
|
||||
triggered_by="pending_processor",
|
||||
sales_data_update_id=update_record.id
|
||||
)
|
||||
results.append({
|
||||
"update_id": str(update_record.id),
|
||||
"status": "success",
|
||||
"validation_run_id": result.get("validation_run_id")
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to process pending validation",
|
||||
update_id=update_record.id,
|
||||
error=str(e)
|
||||
)
|
||||
results.append({
|
||||
"update_id": str(update_record.id),
|
||||
"status": "failed",
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
successful = sum(1 for r in results if r["status"] == "success")
|
||||
|
||||
logger.info(
|
||||
"Pending validations processed",
|
||||
pending_count=len(pending),
|
||||
processed=len(results),
|
||||
successful=successful
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"pending_count": len(pending),
|
||||
"processed": len(results),
|
||||
"successful": successful,
|
||||
"failed": len(results) - successful,
|
||||
"results": results
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to process pending validations",
|
||||
error=str(e)
|
||||
)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": str(e)
|
||||
}
|
||||
208
services/forecasting/app/main.py
Normal file
208
services/forecasting/app/main.py
Normal file
@@ -0,0 +1,208 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/main.py
|
||||
# ================================================================
|
||||
"""
|
||||
Forecasting Service Main Application
|
||||
Demand prediction and forecasting service for bakery operations
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI
|
||||
from sqlalchemy import text
|
||||
from app.core.config import settings
|
||||
from app.core.database import database_manager
|
||||
from app.services.forecasting_alert_service import ForecastingAlertService
|
||||
from shared.service_base import StandardFastAPIService
|
||||
|
||||
# Import API routers
|
||||
from app.api import forecasts, forecasting_operations, analytics, scenario_operations, audit, ml_insights, validation, historical_validation, webhooks, performance_monitoring, retraining, enterprise_forecasting, internal_demo, forecast_feedback
|
||||
|
||||
|
||||
class ForecastingService(StandardFastAPIService):
|
||||
"""Forecasting Service with standardized setup"""
|
||||
|
||||
expected_migration_version = "00003"
|
||||
|
||||
async def on_startup(self, app):
|
||||
"""Custom startup logic including migration verification"""
|
||||
await self.verify_migrations()
|
||||
await super().on_startup(app)
|
||||
|
||||
async def verify_migrations(self):
|
||||
"""Verify database schema matches the latest migrations."""
|
||||
try:
|
||||
async with self.database_manager.get_session() as session:
|
||||
result = await session.execute(text("SELECT version_num FROM alembic_version"))
|
||||
version = result.scalar()
|
||||
if version != self.expected_migration_version:
|
||||
self.logger.error(f"Migration version mismatch: expected {self.expected_migration_version}, got {version}")
|
||||
raise RuntimeError(f"Migration version mismatch: expected {self.expected_migration_version}, got {version}")
|
||||
self.logger.info(f"Migration verification successful: {version}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Migration verification failed: {e}")
|
||||
raise
|
||||
|
||||
def __init__(self):
|
||||
# Define expected database tables for health checks
|
||||
forecasting_expected_tables = [
|
||||
'forecasts', 'prediction_batches', 'model_performance_metrics', 'prediction_cache', 'validation_runs', 'sales_data_updates'
|
||||
]
|
||||
|
||||
self.alert_service = None
|
||||
self.rabbitmq_client = None
|
||||
self.event_publisher = None
|
||||
|
||||
# Create custom checks for alert service
|
||||
async def alert_service_check():
|
||||
"""Custom health check for forecasting alert service"""
|
||||
return await self.alert_service.health_check() if self.alert_service else False
|
||||
|
||||
# Define custom metrics for forecasting service
|
||||
forecasting_custom_metrics = {
|
||||
"forecasts_generated_total": {
|
||||
"type": "counter",
|
||||
"description": "Total forecasts generated"
|
||||
},
|
||||
"predictions_served_total": {
|
||||
"type": "counter",
|
||||
"description": "Total predictions served"
|
||||
},
|
||||
"prediction_errors_total": {
|
||||
"type": "counter",
|
||||
"description": "Total prediction errors"
|
||||
},
|
||||
"forecast_processing_time_seconds": {
|
||||
"type": "histogram",
|
||||
"description": "Time to process forecast request"
|
||||
},
|
||||
"prediction_processing_time_seconds": {
|
||||
"type": "histogram",
|
||||
"description": "Time to process prediction request"
|
||||
},
|
||||
"model_cache_hits_total": {
|
||||
"type": "counter",
|
||||
"description": "Total model cache hits"
|
||||
},
|
||||
"model_cache_misses_total": {
|
||||
"type": "counter",
|
||||
"description": "Total model cache misses"
|
||||
}
|
||||
}
|
||||
|
||||
super().__init__(
|
||||
service_name="forecasting-service",
|
||||
app_name="Bakery Forecasting Service",
|
||||
description="AI-powered demand prediction and forecasting service for bakery operations",
|
||||
version="1.0.0",
|
||||
log_level=settings.LOG_LEVEL,
|
||||
cors_origins=settings.CORS_ORIGINS_LIST,
|
||||
api_prefix="", # Empty because RouteBuilder already includes /api/v1
|
||||
database_manager=database_manager,
|
||||
expected_tables=forecasting_expected_tables,
|
||||
custom_health_checks={"alert_service": alert_service_check},
|
||||
enable_messaging=True,
|
||||
custom_metrics=forecasting_custom_metrics
|
||||
)
|
||||
|
||||
async def _setup_messaging(self):
|
||||
"""Setup messaging for forecasting service using unified messaging"""
|
||||
from shared.messaging import UnifiedEventPublisher, RabbitMQClient
|
||||
try:
|
||||
self.rabbitmq_client = RabbitMQClient(settings.RABBITMQ_URL, service_name="forecasting-service")
|
||||
await self.rabbitmq_client.connect()
|
||||
# Create unified event publisher
|
||||
self.event_publisher = UnifiedEventPublisher(self.rabbitmq_client, "forecasting-service")
|
||||
self.logger.info("Forecasting service unified messaging setup completed")
|
||||
except Exception as e:
|
||||
self.logger.error("Failed to setup forecasting unified messaging", error=str(e))
|
||||
raise
|
||||
|
||||
async def _cleanup_messaging(self):
|
||||
"""Cleanup messaging for forecasting service"""
|
||||
try:
|
||||
if self.rabbitmq_client:
|
||||
await self.rabbitmq_client.disconnect()
|
||||
self.logger.info("Forecasting service messaging cleanup completed")
|
||||
except Exception as e:
|
||||
self.logger.error("Error during forecasting messaging cleanup", error=str(e))
|
||||
|
||||
async def on_startup(self, app: FastAPI):
|
||||
"""Custom startup logic for forecasting service"""
|
||||
await super().on_startup(app)
|
||||
|
||||
# Initialize forecasting alert service with EventPublisher
|
||||
if self.event_publisher:
|
||||
self.alert_service = ForecastingAlertService(self.event_publisher)
|
||||
await self.alert_service.start()
|
||||
self.logger.info("Forecasting alert service initialized")
|
||||
else:
|
||||
self.logger.error("Event publisher not initialized, alert service unavailable")
|
||||
|
||||
# Store the event publisher in app state for internal API access
|
||||
app.state.event_publisher = self.event_publisher
|
||||
|
||||
|
||||
async def on_shutdown(self, app: FastAPI):
|
||||
"""Custom shutdown logic for forecasting service"""
|
||||
# Cleanup alert service
|
||||
if self.alert_service:
|
||||
await self.alert_service.stop()
|
||||
self.logger.info("Alert service cleanup completed")
|
||||
|
||||
def get_service_features(self):
|
||||
"""Return forecasting-specific features"""
|
||||
return [
|
||||
"demand_prediction",
|
||||
"ai_forecasting",
|
||||
"model_performance_tracking",
|
||||
"prediction_caching",
|
||||
"alert_notifications",
|
||||
"messaging_integration"
|
||||
]
|
||||
|
||||
def setup_custom_endpoints(self):
|
||||
"""Setup custom endpoints for forecasting service"""
|
||||
@self.app.get("/alert-metrics")
|
||||
async def get_alert_metrics():
|
||||
"""Alert service metrics endpoint"""
|
||||
if self.alert_service:
|
||||
return self.alert_service.get_metrics()
|
||||
return {"error": "Alert service not initialized"}
|
||||
|
||||
|
||||
# Create service instance
|
||||
service = ForecastingService()
|
||||
|
||||
# Create FastAPI app with standardized setup
|
||||
app = service.create_app(
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc"
|
||||
)
|
||||
|
||||
# Setup standard endpoints
|
||||
service.setup_standard_endpoints()
|
||||
|
||||
# Setup custom endpoints
|
||||
service.setup_custom_endpoints()
|
||||
|
||||
# Include API routers
|
||||
# IMPORTANT: Register audit router FIRST to avoid route matching conflicts
|
||||
service.add_router(audit.router)
|
||||
service.add_router(forecasts.router)
|
||||
service.add_router(forecasting_operations.router)
|
||||
service.add_router(analytics.router)
|
||||
service.add_router(scenario_operations.router)
|
||||
service.add_router(internal_demo.router, tags=["internal-demo"])
|
||||
service.add_router(ml_insights.router) # ML insights endpoint
|
||||
service.add_router(ml_insights.internal_router) # Internal ML insights endpoint
|
||||
service.add_router(validation.router) # Validation endpoint
|
||||
service.add_router(historical_validation.router) # Historical validation endpoint
|
||||
service.add_router(webhooks.router) # Webhooks endpoint
|
||||
service.add_router(performance_monitoring.router) # Performance monitoring endpoint
|
||||
service.add_router(retraining.router) # Retraining endpoint
|
||||
service.add_router(enterprise_forecasting.router) # Enterprise forecasting endpoint
|
||||
service.add_router(forecast_feedback.router) # Forecast feedback endpoint
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
|
||||
11
services/forecasting/app/ml/__init__.py
Normal file
11
services/forecasting/app/ml/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
ML Components for Forecasting
|
||||
Machine learning prediction and forecasting components
|
||||
"""
|
||||
|
||||
from .predictor import BakeryPredictor, BakeryForecaster
|
||||
|
||||
__all__ = [
|
||||
"BakeryPredictor",
|
||||
"BakeryForecaster"
|
||||
]
|
||||
@@ -0,0 +1,393 @@
|
||||
"""
|
||||
Business Rules Insights Orchestrator
|
||||
Coordinates business rules optimization and insight posting
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from typing import Dict, List, Any, Optional
|
||||
import structlog
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add shared clients to path
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../../../..'))
|
||||
from shared.clients.ai_insights_client import AIInsightsClient
|
||||
from shared.messaging import UnifiedEventPublisher
|
||||
|
||||
from app.ml.dynamic_rules_engine import DynamicRulesEngine
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class BusinessRulesInsightsOrchestrator:
|
||||
"""
|
||||
Orchestrates business rules analysis and insight generation workflow.
|
||||
|
||||
Workflow:
|
||||
1. Analyze dynamic business rule performance
|
||||
2. Generate insights for rule optimization
|
||||
3. Post insights to AI Insights Service
|
||||
4. Publish recommendation events to RabbitMQ
|
||||
5. Provide rule optimization for forecasting
|
||||
6. Track rule effectiveness and improvements
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ai_insights_base_url: str = "http://ai-insights-service:8000",
|
||||
event_publisher: Optional[UnifiedEventPublisher] = None
|
||||
):
|
||||
self.rules_engine = DynamicRulesEngine()
|
||||
self.ai_insights_client = AIInsightsClient(ai_insights_base_url)
|
||||
self.event_publisher = event_publisher
|
||||
|
||||
async def analyze_and_post_business_rules_insights(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
sales_data: pd.DataFrame,
|
||||
min_samples: int = 10
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Complete workflow: Analyze business rules and post insights.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
inventory_product_id: Product identifier
|
||||
sales_data: Historical sales data
|
||||
min_samples: Minimum samples for rule analysis
|
||||
|
||||
Returns:
|
||||
Workflow results with analysis and posted insights
|
||||
"""
|
||||
logger.info(
|
||||
"Starting business rules analysis workflow",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
samples=len(sales_data)
|
||||
)
|
||||
|
||||
# Step 1: Learn and analyze rules
|
||||
rules_results = await self.rules_engine.learn_all_rules(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
sales_data=sales_data,
|
||||
external_data=None,
|
||||
min_samples=min_samples
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Business rules analysis complete",
|
||||
insights_generated=len(rules_results.get('insights', [])),
|
||||
rules_learned=len(rules_results.get('rules', {}))
|
||||
)
|
||||
|
||||
# Step 2: Enrich insights with tenant_id and product context
|
||||
enriched_insights = self._enrich_insights(
|
||||
rules_results.get('insights', []),
|
||||
tenant_id,
|
||||
inventory_product_id
|
||||
)
|
||||
|
||||
# Step 3: Post insights to AI Insights Service
|
||||
if enriched_insights:
|
||||
post_results = await self.ai_insights_client.create_insights_bulk(
|
||||
tenant_id=UUID(tenant_id),
|
||||
insights=enriched_insights
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Business rules insights posted to AI Insights Service",
|
||||
inventory_product_id=inventory_product_id,
|
||||
total=post_results['total'],
|
||||
successful=post_results['successful'],
|
||||
failed=post_results['failed']
|
||||
)
|
||||
else:
|
||||
post_results = {'total': 0, 'successful': 0, 'failed': 0}
|
||||
logger.info("No insights to post for product", inventory_product_id=inventory_product_id)
|
||||
|
||||
# Step 4: Publish insight events to RabbitMQ
|
||||
created_insights = post_results.get('created_insights', [])
|
||||
if created_insights:
|
||||
product_context = {'inventory_product_id': inventory_product_id}
|
||||
await self._publish_insight_events(
|
||||
tenant_id=tenant_id,
|
||||
insights=created_insights,
|
||||
product_context=product_context
|
||||
)
|
||||
|
||||
# Step 5: Return comprehensive results
|
||||
return {
|
||||
'tenant_id': tenant_id,
|
||||
'inventory_product_id': inventory_product_id,
|
||||
'learned_at': rules_results['learned_at'],
|
||||
'rules': rules_results.get('rules', {}),
|
||||
'insights_generated': len(enriched_insights),
|
||||
'insights_posted': post_results['successful'],
|
||||
'insights_failed': post_results['failed'],
|
||||
'created_insights': post_results.get('created_insights', [])
|
||||
}
|
||||
|
||||
def _enrich_insights(
|
||||
self,
|
||||
insights: List[Dict[str, Any]],
|
||||
tenant_id: str,
|
||||
inventory_product_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Enrich insights with required fields for AI Insights Service.
|
||||
|
||||
Args:
|
||||
insights: Raw insights from rules engine
|
||||
tenant_id: Tenant identifier
|
||||
inventory_product_id: Product identifier
|
||||
|
||||
Returns:
|
||||
Enriched insights ready for posting
|
||||
"""
|
||||
enriched = []
|
||||
|
||||
for insight in insights:
|
||||
# Add required tenant_id
|
||||
enriched_insight = insight.copy()
|
||||
enriched_insight['tenant_id'] = tenant_id
|
||||
|
||||
# Add product context to metrics
|
||||
if 'metrics_json' not in enriched_insight:
|
||||
enriched_insight['metrics_json'] = {}
|
||||
|
||||
enriched_insight['metrics_json']['inventory_product_id'] = inventory_product_id
|
||||
|
||||
# Add source metadata
|
||||
enriched_insight['source_service'] = 'forecasting'
|
||||
enriched_insight['source_model'] = 'dynamic_rules_engine'
|
||||
enriched_insight['detected_at'] = datetime.utcnow().isoformat()
|
||||
|
||||
enriched.append(enriched_insight)
|
||||
|
||||
return enriched
|
||||
|
||||
async def analyze_all_business_rules(
|
||||
self,
|
||||
tenant_id: str,
|
||||
products_data: Dict[str, pd.DataFrame],
|
||||
min_samples: int = 10
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Analyze all products for business rules optimization and generate comparative insights.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
products_data: Dict of {inventory_product_id: sales_data DataFrame}
|
||||
min_samples: Minimum samples for rule analysis
|
||||
|
||||
Returns:
|
||||
Comprehensive analysis with rule optimization insights
|
||||
"""
|
||||
logger.info(
|
||||
"Analyzing business rules for all products",
|
||||
tenant_id=tenant_id,
|
||||
products=len(products_data)
|
||||
)
|
||||
|
||||
all_results = []
|
||||
total_insights_posted = 0
|
||||
|
||||
# Analyze each product
|
||||
for inventory_product_id, sales_data in products_data.items():
|
||||
try:
|
||||
results = await self.analyze_and_post_business_rules_insights(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
sales_data=sales_data,
|
||||
min_samples=min_samples
|
||||
)
|
||||
|
||||
all_results.append(results)
|
||||
total_insights_posted += results['insights_posted']
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error analyzing business rules for product",
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
# Generate summary insight
|
||||
if total_insights_posted > 0:
|
||||
summary_insight = self._generate_portfolio_summary_insight(
|
||||
tenant_id, all_results
|
||||
)
|
||||
|
||||
if summary_insight:
|
||||
enriched_summary = self._enrich_insights(
|
||||
[summary_insight], tenant_id, 'all_products'
|
||||
)
|
||||
|
||||
post_results = await self.ai_insights_client.create_insights_bulk(
|
||||
tenant_id=UUID(tenant_id),
|
||||
insights=enriched_summary
|
||||
)
|
||||
|
||||
total_insights_posted += post_results['successful']
|
||||
|
||||
logger.info(
|
||||
"All business rules analysis complete",
|
||||
tenant_id=tenant_id,
|
||||
products_analyzed=len(all_results),
|
||||
total_insights_posted=total_insights_posted
|
||||
)
|
||||
|
||||
return {
|
||||
'tenant_id': tenant_id,
|
||||
'analyzed_at': datetime.utcnow().isoformat(),
|
||||
'products_analyzed': len(all_results),
|
||||
'product_results': all_results,
|
||||
'total_insights_posted': total_insights_posted
|
||||
}
|
||||
|
||||
def _generate_portfolio_summary_insight(
|
||||
self,
|
||||
tenant_id: str,
|
||||
all_results: List[Dict[str, Any]]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Generate portfolio-level business rules summary insight.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
all_results: All product analysis results
|
||||
|
||||
Returns:
|
||||
Summary insight or None
|
||||
"""
|
||||
if not all_results:
|
||||
return None
|
||||
|
||||
# Calculate summary statistics
|
||||
total_products = len(all_results)
|
||||
total_rules = sum(len(r.get('rules', {})) for r in all_results)
|
||||
|
||||
# Count products with significant rule improvements
|
||||
significant_improvements = sum(1 for r in all_results
|
||||
if any('improvement' in str(v).lower() for v in r.get('rules', {}).values()))
|
||||
|
||||
return {
|
||||
'type': 'recommendation',
|
||||
'priority': 'high' if significant_improvements > total_products * 0.3 else 'medium',
|
||||
'category': 'forecasting',
|
||||
'title': f'Business Rule Optimization: {total_products} Products Analyzed',
|
||||
'description': f'Learned {total_rules} dynamic rules across {total_products} products. Identified {significant_improvements} products with significant rule improvements.',
|
||||
'impact_type': 'operational_efficiency',
|
||||
'impact_value': total_rules,
|
||||
'impact_unit': 'rules',
|
||||
'confidence': 80,
|
||||
'metrics_json': {
|
||||
'total_products': total_products,
|
||||
'total_rules': total_rules,
|
||||
'significant_improvements': significant_improvements,
|
||||
'rules_per_product': round(total_rules / total_products, 2)
|
||||
},
|
||||
'actionable': True,
|
||||
'recommendation_actions': [
|
||||
{
|
||||
'label': 'Review Learned Rules',
|
||||
'action': 'review_business_rules',
|
||||
'params': {'tenant_id': tenant_id}
|
||||
},
|
||||
{
|
||||
'label': 'Implement Optimized Rules',
|
||||
'action': 'implement_business_rules',
|
||||
'params': {'tenant_id': tenant_id}
|
||||
}
|
||||
],
|
||||
'source_service': 'forecasting',
|
||||
'source_model': 'dynamic_rules_engine'
|
||||
}
|
||||
|
||||
async def get_learned_rules(
|
||||
self,
|
||||
inventory_product_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get cached learned rules for a product.
|
||||
|
||||
Args:
|
||||
inventory_product_id: Product identifier
|
||||
|
||||
Returns:
|
||||
Learned rules or None if not analyzed
|
||||
"""
|
||||
return self.rules_engine.get_all_rules(inventory_product_id)
|
||||
|
||||
async def _publish_insight_events(self, tenant_id, insights, product_context=None):
|
||||
"""
|
||||
Publish insight events to RabbitMQ for alert processing.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
insights: List of created insights
|
||||
product_context: Additional context about the product
|
||||
"""
|
||||
if not self.event_publisher:
|
||||
logger.warning("No event publisher available for business rules insights")
|
||||
return
|
||||
|
||||
for insight in insights:
|
||||
# Determine severity based on confidence and priority
|
||||
confidence = insight.get('confidence', 0)
|
||||
priority = insight.get('priority', 'medium')
|
||||
|
||||
# Map priority to severity, with confidence as tiebreaker
|
||||
if priority == 'critical' or (priority == 'high' and confidence >= 70):
|
||||
severity = 'high'
|
||||
elif priority == 'high' or (priority == 'medium' and confidence >= 80):
|
||||
severity = 'medium'
|
||||
else:
|
||||
severity = 'low'
|
||||
|
||||
# Prepare the event data
|
||||
event_data = {
|
||||
'insight_id': insight.get('id'),
|
||||
'type': insight.get('type'),
|
||||
'title': insight.get('title'),
|
||||
'description': insight.get('description'),
|
||||
'category': insight.get('category'),
|
||||
'priority': insight.get('priority'),
|
||||
'confidence': confidence,
|
||||
'recommendation': insight.get('recommendation_actions', []),
|
||||
'impact_type': insight.get('impact_type'),
|
||||
'impact_value': insight.get('impact_value'),
|
||||
'inventory_product_id': product_context.get('inventory_product_id') if product_context else None,
|
||||
'timestamp': insight.get('detected_at', datetime.utcnow().isoformat()),
|
||||
'source_service': 'forecasting',
|
||||
'source_model': 'dynamic_rules_engine'
|
||||
}
|
||||
|
||||
try:
|
||||
await self.event_publisher.publish_recommendation(
|
||||
event_type='ai_business_rule',
|
||||
tenant_id=tenant_id,
|
||||
severity=severity,
|
||||
data=event_data
|
||||
)
|
||||
logger.info(
|
||||
"Published business rules insight event",
|
||||
tenant_id=tenant_id,
|
||||
insight_id=insight.get('id'),
|
||||
severity=severity
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to publish business rules insight event",
|
||||
tenant_id=tenant_id,
|
||||
insight_id=insight.get('id'),
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
"""Close HTTP client connections."""
|
||||
await self.ai_insights_client.close()
|
||||
235
services/forecasting/app/ml/calendar_features.py
Normal file
235
services/forecasting/app/ml/calendar_features.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""
|
||||
Calendar-based Feature Engineering for Forecasting Service
|
||||
Generates calendar features for future date predictions
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import structlog
|
||||
from typing import Dict, List, Any, Optional
|
||||
from datetime import datetime, date, time, timedelta
|
||||
from app.services.data_client import data_client
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class ForecastCalendarFeatures:
|
||||
"""
|
||||
Generates calendar-based features for future predictions
|
||||
Optimized for forecasting service (future dates only)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.calendar_cache = {} # Cache calendar data per tenant
|
||||
|
||||
async def get_calendar_for_tenant(
|
||||
self,
|
||||
tenant_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get cached calendar for tenant"""
|
||||
if tenant_id in self.calendar_cache:
|
||||
return self.calendar_cache[tenant_id]
|
||||
|
||||
calendar = await data_client.fetch_tenant_calendar(tenant_id)
|
||||
if calendar:
|
||||
self.calendar_cache[tenant_id] = calendar
|
||||
|
||||
return calendar
|
||||
|
||||
def _is_date_in_holiday_period(
|
||||
self,
|
||||
check_date: date,
|
||||
holiday_periods: List[Dict[str, Any]]
|
||||
) -> tuple[bool, Optional[str]]:
|
||||
"""Check if date is within any holiday period"""
|
||||
for period in holiday_periods:
|
||||
start = datetime.strptime(period["start_date"], "%Y-%m-%d").date()
|
||||
end = datetime.strptime(period["end_date"], "%Y-%m-%d").date()
|
||||
|
||||
if start <= check_date <= end:
|
||||
return True, period["name"]
|
||||
|
||||
return False, None
|
||||
|
||||
def _is_school_hours_active(
|
||||
self,
|
||||
check_datetime: datetime,
|
||||
school_hours: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""Check if datetime falls during school operating hours"""
|
||||
# Only weekdays
|
||||
if check_datetime.weekday() >= 5:
|
||||
return False
|
||||
|
||||
check_time = check_datetime.time()
|
||||
|
||||
# Morning session
|
||||
morning_start = datetime.strptime(
|
||||
school_hours["morning_start"], "%H:%M"
|
||||
).time()
|
||||
morning_end = datetime.strptime(
|
||||
school_hours["morning_end"], "%H:%M"
|
||||
).time()
|
||||
|
||||
if morning_start <= check_time <= morning_end:
|
||||
return True
|
||||
|
||||
# Afternoon session if exists
|
||||
if school_hours.get("has_afternoon_session", False):
|
||||
afternoon_start = datetime.strptime(
|
||||
school_hours["afternoon_start"], "%H:%M"
|
||||
).time()
|
||||
afternoon_end = datetime.strptime(
|
||||
school_hours["afternoon_end"], "%H:%M"
|
||||
).time()
|
||||
|
||||
if afternoon_start <= check_time <= afternoon_end:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _calculate_school_proximity_intensity(
|
||||
self,
|
||||
check_datetime: datetime,
|
||||
school_hours: Dict[str, Any]
|
||||
) -> float:
|
||||
"""
|
||||
Calculate school proximity impact intensity
|
||||
Returns 0.0-1.0 based on drop-off/pick-up times
|
||||
"""
|
||||
# Only weekdays
|
||||
if check_datetime.weekday() >= 5:
|
||||
return 0.0
|
||||
|
||||
check_time = check_datetime.time()
|
||||
|
||||
morning_start = datetime.strptime(
|
||||
school_hours["morning_start"], "%H:%M"
|
||||
).time()
|
||||
morning_end = datetime.strptime(
|
||||
school_hours["morning_end"], "%H:%M"
|
||||
).time()
|
||||
|
||||
# Morning drop-off peak (30 min before to 15 min after start)
|
||||
drop_off_start = (
|
||||
datetime.combine(date.today(), morning_start) - timedelta(minutes=30)
|
||||
).time()
|
||||
drop_off_end = (
|
||||
datetime.combine(date.today(), morning_start) + timedelta(minutes=15)
|
||||
).time()
|
||||
|
||||
if drop_off_start <= check_time <= drop_off_end:
|
||||
return 1.0 # Peak
|
||||
|
||||
# Morning pick-up peak (15 min before to 30 min after end)
|
||||
pickup_start = (
|
||||
datetime.combine(date.today(), morning_end) - timedelta(minutes=15)
|
||||
).time()
|
||||
pickup_end = (
|
||||
datetime.combine(date.today(), morning_end) + timedelta(minutes=30)
|
||||
).time()
|
||||
|
||||
if pickup_start <= check_time <= pickup_end:
|
||||
return 1.0 # Peak
|
||||
|
||||
# During school hours (moderate)
|
||||
if morning_start <= check_time <= morning_end:
|
||||
return 0.3
|
||||
|
||||
return 0.0
|
||||
|
||||
async def add_calendar_features(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
tenant_id: str,
|
||||
date_column: str = "ds"
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Add calendar features to forecast dataframe
|
||||
|
||||
Args:
|
||||
df: Forecast dataframe with future dates
|
||||
tenant_id: Tenant ID to fetch calendar
|
||||
date_column: Name of date column (default 'ds' for Prophet)
|
||||
|
||||
Returns:
|
||||
DataFrame with calendar features added
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Adding calendar features to forecast",
|
||||
tenant_id=tenant_id,
|
||||
rows=len(df)
|
||||
)
|
||||
|
||||
# Get calendar
|
||||
calendar = await self.get_calendar_for_tenant(tenant_id)
|
||||
|
||||
if not calendar:
|
||||
logger.info(
|
||||
"No calendar available, using zero features",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
df["is_school_holiday"] = 0
|
||||
df["school_hours_active"] = 0
|
||||
df["school_proximity_intensity"] = 0.0
|
||||
return df
|
||||
|
||||
holiday_periods = calendar.get("holiday_periods", [])
|
||||
school_hours = calendar.get("school_hours", {})
|
||||
|
||||
# Initialize feature lists
|
||||
school_holidays = []
|
||||
hours_active = []
|
||||
proximity_intensity = []
|
||||
|
||||
# Process each row
|
||||
for idx, row in df.iterrows():
|
||||
row_date = pd.to_datetime(row[date_column])
|
||||
|
||||
# Check holiday
|
||||
is_holiday, _ = self._is_date_in_holiday_period(
|
||||
row_date.date(),
|
||||
holiday_periods
|
||||
)
|
||||
school_holidays.append(1 if is_holiday else 0)
|
||||
|
||||
# Check school hours and proximity (if datetime has time component)
|
||||
if hasattr(row_date, 'hour'):
|
||||
hours_active.append(
|
||||
1 if self._is_school_hours_active(row_date, school_hours) else 0
|
||||
)
|
||||
proximity_intensity.append(
|
||||
self._calculate_school_proximity_intensity(row_date, school_hours)
|
||||
)
|
||||
else:
|
||||
hours_active.append(0)
|
||||
proximity_intensity.append(0.0)
|
||||
|
||||
# Add features
|
||||
df["is_school_holiday"] = school_holidays
|
||||
df["school_hours_active"] = hours_active
|
||||
df["school_proximity_intensity"] = proximity_intensity
|
||||
|
||||
logger.info(
|
||||
"Calendar features added to forecast",
|
||||
tenant_id=tenant_id,
|
||||
holidays_in_forecast=sum(school_holidays)
|
||||
)
|
||||
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error adding calendar features to forecast",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
# Return with zero features on error
|
||||
df["is_school_holiday"] = 0
|
||||
df["school_hours_active"] = 0
|
||||
df["school_proximity_intensity"] = 0.0
|
||||
return df
|
||||
|
||||
|
||||
# Global instance
|
||||
forecast_calendar_features = ForecastCalendarFeatures()
|
||||
403
services/forecasting/app/ml/demand_insights_orchestrator.py
Normal file
403
services/forecasting/app/ml/demand_insights_orchestrator.py
Normal file
@@ -0,0 +1,403 @@
|
||||
"""
|
||||
Demand Insights Orchestrator
|
||||
Coordinates demand forecasting analysis and insight posting
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from typing import Dict, List, Any, Optional
|
||||
import structlog
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add shared clients to path
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../../../..'))
|
||||
from shared.clients.ai_insights_client import AIInsightsClient
|
||||
from shared.messaging import UnifiedEventPublisher
|
||||
|
||||
from app.ml.predictor import BakeryForecaster
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class DemandInsightsOrchestrator:
|
||||
"""
|
||||
Orchestrates demand forecasting analysis and insight generation workflow.
|
||||
|
||||
Workflow:
|
||||
1. Analyze historical demand patterns from sales data
|
||||
2. Generate insights for demand optimization
|
||||
3. Post insights to AI Insights Service
|
||||
4. Publish recommendation events to RabbitMQ
|
||||
5. Provide demand pattern analysis for forecasting
|
||||
6. Track demand forecasting performance
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ai_insights_base_url: str = "http://ai-insights-service:8000",
|
||||
event_publisher: Optional[UnifiedEventPublisher] = None
|
||||
):
|
||||
self.forecaster = BakeryForecaster()
|
||||
self.ai_insights_client = AIInsightsClient(ai_insights_base_url)
|
||||
self.event_publisher = event_publisher
|
||||
|
||||
async def analyze_and_post_demand_insights(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
sales_data: pd.DataFrame,
|
||||
forecast_horizon_days: int = 30,
|
||||
min_history_days: int = 90
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Complete workflow: Analyze demand and post insights.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
inventory_product_id: Product identifier
|
||||
sales_data: Historical sales data
|
||||
forecast_horizon_days: Days to forecast ahead
|
||||
min_history_days: Minimum days of history required
|
||||
|
||||
Returns:
|
||||
Workflow results with analysis and posted insights
|
||||
"""
|
||||
logger.info(
|
||||
"Starting demand forecasting analysis workflow",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
history_days=len(sales_data)
|
||||
)
|
||||
|
||||
# Step 1: Analyze demand patterns
|
||||
analysis_results = await self.forecaster.analyze_demand_patterns(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
sales_data=sales_data,
|
||||
forecast_horizon_days=forecast_horizon_days,
|
||||
min_history_days=min_history_days
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Demand analysis complete",
|
||||
inventory_product_id=inventory_product_id,
|
||||
insights_generated=len(analysis_results.get('insights', []))
|
||||
)
|
||||
|
||||
# Step 2: Enrich insights with tenant_id and product context
|
||||
enriched_insights = self._enrich_insights(
|
||||
analysis_results.get('insights', []),
|
||||
tenant_id,
|
||||
inventory_product_id
|
||||
)
|
||||
|
||||
# Step 3: Post insights to AI Insights Service
|
||||
if enriched_insights:
|
||||
post_results = await self.ai_insights_client.create_insights_bulk(
|
||||
tenant_id=UUID(tenant_id),
|
||||
insights=enriched_insights
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Demand insights posted to AI Insights Service",
|
||||
inventory_product_id=inventory_product_id,
|
||||
total=post_results['total'],
|
||||
successful=post_results['successful'],
|
||||
failed=post_results['failed']
|
||||
)
|
||||
else:
|
||||
post_results = {'total': 0, 'successful': 0, 'failed': 0}
|
||||
logger.info("No insights to post for product", inventory_product_id=inventory_product_id)
|
||||
|
||||
# Step 4: Publish insight events to RabbitMQ
|
||||
created_insights = post_results.get('created_insights', [])
|
||||
if created_insights:
|
||||
product_context = {'inventory_product_id': inventory_product_id}
|
||||
await self._publish_insight_events(
|
||||
tenant_id=tenant_id,
|
||||
insights=created_insights,
|
||||
product_context=product_context
|
||||
)
|
||||
|
||||
# Step 5: Return comprehensive results
|
||||
return {
|
||||
'tenant_id': tenant_id,
|
||||
'inventory_product_id': inventory_product_id,
|
||||
'analyzed_at': analysis_results['analyzed_at'],
|
||||
'history_days': analysis_results['history_days'],
|
||||
'demand_patterns': analysis_results.get('patterns', {}),
|
||||
'trend_analysis': analysis_results.get('trend_analysis', {}),
|
||||
'seasonal_factors': analysis_results.get('seasonal_factors', {}),
|
||||
'insights_generated': len(enriched_insights),
|
||||
'insights_posted': post_results['successful'],
|
||||
'insights_failed': post_results['failed'],
|
||||
'created_insights': post_results.get('created_insights', [])
|
||||
}
|
||||
|
||||
def _enrich_insights(
|
||||
self,
|
||||
insights: List[Dict[str, Any]],
|
||||
tenant_id: str,
|
||||
inventory_product_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Enrich insights with required fields for AI Insights Service.
|
||||
|
||||
Args:
|
||||
insights: Raw insights from forecaster
|
||||
tenant_id: Tenant identifier
|
||||
inventory_product_id: Product identifier
|
||||
|
||||
Returns:
|
||||
Enriched insights ready for posting
|
||||
"""
|
||||
enriched = []
|
||||
|
||||
for insight in insights:
|
||||
# Add required tenant_id
|
||||
enriched_insight = insight.copy()
|
||||
enriched_insight['tenant_id'] = tenant_id
|
||||
|
||||
# Add product context to metrics
|
||||
if 'metrics_json' not in enriched_insight:
|
||||
enriched_insight['metrics_json'] = {}
|
||||
|
||||
enriched_insight['metrics_json']['inventory_product_id'] = inventory_product_id
|
||||
|
||||
# Add source metadata
|
||||
enriched_insight['source_service'] = 'forecasting'
|
||||
enriched_insight['source_model'] = 'demand_analyzer'
|
||||
enriched_insight['detected_at'] = datetime.utcnow().isoformat()
|
||||
|
||||
enriched.append(enriched_insight)
|
||||
|
||||
return enriched
|
||||
|
||||
async def analyze_all_products(
|
||||
self,
|
||||
tenant_id: str,
|
||||
products_data: Dict[str, pd.DataFrame],
|
||||
forecast_horizon_days: int = 30,
|
||||
min_history_days: int = 90
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Analyze all products for a tenant and generate comparative insights.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
products_data: Dict of {inventory_product_id: sales_data DataFrame}
|
||||
forecast_horizon_days: Days to forecast
|
||||
min_history_days: Minimum history required
|
||||
|
||||
Returns:
|
||||
Comprehensive analysis with product comparison
|
||||
"""
|
||||
logger.info(
|
||||
"Analyzing all products for tenant",
|
||||
tenant_id=tenant_id,
|
||||
products=len(products_data)
|
||||
)
|
||||
|
||||
all_results = []
|
||||
total_insights_posted = 0
|
||||
|
||||
# Analyze each product
|
||||
for inventory_product_id, sales_data in products_data.items():
|
||||
try:
|
||||
results = await self.analyze_and_post_demand_insights(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
sales_data=sales_data,
|
||||
forecast_horizon_days=forecast_horizon_days,
|
||||
min_history_days=min_history_days
|
||||
)
|
||||
|
||||
all_results.append(results)
|
||||
total_insights_posted += results['insights_posted']
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error analyzing product",
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
# Generate summary insight
|
||||
if total_insights_posted > 0:
|
||||
summary_insight = self._generate_portfolio_summary_insight(
|
||||
tenant_id, all_results
|
||||
)
|
||||
|
||||
if summary_insight:
|
||||
enriched_summary = self._enrich_insights(
|
||||
[summary_insight], tenant_id, 'all_products'
|
||||
)
|
||||
|
||||
post_results = await self.ai_insights_client.create_insights_bulk(
|
||||
tenant_id=UUID(tenant_id),
|
||||
insights=enriched_summary
|
||||
)
|
||||
|
||||
total_insights_posted += post_results['successful']
|
||||
|
||||
logger.info(
|
||||
"All products analysis complete",
|
||||
tenant_id=tenant_id,
|
||||
products_analyzed=len(all_results),
|
||||
total_insights_posted=total_insights_posted
|
||||
)
|
||||
|
||||
return {
|
||||
'tenant_id': tenant_id,
|
||||
'analyzed_at': datetime.utcnow().isoformat(),
|
||||
'products_analyzed': len(all_results),
|
||||
'product_results': all_results,
|
||||
'total_insights_posted': total_insights_posted
|
||||
}
|
||||
|
||||
def _generate_portfolio_summary_insight(
|
||||
self,
|
||||
tenant_id: str,
|
||||
all_results: List[Dict[str, Any]]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Generate portfolio-level summary insight.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
all_results: All product analysis results
|
||||
|
||||
Returns:
|
||||
Summary insight or None
|
||||
"""
|
||||
if not all_results:
|
||||
return None
|
||||
|
||||
# Calculate summary statistics
|
||||
total_products = len(all_results)
|
||||
high_demand_products = sum(1 for r in all_results if r.get('trend_analysis', {}).get('is_increasing', False))
|
||||
|
||||
avg_seasonal_factor = sum(
|
||||
r.get('seasonal_factors', {}).get('peak_ratio', 1.0)
|
||||
for r in all_results
|
||||
if r.get('seasonal_factors', {}).get('peak_ratio')
|
||||
) / max(1, len(all_results))
|
||||
|
||||
return {
|
||||
'type': 'recommendation',
|
||||
'priority': 'medium' if high_demand_products > total_products * 0.5 else 'low',
|
||||
'category': 'forecasting',
|
||||
'title': f'Demand Pattern Summary: {total_products} Products Analyzed',
|
||||
'description': f'Detected {high_demand_products} products with increasing demand trends. Average seasonal peak ratio: {avg_seasonal_factor:.2f}x.',
|
||||
'impact_type': 'demand_optimization',
|
||||
'impact_value': high_demand_products,
|
||||
'impact_unit': 'products',
|
||||
'confidence': 75,
|
||||
'metrics_json': {
|
||||
'total_products': total_products,
|
||||
'high_demand_products': high_demand_products,
|
||||
'avg_seasonal_factor': round(avg_seasonal_factor, 2),
|
||||
'trend_strength': 'strong' if high_demand_products > total_products * 0.7 else 'moderate'
|
||||
},
|
||||
'actionable': True,
|
||||
'recommendation_actions': [
|
||||
{
|
||||
'label': 'Review Production Schedule',
|
||||
'action': 'review_production_schedule',
|
||||
'params': {'tenant_id': tenant_id}
|
||||
},
|
||||
{
|
||||
'label': 'Adjust Inventory Levels',
|
||||
'action': 'adjust_inventory_levels',
|
||||
'params': {'tenant_id': tenant_id}
|
||||
}
|
||||
],
|
||||
'source_service': 'forecasting',
|
||||
'source_model': 'demand_analyzer'
|
||||
}
|
||||
|
||||
async def get_demand_patterns(
|
||||
self,
|
||||
inventory_product_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get cached demand patterns for a product.
|
||||
|
||||
Args:
|
||||
inventory_product_id: Product identifier
|
||||
|
||||
Returns:
|
||||
Demand patterns or None if not analyzed
|
||||
"""
|
||||
return self.forecaster.get_cached_demand_patterns(inventory_product_id)
|
||||
|
||||
async def _publish_insight_events(self, tenant_id, insights, product_context=None):
|
||||
"""
|
||||
Publish insight events to RabbitMQ for alert processing.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
insights: List of created insights
|
||||
product_context: Additional context about the product
|
||||
"""
|
||||
if not self.event_publisher:
|
||||
logger.warning("No event publisher available for demand insights")
|
||||
return
|
||||
|
||||
for insight in insights:
|
||||
# Determine severity based on confidence and priority
|
||||
confidence = insight.get('confidence', 0)
|
||||
priority = insight.get('priority', 'medium')
|
||||
|
||||
# Map priority to severity, with confidence as tiebreaker
|
||||
if priority == 'critical' or (priority == 'high' and confidence >= 70):
|
||||
severity = 'high'
|
||||
elif priority == 'high' or (priority == 'medium' and confidence >= 80):
|
||||
severity = 'medium'
|
||||
else:
|
||||
severity = 'low'
|
||||
|
||||
# Prepare the event data
|
||||
event_data = {
|
||||
'insight_id': insight.get('id'),
|
||||
'type': insight.get('type'),
|
||||
'title': insight.get('title'),
|
||||
'description': insight.get('description'),
|
||||
'category': insight.get('category'),
|
||||
'priority': insight.get('priority'),
|
||||
'confidence': confidence,
|
||||
'recommendation': insight.get('recommendation_actions', []),
|
||||
'impact_type': insight.get('impact_type'),
|
||||
'impact_value': insight.get('impact_value'),
|
||||
'inventory_product_id': product_context.get('inventory_product_id') if product_context else None,
|
||||
'timestamp': insight.get('detected_at', datetime.utcnow().isoformat()),
|
||||
'source_service': 'forecasting',
|
||||
'source_model': 'demand_analyzer'
|
||||
}
|
||||
|
||||
try:
|
||||
await self.event_publisher.publish_recommendation(
|
||||
event_type='ai_demand_forecast',
|
||||
tenant_id=tenant_id,
|
||||
severity=severity,
|
||||
data=event_data
|
||||
)
|
||||
logger.info(
|
||||
"Published demand insight event",
|
||||
tenant_id=tenant_id,
|
||||
insight_id=insight.get('id'),
|
||||
severity=severity
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to publish demand insight event",
|
||||
tenant_id=tenant_id,
|
||||
insight_id=insight.get('id'),
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
"""Close HTTP client connections."""
|
||||
await self.ai_insights_client.close()
|
||||
758
services/forecasting/app/ml/dynamic_rules_engine.py
Normal file
758
services/forecasting/app/ml/dynamic_rules_engine.py
Normal file
@@ -0,0 +1,758 @@
|
||||
"""
|
||||
Dynamic Business Rules Engine
|
||||
Learns optimal adjustment factors from historical data instead of using hardcoded values
|
||||
Replaces hardcoded weather multipliers, holiday adjustments, event impacts with learned values
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
import structlog
|
||||
from datetime import datetime, timedelta
|
||||
from scipy import stats
|
||||
from sklearn.linear_model import Ridge
|
||||
from collections import defaultdict
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class DynamicRulesEngine:
|
||||
"""
|
||||
Learns business rules from historical data instead of using hardcoded values.
|
||||
|
||||
Current hardcoded values to replace:
|
||||
- Weather: rain = -15%, snow = -25%, extreme_heat = -10%
|
||||
- Holidays: +50% (all holidays treated the same)
|
||||
- Events: +30% (all events treated the same)
|
||||
- Weekend: Manual assumptions
|
||||
|
||||
Dynamic approach:
|
||||
- Learn actual weather impact per weather condition per product
|
||||
- Learn holiday multipliers per holiday type
|
||||
- Learn event impact by event type
|
||||
- Learn day-of-week patterns per product
|
||||
- Generate insights when learned values differ from hardcoded assumptions
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.weather_rules = {}
|
||||
self.holiday_rules = {}
|
||||
self.event_rules = {}
|
||||
self.dow_rules = {}
|
||||
self.month_rules = {}
|
||||
|
||||
async def learn_all_rules(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
sales_data: pd.DataFrame,
|
||||
external_data: Optional[pd.DataFrame] = None,
|
||||
min_samples: int = 10
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Learn all business rules from historical data.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
inventory_product_id: Product identifier
|
||||
sales_data: Historical sales data with 'date', 'quantity' columns
|
||||
external_data: Optional weather/events/holidays data
|
||||
min_samples: Minimum samples required to learn a rule
|
||||
|
||||
Returns:
|
||||
Dictionary of learned rules and insights
|
||||
"""
|
||||
logger.info(
|
||||
"Learning dynamic business rules from historical data",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
data_points=len(sales_data)
|
||||
)
|
||||
|
||||
results = {
|
||||
'tenant_id': tenant_id,
|
||||
'inventory_product_id': inventory_product_id,
|
||||
'learned_at': datetime.utcnow().isoformat(),
|
||||
'rules': {},
|
||||
'insights': []
|
||||
}
|
||||
|
||||
# Ensure date column is datetime
|
||||
if 'date' not in sales_data.columns:
|
||||
sales_data = sales_data.copy()
|
||||
sales_data['date'] = sales_data['ds']
|
||||
|
||||
sales_data['date'] = pd.to_datetime(sales_data['date'])
|
||||
|
||||
# Learn weather impact rules
|
||||
if external_data is not None and 'weather_condition' in external_data.columns:
|
||||
weather_rules, weather_insights = await self._learn_weather_rules(
|
||||
sales_data, external_data, min_samples
|
||||
)
|
||||
results['rules']['weather'] = weather_rules
|
||||
results['insights'].extend(weather_insights)
|
||||
self.weather_rules[inventory_product_id] = weather_rules
|
||||
|
||||
# Learn holiday rules
|
||||
if external_data is not None and 'is_holiday' in external_data.columns:
|
||||
holiday_rules, holiday_insights = await self._learn_holiday_rules(
|
||||
sales_data, external_data, min_samples
|
||||
)
|
||||
results['rules']['holidays'] = holiday_rules
|
||||
results['insights'].extend(holiday_insights)
|
||||
self.holiday_rules[inventory_product_id] = holiday_rules
|
||||
|
||||
# Learn event rules
|
||||
if external_data is not None and 'event_type' in external_data.columns:
|
||||
event_rules, event_insights = await self._learn_event_rules(
|
||||
sales_data, external_data, min_samples
|
||||
)
|
||||
results['rules']['events'] = event_rules
|
||||
results['insights'].extend(event_insights)
|
||||
self.event_rules[inventory_product_id] = event_rules
|
||||
|
||||
# Learn day-of-week patterns (always available)
|
||||
dow_rules, dow_insights = await self._learn_day_of_week_rules(
|
||||
sales_data, min_samples
|
||||
)
|
||||
results['rules']['day_of_week'] = dow_rules
|
||||
results['insights'].extend(dow_insights)
|
||||
self.dow_rules[inventory_product_id] = dow_rules
|
||||
|
||||
# Learn monthly seasonality
|
||||
month_rules, month_insights = await self._learn_month_rules(
|
||||
sales_data, min_samples
|
||||
)
|
||||
results['rules']['months'] = month_rules
|
||||
results['insights'].extend(month_insights)
|
||||
self.month_rules[inventory_product_id] = month_rules
|
||||
|
||||
logger.info(
|
||||
"Dynamic rules learning complete",
|
||||
total_insights=len(results['insights']),
|
||||
rules_learned=len(results['rules'])
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def _learn_weather_rules(
|
||||
self,
|
||||
sales_data: pd.DataFrame,
|
||||
external_data: pd.DataFrame,
|
||||
min_samples: int
|
||||
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
|
||||
"""
|
||||
Learn actual weather impact from historical data.
|
||||
|
||||
Hardcoded assumptions:
|
||||
- rain: -15%
|
||||
- snow: -25%
|
||||
- extreme_heat: -10%
|
||||
|
||||
Learn actual impact for this product.
|
||||
"""
|
||||
logger.info("Learning weather impact rules")
|
||||
|
||||
# Merge sales with weather data
|
||||
merged = sales_data.merge(
|
||||
external_data[['date', 'weather_condition', 'temperature', 'precipitation']],
|
||||
on='date',
|
||||
how='left'
|
||||
)
|
||||
|
||||
# Baseline: average sales on clear days
|
||||
clear_days = merged[
|
||||
(merged['weather_condition'].isin(['clear', 'sunny', 'partly_cloudy'])) |
|
||||
(merged['weather_condition'].isna())
|
||||
]
|
||||
baseline_avg = clear_days['quantity'].mean()
|
||||
|
||||
weather_rules = {
|
||||
'baseline_avg': float(baseline_avg),
|
||||
'conditions': {}
|
||||
}
|
||||
|
||||
insights = []
|
||||
|
||||
# Hardcoded values for comparison
|
||||
hardcoded_impacts = {
|
||||
'rain': -0.15,
|
||||
'snow': -0.25,
|
||||
'extreme_heat': -0.10
|
||||
}
|
||||
|
||||
# Learn impact for each weather condition
|
||||
for condition in ['rain', 'rainy', 'snow', 'snowy', 'extreme_heat', 'hot', 'storm', 'fog']:
|
||||
condition_days = merged[merged['weather_condition'].str.contains(condition, case=False, na=False)]
|
||||
|
||||
if len(condition_days) >= min_samples:
|
||||
condition_avg = condition_days['quantity'].mean()
|
||||
learned_impact = (condition_avg - baseline_avg) / baseline_avg
|
||||
|
||||
# Statistical significance test
|
||||
t_stat, p_value = stats.ttest_ind(
|
||||
condition_days['quantity'].values,
|
||||
clear_days['quantity'].values,
|
||||
equal_var=False
|
||||
)
|
||||
|
||||
weather_rules['conditions'][condition] = {
|
||||
'learned_multiplier': float(1 + learned_impact),
|
||||
'learned_impact_pct': float(learned_impact * 100),
|
||||
'sample_size': int(len(condition_days)),
|
||||
'avg_quantity': float(condition_avg),
|
||||
'p_value': float(p_value),
|
||||
'significant': bool(p_value < 0.05)
|
||||
}
|
||||
|
||||
# Compare with hardcoded value if exists
|
||||
if condition in hardcoded_impacts and p_value < 0.05:
|
||||
hardcoded_impact = hardcoded_impacts[condition]
|
||||
difference = abs(learned_impact - hardcoded_impact)
|
||||
|
||||
if difference > 0.05: # More than 5% difference
|
||||
insight = {
|
||||
'type': 'optimization',
|
||||
'priority': 'high' if difference > 0.15 else 'medium',
|
||||
'category': 'forecasting',
|
||||
'title': f'Weather Rule Mismatch: {condition.title()}',
|
||||
'description': f'Learned {condition} impact is {learned_impact*100:.1f}% vs hardcoded {hardcoded_impact*100:.1f}%. Updating rule could improve forecast accuracy by {difference*100:.1f}%.',
|
||||
'impact_type': 'forecast_improvement',
|
||||
'impact_value': difference * 100,
|
||||
'impact_unit': 'percentage_points',
|
||||
'confidence': self._calculate_confidence(len(condition_days), p_value),
|
||||
'metrics_json': {
|
||||
'weather_condition': condition,
|
||||
'learned_impact_pct': round(learned_impact * 100, 2),
|
||||
'hardcoded_impact_pct': round(hardcoded_impact * 100, 2),
|
||||
'difference_pct': round(difference * 100, 2),
|
||||
'baseline_avg': round(baseline_avg, 2),
|
||||
'condition_avg': round(condition_avg, 2),
|
||||
'sample_size': len(condition_days),
|
||||
'p_value': round(p_value, 4)
|
||||
},
|
||||
'actionable': True,
|
||||
'recommendation_actions': [
|
||||
{
|
||||
'label': 'Update Weather Rule',
|
||||
'action': 'update_weather_multiplier',
|
||||
'params': {
|
||||
'condition': condition,
|
||||
'new_multiplier': round(1 + learned_impact, 3)
|
||||
}
|
||||
}
|
||||
],
|
||||
'source_service': 'forecasting',
|
||||
'source_model': 'dynamic_rules_engine'
|
||||
}
|
||||
insights.append(insight)
|
||||
|
||||
logger.info(
|
||||
"Weather rule discrepancy detected",
|
||||
condition=condition,
|
||||
learned=f"{learned_impact*100:.1f}%",
|
||||
hardcoded=f"{hardcoded_impact*100:.1f}%"
|
||||
)
|
||||
|
||||
return weather_rules, insights
|
||||
|
||||
async def _learn_holiday_rules(
|
||||
self,
|
||||
sales_data: pd.DataFrame,
|
||||
external_data: pd.DataFrame,
|
||||
min_samples: int
|
||||
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
|
||||
"""
|
||||
Learn holiday impact by holiday type instead of uniform +50%.
|
||||
|
||||
Hardcoded: All holidays = +50%
|
||||
Learn: Christmas vs Easter vs National holidays have different impacts
|
||||
"""
|
||||
logger.info("Learning holiday impact rules")
|
||||
|
||||
# Merge sales with holiday data
|
||||
merged = sales_data.merge(
|
||||
external_data[['date', 'is_holiday', 'holiday_name', 'holiday_type']],
|
||||
on='date',
|
||||
how='left'
|
||||
)
|
||||
|
||||
# Baseline: non-holiday average
|
||||
non_holidays = merged[merged['is_holiday'] == False]
|
||||
baseline_avg = non_holidays['quantity'].mean()
|
||||
|
||||
holiday_rules = {
|
||||
'baseline_avg': float(baseline_avg),
|
||||
'hardcoded_multiplier': 1.5, # Current +50%
|
||||
'holiday_types': {}
|
||||
}
|
||||
|
||||
insights = []
|
||||
|
||||
# Learn impact per holiday type
|
||||
if 'holiday_type' in merged.columns:
|
||||
for holiday_type in merged[merged['is_holiday'] == True]['holiday_type'].unique():
|
||||
if pd.isna(holiday_type):
|
||||
continue
|
||||
|
||||
holiday_days = merged[merged['holiday_type'] == holiday_type]
|
||||
|
||||
if len(holiday_days) >= min_samples:
|
||||
holiday_avg = holiday_days['quantity'].mean()
|
||||
learned_multiplier = holiday_avg / baseline_avg
|
||||
learned_impact = (learned_multiplier - 1) * 100
|
||||
|
||||
# Statistical test
|
||||
t_stat, p_value = stats.ttest_ind(
|
||||
holiday_days['quantity'].values,
|
||||
non_holidays['quantity'].values,
|
||||
equal_var=False
|
||||
)
|
||||
|
||||
holiday_rules['holiday_types'][holiday_type] = {
|
||||
'learned_multiplier': float(learned_multiplier),
|
||||
'learned_impact_pct': float(learned_impact),
|
||||
'sample_size': int(len(holiday_days)),
|
||||
'avg_quantity': float(holiday_avg),
|
||||
'p_value': float(p_value),
|
||||
'significant': bool(p_value < 0.05)
|
||||
}
|
||||
|
||||
# Compare with hardcoded +50%
|
||||
hardcoded_multiplier = 1.5
|
||||
difference = abs(learned_multiplier - hardcoded_multiplier)
|
||||
|
||||
if difference > 0.1 and p_value < 0.05: # More than 10% difference
|
||||
insight = {
|
||||
'type': 'recommendation',
|
||||
'priority': 'high' if difference > 0.3 else 'medium',
|
||||
'category': 'forecasting',
|
||||
'title': f'Holiday Rule Optimization: {holiday_type}',
|
||||
'description': f'{holiday_type} shows {learned_impact:.1f}% impact vs hardcoded +50%. Using learned multiplier {learned_multiplier:.2f}x could improve forecast accuracy.',
|
||||
'impact_type': 'forecast_improvement',
|
||||
'impact_value': difference * 100,
|
||||
'impact_unit': 'percentage_points',
|
||||
'confidence': self._calculate_confidence(len(holiday_days), p_value),
|
||||
'metrics_json': {
|
||||
'holiday_type': holiday_type,
|
||||
'learned_multiplier': round(learned_multiplier, 3),
|
||||
'hardcoded_multiplier': 1.5,
|
||||
'learned_impact_pct': round(learned_impact, 2),
|
||||
'hardcoded_impact_pct': 50.0,
|
||||
'baseline_avg': round(baseline_avg, 2),
|
||||
'holiday_avg': round(holiday_avg, 2),
|
||||
'sample_size': len(holiday_days),
|
||||
'p_value': round(p_value, 4)
|
||||
},
|
||||
'actionable': True,
|
||||
'recommendation_actions': [
|
||||
{
|
||||
'label': 'Update Holiday Rule',
|
||||
'action': 'update_holiday_multiplier',
|
||||
'params': {
|
||||
'holiday_type': holiday_type,
|
||||
'new_multiplier': round(learned_multiplier, 3)
|
||||
}
|
||||
}
|
||||
],
|
||||
'source_service': 'forecasting',
|
||||
'source_model': 'dynamic_rules_engine'
|
||||
}
|
||||
insights.append(insight)
|
||||
|
||||
logger.info(
|
||||
"Holiday rule optimization identified",
|
||||
holiday_type=holiday_type,
|
||||
learned=f"{learned_multiplier:.2f}x",
|
||||
hardcoded="1.5x"
|
||||
)
|
||||
|
||||
# Overall holiday impact
|
||||
all_holidays = merged[merged['is_holiday'] == True]
|
||||
if len(all_holidays) >= min_samples:
|
||||
overall_avg = all_holidays['quantity'].mean()
|
||||
overall_multiplier = overall_avg / baseline_avg
|
||||
|
||||
holiday_rules['overall_learned_multiplier'] = float(overall_multiplier)
|
||||
holiday_rules['overall_learned_impact_pct'] = float((overall_multiplier - 1) * 100)
|
||||
|
||||
return holiday_rules, insights
|
||||
|
||||
async def _learn_event_rules(
|
||||
self,
|
||||
sales_data: pd.DataFrame,
|
||||
external_data: pd.DataFrame,
|
||||
min_samples: int
|
||||
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
|
||||
"""
|
||||
Learn event impact by event type instead of uniform +30%.
|
||||
|
||||
Hardcoded: All events = +30%
|
||||
Learn: Sports events vs concerts vs festivals have different impacts
|
||||
"""
|
||||
logger.info("Learning event impact rules")
|
||||
|
||||
# Merge sales with event data
|
||||
merged = sales_data.merge(
|
||||
external_data[['date', 'event_name', 'event_type', 'event_attendance']],
|
||||
on='date',
|
||||
how='left'
|
||||
)
|
||||
|
||||
# Baseline: non-event days
|
||||
non_events = merged[merged['event_name'].isna()]
|
||||
baseline_avg = non_events['quantity'].mean()
|
||||
|
||||
event_rules = {
|
||||
'baseline_avg': float(baseline_avg),
|
||||
'hardcoded_multiplier': 1.3, # Current +30%
|
||||
'event_types': {}
|
||||
}
|
||||
|
||||
insights = []
|
||||
|
||||
# Learn impact per event type
|
||||
if 'event_type' in merged.columns:
|
||||
for event_type in merged[merged['event_type'].notna()]['event_type'].unique():
|
||||
if pd.isna(event_type):
|
||||
continue
|
||||
|
||||
event_days = merged[merged['event_type'] == event_type]
|
||||
|
||||
if len(event_days) >= min_samples:
|
||||
event_avg = event_days['quantity'].mean()
|
||||
learned_multiplier = event_avg / baseline_avg
|
||||
learned_impact = (learned_multiplier - 1) * 100
|
||||
|
||||
# Statistical test
|
||||
t_stat, p_value = stats.ttest_ind(
|
||||
event_days['quantity'].values,
|
||||
non_events['quantity'].values,
|
||||
equal_var=False
|
||||
)
|
||||
|
||||
event_rules['event_types'][event_type] = {
|
||||
'learned_multiplier': float(learned_multiplier),
|
||||
'learned_impact_pct': float(learned_impact),
|
||||
'sample_size': int(len(event_days)),
|
||||
'avg_quantity': float(event_avg),
|
||||
'p_value': float(p_value),
|
||||
'significant': bool(p_value < 0.05)
|
||||
}
|
||||
|
||||
# Compare with hardcoded +30%
|
||||
hardcoded_multiplier = 1.3
|
||||
difference = abs(learned_multiplier - hardcoded_multiplier)
|
||||
|
||||
if difference > 0.1 and p_value < 0.05:
|
||||
insight = {
|
||||
'type': 'recommendation',
|
||||
'priority': 'medium',
|
||||
'category': 'forecasting',
|
||||
'title': f'Event Rule Optimization: {event_type}',
|
||||
'description': f'{event_type} events show {learned_impact:.1f}% impact vs hardcoded +30%. Using learned multiplier could improve event forecasts.',
|
||||
'impact_type': 'forecast_improvement',
|
||||
'impact_value': difference * 100,
|
||||
'impact_unit': 'percentage_points',
|
||||
'confidence': self._calculate_confidence(len(event_days), p_value),
|
||||
'metrics_json': {
|
||||
'event_type': event_type,
|
||||
'learned_multiplier': round(learned_multiplier, 3),
|
||||
'hardcoded_multiplier': 1.3,
|
||||
'learned_impact_pct': round(learned_impact, 2),
|
||||
'hardcoded_impact_pct': 30.0,
|
||||
'baseline_avg': round(baseline_avg, 2),
|
||||
'event_avg': round(event_avg, 2),
|
||||
'sample_size': len(event_days),
|
||||
'p_value': round(p_value, 4)
|
||||
},
|
||||
'actionable': True,
|
||||
'recommendation_actions': [
|
||||
{
|
||||
'label': 'Update Event Rule',
|
||||
'action': 'update_event_multiplier',
|
||||
'params': {
|
||||
'event_type': event_type,
|
||||
'new_multiplier': round(learned_multiplier, 3)
|
||||
}
|
||||
}
|
||||
],
|
||||
'source_service': 'forecasting',
|
||||
'source_model': 'dynamic_rules_engine'
|
||||
}
|
||||
insights.append(insight)
|
||||
|
||||
return event_rules, insights
|
||||
|
||||
async def _learn_day_of_week_rules(
|
||||
self,
|
||||
sales_data: pd.DataFrame,
|
||||
min_samples: int
|
||||
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
|
||||
"""
|
||||
Learn day-of-week patterns per product.
|
||||
Replace general assumptions with product-specific patterns.
|
||||
"""
|
||||
logger.info("Learning day-of-week patterns")
|
||||
|
||||
sales_data = sales_data.copy()
|
||||
sales_data['day_of_week'] = sales_data['date'].dt.dayofweek
|
||||
sales_data['day_name'] = sales_data['date'].dt.day_name()
|
||||
|
||||
# Calculate average per day of week
|
||||
dow_avg = sales_data.groupby('day_of_week')['quantity'].agg(['mean', 'std', 'count'])
|
||||
|
||||
overall_avg = sales_data['quantity'].mean()
|
||||
|
||||
dow_rules = {
|
||||
'overall_avg': float(overall_avg),
|
||||
'days': {}
|
||||
}
|
||||
|
||||
insights = []
|
||||
|
||||
day_names = ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']
|
||||
|
||||
for dow in range(7):
|
||||
if dow not in dow_avg.index or dow_avg.loc[dow, 'count'] < min_samples:
|
||||
continue
|
||||
|
||||
day_avg = dow_avg.loc[dow, 'mean']
|
||||
day_std = dow_avg.loc[dow, 'std']
|
||||
day_count = dow_avg.loc[dow, 'count']
|
||||
|
||||
multiplier = day_avg / overall_avg
|
||||
impact_pct = (multiplier - 1) * 100
|
||||
|
||||
# Coefficient of variation
|
||||
cv = (day_std / day_avg) if day_avg > 0 else 0
|
||||
|
||||
dow_rules['days'][day_names[dow]] = {
|
||||
'day_of_week': int(dow),
|
||||
'learned_multiplier': float(multiplier),
|
||||
'impact_pct': float(impact_pct),
|
||||
'avg_quantity': float(day_avg),
|
||||
'std_quantity': float(day_std),
|
||||
'sample_size': int(day_count),
|
||||
'coefficient_of_variation': float(cv)
|
||||
}
|
||||
|
||||
# Insight for significant deviations
|
||||
if abs(impact_pct) > 20: # More than 20% difference
|
||||
insight = {
|
||||
'type': 'insight',
|
||||
'priority': 'medium' if abs(impact_pct) > 30 else 'low',
|
||||
'category': 'forecasting',
|
||||
'title': f'{day_names[dow]} Pattern: {abs(impact_pct):.0f}% {"Higher" if impact_pct > 0 else "Lower"}',
|
||||
'description': f'{day_names[dow]} sales average {day_avg:.1f} units ({impact_pct:+.1f}% vs weekly average {overall_avg:.1f}). Consider this pattern in production planning.',
|
||||
'impact_type': 'operational_insight',
|
||||
'impact_value': abs(impact_pct),
|
||||
'impact_unit': 'percentage',
|
||||
'confidence': self._calculate_confidence(day_count, 0.01), # Low p-value for large samples
|
||||
'metrics_json': {
|
||||
'day_of_week': day_names[dow],
|
||||
'day_multiplier': round(multiplier, 3),
|
||||
'impact_pct': round(impact_pct, 2),
|
||||
'day_avg': round(day_avg, 2),
|
||||
'overall_avg': round(overall_avg, 2),
|
||||
'sample_size': int(day_count),
|
||||
'std': round(day_std, 2)
|
||||
},
|
||||
'actionable': True,
|
||||
'recommendation_actions': [
|
||||
{
|
||||
'label': 'Adjust Production Schedule',
|
||||
'action': 'adjust_weekly_production',
|
||||
'params': {
|
||||
'day': day_names[dow],
|
||||
'multiplier': round(multiplier, 3)
|
||||
}
|
||||
}
|
||||
],
|
||||
'source_service': 'forecasting',
|
||||
'source_model': 'dynamic_rules_engine'
|
||||
}
|
||||
insights.append(insight)
|
||||
|
||||
return dow_rules, insights
|
||||
|
||||
async def _learn_month_rules(
|
||||
self,
|
||||
sales_data: pd.DataFrame,
|
||||
min_samples: int
|
||||
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
|
||||
"""
|
||||
Learn monthly seasonality patterns per product.
|
||||
"""
|
||||
logger.info("Learning monthly seasonality patterns")
|
||||
|
||||
sales_data = sales_data.copy()
|
||||
sales_data['month'] = sales_data['date'].dt.month
|
||||
sales_data['month_name'] = sales_data['date'].dt.month_name()
|
||||
|
||||
# Calculate average per month
|
||||
month_avg = sales_data.groupby('month')['quantity'].agg(['mean', 'std', 'count'])
|
||||
|
||||
overall_avg = sales_data['quantity'].mean()
|
||||
|
||||
month_rules = {
|
||||
'overall_avg': float(overall_avg),
|
||||
'months': {}
|
||||
}
|
||||
|
||||
insights = []
|
||||
|
||||
month_names = ['January', 'February', 'March', 'April', 'May', 'June',
|
||||
'July', 'August', 'September', 'October', 'November', 'December']
|
||||
|
||||
for month in range(1, 13):
|
||||
if month not in month_avg.index or month_avg.loc[month, 'count'] < min_samples:
|
||||
continue
|
||||
|
||||
month_mean = month_avg.loc[month, 'mean']
|
||||
month_std = month_avg.loc[month, 'std']
|
||||
month_count = month_avg.loc[month, 'count']
|
||||
|
||||
multiplier = month_mean / overall_avg
|
||||
impact_pct = (multiplier - 1) * 100
|
||||
|
||||
month_rules['months'][month_names[month - 1]] = {
|
||||
'month': int(month),
|
||||
'learned_multiplier': float(multiplier),
|
||||
'impact_pct': float(impact_pct),
|
||||
'avg_quantity': float(month_mean),
|
||||
'std_quantity': float(month_std),
|
||||
'sample_size': int(month_count)
|
||||
}
|
||||
|
||||
# Insight for significant seasonal patterns
|
||||
if abs(impact_pct) > 25: # More than 25% seasonal variation
|
||||
insight = {
|
||||
'type': 'insight',
|
||||
'priority': 'medium',
|
||||
'category': 'forecasting',
|
||||
'title': f'Seasonal Pattern: {month_names[month - 1]} {abs(impact_pct):.0f}% {"Higher" if impact_pct > 0 else "Lower"}',
|
||||
'description': f'{month_names[month - 1]} shows strong seasonality with {impact_pct:+.1f}% vs annual average. Plan inventory accordingly.',
|
||||
'impact_type': 'operational_insight',
|
||||
'impact_value': abs(impact_pct),
|
||||
'impact_unit': 'percentage',
|
||||
'confidence': self._calculate_confidence(month_count, 0.01),
|
||||
'metrics_json': {
|
||||
'month': month_names[month - 1],
|
||||
'multiplier': round(multiplier, 3),
|
||||
'impact_pct': round(impact_pct, 2),
|
||||
'month_avg': round(month_mean, 2),
|
||||
'annual_avg': round(overall_avg, 2),
|
||||
'sample_size': int(month_count)
|
||||
},
|
||||
'actionable': True,
|
||||
'recommendation_actions': [
|
||||
{
|
||||
'label': 'Adjust Seasonal Planning',
|
||||
'action': 'adjust_seasonal_forecast',
|
||||
'params': {
|
||||
'month': month_names[month - 1],
|
||||
'multiplier': round(multiplier, 3)
|
||||
}
|
||||
}
|
||||
],
|
||||
'source_service': 'forecasting',
|
||||
'source_model': 'dynamic_rules_engine'
|
||||
}
|
||||
insights.append(insight)
|
||||
|
||||
return month_rules, insights
|
||||
|
||||
def _calculate_confidence(self, sample_size: int, p_value: float) -> int:
|
||||
"""
|
||||
Calculate confidence score (0-100) based on sample size and statistical significance.
|
||||
|
||||
Args:
|
||||
sample_size: Number of observations
|
||||
p_value: Statistical significance p-value
|
||||
|
||||
Returns:
|
||||
Confidence score 0-100
|
||||
"""
|
||||
# Sample size score (0-50 points)
|
||||
if sample_size >= 100:
|
||||
sample_score = 50
|
||||
elif sample_size >= 50:
|
||||
sample_score = 40
|
||||
elif sample_size >= 30:
|
||||
sample_score = 30
|
||||
elif sample_size >= 20:
|
||||
sample_score = 20
|
||||
else:
|
||||
sample_score = 10
|
||||
|
||||
# Statistical significance score (0-50 points)
|
||||
if p_value < 0.001:
|
||||
sig_score = 50
|
||||
elif p_value < 0.01:
|
||||
sig_score = 45
|
||||
elif p_value < 0.05:
|
||||
sig_score = 35
|
||||
elif p_value < 0.1:
|
||||
sig_score = 20
|
||||
else:
|
||||
sig_score = 10
|
||||
|
||||
return min(100, sample_score + sig_score)
|
||||
|
||||
def get_rule(
|
||||
self,
|
||||
inventory_product_id: str,
|
||||
rule_type: str,
|
||||
key: str
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
Get learned rule multiplier for a specific condition.
|
||||
|
||||
Args:
|
||||
inventory_product_id: Product identifier
|
||||
rule_type: 'weather', 'holiday', 'event', 'day_of_week', 'month'
|
||||
key: Specific condition key (e.g., 'rain', 'Christmas', 'Monday')
|
||||
|
||||
Returns:
|
||||
Learned multiplier or None if not learned
|
||||
"""
|
||||
if rule_type == 'weather':
|
||||
rules = self.weather_rules.get(inventory_product_id, {})
|
||||
return rules.get('conditions', {}).get(key, {}).get('learned_multiplier')
|
||||
|
||||
elif rule_type == 'holiday':
|
||||
rules = self.holiday_rules.get(inventory_product_id, {})
|
||||
return rules.get('holiday_types', {}).get(key, {}).get('learned_multiplier')
|
||||
|
||||
elif rule_type == 'event':
|
||||
rules = self.event_rules.get(inventory_product_id, {})
|
||||
return rules.get('event_types', {}).get(key, {}).get('learned_multiplier')
|
||||
|
||||
elif rule_type == 'day_of_week':
|
||||
rules = self.dow_rules.get(inventory_product_id, {})
|
||||
return rules.get('days', {}).get(key, {}).get('learned_multiplier')
|
||||
|
||||
elif rule_type == 'month':
|
||||
rules = self.month_rules.get(inventory_product_id, {})
|
||||
return rules.get('months', {}).get(key, {}).get('learned_multiplier')
|
||||
|
||||
return None
|
||||
|
||||
def export_rules_for_prophet(
|
||||
self,
|
||||
inventory_product_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Export learned rules in format suitable for Prophet model integration.
|
||||
|
||||
Returns:
|
||||
Dictionary with multipliers for Prophet custom seasonality/regressors
|
||||
"""
|
||||
return {
|
||||
'weather': self.weather_rules.get(inventory_product_id, {}),
|
||||
'holidays': self.holiday_rules.get(inventory_product_id, {}),
|
||||
'events': self.event_rules.get(inventory_product_id, {}),
|
||||
'day_of_week': self.dow_rules.get(inventory_product_id, {}),
|
||||
'months': self.month_rules.get(inventory_product_id, {})
|
||||
}
|
||||
263
services/forecasting/app/ml/multi_horizon_forecaster.py
Normal file
263
services/forecasting/app/ml/multi_horizon_forecaster.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""
|
||||
Multi-Horizon Forecasting System
|
||||
Generates forecasts for multiple time horizons (7, 14, 30, 90 days)
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from datetime import datetime, timedelta, date
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class MultiHorizonForecaster:
|
||||
"""
|
||||
Multi-horizon forecasting with horizon-specific models.
|
||||
|
||||
Horizons:
|
||||
- Short-term (1-7 days): High precision, detailed features
|
||||
- Medium-term (8-14 days): Balanced approach
|
||||
- Long-term (15-30 days): Focus on trends, seasonal patterns
|
||||
- Very long-term (31-90 days): Strategic planning, major trends only
|
||||
"""
|
||||
|
||||
HORIZONS = {
|
||||
'short': (1, 7),
|
||||
'medium': (8, 14),
|
||||
'long': (15, 30),
|
||||
'very_long': (31, 90)
|
||||
}
|
||||
|
||||
def __init__(self, base_forecaster=None):
|
||||
"""
|
||||
Initialize multi-horizon forecaster.
|
||||
|
||||
Args:
|
||||
base_forecaster: Base forecaster (e.g., BakeryForecaster) to use
|
||||
"""
|
||||
self.base_forecaster = base_forecaster
|
||||
|
||||
async def generate_multi_horizon_forecast(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
start_date: date,
|
||||
horizons: List[str] = None,
|
||||
include_confidence_intervals: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate forecasts for multiple horizons.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
inventory_product_id: Product identifier
|
||||
start_date: Start date for forecasts
|
||||
horizons: List of horizons to forecast ('short', 'medium', 'long', 'very_long')
|
||||
include_confidence_intervals: Include confidence intervals
|
||||
|
||||
Returns:
|
||||
Dictionary with forecasts by horizon
|
||||
"""
|
||||
if horizons is None:
|
||||
horizons = ['short', 'medium', 'long']
|
||||
|
||||
logger.info(
|
||||
"Generating multi-horizon forecast",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
horizons=horizons
|
||||
)
|
||||
|
||||
results = {
|
||||
'tenant_id': tenant_id,
|
||||
'inventory_product_id': inventory_product_id,
|
||||
'start_date': start_date.isoformat(),
|
||||
'generated_at': datetime.now().isoformat(),
|
||||
'horizons': {}
|
||||
}
|
||||
|
||||
for horizon_name in horizons:
|
||||
if horizon_name not in self.HORIZONS:
|
||||
logger.warning(f"Unknown horizon: {horizon_name}, skipping")
|
||||
continue
|
||||
|
||||
start_day, end_day = self.HORIZONS[horizon_name]
|
||||
|
||||
# Generate forecast for this horizon
|
||||
horizon_forecast = await self._generate_horizon_forecast(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
start_date=start_date,
|
||||
days_ahead=end_day,
|
||||
horizon_name=horizon_name,
|
||||
include_confidence=include_confidence_intervals
|
||||
)
|
||||
|
||||
results['horizons'][horizon_name] = horizon_forecast
|
||||
|
||||
logger.info("Multi-horizon forecast complete",
|
||||
horizons_generated=len(results['horizons']))
|
||||
|
||||
return results
|
||||
|
||||
async def _generate_horizon_forecast(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
start_date: date,
|
||||
days_ahead: int,
|
||||
horizon_name: str,
|
||||
include_confidence: bool
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate forecast for a specific horizon.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
inventory_product_id: Product identifier
|
||||
start_date: Start date
|
||||
days_ahead: Number of days ahead
|
||||
horizon_name: Horizon name ('short', 'medium', etc.)
|
||||
include_confidence: Include confidence intervals
|
||||
|
||||
Returns:
|
||||
Forecast data for the horizon
|
||||
"""
|
||||
# Generate date range
|
||||
dates = [start_date + timedelta(days=i) for i in range(days_ahead)]
|
||||
|
||||
# Use base forecaster if available
|
||||
if self.base_forecaster:
|
||||
# Call base forecaster for predictions
|
||||
forecasts = []
|
||||
|
||||
for forecast_date in dates:
|
||||
try:
|
||||
# This would call the actual forecasting service
|
||||
# For now, we'll return a structured response
|
||||
forecasts.append({
|
||||
'date': forecast_date.isoformat(),
|
||||
'predicted_demand': 0, # Placeholder
|
||||
'confidence_lower': 0 if include_confidence else None,
|
||||
'confidence_upper': 0 if include_confidence else None
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate forecast for {forecast_date}: {e}")
|
||||
|
||||
return {
|
||||
'horizon_name': horizon_name,
|
||||
'days_ahead': days_ahead,
|
||||
'start_date': start_date.isoformat(),
|
||||
'end_date': dates[-1].isoformat(),
|
||||
'forecasts': forecasts,
|
||||
'aggregates': self._calculate_horizon_aggregates(forecasts)
|
||||
}
|
||||
else:
|
||||
logger.warning("No base forecaster available, returning placeholder")
|
||||
return {
|
||||
'horizon_name': horizon_name,
|
||||
'days_ahead': days_ahead,
|
||||
'forecasts': [],
|
||||
'aggregates': {}
|
||||
}
|
||||
|
||||
def _calculate_horizon_aggregates(self, forecasts: List[Dict]) -> Dict[str, float]:
|
||||
"""
|
||||
Calculate aggregate statistics for a horizon.
|
||||
|
||||
Args:
|
||||
forecasts: List of daily forecasts
|
||||
|
||||
Returns:
|
||||
Aggregate statistics
|
||||
"""
|
||||
if not forecasts:
|
||||
return {}
|
||||
|
||||
demands = [f['predicted_demand'] for f in forecasts if f.get('predicted_demand')]
|
||||
|
||||
if not demands:
|
||||
return {}
|
||||
|
||||
return {
|
||||
'total_demand': sum(demands),
|
||||
'avg_daily_demand': np.mean(demands),
|
||||
'max_daily_demand': max(demands),
|
||||
'min_daily_demand': min(demands),
|
||||
'demand_volatility': np.std(demands) if len(demands) > 1 else 0
|
||||
}
|
||||
|
||||
def get_horizon_recommendation(
|
||||
self,
|
||||
horizon_name: str,
|
||||
forecast_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate recommendations based on horizon forecast.
|
||||
|
||||
Args:
|
||||
horizon_name: Horizon name
|
||||
forecast_data: Forecast data for the horizon
|
||||
|
||||
Returns:
|
||||
Recommendations dictionary
|
||||
"""
|
||||
aggregates = forecast_data.get('aggregates', {})
|
||||
total_demand = aggregates.get('total_demand', 0)
|
||||
volatility = aggregates.get('demand_volatility', 0)
|
||||
|
||||
recommendations = {
|
||||
'horizon': horizon_name,
|
||||
'actions': []
|
||||
}
|
||||
|
||||
if horizon_name == 'short':
|
||||
# Short-term: Operational recommendations
|
||||
if total_demand > 0:
|
||||
recommendations['actions'].append(f"Prepare {total_demand:.0f} units for next 7 days")
|
||||
if volatility > 10:
|
||||
recommendations['actions'].append("High volatility expected - increase safety stock")
|
||||
|
||||
elif horizon_name == 'medium':
|
||||
# Medium-term: Procurement planning
|
||||
recommendations['actions'].append(f"Order supplies for {total_demand:.0f} units (2-week demand)")
|
||||
if aggregates.get('max_daily_demand', 0) > aggregates.get('avg_daily_demand', 0) * 1.5:
|
||||
recommendations['actions'].append("Peak demand day detected - plan extra capacity")
|
||||
|
||||
elif horizon_name == 'long':
|
||||
# Long-term: Strategic planning
|
||||
avg_weekly_demand = total_demand / 4 if total_demand > 0 else 0
|
||||
recommendations['actions'].append(f"Monthly demand projection: {total_demand:.0f} units")
|
||||
recommendations['actions'].append(f"Average weekly demand: {avg_weekly_demand:.0f} units")
|
||||
|
||||
elif horizon_name == 'very_long':
|
||||
# Very long-term: Capacity planning
|
||||
recommendations['actions'].append(f"Quarterly demand projection: {total_demand:.0f} units")
|
||||
recommendations['actions'].append("Review capacity and staffing needs")
|
||||
|
||||
return recommendations
|
||||
|
||||
|
||||
def get_appropriate_horizons_for_use_case(use_case: str) -> List[str]:
|
||||
"""
|
||||
Get appropriate forecast horizons for a use case.
|
||||
|
||||
Args:
|
||||
use_case: Use case name (e.g., 'production_planning', 'procurement', 'strategic')
|
||||
|
||||
Returns:
|
||||
List of horizon names
|
||||
"""
|
||||
use_case_horizons = {
|
||||
'production_planning': ['short'],
|
||||
'procurement': ['short', 'medium'],
|
||||
'inventory_optimization': ['short', 'medium'],
|
||||
'capacity_planning': ['medium', 'long'],
|
||||
'strategic_planning': ['long', 'very_long'],
|
||||
'financial_planning': ['long', 'very_long'],
|
||||
'all': ['short', 'medium', 'long', 'very_long']
|
||||
}
|
||||
|
||||
return use_case_horizons.get(use_case, ['short', 'medium'])
|
||||
593
services/forecasting/app/ml/pattern_detector.py
Normal file
593
services/forecasting/app/ml/pattern_detector.py
Normal file
@@ -0,0 +1,593 @@
|
||||
"""
|
||||
Pattern Detection Engine for Sales Data
|
||||
Automatically identifies patterns and generates insights
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
import structlog
|
||||
from scipy import stats
|
||||
from collections import defaultdict
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class SalesPatternDetector:
|
||||
"""
|
||||
Detect sales patterns and generate actionable insights.
|
||||
|
||||
Patterns detected:
|
||||
- Time-of-day patterns (hourly peaks)
|
||||
- Day-of-week patterns (weekend spikes)
|
||||
- Weekly seasonality patterns
|
||||
- Monthly patterns
|
||||
- Holiday impact patterns
|
||||
- Weather correlation patterns
|
||||
"""
|
||||
|
||||
def __init__(self, significance_threshold: float = 0.15):
|
||||
"""
|
||||
Initialize pattern detector.
|
||||
|
||||
Args:
|
||||
significance_threshold: Minimum percentage difference to consider significant (default 15%)
|
||||
"""
|
||||
self.significance_threshold = significance_threshold
|
||||
self.detected_patterns = []
|
||||
|
||||
async def detect_all_patterns(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
sales_data: pd.DataFrame,
|
||||
min_confidence: int = 70
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Detect all patterns in sales data and generate insights.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
inventory_product_id: Product identifier
|
||||
sales_data: Sales data with columns: date, quantity, (optional: hour, temperature, etc.)
|
||||
min_confidence: Minimum confidence score for insights
|
||||
|
||||
Returns:
|
||||
List of insight dictionaries ready for AI Insights Service
|
||||
"""
|
||||
logger.info(
|
||||
"Starting pattern detection",
|
||||
tenant_id=tenant_id,
|
||||
product_id=inventory_product_id,
|
||||
data_points=len(sales_data)
|
||||
)
|
||||
|
||||
insights = []
|
||||
|
||||
# Ensure date column is datetime
|
||||
if 'date' in sales_data.columns:
|
||||
sales_data['date'] = pd.to_datetime(sales_data['date'])
|
||||
|
||||
# 1. Day-of-week patterns
|
||||
dow_insights = await self._detect_day_of_week_patterns(
|
||||
tenant_id, inventory_product_id, sales_data, min_confidence
|
||||
)
|
||||
insights.extend(dow_insights)
|
||||
|
||||
# 2. Weekend vs weekday patterns
|
||||
weekend_insights = await self._detect_weekend_patterns(
|
||||
tenant_id, inventory_product_id, sales_data, min_confidence
|
||||
)
|
||||
insights.extend(weekend_insights)
|
||||
|
||||
# 3. Month-end patterns
|
||||
month_end_insights = await self._detect_month_end_patterns(
|
||||
tenant_id, inventory_product_id, sales_data, min_confidence
|
||||
)
|
||||
insights.extend(month_end_insights)
|
||||
|
||||
# 4. Hourly patterns (if hour data available)
|
||||
if 'hour' in sales_data.columns:
|
||||
hourly_insights = await self._detect_hourly_patterns(
|
||||
tenant_id, inventory_product_id, sales_data, min_confidence
|
||||
)
|
||||
insights.extend(hourly_insights)
|
||||
|
||||
# 5. Weather correlation (if temperature data available)
|
||||
if 'temperature' in sales_data.columns:
|
||||
weather_insights = await self._detect_weather_correlations(
|
||||
tenant_id, inventory_product_id, sales_data, min_confidence
|
||||
)
|
||||
insights.extend(weather_insights)
|
||||
|
||||
# 6. Trend detection
|
||||
trend_insights = await self._detect_trends(
|
||||
tenant_id, inventory_product_id, sales_data, min_confidence
|
||||
)
|
||||
insights.extend(trend_insights)
|
||||
|
||||
logger.info(
|
||||
"Pattern detection complete",
|
||||
total_insights=len(insights),
|
||||
product_id=inventory_product_id
|
||||
)
|
||||
|
||||
return insights
|
||||
|
||||
async def _detect_day_of_week_patterns(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
sales_data: pd.DataFrame,
|
||||
min_confidence: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Detect day-of-week patterns (e.g., Friday sales spike)."""
|
||||
insights = []
|
||||
|
||||
if 'date' not in sales_data.columns or 'quantity' not in sales_data.columns:
|
||||
return insights
|
||||
|
||||
# Add day of week
|
||||
sales_data['day_of_week'] = sales_data['date'].dt.dayofweek
|
||||
sales_data['day_name'] = sales_data['date'].dt.day_name()
|
||||
|
||||
# Calculate average sales per day of week
|
||||
dow_avg = sales_data.groupby(['day_of_week', 'day_name'])['quantity'].agg(['mean', 'count']).reset_index()
|
||||
|
||||
# Only consider days with sufficient data (at least 4 observations)
|
||||
dow_avg = dow_avg[dow_avg['count'] >= 4]
|
||||
|
||||
if len(dow_avg) < 2:
|
||||
return insights
|
||||
|
||||
overall_avg = sales_data['quantity'].mean()
|
||||
|
||||
# Find days significantly above average
|
||||
for _, row in dow_avg.iterrows():
|
||||
day_avg = row['mean']
|
||||
pct_diff = ((day_avg - overall_avg) / overall_avg) * 100
|
||||
|
||||
if abs(pct_diff) > self.significance_threshold * 100:
|
||||
# Calculate confidence based on sample size and consistency
|
||||
confidence = self._calculate_pattern_confidence(
|
||||
sample_size=int(row['count']),
|
||||
effect_size=abs(pct_diff) / 100,
|
||||
variability=sales_data['quantity'].std()
|
||||
)
|
||||
|
||||
if confidence >= min_confidence:
|
||||
if pct_diff > 0:
|
||||
insight = self._create_insight(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
insight_type='pattern',
|
||||
category='sales',
|
||||
priority='medium' if pct_diff > 20 else 'low',
|
||||
title=f'{row["day_name"]} Sales Pattern Detected',
|
||||
description=f'Sales on {row["day_name"]} are {abs(pct_diff):.1f}% {"higher" if pct_diff > 0 else "lower"} than average ({day_avg:.1f} vs {overall_avg:.1f} units).',
|
||||
confidence=confidence,
|
||||
metrics={
|
||||
'day_of_week': row['day_name'],
|
||||
'avg_sales': float(day_avg),
|
||||
'overall_avg': float(overall_avg),
|
||||
'difference_pct': float(pct_diff),
|
||||
'sample_size': int(row['count'])
|
||||
},
|
||||
actionable=True,
|
||||
actions=[
|
||||
{'label': 'Adjust Production', 'action': 'adjust_daily_production'},
|
||||
{'label': 'Review Schedule', 'action': 'review_production_schedule'}
|
||||
]
|
||||
)
|
||||
insights.append(insight)
|
||||
|
||||
return insights
|
||||
|
||||
async def _detect_weekend_patterns(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
sales_data: pd.DataFrame,
|
||||
min_confidence: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Detect weekend vs weekday patterns."""
|
||||
insights = []
|
||||
|
||||
if 'date' not in sales_data.columns or 'quantity' not in sales_data.columns:
|
||||
return insights
|
||||
|
||||
# Classify weekend vs weekday
|
||||
sales_data['is_weekend'] = sales_data['date'].dt.dayofweek.isin([5, 6])
|
||||
|
||||
# Calculate averages
|
||||
weekend_avg = sales_data[sales_data['is_weekend']]['quantity'].mean()
|
||||
weekday_avg = sales_data[~sales_data['is_weekend']]['quantity'].mean()
|
||||
|
||||
weekend_count = sales_data[sales_data['is_weekend']]['quantity'].count()
|
||||
weekday_count = sales_data[~sales_data['is_weekend']]['quantity'].count()
|
||||
|
||||
if weekend_count < 4 or weekday_count < 4:
|
||||
return insights
|
||||
|
||||
pct_diff = ((weekend_avg - weekday_avg) / weekday_avg) * 100
|
||||
|
||||
if abs(pct_diff) > self.significance_threshold * 100:
|
||||
confidence = self._calculate_pattern_confidence(
|
||||
sample_size=min(weekend_count, weekday_count),
|
||||
effect_size=abs(pct_diff) / 100,
|
||||
variability=sales_data['quantity'].std()
|
||||
)
|
||||
|
||||
if confidence >= min_confidence:
|
||||
# Estimate revenue impact
|
||||
impact_value = abs(weekend_avg - weekday_avg) * 8 * 4 # 8 weekend days per month
|
||||
|
||||
insight = self._create_insight(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
insight_type='recommendation',
|
||||
category='forecasting',
|
||||
priority='high' if abs(pct_diff) > 25 else 'medium',
|
||||
title=f'Weekend Demand Pattern: {abs(pct_diff):.0f}% {"Higher" if pct_diff > 0 else "Lower"}',
|
||||
description=f'Weekend sales average {weekend_avg:.1f} units vs {weekday_avg:.1f} on weekdays ({abs(pct_diff):.0f}% {"increase" if pct_diff > 0 else "decrease"}). Recommend adjusting weekend production targets.',
|
||||
confidence=confidence,
|
||||
impact_type='revenue_increase' if pct_diff > 0 else 'cost_savings',
|
||||
impact_value=float(impact_value),
|
||||
impact_unit='units/month',
|
||||
metrics={
|
||||
'weekend_avg': float(weekend_avg),
|
||||
'weekday_avg': float(weekday_avg),
|
||||
'difference_pct': float(pct_diff),
|
||||
'weekend_samples': int(weekend_count),
|
||||
'weekday_samples': int(weekday_count)
|
||||
},
|
||||
actionable=True,
|
||||
actions=[
|
||||
{'label': 'Increase Weekend Production', 'action': 'adjust_weekend_production'},
|
||||
{'label': 'Update Forecast Multiplier', 'action': 'update_forecast_rule'}
|
||||
]
|
||||
)
|
||||
insights.append(insight)
|
||||
|
||||
return insights
|
||||
|
||||
async def _detect_month_end_patterns(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
sales_data: pd.DataFrame,
|
||||
min_confidence: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Detect month-end and payday patterns."""
|
||||
insights = []
|
||||
|
||||
if 'date' not in sales_data.columns or 'quantity' not in sales_data.columns:
|
||||
return insights
|
||||
|
||||
# Identify payday periods (15th and last 3 days of month)
|
||||
sales_data['day_of_month'] = sales_data['date'].dt.day
|
||||
sales_data['is_payday'] = (
|
||||
(sales_data['day_of_month'] == 15) |
|
||||
(sales_data['date'].dt.is_month_end) |
|
||||
(sales_data['day_of_month'] >= sales_data['date'].dt.days_in_month - 2)
|
||||
)
|
||||
|
||||
payday_avg = sales_data[sales_data['is_payday']]['quantity'].mean()
|
||||
regular_avg = sales_data[~sales_data['is_payday']]['quantity'].mean()
|
||||
|
||||
payday_count = sales_data[sales_data['is_payday']]['quantity'].count()
|
||||
|
||||
if payday_count < 4:
|
||||
return insights
|
||||
|
||||
pct_diff = ((payday_avg - regular_avg) / regular_avg) * 100
|
||||
|
||||
if abs(pct_diff) > self.significance_threshold * 100:
|
||||
confidence = self._calculate_pattern_confidence(
|
||||
sample_size=payday_count,
|
||||
effect_size=abs(pct_diff) / 100,
|
||||
variability=sales_data['quantity'].std()
|
||||
)
|
||||
|
||||
if confidence >= min_confidence and pct_diff > 0:
|
||||
insight = self._create_insight(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
insight_type='pattern',
|
||||
category='sales',
|
||||
priority='medium',
|
||||
title=f'Payday Shopping Pattern Detected',
|
||||
description=f'Sales increase {pct_diff:.0f}% during payday periods (15th and month-end). Average {payday_avg:.1f} vs {regular_avg:.1f} units.',
|
||||
confidence=confidence,
|
||||
metrics={
|
||||
'payday_avg': float(payday_avg),
|
||||
'regular_avg': float(regular_avg),
|
||||
'difference_pct': float(pct_diff)
|
||||
},
|
||||
actionable=True,
|
||||
actions=[
|
||||
{'label': 'Increase Payday Stock', 'action': 'adjust_payday_production'}
|
||||
]
|
||||
)
|
||||
insights.append(insight)
|
||||
|
||||
return insights
|
||||
|
||||
async def _detect_hourly_patterns(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
sales_data: pd.DataFrame,
|
||||
min_confidence: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Detect hourly sales patterns (if POS data available)."""
|
||||
insights = []
|
||||
|
||||
if 'hour' not in sales_data.columns or 'quantity' not in sales_data.columns:
|
||||
return insights
|
||||
|
||||
hourly_avg = sales_data.groupby('hour')['quantity'].agg(['mean', 'count']).reset_index()
|
||||
hourly_avg = hourly_avg[hourly_avg['count'] >= 3] # At least 3 observations
|
||||
|
||||
if len(hourly_avg) < 3:
|
||||
return insights
|
||||
|
||||
overall_avg = sales_data['quantity'].mean()
|
||||
|
||||
# Find peak hours (top 3)
|
||||
top_hours = hourly_avg.nlargest(3, 'mean')
|
||||
|
||||
for _, row in top_hours.iterrows():
|
||||
hour_avg = row['mean']
|
||||
pct_diff = ((hour_avg - overall_avg) / overall_avg) * 100
|
||||
|
||||
if pct_diff > self.significance_threshold * 100:
|
||||
confidence = self._calculate_pattern_confidence(
|
||||
sample_size=int(row['count']),
|
||||
effect_size=pct_diff / 100,
|
||||
variability=sales_data['quantity'].std()
|
||||
)
|
||||
|
||||
if confidence >= min_confidence:
|
||||
hour = int(row['hour'])
|
||||
time_label = f"{hour:02d}:00-{(hour+1):02d}:00"
|
||||
|
||||
insight = self._create_insight(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
insight_type='pattern',
|
||||
category='sales',
|
||||
priority='low',
|
||||
title=f'Peak Sales Hour: {time_label}',
|
||||
description=f'Sales peak during {time_label} with {hour_avg:.1f} units ({pct_diff:.0f}% above average).',
|
||||
confidence=confidence,
|
||||
metrics={
|
||||
'peak_hour': hour,
|
||||
'avg_sales': float(hour_avg),
|
||||
'overall_avg': float(overall_avg),
|
||||
'difference_pct': float(pct_diff)
|
||||
},
|
||||
actionable=True,
|
||||
actions=[
|
||||
{'label': 'Ensure Fresh Stock', 'action': 'schedule_production'},
|
||||
{'label': 'Increase Staffing', 'action': 'adjust_staffing'}
|
||||
]
|
||||
)
|
||||
insights.append(insight)
|
||||
|
||||
return insights
|
||||
|
||||
async def _detect_weather_correlations(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
sales_data: pd.DataFrame,
|
||||
min_confidence: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Detect weather-sales correlations."""
|
||||
insights = []
|
||||
|
||||
if 'temperature' not in sales_data.columns or 'quantity' not in sales_data.columns:
|
||||
return insights
|
||||
|
||||
# Remove NaN values
|
||||
clean_data = sales_data[['temperature', 'quantity']].dropna()
|
||||
|
||||
if len(clean_data) < 30: # Need sufficient data
|
||||
return insights
|
||||
|
||||
# Calculate correlation
|
||||
correlation, p_value = stats.pearsonr(clean_data['temperature'], clean_data['quantity'])
|
||||
|
||||
if abs(correlation) > 0.3 and p_value < 0.05: # Moderate correlation and significant
|
||||
confidence = self._calculate_correlation_confidence(correlation, p_value, len(clean_data))
|
||||
|
||||
if confidence >= min_confidence:
|
||||
direction = 'increase' if correlation > 0 else 'decrease'
|
||||
|
||||
insight = self._create_insight(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
insight_type='insight',
|
||||
category='forecasting',
|
||||
priority='medium' if abs(correlation) > 0.5 else 'low',
|
||||
title=f'Temperature Impact on Sales: {abs(correlation):.0%} Correlation',
|
||||
description=f'Sales {direction} with temperature (correlation: {correlation:.2f}). {"Warmer" if correlation > 0 else "Colder"} weather associated with {"higher" if correlation > 0 else "lower"} sales.',
|
||||
confidence=confidence,
|
||||
metrics={
|
||||
'correlation': float(correlation),
|
||||
'p_value': float(p_value),
|
||||
'sample_size': len(clean_data),
|
||||
'direction': direction
|
||||
},
|
||||
actionable=False
|
||||
)
|
||||
insights.append(insight)
|
||||
|
||||
return insights
|
||||
|
||||
async def _detect_trends(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
sales_data: pd.DataFrame,
|
||||
min_confidence: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Detect overall trends (growing, declining, stable)."""
|
||||
insights = []
|
||||
|
||||
if 'date' not in sales_data.columns or 'quantity' not in sales_data.columns or len(sales_data) < 60:
|
||||
return insights
|
||||
|
||||
# Sort by date
|
||||
sales_data = sales_data.sort_values('date')
|
||||
|
||||
# Calculate 30-day rolling average
|
||||
sales_data['rolling_30d'] = sales_data['quantity'].rolling(window=30, min_periods=15).mean()
|
||||
|
||||
# Compare first and last 30-day averages
|
||||
first_30_avg = sales_data['rolling_30d'].iloc[:30].mean()
|
||||
last_30_avg = sales_data['rolling_30d'].iloc[-30:].mean()
|
||||
|
||||
if pd.isna(first_30_avg) or pd.isna(last_30_avg):
|
||||
return insights
|
||||
|
||||
pct_change = ((last_30_avg - first_30_avg) / first_30_avg) * 100
|
||||
|
||||
if abs(pct_change) > 10: # 10% change is significant
|
||||
confidence = min(95, 70 + int(abs(pct_change))) # Higher change = higher confidence
|
||||
|
||||
trend_type = 'growing' if pct_change > 0 else 'declining'
|
||||
|
||||
insight = self._create_insight(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
insight_type='prediction',
|
||||
category='forecasting',
|
||||
priority='high' if abs(pct_change) > 20 else 'medium',
|
||||
title=f'Sales Trend: {trend_type.title()} {abs(pct_change):.0f}%',
|
||||
description=f'Sales show a {trend_type} trend over the period. Current 30-day average: {last_30_avg:.1f} vs earlier: {first_30_avg:.1f} ({pct_change:+.0f}%).',
|
||||
confidence=confidence,
|
||||
metrics={
|
||||
'current_avg': float(last_30_avg),
|
||||
'previous_avg': float(first_30_avg),
|
||||
'change_pct': float(pct_change),
|
||||
'trend': trend_type
|
||||
},
|
||||
actionable=True,
|
||||
actions=[
|
||||
{'label': 'Adjust Forecast Model', 'action': 'update_forecast'},
|
||||
{'label': 'Review Capacity', 'action': 'review_production_capacity'}
|
||||
]
|
||||
)
|
||||
insights.append(insight)
|
||||
|
||||
return insights
|
||||
|
||||
def _calculate_pattern_confidence(
|
||||
self,
|
||||
sample_size: int,
|
||||
effect_size: float,
|
||||
variability: float
|
||||
) -> int:
|
||||
"""
|
||||
Calculate confidence score for detected pattern.
|
||||
|
||||
Args:
|
||||
sample_size: Number of observations
|
||||
effect_size: Size of the effect (e.g., 0.25 for 25% difference)
|
||||
variability: Standard deviation of data
|
||||
|
||||
Returns:
|
||||
Confidence score (0-100)
|
||||
"""
|
||||
# Base confidence from sample size
|
||||
if sample_size < 4:
|
||||
base = 50
|
||||
elif sample_size < 10:
|
||||
base = 65
|
||||
elif sample_size < 30:
|
||||
base = 75
|
||||
elif sample_size < 100:
|
||||
base = 85
|
||||
else:
|
||||
base = 90
|
||||
|
||||
# Adjust for effect size
|
||||
effect_boost = min(15, effect_size * 30)
|
||||
|
||||
# Adjust for variability (penalize high variability)
|
||||
variability_penalty = min(10, variability / 10)
|
||||
|
||||
confidence = base + effect_boost - variability_penalty
|
||||
|
||||
return int(max(0, min(100, confidence)))
|
||||
|
||||
def _calculate_correlation_confidence(
|
||||
self,
|
||||
correlation: float,
|
||||
p_value: float,
|
||||
sample_size: int
|
||||
) -> int:
|
||||
"""Calculate confidence for correlation insights."""
|
||||
# Base confidence from correlation strength
|
||||
base = abs(correlation) * 100
|
||||
|
||||
# Boost for significance
|
||||
if p_value < 0.001:
|
||||
significance_boost = 15
|
||||
elif p_value < 0.01:
|
||||
significance_boost = 10
|
||||
elif p_value < 0.05:
|
||||
significance_boost = 5
|
||||
else:
|
||||
significance_boost = 0
|
||||
|
||||
# Boost for sample size
|
||||
if sample_size > 100:
|
||||
sample_boost = 10
|
||||
elif sample_size > 50:
|
||||
sample_boost = 5
|
||||
else:
|
||||
sample_boost = 0
|
||||
|
||||
confidence = base + significance_boost + sample_boost
|
||||
|
||||
return int(max(0, min(100, confidence)))
|
||||
|
||||
def _create_insight(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
insight_type: str,
|
||||
category: str,
|
||||
priority: str,
|
||||
title: str,
|
||||
description: str,
|
||||
confidence: int,
|
||||
metrics: Dict[str, Any],
|
||||
actionable: bool,
|
||||
actions: List[Dict[str, str]] = None,
|
||||
impact_type: str = None,
|
||||
impact_value: float = None,
|
||||
impact_unit: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Create an insight dictionary for AI Insights Service."""
|
||||
return {
|
||||
'tenant_id': tenant_id,
|
||||
'type': insight_type,
|
||||
'priority': priority,
|
||||
'category': category,
|
||||
'title': title,
|
||||
'description': description,
|
||||
'impact_type': impact_type,
|
||||
'impact_value': impact_value,
|
||||
'impact_unit': impact_unit,
|
||||
'confidence': confidence,
|
||||
'metrics_json': metrics,
|
||||
'actionable': actionable,
|
||||
'recommendation_actions': actions or [],
|
||||
'source_service': 'forecasting',
|
||||
'source_data_id': f'pattern_detection_{inventory_product_id}_{datetime.utcnow().strftime("%Y%m%d")}'
|
||||
}
|
||||
854
services/forecasting/app/ml/predictor.py
Normal file
854
services/forecasting/app/ml/predictor.py
Normal file
@@ -0,0 +1,854 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/ml/predictor.py
|
||||
# ================================================================
|
||||
"""
|
||||
Enhanced predictor module with advanced forecasting capabilities
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime, date, timedelta
|
||||
import pickle
|
||||
import json
|
||||
|
||||
from app.core.config import settings
|
||||
from shared.monitoring.metrics import MetricsCollector
|
||||
from shared.database.base import create_database_manager
|
||||
|
||||
logger = structlog.get_logger()
|
||||
metrics = MetricsCollector("forecasting-service")
|
||||
|
||||
class BakeryPredictor:
|
||||
"""
|
||||
Advanced predictor for bakery demand forecasting with dependency injection
|
||||
Handles Prophet models and business-specific logic
|
||||
"""
|
||||
|
||||
def __init__(self, database_manager=None, use_dynamic_rules=True):
|
||||
self.database_manager = database_manager or create_database_manager(settings.DATABASE_URL, "forecasting-service")
|
||||
self.model_cache = {}
|
||||
self.use_dynamic_rules = use_dynamic_rules
|
||||
|
||||
if use_dynamic_rules:
|
||||
try:
|
||||
from app.ml.dynamic_rules_engine import DynamicRulesEngine
|
||||
from shared.clients.ai_insights_client import AIInsightsClient
|
||||
self.rules_engine = DynamicRulesEngine()
|
||||
self.ai_insights_client = AIInsightsClient(
|
||||
base_url=settings.AI_INSIGHTS_SERVICE_URL or "http://ai-insights-service:8000"
|
||||
)
|
||||
# Also provide business_rules for consistency
|
||||
self.business_rules = BakeryBusinessRules(
|
||||
use_dynamic_rules=True,
|
||||
ai_insights_client=self.ai_insights_client
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.warning(f"Failed to import dynamic rules engine: {e}. Falling back to basic business rules.")
|
||||
self.use_dynamic_rules = False
|
||||
self.business_rules = BakeryBusinessRules()
|
||||
else:
|
||||
self.business_rules = BakeryBusinessRules()
|
||||
|
||||
class BakeryForecaster:
|
||||
"""
|
||||
Enhanced forecaster that integrates with repository pattern
|
||||
Uses enhanced features from training service for predictions
|
||||
"""
|
||||
|
||||
def __init__(self, database_manager=None, use_enhanced_features=True):
|
||||
self.database_manager = database_manager or create_database_manager(settings.DATABASE_URL, "forecasting-service")
|
||||
self.predictor = BakeryPredictor(database_manager)
|
||||
self.use_enhanced_features = use_enhanced_features
|
||||
|
||||
# Initialize business rules - this was missing! This fixes the AttributeError
|
||||
self.business_rules = BakeryBusinessRules(use_dynamic_rules=True, ai_insights_client=self.predictor.ai_insights_client if hasattr(self.predictor, 'ai_insights_client') else None)
|
||||
|
||||
# Initialize POI feature service
|
||||
from app.services.poi_feature_service import POIFeatureService
|
||||
self.poi_feature_service = POIFeatureService()
|
||||
|
||||
# Initialize enhanced data processor from shared module
|
||||
if use_enhanced_features:
|
||||
try:
|
||||
from shared.ml.data_processor import EnhancedBakeryDataProcessor
|
||||
self.data_processor = EnhancedBakeryDataProcessor(region='MD')
|
||||
logger.info("Enhanced features enabled using shared data processor")
|
||||
except ImportError as e:
|
||||
logger.warning(
|
||||
f"Could not import EnhancedBakeryDataProcessor from shared module: {e}. "
|
||||
"Falling back to basic features."
|
||||
)
|
||||
self.use_enhanced_features = False
|
||||
self.data_processor = None
|
||||
else:
|
||||
self.data_processor = None
|
||||
|
||||
|
||||
async def predict_demand(self, model, features: Dict[str, Any],
|
||||
business_type: str = "individual") -> Dict[str, float]:
|
||||
"""Generate demand prediction with business rules applied"""
|
||||
|
||||
try:
|
||||
# Generate base prediction
|
||||
base_prediction = await self._generate_base_prediction(model, features)
|
||||
|
||||
# Apply business rules
|
||||
adjusted_prediction = self.business_rules.apply_rules(
|
||||
base_prediction, features, business_type
|
||||
)
|
||||
|
||||
# Add uncertainty estimation
|
||||
final_prediction = self._add_uncertainty_bands(adjusted_prediction, features)
|
||||
|
||||
return final_prediction
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in demand prediction", error=str(e))
|
||||
raise
|
||||
|
||||
async def _generate_base_prediction(self, model, features: Dict[str, Any]) -> Dict[str, float]:
|
||||
"""Generate base prediction from Prophet model"""
|
||||
|
||||
try:
|
||||
# Convert features to Prophet DataFrame
|
||||
df = self._prepare_prophet_dataframe(features)
|
||||
|
||||
# Generate forecast
|
||||
forecast = model.predict(df)
|
||||
|
||||
if len(forecast) > 0:
|
||||
row = forecast.iloc[0]
|
||||
return {
|
||||
"yhat": float(row['yhat']),
|
||||
"yhat_lower": float(row['yhat_lower']),
|
||||
"yhat_upper": float(row['yhat_upper']),
|
||||
"trend": float(row.get('trend', 0)),
|
||||
"seasonal": float(row.get('seasonal', 0)),
|
||||
"weekly": float(row.get('weekly', 0)),
|
||||
"yearly": float(row.get('yearly', 0)),
|
||||
"holidays": float(row.get('holidays', 0))
|
||||
}
|
||||
else:
|
||||
raise ValueError("No prediction generated from model")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error generating base prediction", error=str(e))
|
||||
raise
|
||||
|
||||
async def _prepare_prophet_dataframe(self, features: Dict[str, Any],
|
||||
historical_data: pd.DataFrame = None) -> pd.DataFrame:
|
||||
"""
|
||||
Convert features to Prophet-compatible DataFrame.
|
||||
Uses enhanced features when available (60+ features vs basic 10).
|
||||
"""
|
||||
|
||||
try:
|
||||
if self.use_enhanced_features and self.data_processor:
|
||||
# Use enhanced data processor from training service
|
||||
logger.info("Generating enhanced features for prediction")
|
||||
|
||||
# Create future date range
|
||||
future_dates = pd.DatetimeIndex([pd.to_datetime(features['date'])])
|
||||
|
||||
# Prepare weather forecast DataFrame
|
||||
weather_df = pd.DataFrame({
|
||||
'date': [pd.to_datetime(features['date'])],
|
||||
'temperature': [features.get('temperature', 15.0)],
|
||||
'precipitation': [features.get('precipitation', 0.0)],
|
||||
'humidity': [features.get('humidity', 60.0)],
|
||||
'wind_speed': [features.get('wind_speed', 5.0)],
|
||||
'pressure': [features.get('pressure', 1013.0)]
|
||||
})
|
||||
|
||||
# Fetch POI features if tenant_id is available
|
||||
poi_features = None
|
||||
if 'tenant_id' in features:
|
||||
poi_features = await self.poi_feature_service.get_poi_features(
|
||||
features['tenant_id']
|
||||
)
|
||||
if poi_features:
|
||||
logger.info(
|
||||
f"Retrieved {len(poi_features)} POI features for prediction",
|
||||
tenant_id=features['tenant_id']
|
||||
)
|
||||
|
||||
# Use data processor to create ALL enhanced features
|
||||
df = await self.data_processor.prepare_prediction_features(
|
||||
future_dates=future_dates,
|
||||
weather_forecast=weather_df,
|
||||
traffic_forecast=None, # Will add when traffic forecasting is implemented
|
||||
poi_features=poi_features, # POI features for location-based forecasting
|
||||
historical_data=historical_data # For lagged features
|
||||
)
|
||||
|
||||
logger.info(f"Generated {len(df.columns)} enhanced features for prediction")
|
||||
return df
|
||||
|
||||
else:
|
||||
# Fallback to basic features
|
||||
logger.info("Using basic features for prediction")
|
||||
|
||||
# Create base DataFrame
|
||||
df = pd.DataFrame({
|
||||
'ds': [pd.to_datetime(features['date'])]
|
||||
})
|
||||
|
||||
# Add regressor features
|
||||
feature_mapping = {
|
||||
'temperature': 'temperature',
|
||||
'precipitation': 'precipitation',
|
||||
'humidity': 'humidity',
|
||||
'wind_speed': 'wind_speed',
|
||||
'traffic_volume': 'traffic_volume',
|
||||
'pedestrian_count': 'pedestrian_count'
|
||||
}
|
||||
|
||||
for feature_key, df_column in feature_mapping.items():
|
||||
if feature_key in features and features[feature_key] is not None:
|
||||
df[df_column] = float(features[feature_key])
|
||||
else:
|
||||
df[df_column] = 0.0
|
||||
|
||||
# Add categorical features
|
||||
df['day_of_week'] = int(features.get('day_of_week', 0))
|
||||
df['is_weekend'] = int(features.get('is_weekend', False))
|
||||
df['is_holiday'] = int(features.get('is_holiday', False))
|
||||
|
||||
# Business type
|
||||
business_type = features.get('business_type', 'individual')
|
||||
df['is_central_workshop'] = int(business_type == 'central_workshop')
|
||||
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing Prophet dataframe: {e}, falling back to basic features")
|
||||
# Fallback to basic implementation on error
|
||||
df = pd.DataFrame({'ds': [pd.to_datetime(features['date'])]})
|
||||
df['temperature'] = features.get('temperature', 15.0)
|
||||
df['precipitation'] = features.get('precipitation', 0.0)
|
||||
df['is_weekend'] = int(features.get('is_weekend', False))
|
||||
df['is_holiday'] = int(features.get('is_holiday', False))
|
||||
return df
|
||||
|
||||
def _add_uncertainty_bands(self, prediction: Dict[str, float],
|
||||
features: Dict[str, Any]) -> Dict[str, float]:
|
||||
"""Add uncertainty estimation based on external factors"""
|
||||
|
||||
try:
|
||||
base_demand = prediction["yhat"]
|
||||
base_lower = prediction["yhat_lower"]
|
||||
base_upper = prediction["yhat_upper"]
|
||||
|
||||
# Weather uncertainty
|
||||
weather_uncertainty = self._calculate_weather_uncertainty(features)
|
||||
|
||||
# Holiday uncertainty
|
||||
holiday_uncertainty = self._calculate_holiday_uncertainty(features)
|
||||
|
||||
# Weekend uncertainty
|
||||
weekend_uncertainty = self._calculate_weekend_uncertainty(features)
|
||||
|
||||
# Total uncertainty factor
|
||||
total_uncertainty = 1.0 + weather_uncertainty + holiday_uncertainty + weekend_uncertainty
|
||||
|
||||
# Adjust bounds
|
||||
uncertainty_range = (base_upper - base_lower) * total_uncertainty
|
||||
center_point = base_demand
|
||||
|
||||
adjusted_lower = center_point - (uncertainty_range / 2)
|
||||
adjusted_upper = center_point + (uncertainty_range / 2)
|
||||
|
||||
return {
|
||||
"demand": max(0, base_demand), # Never predict negative demand
|
||||
"lower_bound": max(0, adjusted_lower),
|
||||
"upper_bound": adjusted_upper,
|
||||
"uncertainty_factor": total_uncertainty,
|
||||
"trend": prediction.get("trend", 0),
|
||||
"seasonal": prediction.get("seasonal", 0),
|
||||
"holiday_effect": prediction.get("holidays", 0)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error adding uncertainty bands", error=str(e))
|
||||
# Return basic prediction if uncertainty calculation fails
|
||||
return {
|
||||
"demand": max(0, prediction["yhat"]),
|
||||
"lower_bound": max(0, prediction["yhat_lower"]),
|
||||
"upper_bound": prediction["yhat_upper"],
|
||||
"uncertainty_factor": 1.0
|
||||
}
|
||||
|
||||
def _calculate_weather_uncertainty(self, features: Dict[str, Any]) -> float:
|
||||
"""Calculate weather-based uncertainty"""
|
||||
|
||||
uncertainty = 0.0
|
||||
|
||||
# Temperature extremes add uncertainty
|
||||
temp = features.get('temperature')
|
||||
if temp is not None:
|
||||
if temp < settings.TEMPERATURE_THRESHOLD_COLD or temp > settings.TEMPERATURE_THRESHOLD_HOT:
|
||||
uncertainty += 0.1
|
||||
|
||||
# Rain adds uncertainty
|
||||
precipitation = features.get('precipitation')
|
||||
if precipitation is not None and precipitation > 0:
|
||||
uncertainty += 0.05 * min(precipitation, 10) # Cap at 50mm
|
||||
|
||||
return uncertainty
|
||||
|
||||
def _calculate_holiday_uncertainty(self, features: Dict[str, Any]) -> float:
|
||||
"""Calculate holiday-based uncertainty"""
|
||||
|
||||
if features.get('is_holiday', False):
|
||||
return 0.2 # 20% additional uncertainty on holidays
|
||||
return 0.0
|
||||
|
||||
def _calculate_weekend_uncertainty(self, features: Dict[str, Any]) -> float:
|
||||
"""Calculate weekend-based uncertainty"""
|
||||
|
||||
if features.get('is_weekend', False):
|
||||
return 0.1 # 10% additional uncertainty on weekends
|
||||
return 0.0
|
||||
|
||||
async def analyze_demand_patterns(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
sales_data: pd.DataFrame,
|
||||
forecast_horizon_days: int = 30,
|
||||
min_history_days: int = 90
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Analyze demand patterns by delegating to the sales service.
|
||||
|
||||
NOTE: Sales data analysis is the responsibility of the sales service.
|
||||
This method calls the sales service API to get demand pattern analysis.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
inventory_product_id: Product identifier
|
||||
sales_data: Historical sales DataFrame (not used - kept for backward compatibility)
|
||||
forecast_horizon_days: Days to forecast ahead (not used currently)
|
||||
min_history_days: Minimum history required
|
||||
|
||||
Returns:
|
||||
Analysis results with patterns, trends, and insights from sales service
|
||||
"""
|
||||
try:
|
||||
from shared.clients.sales_client import SalesServiceClient
|
||||
from datetime import date, timedelta
|
||||
|
||||
logger.info(
|
||||
"Requesting demand pattern analysis from sales service",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id
|
||||
)
|
||||
|
||||
# Initialize sales client
|
||||
sales_client = SalesServiceClient(config=settings, calling_service_name="forecasting")
|
||||
|
||||
# Calculate date range
|
||||
end_date = date.today()
|
||||
start_date = end_date - timedelta(days=min_history_days)
|
||||
|
||||
# Call sales service for pattern analysis
|
||||
patterns = await sales_client.get_product_demand_patterns(
|
||||
tenant_id=tenant_id,
|
||||
product_id=inventory_product_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
min_history_days=min_history_days
|
||||
)
|
||||
|
||||
# Generate insights from patterns
|
||||
insights = self._generate_insights_from_patterns(
|
||||
patterns,
|
||||
tenant_id,
|
||||
inventory_product_id
|
||||
)
|
||||
|
||||
# Add insights to the result
|
||||
patterns['insights'] = insights
|
||||
|
||||
logger.info(
|
||||
"Demand pattern analysis received from sales service",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
insights_generated=len(insights)
|
||||
)
|
||||
|
||||
return patterns
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error getting demand patterns from sales service",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e),
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
'analyzed_at': datetime.utcnow().isoformat(),
|
||||
'history_days': 0,
|
||||
'insights': [],
|
||||
'patterns': {},
|
||||
'trend_analysis': {},
|
||||
'seasonal_factors': {},
|
||||
'statistics': {},
|
||||
'error': str(e)
|
||||
}
|
||||
|
||||
def _generate_insights_from_patterns(
|
||||
self,
|
||||
patterns: Dict[str, Any],
|
||||
tenant_id: str,
|
||||
inventory_product_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Generate actionable insights from demand patterns provided by sales service.
|
||||
|
||||
Args:
|
||||
patterns: Demand patterns from sales service
|
||||
tenant_id: Tenant identifier
|
||||
inventory_product_id: Product identifier
|
||||
|
||||
Returns:
|
||||
List of insights for AI Insights Service
|
||||
"""
|
||||
insights = []
|
||||
|
||||
# Check if there was an error in pattern analysis
|
||||
if 'error' in patterns:
|
||||
return insights
|
||||
|
||||
trend = patterns.get('trend_analysis', {})
|
||||
stats = patterns.get('statistics', {})
|
||||
seasonal = patterns.get('seasonal_factors', {})
|
||||
|
||||
# Trend insight
|
||||
if trend.get('is_increasing'):
|
||||
insights.append({
|
||||
'type': 'insight',
|
||||
'priority': 'medium',
|
||||
'category': 'forecasting',
|
||||
'title': 'Increasing Demand Trend Detected',
|
||||
'description': f"Product shows {trend.get('direction', 'increasing')} demand trend. Consider increasing inventory levels.",
|
||||
'impact_type': 'demand_increase',
|
||||
'impact_value': abs(trend.get('correlation', 0) * 100),
|
||||
'impact_unit': 'percent',
|
||||
'confidence': min(int(abs(trend.get('correlation', 0)) * 100), 95),
|
||||
'metrics_json': trend,
|
||||
'actionable': True,
|
||||
'recommendation_actions': [
|
||||
{
|
||||
'label': 'Increase Safety Stock',
|
||||
'action': 'increase_safety_stock',
|
||||
'params': {'product_id': inventory_product_id, 'factor': 1.2}
|
||||
}
|
||||
]
|
||||
})
|
||||
elif trend.get('is_decreasing'):
|
||||
insights.append({
|
||||
'type': 'insight',
|
||||
'priority': 'low',
|
||||
'category': 'forecasting',
|
||||
'title': 'Decreasing Demand Trend Detected',
|
||||
'description': f"Product shows {trend.get('direction', 'decreasing')} demand trend. Consider reviewing inventory strategy.",
|
||||
'impact_type': 'demand_decrease',
|
||||
'impact_value': abs(trend.get('correlation', 0) * 100),
|
||||
'impact_unit': 'percent',
|
||||
'confidence': min(int(abs(trend.get('correlation', 0)) * 100), 95),
|
||||
'metrics_json': trend,
|
||||
'actionable': True,
|
||||
'recommendation_actions': [
|
||||
{
|
||||
'label': 'Review Inventory Levels',
|
||||
'action': 'review_inventory',
|
||||
'params': {'product_id': inventory_product_id}
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
# Volatility insight
|
||||
cv = stats.get('coefficient_of_variation', 0)
|
||||
if cv > 0.5:
|
||||
insights.append({
|
||||
'type': 'alert',
|
||||
'priority': 'medium',
|
||||
'category': 'forecasting',
|
||||
'title': 'High Demand Variability Detected',
|
||||
'description': f'Product has high demand variability (CV: {cv:.2f}). Consider dynamic safety stock levels.',
|
||||
'impact_type': 'demand_variability',
|
||||
'impact_value': round(cv * 100, 1),
|
||||
'impact_unit': 'percent',
|
||||
'confidence': 85,
|
||||
'metrics_json': stats,
|
||||
'actionable': True,
|
||||
'recommendation_actions': [
|
||||
{
|
||||
'label': 'Enable Dynamic Safety Stock',
|
||||
'action': 'enable_dynamic_safety_stock',
|
||||
'params': {'product_id': inventory_product_id}
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
# Seasonal pattern insight
|
||||
peak_ratio = seasonal.get('peak_ratio', 1.0)
|
||||
if peak_ratio > 1.5:
|
||||
pattern_data = patterns.get('patterns', {})
|
||||
peak_day = pattern_data.get('peak_day', 0)
|
||||
low_day = pattern_data.get('low_day', 0)
|
||||
insights.append({
|
||||
'type': 'insight',
|
||||
'priority': 'medium',
|
||||
'category': 'forecasting',
|
||||
'title': 'Strong Weekly Pattern Detected',
|
||||
'description': f'Demand is {peak_ratio:.1f}x higher on day {peak_day} compared to day {low_day}. Adjust production schedule accordingly.',
|
||||
'impact_type': 'seasonal_pattern',
|
||||
'impact_value': round((peak_ratio - 1) * 100, 1),
|
||||
'impact_unit': 'percent',
|
||||
'confidence': 80,
|
||||
'metrics_json': {**seasonal, **pattern_data},
|
||||
'actionable': True,
|
||||
'recommendation_actions': [
|
||||
{
|
||||
'label': 'Adjust Production Schedule',
|
||||
'action': 'adjust_production',
|
||||
'params': {'product_id': inventory_product_id, 'pattern': 'weekly'}
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
return insights
|
||||
|
||||
async def _get_dynamic_rules(self, tenant_id: str, inventory_product_id: str, rule_type: str) -> Dict[str, float]:
|
||||
"""
|
||||
Fetch learned dynamic rules from AI Insights Service.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID
|
||||
inventory_product_id: Product UUID
|
||||
rule_type: Type of rules (weather, temporal, holiday, etc.)
|
||||
|
||||
Returns:
|
||||
Dictionary of learned rules with factors
|
||||
"""
|
||||
try:
|
||||
from uuid import UUID
|
||||
|
||||
# Fetch latest rules insight for this product
|
||||
insights = await self.ai_insights_client.get_insights(
|
||||
tenant_id=UUID(tenant_id),
|
||||
filters={
|
||||
'category': 'forecasting',
|
||||
'actionable_only': False,
|
||||
'page_size': 100
|
||||
}
|
||||
)
|
||||
|
||||
if not insights or 'items' not in insights:
|
||||
return {}
|
||||
|
||||
# Find the most recent rules insight for this product
|
||||
for insight in insights['items']:
|
||||
if insight.get('source_model') == 'dynamic_rules_engine':
|
||||
metrics = insight.get('metrics_json', {})
|
||||
if metrics.get('inventory_product_id') == inventory_product_id:
|
||||
rules_data = metrics.get('rules', {})
|
||||
return rules_data.get(rule_type, {})
|
||||
|
||||
return {}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch dynamic rules: {e}")
|
||||
return {}
|
||||
|
||||
async def generate_forecast_with_repository(self, tenant_id: str, inventory_product_id: str,
|
||||
forecast_date: date, model_id: str = None) -> Dict[str, Any]:
|
||||
"""Generate forecast with repository integration"""
|
||||
try:
|
||||
# This would integrate with repositories for model loading and caching
|
||||
# For now, we'll implement basic forecasting logic using the forecaster's methods
|
||||
# This is a simplified approach - in production, this would use repositories
|
||||
|
||||
# For now, prepare minimal features for prediction
|
||||
features = {
|
||||
'date': forecast_date.isoformat(),
|
||||
'day_of_week': forecast_date.weekday(),
|
||||
'is_weekend': 1 if forecast_date.weekday() >= 5 else 0,
|
||||
'is_holiday': 0, # Would come from calendar service in real implementation
|
||||
# Add default weather values if needed
|
||||
'temperature': 20.0,
|
||||
'precipitation': 0.0,
|
||||
}
|
||||
|
||||
# This is a placeholder - in a full implementation, we would:
|
||||
# 1. Load the appropriate model from repository
|
||||
# 2. Use historical data to make prediction
|
||||
# 3. Apply business rules
|
||||
# For now, return the structure with basic info
|
||||
|
||||
# For more realistic implementation, we'd use self.predict_demand method
|
||||
# but that requires a model object which needs to be loaded
|
||||
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"forecast_date": forecast_date.isoformat(),
|
||||
"prediction": 10.0, # Placeholder value - in reality would be calculated
|
||||
"confidence_interval": {"lower": 8.0, "upper": 12.0}, # Placeholder values
|
||||
"status": "completed",
|
||||
"repository_integration": True,
|
||||
"forecast_method": "placeholder"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error("Forecast generation failed", error=str(e))
|
||||
raise
|
||||
|
||||
|
||||
class BakeryBusinessRules:
|
||||
"""
|
||||
Business rules for Spanish bakeries
|
||||
Applies domain-specific adjustments to predictions
|
||||
Supports both dynamic learned rules and hardcoded fallbacks
|
||||
"""
|
||||
|
||||
def __init__(self, use_dynamic_rules=False, ai_insights_client=None):
|
||||
self.use_dynamic_rules = use_dynamic_rules
|
||||
self.ai_insights_client = ai_insights_client
|
||||
self.rules_cache = {}
|
||||
|
||||
async def apply_rules(self, prediction: Dict[str, float], features: Dict[str, Any],
|
||||
business_type: str, tenant_id: str = None, inventory_product_id: str = None) -> Dict[str, float]:
|
||||
"""Apply all business rules to prediction (dynamic or hardcoded)"""
|
||||
|
||||
adjusted_prediction = prediction.copy()
|
||||
|
||||
# Apply weather rules
|
||||
adjusted_prediction = await self._apply_weather_rules(
|
||||
adjusted_prediction, features, tenant_id, inventory_product_id
|
||||
)
|
||||
|
||||
# Apply time-based rules
|
||||
adjusted_prediction = await self._apply_time_rules(
|
||||
adjusted_prediction, features, tenant_id, inventory_product_id
|
||||
)
|
||||
|
||||
# Apply business type rules
|
||||
adjusted_prediction = self._apply_business_type_rules(adjusted_prediction, business_type)
|
||||
|
||||
# Apply Spanish-specific rules
|
||||
adjusted_prediction = self._apply_spanish_rules(adjusted_prediction, features)
|
||||
|
||||
return adjusted_prediction
|
||||
|
||||
async def _get_dynamic_rules(self, tenant_id: str, inventory_product_id: str, rule_type: str) -> Dict[str, float]:
|
||||
"""
|
||||
Fetch learned dynamic rules from AI Insights Service.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID
|
||||
inventory_product_id: Product UUID
|
||||
rule_type: Type of rules (weather, temporal, holiday, etc.)
|
||||
|
||||
Returns:
|
||||
Dictionary of learned rules with factors
|
||||
"""
|
||||
# Check cache first
|
||||
cache_key = f"{tenant_id}:{inventory_product_id}:{rule_type}"
|
||||
if cache_key in self.rules_cache:
|
||||
return self.rules_cache[cache_key]
|
||||
|
||||
try:
|
||||
from uuid import UUID
|
||||
|
||||
if not self.ai_insights_client:
|
||||
return {}
|
||||
|
||||
# Fetch latest rules insight for this product
|
||||
insights = await self.ai_insights_client.get_insights(
|
||||
tenant_id=UUID(tenant_id),
|
||||
filters={
|
||||
'category': 'forecasting',
|
||||
'actionable_only': False,
|
||||
'page_size': 100
|
||||
}
|
||||
)
|
||||
|
||||
if not insights or 'items' not in insights:
|
||||
return {}
|
||||
|
||||
# Find the most recent rules insight for this product
|
||||
for insight in insights['items']:
|
||||
if insight.get('source_model') == 'dynamic_rules_engine':
|
||||
metrics = insight.get('metrics_json', {})
|
||||
if metrics.get('inventory_product_id') == inventory_product_id:
|
||||
rules_data = metrics.get('rules', {})
|
||||
result = rules_data.get(rule_type, {})
|
||||
# Cache the result
|
||||
self.rules_cache[cache_key] = result
|
||||
return result
|
||||
|
||||
return {}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch dynamic rules: {e}")
|
||||
return {}
|
||||
|
||||
async def _apply_weather_rules(self, prediction: Dict[str, float],
|
||||
features: Dict[str, Any],
|
||||
tenant_id: str = None,
|
||||
inventory_product_id: str = None) -> Dict[str, float]:
|
||||
"""Apply weather-based business rules (dynamic or hardcoded fallback)"""
|
||||
|
||||
if self.use_dynamic_rules and tenant_id and inventory_product_id:
|
||||
try:
|
||||
# Fetch dynamic weather rules
|
||||
rules = await self._get_dynamic_rules(tenant_id, inventory_product_id, 'weather')
|
||||
|
||||
# Apply learned rain impact
|
||||
precipitation = features.get('precipitation', 0)
|
||||
if precipitation > 0:
|
||||
rain_factor = rules.get('rain_factor', settings.RAIN_IMPACT_FACTOR)
|
||||
prediction["yhat"] *= rain_factor
|
||||
prediction["yhat_lower"] *= rain_factor
|
||||
prediction["yhat_upper"] *= rain_factor
|
||||
|
||||
# Apply learned temperature impact
|
||||
temperature = features.get('temperature')
|
||||
if temperature is not None:
|
||||
if temperature > settings.TEMPERATURE_THRESHOLD_HOT:
|
||||
hot_factor = rules.get('temperature_hot_factor', 0.9)
|
||||
prediction["yhat"] *= hot_factor
|
||||
elif temperature < settings.TEMPERATURE_THRESHOLD_COLD:
|
||||
cold_factor = rules.get('temperature_cold_factor', 1.1)
|
||||
prediction["yhat"] *= cold_factor
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to apply dynamic weather rules, using fallback: {e}")
|
||||
# Fallback to hardcoded
|
||||
precipitation = features.get('precipitation', 0)
|
||||
if precipitation > 0:
|
||||
prediction["yhat"] *= settings.RAIN_IMPACT_FACTOR
|
||||
prediction["yhat_lower"] *= settings.RAIN_IMPACT_FACTOR
|
||||
prediction["yhat_upper"] *= settings.RAIN_IMPACT_FACTOR
|
||||
|
||||
temperature = features.get('temperature')
|
||||
if temperature is not None:
|
||||
if temperature > settings.TEMPERATURE_THRESHOLD_HOT:
|
||||
prediction["yhat"] *= 0.9
|
||||
elif temperature < settings.TEMPERATURE_THRESHOLD_COLD:
|
||||
prediction["yhat"] *= 1.1
|
||||
else:
|
||||
# Use hardcoded rules
|
||||
precipitation = features.get('precipitation', 0)
|
||||
if precipitation > 0:
|
||||
rain_factor = settings.RAIN_IMPACT_FACTOR
|
||||
prediction["yhat"] *= rain_factor
|
||||
prediction["yhat_lower"] *= rain_factor
|
||||
prediction["yhat_upper"] *= rain_factor
|
||||
|
||||
temperature = features.get('temperature')
|
||||
if temperature is not None:
|
||||
if temperature > settings.TEMPERATURE_THRESHOLD_HOT:
|
||||
prediction["yhat"] *= 0.9
|
||||
elif temperature < settings.TEMPERATURE_THRESHOLD_COLD:
|
||||
prediction["yhat"] *= 1.1
|
||||
|
||||
return prediction
|
||||
|
||||
async def _apply_time_rules(self, prediction: Dict[str, float],
|
||||
features: Dict[str, Any],
|
||||
tenant_id: str = None,
|
||||
inventory_product_id: str = None) -> Dict[str, float]:
|
||||
"""Apply time-based business rules (dynamic or hardcoded fallback)"""
|
||||
|
||||
if self.use_dynamic_rules and tenant_id and inventory_product_id:
|
||||
try:
|
||||
# Fetch dynamic temporal rules
|
||||
rules = await self._get_dynamic_rules(tenant_id, inventory_product_id, 'temporal')
|
||||
|
||||
# Apply learned weekend adjustment
|
||||
if features.get('is_weekend', False):
|
||||
weekend_factor = rules.get('weekend_factor', settings.WEEKEND_ADJUSTMENT_FACTOR)
|
||||
prediction["yhat"] *= weekend_factor
|
||||
prediction["yhat_lower"] *= weekend_factor
|
||||
prediction["yhat_upper"] *= weekend_factor
|
||||
|
||||
# Apply learned holiday adjustment
|
||||
if features.get('is_holiday', False):
|
||||
holiday_factor = rules.get('holiday_factor', settings.HOLIDAY_ADJUSTMENT_FACTOR)
|
||||
prediction["yhat"] *= holiday_factor
|
||||
prediction["yhat_lower"] *= holiday_factor
|
||||
prediction["yhat_upper"] *= holiday_factor
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to apply dynamic time rules, using fallback: {e}")
|
||||
# Fallback to hardcoded
|
||||
if features.get('is_weekend', False):
|
||||
prediction["yhat"] *= settings.WEEKEND_ADJUSTMENT_FACTOR
|
||||
prediction["yhat_lower"] *= settings.WEEKEND_ADJUSTMENT_FACTOR
|
||||
prediction["yhat_upper"] *= settings.WEEKEND_ADJUSTMENT_FACTOR
|
||||
|
||||
if features.get('is_holiday', False):
|
||||
prediction["yhat"] *= settings.HOLIDAY_ADJUSTMENT_FACTOR
|
||||
prediction["yhat_lower"] *= settings.HOLIDAY_ADJUSTMENT_FACTOR
|
||||
prediction["yhat_upper"] *= settings.HOLIDAY_ADJUSTMENT_FACTOR
|
||||
else:
|
||||
# Use hardcoded rules
|
||||
if features.get('is_weekend', False):
|
||||
weekend_factor = settings.WEEKEND_ADJUSTMENT_FACTOR
|
||||
prediction["yhat"] *= weekend_factor
|
||||
prediction["yhat_lower"] *= weekend_factor
|
||||
prediction["yhat_upper"] *= weekend_factor
|
||||
|
||||
if features.get('is_holiday', False):
|
||||
holiday_factor = settings.HOLIDAY_ADJUSTMENT_FACTOR
|
||||
prediction["yhat"] *= holiday_factor
|
||||
prediction["yhat_lower"] *= holiday_factor
|
||||
prediction["yhat_upper"] *= holiday_factor
|
||||
|
||||
return prediction
|
||||
|
||||
def _apply_business_type_rules(self, prediction: Dict[str, float],
|
||||
business_type: str) -> Dict[str, float]:
|
||||
"""Apply business type specific rules"""
|
||||
|
||||
if business_type == "central_workshop":
|
||||
# Central workshops have more stable demand
|
||||
uncertainty_reduction = 0.8
|
||||
center = prediction["yhat"]
|
||||
lower = prediction["yhat_lower"]
|
||||
upper = prediction["yhat_upper"]
|
||||
|
||||
# Reduce uncertainty band
|
||||
new_range = (upper - lower) * uncertainty_reduction
|
||||
prediction["yhat_lower"] = center - (new_range / 2)
|
||||
prediction["yhat_upper"] = center + (new_range / 2)
|
||||
|
||||
return prediction
|
||||
|
||||
def _apply_spanish_rules(self, prediction: Dict[str, float],
|
||||
features: Dict[str, Any]) -> Dict[str, float]:
|
||||
"""Apply Spanish bakery specific rules"""
|
||||
|
||||
# Spanish siesta time considerations
|
||||
date_str = features.get('date')
|
||||
if date_str:
|
||||
try:
|
||||
current_date = pd.to_datetime(date_str)
|
||||
day_of_week = current_date.weekday()
|
||||
|
||||
# Reduced activity during typical siesta hours (14:00-17:00)
|
||||
# This affects afternoon sales planning
|
||||
if day_of_week < 5: # Weekdays
|
||||
prediction["yhat"] *= 0.95 # Slight reduction for siesta effect
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing date in spanish rules: {e}")
|
||||
else:
|
||||
logger.warning("Date not provided in features, skipping Spanish rules")
|
||||
|
||||
return prediction
|
||||
312
services/forecasting/app/ml/rules_orchestrator.py
Normal file
312
services/forecasting/app/ml/rules_orchestrator.py
Normal file
@@ -0,0 +1,312 @@
|
||||
"""
|
||||
Rules Orchestrator
|
||||
Coordinates dynamic rules learning, insight posting, and integration with forecasting service
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from typing import Dict, List, Any, Optional
|
||||
import structlog
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from app.ml.dynamic_rules_engine import DynamicRulesEngine
|
||||
from app.clients.ai_insights_client import AIInsightsClient
|
||||
from shared.messaging import UnifiedEventPublisher
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class RulesOrchestrator:
|
||||
"""
|
||||
Orchestrates dynamic rules learning and insight generation workflow.
|
||||
|
||||
Workflow:
|
||||
1. Learn dynamic rules from historical data
|
||||
2. Generate insights comparing learned vs hardcoded rules
|
||||
3. Post insights to AI Insights Service
|
||||
4. Provide learned rules for forecasting integration
|
||||
5. Track rule updates and performance
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ai_insights_base_url: str = "http://ai-insights-service:8000",
|
||||
event_publisher: Optional[UnifiedEventPublisher] = None
|
||||
):
|
||||
self.rules_engine = DynamicRulesEngine()
|
||||
self.ai_insights_client = AIInsightsClient(ai_insights_base_url)
|
||||
self.event_publisher = event_publisher
|
||||
|
||||
async def learn_and_post_rules(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
sales_data: pd.DataFrame,
|
||||
external_data: Optional[pd.DataFrame] = None,
|
||||
min_samples: int = 10
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Complete workflow: Learn rules and post insights.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
inventory_product_id: Product identifier
|
||||
sales_data: Historical sales data
|
||||
external_data: Optional weather/events/holidays data
|
||||
min_samples: Minimum samples for rule learning
|
||||
|
||||
Returns:
|
||||
Workflow results with learned rules and posted insights
|
||||
"""
|
||||
logger.info(
|
||||
"Starting dynamic rules learning workflow",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id
|
||||
)
|
||||
|
||||
# Step 1: Learn all rules from data
|
||||
rules_results = await self.rules_engine.learn_all_rules(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
sales_data=sales_data,
|
||||
external_data=external_data,
|
||||
min_samples=min_samples
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Rules learning complete",
|
||||
insights_generated=len(rules_results['insights']),
|
||||
rules_learned=len(rules_results['rules'])
|
||||
)
|
||||
|
||||
# Step 2: Enrich insights with tenant_id and product context
|
||||
enriched_insights = self._enrich_insights(
|
||||
rules_results['insights'],
|
||||
tenant_id,
|
||||
inventory_product_id
|
||||
)
|
||||
|
||||
# Step 3: Post insights to AI Insights Service
|
||||
if enriched_insights:
|
||||
post_results = await self.ai_insights_client.create_insights_bulk(
|
||||
tenant_id=UUID(tenant_id),
|
||||
insights=enriched_insights
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Insights posted to AI Insights Service",
|
||||
total=post_results['total'],
|
||||
successful=post_results['successful'],
|
||||
failed=post_results['failed']
|
||||
)
|
||||
else:
|
||||
post_results = {'total': 0, 'successful': 0, 'failed': 0}
|
||||
logger.info("No insights to post")
|
||||
|
||||
# Step 4: Publish insight events to RabbitMQ
|
||||
created_insights = post_results.get('created_insights', [])
|
||||
if created_insights:
|
||||
product_context = {'inventory_product_id': inventory_product_id}
|
||||
await self._publish_insight_events(
|
||||
tenant_id=tenant_id,
|
||||
insights=created_insights,
|
||||
product_context=product_context
|
||||
)
|
||||
|
||||
# Step 5: Return comprehensive results
|
||||
return {
|
||||
'tenant_id': tenant_id,
|
||||
'inventory_product_id': inventory_product_id,
|
||||
'learned_at': rules_results['learned_at'],
|
||||
'rules': rules_results['rules'],
|
||||
'insights_generated': len(enriched_insights),
|
||||
'insights_posted': post_results['successful'],
|
||||
'insights_failed': post_results['failed'],
|
||||
'created_insights': post_results.get('created_insights', [])
|
||||
}
|
||||
|
||||
def _enrich_insights(
|
||||
self,
|
||||
insights: List[Dict[str, Any]],
|
||||
tenant_id: str,
|
||||
inventory_product_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Enrich insights with required fields for AI Insights Service.
|
||||
|
||||
Args:
|
||||
insights: Raw insights from rules engine
|
||||
tenant_id: Tenant identifier
|
||||
inventory_product_id: Product identifier
|
||||
|
||||
Returns:
|
||||
Enriched insights ready for posting
|
||||
"""
|
||||
enriched = []
|
||||
|
||||
for insight in insights:
|
||||
# Add required tenant_id and product context
|
||||
enriched_insight = insight.copy()
|
||||
enriched_insight['tenant_id'] = tenant_id
|
||||
|
||||
# Add product context to metrics
|
||||
if 'metrics_json' not in enriched_insight:
|
||||
enriched_insight['metrics_json'] = {}
|
||||
|
||||
enriched_insight['metrics_json']['inventory_product_id'] = inventory_product_id
|
||||
|
||||
# Add source metadata
|
||||
enriched_insight['source_service'] = 'forecasting'
|
||||
enriched_insight['source_model'] = 'dynamic_rules_engine'
|
||||
enriched_insight['detected_at'] = datetime.utcnow().isoformat()
|
||||
|
||||
enriched.append(enriched_insight)
|
||||
|
||||
return enriched
|
||||
|
||||
async def get_learned_rules_for_forecasting(
|
||||
self,
|
||||
inventory_product_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get learned rules in format ready for forecasting integration.
|
||||
|
||||
Args:
|
||||
inventory_product_id: Product identifier
|
||||
|
||||
Returns:
|
||||
Dictionary with learned multipliers for all rule types
|
||||
"""
|
||||
return self.rules_engine.export_rules_for_prophet(inventory_product_id)
|
||||
|
||||
def get_rule_multiplier(
|
||||
self,
|
||||
inventory_product_id: str,
|
||||
rule_type: str,
|
||||
key: str,
|
||||
default: float = 1.0
|
||||
) -> float:
|
||||
"""
|
||||
Get learned rule multiplier with fallback to default.
|
||||
|
||||
Args:
|
||||
inventory_product_id: Product identifier
|
||||
rule_type: 'weather', 'holiday', 'event', 'day_of_week', 'month'
|
||||
key: Condition key
|
||||
default: Default multiplier if rule not learned
|
||||
|
||||
Returns:
|
||||
Learned multiplier or default
|
||||
"""
|
||||
learned = self.rules_engine.get_rule(inventory_product_id, rule_type, key)
|
||||
return learned if learned is not None else default
|
||||
|
||||
async def update_rules_periodically(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
sales_data: pd.DataFrame,
|
||||
external_data: Optional[pd.DataFrame] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Update learned rules with new data (for periodic refresh).
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
inventory_product_id: Product identifier
|
||||
sales_data: Updated historical sales data
|
||||
external_data: Updated external data
|
||||
|
||||
Returns:
|
||||
Update results
|
||||
"""
|
||||
logger.info(
|
||||
"Updating learned rules with new data",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
new_data_points=len(sales_data)
|
||||
)
|
||||
|
||||
# Re-learn rules with updated data
|
||||
results = await self.learn_and_post_rules(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
sales_data=sales_data,
|
||||
external_data=external_data
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Rules update complete",
|
||||
insights_posted=results['insights_posted']
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def _publish_insight_events(self, tenant_id, insights, product_context=None):
|
||||
"""
|
||||
Publish insight events to RabbitMQ for alert processing.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
insights: List of created insights
|
||||
product_context: Additional context about the product
|
||||
"""
|
||||
if not self.event_publisher:
|
||||
logger.warning("No event publisher available for business rules insights")
|
||||
return
|
||||
|
||||
for insight in insights:
|
||||
# Determine severity based on confidence and priority
|
||||
confidence = insight.get('confidence', 0)
|
||||
priority = insight.get('priority', 'medium')
|
||||
|
||||
# Map priority to severity, with confidence as tiebreaker
|
||||
if priority == 'critical' or (priority == 'high' and confidence >= 70):
|
||||
severity = 'high'
|
||||
elif priority == 'high' or (priority == 'medium' and confidence >= 80):
|
||||
severity = 'medium'
|
||||
else:
|
||||
severity = 'low'
|
||||
|
||||
# Prepare the event data
|
||||
event_data = {
|
||||
'insight_id': insight.get('id'),
|
||||
'type': insight.get('type'),
|
||||
'title': insight.get('title'),
|
||||
'description': insight.get('description'),
|
||||
'category': insight.get('category'),
|
||||
'priority': insight.get('priority'),
|
||||
'confidence': confidence,
|
||||
'recommendation': insight.get('recommendation_actions', []),
|
||||
'impact_type': insight.get('impact_type'),
|
||||
'impact_value': insight.get('impact_value'),
|
||||
'inventory_product_id': product_context.get('inventory_product_id') if product_context else None,
|
||||
'timestamp': insight.get('detected_at', datetime.utcnow().isoformat()),
|
||||
'source_service': 'forecasting',
|
||||
'source_model': 'dynamic_rules_engine'
|
||||
}
|
||||
|
||||
try:
|
||||
await self.event_publisher.publish_recommendation(
|
||||
event_type='ai_business_rule',
|
||||
tenant_id=tenant_id,
|
||||
severity=severity,
|
||||
data=event_data
|
||||
)
|
||||
logger.info(
|
||||
"Published business rules insight event",
|
||||
tenant_id=tenant_id,
|
||||
insight_id=insight.get('id'),
|
||||
severity=severity
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to publish business rules insight event",
|
||||
tenant_id=tenant_id,
|
||||
insight_id=insight.get('id'),
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
"""Close HTTP client connections."""
|
||||
await self.ai_insights_client.close()
|
||||
385
services/forecasting/app/ml/scenario_planner.py
Normal file
385
services/forecasting/app/ml/scenario_planner.py
Normal file
@@ -0,0 +1,385 @@
|
||||
"""
|
||||
Scenario Planning System
|
||||
What-if analysis for demand forecasting
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Dict, List, Any, Optional
|
||||
from datetime import datetime, date, timedelta
|
||||
import structlog
|
||||
from enum import Enum
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class ScenarioType(str, Enum):
|
||||
"""Types of scenarios"""
|
||||
BASELINE = "baseline"
|
||||
OPTIMISTIC = "optimistic"
|
||||
PESSIMISTIC = "pessimistic"
|
||||
CUSTOM = "custom"
|
||||
PROMOTION = "promotion"
|
||||
EVENT = "event"
|
||||
WEATHER = "weather"
|
||||
PRICE_CHANGE = "price_change"
|
||||
|
||||
|
||||
class ScenarioPlanner:
|
||||
"""
|
||||
Scenario planning for demand forecasting.
|
||||
|
||||
Allows testing "what-if" scenarios:
|
||||
- What if we run a promotion?
|
||||
- What if there's a local festival?
|
||||
- What if weather is unusually bad?
|
||||
- What if we change prices?
|
||||
"""
|
||||
|
||||
def __init__(self, base_forecaster=None):
|
||||
"""
|
||||
Initialize scenario planner.
|
||||
|
||||
Args:
|
||||
base_forecaster: Base forecaster to use for baseline predictions
|
||||
"""
|
||||
self.base_forecaster = base_forecaster
|
||||
|
||||
async def create_scenario(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
scenario_name: str,
|
||||
scenario_type: ScenarioType,
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
adjustments: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a forecast scenario with adjustments.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
inventory_product_id: Product identifier
|
||||
scenario_name: Name for the scenario
|
||||
scenario_type: Type of scenario
|
||||
start_date: Scenario start date
|
||||
end_date: Scenario end date
|
||||
adjustments: Dictionary of adjustments to apply
|
||||
|
||||
Returns:
|
||||
Scenario forecast results
|
||||
"""
|
||||
logger.info(
|
||||
"Creating forecast scenario",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
scenario_name=scenario_name,
|
||||
scenario_type=scenario_type
|
||||
)
|
||||
|
||||
# Generate baseline forecast first
|
||||
baseline_forecast = await self._generate_baseline_forecast(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
# Apply scenario adjustments
|
||||
scenario_forecast = self._apply_scenario_adjustments(
|
||||
baseline_forecast=baseline_forecast,
|
||||
adjustments=adjustments,
|
||||
scenario_type=scenario_type
|
||||
)
|
||||
|
||||
# Calculate impact
|
||||
impact_analysis = self._calculate_scenario_impact(
|
||||
baseline_forecast=baseline_forecast,
|
||||
scenario_forecast=scenario_forecast
|
||||
)
|
||||
|
||||
return {
|
||||
'scenario_id': f"scenario_{tenant_id}_{inventory_product_id}_{datetime.now().strftime('%Y%m%d%H%M%S')}",
|
||||
'scenario_name': scenario_name,
|
||||
'scenario_type': scenario_type,
|
||||
'tenant_id': tenant_id,
|
||||
'inventory_product_id': inventory_product_id,
|
||||
'date_range': {
|
||||
'start': start_date.isoformat(),
|
||||
'end': end_date.isoformat()
|
||||
},
|
||||
'baseline_forecast': baseline_forecast,
|
||||
'scenario_forecast': scenario_forecast,
|
||||
'impact_analysis': impact_analysis,
|
||||
'adjustments_applied': adjustments,
|
||||
'created_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
async def compare_scenarios(
|
||||
self,
|
||||
scenarios: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Compare multiple scenarios side-by-side.
|
||||
|
||||
Args:
|
||||
scenarios: List of scenario results from create_scenario()
|
||||
|
||||
Returns:
|
||||
Comparison analysis
|
||||
"""
|
||||
if len(scenarios) < 2:
|
||||
return {'error': 'Need at least 2 scenarios to compare'}
|
||||
|
||||
comparison = {
|
||||
'scenarios_compared': len(scenarios),
|
||||
'scenario_names': [s['scenario_name'] for s in scenarios],
|
||||
'comparison_metrics': {}
|
||||
}
|
||||
|
||||
# Extract total demand for each scenario
|
||||
for scenario in scenarios:
|
||||
scenario_name = scenario['scenario_name']
|
||||
scenario_forecast = scenario['scenario_forecast']
|
||||
|
||||
total_demand = sum(f['predicted_demand'] for f in scenario_forecast)
|
||||
|
||||
comparison['comparison_metrics'][scenario_name] = {
|
||||
'total_demand': total_demand,
|
||||
'avg_daily_demand': total_demand / len(scenario_forecast) if scenario_forecast else 0,
|
||||
'peak_demand': max(f['predicted_demand'] for f in scenario_forecast) if scenario_forecast else 0
|
||||
}
|
||||
|
||||
# Determine best and worst scenarios
|
||||
total_demands = {
|
||||
name: metrics['total_demand']
|
||||
for name, metrics in comparison['comparison_metrics'].items()
|
||||
}
|
||||
|
||||
comparison['best_scenario'] = max(total_demands, key=total_demands.get)
|
||||
comparison['worst_scenario'] = min(total_demands, key=total_demands.get)
|
||||
|
||||
comparison['demand_range'] = {
|
||||
'min': min(total_demands.values()),
|
||||
'max': max(total_demands.values()),
|
||||
'spread': max(total_demands.values()) - min(total_demands.values())
|
||||
}
|
||||
|
||||
return comparison
|
||||
|
||||
async def _generate_baseline_forecast(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
start_date: date,
|
||||
end_date: date
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Generate baseline forecast without adjustments.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
inventory_product_id: Product identifier
|
||||
start_date: Start date
|
||||
end_date: End date
|
||||
|
||||
Returns:
|
||||
List of daily forecasts
|
||||
"""
|
||||
# Generate date range
|
||||
dates = []
|
||||
current_date = start_date
|
||||
while current_date <= end_date:
|
||||
dates.append(current_date)
|
||||
current_date += timedelta(days=1)
|
||||
|
||||
# Placeholder forecast (in real implementation, call forecasting service)
|
||||
baseline = []
|
||||
for forecast_date in dates:
|
||||
baseline.append({
|
||||
'date': forecast_date.isoformat(),
|
||||
'predicted_demand': 100, # Placeholder
|
||||
'confidence_lower': 80,
|
||||
'confidence_upper': 120
|
||||
})
|
||||
|
||||
return baseline
|
||||
|
||||
def _apply_scenario_adjustments(
|
||||
self,
|
||||
baseline_forecast: List[Dict[str, Any]],
|
||||
adjustments: Dict[str, Any],
|
||||
scenario_type: ScenarioType
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Apply adjustments to baseline forecast.
|
||||
|
||||
Args:
|
||||
baseline_forecast: Baseline forecast data
|
||||
adjustments: Adjustments to apply
|
||||
scenario_type: Type of scenario
|
||||
|
||||
Returns:
|
||||
Adjusted forecast
|
||||
"""
|
||||
scenario_forecast = []
|
||||
|
||||
for day_forecast in baseline_forecast:
|
||||
adjusted_forecast = day_forecast.copy()
|
||||
|
||||
# Apply different adjustment types
|
||||
if 'demand_multiplier' in adjustments:
|
||||
# Multiply demand by factor
|
||||
multiplier = adjustments['demand_multiplier']
|
||||
adjusted_forecast['predicted_demand'] *= multiplier
|
||||
adjusted_forecast['confidence_lower'] *= multiplier
|
||||
adjusted_forecast['confidence_upper'] *= multiplier
|
||||
|
||||
if 'demand_offset' in adjustments:
|
||||
# Add/subtract fixed amount
|
||||
offset = adjustments['demand_offset']
|
||||
adjusted_forecast['predicted_demand'] += offset
|
||||
adjusted_forecast['confidence_lower'] += offset
|
||||
adjusted_forecast['confidence_upper'] += offset
|
||||
|
||||
if 'event_impact' in adjustments:
|
||||
# Apply event-specific impact
|
||||
event_multiplier = adjustments['event_impact']
|
||||
adjusted_forecast['predicted_demand'] *= event_multiplier
|
||||
|
||||
if 'weather_impact' in adjustments:
|
||||
# Apply weather adjustments
|
||||
weather_factor = adjustments['weather_impact']
|
||||
adjusted_forecast['predicted_demand'] *= weather_factor
|
||||
|
||||
if 'price_elasticity' in adjustments and 'price_change_percent' in adjustments:
|
||||
# Apply price elasticity
|
||||
elasticity = adjustments['price_elasticity']
|
||||
price_change = adjustments['price_change_percent']
|
||||
demand_change = -elasticity * price_change # Negative correlation
|
||||
adjusted_forecast['predicted_demand'] *= (1 + demand_change)
|
||||
|
||||
# Ensure non-negative demand
|
||||
adjusted_forecast['predicted_demand'] = max(0, adjusted_forecast['predicted_demand'])
|
||||
adjusted_forecast['confidence_lower'] = max(0, adjusted_forecast['confidence_lower'])
|
||||
|
||||
scenario_forecast.append(adjusted_forecast)
|
||||
|
||||
return scenario_forecast
|
||||
|
||||
def _calculate_scenario_impact(
|
||||
self,
|
||||
baseline_forecast: List[Dict[str, Any]],
|
||||
scenario_forecast: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Calculate impact of scenario vs baseline.
|
||||
|
||||
Args:
|
||||
baseline_forecast: Baseline forecast
|
||||
scenario_forecast: Scenario forecast
|
||||
|
||||
Returns:
|
||||
Impact analysis
|
||||
"""
|
||||
baseline_total = sum(f['predicted_demand'] for f in baseline_forecast)
|
||||
scenario_total = sum(f['predicted_demand'] for f in scenario_forecast)
|
||||
|
||||
difference = scenario_total - baseline_total
|
||||
percent_change = (difference / baseline_total * 100) if baseline_total > 0 else 0
|
||||
|
||||
return {
|
||||
'baseline_total_demand': baseline_total,
|
||||
'scenario_total_demand': scenario_total,
|
||||
'absolute_difference': difference,
|
||||
'percent_change': percent_change,
|
||||
'impact_category': self._categorize_impact(percent_change),
|
||||
'days_analyzed': len(baseline_forecast)
|
||||
}
|
||||
|
||||
def _categorize_impact(self, percent_change: float) -> str:
|
||||
"""Categorize impact magnitude"""
|
||||
if abs(percent_change) < 5:
|
||||
return "minimal"
|
||||
elif abs(percent_change) < 15:
|
||||
return "moderate"
|
||||
elif abs(percent_change) < 30:
|
||||
return "significant"
|
||||
else:
|
||||
return "major"
|
||||
|
||||
def generate_predefined_scenarios(
|
||||
self,
|
||||
base_scenario: Dict[str, Any]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Generate common predefined scenarios for comparison.
|
||||
|
||||
Args:
|
||||
base_scenario: Base scenario parameters
|
||||
|
||||
Returns:
|
||||
List of scenario configurations
|
||||
"""
|
||||
scenarios = []
|
||||
|
||||
# Baseline scenario
|
||||
scenarios.append({
|
||||
'scenario_name': 'Baseline',
|
||||
'scenario_type': ScenarioType.BASELINE,
|
||||
'adjustments': {}
|
||||
})
|
||||
|
||||
# Optimistic scenario
|
||||
scenarios.append({
|
||||
'scenario_name': 'Optimistic',
|
||||
'scenario_type': ScenarioType.OPTIMISTIC,
|
||||
'adjustments': {
|
||||
'demand_multiplier': 1.2, # 20% increase
|
||||
'description': '+20% demand increase'
|
||||
}
|
||||
})
|
||||
|
||||
# Pessimistic scenario
|
||||
scenarios.append({
|
||||
'scenario_name': 'Pessimistic',
|
||||
'scenario_type': ScenarioType.PESSIMISTIC,
|
||||
'adjustments': {
|
||||
'demand_multiplier': 0.8, # 20% decrease
|
||||
'description': '-20% demand decrease'
|
||||
}
|
||||
})
|
||||
|
||||
# Promotion scenario
|
||||
scenarios.append({
|
||||
'scenario_name': 'Promotion Campaign',
|
||||
'scenario_type': ScenarioType.PROMOTION,
|
||||
'adjustments': {
|
||||
'demand_multiplier': 1.5, # 50% increase
|
||||
'description': '50% promotion boost'
|
||||
}
|
||||
})
|
||||
|
||||
# Bad weather scenario
|
||||
scenarios.append({
|
||||
'scenario_name': 'Bad Weather',
|
||||
'scenario_type': ScenarioType.WEATHER,
|
||||
'adjustments': {
|
||||
'weather_impact': 0.7, # 30% decrease
|
||||
'description': 'Bad weather reduces foot traffic'
|
||||
}
|
||||
})
|
||||
|
||||
# Price increase scenario
|
||||
scenarios.append({
|
||||
'scenario_name': 'Price Increase 10%',
|
||||
'scenario_type': ScenarioType.PRICE_CHANGE,
|
||||
'adjustments': {
|
||||
'price_elasticity': 1.2, # Elastic demand
|
||||
'price_change_percent': 0.10, # 10% price increase
|
||||
'description': '10% price increase with elastic demand'
|
||||
}
|
||||
})
|
||||
|
||||
return scenarios
|
||||
29
services/forecasting/app/models/__init__.py
Normal file
29
services/forecasting/app/models/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
Forecasting Service Models Package
|
||||
|
||||
Import all models to ensure they are registered with SQLAlchemy Base.
|
||||
"""
|
||||
|
||||
# Import AuditLog model for this service
|
||||
from shared.security import create_audit_log_model
|
||||
from shared.database.base import Base
|
||||
|
||||
# Create audit log model for this service
|
||||
AuditLog = create_audit_log_model(Base)
|
||||
|
||||
# Import all models to register them with the Base metadata
|
||||
from .forecasts import Forecast, PredictionBatch
|
||||
from .predictions import ModelPerformanceMetric, PredictionCache
|
||||
from .validation_run import ValidationRun
|
||||
from .sales_data_update import SalesDataUpdate
|
||||
|
||||
# List all models for easier access
|
||||
__all__ = [
|
||||
"Forecast",
|
||||
"PredictionBatch",
|
||||
"ModelPerformanceMetric",
|
||||
"PredictionCache",
|
||||
"ValidationRun",
|
||||
"SalesDataUpdate",
|
||||
"AuditLog",
|
||||
]
|
||||
101
services/forecasting/app/models/forecasts.py
Normal file
101
services/forecasting/app/models/forecasts.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/models/forecasts.py
|
||||
# ================================================================
|
||||
"""
|
||||
Forecast models for the forecasting service
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, Integer, Float, DateTime, Boolean, Text, JSON, UniqueConstraint, Index
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
|
||||
from shared.database.base import Base
|
||||
|
||||
class Forecast(Base):
|
||||
"""Forecast model for storing prediction results"""
|
||||
__tablename__ = "forecasts"
|
||||
|
||||
__table_args__ = (
|
||||
# Unique constraint to prevent duplicate forecasts
|
||||
# Ensures only one forecast per (tenant, product, date, location) combination
|
||||
UniqueConstraint(
|
||||
'tenant_id', 'inventory_product_id', 'forecast_date', 'location',
|
||||
name='uq_forecast_tenant_product_date_location'
|
||||
),
|
||||
# Composite index for common query patterns
|
||||
Index('ix_forecasts_tenant_product_date', 'tenant_id', 'inventory_product_id', 'forecast_date'),
|
||||
)
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
inventory_product_id = Column(UUID(as_uuid=True), nullable=False, index=True) # Reference to inventory service
|
||||
product_name = Column(String(255), nullable=True, index=True) # Product name (optional - use inventory_product_id as reference)
|
||||
location = Column(String(255), nullable=False, index=True)
|
||||
|
||||
# Forecast period
|
||||
forecast_date = Column(DateTime(timezone=True), nullable=False, index=True)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
|
||||
# Prediction results
|
||||
predicted_demand = Column(Float, nullable=False)
|
||||
confidence_lower = Column(Float, nullable=False)
|
||||
confidence_upper = Column(Float, nullable=False)
|
||||
confidence_level = Column(Float, default=0.8)
|
||||
|
||||
# Model information
|
||||
model_id = Column(String(255), nullable=False)
|
||||
model_version = Column(String(50), nullable=False)
|
||||
algorithm = Column(String(50), default="prophet")
|
||||
|
||||
# Business context
|
||||
business_type = Column(String(50), default="individual") # individual or central_workshop
|
||||
day_of_week = Column(Integer, nullable=False)
|
||||
is_holiday = Column(Boolean, default=False)
|
||||
is_weekend = Column(Boolean, default=False)
|
||||
|
||||
# External factors
|
||||
weather_temperature = Column(Float)
|
||||
weather_precipitation = Column(Float)
|
||||
weather_description = Column(String(100))
|
||||
traffic_volume = Column(Integer)
|
||||
|
||||
# Metadata
|
||||
processing_time_ms = Column(Integer)
|
||||
features_used = Column(JSON)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Forecast(id={self.id}, inventory_product_id={self.inventory_product_id}, date={self.forecast_date})>"
|
||||
|
||||
class PredictionBatch(Base):
|
||||
"""Batch prediction requests"""
|
||||
__tablename__ = "prediction_batches"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
|
||||
# Batch information
|
||||
batch_name = Column(String(255), nullable=False)
|
||||
requested_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
completed_at = Column(DateTime(timezone=True))
|
||||
|
||||
# Status
|
||||
status = Column(String(50), default="pending") # pending, processing, completed, failed
|
||||
total_products = Column(Integer, default=0)
|
||||
completed_products = Column(Integer, default=0)
|
||||
failed_products = Column(Integer, default=0)
|
||||
|
||||
# Configuration
|
||||
forecast_days = Column(Integer, default=7)
|
||||
business_type = Column(String(50), default="individual")
|
||||
|
||||
# Results
|
||||
error_message = Column(Text)
|
||||
processing_time_ms = Column(Integer)
|
||||
|
||||
cancelled_by = Column(String, nullable=True)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<PredictionBatch(id={self.id}, status={self.status})>"
|
||||
|
||||
|
||||
67
services/forecasting/app/models/predictions.py
Normal file
67
services/forecasting/app/models/predictions.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/models/predictions.py
|
||||
# ================================================================
|
||||
"""
|
||||
Additional prediction models for the forecasting service
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, Integer, Float, DateTime, Boolean, Text, JSON
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
|
||||
from shared.database.base import Base
|
||||
|
||||
class ModelPerformanceMetric(Base):
|
||||
"""Track model performance over time"""
|
||||
__tablename__ = "model_performance_metrics"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
model_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
inventory_product_id = Column(UUID(as_uuid=True), nullable=False) # Reference to inventory service
|
||||
|
||||
# Performance metrics
|
||||
mae = Column(Float) # Mean Absolute Error
|
||||
mape = Column(Float) # Mean Absolute Percentage Error
|
||||
rmse = Column(Float) # Root Mean Square Error
|
||||
accuracy_score = Column(Float)
|
||||
|
||||
# Evaluation period
|
||||
evaluation_date = Column(DateTime(timezone=True), nullable=False)
|
||||
evaluation_period_start = Column(DateTime(timezone=True))
|
||||
evaluation_period_end = Column(DateTime(timezone=True))
|
||||
|
||||
# Metadata
|
||||
sample_size = Column(Integer)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ModelPerformanceMetric(model_id={self.model_id}, mae={self.mae})>"
|
||||
|
||||
class PredictionCache(Base):
|
||||
"""Cache frequently requested predictions"""
|
||||
__tablename__ = "prediction_cache"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
cache_key = Column(String(255), unique=True, nullable=False, index=True)
|
||||
|
||||
# Cached data
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
inventory_product_id = Column(UUID(as_uuid=True), nullable=False) # Reference to inventory service
|
||||
location = Column(String(255), nullable=False)
|
||||
forecast_date = Column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
# Cached results
|
||||
predicted_demand = Column(Float, nullable=False)
|
||||
confidence_lower = Column(Float, nullable=False)
|
||||
confidence_upper = Column(Float, nullable=False)
|
||||
model_id = Column(UUID(as_uuid=True), nullable=False)
|
||||
|
||||
# Cache metadata
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
expires_at = Column(DateTime(timezone=True), nullable=False)
|
||||
hit_count = Column(Integer, default=0)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<PredictionCache(key={self.cache_key}, inventory_product_id={self.inventory_product_id})>"
|
||||
78
services/forecasting/app/models/sales_data_update.py
Normal file
78
services/forecasting/app/models/sales_data_update.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/models/sales_data_update.py
|
||||
# ================================================================
|
||||
"""
|
||||
Sales Data Update Tracking Model
|
||||
|
||||
Tracks when sales data is added or updated for past dates,
|
||||
enabling automated historical validation backfill.
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, Integer, DateTime, Boolean, Index, Date
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
|
||||
from shared.database.base import Base
|
||||
|
||||
|
||||
class SalesDataUpdate(Base):
|
||||
"""Track sales data updates for historical validation"""
|
||||
__tablename__ = "sales_data_updates"
|
||||
|
||||
__table_args__ = (
|
||||
Index('ix_sales_updates_tenant_status', 'tenant_id', 'validation_status', 'created_at'),
|
||||
Index('ix_sales_updates_date_range', 'tenant_id', 'update_date_start', 'update_date_end'),
|
||||
Index('ix_sales_updates_validation_status', 'validation_status'),
|
||||
)
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
|
||||
# Date range of sales data that was added/updated
|
||||
update_date_start = Column(Date, nullable=False, index=True)
|
||||
update_date_end = Column(Date, nullable=False, index=True)
|
||||
|
||||
# Update metadata
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
update_source = Column(String(100), nullable=True) # import, manual, pos_sync
|
||||
records_affected = Column(Integer, default=0)
|
||||
|
||||
# Validation tracking
|
||||
validation_status = Column(String(50), default="pending") # pending, processing, completed, failed
|
||||
validation_run_id = Column(UUID(as_uuid=True), nullable=True)
|
||||
validated_at = Column(DateTime(timezone=True), nullable=True)
|
||||
validation_error = Column(String(500), nullable=True)
|
||||
|
||||
# Determines if this update should trigger validation
|
||||
requires_validation = Column(Boolean, default=True)
|
||||
|
||||
# Additional context
|
||||
import_job_id = Column(String(255), nullable=True) # Link to sales import job if applicable
|
||||
notes = Column(String(500), nullable=True)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<SalesDataUpdate(id={self.id}, tenant_id={self.tenant_id}, "
|
||||
f"date_range={self.update_date_start} to {self.update_date_end}, "
|
||||
f"status={self.validation_status})>"
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert to dictionary for API responses"""
|
||||
return {
|
||||
'id': str(self.id),
|
||||
'tenant_id': str(self.tenant_id),
|
||||
'update_date_start': self.update_date_start.isoformat() if self.update_date_start else None,
|
||||
'update_date_end': self.update_date_end.isoformat() if self.update_date_end else None,
|
||||
'created_at': self.created_at.isoformat() if self.created_at else None,
|
||||
'update_source': self.update_source,
|
||||
'records_affected': self.records_affected,
|
||||
'validation_status': self.validation_status,
|
||||
'validation_run_id': str(self.validation_run_id) if self.validation_run_id else None,
|
||||
'validated_at': self.validated_at.isoformat() if self.validated_at else None,
|
||||
'validation_error': self.validation_error,
|
||||
'requires_validation': self.requires_validation,
|
||||
'import_job_id': self.import_job_id,
|
||||
'notes': self.notes
|
||||
}
|
||||
110
services/forecasting/app/models/validation_run.py
Normal file
110
services/forecasting/app/models/validation_run.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/models/validation_run.py
|
||||
# ================================================================
|
||||
"""
|
||||
Validation run models for tracking forecast validation executions
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, Integer, Float, DateTime, Text, JSON, Index
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
|
||||
from shared.database.base import Base
|
||||
|
||||
|
||||
class ValidationRun(Base):
|
||||
"""Track forecast validation execution runs"""
|
||||
__tablename__ = "validation_runs"
|
||||
|
||||
__table_args__ = (
|
||||
Index('ix_validation_runs_tenant_created', 'tenant_id', 'started_at'),
|
||||
Index('ix_validation_runs_status', 'status', 'started_at'),
|
||||
Index('ix_validation_runs_orchestration', 'orchestration_run_id'),
|
||||
)
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
|
||||
# Link to orchestration run (if triggered by orchestrator)
|
||||
orchestration_run_id = Column(UUID(as_uuid=True), nullable=True)
|
||||
|
||||
# Validation period
|
||||
validation_start_date = Column(DateTime(timezone=True), nullable=False)
|
||||
validation_end_date = Column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
# Execution metadata
|
||||
started_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
completed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
duration_seconds = Column(Float, nullable=True)
|
||||
|
||||
# Status and results
|
||||
status = Column(String(50), default="pending") # pending, running, completed, failed
|
||||
|
||||
# Validation statistics
|
||||
total_forecasts_evaluated = Column(Integer, default=0)
|
||||
forecasts_with_actuals = Column(Integer, default=0)
|
||||
forecasts_without_actuals = Column(Integer, default=0)
|
||||
|
||||
# Accuracy metrics summary (across all validated forecasts)
|
||||
overall_mae = Column(Float, nullable=True)
|
||||
overall_mape = Column(Float, nullable=True)
|
||||
overall_rmse = Column(Float, nullable=True)
|
||||
overall_r2_score = Column(Float, nullable=True)
|
||||
overall_accuracy_percentage = Column(Float, nullable=True)
|
||||
|
||||
# Additional statistics
|
||||
total_predicted_demand = Column(Float, default=0.0)
|
||||
total_actual_demand = Column(Float, default=0.0)
|
||||
|
||||
# Breakdown by product/location (JSON)
|
||||
metrics_by_product = Column(JSON, nullable=True) # {product_id: {mae, mape, ...}}
|
||||
metrics_by_location = Column(JSON, nullable=True) # {location: {mae, mape, ...}}
|
||||
|
||||
# Performance metrics created count
|
||||
metrics_records_created = Column(Integer, default=0)
|
||||
|
||||
# Error tracking
|
||||
error_message = Column(Text, nullable=True)
|
||||
error_details = Column(JSON, nullable=True)
|
||||
|
||||
# Execution context
|
||||
triggered_by = Column(String(100), default="manual") # manual, orchestrator, scheduled
|
||||
execution_mode = Column(String(50), default="batch") # batch, single_day, real_time
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<ValidationRun(id={self.id}, tenant_id={self.tenant_id}, "
|
||||
f"status={self.status}, forecasts_evaluated={self.total_forecasts_evaluated})>"
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert to dictionary for API responses"""
|
||||
return {
|
||||
'id': str(self.id),
|
||||
'tenant_id': str(self.tenant_id),
|
||||
'orchestration_run_id': str(self.orchestration_run_id) if self.orchestration_run_id else None,
|
||||
'validation_start_date': self.validation_start_date.isoformat() if self.validation_start_date else None,
|
||||
'validation_end_date': self.validation_end_date.isoformat() if self.validation_end_date else None,
|
||||
'started_at': self.started_at.isoformat() if self.started_at else None,
|
||||
'completed_at': self.completed_at.isoformat() if self.completed_at else None,
|
||||
'duration_seconds': self.duration_seconds,
|
||||
'status': self.status,
|
||||
'total_forecasts_evaluated': self.total_forecasts_evaluated,
|
||||
'forecasts_with_actuals': self.forecasts_with_actuals,
|
||||
'forecasts_without_actuals': self.forecasts_without_actuals,
|
||||
'overall_mae': self.overall_mae,
|
||||
'overall_mape': self.overall_mape,
|
||||
'overall_rmse': self.overall_rmse,
|
||||
'overall_r2_score': self.overall_r2_score,
|
||||
'overall_accuracy_percentage': self.overall_accuracy_percentage,
|
||||
'total_predicted_demand': self.total_predicted_demand,
|
||||
'total_actual_demand': self.total_actual_demand,
|
||||
'metrics_by_product': self.metrics_by_product,
|
||||
'metrics_by_location': self.metrics_by_location,
|
||||
'metrics_records_created': self.metrics_records_created,
|
||||
'error_message': self.error_message,
|
||||
'error_details': self.error_details,
|
||||
'triggered_by': self.triggered_by,
|
||||
'execution_mode': self.execution_mode,
|
||||
}
|
||||
18
services/forecasting/app/repositories/__init__.py
Normal file
18
services/forecasting/app/repositories/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
Forecasting Service Repositories
|
||||
Repository implementations for forecasting service
|
||||
"""
|
||||
|
||||
from .base import ForecastingBaseRepository
|
||||
from .forecast_repository import ForecastRepository
|
||||
from .prediction_batch_repository import PredictionBatchRepository
|
||||
from .performance_metric_repository import PerformanceMetricRepository
|
||||
from .prediction_cache_repository import PredictionCacheRepository
|
||||
|
||||
__all__ = [
|
||||
"ForecastingBaseRepository",
|
||||
"ForecastRepository",
|
||||
"PredictionBatchRepository",
|
||||
"PerformanceMetricRepository",
|
||||
"PredictionCacheRepository"
|
||||
]
|
||||
253
services/forecasting/app/repositories/base.py
Normal file
253
services/forecasting/app/repositories/base.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""
|
||||
Base Repository for Forecasting Service
|
||||
Service-specific repository base class with forecasting utilities
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any, Type
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import text
|
||||
from datetime import datetime, date, timedelta, timezone
|
||||
import structlog
|
||||
|
||||
from shared.database.repository import BaseRepository
|
||||
from shared.database.exceptions import DatabaseError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class ForecastingBaseRepository(BaseRepository):
|
||||
"""Base repository for forecasting service with common forecasting operations"""
|
||||
|
||||
def __init__(self, model: Type, session: AsyncSession, cache_ttl: Optional[int] = 600):
|
||||
# Forecasting data benefits from medium cache time (10 minutes)
|
||||
super().__init__(model, session, cache_ttl)
|
||||
|
||||
async def get_by_tenant_id(self, tenant_id: str, skip: int = 0, limit: int = 100) -> List:
|
||||
"""Get records by tenant ID"""
|
||||
if hasattr(self.model, 'tenant_id'):
|
||||
return await self.get_multi(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
filters={"tenant_id": tenant_id},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
return await self.get_multi(skip=skip, limit=limit)
|
||||
|
||||
async def get_by_inventory_product_id(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List:
|
||||
"""Get records by tenant and inventory product"""
|
||||
if hasattr(self.model, 'inventory_product_id'):
|
||||
return await self.get_multi(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
filters={
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id
|
||||
},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
return await self.get_by_tenant_id(tenant_id, skip, limit)
|
||||
|
||||
async def get_by_date_range(
|
||||
self,
|
||||
tenant_id: str,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List:
|
||||
"""Get records within date range for a tenant"""
|
||||
if not hasattr(self.model, 'forecast_date') and not hasattr(self.model, 'created_at'):
|
||||
logger.warning(f"Model {self.model.__name__} has no date field for filtering")
|
||||
return []
|
||||
|
||||
try:
|
||||
table_name = self.model.__tablename__
|
||||
date_field = "forecast_date" if hasattr(self.model, 'forecast_date') else "created_at"
|
||||
|
||||
query_text = f"""
|
||||
SELECT * FROM {table_name}
|
||||
WHERE tenant_id = :tenant_id
|
||||
AND {date_field} >= :start_date
|
||||
AND {date_field} <= :end_date
|
||||
ORDER BY {date_field} DESC
|
||||
LIMIT :limit OFFSET :skip
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {
|
||||
"tenant_id": tenant_id,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"limit": limit,
|
||||
"skip": skip
|
||||
})
|
||||
|
||||
# Convert rows to model objects
|
||||
records = []
|
||||
for row in result.fetchall():
|
||||
record_dict = dict(row._mapping)
|
||||
record = self.model(**record_dict)
|
||||
records.append(record)
|
||||
|
||||
return records
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get records by date range",
|
||||
model=self.model.__name__,
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Date range query failed: {str(e)}")
|
||||
|
||||
async def get_recent_records(
|
||||
self,
|
||||
tenant_id: str,
|
||||
hours: int = 24,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List:
|
||||
"""Get recent records for a tenant"""
|
||||
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours)
|
||||
return await self.get_by_date_range(
|
||||
tenant_id, cutoff_time, datetime.now(timezone.utc), skip, limit
|
||||
)
|
||||
|
||||
async def cleanup_old_records(self, days_old: int = 90) -> int:
|
||||
"""Clean up old forecasting records"""
|
||||
try:
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_old)
|
||||
table_name = self.model.__tablename__
|
||||
|
||||
# Use created_at or forecast_date for cleanup
|
||||
date_field = "forecast_date" if hasattr(self.model, 'forecast_date') else "created_at"
|
||||
|
||||
query_text = f"""
|
||||
DELETE FROM {table_name}
|
||||
WHERE {date_field} < :cutoff_date
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
|
||||
deleted_count = result.rowcount
|
||||
|
||||
logger.info(f"Cleaned up old {self.model.__name__} records",
|
||||
deleted_count=deleted_count,
|
||||
days_old=days_old)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup old records",
|
||||
model=self.model.__name__,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Cleanup failed: {str(e)}")
|
||||
|
||||
async def get_statistics_by_tenant(self, tenant_id: str) -> Dict[str, Any]:
|
||||
"""Get statistics for a tenant"""
|
||||
try:
|
||||
table_name = self.model.__tablename__
|
||||
|
||||
# Get basic counts
|
||||
total_records = await self.count(filters={"tenant_id": tenant_id})
|
||||
|
||||
# Get recent activity (records in last 7 days)
|
||||
seven_days_ago = datetime.now(timezone.utc) - timedelta(days=7)
|
||||
recent_records = len(await self.get_by_date_range(
|
||||
tenant_id, seven_days_ago, datetime.now(timezone.utc), limit=1000
|
||||
))
|
||||
|
||||
# Get records by product if applicable
|
||||
product_stats = {}
|
||||
if hasattr(self.model, 'inventory_product_id'):
|
||||
product_query = text(f"""
|
||||
SELECT inventory_product_id, COUNT(*) as count
|
||||
FROM {table_name}
|
||||
WHERE tenant_id = :tenant_id
|
||||
GROUP BY inventory_product_id
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(product_query, {"tenant_id": tenant_id})
|
||||
product_stats = {row.inventory_product_id: row.count for row in result.fetchall()}
|
||||
|
||||
return {
|
||||
"total_records": total_records,
|
||||
"recent_records_7d": recent_records,
|
||||
"records_by_product": product_stats
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get tenant statistics",
|
||||
model=self.model.__name__,
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"total_records": 0,
|
||||
"recent_records_7d": 0,
|
||||
"records_by_product": {}
|
||||
}
|
||||
|
||||
def _validate_forecast_data(self, data: Dict[str, Any], required_fields: List[str]) -> Dict[str, Any]:
|
||||
"""Validate forecasting-related data"""
|
||||
errors = []
|
||||
|
||||
for field in required_fields:
|
||||
if field not in data or data[field] is None:
|
||||
errors.append(f"Missing required field: {field}")
|
||||
|
||||
# Validate tenant_id format if present
|
||||
if "tenant_id" in data and data["tenant_id"]:
|
||||
tenant_id = data["tenant_id"]
|
||||
if not isinstance(tenant_id, str) or len(tenant_id) < 1:
|
||||
errors.append("Invalid tenant_id format")
|
||||
|
||||
# Validate inventory_product_id if present
|
||||
if "inventory_product_id" in data and data["inventory_product_id"]:
|
||||
inventory_product_id = data["inventory_product_id"]
|
||||
if not isinstance(inventory_product_id, str) or len(inventory_product_id) < 1:
|
||||
errors.append("Invalid inventory_product_id format")
|
||||
|
||||
# Validate dates if present - accept datetime objects, date objects, and date strings
|
||||
date_fields = ["forecast_date", "created_at", "evaluation_date", "expires_at"]
|
||||
for field in date_fields:
|
||||
if field in data and data[field]:
|
||||
field_value = data[field]
|
||||
field_type = type(field_value).__name__
|
||||
|
||||
if isinstance(field_value, (datetime, date)):
|
||||
logger.debug(f"Date field {field} is valid {field_type}", field_value=str(field_value))
|
||||
continue # Already a datetime or date, valid
|
||||
elif isinstance(field_value, str):
|
||||
# Try to parse the string date
|
||||
try:
|
||||
from dateutil.parser import parse
|
||||
parse(field_value) # Just validate, don't convert yet
|
||||
logger.debug(f"Date field {field} is valid string", field_value=field_value)
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.error(f"Date parsing failed for {field}", field_value=field_value, error=str(e))
|
||||
errors.append(f"Invalid {field} format - must be datetime or valid date string")
|
||||
else:
|
||||
logger.error(f"Date field {field} has invalid type {field_type}", field_value=str(field_value))
|
||||
errors.append(f"Invalid {field} format - must be datetime or valid date string")
|
||||
|
||||
# Validate numeric fields
|
||||
numeric_fields = [
|
||||
"predicted_demand", "confidence_lower", "confidence_upper",
|
||||
"mae", "mape", "rmse", "accuracy_score"
|
||||
]
|
||||
for field in numeric_fields:
|
||||
if field in data and data[field] is not None:
|
||||
try:
|
||||
float(data[field])
|
||||
except (ValueError, TypeError):
|
||||
errors.append(f"Invalid {field} format - must be numeric")
|
||||
|
||||
return {
|
||||
"is_valid": len(errors) == 0,
|
||||
"errors": errors
|
||||
}
|
||||
565
services/forecasting/app/repositories/forecast_repository.py
Normal file
565
services/forecasting/app/repositories/forecast_repository.py
Normal file
@@ -0,0 +1,565 @@
|
||||
"""
|
||||
Forecast Repository
|
||||
Repository for forecast operations
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, text, desc, func
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from datetime import datetime, timedelta, date, timezone
|
||||
import structlog
|
||||
|
||||
from .base import ForecastingBaseRepository
|
||||
from app.models.forecasts import Forecast
|
||||
from shared.database.exceptions import DatabaseError, ValidationError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class ForecastRepository(ForecastingBaseRepository):
|
||||
"""Repository for forecast operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 600):
|
||||
# Forecasts are relatively stable, medium cache time (10 minutes)
|
||||
super().__init__(Forecast, session, cache_ttl)
|
||||
|
||||
async def create_forecast(self, forecast_data: Dict[str, Any]) -> Forecast:
|
||||
"""
|
||||
Create a new forecast with validation.
|
||||
|
||||
Handles duplicate forecast race condition gracefully:
|
||||
If a forecast already exists for the same (tenant, product, date, location),
|
||||
it will be updated instead of creating a duplicate.
|
||||
"""
|
||||
try:
|
||||
# Validate forecast data
|
||||
validation_result = self._validate_forecast_data(
|
||||
forecast_data,
|
||||
["tenant_id", "inventory_product_id", "location", "forecast_date",
|
||||
"predicted_demand", "confidence_lower", "confidence_upper", "model_id"]
|
||||
)
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid forecast data: {validation_result['errors']}")
|
||||
|
||||
# Set default values
|
||||
if "confidence_level" not in forecast_data:
|
||||
forecast_data["confidence_level"] = 0.8
|
||||
if "algorithm" not in forecast_data:
|
||||
forecast_data["algorithm"] = "prophet"
|
||||
if "business_type" not in forecast_data:
|
||||
forecast_data["business_type"] = "individual"
|
||||
|
||||
# Try to create forecast
|
||||
try:
|
||||
forecast = await self.create(forecast_data)
|
||||
|
||||
logger.info("Forecast created successfully",
|
||||
forecast_id=forecast.id,
|
||||
tenant_id=forecast.tenant_id,
|
||||
inventory_product_id=forecast.inventory_product_id,
|
||||
forecast_date=forecast.forecast_date.isoformat())
|
||||
|
||||
return forecast
|
||||
|
||||
except IntegrityError as ie:
|
||||
# Handle unique constraint violation (duplicate forecast)
|
||||
error_msg = str(ie).lower()
|
||||
if "unique constraint" in error_msg or "duplicate" in error_msg or "uq_forecast_tenant_product_date_location" in error_msg:
|
||||
logger.warning("Forecast already exists (race condition), updating instead",
|
||||
tenant_id=forecast_data.get("tenant_id"),
|
||||
inventory_product_id=forecast_data.get("inventory_product_id"),
|
||||
forecast_date=str(forecast_data.get("forecast_date")))
|
||||
|
||||
# Rollback the failed insert
|
||||
await self.session.rollback()
|
||||
|
||||
# Fetch the existing forecast
|
||||
existing_forecast = await self.get_existing_forecast(
|
||||
tenant_id=forecast_data["tenant_id"],
|
||||
inventory_product_id=forecast_data["inventory_product_id"],
|
||||
forecast_date=forecast_data["forecast_date"],
|
||||
location=forecast_data["location"]
|
||||
)
|
||||
|
||||
if existing_forecast:
|
||||
# Update existing forecast with new prediction data
|
||||
update_data = {
|
||||
"predicted_demand": forecast_data["predicted_demand"],
|
||||
"confidence_lower": forecast_data["confidence_lower"],
|
||||
"confidence_upper": forecast_data["confidence_upper"],
|
||||
"confidence_level": forecast_data.get("confidence_level", 0.8),
|
||||
"model_id": forecast_data["model_id"],
|
||||
"model_version": forecast_data.get("model_version"),
|
||||
"algorithm": forecast_data.get("algorithm", "prophet"),
|
||||
"processing_time_ms": forecast_data.get("processing_time_ms"),
|
||||
"features_used": forecast_data.get("features_used"),
|
||||
"weather_temperature": forecast_data.get("weather_temperature"),
|
||||
"weather_precipitation": forecast_data.get("weather_precipitation"),
|
||||
"weather_description": forecast_data.get("weather_description"),
|
||||
}
|
||||
|
||||
updated_forecast = await self.update(str(existing_forecast.id), update_data)
|
||||
|
||||
logger.info("Existing forecast updated after duplicate detection",
|
||||
forecast_id=updated_forecast.id,
|
||||
tenant_id=updated_forecast.tenant_id,
|
||||
inventory_product_id=updated_forecast.inventory_product_id)
|
||||
|
||||
return updated_forecast
|
||||
else:
|
||||
# This shouldn't happen, but log it
|
||||
logger.error("Duplicate forecast detected but not found in database")
|
||||
raise DatabaseError("Duplicate forecast detected but not found")
|
||||
else:
|
||||
# Different integrity error, re-raise
|
||||
raise
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except IntegrityError as ie:
|
||||
# Re-raise integrity errors that weren't handled above
|
||||
logger.error("Database integrity error creating forecast",
|
||||
tenant_id=forecast_data.get("tenant_id"),
|
||||
error=str(ie))
|
||||
raise DatabaseError(f"Database integrity error: {str(ie)}")
|
||||
except Exception as e:
|
||||
logger.error("Failed to create forecast",
|
||||
tenant_id=forecast_data.get("tenant_id"),
|
||||
inventory_product_id=forecast_data.get("inventory_product_id"),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create forecast: {str(e)}")
|
||||
|
||||
async def get_existing_forecast(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
forecast_date: datetime,
|
||||
location: str
|
||||
) -> Optional[Forecast]:
|
||||
"""Get an existing forecast by unique key (tenant, product, date, location)"""
|
||||
try:
|
||||
query = select(Forecast).where(
|
||||
and_(
|
||||
Forecast.tenant_id == tenant_id,
|
||||
Forecast.inventory_product_id == inventory_product_id,
|
||||
Forecast.forecast_date == forecast_date,
|
||||
Forecast.location == location
|
||||
)
|
||||
)
|
||||
result = await self.session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error("Failed to get existing forecast", error=str(e))
|
||||
return None
|
||||
|
||||
async def get_forecasts_by_date_range(
|
||||
self,
|
||||
tenant_id: str,
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
inventory_product_id: str = None,
|
||||
location: str = None
|
||||
) -> List[Forecast]:
|
||||
"""Get forecasts within a date range"""
|
||||
try:
|
||||
filters = {"tenant_id": tenant_id}
|
||||
|
||||
if inventory_product_id:
|
||||
filters["inventory_product_id"] = inventory_product_id
|
||||
if location:
|
||||
filters["location"] = location
|
||||
|
||||
# Convert dates to datetime for comparison
|
||||
start_datetime = datetime.combine(start_date, datetime.min.time())
|
||||
end_datetime = datetime.combine(end_date, datetime.max.time())
|
||||
|
||||
return await self.get_by_date_range(
|
||||
tenant_id, start_datetime, end_datetime
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get forecasts by date range",
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get forecasts: {str(e)}")
|
||||
|
||||
async def get_latest_forecast_for_product(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
location: str = None
|
||||
) -> Optional[Forecast]:
|
||||
"""Get the most recent forecast for a product"""
|
||||
try:
|
||||
filters = {
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id
|
||||
}
|
||||
if location:
|
||||
filters["location"] = location
|
||||
|
||||
forecasts = await self.get_multi(
|
||||
filters=filters,
|
||||
limit=1,
|
||||
order_by="forecast_date",
|
||||
order_desc=True
|
||||
)
|
||||
|
||||
return forecasts[0] if forecasts else None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get latest forecast for product",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get latest forecast: {str(e)}")
|
||||
|
||||
async def get_forecasts_for_date(
|
||||
self,
|
||||
tenant_id: str,
|
||||
forecast_date: date,
|
||||
inventory_product_id: str = None
|
||||
) -> List[Forecast]:
|
||||
"""Get all forecasts for a specific date"""
|
||||
try:
|
||||
# Convert date to datetime range
|
||||
start_datetime = datetime.combine(forecast_date, datetime.min.time())
|
||||
end_datetime = datetime.combine(forecast_date, datetime.max.time())
|
||||
|
||||
return await self.get_by_date_range(
|
||||
tenant_id, start_datetime, end_datetime
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get forecasts for date",
|
||||
tenant_id=tenant_id,
|
||||
forecast_date=forecast_date,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get forecasts for date: {str(e)}")
|
||||
|
||||
async def get_forecast_accuracy_metrics(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str = None,
|
||||
days_back: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""Get forecast accuracy metrics"""
|
||||
try:
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_back)
|
||||
|
||||
# Build base query conditions
|
||||
conditions = ["tenant_id = :tenant_id", "forecast_date >= :cutoff_date"]
|
||||
params = {
|
||||
"tenant_id": tenant_id,
|
||||
"cutoff_date": cutoff_date
|
||||
}
|
||||
|
||||
if inventory_product_id:
|
||||
conditions.append("inventory_product_id = :inventory_product_id")
|
||||
params["inventory_product_id"] = inventory_product_id
|
||||
|
||||
query_text = f"""
|
||||
SELECT
|
||||
COUNT(*) as total_forecasts,
|
||||
AVG(predicted_demand) as avg_predicted_demand,
|
||||
MIN(predicted_demand) as min_predicted_demand,
|
||||
MAX(predicted_demand) as max_predicted_demand,
|
||||
AVG(confidence_upper - confidence_lower) as avg_confidence_interval,
|
||||
AVG(processing_time_ms) as avg_processing_time_ms,
|
||||
COUNT(DISTINCT inventory_product_id) as unique_products,
|
||||
COUNT(DISTINCT model_id) as unique_models
|
||||
FROM forecasts
|
||||
WHERE {' AND '.join(conditions)}
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), params)
|
||||
row = result.fetchone()
|
||||
|
||||
if row and row.total_forecasts > 0:
|
||||
return {
|
||||
"total_forecasts": int(row.total_forecasts),
|
||||
"avg_predicted_demand": float(row.avg_predicted_demand or 0),
|
||||
"min_predicted_demand": float(row.min_predicted_demand or 0),
|
||||
"max_predicted_demand": float(row.max_predicted_demand or 0),
|
||||
"avg_confidence_interval": float(row.avg_confidence_interval or 0),
|
||||
"avg_processing_time_ms": float(row.avg_processing_time_ms or 0),
|
||||
"unique_products": int(row.unique_products or 0),
|
||||
"unique_models": int(row.unique_models or 0),
|
||||
"period_days": days_back
|
||||
}
|
||||
|
||||
return {
|
||||
"total_forecasts": 0,
|
||||
"avg_predicted_demand": 0.0,
|
||||
"min_predicted_demand": 0.0,
|
||||
"max_predicted_demand": 0.0,
|
||||
"avg_confidence_interval": 0.0,
|
||||
"avg_processing_time_ms": 0.0,
|
||||
"unique_products": 0,
|
||||
"unique_models": 0,
|
||||
"period_days": days_back
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get forecast accuracy metrics",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"total_forecasts": 0,
|
||||
"avg_predicted_demand": 0.0,
|
||||
"min_predicted_demand": 0.0,
|
||||
"max_predicted_demand": 0.0,
|
||||
"avg_confidence_interval": 0.0,
|
||||
"avg_processing_time_ms": 0.0,
|
||||
"unique_products": 0,
|
||||
"unique_models": 0,
|
||||
"period_days": days_back
|
||||
}
|
||||
|
||||
async def get_demand_trends(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
days_back: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""Get demand trends for a product"""
|
||||
try:
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_back)
|
||||
|
||||
query_text = """
|
||||
SELECT
|
||||
DATE(forecast_date) as date,
|
||||
AVG(predicted_demand) as avg_demand,
|
||||
MIN(predicted_demand) as min_demand,
|
||||
MAX(predicted_demand) as max_demand,
|
||||
COUNT(*) as forecast_count
|
||||
FROM forecasts
|
||||
WHERE tenant_id = :tenant_id
|
||||
AND inventory_product_id = :inventory_product_id
|
||||
AND forecast_date >= :cutoff_date
|
||||
GROUP BY DATE(forecast_date)
|
||||
ORDER BY date DESC
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"cutoff_date": cutoff_date
|
||||
})
|
||||
|
||||
trends = []
|
||||
for row in result.fetchall():
|
||||
trends.append({
|
||||
"date": row.date.isoformat() if row.date else None,
|
||||
"avg_demand": float(row.avg_demand),
|
||||
"min_demand": float(row.min_demand),
|
||||
"max_demand": float(row.max_demand),
|
||||
"forecast_count": int(row.forecast_count)
|
||||
})
|
||||
|
||||
# Calculate overall trend direction
|
||||
if len(trends) >= 2:
|
||||
recent_avg = sum(t["avg_demand"] for t in trends[:7]) / min(7, len(trends))
|
||||
older_avg = sum(t["avg_demand"] for t in trends[-7:]) / min(7, len(trends[-7:]))
|
||||
trend_direction = "increasing" if recent_avg > older_avg else "decreasing"
|
||||
else:
|
||||
trend_direction = "stable"
|
||||
|
||||
return {
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"period_days": days_back,
|
||||
"trends": trends,
|
||||
"trend_direction": trend_direction,
|
||||
"total_data_points": len(trends)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get demand trends",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"period_days": days_back,
|
||||
"trends": [],
|
||||
"trend_direction": "unknown",
|
||||
"total_data_points": 0
|
||||
}
|
||||
|
||||
async def get_model_usage_statistics(self, tenant_id: str) -> Dict[str, Any]:
|
||||
"""Get statistics about model usage"""
|
||||
try:
|
||||
# Get model usage counts
|
||||
model_query = text("""
|
||||
SELECT
|
||||
model_id,
|
||||
algorithm,
|
||||
COUNT(*) as usage_count,
|
||||
AVG(predicted_demand) as avg_prediction,
|
||||
MAX(forecast_date) as last_used,
|
||||
COUNT(DISTINCT inventory_product_id) as products_covered
|
||||
FROM forecasts
|
||||
WHERE tenant_id = :tenant_id
|
||||
GROUP BY model_id, algorithm
|
||||
ORDER BY usage_count DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(model_query, {"tenant_id": tenant_id})
|
||||
|
||||
model_stats = []
|
||||
for row in result.fetchall():
|
||||
model_stats.append({
|
||||
"model_id": row.model_id,
|
||||
"algorithm": row.algorithm,
|
||||
"usage_count": int(row.usage_count),
|
||||
"avg_prediction": float(row.avg_prediction),
|
||||
"last_used": row.last_used.isoformat() if row.last_used else None,
|
||||
"products_covered": int(row.products_covered)
|
||||
})
|
||||
|
||||
# Get algorithm distribution
|
||||
algorithm_query = text("""
|
||||
SELECT algorithm, COUNT(*) as count
|
||||
FROM forecasts
|
||||
WHERE tenant_id = :tenant_id
|
||||
GROUP BY algorithm
|
||||
""")
|
||||
|
||||
algorithm_result = await self.session.execute(algorithm_query, {"tenant_id": tenant_id})
|
||||
algorithm_distribution = {row.algorithm: row.count for row in algorithm_result.fetchall()}
|
||||
|
||||
return {
|
||||
"model_statistics": model_stats,
|
||||
"algorithm_distribution": algorithm_distribution,
|
||||
"total_unique_models": len(model_stats)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get model usage statistics",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"model_statistics": [],
|
||||
"algorithm_distribution": {},
|
||||
"total_unique_models": 0
|
||||
}
|
||||
|
||||
async def cleanup_old_forecasts(self, days_old: int = 90) -> int:
|
||||
"""Clean up old forecasts"""
|
||||
return await self.cleanup_old_records(days_old=days_old)
|
||||
|
||||
async def get_forecast_summary(self, tenant_id: str) -> Dict[str, Any]:
|
||||
"""Get comprehensive forecast summary for a tenant"""
|
||||
try:
|
||||
# Get basic statistics
|
||||
basic_stats = await self.get_statistics_by_tenant(tenant_id)
|
||||
|
||||
# Get accuracy metrics
|
||||
accuracy_metrics = await self.get_forecast_accuracy_metrics(tenant_id)
|
||||
|
||||
# Get model usage
|
||||
model_usage = await self.get_model_usage_statistics(tenant_id)
|
||||
|
||||
# Get recent activity
|
||||
recent_forecasts = await self.get_recent_records(tenant_id, hours=24)
|
||||
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"basic_statistics": basic_stats,
|
||||
"accuracy_metrics": accuracy_metrics,
|
||||
"model_usage": model_usage,
|
||||
"recent_activity": {
|
||||
"forecasts_last_24h": len(recent_forecasts),
|
||||
"latest_forecast": recent_forecasts[0].forecast_date.isoformat() if recent_forecasts else None
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get forecast summary",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return {"error": f"Failed to get forecast summary: {str(e)}"}
|
||||
|
||||
async def get_forecasts_by_date(
|
||||
self,
|
||||
tenant_id: str,
|
||||
forecast_date: date,
|
||||
inventory_product_id: str = None
|
||||
) -> List[Forecast]:
|
||||
"""
|
||||
Get all forecasts for a specific date.
|
||||
Used for forecast validation against actual sales.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID
|
||||
forecast_date: Date to get forecasts for
|
||||
inventory_product_id: Optional product filter
|
||||
|
||||
Returns:
|
||||
List of forecasts for the date
|
||||
"""
|
||||
try:
|
||||
query = select(Forecast).where(
|
||||
and_(
|
||||
Forecast.tenant_id == tenant_id,
|
||||
func.date(Forecast.forecast_date) == forecast_date
|
||||
)
|
||||
)
|
||||
|
||||
if inventory_product_id:
|
||||
query = query.where(Forecast.inventory_product_id == inventory_product_id)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
forecasts = result.scalars().all()
|
||||
|
||||
logger.info("Retrieved forecasts by date",
|
||||
tenant_id=tenant_id,
|
||||
forecast_date=forecast_date.isoformat(),
|
||||
count=len(forecasts))
|
||||
|
||||
return list(forecasts)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get forecasts by date",
|
||||
tenant_id=tenant_id,
|
||||
forecast_date=forecast_date.isoformat(),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get forecasts: {str(e)}")
|
||||
|
||||
async def bulk_create_forecasts(self, forecasts_data: List[Dict[str, Any]]) -> List[Forecast]:
|
||||
"""Bulk create multiple forecasts"""
|
||||
try:
|
||||
created_forecasts = []
|
||||
|
||||
for forecast_data in forecasts_data:
|
||||
# Validate each forecast
|
||||
validation_result = self._validate_forecast_data(
|
||||
forecast_data,
|
||||
["tenant_id", "inventory_product_id", "location", "forecast_date",
|
||||
"predicted_demand", "confidence_lower", "confidence_upper", "model_id"]
|
||||
)
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
logger.warning("Skipping invalid forecast data",
|
||||
errors=validation_result["errors"],
|
||||
data=forecast_data)
|
||||
continue
|
||||
|
||||
forecast = await self.create(forecast_data)
|
||||
created_forecasts.append(forecast)
|
||||
|
||||
logger.info("Bulk created forecasts",
|
||||
requested_count=len(forecasts_data),
|
||||
created_count=len(created_forecasts))
|
||||
|
||||
return created_forecasts
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to bulk create forecasts",
|
||||
requested_count=len(forecasts_data),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Bulk forecast creation failed: {str(e)}")
|
||||
@@ -0,0 +1,214 @@
|
||||
# services/forecasting/app/repositories/forecasting_alert_repository.py
|
||||
"""
|
||||
Forecasting Alert Repository
|
||||
Data access layer for forecasting-specific alert detection and analysis
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any
|
||||
from uuid import UUID
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class ForecastingAlertRepository:
|
||||
"""Repository for forecasting alert data access"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def get_weekend_demand_surges(self, tenant_id: UUID) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get predicted weekend demand surges
|
||||
Returns forecasts showing significant growth over previous weeks
|
||||
"""
|
||||
try:
|
||||
query = text("""
|
||||
WITH weekend_forecast AS (
|
||||
SELECT
|
||||
f.tenant_id,
|
||||
f.inventory_product_id,
|
||||
f.product_name,
|
||||
f.predicted_demand,
|
||||
f.forecast_date,
|
||||
LAG(f.predicted_demand, 7) OVER (
|
||||
PARTITION BY f.tenant_id, f.inventory_product_id
|
||||
ORDER BY f.forecast_date
|
||||
) as prev_week_demand,
|
||||
AVG(f.predicted_demand) OVER (
|
||||
PARTITION BY f.tenant_id, f.inventory_product_id
|
||||
ORDER BY f.forecast_date
|
||||
ROWS BETWEEN 6 PRECEDING AND CURRENT ROW
|
||||
) as avg_weekly_demand
|
||||
FROM forecasts f
|
||||
WHERE f.forecast_date >= CURRENT_DATE + INTERVAL '1 day'
|
||||
AND f.forecast_date <= CURRENT_DATE + INTERVAL '3 days'
|
||||
AND EXTRACT(DOW FROM f.forecast_date) IN (6, 0)
|
||||
AND f.tenant_id = :tenant_id
|
||||
),
|
||||
surge_analysis AS (
|
||||
SELECT *,
|
||||
CASE
|
||||
WHEN prev_week_demand > 0 THEN
|
||||
(predicted_demand - prev_week_demand) / prev_week_demand * 100
|
||||
ELSE 0
|
||||
END as growth_percentage,
|
||||
CASE
|
||||
WHEN avg_weekly_demand > 0 THEN
|
||||
(predicted_demand - avg_weekly_demand) / avg_weekly_demand * 100
|
||||
ELSE 0
|
||||
END as avg_growth_percentage
|
||||
FROM weekend_forecast
|
||||
)
|
||||
SELECT * FROM surge_analysis
|
||||
WHERE growth_percentage > 50 OR avg_growth_percentage > 50
|
||||
ORDER BY growth_percentage DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(query, {"tenant_id": tenant_id})
|
||||
return [dict(row._mapping) for row in result.fetchall()]
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get weekend demand surges", error=str(e), tenant_id=str(tenant_id))
|
||||
raise
|
||||
|
||||
async def get_weather_impact_forecasts(self, tenant_id: UUID) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get weather impact on demand forecasts
|
||||
Returns forecasts with rain or significant demand changes
|
||||
"""
|
||||
try:
|
||||
query = text("""
|
||||
WITH weather_impact AS (
|
||||
SELECT
|
||||
f.tenant_id,
|
||||
f.inventory_product_id,
|
||||
f.product_name,
|
||||
f.predicted_demand,
|
||||
f.forecast_date,
|
||||
f.weather_precipitation,
|
||||
f.weather_temperature,
|
||||
f.traffic_volume,
|
||||
AVG(f.predicted_demand) OVER (
|
||||
PARTITION BY f.tenant_id, f.inventory_product_id
|
||||
ORDER BY f.forecast_date
|
||||
ROWS BETWEEN 6 PRECEDING AND CURRENT ROW
|
||||
) as avg_demand
|
||||
FROM forecasts f
|
||||
WHERE f.forecast_date >= CURRENT_DATE + INTERVAL '1 day'
|
||||
AND f.forecast_date <= CURRENT_DATE + INTERVAL '2 days'
|
||||
AND f.tenant_id = :tenant_id
|
||||
),
|
||||
rain_impact AS (
|
||||
SELECT *,
|
||||
CASE
|
||||
WHEN weather_precipitation > 2.0 THEN true
|
||||
ELSE false
|
||||
END as rain_forecast,
|
||||
CASE
|
||||
WHEN traffic_volume < 80 THEN true
|
||||
ELSE false
|
||||
END as low_traffic_expected,
|
||||
(predicted_demand - avg_demand) / avg_demand * 100 as demand_change
|
||||
FROM weather_impact
|
||||
)
|
||||
SELECT * FROM rain_impact
|
||||
WHERE rain_forecast = true OR demand_change < -15
|
||||
ORDER BY demand_change ASC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(query, {"tenant_id": tenant_id})
|
||||
return [dict(row._mapping) for row in result.fetchall()]
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get weather impact forecasts", error=str(e), tenant_id=str(tenant_id))
|
||||
raise
|
||||
|
||||
async def get_holiday_demand_spikes(self, tenant_id: UUID) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get historical holiday demand spike analysis
|
||||
Returns products with significant holiday demand increases
|
||||
"""
|
||||
try:
|
||||
query = text("""
|
||||
WITH holiday_demand AS (
|
||||
SELECT
|
||||
f.tenant_id,
|
||||
f.inventory_product_id,
|
||||
f.product_name,
|
||||
AVG(f.predicted_demand) as avg_holiday_demand,
|
||||
AVG(CASE WHEN f.is_holiday = false THEN f.predicted_demand END) as avg_normal_demand,
|
||||
COUNT(*) as forecast_count
|
||||
FROM forecasts f
|
||||
WHERE f.created_at > CURRENT_DATE - INTERVAL '365 days'
|
||||
AND f.tenant_id = :tenant_id
|
||||
GROUP BY f.tenant_id, f.inventory_product_id, f.product_name
|
||||
HAVING COUNT(*) >= 10
|
||||
),
|
||||
demand_spike_analysis AS (
|
||||
SELECT *,
|
||||
CASE
|
||||
WHEN avg_normal_demand > 0 THEN
|
||||
(avg_holiday_demand - avg_normal_demand) / avg_normal_demand * 100
|
||||
ELSE 0
|
||||
END as spike_percentage
|
||||
FROM holiday_demand
|
||||
)
|
||||
SELECT * FROM demand_spike_analysis
|
||||
WHERE spike_percentage > 25
|
||||
ORDER BY spike_percentage DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(query, {"tenant_id": tenant_id})
|
||||
return [dict(row._mapping) for row in result.fetchall()]
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get holiday demand spikes", error=str(e), tenant_id=str(tenant_id))
|
||||
raise
|
||||
|
||||
async def get_demand_pattern_analysis(self, tenant_id: UUID) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get weekly demand pattern analysis for optimization
|
||||
Returns products with significant demand variations
|
||||
"""
|
||||
try:
|
||||
query = text("""
|
||||
WITH weekly_patterns AS (
|
||||
SELECT
|
||||
f.tenant_id,
|
||||
f.inventory_product_id,
|
||||
f.product_name,
|
||||
EXTRACT(DOW FROM f.forecast_date) as day_of_week,
|
||||
AVG(f.predicted_demand) as avg_demand,
|
||||
STDDEV(f.predicted_demand) as demand_variance,
|
||||
COUNT(*) as data_points
|
||||
FROM forecasts f
|
||||
WHERE f.created_at > CURRENT_DATE - INTERVAL '60 days'
|
||||
AND f.tenant_id = :tenant_id
|
||||
GROUP BY f.tenant_id, f.inventory_product_id, f.product_name, EXTRACT(DOW FROM f.forecast_date)
|
||||
HAVING COUNT(*) >= 5
|
||||
),
|
||||
pattern_analysis AS (
|
||||
SELECT
|
||||
tenant_id, inventory_product_id, product_name,
|
||||
MAX(avg_demand) as peak_demand,
|
||||
MIN(avg_demand) as min_demand,
|
||||
AVG(avg_demand) as overall_avg,
|
||||
MAX(avg_demand) - MIN(avg_demand) as demand_range
|
||||
FROM weekly_patterns
|
||||
GROUP BY tenant_id, inventory_product_id, product_name
|
||||
)
|
||||
SELECT * FROM pattern_analysis
|
||||
WHERE demand_range > overall_avg * 0.3
|
||||
AND peak_demand > overall_avg * 1.5
|
||||
ORDER BY demand_range DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(query, {"tenant_id": tenant_id})
|
||||
return [dict(row._mapping) for row in result.fetchall()]
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get demand pattern analysis", error=str(e), tenant_id=str(tenant_id))
|
||||
raise
|
||||
@@ -0,0 +1,271 @@
|
||||
"""
|
||||
Performance Metric Repository
|
||||
Repository for model performance metrics in forecasting service
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import text
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import structlog
|
||||
|
||||
from .base import ForecastingBaseRepository
|
||||
from app.models.predictions import ModelPerformanceMetric
|
||||
from shared.database.exceptions import DatabaseError, ValidationError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class PerformanceMetricRepository(ForecastingBaseRepository):
|
||||
"""Repository for model performance metrics operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 900):
|
||||
# Performance metrics are stable, longer cache time (15 minutes)
|
||||
super().__init__(ModelPerformanceMetric, session, cache_ttl)
|
||||
|
||||
async def create_metric(self, metric_data: Dict[str, Any]) -> ModelPerformanceMetric:
|
||||
"""Create a new performance metric"""
|
||||
try:
|
||||
# Validate metric data
|
||||
validation_result = self._validate_forecast_data(
|
||||
metric_data,
|
||||
["model_id", "tenant_id", "inventory_product_id", "evaluation_date"]
|
||||
)
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid metric data: {validation_result['errors']}")
|
||||
|
||||
metric = await self.create(metric_data)
|
||||
|
||||
logger.info("Performance metric created",
|
||||
metric_id=metric.id,
|
||||
model_id=metric.model_id,
|
||||
tenant_id=metric.tenant_id,
|
||||
inventory_product_id=metric.inventory_product_id)
|
||||
|
||||
return metric
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to create performance metric",
|
||||
model_id=metric_data.get("model_id"),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create metric: {str(e)}")
|
||||
|
||||
async def get_metrics_by_model(
|
||||
self,
|
||||
model_id: str,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[ModelPerformanceMetric]:
|
||||
"""Get all metrics for a model"""
|
||||
try:
|
||||
return await self.get_multi(
|
||||
filters={"model_id": model_id},
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
order_by="evaluation_date",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get metrics by model",
|
||||
model_id=model_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get metrics: {str(e)}")
|
||||
|
||||
async def get_latest_metric_for_model(self, model_id: str) -> Optional[ModelPerformanceMetric]:
|
||||
"""Get the latest performance metric for a model"""
|
||||
try:
|
||||
metrics = await self.get_multi(
|
||||
filters={"model_id": model_id},
|
||||
limit=1,
|
||||
order_by="evaluation_date",
|
||||
order_desc=True
|
||||
)
|
||||
return metrics[0] if metrics else None
|
||||
except Exception as e:
|
||||
logger.error("Failed to get latest metric for model",
|
||||
model_id=model_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get latest metric: {str(e)}")
|
||||
|
||||
async def get_performance_trends(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str = None,
|
||||
days: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""Get performance trends over time"""
|
||||
try:
|
||||
start_date = datetime.now(timezone.utc) - timedelta(days=days)
|
||||
|
||||
conditions = [
|
||||
"tenant_id = :tenant_id",
|
||||
"evaluation_date >= :start_date"
|
||||
]
|
||||
params = {
|
||||
"tenant_id": tenant_id,
|
||||
"start_date": start_date
|
||||
}
|
||||
|
||||
if inventory_product_id:
|
||||
conditions.append("inventory_product_id = :inventory_product_id")
|
||||
params["inventory_product_id"] = inventory_product_id
|
||||
|
||||
query_text = f"""
|
||||
SELECT
|
||||
DATE(evaluation_date) as date,
|
||||
inventory_product_id,
|
||||
AVG(mae) as avg_mae,
|
||||
AVG(mape) as avg_mape,
|
||||
AVG(rmse) as avg_rmse,
|
||||
AVG(accuracy_score) as avg_accuracy,
|
||||
COUNT(*) as measurement_count
|
||||
FROM model_performance_metrics
|
||||
WHERE {' AND '.join(conditions)}
|
||||
GROUP BY DATE(evaluation_date), inventory_product_id
|
||||
ORDER BY date DESC, inventory_product_id
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), params)
|
||||
|
||||
trends = []
|
||||
for row in result.fetchall():
|
||||
trends.append({
|
||||
"date": row.date.isoformat() if row.date else None,
|
||||
"inventory_product_id": row.inventory_product_id,
|
||||
"metrics": {
|
||||
"avg_mae": float(row.avg_mae) if row.avg_mae else None,
|
||||
"avg_mape": float(row.avg_mape) if row.avg_mape else None,
|
||||
"avg_rmse": float(row.avg_rmse) if row.avg_rmse else None,
|
||||
"avg_accuracy": float(row.avg_accuracy) if row.avg_accuracy else None
|
||||
},
|
||||
"measurement_count": int(row.measurement_count)
|
||||
})
|
||||
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"period_days": days,
|
||||
"trends": trends,
|
||||
"total_measurements": len(trends)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get performance trends",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"period_days": days,
|
||||
"trends": [],
|
||||
"total_measurements": 0
|
||||
}
|
||||
|
||||
async def cleanup_old_metrics(self, days_old: int = 180) -> int:
|
||||
"""Clean up old performance metrics"""
|
||||
return await self.cleanup_old_records(days_old=days_old)
|
||||
|
||||
async def bulk_create_metrics(self, metrics: List[ModelPerformanceMetric]) -> int:
|
||||
"""
|
||||
Bulk insert performance metrics for validation
|
||||
|
||||
Args:
|
||||
metrics: List of ModelPerformanceMetric objects to insert
|
||||
|
||||
Returns:
|
||||
Number of metrics created
|
||||
"""
|
||||
try:
|
||||
if not metrics:
|
||||
return 0
|
||||
|
||||
self.session.add_all(metrics)
|
||||
await self.session.flush()
|
||||
|
||||
logger.info(
|
||||
"Bulk created performance metrics",
|
||||
count=len(metrics)
|
||||
)
|
||||
|
||||
return len(metrics)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to bulk create performance metrics",
|
||||
count=len(metrics),
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Failed to bulk create metrics: {str(e)}")
|
||||
|
||||
async def get_metrics_by_date_range(
|
||||
self,
|
||||
tenant_id: str,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
inventory_product_id: Optional[str] = None
|
||||
) -> List[ModelPerformanceMetric]:
|
||||
"""
|
||||
Get performance metrics for a date range
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
start_date: Start of date range
|
||||
end_date: End of date range
|
||||
inventory_product_id: Optional product filter
|
||||
|
||||
Returns:
|
||||
List of performance metrics
|
||||
"""
|
||||
try:
|
||||
filters = {
|
||||
"tenant_id": tenant_id
|
||||
}
|
||||
|
||||
if inventory_product_id:
|
||||
filters["inventory_product_id"] = inventory_product_id
|
||||
|
||||
# Build custom query for date range
|
||||
query_text = """
|
||||
SELECT *
|
||||
FROM model_performance_metrics
|
||||
WHERE tenant_id = :tenant_id
|
||||
AND evaluation_date >= :start_date
|
||||
AND evaluation_date <= :end_date
|
||||
"""
|
||||
|
||||
params = {
|
||||
"tenant_id": tenant_id,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date
|
||||
}
|
||||
|
||||
if inventory_product_id:
|
||||
query_text += " AND inventory_product_id = :inventory_product_id"
|
||||
params["inventory_product_id"] = inventory_product_id
|
||||
|
||||
query_text += " ORDER BY evaluation_date DESC"
|
||||
|
||||
result = await self.session.execute(text(query_text), params)
|
||||
rows = result.fetchall()
|
||||
|
||||
# Convert rows to ModelPerformanceMetric objects
|
||||
metrics = []
|
||||
for row in rows:
|
||||
metric = ModelPerformanceMetric()
|
||||
for column in row._mapping.keys():
|
||||
setattr(metric, column, row._mapping[column])
|
||||
metrics.append(metric)
|
||||
|
||||
return metrics
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get metrics by date range",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Failed to get metrics: {str(e)}")
|
||||
@@ -0,0 +1,388 @@
|
||||
"""
|
||||
Prediction Batch Repository
|
||||
Repository for prediction batch operations
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import text
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import structlog
|
||||
|
||||
from .base import ForecastingBaseRepository
|
||||
from app.models.forecasts import PredictionBatch
|
||||
from shared.database.exceptions import DatabaseError, ValidationError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class PredictionBatchRepository(ForecastingBaseRepository):
|
||||
"""Repository for prediction batch operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 300):
|
||||
# Batch operations change frequently, shorter cache time (5 minutes)
|
||||
super().__init__(PredictionBatch, session, cache_ttl)
|
||||
|
||||
async def create_batch(self, batch_data: Dict[str, Any]) -> PredictionBatch:
|
||||
"""Create a new prediction batch"""
|
||||
try:
|
||||
# Validate batch data
|
||||
validation_result = self._validate_forecast_data(
|
||||
batch_data,
|
||||
["tenant_id", "batch_name"]
|
||||
)
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid batch data: {validation_result['errors']}")
|
||||
|
||||
# Set default values
|
||||
if "status" not in batch_data:
|
||||
batch_data["status"] = "pending"
|
||||
if "forecast_days" not in batch_data:
|
||||
batch_data["forecast_days"] = 7
|
||||
if "business_type" not in batch_data:
|
||||
batch_data["business_type"] = "individual"
|
||||
|
||||
batch = await self.create(batch_data)
|
||||
|
||||
logger.info("Prediction batch created",
|
||||
batch_id=batch.id,
|
||||
tenant_id=batch.tenant_id,
|
||||
batch_name=batch.batch_name)
|
||||
|
||||
return batch
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to create prediction batch",
|
||||
tenant_id=batch_data.get("tenant_id"),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create batch: {str(e)}")
|
||||
|
||||
async def update_batch_progress(
|
||||
self,
|
||||
batch_id: str,
|
||||
completed_products: int = None,
|
||||
failed_products: int = None,
|
||||
total_products: int = None,
|
||||
status: str = None
|
||||
) -> Optional[PredictionBatch]:
|
||||
"""Update batch progress"""
|
||||
try:
|
||||
update_data = {}
|
||||
|
||||
if completed_products is not None:
|
||||
update_data["completed_products"] = completed_products
|
||||
if failed_products is not None:
|
||||
update_data["failed_products"] = failed_products
|
||||
if total_products is not None:
|
||||
update_data["total_products"] = total_products
|
||||
if status:
|
||||
update_data["status"] = status
|
||||
if status in ["completed", "failed"]:
|
||||
update_data["completed_at"] = datetime.now(timezone.utc)
|
||||
|
||||
if not update_data:
|
||||
return await self.get_by_id(batch_id)
|
||||
|
||||
updated_batch = await self.update(batch_id, update_data)
|
||||
|
||||
logger.debug("Batch progress updated",
|
||||
batch_id=batch_id,
|
||||
status=status,
|
||||
completed=completed_products)
|
||||
|
||||
return updated_batch
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to update batch progress",
|
||||
batch_id=batch_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to update batch: {str(e)}")
|
||||
|
||||
async def complete_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
processing_time_ms: int = None
|
||||
) -> Optional[PredictionBatch]:
|
||||
"""Mark batch as completed"""
|
||||
try:
|
||||
update_data = {
|
||||
"status": "completed",
|
||||
"completed_at": datetime.now(timezone.utc)
|
||||
}
|
||||
|
||||
if processing_time_ms:
|
||||
update_data["processing_time_ms"] = processing_time_ms
|
||||
|
||||
updated_batch = await self.update(batch_id, update_data)
|
||||
|
||||
logger.info("Batch completed",
|
||||
batch_id=batch_id,
|
||||
processing_time_ms=processing_time_ms)
|
||||
|
||||
return updated_batch
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to complete batch",
|
||||
batch_id=batch_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to complete batch: {str(e)}")
|
||||
|
||||
async def fail_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
error_message: str,
|
||||
processing_time_ms: int = None
|
||||
) -> Optional[PredictionBatch]:
|
||||
"""Mark batch as failed"""
|
||||
try:
|
||||
update_data = {
|
||||
"status": "failed",
|
||||
"completed_at": datetime.now(timezone.utc),
|
||||
"error_message": error_message
|
||||
}
|
||||
|
||||
if processing_time_ms:
|
||||
update_data["processing_time_ms"] = processing_time_ms
|
||||
|
||||
updated_batch = await self.update(batch_id, update_data)
|
||||
|
||||
logger.error("Batch failed",
|
||||
batch_id=batch_id,
|
||||
error_message=error_message)
|
||||
|
||||
return updated_batch
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to mark batch as failed",
|
||||
batch_id=batch_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to fail batch: {str(e)}")
|
||||
|
||||
async def cancel_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
cancelled_by: str = None
|
||||
) -> Optional[PredictionBatch]:
|
||||
"""Cancel a batch"""
|
||||
try:
|
||||
batch = await self.get_by_id(batch_id)
|
||||
if not batch:
|
||||
return None
|
||||
|
||||
if batch.status in ["completed", "failed"]:
|
||||
logger.warning("Cannot cancel finished batch",
|
||||
batch_id=batch_id,
|
||||
status=batch.status)
|
||||
return batch
|
||||
|
||||
update_data = {
|
||||
"status": "cancelled",
|
||||
"completed_at": datetime.now(timezone.utc),
|
||||
"cancelled_by": cancelled_by,
|
||||
"error_message": f"Cancelled by {cancelled_by}" if cancelled_by else "Cancelled"
|
||||
}
|
||||
|
||||
updated_batch = await self.update(batch_id, update_data)
|
||||
|
||||
logger.info("Batch cancelled",
|
||||
batch_id=batch_id,
|
||||
cancelled_by=cancelled_by)
|
||||
|
||||
return updated_batch
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cancel batch",
|
||||
batch_id=batch_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to cancel batch: {str(e)}")
|
||||
|
||||
async def get_active_batches(self, tenant_id: str = None) -> List[PredictionBatch]:
|
||||
"""Get currently active (pending/processing) batches"""
|
||||
try:
|
||||
filters = {"status": "processing"}
|
||||
if tenant_id:
|
||||
# Need to handle multiple status values with raw query
|
||||
query_text = """
|
||||
SELECT * FROM prediction_batches
|
||||
WHERE status IN ('pending', 'processing')
|
||||
AND tenant_id = :tenant_id
|
||||
ORDER BY requested_at DESC
|
||||
"""
|
||||
params = {"tenant_id": tenant_id}
|
||||
else:
|
||||
query_text = """
|
||||
SELECT * FROM prediction_batches
|
||||
WHERE status IN ('pending', 'processing')
|
||||
ORDER BY requested_at DESC
|
||||
"""
|
||||
params = {}
|
||||
|
||||
result = await self.session.execute(text(query_text), params)
|
||||
|
||||
batches = []
|
||||
for row in result.fetchall():
|
||||
record_dict = dict(row._mapping)
|
||||
batch = self.model(**record_dict)
|
||||
batches.append(batch)
|
||||
|
||||
return batches
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get active batches",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def get_batch_statistics(self, tenant_id: str = None) -> Dict[str, Any]:
|
||||
"""Get batch processing statistics"""
|
||||
try:
|
||||
base_filter = "WHERE 1=1"
|
||||
params = {}
|
||||
|
||||
if tenant_id:
|
||||
base_filter = "WHERE tenant_id = :tenant_id"
|
||||
params["tenant_id"] = tenant_id
|
||||
|
||||
# Get counts by status
|
||||
status_query = text(f"""
|
||||
SELECT
|
||||
status,
|
||||
COUNT(*) as count,
|
||||
AVG(CASE WHEN processing_time_ms IS NOT NULL THEN processing_time_ms END) as avg_processing_time_ms
|
||||
FROM prediction_batches
|
||||
{base_filter}
|
||||
GROUP BY status
|
||||
""")
|
||||
|
||||
result = await self.session.execute(status_query, params)
|
||||
|
||||
status_stats = {}
|
||||
total_batches = 0
|
||||
avg_processing_times = {}
|
||||
|
||||
for row in result.fetchall():
|
||||
status_stats[row.status] = row.count
|
||||
total_batches += row.count
|
||||
if row.avg_processing_time_ms:
|
||||
avg_processing_times[row.status] = float(row.avg_processing_time_ms)
|
||||
|
||||
# Get recent activity (batches in last 7 days)
|
||||
seven_days_ago = datetime.now(timezone.utc) - timedelta(days=7)
|
||||
recent_query = text(f"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM prediction_batches
|
||||
{base_filter}
|
||||
AND requested_at >= :seven_days_ago
|
||||
""")
|
||||
|
||||
recent_result = await self.session.execute(recent_query, {
|
||||
**params,
|
||||
"seven_days_ago": seven_days_ago
|
||||
})
|
||||
recent_batches = recent_result.scalar() or 0
|
||||
|
||||
# Calculate success rate
|
||||
completed = status_stats.get("completed", 0)
|
||||
failed = status_stats.get("failed", 0)
|
||||
cancelled = status_stats.get("cancelled", 0)
|
||||
finished_batches = completed + failed + cancelled
|
||||
|
||||
success_rate = (completed / finished_batches * 100) if finished_batches > 0 else 0
|
||||
|
||||
return {
|
||||
"total_batches": total_batches,
|
||||
"batches_by_status": status_stats,
|
||||
"success_rate": round(success_rate, 2),
|
||||
"recent_batches_7d": recent_batches,
|
||||
"avg_processing_times_ms": avg_processing_times
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get batch statistics",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"total_batches": 0,
|
||||
"batches_by_status": {},
|
||||
"success_rate": 0.0,
|
||||
"recent_batches_7d": 0,
|
||||
"avg_processing_times_ms": {}
|
||||
}
|
||||
|
||||
async def cleanup_old_batches(self, days_old: int = 30) -> int:
|
||||
"""Clean up old completed/failed batches"""
|
||||
try:
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_old)
|
||||
|
||||
query_text = """
|
||||
DELETE FROM prediction_batches
|
||||
WHERE status IN ('completed', 'failed', 'cancelled')
|
||||
AND completed_at < :cutoff_date
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
|
||||
deleted_count = result.rowcount
|
||||
|
||||
logger.info("Cleaned up old prediction batches",
|
||||
deleted_count=deleted_count,
|
||||
days_old=days_old)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup old batches",
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Batch cleanup failed: {str(e)}")
|
||||
|
||||
async def get_batch_details(self, batch_id: str) -> Dict[str, Any]:
|
||||
"""Get detailed batch information"""
|
||||
try:
|
||||
batch = await self.get_by_id(batch_id)
|
||||
if not batch:
|
||||
return {"error": "Batch not found"}
|
||||
|
||||
# Calculate completion percentage
|
||||
completion_percentage = 0
|
||||
if batch.total_products > 0:
|
||||
completion_percentage = (batch.completed_products / batch.total_products) * 100
|
||||
|
||||
# Calculate elapsed time
|
||||
elapsed_time_ms = 0
|
||||
if batch.completed_at:
|
||||
elapsed_time_ms = int((batch.completed_at - batch.requested_at).total_seconds() * 1000)
|
||||
elif batch.status in ["pending", "processing"]:
|
||||
elapsed_time_ms = int((datetime.now(timezone.utc) - batch.requested_at).total_seconds() * 1000)
|
||||
|
||||
return {
|
||||
"batch_id": str(batch.id),
|
||||
"tenant_id": str(batch.tenant_id),
|
||||
"batch_name": batch.batch_name,
|
||||
"status": batch.status,
|
||||
"progress": {
|
||||
"total_products": batch.total_products,
|
||||
"completed_products": batch.completed_products,
|
||||
"failed_products": batch.failed_products,
|
||||
"completion_percentage": round(completion_percentage, 2)
|
||||
},
|
||||
"timing": {
|
||||
"requested_at": batch.requested_at.isoformat(),
|
||||
"completed_at": batch.completed_at.isoformat() if batch.completed_at else None,
|
||||
"elapsed_time_ms": elapsed_time_ms,
|
||||
"processing_time_ms": batch.processing_time_ms
|
||||
},
|
||||
"configuration": {
|
||||
"forecast_days": batch.forecast_days,
|
||||
"business_type": batch.business_type
|
||||
},
|
||||
"error_message": batch.error_message,
|
||||
"cancelled_by": batch.cancelled_by
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get batch details",
|
||||
batch_id=batch_id,
|
||||
error=str(e))
|
||||
return {"error": f"Failed to get batch details: {str(e)}"}
|
||||
@@ -0,0 +1,302 @@
|
||||
"""
|
||||
Prediction Cache Repository
|
||||
Repository for prediction cache operations
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import text
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import structlog
|
||||
import hashlib
|
||||
|
||||
from .base import ForecastingBaseRepository
|
||||
from app.models.predictions import PredictionCache
|
||||
from shared.database.exceptions import DatabaseError, ValidationError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class PredictionCacheRepository(ForecastingBaseRepository):
|
||||
"""Repository for prediction cache operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 60):
|
||||
# Cache entries change very frequently, short cache time (1 minute)
|
||||
super().__init__(PredictionCache, session, cache_ttl)
|
||||
|
||||
def _generate_cache_key(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
location: str,
|
||||
forecast_date: datetime
|
||||
) -> str:
|
||||
"""Generate cache key for prediction"""
|
||||
key_data = f"{tenant_id}:{inventory_product_id}:{location}:{forecast_date.isoformat()}"
|
||||
return hashlib.md5(key_data.encode()).hexdigest()
|
||||
|
||||
async def cache_prediction(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
location: str,
|
||||
forecast_date: datetime,
|
||||
predicted_demand: float,
|
||||
confidence_lower: float,
|
||||
confidence_upper: float,
|
||||
model_id: str,
|
||||
expires_in_hours: int = 24
|
||||
) -> PredictionCache:
|
||||
"""Cache a prediction result"""
|
||||
try:
|
||||
cache_key = self._generate_cache_key(tenant_id, inventory_product_id, location, forecast_date)
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(hours=expires_in_hours)
|
||||
|
||||
cache_data = {
|
||||
"cache_key": cache_key,
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"location": location,
|
||||
"forecast_date": forecast_date,
|
||||
"predicted_demand": predicted_demand,
|
||||
"confidence_lower": confidence_lower,
|
||||
"confidence_upper": confidence_upper,
|
||||
"model_id": model_id,
|
||||
"expires_at": expires_at,
|
||||
"hit_count": 0
|
||||
}
|
||||
|
||||
# Try to update existing cache entry first
|
||||
existing_cache = await self.get_by_field("cache_key", cache_key)
|
||||
if existing_cache:
|
||||
cache_entry = await self.update(existing_cache.id, cache_data)
|
||||
logger.debug("Updated cache entry", cache_key=cache_key)
|
||||
else:
|
||||
cache_entry = await self.create(cache_data)
|
||||
logger.debug("Created cache entry", cache_key=cache_key)
|
||||
|
||||
return cache_entry
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cache prediction",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to cache prediction: {str(e)}")
|
||||
|
||||
async def get_cached_prediction(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
location: str,
|
||||
forecast_date: datetime
|
||||
) -> Optional[PredictionCache]:
|
||||
"""Get cached prediction if valid"""
|
||||
try:
|
||||
cache_key = self._generate_cache_key(tenant_id, inventory_product_id, location, forecast_date)
|
||||
|
||||
cache_entry = await self.get_by_field("cache_key", cache_key)
|
||||
|
||||
if not cache_entry:
|
||||
logger.debug("Cache miss", cache_key=cache_key)
|
||||
return None
|
||||
|
||||
# Check if cache entry has expired
|
||||
if cache_entry.expires_at < datetime.now(timezone.utc):
|
||||
logger.debug("Cache expired", cache_key=cache_key)
|
||||
await self.delete(cache_entry.id)
|
||||
return None
|
||||
|
||||
# Increment hit count
|
||||
await self.update(cache_entry.id, {"hit_count": cache_entry.hit_count + 1})
|
||||
|
||||
logger.debug("Cache hit",
|
||||
cache_key=cache_key,
|
||||
hit_count=cache_entry.hit_count + 1)
|
||||
|
||||
return cache_entry
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get cached prediction",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e))
|
||||
return None
|
||||
|
||||
async def invalidate_cache(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str = None,
|
||||
location: str = None
|
||||
) -> int:
|
||||
"""Invalidate cache entries"""
|
||||
try:
|
||||
conditions = ["tenant_id = :tenant_id"]
|
||||
params = {"tenant_id": tenant_id}
|
||||
|
||||
if inventory_product_id:
|
||||
conditions.append("inventory_product_id = :inventory_product_id")
|
||||
params["inventory_product_id"] = inventory_product_id
|
||||
|
||||
if location:
|
||||
conditions.append("location = :location")
|
||||
params["location"] = location
|
||||
|
||||
query_text = f"""
|
||||
DELETE FROM prediction_cache
|
||||
WHERE {' AND '.join(conditions)}
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), params)
|
||||
invalidated_count = result.rowcount
|
||||
|
||||
logger.info("Cache invalidated",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
location=location,
|
||||
invalidated_count=invalidated_count)
|
||||
|
||||
return invalidated_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to invalidate cache",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Cache invalidation failed: {str(e)}")
|
||||
|
||||
async def cleanup_expired_cache(self) -> int:
|
||||
"""Clean up expired cache entries"""
|
||||
try:
|
||||
query_text = """
|
||||
DELETE FROM prediction_cache
|
||||
WHERE expires_at < :now
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {"now": datetime.now(timezone.utc)})
|
||||
deleted_count = result.rowcount
|
||||
|
||||
logger.info("Cleaned up expired cache entries",
|
||||
deleted_count=deleted_count)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup expired cache",
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Cache cleanup failed: {str(e)}")
|
||||
|
||||
async def get_cache_statistics(self, tenant_id: str = None) -> Dict[str, Any]:
|
||||
"""Get cache performance statistics"""
|
||||
try:
|
||||
base_filter = "WHERE 1=1"
|
||||
params = {}
|
||||
|
||||
if tenant_id:
|
||||
base_filter = "WHERE tenant_id = :tenant_id"
|
||||
params["tenant_id"] = tenant_id
|
||||
|
||||
# Get cache statistics
|
||||
stats_query = text(f"""
|
||||
SELECT
|
||||
COUNT(*) as total_entries,
|
||||
COUNT(CASE WHEN expires_at > :now THEN 1 END) as active_entries,
|
||||
COUNT(CASE WHEN expires_at <= :now THEN 1 END) as expired_entries,
|
||||
SUM(hit_count) as total_hits,
|
||||
AVG(hit_count) as avg_hits_per_entry,
|
||||
MAX(hit_count) as max_hits,
|
||||
COUNT(DISTINCT inventory_product_id) as unique_products
|
||||
FROM prediction_cache
|
||||
{base_filter}
|
||||
""")
|
||||
|
||||
params["now"] = datetime.now(timezone.utc)
|
||||
|
||||
result = await self.session.execute(stats_query, params)
|
||||
row = result.fetchone()
|
||||
|
||||
if row:
|
||||
return {
|
||||
"total_entries": int(row.total_entries or 0),
|
||||
"active_entries": int(row.active_entries or 0),
|
||||
"expired_entries": int(row.expired_entries or 0),
|
||||
"total_hits": int(row.total_hits or 0),
|
||||
"avg_hits_per_entry": float(row.avg_hits_per_entry or 0),
|
||||
"max_hits": int(row.max_hits or 0),
|
||||
"unique_products": int(row.unique_products or 0),
|
||||
"cache_hit_ratio": round((row.total_hits / max(row.total_entries, 1)), 2)
|
||||
}
|
||||
|
||||
return {
|
||||
"total_entries": 0,
|
||||
"active_entries": 0,
|
||||
"expired_entries": 0,
|
||||
"total_hits": 0,
|
||||
"avg_hits_per_entry": 0.0,
|
||||
"max_hits": 0,
|
||||
"unique_products": 0,
|
||||
"cache_hit_ratio": 0.0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get cache statistics",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"total_entries": 0,
|
||||
"active_entries": 0,
|
||||
"expired_entries": 0,
|
||||
"total_hits": 0,
|
||||
"avg_hits_per_entry": 0.0,
|
||||
"max_hits": 0,
|
||||
"unique_products": 0,
|
||||
"cache_hit_ratio": 0.0
|
||||
}
|
||||
|
||||
async def get_most_accessed_predictions(
|
||||
self,
|
||||
tenant_id: str = None,
|
||||
limit: int = 10
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get most frequently accessed cached predictions"""
|
||||
try:
|
||||
base_filter = "WHERE hit_count > 0"
|
||||
params = {"limit": limit}
|
||||
|
||||
if tenant_id:
|
||||
base_filter = "WHERE tenant_id = :tenant_id AND hit_count > 0"
|
||||
params["tenant_id"] = tenant_id
|
||||
|
||||
query_text = f"""
|
||||
SELECT
|
||||
inventory_product_id,
|
||||
location,
|
||||
hit_count,
|
||||
predicted_demand,
|
||||
created_at,
|
||||
expires_at
|
||||
FROM prediction_cache
|
||||
{base_filter}
|
||||
ORDER BY hit_count DESC
|
||||
LIMIT :limit
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), params)
|
||||
|
||||
popular_predictions = []
|
||||
for row in result.fetchall():
|
||||
popular_predictions.append({
|
||||
"inventory_product_id": row.inventory_product_id,
|
||||
"location": row.location,
|
||||
"hit_count": int(row.hit_count),
|
||||
"predicted_demand": float(row.predicted_demand),
|
||||
"created_at": row.created_at.isoformat() if row.created_at else None,
|
||||
"expires_at": row.expires_at.isoformat() if row.expires_at else None
|
||||
})
|
||||
|
||||
return popular_predictions
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get most accessed predictions",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return []
|
||||
0
services/forecasting/app/schemas/__init__.py
Normal file
0
services/forecasting/app/schemas/__init__.py
Normal file
302
services/forecasting/app/schemas/forecasts.py
Normal file
302
services/forecasting/app/schemas/forecasts.py
Normal file
@@ -0,0 +1,302 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/schemas/forecasts.py
|
||||
# ================================================================
|
||||
"""
|
||||
Forecast schemas for request/response validation
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from datetime import datetime, date
|
||||
from typing import Optional, List, Dict, Any
|
||||
from enum import Enum
|
||||
from uuid import UUID
|
||||
|
||||
class BusinessType(str, Enum):
|
||||
INDIVIDUAL = "individual"
|
||||
CENTRAL_WORKSHOP = "central_workshop"
|
||||
|
||||
|
||||
class ForecastRequest(BaseModel):
|
||||
"""Request schema for generating forecasts"""
|
||||
inventory_product_id: str = Field(..., description="Inventory product UUID reference")
|
||||
# product_name: str = Field(..., description="Product name") # DEPRECATED - use inventory_product_id
|
||||
forecast_date: date = Field(..., description="Starting date for forecast")
|
||||
forecast_days: int = Field(1, ge=1, le=30, description="Number of days to forecast")
|
||||
location: str = Field(..., description="Location identifier")
|
||||
|
||||
# Optional parameters - internally handled
|
||||
confidence_level: float = Field(0.8, ge=0.5, le=0.95, description="Confidence level")
|
||||
|
||||
@validator('inventory_product_id')
|
||||
def validate_inventory_product_id(cls, v):
|
||||
"""Validate that inventory_product_id is a valid UUID"""
|
||||
try:
|
||||
UUID(v)
|
||||
except (ValueError, AttributeError):
|
||||
raise ValueError(f"inventory_product_id must be a valid UUID, got: {v}")
|
||||
return v
|
||||
|
||||
@validator('forecast_date')
|
||||
def validate_forecast_date(cls, v):
|
||||
if v < date.today():
|
||||
raise ValueError("Forecast date cannot be in the past")
|
||||
return v
|
||||
|
||||
class BatchForecastRequest(BaseModel):
|
||||
"""Request schema for batch forecasting"""
|
||||
tenant_id: Optional[str] = None # Optional, can be from path parameter
|
||||
batch_name: str = Field(..., description="Batch name for tracking")
|
||||
inventory_product_ids: List[str] = Field(..., description="List of inventory product IDs")
|
||||
forecast_days: int = Field(7, ge=1, le=30, description="Number of days to forecast")
|
||||
|
||||
@validator('tenant_id')
|
||||
def validate_tenant_id(cls, v):
|
||||
"""Validate that tenant_id is a valid UUID if provided"""
|
||||
if v is not None:
|
||||
try:
|
||||
UUID(v)
|
||||
except (ValueError, AttributeError):
|
||||
raise ValueError(f"tenant_id must be a valid UUID, got: {v}")
|
||||
return v
|
||||
|
||||
@validator('inventory_product_ids')
|
||||
def validate_inventory_product_ids(cls, v):
|
||||
"""Validate that all inventory_product_ids are valid UUIDs"""
|
||||
for product_id in v:
|
||||
try:
|
||||
UUID(product_id)
|
||||
except (ValueError, AttributeError):
|
||||
raise ValueError(f"All inventory_product_ids must be valid UUIDs, got invalid: {product_id}")
|
||||
return v
|
||||
|
||||
class ForecastResponse(BaseModel):
|
||||
"""Response schema for forecast results"""
|
||||
id: str
|
||||
tenant_id: str
|
||||
inventory_product_id: str # Reference to inventory service
|
||||
# product_name: str # Can be fetched from inventory service if needed for display
|
||||
location: str
|
||||
forecast_date: datetime
|
||||
|
||||
# Predictions
|
||||
predicted_demand: float
|
||||
confidence_lower: float
|
||||
confidence_upper: float
|
||||
confidence_level: float
|
||||
|
||||
# Model info
|
||||
model_id: str
|
||||
model_version: str
|
||||
algorithm: str
|
||||
|
||||
# Context
|
||||
business_type: str
|
||||
is_holiday: bool
|
||||
is_weekend: bool
|
||||
day_of_week: int
|
||||
|
||||
# External factors
|
||||
weather_temperature: Optional[float]
|
||||
weather_precipitation: Optional[float]
|
||||
weather_description: Optional[str]
|
||||
traffic_volume: Optional[int]
|
||||
|
||||
# Metadata
|
||||
created_at: datetime
|
||||
processing_time_ms: Optional[int]
|
||||
features_used: Optional[Dict[str, Any]]
|
||||
|
||||
class BatchForecastResponse(BaseModel):
|
||||
"""Response schema for batch forecast requests"""
|
||||
id: str
|
||||
tenant_id: str
|
||||
batch_name: str
|
||||
status: str
|
||||
total_products: int
|
||||
completed_products: int
|
||||
failed_products: int
|
||||
|
||||
# Timing
|
||||
requested_at: datetime
|
||||
completed_at: Optional[datetime]
|
||||
processing_time_ms: Optional[int]
|
||||
|
||||
# Results
|
||||
forecasts: Optional[List[ForecastResponse]]
|
||||
error_message: Optional[str]
|
||||
|
||||
class MultiDayForecastResponse(BaseModel):
|
||||
"""Response schema for multi-day forecast results"""
|
||||
tenant_id: str = Field(..., description="Tenant ID")
|
||||
inventory_product_id: str = Field(..., description="Inventory product ID")
|
||||
forecast_start_date: date = Field(..., description="Start date of forecast period")
|
||||
forecast_days: int = Field(..., description="Number of forecasted days")
|
||||
forecasts: List[ForecastResponse] = Field(..., description="Daily forecasts")
|
||||
total_predicted_demand: float = Field(..., description="Total demand across all days")
|
||||
average_confidence_level: float = Field(..., description="Average confidence across all days")
|
||||
processing_time_ms: int = Field(..., description="Total processing time")
|
||||
|
||||
|
||||
# ================================================================
|
||||
# SCENARIO SIMULATION SCHEMAS - PROFESSIONAL/ENTERPRISE ONLY
|
||||
# ================================================================
|
||||
|
||||
class ScenarioType(str, Enum):
|
||||
"""Types of scenarios available for simulation"""
|
||||
WEATHER = "weather" # Weather impact (heatwave, cold snap, rain, etc.)
|
||||
COMPETITION = "competition" # New competitor opening nearby
|
||||
EVENT = "event" # Local event (festival, sports, concert, etc.)
|
||||
PRICING = "pricing" # Price changes
|
||||
PROMOTION = "promotion" # Promotional campaigns
|
||||
HOLIDAY = "holiday" # Holiday periods
|
||||
SUPPLY_DISRUPTION = "supply_disruption" # Supply chain issues
|
||||
CUSTOM = "custom" # Custom user-defined scenario
|
||||
|
||||
|
||||
class WeatherScenario(BaseModel):
|
||||
"""Weather scenario parameters"""
|
||||
temperature_change: Optional[float] = Field(None, ge=-30, le=30, description="Temperature change in °C")
|
||||
precipitation_change: Optional[float] = Field(None, ge=0, le=100, description="Precipitation change in mm")
|
||||
weather_type: Optional[str] = Field(None, description="Weather type (heatwave, cold_snap, rainy, etc.)")
|
||||
|
||||
|
||||
class CompetitionScenario(BaseModel):
|
||||
"""Competition scenario parameters"""
|
||||
new_competitors: int = Field(1, ge=1, le=10, description="Number of new competitors")
|
||||
distance_km: float = Field(0.5, ge=0.1, le=10, description="Distance from location in km")
|
||||
estimated_market_share_loss: float = Field(0.1, ge=0, le=0.5, description="Estimated market share loss (0-50%)")
|
||||
|
||||
|
||||
class EventScenario(BaseModel):
|
||||
"""Event scenario parameters"""
|
||||
event_type: str = Field(..., description="Type of event (festival, sports, concert, etc.)")
|
||||
expected_attendance: int = Field(..., ge=0, description="Expected attendance")
|
||||
distance_km: float = Field(0.5, ge=0, le=50, description="Distance from location in km")
|
||||
duration_days: int = Field(1, ge=1, le=30, description="Duration in days")
|
||||
|
||||
|
||||
class PricingScenario(BaseModel):
|
||||
"""Pricing scenario parameters"""
|
||||
price_change_percent: float = Field(..., ge=-50, le=100, description="Price change percentage")
|
||||
affected_products: Optional[List[str]] = Field(None, description="List of affected product IDs")
|
||||
|
||||
|
||||
class PromotionScenario(BaseModel):
|
||||
"""Promotion scenario parameters"""
|
||||
discount_percent: float = Field(..., ge=0, le=75, description="Discount percentage")
|
||||
promotion_type: str = Field(..., description="Type of promotion (bogo, discount, bundle, etc.)")
|
||||
expected_traffic_increase: float = Field(0.2, ge=0, le=2, description="Expected traffic increase (0-200%)")
|
||||
|
||||
|
||||
class ScenarioSimulationRequest(BaseModel):
|
||||
"""Request schema for scenario simulation - PROFESSIONAL/ENTERPRISE ONLY"""
|
||||
scenario_name: str = Field(..., min_length=3, max_length=200, description="Name for this scenario")
|
||||
scenario_type: ScenarioType = Field(..., description="Type of scenario to simulate")
|
||||
inventory_product_ids: List[str] = Field(..., min_items=1, description="Products to simulate")
|
||||
start_date: date = Field(..., description="Simulation start date")
|
||||
duration_days: int = Field(7, ge=1, le=30, description="Simulation duration in days")
|
||||
|
||||
# Scenario-specific parameters (one should be provided based on scenario_type)
|
||||
weather_params: Optional[WeatherScenario] = None
|
||||
competition_params: Optional[CompetitionScenario] = None
|
||||
event_params: Optional[EventScenario] = None
|
||||
pricing_params: Optional[PricingScenario] = None
|
||||
promotion_params: Optional[PromotionScenario] = None
|
||||
|
||||
# Custom scenario parameters
|
||||
custom_multipliers: Optional[Dict[str, float]] = Field(
|
||||
None,
|
||||
description="Custom multipliers for baseline forecast (e.g., {'demand': 1.2, 'traffic': 0.8})"
|
||||
)
|
||||
|
||||
# Comparison settings
|
||||
include_baseline: bool = Field(True, description="Include baseline forecast for comparison")
|
||||
|
||||
@validator('start_date')
|
||||
def validate_start_date(cls, v):
|
||||
if v < date.today():
|
||||
raise ValueError("Simulation start date cannot be in the past")
|
||||
return v
|
||||
|
||||
|
||||
class ScenarioImpact(BaseModel):
|
||||
"""Impact of scenario on a specific product"""
|
||||
inventory_product_id: str
|
||||
baseline_demand: float
|
||||
simulated_demand: float
|
||||
demand_change_percent: float
|
||||
confidence_range: tuple[float, float]
|
||||
impact_factors: Dict[str, Any] # Breakdown of what drove the change
|
||||
|
||||
|
||||
class ScenarioSimulationResponse(BaseModel):
|
||||
"""Response schema for scenario simulation"""
|
||||
id: str = Field(..., description="Simulation ID")
|
||||
tenant_id: str
|
||||
scenario_name: str
|
||||
scenario_type: ScenarioType
|
||||
|
||||
# Simulation parameters
|
||||
start_date: date
|
||||
end_date: date
|
||||
duration_days: int
|
||||
|
||||
# Results
|
||||
baseline_forecasts: Optional[List[ForecastResponse]] = Field(
|
||||
None,
|
||||
description="Baseline forecasts (if requested)"
|
||||
)
|
||||
scenario_forecasts: List[ForecastResponse] = Field(..., description="Forecasts with scenario applied")
|
||||
|
||||
# Impact summary
|
||||
total_baseline_demand: float
|
||||
total_scenario_demand: float
|
||||
overall_impact_percent: float
|
||||
product_impacts: List[ScenarioImpact]
|
||||
|
||||
# Insights and recommendations
|
||||
insights: List[str] = Field(..., description="AI-generated insights about the scenario")
|
||||
recommendations: List[str] = Field(..., description="Actionable recommendations")
|
||||
risk_level: str = Field(..., description="Risk level: low, medium, high")
|
||||
|
||||
# Metadata
|
||||
created_at: datetime
|
||||
processing_time_ms: int
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"id": "scenario_123",
|
||||
"tenant_id": "tenant_456",
|
||||
"scenario_name": "Summer Heatwave Impact",
|
||||
"scenario_type": "weather",
|
||||
"overall_impact_percent": 15.5,
|
||||
"insights": [
|
||||
"Cold beverages expected to increase by 45%",
|
||||
"Bread products may decrease by 8% due to reduced appetite",
|
||||
"Ice cream demand projected to surge by 120%"
|
||||
],
|
||||
"recommendations": [
|
||||
"Increase cold beverage inventory by 40%",
|
||||
"Reduce bread production by 10%",
|
||||
"Stock additional ice cream varieties"
|
||||
],
|
||||
"risk_level": "medium"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ScenarioComparisonRequest(BaseModel):
|
||||
"""Request to compare multiple scenarios"""
|
||||
scenario_ids: List[str] = Field(..., min_items=2, max_items=5, description="Scenario IDs to compare")
|
||||
|
||||
|
||||
class ScenarioComparisonResponse(BaseModel):
|
||||
"""Response comparing multiple scenarios"""
|
||||
scenarios: List[ScenarioSimulationResponse]
|
||||
comparison_matrix: Dict[str, Dict[str, Any]]
|
||||
best_case_scenario_id: str
|
||||
worst_case_scenario_id: str
|
||||
recommended_action: str
|
||||
|
||||
|
||||
17
services/forecasting/app/services/__init__.py
Normal file
17
services/forecasting/app/services/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
Forecasting Service Layer
|
||||
Business logic services for demand forecasting and prediction
|
||||
"""
|
||||
|
||||
from .forecasting_service import ForecastingService, EnhancedForecastingService
|
||||
from .prediction_service import PredictionService
|
||||
from .model_client import ModelClient
|
||||
from .data_client import DataClient
|
||||
|
||||
__all__ = [
|
||||
"ForecastingService",
|
||||
"EnhancedForecastingService",
|
||||
"PredictionService",
|
||||
"ModelClient",
|
||||
"DataClient"
|
||||
]
|
||||
132
services/forecasting/app/services/data_client.py
Normal file
132
services/forecasting/app/services/data_client.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# services/training/app/services/data_client.py
|
||||
"""
|
||||
Training Service Data Client
|
||||
Migrated to use shared service clients - much simpler now!
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
# Import the shared clients
|
||||
from shared.clients import get_sales_client, get_external_client, get_service_clients
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
class DataClient:
|
||||
"""
|
||||
Data client for training service
|
||||
Now uses the shared data service client under the hood
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Get the new specialized clients
|
||||
self.sales_client = get_sales_client(settings, "forecasting")
|
||||
self.external_client = get_external_client(settings, "forecasting")
|
||||
|
||||
# Or alternatively, get all clients at once:
|
||||
# self.clients = get_service_clients(settings, "forecasting")
|
||||
# Then use: self.clients.sales.get_sales_data(...) and self.clients.external.get_weather_forecast(...)
|
||||
|
||||
|
||||
async def fetch_weather_forecast(
|
||||
self,
|
||||
tenant_id: str,
|
||||
days: int = 7,
|
||||
latitude: Optional[float] = None,
|
||||
longitude: Optional[float] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch weather forecast data
|
||||
Uses new v2.0 optimized endpoint via shared external client
|
||||
"""
|
||||
try:
|
||||
weather_data = await self.external_client.get_weather_forecast(
|
||||
tenant_id=tenant_id,
|
||||
days=days,
|
||||
latitude=latitude,
|
||||
longitude=longitude
|
||||
)
|
||||
|
||||
if weather_data:
|
||||
logger.info(f"Fetched {len(weather_data)} weather records",
|
||||
tenant_id=tenant_id)
|
||||
return weather_data
|
||||
else:
|
||||
logger.warning("No weather data returned", tenant_id=tenant_id)
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching weather data: {e}", tenant_id=tenant_id)
|
||||
return []
|
||||
|
||||
async def fetch_tenant_calendar(
|
||||
self,
|
||||
tenant_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch tenant's assigned school calendar
|
||||
Returns None if no calendar assigned
|
||||
"""
|
||||
try:
|
||||
location_context = await self.external_client.get_tenant_location_context(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
if location_context and location_context.get("calendar"):
|
||||
logger.info(
|
||||
"Fetched calendar for tenant",
|
||||
tenant_id=tenant_id,
|
||||
calendar_name=location_context["calendar"].get("calendar_name")
|
||||
)
|
||||
return location_context["calendar"]
|
||||
else:
|
||||
logger.info("No calendar assigned to tenant", tenant_id=tenant_id)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching calendar: {e}", tenant_id=tenant_id)
|
||||
return None
|
||||
|
||||
async def check_school_holiday(
|
||||
self,
|
||||
calendar_id: str,
|
||||
check_date: str,
|
||||
tenant_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a date is a school holiday
|
||||
|
||||
Args:
|
||||
calendar_id: School calendar UUID
|
||||
check_date: Date in ISO format (YYYY-MM-DD)
|
||||
tenant_id: Tenant ID for auth
|
||||
|
||||
Returns:
|
||||
True if school holiday, False otherwise
|
||||
"""
|
||||
try:
|
||||
result = await self.external_client.check_is_school_holiday(
|
||||
calendar_id=calendar_id,
|
||||
check_date=check_date,
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
if result:
|
||||
is_holiday = result.get("is_holiday", False)
|
||||
if is_holiday:
|
||||
logger.debug(
|
||||
"School holiday detected",
|
||||
date=check_date,
|
||||
holiday_name=result.get("holiday_name")
|
||||
)
|
||||
return is_holiday
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking school holiday: {e}", date=check_date)
|
||||
return False
|
||||
|
||||
# Global instance - same as before, but much simpler implementation
|
||||
data_client = DataClient()
|
||||
@@ -0,0 +1,260 @@
|
||||
"""
|
||||
Enterprise forecasting service for aggregated demand across parent-child tenants
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import date, datetime
|
||||
import json
|
||||
import redis.asyncio as redis
|
||||
|
||||
from shared.clients.forecast_client import ForecastServiceClient
|
||||
from shared.clients.tenant_client import TenantServiceClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EnterpriseForecastingService:
|
||||
"""
|
||||
Service for aggregating forecasts across parent and child tenants
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
forecast_client: ForecastServiceClient,
|
||||
tenant_client: TenantServiceClient,
|
||||
redis_client: redis.Redis
|
||||
):
|
||||
self.forecast_client = forecast_client
|
||||
self.tenant_client = tenant_client
|
||||
self.redis_client = redis_client
|
||||
self.cache_ttl_seconds = 3600 # 1 hour TTL
|
||||
|
||||
async def get_aggregated_forecast(
|
||||
self,
|
||||
parent_tenant_id: str,
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
product_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get aggregated forecast across parent and all child tenants
|
||||
|
||||
Args:
|
||||
parent_tenant_id: Parent tenant ID
|
||||
start_date: Start date for forecast aggregation
|
||||
end_date: End date for forecast aggregation
|
||||
product_id: Optional product ID to filter by
|
||||
|
||||
Returns:
|
||||
Dict with aggregated forecast data by date and product
|
||||
"""
|
||||
# Create cache key
|
||||
cache_key = f"agg_forecast:{parent_tenant_id}:{start_date}:{end_date}:{product_id or 'all'}"
|
||||
|
||||
# Try to get from cache first
|
||||
try:
|
||||
cached_result = await self.redis_client.get(cache_key)
|
||||
if cached_result:
|
||||
logger.info(f"Cache hit for aggregated forecast: {cache_key}")
|
||||
return json.loads(cached_result)
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache read failed: {e}")
|
||||
|
||||
logger.info(f"Computing aggregated forecast for parent {parent_tenant_id} from {start_date} to {end_date}")
|
||||
|
||||
# Get child tenant IDs
|
||||
child_tenants = await self.tenant_client.get_child_tenants(parent_tenant_id)
|
||||
child_tenant_ids = [child['id'] for child in child_tenants]
|
||||
|
||||
# Include parent tenant in the list for complete aggregation
|
||||
all_tenant_ids = [parent_tenant_id] + child_tenant_ids
|
||||
|
||||
# Fetch forecasts for all tenants (parent + children)
|
||||
all_forecasts = {}
|
||||
tenant_contributions = {} # Track which tenant contributed to each forecast
|
||||
|
||||
for tenant_id in all_tenant_ids:
|
||||
try:
|
||||
tenant_forecasts = await self.forecast_client.get_forecasts(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
product_id=product_id
|
||||
)
|
||||
|
||||
for forecast_date_str, products in tenant_forecasts.items():
|
||||
if forecast_date_str not in all_forecasts:
|
||||
all_forecasts[forecast_date_str] = {}
|
||||
tenant_contributions[forecast_date_str] = {}
|
||||
|
||||
for product_id_key, forecast_data in products.items():
|
||||
if product_id_key not in all_forecasts[forecast_date_str]:
|
||||
all_forecasts[forecast_date_str][product_id_key] = {
|
||||
'predicted_demand': 0,
|
||||
'confidence_lower': 0,
|
||||
'confidence_upper': 0,
|
||||
'tenant_contributions': []
|
||||
}
|
||||
|
||||
# Aggregate the forecast values
|
||||
all_forecasts[forecast_date_str][product_id_key]['predicted_demand'] += forecast_data.get('predicted_demand', 0)
|
||||
|
||||
# For confidence intervals, we'll use a simple approach
|
||||
# In a real implementation, this would require proper statistical combination
|
||||
all_forecasts[forecast_date_str][product_id_key]['confidence_lower'] += forecast_data.get('confidence_lower', 0)
|
||||
all_forecasts[forecast_date_str][product_id_key]['confidence_upper'] += forecast_data.get('confidence_upper', 0)
|
||||
|
||||
# Track contribution by tenant
|
||||
all_forecasts[forecast_date_str][product_id_key]['tenant_contributions'].append({
|
||||
'tenant_id': tenant_id,
|
||||
'demand': forecast_data.get('predicted_demand', 0),
|
||||
'confidence_lower': forecast_data.get('confidence_lower', 0),
|
||||
'confidence_upper': forecast_data.get('confidence_upper', 0)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch forecasts for tenant {tenant_id}: {e}")
|
||||
# Continue with other tenants even if one fails
|
||||
|
||||
# Prepare result
|
||||
result = {
|
||||
"parent_tenant_id": parent_tenant_id,
|
||||
"aggregated_forecasts": all_forecasts,
|
||||
"tenant_contributions": tenant_contributions,
|
||||
"child_tenant_count": len(child_tenant_ids),
|
||||
"forecast_dates": list(all_forecasts.keys()),
|
||||
"computed_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
# Cache the result
|
||||
try:
|
||||
await self.redis_client.setex(
|
||||
cache_key,
|
||||
self.cache_ttl_seconds,
|
||||
json.dumps(result, default=str) # Handle date serialization
|
||||
)
|
||||
logger.info(f"Forecast cached for {cache_key}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache write failed: {e}")
|
||||
|
||||
return result
|
||||
|
||||
async def get_network_performance_metrics(
|
||||
self,
|
||||
parent_tenant_id: str,
|
||||
start_date: date,
|
||||
end_date: date
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get aggregated performance metrics across the tenant network
|
||||
|
||||
Args:
|
||||
parent_tenant_id: Parent tenant ID
|
||||
start_date: Start date for metrics
|
||||
end_date: End date for metrics
|
||||
|
||||
Returns:
|
||||
Dict with aggregated performance metrics
|
||||
"""
|
||||
child_tenants = await self.tenant_client.get_child_tenants(parent_tenant_id)
|
||||
child_tenant_ids = [child['id'] for child in child_tenants]
|
||||
|
||||
# Include parent tenant in the list for complete aggregation
|
||||
all_tenant_ids = [parent_tenant_id] + child_tenant_ids
|
||||
|
||||
total_sales = 0
|
||||
total_forecasted = 0
|
||||
total_accuracy = 0
|
||||
tenant_count = 0
|
||||
|
||||
performance_data = {}
|
||||
|
||||
for tenant_id in all_tenant_ids:
|
||||
try:
|
||||
# Fetch sales and forecast data for the period
|
||||
sales_data = await self._fetch_sales_data(tenant_id, start_date, end_date)
|
||||
forecast_data = await self.get_aggregated_forecast(tenant_id, start_date, end_date)
|
||||
|
||||
tenant_performance = {
|
||||
'tenant_id': tenant_id,
|
||||
'sales': sales_data.get('total_sales', 0),
|
||||
'forecasted': sum(
|
||||
sum(day.get('predicted_demand', 0) for product in day.values())
|
||||
if isinstance(day, dict) else day
|
||||
for day in forecast_data.get('aggregated_forecasts', {}).values()
|
||||
),
|
||||
}
|
||||
|
||||
# Calculate accuracy if both sales and forecast data exist
|
||||
if tenant_performance['sales'] > 0 and tenant_performance['forecasted'] > 0:
|
||||
accuracy = 1 - abs(tenant_performance['forecasted'] - tenant_performance['sales']) / tenant_performance['sales']
|
||||
tenant_performance['accuracy'] = max(0, min(1, accuracy)) # Clamp between 0 and 1
|
||||
else:
|
||||
tenant_performance['accuracy'] = 0
|
||||
|
||||
performance_data[tenant_id] = tenant_performance
|
||||
total_sales += tenant_performance['sales']
|
||||
total_forecasted += tenant_performance['forecasted']
|
||||
total_accuracy += tenant_performance['accuracy']
|
||||
tenant_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch performance data for tenant {tenant_id}: {e}")
|
||||
|
||||
network_performance = {
|
||||
"parent_tenant_id": parent_tenant_id,
|
||||
"total_sales": total_sales,
|
||||
"total_forecasted": total_forecasted,
|
||||
"average_accuracy": total_accuracy / tenant_count if tenant_count > 0 else 0,
|
||||
"tenant_count": tenant_count,
|
||||
"child_tenant_count": len(child_tenant_ids),
|
||||
"tenant_performances": performance_data,
|
||||
"computed_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
return network_performance
|
||||
|
||||
async def _fetch_sales_data(self, tenant_id: str, start_date: date, end_date: date) -> Dict[str, Any]:
|
||||
"""
|
||||
Helper method to fetch sales data from the sales service
|
||||
"""
|
||||
try:
|
||||
from shared.clients.sales_client import SalesServiceClient
|
||||
from shared.config.base import get_settings
|
||||
|
||||
# Create sales client
|
||||
config = get_settings()
|
||||
sales_client = SalesServiceClient(config, calling_service_name="forecasting")
|
||||
|
||||
# Fetch sales data for the date range
|
||||
sales_data = await sales_client.get_sales_data(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date.isoformat(),
|
||||
end_date=end_date.isoformat(),
|
||||
aggregation="daily"
|
||||
)
|
||||
|
||||
# Calculate total sales from the retrieved data
|
||||
total_sales = 0
|
||||
if sales_data:
|
||||
for sale in sales_data:
|
||||
# Sum up quantity_sold or total_amount depending on what's available
|
||||
total_sales += sale.get('quantity_sold', 0)
|
||||
|
||||
return {
|
||||
'total_sales': total_sales,
|
||||
'date_range': f"{start_date} to {end_date}",
|
||||
'tenant_id': tenant_id,
|
||||
'record_count': len(sales_data) if sales_data else 0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch sales data for tenant {tenant_id}: {e}")
|
||||
# Return empty result on error
|
||||
return {
|
||||
'total_sales': 0,
|
||||
'date_range': f"{start_date} to {end_date}",
|
||||
'tenant_id': tenant_id,
|
||||
'error': str(e)
|
||||
}
|
||||
495
services/forecasting/app/services/forecast_cache.py
Normal file
495
services/forecasting/app/services/forecast_cache.py
Normal file
@@ -0,0 +1,495 @@
|
||||
# services/forecasting/app/services/forecast_cache.py
|
||||
"""
|
||||
Forecast Cache Service - Redis-based caching for forecast results
|
||||
|
||||
Provides service-level caching for forecast predictions to eliminate redundant
|
||||
computations when multiple services (Orders, Production) request the same
|
||||
forecast data within a short time window.
|
||||
|
||||
Cache Strategy:
|
||||
- Key: forecast:{tenant_id}:{product_id}:{forecast_date}
|
||||
- TTL: Until midnight of day after forecast_date
|
||||
- Invalidation: On model retraining for specific products
|
||||
- Metadata: Includes 'cached' flag for observability
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime, date, timedelta
|
||||
from typing import Optional, Dict, Any, List
|
||||
from uuid import UUID
|
||||
import structlog
|
||||
from shared.redis_utils import get_redis_client
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class ForecastCacheService:
|
||||
"""Service-level caching for forecast predictions"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize forecast cache service"""
|
||||
pass
|
||||
|
||||
async def _get_redis(self):
|
||||
"""Get shared Redis client"""
|
||||
return await get_redis_client()
|
||||
|
||||
async def is_available(self) -> bool:
|
||||
"""Check if Redis cache is available"""
|
||||
try:
|
||||
client = await self._get_redis()
|
||||
await client.ping()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# ================================================================
|
||||
# FORECAST CACHING
|
||||
# ================================================================
|
||||
|
||||
def _get_forecast_key(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
product_id: UUID,
|
||||
forecast_date: date
|
||||
) -> str:
|
||||
"""Generate cache key for forecast"""
|
||||
return f"forecast:{tenant_id}:{product_id}:{forecast_date.isoformat()}"
|
||||
|
||||
def _get_batch_forecast_key(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
product_ids: List[UUID],
|
||||
forecast_date: date
|
||||
) -> str:
|
||||
"""Generate cache key for batch forecast"""
|
||||
# Sort product IDs for consistent key generation
|
||||
sorted_ids = sorted(str(pid) for pid in product_ids)
|
||||
products_hash = hash(tuple(sorted_ids))
|
||||
return f"forecast:batch:{tenant_id}:{products_hash}:{forecast_date.isoformat()}"
|
||||
|
||||
def _calculate_ttl(self, forecast_date: date) -> int:
|
||||
"""
|
||||
Calculate TTL for forecast cache entry
|
||||
|
||||
Forecasts expire at midnight of the day after forecast_date.
|
||||
This ensures forecasts remain cached throughout the forecasted day
|
||||
but don't become stale.
|
||||
|
||||
Args:
|
||||
forecast_date: Date of the forecast
|
||||
|
||||
Returns:
|
||||
TTL in seconds
|
||||
"""
|
||||
# Expire at midnight after forecast_date
|
||||
expiry_datetime = datetime.combine(
|
||||
forecast_date + timedelta(days=1),
|
||||
datetime.min.time()
|
||||
)
|
||||
|
||||
now = datetime.now()
|
||||
ttl_seconds = int((expiry_datetime - now).total_seconds())
|
||||
|
||||
# Minimum TTL of 1 hour, maximum of 48 hours
|
||||
return max(3600, min(ttl_seconds, 172800))
|
||||
|
||||
async def get_cached_forecast(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
product_id: UUID,
|
||||
forecast_date: date
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve cached forecast if available
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
product_id: Product identifier
|
||||
forecast_date: Date of forecast
|
||||
|
||||
Returns:
|
||||
Cached forecast data or None if not found
|
||||
"""
|
||||
if not await self.is_available():
|
||||
return None
|
||||
|
||||
try:
|
||||
key = self._get_forecast_key(tenant_id, product_id, forecast_date)
|
||||
client = await self._get_redis()
|
||||
cached_data = await client.get(key)
|
||||
|
||||
if cached_data:
|
||||
forecast_data = json.loads(cached_data)
|
||||
# Add cache hit metadata
|
||||
forecast_data['cached'] = True
|
||||
forecast_data['cache_hit_at'] = datetime.now().isoformat()
|
||||
|
||||
logger.info("Forecast cache HIT",
|
||||
tenant_id=str(tenant_id),
|
||||
product_id=str(product_id),
|
||||
forecast_date=str(forecast_date))
|
||||
return forecast_data
|
||||
|
||||
logger.debug("Forecast cache MISS",
|
||||
tenant_id=str(tenant_id),
|
||||
product_id=str(product_id),
|
||||
forecast_date=str(forecast_date))
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error retrieving cached forecast",
|
||||
error=str(e),
|
||||
tenant_id=str(tenant_id))
|
||||
return None
|
||||
|
||||
async def cache_forecast(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
product_id: UUID,
|
||||
forecast_date: date,
|
||||
forecast_data: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""
|
||||
Cache forecast prediction result
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
product_id: Product identifier
|
||||
forecast_date: Date of forecast
|
||||
forecast_data: Forecast prediction data to cache
|
||||
|
||||
Returns:
|
||||
True if cached successfully, False otherwise
|
||||
"""
|
||||
if not await self.is_available():
|
||||
logger.warning("Redis not available, skipping forecast cache")
|
||||
return False
|
||||
|
||||
try:
|
||||
key = self._get_forecast_key(tenant_id, product_id, forecast_date)
|
||||
ttl = self._calculate_ttl(forecast_date)
|
||||
|
||||
# Add caching metadata
|
||||
cache_entry = {
|
||||
**forecast_data,
|
||||
'cached_at': datetime.now().isoformat(),
|
||||
'cache_key': key,
|
||||
'ttl_seconds': ttl
|
||||
}
|
||||
|
||||
# Serialize and cache
|
||||
client = await self._get_redis()
|
||||
await client.setex(
|
||||
key,
|
||||
ttl,
|
||||
json.dumps(cache_entry, default=str)
|
||||
)
|
||||
|
||||
logger.info("Forecast cached successfully",
|
||||
tenant_id=str(tenant_id),
|
||||
product_id=str(product_id),
|
||||
forecast_date=str(forecast_date),
|
||||
ttl_hours=round(ttl / 3600, 2))
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error caching forecast",
|
||||
error=str(e),
|
||||
tenant_id=str(tenant_id))
|
||||
return False
|
||||
|
||||
async def get_cached_batch_forecast(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
product_ids: List[UUID],
|
||||
forecast_date: date
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve cached batch forecast
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
product_ids: List of product identifiers
|
||||
forecast_date: Date of forecast
|
||||
|
||||
Returns:
|
||||
Cached batch forecast data or None
|
||||
"""
|
||||
if not await self.is_available():
|
||||
return None
|
||||
|
||||
try:
|
||||
key = self._get_batch_forecast_key(tenant_id, product_ids, forecast_date)
|
||||
client = await self._get_redis()
|
||||
cached_data = await client.get(key)
|
||||
|
||||
if cached_data:
|
||||
forecast_data = json.loads(cached_data)
|
||||
forecast_data['cached'] = True
|
||||
forecast_data['cache_hit_at'] = datetime.now().isoformat()
|
||||
|
||||
logger.info("Batch forecast cache HIT",
|
||||
tenant_id=str(tenant_id),
|
||||
products_count=len(product_ids),
|
||||
forecast_date=str(forecast_date))
|
||||
return forecast_data
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error retrieving cached batch forecast", error=str(e))
|
||||
return None
|
||||
|
||||
async def cache_batch_forecast(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
product_ids: List[UUID],
|
||||
forecast_date: date,
|
||||
forecast_data: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""Cache batch forecast result"""
|
||||
if not await self.is_available():
|
||||
return False
|
||||
|
||||
try:
|
||||
key = self._get_batch_forecast_key(tenant_id, product_ids, forecast_date)
|
||||
ttl = self._calculate_ttl(forecast_date)
|
||||
|
||||
cache_entry = {
|
||||
**forecast_data,
|
||||
'cached_at': datetime.now().isoformat(),
|
||||
'cache_key': key,
|
||||
'ttl_seconds': ttl
|
||||
}
|
||||
|
||||
client = await self._get_redis()
|
||||
await client.setex(key, ttl, json.dumps(cache_entry, default=str))
|
||||
|
||||
logger.info("Batch forecast cached successfully",
|
||||
tenant_id=str(tenant_id),
|
||||
products_count=len(product_ids),
|
||||
ttl_hours=round(ttl / 3600, 2))
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error caching batch forecast", error=str(e))
|
||||
return False
|
||||
|
||||
# ================================================================
|
||||
# CACHE INVALIDATION
|
||||
# ================================================================
|
||||
|
||||
async def invalidate_product_forecasts(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
product_id: UUID
|
||||
) -> int:
|
||||
"""
|
||||
Invalidate all forecast cache entries for a product
|
||||
|
||||
Called when model is retrained for specific product.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
product_id: Product identifier
|
||||
|
||||
Returns:
|
||||
Number of cache entries invalidated
|
||||
"""
|
||||
if not await self.is_available():
|
||||
return 0
|
||||
|
||||
try:
|
||||
# Find all keys matching this product
|
||||
pattern = f"forecast:{tenant_id}:{product_id}:*"
|
||||
client = await self._get_redis()
|
||||
keys = await client.keys(pattern)
|
||||
|
||||
if keys:
|
||||
deleted = await client.delete(*keys)
|
||||
logger.info("Invalidated product forecast cache",
|
||||
tenant_id=str(tenant_id),
|
||||
product_id=str(product_id),
|
||||
keys_deleted=deleted)
|
||||
return deleted
|
||||
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error invalidating product forecasts",
|
||||
error=str(e),
|
||||
tenant_id=str(tenant_id))
|
||||
return 0
|
||||
|
||||
async def invalidate_tenant_forecasts(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
forecast_date: Optional[date] = None
|
||||
) -> int:
|
||||
"""
|
||||
Invalidate forecast cache for tenant
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
forecast_date: Optional specific date to invalidate
|
||||
|
||||
Returns:
|
||||
Number of cache entries invalidated
|
||||
"""
|
||||
if not await self.is_available():
|
||||
return 0
|
||||
|
||||
try:
|
||||
if forecast_date:
|
||||
pattern = f"forecast:{tenant_id}:*:{forecast_date.isoformat()}"
|
||||
else:
|
||||
pattern = f"forecast:{tenant_id}:*"
|
||||
|
||||
client = await self._get_redis()
|
||||
keys = await client.keys(pattern)
|
||||
|
||||
if keys:
|
||||
deleted = await client.delete(*keys)
|
||||
logger.info("Invalidated tenant forecast cache",
|
||||
tenant_id=str(tenant_id),
|
||||
forecast_date=str(forecast_date) if forecast_date else "all",
|
||||
keys_deleted=deleted)
|
||||
return deleted
|
||||
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error invalidating tenant forecasts", error=str(e))
|
||||
return 0
|
||||
|
||||
async def invalidate_all_forecasts(self) -> int:
|
||||
"""
|
||||
Invalidate all forecast cache entries (use with caution)
|
||||
|
||||
Returns:
|
||||
Number of cache entries invalidated
|
||||
"""
|
||||
if not await self.is_available():
|
||||
return 0
|
||||
|
||||
try:
|
||||
pattern = "forecast:*"
|
||||
client = await self._get_redis()
|
||||
keys = await client.keys(pattern)
|
||||
|
||||
if keys:
|
||||
deleted = await client.delete(*keys)
|
||||
logger.warning("Invalidated ALL forecast cache", keys_deleted=deleted)
|
||||
return deleted
|
||||
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error invalidating all forecasts", error=str(e))
|
||||
return 0
|
||||
|
||||
# ================================================================
|
||||
# CACHE STATISTICS & MONITORING
|
||||
# ================================================================
|
||||
|
||||
async def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get cache statistics for monitoring
|
||||
|
||||
Returns:
|
||||
Dictionary with cache metrics
|
||||
"""
|
||||
if not await self.is_available():
|
||||
return {"available": False}
|
||||
|
||||
try:
|
||||
client = await self._get_redis()
|
||||
info = await client.info()
|
||||
|
||||
# Get forecast-specific stats
|
||||
forecast_keys = await client.keys("forecast:*")
|
||||
batch_keys = await client.keys("forecast:batch:*")
|
||||
|
||||
return {
|
||||
"available": True,
|
||||
"total_forecast_keys": len(forecast_keys),
|
||||
"batch_forecast_keys": len(batch_keys),
|
||||
"single_forecast_keys": len(forecast_keys) - len(batch_keys),
|
||||
"used_memory": info.get("used_memory_human"),
|
||||
"connected_clients": info.get("connected_clients"),
|
||||
"keyspace_hits": info.get("keyspace_hits", 0),
|
||||
"keyspace_misses": info.get("keyspace_misses", 0),
|
||||
"hit_rate_percent": self._calculate_hit_rate(
|
||||
info.get("keyspace_hits", 0),
|
||||
info.get("keyspace_misses", 0)
|
||||
),
|
||||
"total_commands_processed": info.get("total_commands_processed", 0)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error("Error getting cache stats", error=str(e))
|
||||
return {"available": False, "error": str(e)}
|
||||
|
||||
def _calculate_hit_rate(self, hits: int, misses: int) -> float:
|
||||
"""Calculate cache hit rate percentage"""
|
||||
total = hits + misses
|
||||
return round((hits / total * 100), 2) if total > 0 else 0.0
|
||||
|
||||
async def get_cached_forecast_info(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
product_id: UUID,
|
||||
forecast_date: date
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get metadata about cached forecast without retrieving full data
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
product_id: Product identifier
|
||||
forecast_date: Date of forecast
|
||||
|
||||
Returns:
|
||||
Cache metadata or None
|
||||
"""
|
||||
if not await self.is_available():
|
||||
return None
|
||||
|
||||
try:
|
||||
key = self._get_forecast_key(tenant_id, product_id, forecast_date)
|
||||
client = await self._get_redis()
|
||||
ttl = await client.ttl(key)
|
||||
|
||||
if ttl > 0:
|
||||
return {
|
||||
"cached": True,
|
||||
"cache_key": key,
|
||||
"ttl_seconds": ttl,
|
||||
"ttl_hours": round(ttl / 3600, 2),
|
||||
"expires_at": (datetime.now() + timedelta(seconds=ttl)).isoformat()
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting forecast cache info", error=str(e))
|
||||
return None
|
||||
|
||||
|
||||
# Global cache service instance
|
||||
_cache_service = None
|
||||
|
||||
|
||||
def get_forecast_cache_service() -> ForecastCacheService:
|
||||
"""
|
||||
Get the global forecast cache service instance
|
||||
|
||||
Returns:
|
||||
ForecastCacheService instance
|
||||
"""
|
||||
global _cache_service
|
||||
|
||||
if _cache_service is None:
|
||||
_cache_service = ForecastCacheService()
|
||||
|
||||
return _cache_service
|
||||
533
services/forecasting/app/services/forecast_feedback_service.py
Normal file
533
services/forecasting/app/services/forecast_feedback_service.py
Normal file
@@ -0,0 +1,533 @@
|
||||
# services/forecasting/app/services/forecast_feedback_service.py
|
||||
"""
|
||||
Forecast Feedback Service
|
||||
Business logic for collecting and analyzing forecast feedback
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime, timedelta, date
|
||||
import uuid
|
||||
import structlog
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForecastFeedback:
|
||||
"""Data class for forecast feedback"""
|
||||
feedback_id: uuid.UUID
|
||||
forecast_id: uuid.UUID
|
||||
tenant_id: str
|
||||
feedback_type: str
|
||||
confidence: str
|
||||
actual_value: Optional[float]
|
||||
notes: Optional[str]
|
||||
feedback_data: Dict[str, Any]
|
||||
created_at: datetime
|
||||
created_by: Optional[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForecastAccuracyMetrics:
|
||||
"""Data class for forecast accuracy metrics"""
|
||||
forecast_id: str
|
||||
total_feedback_count: int
|
||||
accuracy_score: float
|
||||
feedback_distribution: Dict[str, int]
|
||||
average_confidence: float
|
||||
last_feedback_date: Optional[datetime]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForecasterPerformanceMetrics:
|
||||
"""Data class for forecaster performance metrics"""
|
||||
overall_accuracy: float
|
||||
total_forecasts_with_feedback: int
|
||||
accuracy_by_product: Dict[str, float]
|
||||
accuracy_trend: str
|
||||
improvement_suggestions: List[str]
|
||||
|
||||
|
||||
class ForecastFeedbackService:
|
||||
"""
|
||||
Service for managing forecast feedback and accuracy tracking
|
||||
"""
|
||||
|
||||
def __init__(self, database_manager):
|
||||
self.database_manager = database_manager
|
||||
|
||||
async def forecast_exists(self, tenant_id: str, forecast_id: str) -> bool:
|
||||
"""
|
||||
Check if a forecast exists
|
||||
"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as session:
|
||||
from app.models.forecasts import Forecast
|
||||
|
||||
result = await session.execute(
|
||||
"""
|
||||
SELECT 1 FROM forecasts
|
||||
WHERE tenant_id = :tenant_id AND id = :forecast_id
|
||||
""",
|
||||
{"tenant_id": tenant_id, "forecast_id": forecast_id}
|
||||
)
|
||||
return result.scalar() is not None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to check forecast existence", error=str(e))
|
||||
raise Exception(f"Failed to check forecast existence: {str(e)}")
|
||||
|
||||
async def submit_feedback(
|
||||
self,
|
||||
tenant_id: str,
|
||||
forecast_id: str,
|
||||
feedback_type: str,
|
||||
confidence: str,
|
||||
actual_value: Optional[float] = None,
|
||||
notes: Optional[str] = None,
|
||||
feedback_data: Optional[Dict[str, Any]] = None
|
||||
) -> ForecastFeedback:
|
||||
"""
|
||||
Submit feedback on forecast accuracy
|
||||
"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as session:
|
||||
# Create feedback record
|
||||
feedback_id = uuid.uuid4()
|
||||
created_at = datetime.now()
|
||||
|
||||
# In a real implementation, this would insert into a forecast_feedback table
|
||||
# For demo purposes, we'll simulate the database operation
|
||||
|
||||
feedback = ForecastFeedback(
|
||||
feedback_id=feedback_id,
|
||||
forecast_id=uuid.UUID(forecast_id),
|
||||
tenant_id=tenant_id,
|
||||
feedback_type=feedback_type,
|
||||
confidence=confidence,
|
||||
actual_value=actual_value,
|
||||
notes=notes,
|
||||
feedback_data=feedback_data or {},
|
||||
created_at=created_at,
|
||||
created_by="system" # In real implementation, this would be the user ID
|
||||
)
|
||||
|
||||
# Simulate database insert
|
||||
logger.info("Feedback submitted",
|
||||
feedback_id=str(feedback_id),
|
||||
forecast_id=forecast_id,
|
||||
feedback_type=feedback_type)
|
||||
|
||||
return feedback
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to submit feedback", error=str(e))
|
||||
raise Exception(f"Failed to submit feedback: {str(e)}")
|
||||
|
||||
async def get_feedback_for_forecast(
|
||||
self,
|
||||
tenant_id: str,
|
||||
forecast_id: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0
|
||||
) -> List[ForecastFeedback]:
|
||||
"""
|
||||
Get all feedback for a specific forecast
|
||||
"""
|
||||
try:
|
||||
# In a real implementation, this would query the forecast_feedback table
|
||||
# For demo purposes, we'll return simulated data
|
||||
|
||||
# Simulate some feedback data
|
||||
simulated_feedback = []
|
||||
|
||||
for i in range(min(limit, 3)): # Return up to 3 simulated feedback items
|
||||
feedback = ForecastFeedback(
|
||||
feedback_id=uuid.uuid4(),
|
||||
forecast_id=uuid.UUID(forecast_id),
|
||||
tenant_id=tenant_id,
|
||||
feedback_type=["too_high", "too_low", "accurate"][i % 3],
|
||||
confidence=["medium", "high", "low"][i % 3],
|
||||
actual_value=150.0 + i * 20 if i < 2 else None,
|
||||
notes=f"Feedback sample {i+1}" if i == 0 else None,
|
||||
feedback_data={"sample": i+1, "demo": True},
|
||||
created_at=datetime.now() - timedelta(days=i),
|
||||
created_by="demo_user"
|
||||
)
|
||||
simulated_feedback.append(feedback)
|
||||
|
||||
return simulated_feedback
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get feedback for forecast", error=str(e))
|
||||
raise Exception(f"Failed to get feedback: {str(e)}")
|
||||
|
||||
async def calculate_accuracy_metrics(
|
||||
self,
|
||||
tenant_id: str,
|
||||
forecast_id: str
|
||||
) -> ForecastAccuracyMetrics:
|
||||
"""
|
||||
Calculate accuracy metrics for a forecast
|
||||
"""
|
||||
try:
|
||||
# Get feedback for this forecast
|
||||
feedback_list = await self.get_feedback_for_forecast(tenant_id, forecast_id)
|
||||
|
||||
if not feedback_list:
|
||||
return None
|
||||
|
||||
# Calculate metrics
|
||||
total_feedback = len(feedback_list)
|
||||
|
||||
# Count feedback distribution
|
||||
feedback_distribution = {
|
||||
"too_high": 0,
|
||||
"too_low": 0,
|
||||
"accurate": 0,
|
||||
"uncertain": 0
|
||||
}
|
||||
|
||||
confidence_scores = {
|
||||
"low": 1,
|
||||
"medium": 2,
|
||||
"high": 3
|
||||
}
|
||||
|
||||
total_confidence = 0
|
||||
|
||||
for feedback in feedback_list:
|
||||
feedback_distribution[feedback.feedback_type] += 1
|
||||
total_confidence += confidence_scores.get(feedback.confidence, 1)
|
||||
|
||||
# Calculate accuracy score (simplified)
|
||||
accurate_count = feedback_distribution["accurate"]
|
||||
accuracy_score = (accurate_count / total_feedback) * 100
|
||||
|
||||
# Adjust for confidence
|
||||
avg_confidence = total_confidence / total_feedback
|
||||
adjusted_accuracy = accuracy_score * (avg_confidence / 3) # Normalize confidence to 0-1 range
|
||||
|
||||
return ForecastAccuracyMetrics(
|
||||
forecast_id=forecast_id,
|
||||
total_feedback_count=total_feedback,
|
||||
accuracy_score=round(adjusted_accuracy, 1),
|
||||
feedback_distribution=feedback_distribution,
|
||||
average_confidence=round(avg_confidence, 1),
|
||||
last_feedback_date=max(f.created_at for f in feedback_list)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to calculate accuracy metrics", error=str(e))
|
||||
raise Exception(f"Failed to calculate metrics: {str(e)}")
|
||||
|
||||
async def calculate_performance_summary(
|
||||
self,
|
||||
tenant_id: str,
|
||||
start_date: Optional[date] = None,
|
||||
end_date: Optional[date] = None,
|
||||
product_id: Optional[str] = None
|
||||
) -> ForecasterPerformanceMetrics:
|
||||
"""
|
||||
Calculate overall forecaster performance summary
|
||||
"""
|
||||
try:
|
||||
# In a real implementation, this would aggregate data across multiple forecasts
|
||||
# For demo purposes, we'll return simulated metrics
|
||||
|
||||
# Simulate performance data
|
||||
accuracy_by_product = {
|
||||
"baguette": 85.5,
|
||||
"croissant": 78.2,
|
||||
"pain_au_chocolat": 92.1
|
||||
}
|
||||
|
||||
if product_id and product_id in accuracy_by_product:
|
||||
# Return metrics for specific product
|
||||
product_accuracy = accuracy_by_product[product_id]
|
||||
accuracy_by_product = {product_id: product_accuracy}
|
||||
|
||||
# Calculate overall accuracy
|
||||
overall_accuracy = sum(accuracy_by_product.values()) / len(accuracy_by_product)
|
||||
|
||||
# Determine trend (simulated)
|
||||
trend_data = [82.3, 84.1, 85.5, 86.8, 88.2] # Last 5 periods
|
||||
if trend_data[-1] > trend_data[0]:
|
||||
trend = "improving"
|
||||
elif trend_data[-1] < trend_data[0]:
|
||||
trend = "declining"
|
||||
else:
|
||||
trend = "stable"
|
||||
|
||||
# Generate improvement suggestions
|
||||
suggestions = []
|
||||
|
||||
for product, accuracy in accuracy_by_product.items():
|
||||
if accuracy < 80:
|
||||
suggestions.append(f"Improve {product} forecast accuracy (current: {accuracy}%)")
|
||||
elif accuracy < 90:
|
||||
suggestions.append(f"Consider fine-tuning {product} forecast model (current: {accuracy}%)")
|
||||
|
||||
if not suggestions:
|
||||
suggestions.append("Overall forecast accuracy is excellent - maintain current approach")
|
||||
|
||||
return ForecasterPerformanceMetrics(
|
||||
overall_accuracy=round(overall_accuracy, 1),
|
||||
total_forecasts_with_feedback=42,
|
||||
accuracy_by_product=accuracy_by_product,
|
||||
accuracy_trend=trend,
|
||||
improvement_suggestions=suggestions
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to calculate performance summary", error=str(e))
|
||||
raise Exception(f"Failed to calculate summary: {str(e)}")
|
||||
|
||||
async def get_feedback_trends(
|
||||
self,
|
||||
tenant_id: str,
|
||||
days: int = 30,
|
||||
product_id: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get feedback trends over time
|
||||
"""
|
||||
try:
|
||||
# Simulate trend data
|
||||
trends = []
|
||||
end_date = datetime.now()
|
||||
|
||||
# Generate daily trend data
|
||||
for i in range(days):
|
||||
date = end_date - timedelta(days=i)
|
||||
|
||||
# Simulate varying accuracy with weekly pattern
|
||||
base_accuracy = 85.0
|
||||
weekly_variation = 3.0 * (i % 7 / 6 - 0.5) # Weekly pattern
|
||||
daily_noise = (i % 3 - 1) * 1.5 # Daily noise
|
||||
|
||||
accuracy = max(70, min(95, base_accuracy + weekly_variation + daily_noise))
|
||||
|
||||
trends.append({
|
||||
'date': date.strftime('%Y-%m-%d'),
|
||||
'accuracy_score': round(accuracy, 1),
|
||||
'feedback_count': max(1, int(5 + i % 10)),
|
||||
'confidence_score': round(2.5 + (i % 5 - 2) * 0.2, 1)
|
||||
})
|
||||
|
||||
# Sort by date (oldest first)
|
||||
trends.sort(key=lambda x: x['date'])
|
||||
|
||||
return trends
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get feedback trends", error=str(e))
|
||||
raise Exception(f"Failed to get trends: {str(e)}")
|
||||
|
||||
async def trigger_retraining_from_feedback(
|
||||
self,
|
||||
tenant_id: str,
|
||||
forecast_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Trigger model retraining based on feedback
|
||||
"""
|
||||
try:
|
||||
# In a real implementation, this would:
|
||||
# 1. Collect recent feedback data
|
||||
# 2. Prepare training dataset
|
||||
# 3. Submit retraining job to ML service
|
||||
# 4. Return job ID
|
||||
|
||||
# For demo purposes, simulate a retraining job
|
||||
job_id = str(uuid.uuid4())
|
||||
|
||||
logger.info("Retraining job triggered",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
forecast_id=forecast_id)
|
||||
|
||||
return {
|
||||
'job_id': job_id,
|
||||
'forecasts_included': 15,
|
||||
'feedback_samples_used': 42,
|
||||
'status': 'queued',
|
||||
'estimated_completion': (datetime.now() + timedelta(minutes=30)).isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to trigger retraining", error=str(e))
|
||||
raise Exception(f"Failed to trigger retraining: {str(e)}")
|
||||
|
||||
async def get_improvement_suggestions(
|
||||
self,
|
||||
tenant_id: str,
|
||||
forecast_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get AI-generated improvement suggestions
|
||||
"""
|
||||
try:
|
||||
# Get accuracy metrics for this forecast
|
||||
metrics = await self.calculate_accuracy_metrics(tenant_id, forecast_id)
|
||||
|
||||
if not metrics:
|
||||
return [
|
||||
{
|
||||
'suggestion': 'Insufficient feedback data to generate suggestions',
|
||||
'type': 'data',
|
||||
'priority': 'low',
|
||||
'confidence': 0.7
|
||||
}
|
||||
]
|
||||
|
||||
# Generate suggestions based on metrics
|
||||
suggestions = []
|
||||
|
||||
# Analyze feedback distribution
|
||||
feedback_dist = metrics.feedback_distribution
|
||||
total_feedback = metrics.total_feedback_count
|
||||
|
||||
if feedback_dist['too_high'] > total_feedback * 0.4:
|
||||
suggestions.append({
|
||||
'suggestion': 'Forecasts are consistently too high - consider adjusting demand estimation parameters',
|
||||
'type': 'bias',
|
||||
'priority': 'high',
|
||||
'confidence': 0.9,
|
||||
'details': {
|
||||
'too_high_percentage': feedback_dist['too_high'] / total_feedback * 100,
|
||||
'recommended_action': 'Reduce demand estimation by 10-15%'
|
||||
}
|
||||
})
|
||||
|
||||
if feedback_dist['too_low'] > total_feedback * 0.4:
|
||||
suggestions.append({
|
||||
'suggestion': 'Forecasts are consistently too low - consider increasing demand estimation parameters',
|
||||
'type': 'bias',
|
||||
'priority': 'high',
|
||||
'confidence': 0.9,
|
||||
'details': {
|
||||
'too_low_percentage': feedback_dist['too_low'] / total_feedback * 100,
|
||||
'recommended_action': 'Increase demand estimation by 10-15%'
|
||||
}
|
||||
})
|
||||
|
||||
if metrics.accuracy_score < 70:
|
||||
suggestions.append({
|
||||
'suggestion': 'Low overall accuracy - consider comprehensive model review and retraining',
|
||||
'type': 'model',
|
||||
'priority': 'critical',
|
||||
'confidence': 0.85,
|
||||
'details': {
|
||||
'current_accuracy': metrics.accuracy_score,
|
||||
'recommended_action': 'Full model retraining with expanded feature set'
|
||||
}
|
||||
})
|
||||
elif metrics.accuracy_score < 85:
|
||||
suggestions.append({
|
||||
'suggestion': 'Moderate accuracy - consider feature engineering improvements',
|
||||
'type': 'features',
|
||||
'priority': 'medium',
|
||||
'confidence': 0.8,
|
||||
'details': {
|
||||
'current_accuracy': metrics.accuracy_score,
|
||||
'recommended_action': 'Add weather data, promotions, and seasonal features'
|
||||
}
|
||||
})
|
||||
|
||||
if metrics.average_confidence < 2.0: # Average of medium (2) and high (3)
|
||||
suggestions.append({
|
||||
'suggestion': 'Low confidence in feedback - consider improving feedback collection process',
|
||||
'type': 'process',
|
||||
'priority': 'medium',
|
||||
'confidence': 0.75,
|
||||
'details': {
|
||||
'average_confidence': metrics.average_confidence,
|
||||
'recommended_action': 'Provide clearer guidance to users on feedback submission'
|
||||
}
|
||||
})
|
||||
|
||||
if not suggestions:
|
||||
suggestions.append({
|
||||
'suggestion': 'Forecast accuracy is good - consider expanding to additional products',
|
||||
'type': 'expansion',
|
||||
'priority': 'low',
|
||||
'confidence': 0.85,
|
||||
'details': {
|
||||
'current_accuracy': metrics.accuracy_score,
|
||||
'recommended_action': 'Extend forecasting to new product categories'
|
||||
}
|
||||
})
|
||||
|
||||
return suggestions
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to generate improvement suggestions", error=str(e))
|
||||
raise Exception(f"Failed to generate suggestions: {str(e)}")
|
||||
|
||||
|
||||
# Helper class for feedback analysis
|
||||
class FeedbackAnalyzer:
|
||||
"""
|
||||
Helper class for analyzing feedback patterns
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def detect_feedback_patterns(feedback_list: List[ForecastFeedback]) -> Dict[str, Any]:
|
||||
"""
|
||||
Detect patterns in feedback data
|
||||
"""
|
||||
if not feedback_list:
|
||||
return {'patterns': [], 'anomalies': []}
|
||||
|
||||
patterns = []
|
||||
anomalies = []
|
||||
|
||||
# Simple pattern detection (in real implementation, this would be more sophisticated)
|
||||
feedback_types = [f.feedback_type for f in feedback_list]
|
||||
|
||||
if len(set(feedback_types)) == 1:
|
||||
patterns.append({
|
||||
'type': 'consistent_feedback',
|
||||
'pattern': f'All feedback is "{feedback_types[0]}"',
|
||||
'confidence': 0.9
|
||||
})
|
||||
|
||||
return {'patterns': patterns, 'anomalies': anomalies}
|
||||
|
||||
|
||||
# Helper class for accuracy calculation
|
||||
class AccuracyCalculator:
|
||||
"""
|
||||
Helper class for calculating forecast accuracy metrics
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def calculate_mape(actual: float, predicted: float) -> float:
|
||||
"""
|
||||
Calculate Mean Absolute Percentage Error
|
||||
"""
|
||||
if actual == 0:
|
||||
return 0.0
|
||||
return abs((actual - predicted) / actual) * 100
|
||||
|
||||
@staticmethod
|
||||
def calculate_rmse(actual: float, predicted: float) -> float:
|
||||
"""
|
||||
Calculate Root Mean Squared Error
|
||||
"""
|
||||
return (actual - predicted) ** 2
|
||||
|
||||
@staticmethod
|
||||
def feedback_to_accuracy_score(feedback_type: str) -> float:
|
||||
"""
|
||||
Convert feedback type to accuracy score
|
||||
"""
|
||||
feedback_scores = {
|
||||
'accurate': 100,
|
||||
'too_high': 50,
|
||||
'too_low': 50,
|
||||
'uncertain': 75
|
||||
}
|
||||
return feedback_scores.get(feedback_type, 75)
|
||||
338
services/forecasting/app/services/forecasting_alert_service.py
Normal file
338
services/forecasting/app/services/forecasting_alert_service.py
Normal file
@@ -0,0 +1,338 @@
|
||||
"""
|
||||
Forecasting Alert Service - Simplified
|
||||
|
||||
Emits minimal events using EventPublisher.
|
||||
All enrichment handled by alert_processor.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import List, Dict, Any, Optional
|
||||
from uuid import UUID
|
||||
from datetime import datetime, timedelta
|
||||
import structlog
|
||||
|
||||
from shared.messaging import UnifiedEventPublisher
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class ForecastingAlertService:
|
||||
"""Simplified forecasting alert service using EventPublisher"""
|
||||
|
||||
def __init__(self, event_publisher: UnifiedEventPublisher):
|
||||
self.publisher = event_publisher
|
||||
|
||||
async def start(self):
|
||||
"""Start the forecasting alert service"""
|
||||
logger.info("ForecastingAlertService started")
|
||||
# Add any initialization logic here if needed
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the forecasting alert service"""
|
||||
logger.info("ForecastingAlertService stopped")
|
||||
# Add any cleanup logic here if needed
|
||||
|
||||
async def health_check(self):
|
||||
"""Health check for the forecasting alert service"""
|
||||
try:
|
||||
# Check if the event publisher is available and operational
|
||||
if hasattr(self, 'publisher') and self.publisher:
|
||||
# Basic check if publisher is available
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error("ForecastingAlertService health check failed", error=str(e))
|
||||
return False
|
||||
|
||||
async def emit_demand_surge_weekend(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
product_name: str,
|
||||
inventory_product_id: str,
|
||||
predicted_demand: float,
|
||||
growth_percentage: float,
|
||||
forecast_date: str,
|
||||
weather_favorable: bool = False
|
||||
):
|
||||
"""Emit weekend demand surge alert"""
|
||||
|
||||
# Determine severity based on growth magnitude
|
||||
if growth_percentage > 100:
|
||||
severity = 'high'
|
||||
elif growth_percentage > 75:
|
||||
severity = 'medium'
|
||||
else:
|
||||
severity = 'low'
|
||||
|
||||
metadata = {
|
||||
"product_name": product_name,
|
||||
"inventory_product_id": str(inventory_product_id),
|
||||
"predicted_demand": float(predicted_demand),
|
||||
"growth_percentage": float(growth_percentage),
|
||||
"forecast_date": forecast_date,
|
||||
"weather_favorable": weather_favorable
|
||||
}
|
||||
|
||||
await self.publisher.publish_alert(
|
||||
event_type="forecasting.demand_surge_weekend",
|
||||
tenant_id=tenant_id,
|
||||
severity=severity,
|
||||
data=metadata
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"demand_surge_weekend_emitted",
|
||||
tenant_id=str(tenant_id),
|
||||
product_name=product_name,
|
||||
growth_percentage=growth_percentage
|
||||
)
|
||||
|
||||
async def emit_weather_impact_alert(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
forecast_date: str,
|
||||
precipitation: float,
|
||||
expected_demand_change: float,
|
||||
traffic_volume: int,
|
||||
weather_type: str = "general",
|
||||
product_name: Optional[str] = None
|
||||
):
|
||||
"""Emit weather impact alert"""
|
||||
|
||||
# Determine severity based on impact
|
||||
if expected_demand_change < -20:
|
||||
severity = 'high'
|
||||
elif expected_demand_change < -10:
|
||||
severity = 'medium'
|
||||
else:
|
||||
severity = 'low'
|
||||
|
||||
metadata = {
|
||||
"forecast_date": forecast_date,
|
||||
"precipitation_mm": float(precipitation),
|
||||
"expected_demand_change": float(expected_demand_change),
|
||||
"traffic_volume": traffic_volume,
|
||||
"weather_type": weather_type
|
||||
}
|
||||
|
||||
if product_name:
|
||||
metadata["product_name"] = product_name
|
||||
|
||||
# Add triggers information
|
||||
triggers = ['weather_conditions', 'demand_forecast']
|
||||
if precipitation > 0:
|
||||
triggers.append('rain_forecast')
|
||||
if expected_demand_change < -15:
|
||||
triggers.append('outdoor_events_cancelled')
|
||||
|
||||
metadata["triggers"] = triggers
|
||||
|
||||
await self.publisher.publish_alert(
|
||||
event_type="forecasting.weather_impact_alert",
|
||||
tenant_id=tenant_id,
|
||||
severity=severity,
|
||||
data=metadata
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"weather_impact_alert_emitted",
|
||||
tenant_id=str(tenant_id),
|
||||
weather_type=weather_type,
|
||||
expected_demand_change=expected_demand_change
|
||||
)
|
||||
|
||||
async def emit_holiday_preparation(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
holiday_name: str,
|
||||
days_until_holiday: int,
|
||||
product_name: str,
|
||||
spike_percentage: float,
|
||||
avg_holiday_demand: float,
|
||||
avg_normal_demand: float,
|
||||
holiday_date: str
|
||||
):
|
||||
"""Emit holiday preparation alert"""
|
||||
|
||||
# Determine severity based on spike magnitude and preparation time
|
||||
if spike_percentage > 75 and days_until_holiday <= 3:
|
||||
severity = 'high'
|
||||
elif spike_percentage > 50 or days_until_holiday <= 3:
|
||||
severity = 'medium'
|
||||
else:
|
||||
severity = 'low'
|
||||
|
||||
metadata = {
|
||||
"holiday_name": holiday_name,
|
||||
"days_until_holiday": days_until_holiday,
|
||||
"product_name": product_name,
|
||||
"spike_percentage": float(spike_percentage),
|
||||
"avg_holiday_demand": float(avg_holiday_demand),
|
||||
"avg_normal_demand": float(avg_normal_demand),
|
||||
"holiday_date": holiday_date
|
||||
}
|
||||
|
||||
# Add triggers information
|
||||
triggers = [f'spanish_holiday_in_{days_until_holiday}_days']
|
||||
if spike_percentage > 25:
|
||||
triggers.append('historical_demand_spike')
|
||||
|
||||
metadata["triggers"] = triggers
|
||||
|
||||
await self.publisher.publish_alert(
|
||||
event_type="forecasting.holiday_preparation",
|
||||
tenant_id=tenant_id,
|
||||
severity=severity,
|
||||
data=metadata
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"holiday_preparation_emitted",
|
||||
tenant_id=str(tenant_id),
|
||||
holiday_name=holiday_name,
|
||||
spike_percentage=spike_percentage
|
||||
)
|
||||
|
||||
async def emit_demand_optimization_recommendation(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
product_name: str,
|
||||
optimization_potential: float,
|
||||
peak_demand: float,
|
||||
min_demand: float,
|
||||
demand_range: float
|
||||
):
|
||||
"""Emit demand pattern optimization recommendation"""
|
||||
|
||||
metadata = {
|
||||
"product_name": product_name,
|
||||
"optimization_potential": float(optimization_potential),
|
||||
"peak_demand": float(peak_demand),
|
||||
"min_demand": float(min_demand),
|
||||
"demand_range": float(demand_range)
|
||||
}
|
||||
|
||||
await self.publisher.publish_recommendation(
|
||||
event_type="forecasting.demand_pattern_optimization",
|
||||
tenant_id=tenant_id,
|
||||
data=metadata
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"demand_pattern_optimization_emitted",
|
||||
tenant_id=str(tenant_id),
|
||||
product_name=product_name,
|
||||
optimization_potential=optimization_potential
|
||||
)
|
||||
|
||||
async def emit_demand_spike_detected(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
product_name: str,
|
||||
spike_percentage: float
|
||||
):
|
||||
"""Emit demand spike detected event"""
|
||||
|
||||
# Determine severity based on spike magnitude
|
||||
if spike_percentage > 50:
|
||||
severity = 'high'
|
||||
elif spike_percentage > 20:
|
||||
severity = 'medium'
|
||||
else:
|
||||
severity = 'low'
|
||||
|
||||
metadata = {
|
||||
"product_name": product_name,
|
||||
"spike_percentage": float(spike_percentage),
|
||||
"detection_source": "database"
|
||||
}
|
||||
|
||||
await self.publisher.publish_alert(
|
||||
event_type="forecasting.demand_spike_detected",
|
||||
tenant_id=tenant_id,
|
||||
severity=severity,
|
||||
data=metadata
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"demand_spike_detected_emitted",
|
||||
tenant_id=str(tenant_id),
|
||||
product_name=product_name,
|
||||
spike_percentage=spike_percentage
|
||||
)
|
||||
|
||||
async def emit_severe_weather_impact(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
weather_type: str,
|
||||
severity_level: str,
|
||||
duration_hours: int
|
||||
):
|
||||
"""Emit severe weather impact event"""
|
||||
|
||||
# Determine alert severity based on weather severity
|
||||
if severity_level == 'critical' or duration_hours > 24:
|
||||
alert_severity = 'urgent'
|
||||
elif severity_level == 'high' or duration_hours > 12:
|
||||
alert_severity = 'high'
|
||||
else:
|
||||
alert_severity = 'medium'
|
||||
|
||||
metadata = {
|
||||
"weather_type": weather_type,
|
||||
"severity_level": severity_level,
|
||||
"duration_hours": duration_hours
|
||||
}
|
||||
|
||||
await self.publisher.publish_alert(
|
||||
event_type="forecasting.severe_weather_impact",
|
||||
tenant_id=tenant_id,
|
||||
severity=alert_severity,
|
||||
data=metadata
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"severe_weather_impact_emitted",
|
||||
tenant_id=str(tenant_id),
|
||||
weather_type=weather_type,
|
||||
severity_level=severity_level
|
||||
)
|
||||
|
||||
async def emit_unexpected_demand_spike(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
product_name: str,
|
||||
spike_percentage: float,
|
||||
current_sales: float,
|
||||
forecasted_sales: float
|
||||
):
|
||||
"""Emit unexpected sales spike event"""
|
||||
|
||||
# Determine severity based on spike magnitude
|
||||
if spike_percentage > 75:
|
||||
severity = 'high'
|
||||
elif spike_percentage > 40:
|
||||
severity = 'medium'
|
||||
else:
|
||||
severity = 'low'
|
||||
|
||||
metadata = {
|
||||
"product_name": product_name,
|
||||
"spike_percentage": float(spike_percentage),
|
||||
"current_sales": float(current_sales),
|
||||
"forecasted_sales": float(forecasted_sales)
|
||||
}
|
||||
|
||||
await self.publisher.publish_alert(
|
||||
event_type="forecasting.unexpected_demand_spike",
|
||||
tenant_id=tenant_id,
|
||||
severity=severity,
|
||||
data=metadata
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"unexpected_demand_spike_emitted",
|
||||
tenant_id=str(tenant_id),
|
||||
product_name=product_name,
|
||||
spike_percentage=spike_percentage
|
||||
)
|
||||
@@ -0,0 +1,246 @@
|
||||
"""
|
||||
Forecasting Recommendation Service - Simplified
|
||||
|
||||
Emits minimal events using EventPublisher.
|
||||
All enrichment handled by alert_processor.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Dict, Any, List
|
||||
from uuid import UUID
|
||||
import structlog
|
||||
|
||||
from shared.messaging import UnifiedEventPublisher
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class ForecastingRecommendationService:
|
||||
"""
|
||||
Service for emitting forecasting recommendations using EventPublisher.
|
||||
"""
|
||||
|
||||
def __init__(self, event_publisher: UnifiedEventPublisher):
|
||||
self.publisher = event_publisher
|
||||
|
||||
async def emit_demand_surge_recommendation(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
product_sku: str,
|
||||
product_name: str,
|
||||
predicted_demand: float,
|
||||
normal_demand: float,
|
||||
surge_percentage: float,
|
||||
surge_date: str,
|
||||
confidence_score: float,
|
||||
reasoning: str,
|
||||
) -> None:
|
||||
"""
|
||||
Emit RECOMMENDATION for predicted demand surge.
|
||||
"""
|
||||
metadata = {
|
||||
"product_sku": product_sku,
|
||||
"product_name": product_name,
|
||||
"predicted_demand": float(predicted_demand),
|
||||
"normal_demand": float(normal_demand),
|
||||
"surge_percentage": float(surge_percentage),
|
||||
"surge_date": surge_date,
|
||||
"confidence_score": float(confidence_score),
|
||||
"reasoning": reasoning,
|
||||
"estimated_impact": {
|
||||
"additional_revenue_eur": predicted_demand * 5, # Rough estimate
|
||||
"stockout_risk": "high" if surge_percentage > 50 else "medium",
|
||||
},
|
||||
}
|
||||
|
||||
await self.publisher.publish_recommendation(
|
||||
event_type="demand.demand_surge_predicted",
|
||||
tenant_id=tenant_id,
|
||||
data=metadata
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"demand_surge_recommendation_emitted",
|
||||
tenant_id=str(tenant_id),
|
||||
product_name=product_name,
|
||||
surge_percentage=surge_percentage
|
||||
)
|
||||
|
||||
async def emit_weather_impact_recommendation(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
weather_event: str, # 'rain', 'snow', 'heatwave', etc.
|
||||
forecast_date: str,
|
||||
affected_products: List[Dict[str, Any]],
|
||||
impact_description: str,
|
||||
confidence_score: float,
|
||||
) -> None:
|
||||
"""
|
||||
Emit RECOMMENDATION for weather impact on demand.
|
||||
"""
|
||||
metadata = {
|
||||
"weather_event": weather_event,
|
||||
"forecast_date": forecast_date,
|
||||
"affected_products": affected_products,
|
||||
"impact_description": impact_description,
|
||||
"confidence_score": float(confidence_score),
|
||||
}
|
||||
|
||||
await self.publisher.publish_recommendation(
|
||||
event_type="demand.weather_impact_forecast",
|
||||
tenant_id=tenant_id,
|
||||
data=metadata
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"weather_impact_recommendation_emitted",
|
||||
tenant_id=str(tenant_id),
|
||||
weather_event=weather_event
|
||||
)
|
||||
|
||||
async def emit_holiday_preparation_recommendation(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
holiday_name: str,
|
||||
holiday_date: str,
|
||||
days_until_holiday: int,
|
||||
recommended_products: List[Dict[str, Any]],
|
||||
preparation_tips: List[str],
|
||||
) -> None:
|
||||
"""
|
||||
Emit RECOMMENDATION for holiday preparation.
|
||||
"""
|
||||
metadata = {
|
||||
"holiday_name": holiday_name,
|
||||
"holiday_date": holiday_date,
|
||||
"days_until_holiday": days_until_holiday,
|
||||
"recommended_products": recommended_products,
|
||||
"preparation_tips": preparation_tips,
|
||||
"confidence_score": 0.9, # High confidence for known holidays
|
||||
}
|
||||
|
||||
await self.publisher.publish_recommendation(
|
||||
event_type="demand.holiday_preparation",
|
||||
tenant_id=tenant_id,
|
||||
data=metadata
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"holiday_preparation_recommendation_emitted",
|
||||
tenant_id=str(tenant_id),
|
||||
holiday=holiday_name
|
||||
)
|
||||
|
||||
async def emit_seasonal_trend_recommendation(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
season: str, # 'spring', 'summer', 'fall', 'winter'
|
||||
trend_type: str, # 'increasing', 'decreasing', 'stable'
|
||||
affected_categories: List[str],
|
||||
trend_description: str,
|
||||
suggested_actions: List[str],
|
||||
) -> None:
|
||||
"""
|
||||
Emit RECOMMENDATION for seasonal trend insight.
|
||||
"""
|
||||
metadata = {
|
||||
"season": season,
|
||||
"trend_type": trend_type,
|
||||
"affected_categories": affected_categories,
|
||||
"trend_description": trend_description,
|
||||
"suggested_actions": suggested_actions,
|
||||
"confidence_score": 0.85,
|
||||
}
|
||||
|
||||
await self.publisher.publish_recommendation(
|
||||
event_type="demand.seasonal_trend_insight",
|
||||
tenant_id=tenant_id,
|
||||
data=metadata
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"seasonal_trend_recommendation_emitted",
|
||||
tenant_id=str(tenant_id),
|
||||
season=season
|
||||
)
|
||||
|
||||
async def emit_inventory_optimization_recommendation(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
ingredient_id: str,
|
||||
ingredient_name: str,
|
||||
current_stock: float,
|
||||
optimal_stock: float,
|
||||
unit: str,
|
||||
reason: str,
|
||||
estimated_savings_eur: Optional[float] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Emit RECOMMENDATION for inventory optimization.
|
||||
"""
|
||||
difference = abs(current_stock - optimal_stock)
|
||||
action = "reduce" if current_stock > optimal_stock else "increase"
|
||||
|
||||
estimated_impact = {}
|
||||
if estimated_savings_eur:
|
||||
estimated_impact["financial_savings_eur"] = estimated_savings_eur
|
||||
|
||||
metadata = {
|
||||
"ingredient_id": ingredient_id,
|
||||
"ingredient_name": ingredient_name,
|
||||
"current_stock": float(current_stock),
|
||||
"optimal_stock": float(optimal_stock),
|
||||
"difference": float(difference),
|
||||
"action": action,
|
||||
"unit": unit,
|
||||
"reason": reason,
|
||||
"estimated_impact": estimated_impact if estimated_impact else None,
|
||||
"confidence_score": 0.75,
|
||||
}
|
||||
|
||||
await self.publisher.publish_recommendation(
|
||||
event_type="inventory.inventory_optimization_opportunity",
|
||||
tenant_id=tenant_id,
|
||||
data=metadata
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"inventory_optimization_recommendation_emitted",
|
||||
tenant_id=str(tenant_id),
|
||||
ingredient_name=ingredient_name
|
||||
)
|
||||
|
||||
async def emit_cost_reduction_recommendation(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
opportunity_type: str, # 'supplier_switch', 'bulk_purchase', 'seasonal_buying'
|
||||
title: str,
|
||||
description: str,
|
||||
estimated_savings_eur: float,
|
||||
suggested_actions: List[str],
|
||||
details: Dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Emit RECOMMENDATION for cost reduction opportunity.
|
||||
"""
|
||||
metadata = {
|
||||
"opportunity_type": opportunity_type,
|
||||
"title": title,
|
||||
"description": description,
|
||||
"estimated_savings_eur": float(estimated_savings_eur),
|
||||
"suggested_actions": suggested_actions,
|
||||
"details": details,
|
||||
"confidence_score": 0.8,
|
||||
}
|
||||
|
||||
await self.publisher.publish_recommendation(
|
||||
event_type="supply_chain.cost_reduction_suggestion",
|
||||
tenant_id=tenant_id,
|
||||
data=metadata
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"cost_reduction_recommendation_emitted",
|
||||
tenant_id=str(tenant_id),
|
||||
opportunity_type=opportunity_type
|
||||
)
|
||||
1343
services/forecasting/app/services/forecasting_service.py
Normal file
1343
services/forecasting/app/services/forecasting_service.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,480 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/services/historical_validation_service.py
|
||||
# ================================================================
|
||||
"""
|
||||
Historical Validation Service
|
||||
|
||||
Handles validation backfill when historical sales data is uploaded late.
|
||||
Detects gaps in validation coverage and automatically triggers validation
|
||||
for periods where forecasts exist but haven't been validated yet.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, func, Date, or_
|
||||
from datetime import datetime, timedelta, timezone, date
|
||||
import structlog
|
||||
import uuid
|
||||
|
||||
from app.models.forecasts import Forecast
|
||||
from app.models.validation_run import ValidationRun
|
||||
from app.models.sales_data_update import SalesDataUpdate
|
||||
from app.services.validation_service import ValidationService
|
||||
from shared.database.exceptions import DatabaseError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class HistoricalValidationService:
|
||||
"""Service for backfilling historical validation when sales data arrives late"""
|
||||
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self.db = db_session
|
||||
self.validation_service = ValidationService(db_session)
|
||||
|
||||
async def detect_validation_gaps(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
lookback_days: int = 90
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Detect date ranges where forecasts exist but haven't been validated
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
lookback_days: How far back to check (default 90 days)
|
||||
|
||||
Returns:
|
||||
List of gap periods with date ranges
|
||||
"""
|
||||
try:
|
||||
end_date = datetime.now(timezone.utc)
|
||||
start_date = end_date - timedelta(days=lookback_days)
|
||||
|
||||
logger.info(
|
||||
"Detecting validation gaps",
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date.isoformat(),
|
||||
end_date=end_date.isoformat()
|
||||
)
|
||||
|
||||
# Get all dates with forecasts
|
||||
forecast_query = select(
|
||||
func.cast(Forecast.forecast_date, Date).label('forecast_date')
|
||||
).where(
|
||||
and_(
|
||||
Forecast.tenant_id == tenant_id,
|
||||
Forecast.forecast_date >= start_date,
|
||||
Forecast.forecast_date <= end_date
|
||||
)
|
||||
).group_by(
|
||||
func.cast(Forecast.forecast_date, Date)
|
||||
).order_by(
|
||||
func.cast(Forecast.forecast_date, Date)
|
||||
)
|
||||
|
||||
forecast_result = await self.db.execute(forecast_query)
|
||||
forecast_dates = {row.forecast_date for row in forecast_result.fetchall()}
|
||||
|
||||
if not forecast_dates:
|
||||
logger.info("No forecasts found in lookback period", tenant_id=tenant_id)
|
||||
return []
|
||||
|
||||
# Get all dates that have been validated
|
||||
validation_query = select(
|
||||
func.cast(ValidationRun.validation_start_date, Date).label('validated_date')
|
||||
).where(
|
||||
and_(
|
||||
ValidationRun.tenant_id == tenant_id,
|
||||
ValidationRun.status == "completed",
|
||||
ValidationRun.validation_start_date >= start_date,
|
||||
ValidationRun.validation_end_date <= end_date
|
||||
)
|
||||
).group_by(
|
||||
func.cast(ValidationRun.validation_start_date, Date)
|
||||
)
|
||||
|
||||
validation_result = await self.db.execute(validation_query)
|
||||
validated_dates = {row.validated_date for row in validation_result.fetchall()}
|
||||
|
||||
# Find gaps (dates with forecasts but no validation)
|
||||
gap_dates = sorted(forecast_dates - validated_dates)
|
||||
|
||||
if not gap_dates:
|
||||
logger.info("No validation gaps found", tenant_id=tenant_id)
|
||||
return []
|
||||
|
||||
# Group consecutive dates into ranges
|
||||
gaps = []
|
||||
current_gap_start = gap_dates[0]
|
||||
current_gap_end = gap_dates[0]
|
||||
|
||||
for i in range(1, len(gap_dates)):
|
||||
if (gap_dates[i] - current_gap_end).days == 1:
|
||||
# Consecutive date, extend current gap
|
||||
current_gap_end = gap_dates[i]
|
||||
else:
|
||||
# Gap in dates, save current gap and start new one
|
||||
gaps.append({
|
||||
"start_date": current_gap_start,
|
||||
"end_date": current_gap_end,
|
||||
"days_count": (current_gap_end - current_gap_start).days + 1
|
||||
})
|
||||
current_gap_start = gap_dates[i]
|
||||
current_gap_end = gap_dates[i]
|
||||
|
||||
# Don't forget the last gap
|
||||
gaps.append({
|
||||
"start_date": current_gap_start,
|
||||
"end_date": current_gap_end,
|
||||
"days_count": (current_gap_end - current_gap_start).days + 1
|
||||
})
|
||||
|
||||
logger.info(
|
||||
"Validation gaps detected",
|
||||
tenant_id=tenant_id,
|
||||
gaps_count=len(gaps),
|
||||
total_days=len(gap_dates)
|
||||
)
|
||||
|
||||
return gaps
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to detect validation gaps",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Failed to detect validation gaps: {str(e)}")
|
||||
|
||||
async def backfill_validation(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
triggered_by: str = "manual",
|
||||
sales_data_update_id: Optional[uuid.UUID] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Backfill validation for a historical date range
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
start_date: Start date for backfill
|
||||
end_date: End date for backfill
|
||||
triggered_by: How this backfill was triggered
|
||||
sales_data_update_id: Optional link to sales data update record
|
||||
|
||||
Returns:
|
||||
Backfill results with validation summary
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Starting validation backfill",
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date.isoformat(),
|
||||
end_date=end_date.isoformat(),
|
||||
triggered_by=triggered_by
|
||||
)
|
||||
|
||||
# Convert dates to datetime
|
||||
start_datetime = datetime.combine(start_date, datetime.min.time()).replace(tzinfo=timezone.utc)
|
||||
end_datetime = datetime.combine(end_date, datetime.max.time()).replace(tzinfo=timezone.utc)
|
||||
|
||||
# Run validation for the date range
|
||||
validation_result = await self.validation_service.validate_date_range(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_datetime,
|
||||
end_date=end_datetime,
|
||||
orchestration_run_id=None,
|
||||
triggered_by=triggered_by
|
||||
)
|
||||
|
||||
# Update sales data update record if provided
|
||||
if sales_data_update_id:
|
||||
await self._update_sales_data_record(
|
||||
sales_data_update_id=sales_data_update_id,
|
||||
validation_run_id=uuid.UUID(validation_result["validation_run_id"]),
|
||||
status="completed" if validation_result["status"] == "completed" else "failed"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Validation backfill completed",
|
||||
tenant_id=tenant_id,
|
||||
validation_run_id=validation_result.get("validation_run_id"),
|
||||
forecasts_evaluated=validation_result.get("forecasts_evaluated")
|
||||
)
|
||||
|
||||
return {
|
||||
**validation_result,
|
||||
"backfill_date_range": {
|
||||
"start": start_date.isoformat(),
|
||||
"end": end_date.isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Validation backfill failed",
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date.isoformat(),
|
||||
end_date=end_date.isoformat(),
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
if sales_data_update_id:
|
||||
await self._update_sales_data_record(
|
||||
sales_data_update_id=sales_data_update_id,
|
||||
validation_run_id=None,
|
||||
status="failed",
|
||||
error_message=str(e)
|
||||
)
|
||||
|
||||
raise DatabaseError(f"Validation backfill failed: {str(e)}")
|
||||
|
||||
async def auto_backfill_gaps(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
lookback_days: int = 90,
|
||||
max_gaps_to_process: int = 10
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Automatically detect and backfill validation gaps
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
lookback_days: How far back to check
|
||||
max_gaps_to_process: Maximum number of gaps to process in one run
|
||||
|
||||
Returns:
|
||||
Summary of backfill operations
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Starting auto backfill",
|
||||
tenant_id=tenant_id,
|
||||
lookback_days=lookback_days
|
||||
)
|
||||
|
||||
# Detect gaps
|
||||
gaps = await self.detect_validation_gaps(tenant_id, lookback_days)
|
||||
|
||||
if not gaps:
|
||||
return {
|
||||
"gaps_found": 0,
|
||||
"gaps_processed": 0,
|
||||
"validations_completed": 0,
|
||||
"message": "No validation gaps found"
|
||||
}
|
||||
|
||||
# Limit number of gaps to process
|
||||
gaps_to_process = gaps[:max_gaps_to_process]
|
||||
|
||||
results = []
|
||||
for gap in gaps_to_process:
|
||||
try:
|
||||
result = await self.backfill_validation(
|
||||
tenant_id=tenant_id,
|
||||
start_date=gap["start_date"],
|
||||
end_date=gap["end_date"],
|
||||
triggered_by="auto_backfill"
|
||||
)
|
||||
results.append({
|
||||
"gap": gap,
|
||||
"result": result,
|
||||
"status": "success"
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to backfill gap",
|
||||
gap=gap,
|
||||
error=str(e)
|
||||
)
|
||||
results.append({
|
||||
"gap": gap,
|
||||
"error": str(e),
|
||||
"status": "failed"
|
||||
})
|
||||
|
||||
successful = sum(1 for r in results if r["status"] == "success")
|
||||
|
||||
logger.info(
|
||||
"Auto backfill completed",
|
||||
tenant_id=tenant_id,
|
||||
gaps_found=len(gaps),
|
||||
gaps_processed=len(results),
|
||||
successful=successful
|
||||
)
|
||||
|
||||
return {
|
||||
"gaps_found": len(gaps),
|
||||
"gaps_processed": len(results),
|
||||
"validations_completed": successful,
|
||||
"validations_failed": len(results) - successful,
|
||||
"results": results
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Auto backfill failed",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Auto backfill failed: {str(e)}")
|
||||
|
||||
async def register_sales_data_update(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
records_affected: int,
|
||||
update_source: str = "import",
|
||||
import_job_id: Optional[str] = None,
|
||||
auto_trigger_validation: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Register a sales data update and optionally trigger validation
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
start_date: Start date of updated data
|
||||
end_date: End date of updated data
|
||||
records_affected: Number of sales records affected
|
||||
update_source: Source of update (import, manual, pos_sync)
|
||||
import_job_id: Optional import job ID
|
||||
auto_trigger_validation: Whether to automatically trigger validation
|
||||
|
||||
Returns:
|
||||
Update record and validation result if triggered
|
||||
"""
|
||||
try:
|
||||
# Create sales data update record
|
||||
update_record = SalesDataUpdate(
|
||||
tenant_id=tenant_id,
|
||||
update_date_start=start_date,
|
||||
update_date_end=end_date,
|
||||
records_affected=records_affected,
|
||||
update_source=update_source,
|
||||
import_job_id=import_job_id,
|
||||
requires_validation=auto_trigger_validation,
|
||||
validation_status="pending" if auto_trigger_validation else "not_required"
|
||||
)
|
||||
|
||||
self.db.add(update_record)
|
||||
await self.db.flush()
|
||||
|
||||
logger.info(
|
||||
"Registered sales data update",
|
||||
tenant_id=tenant_id,
|
||||
update_id=update_record.id,
|
||||
date_range=f"{start_date} to {end_date}",
|
||||
records_affected=records_affected
|
||||
)
|
||||
|
||||
result = {
|
||||
"update_id": str(update_record.id),
|
||||
"update_record": update_record.to_dict(),
|
||||
"validation_triggered": False
|
||||
}
|
||||
|
||||
# Trigger validation if requested
|
||||
if auto_trigger_validation:
|
||||
try:
|
||||
validation_result = await self.backfill_validation(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
triggered_by="sales_data_update",
|
||||
sales_data_update_id=update_record.id
|
||||
)
|
||||
|
||||
result["validation_triggered"] = True
|
||||
result["validation_result"] = validation_result
|
||||
|
||||
logger.info(
|
||||
"Validation triggered for sales data update",
|
||||
update_id=update_record.id,
|
||||
validation_run_id=validation_result.get("validation_run_id")
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to trigger validation for sales data update",
|
||||
update_id=update_record.id,
|
||||
error=str(e)
|
||||
)
|
||||
update_record.validation_status = "failed"
|
||||
update_record.validation_error = str(e)[:500]
|
||||
|
||||
await self.db.commit()
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to register sales data update",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
await self.db.rollback()
|
||||
raise DatabaseError(f"Failed to register sales data update: {str(e)}")
|
||||
|
||||
async def _update_sales_data_record(
|
||||
self,
|
||||
sales_data_update_id: uuid.UUID,
|
||||
validation_run_id: Optional[uuid.UUID],
|
||||
status: str,
|
||||
error_message: Optional[str] = None
|
||||
):
|
||||
"""Update sales data update record with validation results"""
|
||||
try:
|
||||
query = select(SalesDataUpdate).where(SalesDataUpdate.id == sales_data_update_id)
|
||||
result = await self.db.execute(query)
|
||||
update_record = result.scalar_one_or_none()
|
||||
|
||||
if update_record:
|
||||
update_record.validation_status = status
|
||||
update_record.validation_run_id = validation_run_id
|
||||
update_record.validated_at = datetime.now(timezone.utc)
|
||||
if error_message:
|
||||
update_record.validation_error = error_message[:500]
|
||||
|
||||
await self.db.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to update sales data record",
|
||||
sales_data_update_id=sales_data_update_id,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
async def get_pending_validations(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
limit: int = 50
|
||||
) -> List[SalesDataUpdate]:
|
||||
"""Get pending sales data updates that need validation"""
|
||||
try:
|
||||
query = (
|
||||
select(SalesDataUpdate)
|
||||
.where(
|
||||
and_(
|
||||
SalesDataUpdate.tenant_id == tenant_id,
|
||||
SalesDataUpdate.validation_status == "pending",
|
||||
SalesDataUpdate.requires_validation == True
|
||||
)
|
||||
)
|
||||
.order_by(SalesDataUpdate.created_at)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get pending validations",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Failed to get pending validations: {str(e)}")
|
||||
240
services/forecasting/app/services/model_client.py
Normal file
240
services/forecasting/app/services/model_client.py
Normal file
@@ -0,0 +1,240 @@
|
||||
# services/forecasting/app/services/model_client.py
|
||||
"""
|
||||
Forecast Service Model Client
|
||||
Demonstrates calling training service to get models
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
# Import shared clients - no more code duplication!
|
||||
from shared.clients import get_service_clients, get_training_client, get_sales_client
|
||||
from shared.database.base import create_database_manager
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
class ModelClient:
|
||||
"""
|
||||
Client for managing models in forecasting service with dependency injection
|
||||
Shows how to call multiple services cleanly
|
||||
"""
|
||||
|
||||
def __init__(self, database_manager=None):
|
||||
self.database_manager = database_manager or create_database_manager(
|
||||
settings.DATABASE_URL, "forecasting-service"
|
||||
)
|
||||
|
||||
# Option 1: Get all clients at once
|
||||
self.clients = get_service_clients(settings, "forecasting")
|
||||
|
||||
# Option 2: Get specific clients
|
||||
# self.training_client = get_training_client(settings, "forecasting")
|
||||
# self.sales_client = get_sales_client(settings, "forecasting")
|
||||
|
||||
async def get_available_models(
|
||||
self,
|
||||
tenant_id: str,
|
||||
model_type: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get available trained models from training service
|
||||
"""
|
||||
try:
|
||||
models = await self.clients.training.list_models(
|
||||
tenant_id=tenant_id,
|
||||
status="deployed", # Only get deployed models
|
||||
model_type=model_type
|
||||
)
|
||||
|
||||
if models:
|
||||
logger.info(f"Found {len(models)} available models",
|
||||
tenant_id=tenant_id, model_type=model_type)
|
||||
return models
|
||||
else:
|
||||
logger.warning("No available models found", tenant_id=tenant_id)
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching available models: {e}", tenant_id=tenant_id)
|
||||
return []
|
||||
|
||||
async def get_best_model_for_forecasting(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get the best model for forecasting based on performance metrics
|
||||
"""
|
||||
try:
|
||||
# Get latest model
|
||||
latest_model = await self.clients.training.get_active_model_for_product(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id
|
||||
)
|
||||
|
||||
if not latest_model:
|
||||
logger.warning("No trained models found", tenant_id=tenant_id)
|
||||
return None
|
||||
|
||||
# ✅ FIX 1: Use "model_id" instead of "id"
|
||||
model_id = latest_model.get("model_id")
|
||||
if not model_id:
|
||||
logger.error("Model response missing model_id field", tenant_id=tenant_id)
|
||||
return None
|
||||
|
||||
# ✅ FIX 2: Handle metrics endpoint failure gracefully
|
||||
try:
|
||||
# Get model metrics to validate quality
|
||||
metrics = await self.clients.training.get_model_metrics(
|
||||
tenant_id=tenant_id,
|
||||
model_id=model_id
|
||||
)
|
||||
|
||||
# If metrics call succeeded, check accuracy threshold
|
||||
if metrics and metrics.get("accuracy", 0) > 0.7: # 70% accuracy threshold
|
||||
logger.info(f"Selected model {model_id} with accuracy {metrics.get('accuracy')}",
|
||||
tenant_id=tenant_id)
|
||||
return latest_model
|
||||
elif metrics:
|
||||
logger.warning(f"Model accuracy too low: {metrics.get('accuracy', 'unknown')}",
|
||||
tenant_id=tenant_id)
|
||||
# Still return the model even if accuracy is low - better than no prediction
|
||||
logger.info("Returning model despite low accuracy - no alternative available",
|
||||
tenant_id=tenant_id)
|
||||
return latest_model
|
||||
else:
|
||||
logger.warning("No metrics returned from training service", tenant_id=tenant_id)
|
||||
# Return model anyway - metrics service might be temporarily down
|
||||
return latest_model
|
||||
|
||||
except Exception as metrics_error:
|
||||
# ✅ FIX 3: If metrics endpoint fails, still return the model
|
||||
logger.warning(f"Failed to get model metrics: {metrics_error}", tenant_id=tenant_id)
|
||||
logger.info("Proceeding with model despite metrics failure", tenant_id=tenant_id)
|
||||
return latest_model
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error selecting best model: {e}", tenant_id=tenant_id)
|
||||
return None
|
||||
|
||||
async def get_any_model_for_tenant(
|
||||
self,
|
||||
tenant_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get any available model for a tenant, used as fallback when specific product models aren't found
|
||||
"""
|
||||
try:
|
||||
# First try to get any active models for this tenant
|
||||
models = await self.get_available_models(tenant_id)
|
||||
|
||||
if models:
|
||||
# Return the most recently trained model
|
||||
sorted_models = sorted(models, key=lambda x: x.get('created_at', ''), reverse=True)
|
||||
best_model = sorted_models[0]
|
||||
logger.info("Found fallback model for tenant",
|
||||
tenant_id=tenant_id,
|
||||
model_id=best_model.get('id', 'unknown'),
|
||||
inventory_product_id=best_model.get('inventory_product_id', 'unknown'))
|
||||
return best_model
|
||||
|
||||
logger.warning("No fallback models available for tenant", tenant_id=tenant_id)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting fallback model for tenant",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return None
|
||||
|
||||
async def validate_model_data_compatibility(
|
||||
self,
|
||||
tenant_id: str,
|
||||
model_id: str,
|
||||
forecast_start_date: str,
|
||||
forecast_end_date: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate that we have sufficient data for the model to make forecasts
|
||||
Demonstrates calling both training and data services
|
||||
"""
|
||||
try:
|
||||
# Get model details from training service
|
||||
model = await self.clients.training.get_model(
|
||||
tenant_id=tenant_id,
|
||||
model_id=model_id
|
||||
)
|
||||
|
||||
if not model:
|
||||
return {"is_valid": False, "error": "Model not found"}
|
||||
|
||||
# Get data statistics from data service
|
||||
data_stats = await self.clients.data.get_data_statistics(
|
||||
tenant_id=tenant_id,
|
||||
start_date=forecast_start_date,
|
||||
end_date=forecast_end_date
|
||||
)
|
||||
|
||||
if not data_stats:
|
||||
return {"is_valid": False, "error": "Could not retrieve data statistics"}
|
||||
|
||||
# Check if we have minimum required data points
|
||||
min_required = model.get("metadata", {}).get("min_data_points", 30)
|
||||
available_points = data_stats.get("total_records", 0)
|
||||
|
||||
is_valid = available_points >= min_required
|
||||
|
||||
result = {
|
||||
"is_valid": is_valid,
|
||||
"model_id": model_id,
|
||||
"required_points": min_required,
|
||||
"available_points": available_points,
|
||||
"data_coverage": data_stats.get("coverage_percentage", 0)
|
||||
}
|
||||
|
||||
if not is_valid:
|
||||
result["error"] = f"Insufficient data: need {min_required}, have {available_points}"
|
||||
|
||||
logger.info("Model data compatibility check completed",
|
||||
tenant_id=tenant_id, model_id=model_id, is_valid=is_valid)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating model compatibility: {e}",
|
||||
tenant_id=tenant_id, model_id=model_id)
|
||||
return {"is_valid": False, "error": str(e)}
|
||||
|
||||
async def trigger_model_retraining(
|
||||
self,
|
||||
tenant_id: str,
|
||||
include_weather: bool = True,
|
||||
include_traffic: bool = False
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Trigger a new training job if current model is outdated
|
||||
"""
|
||||
try:
|
||||
# Create training job through training service
|
||||
job = await self.clients.training.create_training_job(
|
||||
tenant_id=tenant_id,
|
||||
include_weather=include_weather,
|
||||
include_traffic=include_traffic,
|
||||
min_data_points=50 # Higher threshold for forecasting
|
||||
)
|
||||
|
||||
if job:
|
||||
logger.info(f"Training job created: {job['job_id']}", tenant_id=tenant_id)
|
||||
return job
|
||||
else:
|
||||
logger.error("Failed to create training job", tenant_id=tenant_id)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error triggering model retraining: {e}", tenant_id=tenant_id)
|
||||
return None
|
||||
|
||||
# Global instance
|
||||
model_client = ModelClient()
|
||||
@@ -0,0 +1,435 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/services/performance_monitoring_service.py
|
||||
# ================================================================
|
||||
"""
|
||||
Performance Monitoring Service
|
||||
|
||||
Monitors forecast accuracy over time and triggers actions when
|
||||
performance degrades below acceptable thresholds.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, func, desc
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import structlog
|
||||
import uuid
|
||||
|
||||
from app.models.validation_run import ValidationRun
|
||||
from app.models.predictions import ModelPerformanceMetric
|
||||
from app.models.forecasts import Forecast
|
||||
from shared.database.exceptions import DatabaseError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class PerformanceMonitoringService:
|
||||
"""Service for monitoring forecast performance and triggering improvements"""
|
||||
|
||||
# Configurable thresholds
|
||||
MAPE_WARNING_THRESHOLD = 20.0 # Warning if MAPE > 20%
|
||||
MAPE_CRITICAL_THRESHOLD = 30.0 # Critical if MAPE > 30%
|
||||
MAPE_TREND_THRESHOLD = 5.0 # Alert if MAPE increases by > 5% over period
|
||||
MIN_SAMPLES_FOR_ALERT = 5 # Minimum validations before alerting
|
||||
TREND_LOOKBACK_DAYS = 30 # Days to analyze for trends
|
||||
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self.db = db_session
|
||||
|
||||
async def get_accuracy_summary(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
days: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get accuracy summary for recent period
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
days: Number of days to analyze
|
||||
|
||||
Returns:
|
||||
Summary with overall metrics and trends
|
||||
"""
|
||||
try:
|
||||
start_date = datetime.now(timezone.utc) - timedelta(days=days)
|
||||
|
||||
# Get recent validation runs
|
||||
query = (
|
||||
select(ValidationRun)
|
||||
.where(
|
||||
and_(
|
||||
ValidationRun.tenant_id == tenant_id,
|
||||
ValidationRun.status == "completed",
|
||||
ValidationRun.started_at >= start_date,
|
||||
ValidationRun.forecasts_with_actuals > 0 # Only runs with actual data
|
||||
)
|
||||
)
|
||||
.order_by(desc(ValidationRun.started_at))
|
||||
)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
runs = result.scalars().all()
|
||||
|
||||
if not runs:
|
||||
return {
|
||||
"status": "no_data",
|
||||
"message": f"No validation runs found in last {days} days",
|
||||
"period_days": days
|
||||
}
|
||||
|
||||
# Calculate summary statistics
|
||||
total_forecasts = sum(r.total_forecasts_evaluated for r in runs)
|
||||
total_with_actuals = sum(r.forecasts_with_actuals for r in runs)
|
||||
|
||||
mape_values = [r.overall_mape for r in runs if r.overall_mape is not None]
|
||||
mae_values = [r.overall_mae for r in runs if r.overall_mae is not None]
|
||||
rmse_values = [r.overall_rmse for r in runs if r.overall_rmse is not None]
|
||||
|
||||
avg_mape = sum(mape_values) / len(mape_values) if mape_values else None
|
||||
avg_mae = sum(mae_values) / len(mae_values) if mae_values else None
|
||||
avg_rmse = sum(rmse_values) / len(rmse_values) if rmse_values else None
|
||||
|
||||
# Determine health status
|
||||
health_status = "healthy"
|
||||
if avg_mape and avg_mape > self.MAPE_CRITICAL_THRESHOLD:
|
||||
health_status = "critical"
|
||||
elif avg_mape and avg_mape > self.MAPE_WARNING_THRESHOLD:
|
||||
health_status = "warning"
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"period_days": days,
|
||||
"validation_runs": len(runs),
|
||||
"total_forecasts_evaluated": total_forecasts,
|
||||
"total_forecasts_with_actuals": total_with_actuals,
|
||||
"coverage_percentage": round(
|
||||
(total_with_actuals / total_forecasts * 100) if total_forecasts > 0 else 0, 2
|
||||
),
|
||||
"average_metrics": {
|
||||
"mape": round(avg_mape, 2) if avg_mape else None,
|
||||
"mae": round(avg_mae, 2) if avg_mae else None,
|
||||
"rmse": round(avg_rmse, 2) if avg_rmse else None,
|
||||
"accuracy_percentage": round(100 - avg_mape, 2) if avg_mape else None
|
||||
},
|
||||
"health_status": health_status,
|
||||
"thresholds": {
|
||||
"warning": self.MAPE_WARNING_THRESHOLD,
|
||||
"critical": self.MAPE_CRITICAL_THRESHOLD
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get accuracy summary",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Failed to get accuracy summary: {str(e)}")
|
||||
|
||||
async def detect_performance_degradation(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
lookback_days: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Detect if forecast performance is degrading over time
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
lookback_days: Days to analyze for trends
|
||||
|
||||
Returns:
|
||||
Degradation analysis with recommendations
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Detecting performance degradation",
|
||||
tenant_id=tenant_id,
|
||||
lookback_days=lookback_days
|
||||
)
|
||||
|
||||
start_date = datetime.now(timezone.utc) - timedelta(days=lookback_days)
|
||||
|
||||
# Get validation runs ordered by time
|
||||
query = (
|
||||
select(ValidationRun)
|
||||
.where(
|
||||
and_(
|
||||
ValidationRun.tenant_id == tenant_id,
|
||||
ValidationRun.status == "completed",
|
||||
ValidationRun.started_at >= start_date,
|
||||
ValidationRun.forecasts_with_actuals > 0
|
||||
)
|
||||
)
|
||||
.order_by(ValidationRun.started_at)
|
||||
)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
runs = list(result.scalars().all())
|
||||
|
||||
if len(runs) < self.MIN_SAMPLES_FOR_ALERT:
|
||||
return {
|
||||
"status": "insufficient_data",
|
||||
"message": f"Need at least {self.MIN_SAMPLES_FOR_ALERT} validation runs",
|
||||
"runs_found": len(runs)
|
||||
}
|
||||
|
||||
# Split into first half and second half
|
||||
midpoint = len(runs) // 2
|
||||
first_half = runs[:midpoint]
|
||||
second_half = runs[midpoint:]
|
||||
|
||||
# Calculate average MAPE for each half
|
||||
first_half_mape = sum(
|
||||
r.overall_mape for r in first_half if r.overall_mape
|
||||
) / len([r for r in first_half if r.overall_mape])
|
||||
|
||||
second_half_mape = sum(
|
||||
r.overall_mape for r in second_half if r.overall_mape
|
||||
) / len([r for r in second_half if r.overall_mape])
|
||||
|
||||
mape_change = second_half_mape - first_half_mape
|
||||
mape_change_percentage = (mape_change / first_half_mape * 100) if first_half_mape > 0 else 0
|
||||
|
||||
# Determine if degradation is significant
|
||||
is_degrading = mape_change > self.MAPE_TREND_THRESHOLD
|
||||
severity = "none"
|
||||
|
||||
if is_degrading:
|
||||
if mape_change > self.MAPE_TREND_THRESHOLD * 2:
|
||||
severity = "high"
|
||||
elif mape_change > self.MAPE_TREND_THRESHOLD:
|
||||
severity = "medium"
|
||||
|
||||
# Get products with worst performance
|
||||
poor_products = await self._identify_poor_performers(tenant_id, lookback_days)
|
||||
|
||||
result = {
|
||||
"status": "analyzed",
|
||||
"period_days": lookback_days,
|
||||
"samples_analyzed": len(runs),
|
||||
"is_degrading": is_degrading,
|
||||
"severity": severity,
|
||||
"metrics": {
|
||||
"first_period_mape": round(first_half_mape, 2),
|
||||
"second_period_mape": round(second_half_mape, 2),
|
||||
"mape_change": round(mape_change, 2),
|
||||
"mape_change_percentage": round(mape_change_percentage, 2)
|
||||
},
|
||||
"poor_performers": poor_products,
|
||||
"recommendations": []
|
||||
}
|
||||
|
||||
# Add recommendations
|
||||
if is_degrading:
|
||||
result["recommendations"].append({
|
||||
"action": "retrain_models",
|
||||
"priority": "high" if severity == "high" else "medium",
|
||||
"reason": f"MAPE increased by {abs(mape_change):.1f}% over {lookback_days} days"
|
||||
})
|
||||
|
||||
if poor_products:
|
||||
result["recommendations"].append({
|
||||
"action": "retrain_poor_performers",
|
||||
"priority": "high",
|
||||
"reason": f"{len(poor_products)} products with MAPE > {self.MAPE_CRITICAL_THRESHOLD}%",
|
||||
"products": poor_products[:10] # Top 10 worst
|
||||
})
|
||||
|
||||
logger.info(
|
||||
"Performance degradation analysis complete",
|
||||
tenant_id=tenant_id,
|
||||
is_degrading=is_degrading,
|
||||
severity=severity,
|
||||
poor_performers=len(poor_products)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to detect performance degradation",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Failed to detect degradation: {str(e)}")
|
||||
|
||||
async def _identify_poor_performers(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
lookback_days: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Identify products/locations with poor accuracy"""
|
||||
try:
|
||||
start_date = datetime.now(timezone.utc) - timedelta(days=lookback_days)
|
||||
|
||||
# Get recent performance metrics grouped by product
|
||||
query = select(
|
||||
ModelPerformanceMetric.inventory_product_id,
|
||||
func.avg(ModelPerformanceMetric.mape).label('avg_mape'),
|
||||
func.avg(ModelPerformanceMetric.mae).label('avg_mae'),
|
||||
func.count(ModelPerformanceMetric.id).label('sample_count')
|
||||
).where(
|
||||
and_(
|
||||
ModelPerformanceMetric.tenant_id == tenant_id,
|
||||
ModelPerformanceMetric.created_at >= start_date,
|
||||
ModelPerformanceMetric.mape.isnot(None)
|
||||
)
|
||||
).group_by(
|
||||
ModelPerformanceMetric.inventory_product_id
|
||||
).having(
|
||||
func.avg(ModelPerformanceMetric.mape) > self.MAPE_CRITICAL_THRESHOLD
|
||||
).order_by(
|
||||
desc(func.avg(ModelPerformanceMetric.mape))
|
||||
).limit(20)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
poor_performers = []
|
||||
|
||||
for row in result.fetchall():
|
||||
poor_performers.append({
|
||||
"inventory_product_id": str(row.inventory_product_id),
|
||||
"avg_mape": round(row.avg_mape, 2),
|
||||
"avg_mae": round(row.avg_mae, 2),
|
||||
"sample_count": row.sample_count,
|
||||
"requires_retraining": True
|
||||
})
|
||||
|
||||
return poor_performers
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to identify poor performers", error=str(e))
|
||||
return []
|
||||
|
||||
async def check_model_age(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
max_age_days: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Check if models are outdated and need retraining
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
max_age_days: Maximum acceptable model age
|
||||
|
||||
Returns:
|
||||
Analysis of model ages
|
||||
"""
|
||||
try:
|
||||
# Get distinct models used in recent forecasts
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=7)
|
||||
|
||||
query = select(
|
||||
Forecast.model_id,
|
||||
Forecast.model_version,
|
||||
Forecast.inventory_product_id,
|
||||
func.max(Forecast.created_at).label('last_used'),
|
||||
func.count(Forecast.id).label('forecast_count')
|
||||
).where(
|
||||
and_(
|
||||
Forecast.tenant_id == tenant_id,
|
||||
Forecast.created_at >= cutoff_date
|
||||
)
|
||||
).group_by(
|
||||
Forecast.model_id,
|
||||
Forecast.model_version,
|
||||
Forecast.inventory_product_id
|
||||
)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
|
||||
models_info = []
|
||||
outdated_count = 0
|
||||
|
||||
for row in result.fetchall():
|
||||
# Check age against training service (would need to query training service)
|
||||
# For now, assume models older than max_age_days need retraining
|
||||
models_info.append({
|
||||
"model_id": row.model_id,
|
||||
"model_version": row.model_version,
|
||||
"inventory_product_id": str(row.inventory_product_id),
|
||||
"last_used": row.last_used.isoformat(),
|
||||
"forecast_count": row.forecast_count
|
||||
})
|
||||
|
||||
return {
|
||||
"status": "analyzed",
|
||||
"models_in_use": len(models_info),
|
||||
"outdated_models": outdated_count,
|
||||
"max_age_days": max_age_days,
|
||||
"models": models_info[:20] # Top 20
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to check model age", error=str(e))
|
||||
raise DatabaseError(f"Failed to check model age: {str(e)}")
|
||||
|
||||
async def generate_performance_report(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
days: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate comprehensive performance report
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
days: Analysis period
|
||||
|
||||
Returns:
|
||||
Complete performance report with recommendations
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Generating performance report",
|
||||
tenant_id=tenant_id,
|
||||
days=days
|
||||
)
|
||||
|
||||
# Get all analyses
|
||||
summary = await self.get_accuracy_summary(tenant_id, days)
|
||||
degradation = await self.detect_performance_degradation(tenant_id, days)
|
||||
model_age = await self.check_model_age(tenant_id)
|
||||
|
||||
# Compile recommendations
|
||||
all_recommendations = []
|
||||
|
||||
if summary.get("health_status") == "critical":
|
||||
all_recommendations.append({
|
||||
"priority": "critical",
|
||||
"action": "immediate_review",
|
||||
"reason": f"Overall MAPE is {summary['average_metrics']['mape']}%",
|
||||
"details": "Forecast accuracy is critically low"
|
||||
})
|
||||
|
||||
all_recommendations.extend(degradation.get("recommendations", []))
|
||||
|
||||
report = {
|
||||
"generated_at": datetime.now(timezone.utc).isoformat(),
|
||||
"tenant_id": str(tenant_id),
|
||||
"analysis_period_days": days,
|
||||
"summary": summary,
|
||||
"degradation_analysis": degradation,
|
||||
"model_age_analysis": model_age,
|
||||
"recommendations": all_recommendations,
|
||||
"requires_action": len(all_recommendations) > 0
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"Performance report generated",
|
||||
tenant_id=tenant_id,
|
||||
health_status=summary.get("health_status"),
|
||||
recommendations_count=len(all_recommendations)
|
||||
)
|
||||
|
||||
return report
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to generate performance report",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Failed to generate report: {str(e)}")
|
||||
99
services/forecasting/app/services/poi_feature_service.py
Normal file
99
services/forecasting/app/services/poi_feature_service.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
POI Feature Service for Forecasting
|
||||
|
||||
Fetches POI features for use in demand forecasting predictions.
|
||||
Ensures feature consistency between training and prediction.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
import structlog
|
||||
|
||||
from shared.clients.external_client import ExternalServiceClient
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class POIFeatureService:
|
||||
"""
|
||||
POI feature service for forecasting.
|
||||
|
||||
Fetches POI context from External service to ensure
|
||||
prediction uses the same features as training.
|
||||
"""
|
||||
|
||||
def __init__(self, external_client: ExternalServiceClient = None):
|
||||
"""
|
||||
Initialize POI feature service.
|
||||
|
||||
Args:
|
||||
external_client: External service client instance (optional)
|
||||
"""
|
||||
if external_client is None:
|
||||
from app.core.config import settings
|
||||
self.external_client = ExternalServiceClient(settings, "forecasting-service")
|
||||
else:
|
||||
self.external_client = external_client
|
||||
|
||||
async def get_poi_features(
|
||||
self,
|
||||
tenant_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get POI features for tenant.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID
|
||||
|
||||
Returns:
|
||||
Dictionary with POI features or empty dict if not available
|
||||
"""
|
||||
try:
|
||||
result = await self.external_client.get_poi_context(tenant_id)
|
||||
|
||||
if result:
|
||||
poi_context = result.get("poi_context", {})
|
||||
ml_features = poi_context.get("ml_features", {})
|
||||
|
||||
logger.info(
|
||||
"POI features retrieved for forecasting",
|
||||
tenant_id=tenant_id,
|
||||
feature_count=len(ml_features)
|
||||
)
|
||||
|
||||
return ml_features
|
||||
else:
|
||||
logger.warning(
|
||||
"No POI context found for tenant",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
return {}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to fetch POI features for forecasting",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e),
|
||||
exc_info=True
|
||||
)
|
||||
return {}
|
||||
|
||||
async def check_poi_service_health(self) -> bool:
|
||||
"""
|
||||
Check if POI service is accessible through the external client.
|
||||
|
||||
Returns:
|
||||
True if service is healthy, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Test the external service health by attempting to get POI context for a dummy tenant
|
||||
# This will go through the proper authentication and routing
|
||||
dummy_context = await self.external_client.get_poi_context("test-tenant")
|
||||
# If we can successfully make a request (even if it returns None for missing tenant),
|
||||
# it means the service is accessible
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"POI service health check failed",
|
||||
error=str(e)
|
||||
)
|
||||
return False
|
||||
1212
services/forecasting/app/services/prediction_service.py
Normal file
1212
services/forecasting/app/services/prediction_service.py
Normal file
File diff suppressed because it is too large
Load Diff
486
services/forecasting/app/services/retraining_trigger_service.py
Normal file
486
services/forecasting/app/services/retraining_trigger_service.py
Normal file
@@ -0,0 +1,486 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/services/retraining_trigger_service.py
|
||||
# ================================================================
|
||||
"""
|
||||
Retraining Trigger Service
|
||||
|
||||
Automatically triggers model retraining based on performance metrics,
|
||||
accuracy degradation, or data availability.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from datetime import datetime, timezone
|
||||
import structlog
|
||||
import uuid
|
||||
|
||||
from app.services.performance_monitoring_service import PerformanceMonitoringService
|
||||
from shared.clients.training_client import TrainingServiceClient
|
||||
from shared.config.base import BaseServiceSettings
|
||||
from shared.database.exceptions import DatabaseError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class RetrainingTriggerService:
|
||||
"""Service for triggering automatic model retraining"""
|
||||
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self.db = db_session
|
||||
self.performance_service = PerformanceMonitoringService(db_session)
|
||||
|
||||
# Initialize training client
|
||||
config = BaseServiceSettings()
|
||||
self.training_client = TrainingServiceClient(config, calling_service_name="forecasting")
|
||||
|
||||
async def evaluate_and_trigger_retraining(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
auto_trigger: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Evaluate performance and trigger retraining if needed
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
auto_trigger: Whether to automatically trigger retraining
|
||||
|
||||
Returns:
|
||||
Evaluation results and retraining actions taken
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Evaluating retraining needs",
|
||||
tenant_id=tenant_id,
|
||||
auto_trigger=auto_trigger
|
||||
)
|
||||
|
||||
# Generate performance report
|
||||
report = await self.performance_service.generate_performance_report(
|
||||
tenant_id=tenant_id,
|
||||
days=30
|
||||
)
|
||||
|
||||
if not report.get("requires_action"):
|
||||
logger.info(
|
||||
"No retraining required",
|
||||
tenant_id=tenant_id,
|
||||
health_status=report["summary"].get("health_status")
|
||||
)
|
||||
return {
|
||||
"status": "no_action_needed",
|
||||
"tenant_id": str(tenant_id),
|
||||
"health_status": report["summary"].get("health_status"),
|
||||
"report": report
|
||||
}
|
||||
|
||||
# Extract products that need retraining
|
||||
products_to_retrain = []
|
||||
recommendations = report.get("recommendations", [])
|
||||
|
||||
for rec in recommendations:
|
||||
if rec.get("action") == "retrain_poor_performers":
|
||||
products_to_retrain.extend(rec.get("products", []))
|
||||
|
||||
if not products_to_retrain and auto_trigger:
|
||||
# If degradation detected but no specific products, consider retraining all
|
||||
degradation = report.get("degradation_analysis", {})
|
||||
if degradation.get("is_degrading") and degradation.get("severity") in ["high", "medium"]:
|
||||
logger.info(
|
||||
"General degradation detected, considering full retraining",
|
||||
tenant_id=tenant_id,
|
||||
severity=degradation.get("severity")
|
||||
)
|
||||
|
||||
retraining_results = []
|
||||
|
||||
if auto_trigger and products_to_retrain:
|
||||
# Trigger retraining for poor performers
|
||||
for product in products_to_retrain:
|
||||
try:
|
||||
result = await self._trigger_product_retraining(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=uuid.UUID(product["inventory_product_id"]),
|
||||
reason=f"MAPE {product['avg_mape']}% exceeds threshold",
|
||||
priority="high"
|
||||
)
|
||||
retraining_results.append(result)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to trigger retraining for product",
|
||||
product_id=product["inventory_product_id"],
|
||||
error=str(e)
|
||||
)
|
||||
retraining_results.append({
|
||||
"product_id": product["inventory_product_id"],
|
||||
"status": "failed",
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
logger.info(
|
||||
"Retraining evaluation complete",
|
||||
tenant_id=tenant_id,
|
||||
products_evaluated=len(products_to_retrain),
|
||||
retraining_triggered=len(retraining_results)
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "evaluated",
|
||||
"tenant_id": str(tenant_id),
|
||||
"requires_action": report.get("requires_action"),
|
||||
"products_needing_retraining": len(products_to_retrain),
|
||||
"retraining_triggered": len(retraining_results),
|
||||
"auto_trigger_enabled": auto_trigger,
|
||||
"retraining_results": retraining_results,
|
||||
"performance_report": report
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to evaluate and trigger retraining",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Failed to evaluate retraining: {str(e)}")
|
||||
|
||||
async def _trigger_product_retraining(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
inventory_product_id: uuid.UUID,
|
||||
reason: str,
|
||||
priority: str = "normal"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Trigger retraining for a specific product
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
inventory_product_id: Product to retrain
|
||||
reason: Reason for retraining
|
||||
priority: Priority level (low, normal, high)
|
||||
|
||||
Returns:
|
||||
Retraining trigger result
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Triggering product retraining",
|
||||
tenant_id=tenant_id,
|
||||
product_id=inventory_product_id,
|
||||
reason=reason,
|
||||
priority=priority
|
||||
)
|
||||
|
||||
# Call training service to trigger retraining
|
||||
result = await self.training_client.trigger_retrain(
|
||||
tenant_id=str(tenant_id),
|
||||
inventory_product_id=str(inventory_product_id),
|
||||
reason=reason,
|
||||
priority=priority
|
||||
)
|
||||
|
||||
if result:
|
||||
logger.info(
|
||||
"Retraining triggered successfully",
|
||||
tenant_id=tenant_id,
|
||||
product_id=inventory_product_id,
|
||||
training_job_id=result.get("training_job_id")
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "triggered",
|
||||
"product_id": str(inventory_product_id),
|
||||
"training_job_id": result.get("training_job_id"),
|
||||
"reason": reason,
|
||||
"priority": priority,
|
||||
"triggered_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
else:
|
||||
logger.warning(
|
||||
"Retraining trigger returned no result",
|
||||
tenant_id=tenant_id,
|
||||
product_id=inventory_product_id
|
||||
)
|
||||
return {
|
||||
"status": "no_response",
|
||||
"product_id": str(inventory_product_id),
|
||||
"reason": reason
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to trigger product retraining",
|
||||
tenant_id=tenant_id,
|
||||
product_id=inventory_product_id,
|
||||
error=str(e)
|
||||
)
|
||||
return {
|
||||
"status": "failed",
|
||||
"product_id": str(inventory_product_id),
|
||||
"error": str(e),
|
||||
"reason": reason
|
||||
}
|
||||
|
||||
async def trigger_bulk_retraining(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
product_ids: List[uuid.UUID],
|
||||
reason: str = "Bulk retraining requested"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Trigger retraining for multiple products
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
product_ids: List of products to retrain
|
||||
reason: Reason for bulk retraining
|
||||
|
||||
Returns:
|
||||
Bulk retraining results
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Triggering bulk retraining",
|
||||
tenant_id=tenant_id,
|
||||
product_count=len(product_ids)
|
||||
)
|
||||
|
||||
results = []
|
||||
|
||||
for product_id in product_ids:
|
||||
result = await self._trigger_product_retraining(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=product_id,
|
||||
reason=reason,
|
||||
priority="normal"
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
successful = sum(1 for r in results if r["status"] == "triggered")
|
||||
|
||||
logger.info(
|
||||
"Bulk retraining completed",
|
||||
tenant_id=tenant_id,
|
||||
total=len(product_ids),
|
||||
successful=successful,
|
||||
failed=len(product_ids) - successful
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"tenant_id": str(tenant_id),
|
||||
"total_products": len(product_ids),
|
||||
"successful": successful,
|
||||
"failed": len(product_ids) - successful,
|
||||
"results": results
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Bulk retraining failed",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Bulk retraining failed: {str(e)}")
|
||||
|
||||
async def check_and_trigger_scheduled_retraining(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
max_model_age_days: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Check model ages and trigger retraining for outdated models
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
max_model_age_days: Maximum acceptable model age
|
||||
|
||||
Returns:
|
||||
Scheduled retraining results
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Checking for scheduled retraining needs",
|
||||
tenant_id=tenant_id,
|
||||
max_model_age_days=max_model_age_days
|
||||
)
|
||||
|
||||
# Get model age analysis
|
||||
model_age_analysis = await self.performance_service.check_model_age(
|
||||
tenant_id=tenant_id,
|
||||
max_age_days=max_model_age_days
|
||||
)
|
||||
|
||||
outdated_count = model_age_analysis.get("outdated_models", 0)
|
||||
|
||||
if outdated_count == 0:
|
||||
logger.info(
|
||||
"No outdated models found",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
return {
|
||||
"status": "no_action_needed",
|
||||
"tenant_id": str(tenant_id),
|
||||
"outdated_models": 0
|
||||
}
|
||||
|
||||
# Trigger retraining for outdated models
|
||||
try:
|
||||
from shared.clients.training_client import TrainingServiceClient
|
||||
from shared.config.base import get_settings
|
||||
from shared.messaging import get_rabbitmq_client
|
||||
|
||||
config = get_settings()
|
||||
training_client = TrainingServiceClient(config, "forecasting")
|
||||
|
||||
# Get list of models that need retraining
|
||||
outdated_models = await training_client.get_outdated_models(
|
||||
tenant_id=str(tenant_id),
|
||||
max_age_days=max_model_age_days,
|
||||
min_accuracy=0.85, # Configurable threshold
|
||||
min_new_data_points=1000 # Configurable threshold
|
||||
)
|
||||
|
||||
if not outdated_models:
|
||||
logger.info("No specific models returned for retraining", tenant_id=tenant_id)
|
||||
return {
|
||||
"status": "no_models_found",
|
||||
"tenant_id": str(tenant_id),
|
||||
"outdated_models": outdated_count
|
||||
}
|
||||
|
||||
# Publish retraining events to RabbitMQ for each model
|
||||
rabbitmq_client = get_rabbitmq_client()
|
||||
triggered_models = []
|
||||
|
||||
if rabbitmq_client:
|
||||
for model in outdated_models:
|
||||
try:
|
||||
import uuid as uuid_module
|
||||
from datetime import datetime
|
||||
|
||||
retraining_event = {
|
||||
"event_id": str(uuid_module.uuid4()),
|
||||
"event_type": "training.retrain.requested",
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"tenant_id": str(tenant_id),
|
||||
"data": {
|
||||
"model_id": model.get('id'),
|
||||
"product_id": model.get('product_id'),
|
||||
"model_type": model.get('model_type'),
|
||||
"current_accuracy": model.get('accuracy'),
|
||||
"model_age_days": model.get('age_days'),
|
||||
"new_data_points": model.get('new_data_points', 0),
|
||||
"trigger_reason": model.get('trigger_reason', 'scheduled_check'),
|
||||
"priority": model.get('priority', 'normal'),
|
||||
"requested_by": "system_scheduled_check"
|
||||
}
|
||||
}
|
||||
|
||||
await rabbitmq_client.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.retrain.requested",
|
||||
event_data=retraining_event
|
||||
)
|
||||
|
||||
triggered_models.append({
|
||||
'model_id': model.get('id'),
|
||||
'product_id': model.get('product_id'),
|
||||
'event_id': retraining_event['event_id']
|
||||
})
|
||||
|
||||
logger.info(
|
||||
"Published retraining request",
|
||||
model_id=model.get('id'),
|
||||
product_id=model.get('product_id'),
|
||||
event_id=retraining_event['event_id'],
|
||||
trigger_reason=model.get('trigger_reason')
|
||||
)
|
||||
|
||||
except Exception as publish_error:
|
||||
logger.error(
|
||||
"Failed to publish retraining event",
|
||||
model_id=model.get('id'),
|
||||
error=str(publish_error)
|
||||
)
|
||||
# Continue with other models even if one fails
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
"RabbitMQ client not available, cannot trigger retraining",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "retraining_triggered",
|
||||
"tenant_id": str(tenant_id),
|
||||
"outdated_models": outdated_count,
|
||||
"triggered_count": len(triggered_models),
|
||||
"triggered_models": triggered_models,
|
||||
"message": f"Triggered retraining for {len(triggered_models)} models"
|
||||
}
|
||||
|
||||
except Exception as trigger_error:
|
||||
logger.error(
|
||||
"Failed to trigger retraining",
|
||||
tenant_id=tenant_id,
|
||||
error=str(trigger_error),
|
||||
exc_info=True
|
||||
)
|
||||
# Return analysis result even if triggering failed
|
||||
return {
|
||||
"status": "trigger_failed",
|
||||
"tenant_id": str(tenant_id),
|
||||
"outdated_models": outdated_count,
|
||||
"error": str(trigger_error),
|
||||
"message": "Analysis complete but failed to trigger retraining"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Scheduled retraining check failed",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Scheduled retraining check failed: {str(e)}")
|
||||
|
||||
async def get_retraining_recommendations(
|
||||
self,
|
||||
tenant_id: uuid.UUID
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get retraining recommendations without triggering
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
|
||||
Returns:
|
||||
Recommendations for manual review
|
||||
"""
|
||||
try:
|
||||
# Evaluate without auto-triggering
|
||||
result = await self.evaluate_and_trigger_retraining(
|
||||
tenant_id=tenant_id,
|
||||
auto_trigger=False
|
||||
)
|
||||
|
||||
# Extract just the recommendations
|
||||
report = result.get("performance_report", {})
|
||||
recommendations = report.get("recommendations", [])
|
||||
|
||||
return {
|
||||
"tenant_id": str(tenant_id),
|
||||
"generated_at": datetime.now(timezone.utc).isoformat(),
|
||||
"requires_action": result.get("requires_action", False),
|
||||
"recommendations": recommendations,
|
||||
"summary": report.get("summary", {}),
|
||||
"degradation_detected": report.get("degradation_analysis", {}).get("is_degrading", False)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get retraining recommendations",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Failed to get recommendations: {str(e)}")
|
||||
97
services/forecasting/app/services/sales_client.py
Normal file
97
services/forecasting/app/services/sales_client.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/services/sales_client.py
|
||||
# ================================================================
|
||||
"""
|
||||
Sales Client for Forecasting Service
|
||||
Wrapper around shared sales client with forecasting-specific methods
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any
|
||||
from datetime import datetime
|
||||
import structlog
|
||||
import uuid
|
||||
|
||||
from shared.clients.sales_client import SalesServiceClient
|
||||
from shared.config.base import BaseServiceSettings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class SalesClient:
|
||||
"""Client for fetching sales data from sales service"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize sales client"""
|
||||
# Load configuration
|
||||
config = BaseServiceSettings()
|
||||
self.client = SalesServiceClient(config, calling_service_name="forecasting")
|
||||
|
||||
async def get_sales_by_date_range(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
product_id: uuid.UUID = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get sales data for a date range
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
start_date: Start of date range
|
||||
end_date: End of date range
|
||||
product_id: Optional product filter
|
||||
|
||||
Returns:
|
||||
List of sales records
|
||||
"""
|
||||
try:
|
||||
# Convert datetime to ISO format strings
|
||||
start_date_str = start_date.isoformat() if start_date else None
|
||||
end_date_str = end_date.isoformat() if end_date else None
|
||||
product_id_str = str(product_id) if product_id else None
|
||||
|
||||
# Use the paginated method to get all sales data
|
||||
sales_data = await self.client.get_all_sales_data(
|
||||
tenant_id=str(tenant_id),
|
||||
start_date=start_date_str,
|
||||
end_date=end_date_str,
|
||||
product_id=product_id_str,
|
||||
aggregation="none", # Get raw data without aggregation
|
||||
page_size=1000,
|
||||
max_pages=100
|
||||
)
|
||||
|
||||
if not sales_data:
|
||||
logger.info(
|
||||
"No sales data found for date range",
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date.isoformat(),
|
||||
end_date=end_date.isoformat()
|
||||
)
|
||||
return []
|
||||
|
||||
logger.info(
|
||||
"Retrieved sales data",
|
||||
tenant_id=tenant_id,
|
||||
records_count=len(sales_data),
|
||||
start_date=start_date.isoformat(),
|
||||
end_date=end_date.isoformat()
|
||||
)
|
||||
|
||||
return sales_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to fetch sales data",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__
|
||||
)
|
||||
# Return empty list instead of raising to allow validation to continue
|
||||
return []
|
||||
|
||||
async def close(self):
|
||||
"""Close the client connection"""
|
||||
if hasattr(self.client, 'close'):
|
||||
await self.client.close()
|
||||
240
services/forecasting/app/services/tenant_deletion_service.py
Normal file
240
services/forecasting/app/services/tenant_deletion_service.py
Normal file
@@ -0,0 +1,240 @@
|
||||
# services/forecasting/app/services/tenant_deletion_service.py
|
||||
"""
|
||||
Tenant Data Deletion Service for Forecasting Service
|
||||
Handles deletion of all forecasting-related data for a tenant
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
from sqlalchemy import select, func, delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
import structlog
|
||||
|
||||
from shared.services.tenant_deletion import (
|
||||
BaseTenantDataDeletionService,
|
||||
TenantDataDeletionResult
|
||||
)
|
||||
from app.models import (
|
||||
Forecast,
|
||||
PredictionBatch,
|
||||
ModelPerformanceMetric,
|
||||
PredictionCache,
|
||||
AuditLog
|
||||
)
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class ForecastingTenantDeletionService(BaseTenantDataDeletionService):
|
||||
"""Service for deleting all forecasting-related data for a tenant"""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
self.service_name = "forecasting"
|
||||
|
||||
async def get_tenant_data_preview(self, tenant_id: str) -> Dict[str, int]:
|
||||
"""
|
||||
Get counts of what would be deleted for a tenant (dry-run)
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID to preview deletion for
|
||||
|
||||
Returns:
|
||||
Dictionary with entity names and their counts
|
||||
"""
|
||||
logger.info("forecasting.tenant_deletion.preview", tenant_id=tenant_id)
|
||||
preview = {}
|
||||
|
||||
try:
|
||||
# Count forecasts
|
||||
forecast_count = await self.db.scalar(
|
||||
select(func.count(Forecast.id)).where(
|
||||
Forecast.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
preview["forecasts"] = forecast_count or 0
|
||||
|
||||
# Count prediction batches
|
||||
batch_count = await self.db.scalar(
|
||||
select(func.count(PredictionBatch.id)).where(
|
||||
PredictionBatch.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
preview["prediction_batches"] = batch_count or 0
|
||||
|
||||
# Count model performance metrics
|
||||
metric_count = await self.db.scalar(
|
||||
select(func.count(ModelPerformanceMetric.id)).where(
|
||||
ModelPerformanceMetric.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
preview["model_performance_metrics"] = metric_count or 0
|
||||
|
||||
# Count prediction cache entries
|
||||
cache_count = await self.db.scalar(
|
||||
select(func.count(PredictionCache.id)).where(
|
||||
PredictionCache.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
preview["prediction_cache"] = cache_count or 0
|
||||
|
||||
# Count audit logs
|
||||
audit_count = await self.db.scalar(
|
||||
select(func.count(AuditLog.id)).where(
|
||||
AuditLog.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
preview["audit_logs"] = audit_count or 0
|
||||
|
||||
logger.info(
|
||||
"forecasting.tenant_deletion.preview_complete",
|
||||
tenant_id=tenant_id,
|
||||
preview=preview
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"forecasting.tenant_deletion.preview_error",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e),
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
return preview
|
||||
|
||||
async def delete_tenant_data(self, tenant_id: str) -> TenantDataDeletionResult:
|
||||
"""
|
||||
Permanently delete all forecasting data for a tenant
|
||||
|
||||
Deletion order:
|
||||
1. PredictionCache (independent)
|
||||
2. ModelPerformanceMetric (independent)
|
||||
3. PredictionBatch (independent)
|
||||
4. Forecast (independent)
|
||||
5. AuditLog (independent)
|
||||
|
||||
Note: All tables are independent with no foreign key relationships
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID to delete data for
|
||||
|
||||
Returns:
|
||||
TenantDataDeletionResult with deletion counts and any errors
|
||||
"""
|
||||
logger.info("forecasting.tenant_deletion.started", tenant_id=tenant_id)
|
||||
result = TenantDataDeletionResult(tenant_id=tenant_id, service_name=self.service_name)
|
||||
|
||||
try:
|
||||
# Step 1: Delete prediction cache
|
||||
logger.info("forecasting.tenant_deletion.deleting_cache", tenant_id=tenant_id)
|
||||
cache_result = await self.db.execute(
|
||||
delete(PredictionCache).where(
|
||||
PredictionCache.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
result.deleted_counts["prediction_cache"] = cache_result.rowcount
|
||||
logger.info(
|
||||
"forecasting.tenant_deletion.cache_deleted",
|
||||
tenant_id=tenant_id,
|
||||
count=cache_result.rowcount
|
||||
)
|
||||
|
||||
# Step 2: Delete model performance metrics
|
||||
logger.info("forecasting.tenant_deletion.deleting_metrics", tenant_id=tenant_id)
|
||||
metrics_result = await self.db.execute(
|
||||
delete(ModelPerformanceMetric).where(
|
||||
ModelPerformanceMetric.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
result.deleted_counts["model_performance_metrics"] = metrics_result.rowcount
|
||||
logger.info(
|
||||
"forecasting.tenant_deletion.metrics_deleted",
|
||||
tenant_id=tenant_id,
|
||||
count=metrics_result.rowcount
|
||||
)
|
||||
|
||||
# Step 3: Delete prediction batches
|
||||
logger.info("forecasting.tenant_deletion.deleting_batches", tenant_id=tenant_id)
|
||||
batches_result = await self.db.execute(
|
||||
delete(PredictionBatch).where(
|
||||
PredictionBatch.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
result.deleted_counts["prediction_batches"] = batches_result.rowcount
|
||||
logger.info(
|
||||
"forecasting.tenant_deletion.batches_deleted",
|
||||
tenant_id=tenant_id,
|
||||
count=batches_result.rowcount
|
||||
)
|
||||
|
||||
# Step 4: Delete forecasts
|
||||
logger.info("forecasting.tenant_deletion.deleting_forecasts", tenant_id=tenant_id)
|
||||
forecasts_result = await self.db.execute(
|
||||
delete(Forecast).where(
|
||||
Forecast.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
result.deleted_counts["forecasts"] = forecasts_result.rowcount
|
||||
logger.info(
|
||||
"forecasting.tenant_deletion.forecasts_deleted",
|
||||
tenant_id=tenant_id,
|
||||
count=forecasts_result.rowcount
|
||||
)
|
||||
|
||||
# Step 5: Delete audit logs
|
||||
logger.info("forecasting.tenant_deletion.deleting_audit_logs", tenant_id=tenant_id)
|
||||
audit_result = await self.db.execute(
|
||||
delete(AuditLog).where(
|
||||
AuditLog.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
result.deleted_counts["audit_logs"] = audit_result.rowcount
|
||||
logger.info(
|
||||
"forecasting.tenant_deletion.audit_logs_deleted",
|
||||
tenant_id=tenant_id,
|
||||
count=audit_result.rowcount
|
||||
)
|
||||
|
||||
# Commit the transaction
|
||||
await self.db.commit()
|
||||
|
||||
# Calculate total deleted
|
||||
total_deleted = sum(result.deleted_counts.values())
|
||||
|
||||
logger.info(
|
||||
"forecasting.tenant_deletion.completed",
|
||||
tenant_id=tenant_id,
|
||||
total_deleted=total_deleted,
|
||||
breakdown=result.deleted_counts
|
||||
)
|
||||
|
||||
result.success = True
|
||||
|
||||
except Exception as e:
|
||||
await self.db.rollback()
|
||||
error_msg = f"Failed to delete forecasting data for tenant {tenant_id}: {str(e)}"
|
||||
logger.error(
|
||||
"forecasting.tenant_deletion.failed",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e),
|
||||
exc_info=True
|
||||
)
|
||||
result.errors.append(error_msg)
|
||||
result.success = False
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_forecasting_tenant_deletion_service(
|
||||
db: AsyncSession
|
||||
) -> ForecastingTenantDeletionService:
|
||||
"""
|
||||
Factory function to create ForecastingTenantDeletionService instance
|
||||
|
||||
Args:
|
||||
db: AsyncSession database session
|
||||
|
||||
Returns:
|
||||
ForecastingTenantDeletionService instance
|
||||
"""
|
||||
return ForecastingTenantDeletionService(db)
|
||||
586
services/forecasting/app/services/validation_service.py
Normal file
586
services/forecasting/app/services/validation_service.py
Normal file
@@ -0,0 +1,586 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/services/validation_service.py
|
||||
# ================================================================
|
||||
"""
|
||||
Forecast Validation Service
|
||||
|
||||
Compares historical forecasts with actual sales data to:
|
||||
1. Calculate accuracy metrics (MAE, MAPE, RMSE, R², accuracy percentage)
|
||||
2. Store performance metrics in the database
|
||||
3. Track validation runs for audit purposes
|
||||
4. Enable continuous model improvement
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, func, Date
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import structlog
|
||||
import math
|
||||
import uuid
|
||||
|
||||
from app.models.forecasts import Forecast
|
||||
from app.models.predictions import ModelPerformanceMetric
|
||||
from app.models.validation_run import ValidationRun
|
||||
from shared.database.exceptions import DatabaseError, ValidationError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class ValidationService:
|
||||
"""Service for validating forecasts against actual sales data"""
|
||||
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self.db = db_session
|
||||
|
||||
async def validate_date_range(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
orchestration_run_id: Optional[uuid.UUID] = None,
|
||||
triggered_by: str = "manual"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate forecasts against actual sales for a date range
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
start_date: Start of validation period
|
||||
end_date: End of validation period
|
||||
orchestration_run_id: Optional link to orchestration run
|
||||
triggered_by: How this validation was triggered (manual, orchestrator, scheduled)
|
||||
|
||||
Returns:
|
||||
Dictionary with validation results and metrics
|
||||
"""
|
||||
validation_run = None
|
||||
|
||||
try:
|
||||
# Create validation run record
|
||||
validation_run = ValidationRun(
|
||||
tenant_id=tenant_id,
|
||||
orchestration_run_id=orchestration_run_id,
|
||||
validation_start_date=start_date,
|
||||
validation_end_date=end_date,
|
||||
status="running",
|
||||
triggered_by=triggered_by,
|
||||
execution_mode="batch"
|
||||
)
|
||||
self.db.add(validation_run)
|
||||
await self.db.flush()
|
||||
|
||||
logger.info(
|
||||
"Starting forecast validation",
|
||||
validation_run_id=validation_run.id,
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date.isoformat(),
|
||||
end_date=end_date.isoformat()
|
||||
)
|
||||
|
||||
# Fetch forecasts with matching sales data
|
||||
forecasts_with_sales = await self._fetch_forecasts_with_sales(
|
||||
tenant_id, start_date, end_date
|
||||
)
|
||||
|
||||
if not forecasts_with_sales:
|
||||
logger.warning(
|
||||
"No forecasts with matching sales data found",
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date.isoformat(),
|
||||
end_date=end_date.isoformat()
|
||||
)
|
||||
validation_run.status = "completed"
|
||||
validation_run.completed_at = datetime.now(timezone.utc)
|
||||
validation_run.duration_seconds = (
|
||||
validation_run.completed_at - validation_run.started_at
|
||||
).total_seconds()
|
||||
await self.db.commit()
|
||||
|
||||
return {
|
||||
"validation_run_id": str(validation_run.id),
|
||||
"status": "completed",
|
||||
"message": "No forecasts with matching sales data found",
|
||||
"forecasts_evaluated": 0,
|
||||
"metrics_created": 0
|
||||
}
|
||||
|
||||
# Calculate metrics and create performance records
|
||||
metrics_results = await self._calculate_and_store_metrics(
|
||||
forecasts_with_sales, validation_run.id
|
||||
)
|
||||
|
||||
# Update validation run with results
|
||||
validation_run.total_forecasts_evaluated = metrics_results["total_evaluated"]
|
||||
validation_run.forecasts_with_actuals = metrics_results["with_actuals"]
|
||||
validation_run.forecasts_without_actuals = metrics_results["without_actuals"]
|
||||
validation_run.overall_mae = metrics_results["overall_mae"]
|
||||
validation_run.overall_mape = metrics_results["overall_mape"]
|
||||
validation_run.overall_rmse = metrics_results["overall_rmse"]
|
||||
validation_run.overall_r2_score = metrics_results["overall_r2_score"]
|
||||
validation_run.overall_accuracy_percentage = metrics_results["overall_accuracy_percentage"]
|
||||
validation_run.total_predicted_demand = metrics_results["total_predicted"]
|
||||
validation_run.total_actual_demand = metrics_results["total_actual"]
|
||||
validation_run.metrics_by_product = metrics_results["metrics_by_product"]
|
||||
validation_run.metrics_by_location = metrics_results["metrics_by_location"]
|
||||
validation_run.metrics_records_created = metrics_results["metrics_created"]
|
||||
validation_run.status = "completed"
|
||||
validation_run.completed_at = datetime.now(timezone.utc)
|
||||
validation_run.duration_seconds = (
|
||||
validation_run.completed_at - validation_run.started_at
|
||||
).total_seconds()
|
||||
|
||||
await self.db.commit()
|
||||
|
||||
logger.info(
|
||||
"Forecast validation completed successfully",
|
||||
validation_run_id=validation_run.id,
|
||||
forecasts_evaluated=validation_run.total_forecasts_evaluated,
|
||||
metrics_created=validation_run.metrics_records_created,
|
||||
overall_mape=validation_run.overall_mape,
|
||||
duration_seconds=validation_run.duration_seconds
|
||||
)
|
||||
|
||||
# Extract poor accuracy products (MAPE > 30%)
|
||||
poor_accuracy_products = []
|
||||
if validation_run.metrics_by_product:
|
||||
for product_id, product_metrics in validation_run.metrics_by_product.items():
|
||||
if product_metrics.get("mape", 0) > 30:
|
||||
poor_accuracy_products.append({
|
||||
"product_id": product_id,
|
||||
"mape": product_metrics.get("mape"),
|
||||
"mae": product_metrics.get("mae"),
|
||||
"accuracy_percentage": product_metrics.get("accuracy_percentage")
|
||||
})
|
||||
|
||||
return {
|
||||
"validation_run_id": str(validation_run.id),
|
||||
"status": "completed",
|
||||
"forecasts_evaluated": validation_run.total_forecasts_evaluated,
|
||||
"forecasts_with_actuals": validation_run.forecasts_with_actuals,
|
||||
"forecasts_without_actuals": validation_run.forecasts_without_actuals,
|
||||
"metrics_created": validation_run.metrics_records_created,
|
||||
"overall_metrics": {
|
||||
"mae": validation_run.overall_mae,
|
||||
"mape": validation_run.overall_mape,
|
||||
"rmse": validation_run.overall_rmse,
|
||||
"r2_score": validation_run.overall_r2_score,
|
||||
"accuracy_percentage": validation_run.overall_accuracy_percentage
|
||||
},
|
||||
"total_predicted_demand": validation_run.total_predicted_demand,
|
||||
"total_actual_demand": validation_run.total_actual_demand,
|
||||
"duration_seconds": validation_run.duration_seconds,
|
||||
"poor_accuracy_products": poor_accuracy_products,
|
||||
"metrics_by_product": validation_run.metrics_by_product,
|
||||
"metrics_by_location": validation_run.metrics_by_location
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Forecast validation failed",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__
|
||||
)
|
||||
|
||||
if validation_run:
|
||||
validation_run.status = "failed"
|
||||
validation_run.error_message = str(e)
|
||||
validation_run.error_details = {"error_type": type(e).__name__}
|
||||
validation_run.completed_at = datetime.now(timezone.utc)
|
||||
validation_run.duration_seconds = (
|
||||
validation_run.completed_at - validation_run.started_at
|
||||
).total_seconds()
|
||||
await self.db.commit()
|
||||
|
||||
raise DatabaseError(f"Forecast validation failed: {str(e)}")
|
||||
|
||||
async def validate_yesterday(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
orchestration_run_id: Optional[uuid.UUID] = None,
|
||||
triggered_by: str = "orchestrator"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convenience method to validate yesterday's forecasts
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
orchestration_run_id: Optional link to orchestration run
|
||||
triggered_by: How this validation was triggered
|
||||
|
||||
Returns:
|
||||
Dictionary with validation results
|
||||
"""
|
||||
yesterday = datetime.now(timezone.utc) - timedelta(days=1)
|
||||
start_date = yesterday.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end_date = yesterday.replace(hour=23, minute=59, second=59, microsecond=999999)
|
||||
|
||||
return await self.validate_date_range(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
orchestration_run_id=orchestration_run_id,
|
||||
triggered_by=triggered_by
|
||||
)
|
||||
|
||||
async def _fetch_forecasts_with_sales(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
start_date: datetime,
|
||||
end_date: datetime
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch forecasts with their corresponding actual sales data
|
||||
|
||||
Returns list of dictionaries containing forecast and sales data
|
||||
"""
|
||||
try:
|
||||
# Import here to avoid circular dependency
|
||||
from app.services.sales_client import SalesClient
|
||||
|
||||
# Query to get all forecasts in the date range
|
||||
query = select(Forecast).where(
|
||||
and_(
|
||||
Forecast.tenant_id == tenant_id,
|
||||
func.cast(Forecast.forecast_date, Date) >= start_date.date(),
|
||||
func.cast(Forecast.forecast_date, Date) <= end_date.date()
|
||||
)
|
||||
).order_by(Forecast.forecast_date, Forecast.inventory_product_id)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
forecasts = result.scalars().all()
|
||||
|
||||
if not forecasts:
|
||||
return []
|
||||
|
||||
# Fetch actual sales data from sales service
|
||||
sales_client = SalesClient()
|
||||
sales_data = await sales_client.get_sales_by_date_range(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
# Create lookup dict: (product_id, date) -> sales quantity
|
||||
sales_lookup = {}
|
||||
for sale in sales_data:
|
||||
sale_date = sale['date']
|
||||
if isinstance(sale_date, str):
|
||||
sale_date = datetime.fromisoformat(sale_date.replace('Z', '+00:00'))
|
||||
|
||||
key = (str(sale['inventory_product_id']), sale_date.date())
|
||||
|
||||
# Sum quantities if multiple sales records for same product/date
|
||||
if key in sales_lookup:
|
||||
sales_lookup[key]['quantity_sold'] += sale['quantity_sold']
|
||||
else:
|
||||
sales_lookup[key] = sale
|
||||
|
||||
# Match forecasts with sales data
|
||||
forecasts_with_sales = []
|
||||
for forecast in forecasts:
|
||||
forecast_date = forecast.forecast_date.date() if hasattr(forecast.forecast_date, 'date') else forecast.forecast_date
|
||||
key = (str(forecast.inventory_product_id), forecast_date)
|
||||
|
||||
sales_record = sales_lookup.get(key)
|
||||
|
||||
forecasts_with_sales.append({
|
||||
"forecast_id": forecast.id,
|
||||
"tenant_id": forecast.tenant_id,
|
||||
"inventory_product_id": forecast.inventory_product_id,
|
||||
"product_name": forecast.product_name,
|
||||
"location": forecast.location,
|
||||
"forecast_date": forecast.forecast_date,
|
||||
"predicted_demand": forecast.predicted_demand,
|
||||
"confidence_lower": forecast.confidence_lower,
|
||||
"confidence_upper": forecast.confidence_upper,
|
||||
"model_id": forecast.model_id,
|
||||
"model_version": forecast.model_version,
|
||||
"algorithm": forecast.algorithm,
|
||||
"created_at": forecast.created_at,
|
||||
"actual_sales": sales_record['quantity_sold'] if sales_record else None,
|
||||
"has_actual_data": sales_record is not None
|
||||
})
|
||||
|
||||
return forecasts_with_sales
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to fetch forecasts with sales data",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Failed to fetch forecast and sales data: {str(e)}")
|
||||
|
||||
async def _calculate_and_store_metrics(
|
||||
self,
|
||||
forecasts_with_sales: List[Dict[str, Any]],
|
||||
validation_run_id: uuid.UUID
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Calculate accuracy metrics and store them in the database
|
||||
|
||||
Returns summary of metrics calculated
|
||||
"""
|
||||
# Separate forecasts with and without actual data
|
||||
forecasts_with_actuals = [f for f in forecasts_with_sales if f["has_actual_data"]]
|
||||
forecasts_without_actuals = [f for f in forecasts_with_sales if not f["has_actual_data"]]
|
||||
|
||||
if not forecasts_with_actuals:
|
||||
return {
|
||||
"total_evaluated": len(forecasts_with_sales),
|
||||
"with_actuals": 0,
|
||||
"without_actuals": len(forecasts_without_actuals),
|
||||
"metrics_created": 0,
|
||||
"overall_mae": None,
|
||||
"overall_mape": None,
|
||||
"overall_rmse": None,
|
||||
"overall_r2_score": None,
|
||||
"overall_accuracy_percentage": None,
|
||||
"total_predicted": 0.0,
|
||||
"total_actual": 0.0,
|
||||
"metrics_by_product": {},
|
||||
"metrics_by_location": {}
|
||||
}
|
||||
|
||||
# Calculate individual metrics and prepare for bulk insert
|
||||
performance_metrics = []
|
||||
errors = []
|
||||
|
||||
for forecast in forecasts_with_actuals:
|
||||
predicted = forecast["predicted_demand"]
|
||||
actual = forecast["actual_sales"]
|
||||
|
||||
# Calculate error
|
||||
error = abs(predicted - actual)
|
||||
errors.append(error)
|
||||
|
||||
# Calculate percentage error (avoid division by zero)
|
||||
percentage_error = (error / actual * 100) if actual > 0 else 0.0
|
||||
|
||||
# Calculate individual metrics
|
||||
mae = error
|
||||
mape = percentage_error
|
||||
rmse = error # Will be squared and averaged later
|
||||
|
||||
# Calculate accuracy percentage (100% - MAPE, capped at 0)
|
||||
accuracy_percentage = max(0.0, 100.0 - mape)
|
||||
|
||||
# Create performance metric record
|
||||
metric = ModelPerformanceMetric(
|
||||
model_id=uuid.UUID(forecast["model_id"]) if isinstance(forecast["model_id"], str) else forecast["model_id"],
|
||||
tenant_id=forecast["tenant_id"],
|
||||
inventory_product_id=forecast["inventory_product_id"],
|
||||
mae=mae,
|
||||
mape=mape,
|
||||
rmse=rmse,
|
||||
accuracy_score=accuracy_percentage / 100.0, # Store as 0-1 scale
|
||||
evaluation_date=forecast["forecast_date"],
|
||||
evaluation_period_start=forecast["forecast_date"],
|
||||
evaluation_period_end=forecast["forecast_date"],
|
||||
sample_size=1
|
||||
)
|
||||
performance_metrics.append(metric)
|
||||
|
||||
# Bulk insert all performance metrics
|
||||
if performance_metrics:
|
||||
self.db.add_all(performance_metrics)
|
||||
await self.db.flush()
|
||||
|
||||
# Calculate overall metrics
|
||||
overall_metrics = self._calculate_overall_metrics(forecasts_with_actuals)
|
||||
|
||||
# Calculate metrics by product
|
||||
metrics_by_product = self._calculate_metrics_by_dimension(
|
||||
forecasts_with_actuals, "inventory_product_id"
|
||||
)
|
||||
|
||||
# Calculate metrics by location
|
||||
metrics_by_location = self._calculate_metrics_by_dimension(
|
||||
forecasts_with_actuals, "location"
|
||||
)
|
||||
|
||||
return {
|
||||
"total_evaluated": len(forecasts_with_sales),
|
||||
"with_actuals": len(forecasts_with_actuals),
|
||||
"without_actuals": len(forecasts_without_actuals),
|
||||
"metrics_created": len(performance_metrics),
|
||||
"overall_mae": overall_metrics["mae"],
|
||||
"overall_mape": overall_metrics["mape"],
|
||||
"overall_rmse": overall_metrics["rmse"],
|
||||
"overall_r2_score": overall_metrics["r2_score"],
|
||||
"overall_accuracy_percentage": overall_metrics["accuracy_percentage"],
|
||||
"total_predicted": overall_metrics["total_predicted"],
|
||||
"total_actual": overall_metrics["total_actual"],
|
||||
"metrics_by_product": metrics_by_product,
|
||||
"metrics_by_location": metrics_by_location
|
||||
}
|
||||
|
||||
def _calculate_overall_metrics(self, forecasts: List[Dict[str, Any]]) -> Dict[str, float]:
|
||||
"""Calculate aggregated metrics across all forecasts"""
|
||||
if not forecasts:
|
||||
return {
|
||||
"mae": None, "mape": None, "rmse": None,
|
||||
"r2_score": None, "accuracy_percentage": None,
|
||||
"total_predicted": 0.0, "total_actual": 0.0
|
||||
}
|
||||
|
||||
predicted_values = [f["predicted_demand"] for f in forecasts]
|
||||
actual_values = [f["actual_sales"] for f in forecasts]
|
||||
|
||||
# MAE: Mean Absolute Error
|
||||
mae = sum(abs(p - a) for p, a in zip(predicted_values, actual_values)) / len(forecasts)
|
||||
|
||||
# MAPE: Mean Absolute Percentage Error (handle division by zero)
|
||||
mape_values = [
|
||||
abs(p - a) / a * 100 if a > 0 else 0.0
|
||||
for p, a in zip(predicted_values, actual_values)
|
||||
]
|
||||
mape = sum(mape_values) / len(mape_values)
|
||||
|
||||
# RMSE: Root Mean Square Error
|
||||
squared_errors = [(p - a) ** 2 for p, a in zip(predicted_values, actual_values)]
|
||||
rmse = math.sqrt(sum(squared_errors) / len(squared_errors))
|
||||
|
||||
# R² Score (coefficient of determination)
|
||||
mean_actual = sum(actual_values) / len(actual_values)
|
||||
ss_total = sum((a - mean_actual) ** 2 for a in actual_values)
|
||||
ss_residual = sum((a - p) ** 2 for a, p in zip(actual_values, predicted_values))
|
||||
r2_score = 1 - (ss_residual / ss_total) if ss_total > 0 else 0.0
|
||||
|
||||
# Accuracy percentage (100% - MAPE)
|
||||
accuracy_percentage = max(0.0, 100.0 - mape)
|
||||
|
||||
return {
|
||||
"mae": round(mae, 2),
|
||||
"mape": round(mape, 2),
|
||||
"rmse": round(rmse, 2),
|
||||
"r2_score": round(r2_score, 4),
|
||||
"accuracy_percentage": round(accuracy_percentage, 2),
|
||||
"total_predicted": round(sum(predicted_values), 2),
|
||||
"total_actual": round(sum(actual_values), 2)
|
||||
}
|
||||
|
||||
def _calculate_metrics_by_dimension(
|
||||
self,
|
||||
forecasts: List[Dict[str, Any]],
|
||||
dimension_key: str
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
"""Calculate metrics grouped by a dimension (product_id or location)"""
|
||||
dimension_groups = {}
|
||||
|
||||
# Group forecasts by dimension
|
||||
for forecast in forecasts:
|
||||
key = str(forecast[dimension_key])
|
||||
if key not in dimension_groups:
|
||||
dimension_groups[key] = []
|
||||
dimension_groups[key].append(forecast)
|
||||
|
||||
# Calculate metrics for each group
|
||||
metrics_by_dimension = {}
|
||||
for key, group_forecasts in dimension_groups.items():
|
||||
metrics = self._calculate_overall_metrics(group_forecasts)
|
||||
metrics_by_dimension[key] = {
|
||||
"count": len(group_forecasts),
|
||||
"mae": metrics["mae"],
|
||||
"mape": metrics["mape"],
|
||||
"rmse": metrics["rmse"],
|
||||
"accuracy_percentage": metrics["accuracy_percentage"],
|
||||
"total_predicted": metrics["total_predicted"],
|
||||
"total_actual": metrics["total_actual"]
|
||||
}
|
||||
|
||||
return metrics_by_dimension
|
||||
|
||||
async def get_validation_run(self, validation_run_id: uuid.UUID) -> Optional[ValidationRun]:
|
||||
"""Get a validation run by ID"""
|
||||
try:
|
||||
query = select(ValidationRun).where(ValidationRun.id == validation_run_id)
|
||||
result = await self.db.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error("Failed to get validation run", validation_run_id=validation_run_id, error=str(e))
|
||||
raise DatabaseError(f"Failed to get validation run: {str(e)}")
|
||||
|
||||
async def get_validation_runs_by_tenant(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
limit: int = 50,
|
||||
skip: int = 0
|
||||
) -> List[ValidationRun]:
|
||||
"""Get validation runs for a tenant"""
|
||||
try:
|
||||
query = (
|
||||
select(ValidationRun)
|
||||
.where(ValidationRun.tenant_id == tenant_id)
|
||||
.order_by(ValidationRun.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(skip)
|
||||
)
|
||||
result = await self.db.execute(query)
|
||||
return result.scalars().all()
|
||||
except Exception as e:
|
||||
logger.error("Failed to get validation runs", tenant_id=tenant_id, error=str(e))
|
||||
raise DatabaseError(f"Failed to get validation runs: {str(e)}")
|
||||
|
||||
async def get_accuracy_trends(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
days: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""Get accuracy trends over time"""
|
||||
try:
|
||||
start_date = datetime.now(timezone.utc) - timedelta(days=days)
|
||||
|
||||
query = (
|
||||
select(ValidationRun)
|
||||
.where(
|
||||
and_(
|
||||
ValidationRun.tenant_id == tenant_id,
|
||||
ValidationRun.status == "completed",
|
||||
ValidationRun.created_at >= start_date
|
||||
)
|
||||
)
|
||||
.order_by(ValidationRun.created_at)
|
||||
)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
runs = result.scalars().all()
|
||||
|
||||
if not runs:
|
||||
return {
|
||||
"period_days": days,
|
||||
"total_runs": 0,
|
||||
"trends": []
|
||||
}
|
||||
|
||||
trends = [
|
||||
{
|
||||
"date": run.validation_start_date.isoformat(),
|
||||
"mae": run.overall_mae,
|
||||
"mape": run.overall_mape,
|
||||
"rmse": run.overall_rmse,
|
||||
"accuracy_percentage": run.overall_accuracy_percentage,
|
||||
"forecasts_evaluated": run.total_forecasts_evaluated,
|
||||
"forecasts_with_actuals": run.forecasts_with_actuals
|
||||
}
|
||||
for run in runs
|
||||
]
|
||||
|
||||
# Calculate averages
|
||||
valid_runs = [r for r in runs if r.overall_mape is not None]
|
||||
avg_mape = sum(r.overall_mape for r in valid_runs) / len(valid_runs) if valid_runs else None
|
||||
avg_accuracy = sum(r.overall_accuracy_percentage for r in valid_runs) / len(valid_runs) if valid_runs else None
|
||||
|
||||
return {
|
||||
"period_days": days,
|
||||
"total_runs": len(runs),
|
||||
"average_mape": round(avg_mape, 2) if avg_mape else None,
|
||||
"average_accuracy": round(avg_accuracy, 2) if avg_accuracy else None,
|
||||
"trends": trends
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get accuracy trends", tenant_id=tenant_id, error=str(e))
|
||||
raise DatabaseError(f"Failed to get accuracy trends: {str(e)}")
|
||||
3
services/forecasting/app/utils/__init__.py
Normal file
3
services/forecasting/app/utils/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Utility modules for forecasting service
|
||||
"""
|
||||
258
services/forecasting/app/utils/distributed_lock.py
Normal file
258
services/forecasting/app/utils/distributed_lock.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""
|
||||
Distributed Locking Mechanisms for Forecasting Service
|
||||
Prevents concurrent forecast generation for the same product/date
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Optional
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import text
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LockAcquisitionError(Exception):
|
||||
"""Raised when lock cannot be acquired"""
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseLock:
|
||||
"""
|
||||
Database-based distributed lock using PostgreSQL advisory locks.
|
||||
Works across multiple service instances.
|
||||
"""
|
||||
|
||||
def __init__(self, lock_name: str, timeout: float = 30.0):
|
||||
"""
|
||||
Initialize database lock.
|
||||
|
||||
Args:
|
||||
lock_name: Unique identifier for the lock
|
||||
timeout: Maximum seconds to wait for lock acquisition
|
||||
"""
|
||||
self.lock_name = lock_name
|
||||
self.timeout = timeout
|
||||
self.lock_id = self._hash_lock_name(lock_name)
|
||||
|
||||
def _hash_lock_name(self, name: str) -> int:
|
||||
"""Convert lock name to integer ID for PostgreSQL advisory lock"""
|
||||
# Use hash and modulo to get a positive 32-bit integer
|
||||
return abs(hash(name)) % (2**31)
|
||||
|
||||
@asynccontextmanager
|
||||
async def acquire(self, session: AsyncSession):
|
||||
"""
|
||||
Acquire distributed lock as async context manager.
|
||||
|
||||
Args:
|
||||
session: Database session for lock operations
|
||||
|
||||
Raises:
|
||||
LockAcquisitionError: If lock cannot be acquired within timeout
|
||||
"""
|
||||
acquired = False
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Try to acquire lock with timeout
|
||||
while time.time() - start_time < self.timeout:
|
||||
# Try non-blocking lock acquisition
|
||||
result = await session.execute(
|
||||
text("SELECT pg_try_advisory_lock(:lock_id)"),
|
||||
{"lock_id": self.lock_id}
|
||||
)
|
||||
acquired = result.scalar()
|
||||
|
||||
if acquired:
|
||||
logger.info(f"Acquired lock: {self.lock_name} (id={self.lock_id})")
|
||||
break
|
||||
|
||||
# Wait a bit before retrying
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
if not acquired:
|
||||
raise LockAcquisitionError(
|
||||
f"Could not acquire lock '{self.lock_name}' within {self.timeout}s"
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
if acquired:
|
||||
# Release lock
|
||||
await session.execute(
|
||||
text("SELECT pg_advisory_unlock(:lock_id)"),
|
||||
{"lock_id": self.lock_id}
|
||||
)
|
||||
logger.info(f"Released lock: {self.lock_name} (id={self.lock_id})")
|
||||
|
||||
|
||||
class SimpleDatabaseLock:
|
||||
"""
|
||||
Simple table-based distributed lock.
|
||||
Alternative to advisory locks, uses a dedicated locks table.
|
||||
"""
|
||||
|
||||
def __init__(self, lock_name: str, timeout: float = 30.0, ttl: float = 300.0):
|
||||
"""
|
||||
Initialize simple database lock.
|
||||
|
||||
Args:
|
||||
lock_name: Unique identifier for the lock
|
||||
timeout: Maximum seconds to wait for lock acquisition
|
||||
ttl: Time-to-live for stale lock cleanup (seconds)
|
||||
"""
|
||||
self.lock_name = lock_name
|
||||
self.timeout = timeout
|
||||
self.ttl = ttl
|
||||
|
||||
async def _ensure_lock_table(self, session: AsyncSession):
|
||||
"""Ensure locks table exists"""
|
||||
create_table_sql = """
|
||||
CREATE TABLE IF NOT EXISTS distributed_locks (
|
||||
lock_name VARCHAR(255) PRIMARY KEY,
|
||||
acquired_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
||||
acquired_by VARCHAR(255),
|
||||
expires_at TIMESTAMP WITH TIME ZONE NOT NULL
|
||||
)
|
||||
"""
|
||||
await session.execute(text(create_table_sql))
|
||||
await session.commit()
|
||||
|
||||
async def _cleanup_stale_locks(self, session: AsyncSession):
|
||||
"""Remove expired locks"""
|
||||
cleanup_sql = """
|
||||
DELETE FROM distributed_locks
|
||||
WHERE expires_at < :now
|
||||
"""
|
||||
await session.execute(
|
||||
text(cleanup_sql),
|
||||
{"now": datetime.now(timezone.utc)}
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
@asynccontextmanager
|
||||
async def acquire(self, session: AsyncSession, owner: str = "forecasting-service"):
|
||||
"""
|
||||
Acquire simple database lock.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
owner: Identifier for lock owner
|
||||
|
||||
Raises:
|
||||
LockAcquisitionError: If lock cannot be acquired
|
||||
"""
|
||||
await self._ensure_lock_table(session)
|
||||
await self._cleanup_stale_locks(session)
|
||||
|
||||
acquired = False
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Try to acquire lock
|
||||
while time.time() - start_time < self.timeout:
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now + timedelta(seconds=self.ttl)
|
||||
|
||||
try:
|
||||
# Try to insert lock record
|
||||
insert_sql = """
|
||||
INSERT INTO distributed_locks (lock_name, acquired_at, acquired_by, expires_at)
|
||||
VALUES (:lock_name, :acquired_at, :acquired_by, :expires_at)
|
||||
ON CONFLICT (lock_name) DO NOTHING
|
||||
RETURNING lock_name
|
||||
"""
|
||||
|
||||
result = await session.execute(
|
||||
text(insert_sql),
|
||||
{
|
||||
"lock_name": self.lock_name,
|
||||
"acquired_at": now,
|
||||
"acquired_by": owner,
|
||||
"expires_at": expires_at
|
||||
}
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
if result.rowcount > 0:
|
||||
acquired = True
|
||||
logger.info(f"Acquired simple lock: {self.lock_name}")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Lock acquisition attempt failed: {e}")
|
||||
await session.rollback()
|
||||
|
||||
# Wait before retrying
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
if not acquired:
|
||||
raise LockAcquisitionError(
|
||||
f"Could not acquire lock '{self.lock_name}' within {self.timeout}s"
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
if acquired:
|
||||
# Release lock
|
||||
delete_sql = """
|
||||
DELETE FROM distributed_locks
|
||||
WHERE lock_name = :lock_name
|
||||
"""
|
||||
await session.execute(
|
||||
text(delete_sql),
|
||||
{"lock_name": self.lock_name}
|
||||
)
|
||||
await session.commit()
|
||||
logger.info(f"Released simple lock: {self.lock_name}")
|
||||
|
||||
|
||||
def get_forecast_lock(
|
||||
tenant_id: str,
|
||||
product_id: str,
|
||||
forecast_date: str,
|
||||
use_advisory: bool = True
|
||||
) -> DatabaseLock:
|
||||
"""
|
||||
Get distributed lock for generating a forecast for a specific product and date.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
product_id: Product identifier
|
||||
forecast_date: Forecast date (ISO format)
|
||||
use_advisory: Use PostgreSQL advisory locks (True) or table-based (False)
|
||||
|
||||
Returns:
|
||||
Lock instance
|
||||
"""
|
||||
lock_name = f"forecast:{tenant_id}:{product_id}:{forecast_date}"
|
||||
|
||||
if use_advisory:
|
||||
return DatabaseLock(lock_name, timeout=30.0)
|
||||
else:
|
||||
return SimpleDatabaseLock(lock_name, timeout=30.0, ttl=300.0)
|
||||
|
||||
|
||||
def get_batch_forecast_lock(tenant_id: str, use_advisory: bool = True) -> DatabaseLock:
|
||||
"""
|
||||
Get distributed lock for batch forecast generation for a tenant.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
use_advisory: Use PostgreSQL advisory locks (True) or table-based (False)
|
||||
|
||||
Returns:
|
||||
Lock instance
|
||||
"""
|
||||
lock_name = f"forecast_batch:{tenant_id}"
|
||||
|
||||
if use_advisory:
|
||||
return DatabaseLock(lock_name, timeout=60.0)
|
||||
else:
|
||||
return SimpleDatabaseLock(lock_name, timeout=60.0, ttl=600.0)
|
||||
Reference in New Issue
Block a user