Initial commit - production deployment

This commit is contained in:
2026-01-21 17:17:16 +01:00
commit c23d00dd92
2289 changed files with 638440 additions and 0 deletions

View File

View File

@@ -0,0 +1,27 @@
"""
Forecasting API Layer
HTTP endpoints for demand forecasting and prediction operations
"""
from .forecasts import router as forecasts_router
from .forecasting_operations import router as forecasting_operations_router
from .analytics import router as analytics_router
from .validation import router as validation_router
from .historical_validation import router as historical_validation_router
from .webhooks import router as webhooks_router
from .performance_monitoring import router as performance_monitoring_router
from .retraining import router as retraining_router
from .enterprise_forecasting import router as enterprise_forecasting_router
__all__ = [
"forecasts_router",
"forecasting_operations_router",
"analytics_router",
"validation_router",
"historical_validation_router",
"webhooks_router",
"performance_monitoring_router",
"retraining_router",
"enterprise_forecasting_router",
]

View File

@@ -0,0 +1,55 @@
# services/forecasting/app/api/analytics.py
"""
Forecasting Analytics API - Reporting, statistics, and insights
"""
import structlog
from fastapi import APIRouter, Depends, HTTPException, status, Query, Path
from datetime import date
from typing import Optional
from app.services.prediction_service import PredictionService
from shared.database.base import create_database_manager
from app.core.config import settings
from shared.routing import RouteBuilder
from shared.auth.access_control import analytics_tier_required
route_builder = RouteBuilder('forecasting')
logger = structlog.get_logger()
router = APIRouter(tags=["forecasting-analytics"])
def get_enhanced_prediction_service():
"""Dependency injection for enhanced PredictionService"""
database_manager = create_database_manager(settings.DATABASE_URL, "forecasting-service")
return PredictionService(database_manager)
@router.get(
route_builder.build_analytics_route("predictions-performance")
)
@analytics_tier_required
async def get_predictions_performance(
tenant_id: str = Path(..., description="Tenant ID"),
start_date: Optional[date] = Query(None),
end_date: Optional[date] = Query(None),
prediction_service: PredictionService = Depends(get_enhanced_prediction_service)
):
"""Get predictions performance analytics (Professional+ tier required)"""
try:
logger.info("Getting predictions performance", tenant_id=tenant_id)
performance = await prediction_service.get_performance_metrics(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date
)
return performance
except Exception as e:
logger.error("Failed to get predictions performance", error=str(e), tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve predictions performance"
)

View File

@@ -0,0 +1,237 @@
# services/forecasting/app/api/audit.py
"""
Audit Logs API - Retrieve audit trail for forecasting service
"""
from fastapi import APIRouter, Depends, HTTPException, Query, Path, status
from typing import Optional, Dict, Any
from uuid import UUID
from datetime import datetime
import structlog
from sqlalchemy import select, func, and_
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import AuditLog
from shared.auth.decorators import get_current_user_dep
from shared.auth.access_control import require_user_role
from shared.routing import RouteBuilder
from shared.models.audit_log_schemas import (
AuditLogResponse,
AuditLogListResponse,
AuditLogStatsResponse
)
from app.core.database import database_manager
route_builder = RouteBuilder('forecasting')
router = APIRouter(tags=["audit-logs"])
logger = structlog.get_logger()
async def get_db():
"""Database session dependency"""
async with database_manager.get_session() as session:
yield session
@router.get(
route_builder.build_base_route("audit-logs"),
response_model=AuditLogListResponse
)
@require_user_role(['admin', 'owner'])
async def get_audit_logs(
tenant_id: UUID = Path(..., description="Tenant ID"),
start_date: Optional[datetime] = Query(None, description="Filter logs from this date"),
end_date: Optional[datetime] = Query(None, description="Filter logs until this date"),
user_id: Optional[UUID] = Query(None, description="Filter by user ID"),
action: Optional[str] = Query(None, description="Filter by action type"),
resource_type: Optional[str] = Query(None, description="Filter by resource type"),
severity: Optional[str] = Query(None, description="Filter by severity level"),
search: Optional[str] = Query(None, description="Search in description field"),
limit: int = Query(100, ge=1, le=1000, description="Number of records to return"),
offset: int = Query(0, ge=0, description="Number of records to skip"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""
Get audit logs for forecasting service.
Requires admin or owner role.
"""
try:
logger.info(
"Retrieving audit logs",
tenant_id=tenant_id,
user_id=current_user.get("user_id"),
filters={
"start_date": start_date,
"end_date": end_date,
"action": action,
"resource_type": resource_type,
"severity": severity
}
)
# Build query filters
filters = [AuditLog.tenant_id == tenant_id]
if start_date:
filters.append(AuditLog.created_at >= start_date)
if end_date:
filters.append(AuditLog.created_at <= end_date)
if user_id:
filters.append(AuditLog.user_id == user_id)
if action:
filters.append(AuditLog.action == action)
if resource_type:
filters.append(AuditLog.resource_type == resource_type)
if severity:
filters.append(AuditLog.severity == severity)
if search:
filters.append(AuditLog.description.ilike(f"%{search}%"))
# Count total matching records
count_query = select(func.count()).select_from(AuditLog).where(and_(*filters))
total_result = await db.execute(count_query)
total = total_result.scalar() or 0
# Fetch paginated results
query = (
select(AuditLog)
.where(and_(*filters))
.order_by(AuditLog.created_at.desc())
.limit(limit)
.offset(offset)
)
result = await db.execute(query)
audit_logs = result.scalars().all()
# Convert to response models
items = [AuditLogResponse.from_orm(log) for log in audit_logs]
logger.info(
"Successfully retrieved audit logs",
tenant_id=tenant_id,
total=total,
returned=len(items)
)
return AuditLogListResponse(
items=items,
total=total,
limit=limit,
offset=offset,
has_more=(offset + len(items)) < total
)
except Exception as e:
logger.error(
"Failed to retrieve audit logs",
error=str(e),
tenant_id=tenant_id
)
raise HTTPException(
status_code=500,
detail=f"Failed to retrieve audit logs: {str(e)}"
)
@router.get(
route_builder.build_base_route("audit-logs/stats"),
response_model=AuditLogStatsResponse
)
@require_user_role(['admin', 'owner'])
async def get_audit_log_stats(
tenant_id: UUID = Path(..., description="Tenant ID"),
start_date: Optional[datetime] = Query(None, description="Filter logs from this date"),
end_date: Optional[datetime] = Query(None, description="Filter logs until this date"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""
Get audit log statistics for forecasting service.
Requires admin or owner role.
"""
try:
logger.info(
"Retrieving audit log statistics",
tenant_id=tenant_id,
user_id=current_user.get("user_id")
)
# Build base filters
filters = [AuditLog.tenant_id == tenant_id]
if start_date:
filters.append(AuditLog.created_at >= start_date)
if end_date:
filters.append(AuditLog.created_at <= end_date)
# Total events
count_query = select(func.count()).select_from(AuditLog).where(and_(*filters))
total_result = await db.execute(count_query)
total_events = total_result.scalar() or 0
# Events by action
action_query = (
select(AuditLog.action, func.count().label('count'))
.where(and_(*filters))
.group_by(AuditLog.action)
)
action_result = await db.execute(action_query)
events_by_action = {row.action: row.count for row in action_result}
# Events by severity
severity_query = (
select(AuditLog.severity, func.count().label('count'))
.where(and_(*filters))
.group_by(AuditLog.severity)
)
severity_result = await db.execute(severity_query)
events_by_severity = {row.severity: row.count for row in severity_result}
# Events by resource type
resource_query = (
select(AuditLog.resource_type, func.count().label('count'))
.where(and_(*filters))
.group_by(AuditLog.resource_type)
)
resource_result = await db.execute(resource_query)
events_by_resource_type = {row.resource_type: row.count for row in resource_result}
# Date range
date_range_query = (
select(
func.min(AuditLog.created_at).label('min_date'),
func.max(AuditLog.created_at).label('max_date')
)
.where(and_(*filters))
)
date_result = await db.execute(date_range_query)
date_row = date_result.one()
logger.info(
"Successfully retrieved audit log statistics",
tenant_id=tenant_id,
total_events=total_events
)
return AuditLogStatsResponse(
total_events=total_events,
events_by_action=events_by_action,
events_by_severity=events_by_severity,
events_by_resource_type=events_by_resource_type,
date_range={
"min": date_row.min_date,
"max": date_row.max_date
}
)
except Exception as e:
logger.error(
"Failed to retrieve audit log statistics",
error=str(e),
tenant_id=tenant_id
)
raise HTTPException(
status_code=500,
detail=f"Failed to retrieve audit log statistics: {str(e)}"
)

View File

@@ -0,0 +1,108 @@
"""
Enterprise forecasting API endpoints
"""
from fastapi import APIRouter, Depends, HTTPException, Query
from typing import Optional
from datetime import date
import structlog
from app.services.enterprise_forecasting_service import EnterpriseForecastingService
from shared.auth.tenant_access import verify_tenant_permission_dep
from shared.clients import get_forecast_client, get_tenant_client
import shared.redis_utils
from app.core.config import settings
logger = structlog.get_logger()
router = APIRouter()
# Global Redis client
_redis_client = None
async def get_forecasting_redis_client():
"""Get or create Redis client"""
global _redis_client
try:
if _redis_client is None:
_redis_client = await shared.redis_utils.initialize_redis(settings.REDIS_URL)
logger.info("Redis client initialized for enterprise forecasting")
return _redis_client
except Exception as e:
logger.warning("Failed to initialize Redis client, enterprise forecasting will work with limited functionality", error=str(e))
return None
async def get_enterprise_forecasting_service(
redis_client = Depends(get_forecasting_redis_client)
) -> EnterpriseForecastingService:
"""Dependency injection for EnterpriseForecastingService"""
forecast_client = get_forecast_client(settings, "forecasting-service")
tenant_client = get_tenant_client(settings, "forecasting-service")
return EnterpriseForecastingService(
forecast_client=forecast_client,
tenant_client=tenant_client,
redis_client=redis_client
)
@router.get("/tenants/{tenant_id}/forecasting/enterprise/aggregated")
async def get_aggregated_forecast(
tenant_id: str,
start_date: date = Query(..., description="Start date for forecast aggregation"),
end_date: date = Query(..., description="End date for forecast aggregation"),
product_id: Optional[str] = Query(None, description="Optional product ID to filter by"),
enterprise_forecasting_service: EnterpriseForecastingService = Depends(get_enterprise_forecasting_service),
verified_tenant: str = Depends(verify_tenant_permission_dep)
):
"""
Get aggregated forecasts across parent and child tenants
"""
try:
# Check if this tenant is a parent tenant
tenant_info = await enterprise_forecasting_service.tenant_client.get_tenant(tenant_id)
if tenant_info.get('tenant_type') != 'parent':
raise HTTPException(
status_code=403,
detail="Only parent tenants can access aggregated enterprise forecasts"
)
result = await enterprise_forecasting_service.get_aggregated_forecast(
parent_tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
product_id=product_id
)
return result
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get aggregated forecast: {str(e)}")
@router.get("/tenants/{tenant_id}/forecasting/enterprise/network-performance")
async def get_network_performance_metrics(
tenant_id: str,
start_date: date = Query(..., description="Start date for metrics"),
end_date: date = Query(..., description="End date for metrics"),
enterprise_forecasting_service: EnterpriseForecastingService = Depends(get_enterprise_forecasting_service),
verified_tenant: str = Depends(verify_tenant_permission_dep)
):
"""
Get aggregated performance metrics across tenant network
"""
try:
# Check if this tenant is a parent tenant
tenant_info = await enterprise_forecasting_service.tenant_client.get_tenant(tenant_id)
if tenant_info.get('tenant_type') != 'parent':
raise HTTPException(
status_code=403,
detail="Only parent tenants can access network performance metrics"
)
result = await enterprise_forecasting_service.get_network_performance_metrics(
parent_tenant_id=tenant_id,
start_date=start_date,
end_date=end_date
)
return result
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get network performance: {str(e)}")

View File

@@ -0,0 +1,417 @@
# services/forecasting/app/api/forecast_feedback.py
"""
Forecast Feedback API - Endpoints for collecting and analyzing forecast feedback
"""
import structlog
from fastapi import APIRouter, Depends, HTTPException, status, Query, Path, Body
from typing import List, Optional, Dict, Any
from datetime import date, datetime
import uuid
import enum
from pydantic import BaseModel, Field
from app.services.forecast_feedback_service import ForecastFeedbackService
from shared.database.base import create_database_manager
from app.core.config import settings
from shared.routing import RouteBuilder
from shared.auth.tenant_access import verify_tenant_permission_dep
route_builder = RouteBuilder('forecasting')
logger = structlog.get_logger()
router = APIRouter(tags=["forecast-feedback"])
# Enums for feedback types
class FeedbackType(str, enum.Enum):
"""Type of feedback on forecast accuracy"""
TOO_HIGH = "too_high"
TOO_LOW = "too_low"
ACCURATE = "accurate"
UNCERTAIN = "uncertain"
class FeedbackConfidence(str, enum.Enum):
"""Confidence level of the feedback provider"""
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
# Pydantic models
from pydantic import BaseModel, Field
class ForecastFeedbackRequest(BaseModel):
"""Request model for submitting forecast feedback"""
feedback_type: FeedbackType = Field(..., description="Type of feedback on forecast accuracy")
confidence: FeedbackConfidence = Field(..., description="Confidence level of the feedback provider")
actual_value: Optional[float] = Field(None, description="Actual observed value")
notes: Optional[str] = Field(None, description="Additional notes about the feedback")
feedback_data: Optional[Dict[str, Any]] = Field(None, description="Additional feedback data")
class ForecastFeedbackResponse(BaseModel):
"""Response model for forecast feedback"""
feedback_id: str = Field(..., description="Unique feedback ID")
forecast_id: str = Field(..., description="Forecast ID this feedback relates to")
tenant_id: str = Field(..., description="Tenant ID")
feedback_type: FeedbackType = Field(..., description="Type of feedback")
confidence: FeedbackConfidence = Field(..., description="Confidence level")
actual_value: Optional[float] = Field(None, description="Actual value observed")
notes: Optional[str] = Field(None, description="Feedback notes")
feedback_data: Dict[str, Any] = Field(..., description="Additional feedback data")
created_at: datetime = Field(..., description="When feedback was created")
created_by: Optional[str] = Field(None, description="Who created the feedback")
class ForecastAccuracyMetrics(BaseModel):
"""Accuracy metrics for a forecast"""
forecast_id: str = Field(..., description="Forecast ID")
total_feedback_count: int = Field(..., description="Total feedback received")
accuracy_score: float = Field(..., description="Calculated accuracy score (0-100)")
feedback_distribution: Dict[str, int] = Field(..., description="Distribution of feedback types")
average_confidence: float = Field(..., description="Average confidence score")
last_feedback_date: Optional[datetime] = Field(None, description="Most recent feedback date")
class ForecasterPerformanceMetrics(BaseModel):
"""Performance metrics for the forecasting system"""
overall_accuracy: float = Field(..., description="Overall system accuracy score")
total_forecasts_with_feedback: int = Field(..., description="Total forecasts with feedback")
accuracy_by_product: Dict[str, float] = Field(..., description="Accuracy by product type")
accuracy_trend: str = Field(..., description="Trend direction: improving, declining, stable")
improvement_suggestions: List[str] = Field(..., description="AI-generated improvement suggestions")
def get_forecast_feedback_service():
"""Dependency injection for ForecastFeedbackService"""
database_manager = create_database_manager(settings.DATABASE_URL, "forecasting-service")
return ForecastFeedbackService(database_manager)
@router.post(
route_builder.build_nested_resource_route("forecasts", "forecast_id", "feedback"),
response_model=ForecastFeedbackResponse,
status_code=status.HTTP_201_CREATED
)
async def submit_forecast_feedback(
tenant_id: str = Path(..., description="Tenant ID"),
forecast_id: str = Path(..., description="Forecast ID"),
feedback_request: ForecastFeedbackRequest = Body(...),
forecast_feedback_service: ForecastFeedbackService = Depends(get_forecast_feedback_service),
verified_tenant: str = Depends(verify_tenant_permission_dep)
):
"""
Submit feedback on forecast accuracy
Allows users to provide feedback on whether forecasts were accurate, too high, or too low.
This feedback is used to improve future forecast accuracy through continuous learning.
"""
try:
logger.info("Submitting forecast feedback",
tenant_id=tenant_id, forecast_id=forecast_id,
feedback_type=feedback_request.feedback_type)
# Validate forecast exists
forecast_exists = await forecast_feedback_service.forecast_exists(tenant_id, forecast_id)
if not forecast_exists:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Forecast not found"
)
# Submit feedback
feedback = await forecast_feedback_service.submit_feedback(
tenant_id=tenant_id,
forecast_id=forecast_id,
feedback_type=feedback_request.feedback_type,
confidence=feedback_request.confidence,
actual_value=feedback_request.actual_value,
notes=feedback_request.notes,
feedback_data=feedback_request.feedback_data
)
return {
'feedback_id': str(feedback.feedback_id),
'forecast_id': str(feedback.forecast_id),
'tenant_id': feedback.tenant_id,
'feedback_type': feedback.feedback_type,
'confidence': feedback.confidence,
'actual_value': feedback.actual_value,
'notes': feedback.notes,
'feedback_data': feedback.feedback_data or {},
'created_at': feedback.created_at,
'created_by': feedback.created_by
}
except HTTPException:
raise
except ValueError as e:
logger.error("Invalid forecast ID", error=str(e))
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid forecast ID format"
)
except Exception as e:
logger.error("Failed to submit forecast feedback", error=str(e), tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to submit feedback"
)
@router.get(
route_builder.build_nested_resource_route("forecasts", "forecast_id", "feedback"),
response_model=List[ForecastFeedbackResponse]
)
async def get_forecast_feedback(
tenant_id: str = Path(..., description="Tenant ID"),
forecast_id: str = Path(..., description="Forecast ID"),
limit: int = Query(50, ge=1, le=1000),
offset: int = Query(0, ge=0),
forecast_feedback_service: ForecastFeedbackService = Depends(get_forecast_feedback_service),
verified_tenant: str = Depends(verify_tenant_permission_dep)
):
"""
Get all feedback for a specific forecast
Retrieves historical feedback submissions for analysis and auditing.
"""
try:
logger.info("Getting forecast feedback", tenant_id=tenant_id, forecast_id=forecast_id)
feedback_list = await forecast_feedback_service.get_feedback_for_forecast(
tenant_id=tenant_id,
forecast_id=forecast_id,
limit=limit,
offset=offset
)
return [
ForecastFeedbackResponse(
feedback_id=str(f.feedback_id),
forecast_id=str(f.forecast_id),
tenant_id=f.tenant_id,
feedback_type=f.feedback_type,
confidence=f.confidence,
actual_value=f.actual_value,
notes=f.notes,
feedback_data=f.feedback_data or {},
created_at=f.created_at,
created_by=f.created_by
) for f in feedback_list
]
except Exception as e:
logger.error("Failed to get forecast feedback", error=str(e), tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve feedback"
)
@router.get(
route_builder.build_nested_resource_route("forecasts", "forecast_id", "accuracy"),
response_model=ForecastAccuracyMetrics
)
async def get_forecast_accuracy_metrics(
tenant_id: str = Path(..., description="Tenant ID"),
forecast_id: str = Path(..., description="Forecast ID"),
forecast_feedback_service: ForecastFeedbackService = Depends(get_forecast_feedback_service),
verified_tenant: str = Depends(verify_tenant_permission_dep)
):
"""
Get accuracy metrics for a specific forecast
Calculates accuracy scores based on feedback and actual vs predicted values.
"""
try:
logger.info("Getting forecast accuracy metrics", tenant_id=tenant_id, forecast_id=forecast_id)
metrics = await forecast_feedback_service.calculate_accuracy_metrics(
tenant_id=tenant_id,
forecast_id=forecast_id
)
if not metrics:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No accuracy metrics available for this forecast"
)
return {
'forecast_id': metrics.forecast_id,
'total_feedback_count': metrics.total_feedback_count,
'accuracy_score': metrics.accuracy_score,
'feedback_distribution': metrics.feedback_distribution,
'average_confidence': metrics.average_confidence,
'last_feedback_date': metrics.last_feedback_date
}
except Exception as e:
logger.error("Failed to get forecast accuracy metrics", error=str(e), tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to calculate accuracy metrics"
)
@router.get(
route_builder.build_base_route("forecasts", "accuracy-summary"),
response_model=ForecasterPerformanceMetrics
)
async def get_forecaster_performance_summary(
tenant_id: str = Path(..., description="Tenant ID"),
start_date: Optional[date] = Query(None, description="Start date filter"),
end_date: Optional[date] = Query(None, description="End date filter"),
product_id: Optional[str] = Query(None, description="Filter by product ID"),
forecast_feedback_service: ForecastFeedbackService = Depends(get_forecast_feedback_service),
verified_tenant: str = Depends(verify_tenant_permission_dep)
):
"""
Get overall forecaster performance summary
Aggregates accuracy metrics across all forecasts to assess overall system performance
and identify areas for improvement.
"""
try:
logger.info("Getting forecaster performance summary", tenant_id=tenant_id)
metrics = await forecast_feedback_service.calculate_performance_summary(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
product_id=product_id
)
return {
'overall_accuracy': metrics.overall_accuracy,
'total_forecasts_with_feedback': metrics.total_forecasts_with_feedback,
'accuracy_by_product': metrics.accuracy_by_product,
'accuracy_trend': metrics.accuracy_trend,
'improvement_suggestions': metrics.improvement_suggestions
}
except Exception as e:
logger.error("Failed to get forecaster performance summary", error=str(e), tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to calculate performance summary"
)
@router.get(
route_builder.build_base_route("forecasts", "feedback-trends")
)
async def get_feedback_trends(
tenant_id: str = Path(..., description="Tenant ID"),
days: int = Query(30, ge=7, le=365, description="Number of days to analyze"),
product_id: Optional[str] = Query(None, description="Filter by product ID"),
forecast_feedback_service: ForecastFeedbackService = Depends(get_forecast_feedback_service),
verified_tenant: str = Depends(verify_tenant_permission_dep)
):
"""
Get feedback trends over time
Analyzes how forecast accuracy and feedback patterns change over time.
"""
try:
logger.info("Getting feedback trends", tenant_id=tenant_id, days=days)
trends = await forecast_feedback_service.get_feedback_trends(
tenant_id=tenant_id,
days=days,
product_id=product_id
)
return {
'success': True,
'trends': trends,
'period': f'Last {days} days'
}
except Exception as e:
logger.error("Failed to get feedback trends", error=str(e), tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve feedback trends"
)
@router.post(
route_builder.build_resource_action_route("forecasts", "forecast_id", "retrain")
)
async def trigger_retraining_from_feedback(
tenant_id: str = Path(..., description="Tenant ID"),
forecast_id: str = Path(..., description="Forecast ID"),
forecast_feedback_service: ForecastFeedbackService = Depends(get_forecast_feedback_service),
verified_tenant: str = Depends(verify_tenant_permission_dep)
):
"""
Trigger model retraining based on feedback
Initiates a retraining job using recent feedback to improve forecast accuracy.
"""
try:
logger.info("Triggering retraining from feedback", tenant_id=tenant_id, forecast_id=forecast_id)
result = await forecast_feedback_service.trigger_retraining_from_feedback(
tenant_id=tenant_id,
forecast_id=forecast_id
)
return {
'success': True,
'message': 'Retraining job initiated successfully',
'job_id': result.job_id,
'forecasts_included': result.forecasts_included,
'feedback_samples_used': result.feedback_samples_used
}
except Exception as e:
logger.error("Failed to trigger retraining", error=str(e), tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to initiate retraining"
)
@router.get(
route_builder.build_resource_action_route("forecasts", "forecast_id", "suggestions")
)
async def get_improvement_suggestions(
tenant_id: str = Path(..., description="Tenant ID"),
forecast_id: str = Path(..., description="Forecast ID"),
forecast_feedback_service: ForecastFeedbackService = Depends(get_forecast_feedback_service),
verified_tenant: str = Depends(verify_tenant_permission_dep)
):
"""
Get AI-generated improvement suggestions for a forecast
Analyzes feedback patterns and suggests specific improvements for forecast accuracy.
"""
try:
logger.info("Getting improvement suggestions", tenant_id=tenant_id, forecast_id=forecast_id)
suggestions = await forecast_feedback_service.get_improvement_suggestions(
tenant_id=tenant_id,
forecast_id=forecast_id
)
return {
'success': True,
'forecast_id': forecast_id,
'suggestions': suggestions,
'confidence_scores': [s.get('confidence', 0.8) for s in suggestions]
}
except Exception as e:
logger.error("Failed to get improvement suggestions", error=str(e), tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to generate suggestions"
)
# Import datetime at runtime to avoid circular imports
from datetime import datetime, timedelta

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,145 @@
# services/forecasting/app/api/forecasts.py
"""
Forecasts API - Atomic CRUD operations on Forecast model
"""
import structlog
from fastapi import APIRouter, Depends, HTTPException, status, Query, Path
from typing import List, Optional
from datetime import date, datetime
import uuid
from app.services.forecasting_service import EnhancedForecastingService
from app.schemas.forecasts import ForecastResponse
from shared.database.base import create_database_manager
from app.core.config import settings
from shared.routing import RouteBuilder
route_builder = RouteBuilder('forecasting')
logger = structlog.get_logger()
router = APIRouter(tags=["forecasts"])
def get_enhanced_forecasting_service():
"""Dependency injection for EnhancedForecastingService"""
database_manager = create_database_manager(settings.DATABASE_URL, "forecasting-service")
return EnhancedForecastingService(database_manager)
@router.get(
route_builder.build_base_route("forecasts"),
response_model=List[ForecastResponse]
)
async def list_forecasts(
tenant_id: str = Path(..., description="Tenant ID"),
inventory_product_id: Optional[str] = Query(None, description="Filter by product ID"),
start_date: Optional[date] = Query(None, description="Start date filter"),
end_date: Optional[date] = Query(None, description="End date filter"),
limit: int = Query(50, ge=1, le=1000),
offset: int = Query(0, ge=0),
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
):
"""List forecasts with optional filters"""
try:
logger.info("Listing forecasts", tenant_id=tenant_id)
forecasts = await enhanced_forecasting_service.list_forecasts(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
start_date=start_date,
end_date=end_date,
limit=limit,
offset=offset
)
return forecasts
except Exception as e:
logger.error("Failed to list forecasts", error=str(e), tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve forecasts"
)
@router.get(
route_builder.build_resource_detail_route("forecasts", "forecast_id"),
response_model=ForecastResponse
)
async def get_forecast(
tenant_id: str = Path(..., description="Tenant ID"),
forecast_id: str = Path(..., description="Forecast ID"),
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
):
"""Get a specific forecast by ID"""
try:
logger.info("Getting forecast", tenant_id=tenant_id, forecast_id=forecast_id)
forecast = await enhanced_forecasting_service.get_forecast(
tenant_id=tenant_id,
forecast_id=uuid.UUID(forecast_id)
)
if not forecast:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Forecast not found"
)
return forecast
except HTTPException:
raise
except ValueError as e:
logger.error("Invalid forecast ID", error=str(e))
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid forecast ID format"
)
except Exception as e:
logger.error("Failed to get forecast", error=str(e), tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve forecast"
)
@router.delete(
route_builder.build_resource_detail_route("forecasts", "forecast_id")
)
async def delete_forecast(
tenant_id: str = Path(..., description="Tenant ID"),
forecast_id: str = Path(..., description="Forecast ID"),
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
):
"""Delete a specific forecast"""
try:
logger.info("Deleting forecast", tenant_id=tenant_id, forecast_id=forecast_id)
success = await enhanced_forecasting_service.delete_forecast(
tenant_id=tenant_id,
forecast_id=uuid.UUID(forecast_id)
)
if not success:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Forecast not found"
)
return {"message": "Forecast deleted successfully"}
except HTTPException:
raise
except ValueError as e:
logger.error("Invalid forecast ID", error=str(e))
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid forecast ID format"
)
except Exception as e:
logger.error("Failed to delete forecast", error=str(e), tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to delete forecast"
)

View File

@@ -0,0 +1,304 @@
# ================================================================
# services/forecasting/app/api/historical_validation.py
# ================================================================
"""
Historical Validation API - Backfill validation for late-arriving sales data
"""
from fastapi import APIRouter, Depends, HTTPException, Path, Query, status
from typing import Dict, Any, List, Optional
from uuid import UUID
from datetime import date
import structlog
from pydantic import BaseModel, Field
from app.services.historical_validation_service import HistoricalValidationService
from shared.auth.decorators import get_current_user_dep
from shared.auth.access_control import require_user_role
from shared.routing import RouteBuilder
from app.core.database import get_db
from sqlalchemy.ext.asyncio import AsyncSession
route_builder = RouteBuilder('forecasting')
router = APIRouter(tags=["historical-validation"])
logger = structlog.get_logger()
# ================================================================
# Request/Response Schemas
# ================================================================
class DetectGapsRequest(BaseModel):
"""Request model for gap detection"""
lookback_days: int = Field(default=90, ge=1, le=365, description="Days to look back")
class BackfillRequest(BaseModel):
"""Request model for manual backfill"""
start_date: date = Field(..., description="Start date for backfill")
end_date: date = Field(..., description="End date for backfill")
class SalesDataUpdateRequest(BaseModel):
"""Request model for registering sales data update"""
start_date: date = Field(..., description="Start date of updated data")
end_date: date = Field(..., description="End date of updated data")
records_affected: int = Field(..., ge=0, description="Number of records affected")
update_source: str = Field(default="import", description="Source of update")
import_job_id: Optional[str] = Field(None, description="Import job ID if applicable")
auto_trigger_validation: bool = Field(default=True, description="Auto-trigger validation")
class AutoBackfillRequest(BaseModel):
"""Request model for automatic backfill"""
lookback_days: int = Field(default=90, ge=1, le=365, description="Days to look back")
max_gaps_to_process: int = Field(default=10, ge=1, le=50, description="Max gaps to process")
# ================================================================
# Endpoints
# ================================================================
@router.post(
route_builder.build_base_route("validation/detect-gaps"),
status_code=status.HTTP_200_OK
)
@require_user_role(['admin', 'owner', 'member'])
async def detect_validation_gaps(
request: DetectGapsRequest,
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""
Detect date ranges where forecasts exist but haven't been validated yet
Returns list of gap periods that need validation backfill.
"""
try:
logger.info(
"Detecting validation gaps",
tenant_id=tenant_id,
lookback_days=request.lookback_days,
user_id=current_user.get("user_id")
)
service = HistoricalValidationService(db)
gaps = await service.detect_validation_gaps(
tenant_id=tenant_id,
lookback_days=request.lookback_days
)
return {
"gaps_found": len(gaps),
"lookback_days": request.lookback_days,
"gaps": [
{
"start_date": gap["start_date"].isoformat(),
"end_date": gap["end_date"].isoformat(),
"days_count": gap["days_count"]
}
for gap in gaps
]
}
except Exception as e:
logger.error(
"Failed to detect validation gaps",
tenant_id=tenant_id,
error=str(e)
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to detect validation gaps: {str(e)}"
)
@router.post(
route_builder.build_base_route("validation/backfill"),
status_code=status.HTTP_200_OK
)
@require_user_role(['admin', 'owner'])
async def backfill_validation(
request: BackfillRequest,
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""
Manually trigger validation backfill for a specific date range
Validates forecasts against sales data for historical periods.
"""
try:
logger.info(
"Manual validation backfill requested",
tenant_id=tenant_id,
start_date=request.start_date.isoformat(),
end_date=request.end_date.isoformat(),
user_id=current_user.get("user_id")
)
service = HistoricalValidationService(db)
result = await service.backfill_validation(
tenant_id=tenant_id,
start_date=request.start_date,
end_date=request.end_date,
triggered_by="manual"
)
return result
except Exception as e:
logger.error(
"Failed to backfill validation",
tenant_id=tenant_id,
error=str(e)
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to backfill validation: {str(e)}"
)
@router.post(
route_builder.build_base_route("validation/auto-backfill"),
status_code=status.HTTP_200_OK
)
@require_user_role(['admin', 'owner'])
async def auto_backfill_validation_gaps(
request: AutoBackfillRequest,
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""
Automatically detect and backfill validation gaps
Finds all date ranges with missing validations and processes them.
"""
try:
logger.info(
"Auto backfill requested",
tenant_id=tenant_id,
lookback_days=request.lookback_days,
max_gaps=request.max_gaps_to_process,
user_id=current_user.get("user_id")
)
service = HistoricalValidationService(db)
result = await service.auto_backfill_gaps(
tenant_id=tenant_id,
lookback_days=request.lookback_days,
max_gaps_to_process=request.max_gaps_to_process
)
return result
except Exception as e:
logger.error(
"Failed to auto backfill",
tenant_id=tenant_id,
error=str(e)
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to auto backfill: {str(e)}"
)
@router.post(
route_builder.build_base_route("validation/register-sales-update"),
status_code=status.HTTP_201_CREATED
)
@require_user_role(['admin', 'owner', 'member'])
async def register_sales_data_update(
request: SalesDataUpdateRequest,
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""
Register a sales data update and optionally trigger validation
Call this endpoint after importing historical sales data to automatically
trigger validation for the affected date range.
"""
try:
logger.info(
"Registering sales data update",
tenant_id=tenant_id,
date_range=f"{request.start_date} to {request.end_date}",
records_affected=request.records_affected,
user_id=current_user.get("user_id")
)
service = HistoricalValidationService(db)
result = await service.register_sales_data_update(
tenant_id=tenant_id,
start_date=request.start_date,
end_date=request.end_date,
records_affected=request.records_affected,
update_source=request.update_source,
import_job_id=request.import_job_id,
auto_trigger_validation=request.auto_trigger_validation
)
return result
except Exception as e:
logger.error(
"Failed to register sales data update",
tenant_id=tenant_id,
error=str(e)
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to register sales data update: {str(e)}"
)
@router.get(
route_builder.build_base_route("validation/pending"),
status_code=status.HTTP_200_OK
)
@require_user_role(['admin', 'owner', 'member'])
async def get_pending_validations(
tenant_id: UUID = Path(..., description="Tenant ID"),
limit: int = Query(50, ge=1, le=100, description="Number of records to return"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""
Get pending sales data updates awaiting validation
Returns list of sales data updates that have been registered
but not yet validated.
"""
try:
service = HistoricalValidationService(db)
pending = await service.get_pending_validations(
tenant_id=tenant_id,
limit=limit
)
return {
"pending_count": len(pending),
"pending_validations": [record.to_dict() for record in pending]
}
except Exception as e:
logger.error(
"Failed to get pending validations",
tenant_id=tenant_id,
error=str(e)
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get pending validations: {str(e)}"
)

View File

@@ -0,0 +1,477 @@
"""
Internal Demo Cloning API for Forecasting Service
Service-to-service endpoint for cloning forecast data
"""
from fastapi import APIRouter, Depends, HTTPException, Header
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
import structlog
import uuid
from datetime import datetime, timezone, timedelta
from typing import Optional
import os
import sys
from pathlib import Path
import json
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent))
from shared.utils.demo_dates import adjust_date_for_demo, resolve_time_marker
from app.core.database import get_db
from app.models.forecasts import Forecast, PredictionBatch
logger = structlog.get_logger()
router = APIRouter(prefix="/internal/demo", tags=["internal"])
# Base demo tenant IDs
DEMO_TENANT_PROFESSIONAL = "a1b2c3d4-e5f6-47a8-b9c0-d1e2f3a4b5c6"
def parse_date_field(date_value, session_time: datetime, field_name: str = "date") -> Optional[datetime]:
"""
Parse date field, handling both ISO strings and BASE_TS markers.
Supports:
- BASE_TS markers: "BASE_TS + 1h30m", "BASE_TS - 2d"
- ISO 8601 strings: "2025-01-15T06:00:00Z"
- None values (returns None)
Returns timezone-aware datetime or None.
"""
if not date_value:
return None
# Check if it's a BASE_TS marker
if isinstance(date_value, str) and date_value.startswith("BASE_TS"):
try:
return resolve_time_marker(date_value, session_time)
except ValueError as e:
logger.warning(
f"Invalid BASE_TS marker in {field_name}",
marker=date_value,
error=str(e)
)
return None
# Handle regular ISO date strings
try:
if isinstance(date_value, str):
original_date = datetime.fromisoformat(date_value.replace('Z', '+00:00'))
elif hasattr(date_value, 'isoformat'):
original_date = date_value
else:
logger.warning(f"Unsupported date format in {field_name}", date_value=date_value)
return None
return adjust_date_for_demo(original_date, session_time)
except (ValueError, AttributeError) as e:
logger.warning(
f"Invalid date format in {field_name}",
date_value=date_value,
error=str(e)
)
return None
def align_to_week_start(target_date: datetime) -> datetime:
"""Align forecast date to Monday (start of week)"""
if target_date:
days_since_monday = target_date.weekday()
return target_date - timedelta(days=days_since_monday)
return target_date
@router.post("/clone")
async def clone_demo_data(
base_tenant_id: str,
virtual_tenant_id: str,
demo_account_type: str,
session_id: Optional[str] = None,
session_created_at: Optional[str] = None,
db: AsyncSession = Depends(get_db)
):
"""
Clone forecasting service data for a virtual demo tenant
This endpoint creates fresh demo data by:
1. Loading seed data from JSON files
2. Applying XOR-based ID transformation
3. Adjusting dates relative to session creation time
4. Creating records in the virtual tenant
Args:
base_tenant_id: Template tenant UUID (for reference)
virtual_tenant_id: Target virtual tenant UUID
demo_account_type: Type of demo account
session_id: Originating session ID for tracing
session_created_at: Session creation timestamp for date adjustment
db: Database session
Returns:
Dictionary with cloning results
Raises:
HTTPException: On validation or cloning errors
"""
start_time = datetime.now(timezone.utc)
try:
# Validate UUIDs
virtual_uuid = uuid.UUID(virtual_tenant_id)
# Parse session creation time for date adjustment
if session_created_at:
try:
session_time = datetime.fromisoformat(session_created_at.replace('Z', '+00:00'))
except (ValueError, AttributeError):
session_time = start_time
else:
session_time = start_time
logger.info(
"Starting forecasting data cloning with date adjustment",
base_tenant_id=base_tenant_id,
virtual_tenant_id=str(virtual_uuid),
demo_account_type=demo_account_type,
session_id=session_id,
session_time=session_time.isoformat()
)
# Load seed data using shared utility
try:
from shared.utils.seed_data_paths import get_seed_data_path
if demo_account_type == "enterprise":
profile = "enterprise"
else:
profile = "professional"
json_file = get_seed_data_path(profile, "10-forecasting.json")
except ImportError:
# Fallback to original path
seed_data_dir = Path(__file__).parent.parent.parent.parent / "shared" / "demo" / "fixtures"
if demo_account_type == "enterprise":
json_file = seed_data_dir / "enterprise" / "parent" / "10-forecasting.json"
else:
json_file = seed_data_dir / "professional" / "10-forecasting.json"
if not json_file.exists():
raise HTTPException(
status_code=404,
detail=f"Seed data file not found: {json_file}"
)
# Load JSON data
with open(json_file, 'r', encoding='utf-8') as f:
seed_data = json.load(f)
# Check if data already exists for this virtual tenant (idempotency)
existing_check = await db.execute(
select(Forecast).where(Forecast.tenant_id == virtual_uuid).limit(1)
)
existing_forecast = existing_check.scalar_one_or_none()
if existing_forecast:
logger.warning(
"Demo data already exists, skipping clone",
virtual_tenant_id=str(virtual_uuid)
)
return {
"status": "skipped",
"reason": "Data already exists",
"records_cloned": 0
}
# Track cloning statistics
stats = {
"forecasts": 0,
"prediction_batches": 0
}
# Transform and insert forecasts
for forecast_data in seed_data.get('forecasts', []):
# Transform ID using XOR
from shared.utils.demo_id_transformer import transform_id
try:
forecast_uuid = uuid.UUID(forecast_data['id'])
tenant_uuid = uuid.UUID(virtual_tenant_id)
transformed_id = transform_id(forecast_data['id'], tenant_uuid)
except ValueError as e:
logger.error("Failed to parse UUIDs for ID transformation",
forecast_id=forecast_data['id'],
virtual_tenant_id=virtual_tenant_id,
error=str(e))
raise HTTPException(
status_code=400,
detail=f"Invalid UUID format in forecast data: {str(e)}"
)
# Transform dates using the proper parse_date_field function
for date_field in ['forecast_date', 'created_at']:
if date_field in forecast_data:
try:
parsed_date = parse_date_field(
forecast_data[date_field],
session_time,
date_field
)
if parsed_date:
forecast_data[date_field] = parsed_date
else:
# If parsing fails, use session_time as fallback
forecast_data[date_field] = session_time
logger.warning("Using fallback date for failed parsing",
date_field=date_field,
original_value=forecast_data[date_field])
except Exception as e:
logger.warning("Failed to parse date, using fallback",
date_field=date_field,
date_value=forecast_data[date_field],
error=str(e))
forecast_data[date_field] = session_time
# Create forecast
# Map product_id to inventory_product_id if needed
inventory_product_id_str = forecast_data.get('inventory_product_id') or forecast_data.get('product_id')
# Convert to UUID if it's a string
if isinstance(inventory_product_id_str, str):
inventory_product_id = uuid.UUID(inventory_product_id_str)
else:
inventory_product_id = inventory_product_id_str
# Map predicted_quantity to predicted_demand if needed
predicted_demand = forecast_data.get('predicted_demand') or forecast_data.get('predicted_quantity')
# Set default location if not provided in seed data
location = forecast_data.get('location') or "Main Bakery"
# Get or calculate forecast date
forecast_date = forecast_data.get('forecast_date')
if not forecast_date:
forecast_date = session_time
# Calculate day_of_week from forecast_date if not provided
# day_of_week should be 0-6 (Monday=0, Sunday=6)
day_of_week = forecast_data.get('day_of_week')
if day_of_week is None and forecast_date:
day_of_week = forecast_date.weekday()
# Calculate is_weekend from day_of_week if not provided
is_weekend = forecast_data.get('is_weekend')
if is_weekend is None and day_of_week is not None:
is_weekend = day_of_week >= 5 # Saturday=5, Sunday=6
else:
is_weekend = False
new_forecast = Forecast(
id=transformed_id,
tenant_id=virtual_uuid,
inventory_product_id=inventory_product_id,
product_name=forecast_data.get('product_name'),
location=location,
forecast_date=forecast_date,
created_at=forecast_data.get('created_at', session_time),
predicted_demand=predicted_demand,
confidence_lower=forecast_data.get('confidence_lower', max(0.0, float(predicted_demand or 0.0) * 0.8)),
confidence_upper=forecast_data.get('confidence_upper', max(0.0, float(predicted_demand or 0.0) * 1.2)),
confidence_level=forecast_data.get('confidence_level', 0.8),
model_id=forecast_data.get('model_id') or 'default-fallback-model',
model_version=forecast_data.get('model_version') or '1.0',
algorithm=forecast_data.get('algorithm', 'prophet'),
business_type=forecast_data.get('business_type', 'individual'),
day_of_week=day_of_week,
is_holiday=forecast_data.get('is_holiday', False),
is_weekend=is_weekend,
weather_temperature=forecast_data.get('weather_temperature'),
weather_precipitation=forecast_data.get('weather_precipitation'),
weather_description=forecast_data.get('weather_description'),
traffic_volume=forecast_data.get('traffic_volume'),
processing_time_ms=forecast_data.get('processing_time_ms'),
features_used=forecast_data.get('features_used')
)
db.add(new_forecast)
stats["forecasts"] += 1
# Transform and insert prediction batches
for batch_data in seed_data.get('prediction_batches', []):
# Transform ID using XOR
from shared.utils.demo_id_transformer import transform_id
try:
batch_uuid = uuid.UUID(batch_data['id'])
tenant_uuid = uuid.UUID(virtual_tenant_id)
transformed_id = transform_id(batch_data['id'], tenant_uuid)
except ValueError as e:
logger.error("Failed to parse UUIDs for ID transformation",
batch_id=batch_data['id'],
virtual_tenant_id=virtual_tenant_id,
error=str(e))
raise HTTPException(
status_code=400,
detail=f"Invalid UUID format in batch data: {str(e)}"
)
# Create prediction batch
# Handle field mapping: batch_id -> batch_name, total_forecasts -> total_products
batch_name = batch_data.get('batch_name') or batch_data.get('batch_id') or f"Batch-{transformed_id}"
total_products = batch_data.get('total_products') or batch_data.get('total_forecasts') or 0
completed_products = batch_data.get('completed_products') or (total_products if batch_data.get('status') == 'COMPLETED' else 0)
# Parse dates (handle created_at or prediction_date for requested_at)
requested_at_raw = batch_data.get('requested_at') or batch_data.get('created_at') or batch_data.get('prediction_date')
requested_at = parse_date_field(requested_at_raw, session_time, 'requested_at') if requested_at_raw else session_time
completed_at_raw = batch_data.get('completed_at')
completed_at = parse_date_field(completed_at_raw, session_time, 'completed_at') if completed_at_raw else None
new_batch = PredictionBatch(
id=transformed_id,
tenant_id=virtual_uuid,
batch_name=batch_name,
requested_at=requested_at,
completed_at=completed_at,
status=batch_data.get('status', 'completed'),
total_products=total_products,
completed_products=completed_products,
failed_products=batch_data.get('failed_products', 0),
forecast_days=batch_data.get('forecast_days', 7),
business_type=batch_data.get('business_type', 'individual'),
error_message=batch_data.get('error_message'),
processing_time_ms=batch_data.get('processing_time_ms'),
cancelled_by=batch_data.get('cancelled_by')
)
db.add(new_batch)
stats["prediction_batches"] += 1
# Commit all changes
await db.commit()
total_records = sum(stats.values())
duration_ms = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
logger.info(
"Forecasting data cloned successfully",
virtual_tenant_id=str(virtual_uuid),
records_cloned=total_records,
duration_ms=duration_ms,
forecasts_cloned=stats["forecasts"],
batches_cloned=stats["prediction_batches"]
)
return {
"service": "forecasting",
"status": "completed",
"records_cloned": total_records,
"duration_ms": duration_ms,
"details": {
"forecasts": stats["forecasts"],
"prediction_batches": stats["prediction_batches"],
"virtual_tenant_id": str(virtual_uuid)
}
}
except ValueError as e:
logger.error("Invalid UUID format", error=str(e), virtual_tenant_id=virtual_tenant_id)
raise HTTPException(status_code=400, detail=f"Invalid UUID: {str(e)}")
except Exception as e:
logger.error(
"Failed to clone forecasting data",
error=str(e),
virtual_tenant_id=virtual_tenant_id,
exc_info=True
)
# Rollback on error
await db.rollback()
return {
"service": "forecasting",
"status": "failed",
"records_cloned": 0,
"duration_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000),
"error": str(e)
}
@router.get("/clone/health")
async def clone_health_check():
"""
Health check for internal cloning endpoint
Used by orchestrator to verify service availability
"""
return {
"service": "forecasting",
"clone_endpoint": "available",
"version": "2.0.0"
}
@router.delete("/tenant/{virtual_tenant_id}")
async def delete_demo_tenant_data(
virtual_tenant_id: uuid.UUID,
db: AsyncSession = Depends(get_db)
):
"""
Delete all demo data for a virtual tenant.
This endpoint is idempotent - safe to call multiple times.
"""
from sqlalchemy import delete
start_time = datetime.now(timezone.utc)
records_deleted = {
"forecasts": 0,
"prediction_batches": 0,
"total": 0
}
try:
# Delete in reverse dependency order
# 1. Delete prediction batches
result = await db.execute(
delete(PredictionBatch)
.where(PredictionBatch.tenant_id == virtual_tenant_id)
)
records_deleted["prediction_batches"] = result.rowcount
# 2. Delete forecasts
result = await db.execute(
delete(Forecast)
.where(Forecast.tenant_id == virtual_tenant_id)
)
records_deleted["forecasts"] = result.rowcount
records_deleted["total"] = sum(records_deleted.values())
await db.commit()
logger.info(
"demo_data_deleted",
service="forecasting",
virtual_tenant_id=str(virtual_tenant_id),
records_deleted=records_deleted
)
return {
"service": "forecasting",
"status": "deleted",
"virtual_tenant_id": str(virtual_tenant_id),
"records_deleted": records_deleted,
"duration_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
}
except Exception as e:
await db.rollback()
logger.error(
"demo_data_deletion_failed",
service="forecasting",
virtual_tenant_id=str(virtual_tenant_id),
error=str(e)
)
raise HTTPException(
status_code=500,
detail=f"Failed to delete demo data: {str(e)}"
)

View File

@@ -0,0 +1,959 @@
"""
ML Insights API Endpoints for Forecasting Service
Provides endpoints to trigger ML insight generation for:
- Dynamic business rules learning
- Demand pattern analysis
- Seasonal trend detection
"""
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, Request
from pydantic import BaseModel, Field
from typing import Optional, List
from uuid import UUID
from datetime import datetime, timedelta
import structlog
import pandas as pd
from app.core.database import get_db
from sqlalchemy.ext.asyncio import AsyncSession
logger = structlog.get_logger()
router = APIRouter(
prefix="/api/v1/tenants/{tenant_id}/forecasting/ml/insights",
tags=["ML Insights"]
)
# ================================================================
# REQUEST/RESPONSE SCHEMAS
# ================================================================
class RulesGenerationRequest(BaseModel):
"""Request schema for rules generation"""
product_ids: Optional[List[str]] = Field(
None,
description="Specific product IDs to analyze. If None, analyzes all products"
)
lookback_days: int = Field(
90,
description="Days of historical data to analyze",
ge=30,
le=365
)
min_samples: int = Field(
10,
description="Minimum samples required for rule learning",
ge=5,
le=100
)
class RulesGenerationResponse(BaseModel):
"""Response schema for rules generation"""
success: bool
message: str
tenant_id: str
products_analyzed: int
total_insights_generated: int
total_insights_posted: int
insights_by_product: dict
errors: List[str] = []
class DemandAnalysisRequest(BaseModel):
"""Request schema for demand analysis"""
product_ids: Optional[List[str]] = Field(
None,
description="Specific product IDs to analyze. If None, analyzes all products"
)
lookback_days: int = Field(
90,
description="Days of historical data to analyze",
ge=30,
le=365
)
forecast_horizon_days: int = Field(
30,
description="Days to forecast ahead",
ge=7,
le=90
)
class DemandAnalysisResponse(BaseModel):
"""Response schema for demand analysis"""
success: bool
message: str
tenant_id: str
products_analyzed: int
total_insights_generated: int
total_insights_posted: int
insights_by_product: dict
errors: List[str] = []
class BusinessRulesAnalysisRequest(BaseModel):
"""Request schema for business rules analysis"""
product_ids: Optional[List[str]] = Field(
None,
description="Specific product IDs to analyze. If None, analyzes all products"
)
lookback_days: int = Field(
90,
description="Days of historical data to analyze",
ge=30,
le=365
)
min_samples: int = Field(
10,
description="Minimum samples required for rule analysis",
ge=5,
le=100
)
class BusinessRulesAnalysisResponse(BaseModel):
"""Response schema for business rules analysis"""
success: bool
message: str
tenant_id: str
products_analyzed: int
total_insights_generated: int
total_insights_posted: int
insights_by_product: dict
errors: List[str] = []
# ================================================================
# API ENDPOINTS
# ================================================================
@router.post("/generate-rules", response_model=RulesGenerationResponse)
async def trigger_rules_generation(
tenant_id: str,
request_data: RulesGenerationRequest,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
Trigger dynamic business rules learning from historical sales data.
This endpoint:
1. Fetches historical sales data for specified products
2. Runs the RulesOrchestrator to learn patterns
3. Generates insights about optimal business rules
4. Posts insights to AI Insights Service
Args:
tenant_id: Tenant UUID
request_data: Rules generation parameters
db: Database session
Returns:
RulesGenerationResponse with generation results
"""
logger.info(
"ML insights rules generation requested",
tenant_id=tenant_id,
product_ids=request_data.product_ids,
lookback_days=request_data.lookback_days
)
try:
# Import ML orchestrator and clients
from app.ml.rules_orchestrator import RulesOrchestrator
from shared.clients.sales_client import SalesServiceClient
from shared.clients.inventory_client import InventoryServiceClient
from app.core.config import settings
# Get event publisher from app state
event_publisher = getattr(request.app.state, 'event_publisher', None)
# Initialize orchestrator and clients
orchestrator = RulesOrchestrator(event_publisher=event_publisher)
inventory_client = InventoryServiceClient(settings)
# Get products to analyze from inventory service via API
if request_data.product_ids:
# Fetch specific products
products = []
for product_id in request_data.product_ids:
product = await inventory_client.get_ingredient_by_id(
ingredient_id=UUID(product_id),
tenant_id=tenant_id
)
if product:
products.append(product)
else:
# Fetch all products for tenant (limit to 10)
all_products = await inventory_client.get_all_ingredients(tenant_id=tenant_id)
products = all_products[:10] # Limit to prevent timeout
if not products:
return RulesGenerationResponse(
success=False,
message="No products found for analysis",
tenant_id=tenant_id,
products_analyzed=0,
total_insights_generated=0,
total_insights_posted=0,
insights_by_product={},
errors=["No products found"]
)
# Initialize sales client to fetch historical data
sales_client = SalesServiceClient(config=settings, calling_service_name="forecasting")
# Calculate date range
end_date = datetime.utcnow()
start_date = end_date - timedelta(days=request_data.lookback_days)
# Process each product
total_insights_generated = 0
total_insights_posted = 0
insights_by_product = {}
errors = []
for product in products:
try:
product_id = str(product['id'])
product_name = product.get('name', 'Unknown')
logger.info(f"Analyzing product {product_name} ({product_id})")
# Fetch sales data for product
sales_data = await sales_client.get_sales_data(
tenant_id=tenant_id,
product_id=product_id,
start_date=start_date.strftime('%Y-%m-%d'),
end_date=end_date.strftime('%Y-%m-%d')
)
if not sales_data:
logger.warning(f"No sales data for product {product_id}")
continue
# Convert to DataFrame
sales_df = pd.DataFrame(sales_data)
if len(sales_df) < request_data.min_samples:
logger.warning(
f"Insufficient data for product {product_id}: "
f"{len(sales_df)} samples < {request_data.min_samples} required"
)
continue
# Check what columns are available and map to expected format
logger.debug(f"Sales data columns for product {product_id}: {sales_df.columns.tolist()}")
# Map common field names to 'quantity' and 'date'
if 'quantity' not in sales_df.columns:
if 'total_quantity' in sales_df.columns:
sales_df['quantity'] = sales_df['total_quantity']
elif 'amount' in sales_df.columns:
sales_df['quantity'] = sales_df['amount']
else:
logger.warning(f"No quantity field found for product {product_id}, skipping")
continue
if 'date' not in sales_df.columns:
if 'sale_date' in sales_df.columns:
sales_df['date'] = sales_df['sale_date']
else:
logger.warning(f"No date field found for product {product_id}, skipping")
continue
# Prepare sales data with required columns
sales_df['date'] = pd.to_datetime(sales_df['date'])
sales_df['quantity'] = sales_df['quantity'].astype(float)
sales_df['day_of_week'] = sales_df['date'].dt.dayofweek
# NOTE: Holiday detection for historical data requires:
# 1. Tenant location context (calendar_id)
# 2. Bulk holiday check API (currently single-date only)
# 3. Historical calendar data
# For real-time forecasts, holiday detection IS implemented via data_client.py
sales_df['is_holiday'] = False
# NOTE: Weather data for historical analysis requires:
# 1. Historical weather API integration
# 2. Tenant location coordinates
# For real-time forecasts, weather data IS fetched via external service
sales_df['weather'] = 'unknown'
# Run rules learning
results = await orchestrator.learn_and_post_rules(
tenant_id=tenant_id,
inventory_product_id=product_id,
sales_data=sales_df,
external_data=None,
min_samples=request_data.min_samples
)
# Track results
total_insights_generated += results['insights_generated']
total_insights_posted += results['insights_posted']
insights_by_product[product_id] = {
'product_name': product_name,
'insights_posted': results['insights_posted'],
'rules_learned': len(results['rules'])
}
logger.info(
f"Product {product_id} analysis complete",
insights_posted=results['insights_posted']
)
except Exception as e:
error_msg = f"Error analyzing product {product_id}: {str(e)}"
logger.error(error_msg, exc_info=True)
errors.append(error_msg)
# Close orchestrator
await orchestrator.close()
# Build response
response = RulesGenerationResponse(
success=total_insights_posted > 0,
message=f"Successfully generated {total_insights_posted} insights from {len(products)} products",
tenant_id=tenant_id,
products_analyzed=len(products),
total_insights_generated=total_insights_generated,
total_insights_posted=total_insights_posted,
insights_by_product=insights_by_product,
errors=errors
)
logger.info(
"ML insights rules generation complete",
tenant_id=tenant_id,
total_insights=total_insights_posted
)
return response
except Exception as e:
logger.error(
"ML insights rules generation failed",
tenant_id=tenant_id,
error=str(e),
exc_info=True
)
raise HTTPException(
status_code=500,
detail=f"Rules generation failed: {str(e)}"
)
@router.post("/analyze-demand", response_model=DemandAnalysisResponse)
async def trigger_demand_analysis(
tenant_id: str,
request_data: DemandAnalysisRequest,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
Trigger demand pattern analysis from historical sales data.
This endpoint:
1. Fetches historical sales data for specified products
2. Runs the DemandInsightsOrchestrator to analyze patterns
3. Generates insights about demand forecasting optimization
4. Posts insights to AI Insights Service
5. Publishes events to RabbitMQ
Args:
tenant_id: Tenant UUID
request_data: Demand analysis parameters
request: FastAPI request object to access app state
db: Database session
Returns:
DemandAnalysisResponse with analysis results
"""
logger.info(
"ML insights demand analysis requested",
tenant_id=tenant_id,
product_ids=request_data.product_ids,
lookback_days=request_data.lookback_days
)
try:
# Import ML orchestrator and clients
from app.ml.demand_insights_orchestrator import DemandInsightsOrchestrator
from shared.clients.sales_client import SalesServiceClient
from shared.clients.inventory_client import InventoryServiceClient
from app.core.config import settings
# Get event publisher from app state
event_publisher = getattr(request.app.state, 'event_publisher', None)
# Initialize orchestrator and clients
orchestrator = DemandInsightsOrchestrator(event_publisher=event_publisher)
inventory_client = InventoryServiceClient(settings)
# Get products to analyze from inventory service via API
if request_data.product_ids:
# Fetch specific products
products = []
for product_id in request_data.product_ids:
product = await inventory_client.get_ingredient_by_id(
ingredient_id=UUID(product_id),
tenant_id=tenant_id
)
if product:
products.append(product)
else:
# Fetch all products for tenant (limit to 10)
all_products = await inventory_client.get_all_ingredients(tenant_id=tenant_id)
products = all_products[:10] # Limit to prevent timeout
if not products:
return DemandAnalysisResponse(
success=False,
message="No products found for analysis",
tenant_id=tenant_id,
products_analyzed=0,
total_insights_generated=0,
total_insights_posted=0,
insights_by_product={},
errors=["No products found"]
)
# Initialize sales client to fetch historical data
sales_client = SalesServiceClient(config=settings, calling_service_name="forecasting")
# Calculate date range
end_date = datetime.utcnow()
start_date = end_date - timedelta(days=request_data.lookback_days)
# Process each product
total_insights_generated = 0
total_insights_posted = 0
insights_by_product = {}
errors = []
for product in products:
try:
product_id = str(product['id'])
product_name = product.get('name', 'Unknown')
logger.info(f"Analyzing product {product_name} ({product_id})")
# Fetch sales data for product
sales_data = await sales_client.get_sales_data(
tenant_id=tenant_id,
product_id=product_id,
start_date=start_date.strftime('%Y-%m-%d'),
end_date=end_date.strftime('%Y-%m-%d')
)
if not sales_data:
logger.warning(f"No sales data for product {product_id}")
continue
# Convert to DataFrame
sales_df = pd.DataFrame(sales_data)
if len(sales_df) < 30: # Minimum for demand analysis
logger.warning(
f"Insufficient data for product {product_id}: "
f"{len(sales_df)} samples < 30 required"
)
continue
# Check what columns are available and map to expected format
logger.debug(f"Sales data columns for product {product_id}: {sales_df.columns.tolist()}")
# Map common field names to 'quantity' and 'date'
if 'quantity' not in sales_df.columns:
if 'total_quantity' in sales_df.columns:
sales_df['quantity'] = sales_df['total_quantity']
elif 'amount' in sales_df.columns:
sales_df['quantity'] = sales_df['amount']
else:
logger.warning(f"No quantity field found for product {product_id}, skipping")
continue
if 'date' not in sales_df.columns:
if 'sale_date' in sales_df.columns:
sales_df['date'] = sales_df['sale_date']
else:
logger.warning(f"No date field found for product {product_id}, skipping")
continue
# Prepare sales data with required columns
sales_df['date'] = pd.to_datetime(sales_df['date'])
sales_df['quantity'] = sales_df['quantity'].astype(float)
sales_df['day_of_week'] = sales_df['date'].dt.dayofweek
# Run demand analysis
results = await orchestrator.analyze_and_post_demand_insights(
tenant_id=tenant_id,
inventory_product_id=product_id,
sales_data=sales_df,
forecast_horizon_days=request_data.forecast_horizon_days,
min_history_days=request_data.lookback_days
)
# Track results
total_insights_generated += results['insights_generated']
total_insights_posted += results['insights_posted']
insights_by_product[product_id] = {
'product_name': product_name,
'insights_posted': results['insights_posted'],
'trend_analysis': results.get('trend_analysis', {})
}
logger.info(
f"Product {product_id} demand analysis complete",
insights_posted=results['insights_posted']
)
except Exception as e:
error_msg = f"Error analyzing product {product_id}: {str(e)}"
logger.error(error_msg, exc_info=True)
errors.append(error_msg)
# Close orchestrator
await orchestrator.close()
# Build response
response = DemandAnalysisResponse(
success=total_insights_posted > 0,
message=f"Successfully generated {total_insights_posted} insights from {len(products)} products",
tenant_id=tenant_id,
products_analyzed=len(products),
total_insights_generated=total_insights_generated,
total_insights_posted=total_insights_posted,
insights_by_product=insights_by_product,
errors=errors
)
logger.info(
"ML insights demand analysis complete",
tenant_id=tenant_id,
total_insights=total_insights_posted
)
return response
except Exception as e:
logger.error(
"ML insights demand analysis failed",
tenant_id=tenant_id,
error=str(e),
exc_info=True
)
raise HTTPException(
status_code=500,
detail=f"Demand analysis failed: {str(e)}"
)
@router.post("/analyze-business-rules", response_model=BusinessRulesAnalysisResponse)
async def trigger_business_rules_analysis(
tenant_id: str,
request_data: BusinessRulesAnalysisRequest,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
Trigger business rules optimization analysis from historical sales data.
This endpoint:
1. Fetches historical sales data for specified products
2. Runs the BusinessRulesInsightsOrchestrator to analyze rules
3. Generates insights about business rule optimization
4. Posts insights to AI Insights Service
5. Publishes events to RabbitMQ
Args:
tenant_id: Tenant UUID
request_data: Business rules analysis parameters
request: FastAPI request object to access app state
db: Database session
Returns:
BusinessRulesAnalysisResponse with analysis results
"""
logger.info(
"ML insights business rules analysis requested",
tenant_id=tenant_id,
product_ids=request_data.product_ids,
lookback_days=request_data.lookback_days
)
try:
# Import ML orchestrator and clients
from app.ml.business_rules_insights_orchestrator import BusinessRulesInsightsOrchestrator
from shared.clients.sales_client import SalesServiceClient
from shared.clients.inventory_client import InventoryServiceClient
from app.core.config import settings
# Get event publisher from app state
event_publisher = getattr(request.app.state, 'event_publisher', None)
# Initialize orchestrator and clients
orchestrator = BusinessRulesInsightsOrchestrator(event_publisher=event_publisher)
inventory_client = InventoryServiceClient(settings)
# Get products to analyze from inventory service via API
if request_data.product_ids:
# Fetch specific products
products = []
for product_id in request_data.product_ids:
product = await inventory_client.get_ingredient_by_id(
ingredient_id=UUID(product_id),
tenant_id=tenant_id
)
if product:
products.append(product)
else:
# Fetch all products for tenant (limit to 10)
all_products = await inventory_client.get_all_ingredients(tenant_id=tenant_id)
products = all_products[:10] # Limit to prevent timeout
if not products:
return BusinessRulesAnalysisResponse(
success=False,
message="No products found for analysis",
tenant_id=tenant_id,
products_analyzed=0,
total_insights_generated=0,
total_insights_posted=0,
insights_by_product={},
errors=["No products found"]
)
# Initialize sales client to fetch historical data
sales_client = SalesServiceClient(config=settings, calling_service_name="forecasting")
# Calculate date range
end_date = datetime.utcnow()
start_date = end_date - timedelta(days=request_data.lookback_days)
# Process each product
total_insights_generated = 0
total_insights_posted = 0
insights_by_product = {}
errors = []
for product in products:
try:
product_id = str(product['id'])
product_name = product.get('name', 'Unknown')
logger.info(f"Analyzing product {product_name} ({product_id})")
# Fetch sales data for product
sales_data = await sales_client.get_sales_data(
tenant_id=tenant_id,
product_id=product_id,
start_date=start_date.strftime('%Y-%m-%d'),
end_date=end_date.strftime('%Y-%m-%d')
)
if not sales_data:
logger.warning(f"No sales data for product {product_id}")
continue
# Convert to DataFrame
sales_df = pd.DataFrame(sales_data)
if len(sales_df) < request_data.min_samples:
logger.warning(
f"Insufficient data for product {product_id}: "
f"{len(sales_df)} samples < {request_data.min_samples} required"
)
continue
# Check what columns are available and map to expected format
logger.debug(f"Sales data columns for product {product_id}: {sales_df.columns.tolist()}")
# Map common field names to 'quantity' and 'date'
if 'quantity' not in sales_df.columns:
if 'total_quantity' in sales_df.columns:
sales_df['quantity'] = sales_df['total_quantity']
elif 'amount' in sales_df.columns:
sales_df['quantity'] = sales_df['amount']
else:
logger.warning(f"No quantity field found for product {product_id}, skipping")
continue
if 'date' not in sales_df.columns:
if 'sale_date' in sales_df.columns:
sales_df['date'] = sales_df['sale_date']
else:
logger.warning(f"No date field found for product {product_id}, skipping")
continue
# Prepare sales data with required columns
sales_df['date'] = pd.to_datetime(sales_df['date'])
sales_df['quantity'] = sales_df['quantity'].astype(float)
sales_df['day_of_week'] = sales_df['date'].dt.dayofweek
# Run business rules analysis
results = await orchestrator.analyze_and_post_business_rules_insights(
tenant_id=tenant_id,
inventory_product_id=product_id,
sales_data=sales_df,
min_samples=request_data.min_samples
)
# Track results
total_insights_generated += results['insights_generated']
total_insights_posted += results['insights_posted']
insights_by_product[product_id] = {
'product_name': product_name,
'insights_posted': results['insights_posted'],
'rules_learned': len(results.get('rules', {}))
}
logger.info(
f"Product {product_id} business rules analysis complete",
insights_posted=results['insights_posted']
)
except Exception as e:
error_msg = f"Error analyzing product {product_id}: {str(e)}"
logger.error(error_msg, exc_info=True)
errors.append(error_msg)
# Close orchestrator
await orchestrator.close()
# Build response
response = BusinessRulesAnalysisResponse(
success=total_insights_posted > 0,
message=f"Successfully generated {total_insights_posted} insights from {len(products)} products",
tenant_id=tenant_id,
products_analyzed=len(products),
total_insights_generated=total_insights_generated,
total_insights_posted=total_insights_posted,
insights_by_product=insights_by_product,
errors=errors
)
logger.info(
"ML insights business rules analysis complete",
tenant_id=tenant_id,
total_insights=total_insights_posted
)
return response
except Exception as e:
logger.error(
"ML insights business rules analysis failed",
tenant_id=tenant_id,
error=str(e),
exc_info=True
)
raise HTTPException(
status_code=500,
detail=f"Business rules analysis failed: {str(e)}"
)
@router.get("/health")
async def ml_insights_health():
"""Health check for ML insights endpoints"""
return {
"status": "healthy",
"service": "forecasting-ml-insights",
"endpoints": [
"POST /ml/insights/generate-rules",
"POST /ml/insights/analyze-demand",
"POST /ml/insights/analyze-business-rules"
]
}
# ================================================================
# INTERNAL ML INSIGHTS ENDPOINTS (for demo session service)
# ================================================================
internal_router = APIRouter(tags=["Internal ML"])
@internal_router.post("/api/v1/tenants/{tenant_id}/forecasting/internal/ml/generate-demand-insights")
async def trigger_demand_insights_internal(
tenant_id: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
Internal endpoint to trigger demand forecasting insights for a tenant.
This endpoint is called by the demo-session service after cloning to generate
AI insights from the seeded forecast data.
Args:
tenant_id: Tenant UUID
request: FastAPI request object to access app state
db: Database session
Returns:
Dict with insights generation results
"""
logger.info(
"Internal demand insights generation triggered",
tenant_id=tenant_id
)
try:
# Import ML orchestrator and clients
from app.ml.demand_insights_orchestrator import DemandInsightsOrchestrator
from shared.clients.sales_client import SalesServiceClient
from shared.clients.inventory_client import InventoryServiceClient
from app.core.config import settings
# Get event publisher from app state
event_publisher = getattr(request.app.state, 'event_publisher', None)
# Initialize orchestrator and clients
orchestrator = DemandInsightsOrchestrator(event_publisher=event_publisher)
inventory_client = InventoryServiceClient(settings)
# Get all products for tenant (limit to 10 for performance)
all_products = await inventory_client.get_all_ingredients(tenant_id=tenant_id)
products = all_products[:10] if all_products else []
logger.info(
"Retrieved products from inventory service",
tenant_id=tenant_id,
product_count=len(products)
)
if not products:
return {
"success": False,
"message": "No products found for analysis",
"tenant_id": tenant_id,
"products_analyzed": 0,
"insights_posted": 0
}
# Initialize sales client
sales_client = SalesServiceClient(config=settings, calling_service_name="forecasting")
# Calculate date range (90 days lookback)
end_date = datetime.utcnow()
start_date = end_date - timedelta(days=90)
# Process each product
total_insights_generated = 0
total_insights_posted = 0
for product in products:
try:
product_id = str(product['id'])
product_name = product.get('name', 'Unknown Product')
logger.debug(
"Analyzing demand for product",
tenant_id=tenant_id,
product_id=product_id,
product_name=product_name
)
# Fetch historical sales data
sales_data_raw = await sales_client.get_sales_data(
tenant_id=tenant_id,
product_id=product_id,
start_date=start_date.strftime('%Y-%m-%d'),
end_date=end_date.strftime('%Y-%m-%d')
)
if not sales_data_raw or len(sales_data_raw) < 10:
logger.debug(
"Insufficient sales data for product",
product_id=product_id,
sales_records=len(sales_data_raw) if sales_data_raw else 0
)
continue
# Convert to DataFrame
sales_df = pd.DataFrame(sales_data_raw)
# Map field names to expected format
if 'quantity' not in sales_df.columns:
if 'total_quantity' in sales_df.columns:
sales_df['quantity'] = sales_df['total_quantity']
elif 'quantity_sold' in sales_df.columns:
sales_df['quantity'] = sales_df['quantity_sold']
else:
logger.warning(
"No quantity field found for product",
product_id=product_id
)
continue
if 'date' not in sales_df.columns:
if 'sale_date' in sales_df.columns:
sales_df['date'] = sales_df['sale_date']
else:
logger.warning(
"No date field found for product",
product_id=product_id
)
continue
# Run demand insights orchestrator
results = await orchestrator.analyze_and_post_demand_insights(
tenant_id=tenant_id,
inventory_product_id=product_id,
sales_data=sales_df,
forecast_horizon_days=30,
min_history_days=90
)
total_insights_generated += results['insights_generated']
total_insights_posted += results['insights_posted']
logger.info(
"Demand insights generated for product",
tenant_id=tenant_id,
product_id=product_id,
insights_posted=results['insights_posted']
)
except Exception as e:
logger.warning(
"Failed to analyze product demand (non-fatal)",
tenant_id=tenant_id,
product_id=product_id,
error=str(e)
)
continue
logger.info(
"Internal demand insights generation complete",
tenant_id=tenant_id,
products_analyzed=len(products),
insights_generated=total_insights_generated,
insights_posted=total_insights_posted
)
return {
"success": True,
"message": f"Generated {total_insights_posted} demand forecasting insights",
"tenant_id": tenant_id,
"products_analyzed": len(products),
"insights_posted": total_insights_posted
}
except Exception as e:
logger.error(
"Internal demand insights generation failed",
tenant_id=tenant_id,
error=str(e),
exc_info=True
)
return {
"success": False,
"message": f"Demand insights generation failed: {str(e)}",
"tenant_id": tenant_id,
"products_analyzed": 0,
"insights_posted": 0
}

View File

@@ -0,0 +1,287 @@
# ================================================================
# services/forecasting/app/api/performance_monitoring.py
# ================================================================
"""
Performance Monitoring API - Track and analyze forecast accuracy over time
"""
from fastapi import APIRouter, Depends, HTTPException, Path, Query, status
from typing import Dict, Any
from uuid import UUID
import structlog
from pydantic import BaseModel, Field
from app.services.performance_monitoring_service import PerformanceMonitoringService
from shared.auth.decorators import get_current_user_dep
from shared.auth.access_control import require_user_role
from shared.routing import RouteBuilder
from app.core.database import get_db
from sqlalchemy.ext.asyncio import AsyncSession
route_builder = RouteBuilder('forecasting')
router = APIRouter(tags=["performance-monitoring"])
logger = structlog.get_logger()
# ================================================================
# Request/Response Schemas
# ================================================================
class AccuracySummaryRequest(BaseModel):
"""Request model for accuracy summary"""
days: int = Field(default=30, ge=1, le=365, description="Analysis period in days")
class DegradationAnalysisRequest(BaseModel):
"""Request model for degradation analysis"""
lookback_days: int = Field(default=30, ge=7, le=365, description="Days to analyze")
class ModelAgeCheckRequest(BaseModel):
"""Request model for model age check"""
max_age_days: int = Field(default=30, ge=1, le=90, description="Max acceptable model age")
class PerformanceReportRequest(BaseModel):
"""Request model for comprehensive performance report"""
days: int = Field(default=30, ge=1, le=365, description="Analysis period in days")
# ================================================================
# Endpoints
# ================================================================
@router.get(
route_builder.build_base_route("monitoring/accuracy-summary"),
status_code=status.HTTP_200_OK
)
@require_user_role(['admin', 'owner', 'member'])
async def get_accuracy_summary(
tenant_id: UUID = Path(..., description="Tenant ID"),
days: int = Query(30, ge=1, le=365, description="Analysis period in days"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""
Get forecast accuracy summary for recent period
Returns overall metrics, validation coverage, and health status.
"""
try:
logger.info(
"Getting accuracy summary",
tenant_id=tenant_id,
days=days,
user_id=current_user.get("user_id")
)
service = PerformanceMonitoringService(db)
summary = await service.get_accuracy_summary(
tenant_id=tenant_id,
days=days
)
return summary
except Exception as e:
logger.error(
"Failed to get accuracy summary",
tenant_id=tenant_id,
error=str(e)
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get accuracy summary: {str(e)}"
)
@router.get(
route_builder.build_base_route("monitoring/degradation-analysis"),
status_code=status.HTTP_200_OK
)
@require_user_role(['admin', 'owner', 'member'])
async def analyze_performance_degradation(
tenant_id: UUID = Path(..., description="Tenant ID"),
lookback_days: int = Query(30, ge=7, le=365, description="Days to analyze"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""
Detect if forecast performance is degrading over time
Compares first half vs second half of period and identifies poor performers.
"""
try:
logger.info(
"Analyzing performance degradation",
tenant_id=tenant_id,
lookback_days=lookback_days,
user_id=current_user.get("user_id")
)
service = PerformanceMonitoringService(db)
analysis = await service.detect_performance_degradation(
tenant_id=tenant_id,
lookback_days=lookback_days
)
return analysis
except Exception as e:
logger.error(
"Failed to analyze degradation",
tenant_id=tenant_id,
error=str(e)
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to analyze degradation: {str(e)}"
)
@router.get(
route_builder.build_base_route("monitoring/model-age"),
status_code=status.HTTP_200_OK
)
@require_user_role(['admin', 'owner', 'member'])
async def check_model_age(
tenant_id: UUID = Path(..., description="Tenant ID"),
max_age_days: int = Query(30, ge=1, le=90, description="Max acceptable model age"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""
Check if models are outdated and need retraining
Returns models in use and identifies those needing updates.
"""
try:
logger.info(
"Checking model age",
tenant_id=tenant_id,
max_age_days=max_age_days,
user_id=current_user.get("user_id")
)
service = PerformanceMonitoringService(db)
analysis = await service.check_model_age(
tenant_id=tenant_id,
max_age_days=max_age_days
)
return analysis
except Exception as e:
logger.error(
"Failed to check model age",
tenant_id=tenant_id,
error=str(e)
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to check model age: {str(e)}"
)
@router.post(
route_builder.build_base_route("monitoring/performance-report"),
status_code=status.HTTP_200_OK
)
@require_user_role(['admin', 'owner', 'member'])
async def generate_performance_report(
request: PerformanceReportRequest,
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""
Generate comprehensive performance report
Combines accuracy summary, degradation analysis, and model age check
with actionable recommendations.
"""
try:
logger.info(
"Generating performance report",
tenant_id=tenant_id,
days=request.days,
user_id=current_user.get("user_id")
)
service = PerformanceMonitoringService(db)
report = await service.generate_performance_report(
tenant_id=tenant_id,
days=request.days
)
return report
except Exception as e:
logger.error(
"Failed to generate performance report",
tenant_id=tenant_id,
error=str(e)
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to generate performance report: {str(e)}"
)
@router.get(
route_builder.build_base_route("monitoring/health"),
status_code=status.HTTP_200_OK
)
@require_user_role(['admin', 'owner', 'member'])
async def get_health_status(
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""
Get quick health status for dashboards
Returns simplified health metrics for UI display.
"""
try:
service = PerformanceMonitoringService(db)
# Get 7-day summary for quick health check
summary = await service.get_accuracy_summary(
tenant_id=tenant_id,
days=7
)
if summary.get("status") == "no_data":
return {
"status": "unknown",
"message": "No recent validation data available",
"health_status": "unknown"
}
return {
"status": "ok",
"health_status": summary.get("health_status"),
"current_mape": summary["average_metrics"].get("mape"),
"accuracy_percentage": summary["average_metrics"].get("accuracy_percentage"),
"validation_coverage": summary.get("coverage_percentage"),
"last_7_days": {
"validation_runs": summary.get("validation_runs"),
"forecasts_evaluated": summary.get("total_forecasts_evaluated")
}
}
except Exception as e:
logger.error(
"Failed to get health status",
tenant_id=tenant_id,
error=str(e)
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get health status: {str(e)}"
)

View File

@@ -0,0 +1,297 @@
# ================================================================
# services/forecasting/app/api/retraining.py
# ================================================================
"""
Retraining API - Trigger and manage model retraining based on performance
"""
from fastapi import APIRouter, Depends, HTTPException, Path, Query, status
from typing import Dict, Any, List
from uuid import UUID
import structlog
from pydantic import BaseModel, Field
from app.services.retraining_trigger_service import RetrainingTriggerService
from shared.auth.decorators import get_current_user_dep
from shared.auth.access_control import require_user_role
from shared.routing import RouteBuilder
from app.core.database import get_db
from sqlalchemy.ext.asyncio import AsyncSession
route_builder = RouteBuilder('forecasting')
router = APIRouter(tags=["retraining"])
logger = structlog.get_logger()
# ================================================================
# Request/Response Schemas
# ================================================================
class EvaluateRetrainingRequest(BaseModel):
"""Request model for retraining evaluation"""
auto_trigger: bool = Field(
default=False,
description="Automatically trigger retraining for poor performers"
)
class TriggerProductRetrainingRequest(BaseModel):
"""Request model for single product retraining"""
inventory_product_id: UUID = Field(..., description="Product to retrain")
reason: str = Field(..., description="Reason for retraining")
priority: str = Field(
default="normal",
description="Priority level: low, normal, high"
)
class TriggerBulkRetrainingRequest(BaseModel):
"""Request model for bulk retraining"""
product_ids: List[UUID] = Field(..., description="List of products to retrain")
reason: str = Field(
default="Bulk retraining requested",
description="Reason for bulk retraining"
)
class ScheduledRetrainingCheckRequest(BaseModel):
"""Request model for scheduled retraining check"""
max_model_age_days: int = Field(
default=30,
ge=1,
le=90,
description="Maximum acceptable model age"
)
# ================================================================
# Endpoints
# ================================================================
@router.post(
route_builder.build_base_route("retraining/evaluate"),
status_code=status.HTTP_200_OK
)
@require_user_role(['admin', 'owner'])
async def evaluate_retraining_needs(
request: EvaluateRetrainingRequest,
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""
Evaluate performance and optionally trigger retraining
Analyzes 30-day performance and identifies products needing retraining.
If auto_trigger=true, automatically triggers retraining for poor performers.
"""
try:
logger.info(
"Evaluating retraining needs",
tenant_id=tenant_id,
auto_trigger=request.auto_trigger,
user_id=current_user.get("user_id")
)
service = RetrainingTriggerService(db)
result = await service.evaluate_and_trigger_retraining(
tenant_id=tenant_id,
auto_trigger=request.auto_trigger
)
return result
except Exception as e:
logger.error(
"Failed to evaluate retraining needs",
tenant_id=tenant_id,
error=str(e)
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to evaluate retraining: {str(e)}"
)
@router.post(
route_builder.build_base_route("retraining/trigger-product"),
status_code=status.HTTP_200_OK
)
@require_user_role(['admin', 'owner'])
async def trigger_product_retraining(
request: TriggerProductRetrainingRequest,
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""
Trigger retraining for a specific product
Manually trigger model retraining for a single product.
"""
try:
logger.info(
"Triggering product retraining",
tenant_id=tenant_id,
product_id=request.inventory_product_id,
reason=request.reason,
user_id=current_user.get("user_id")
)
service = RetrainingTriggerService(db)
result = await service._trigger_product_retraining(
tenant_id=tenant_id,
inventory_product_id=request.inventory_product_id,
reason=request.reason,
priority=request.priority
)
return result
except Exception as e:
logger.error(
"Failed to trigger product retraining",
tenant_id=tenant_id,
product_id=request.inventory_product_id,
error=str(e)
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to trigger retraining: {str(e)}"
)
@router.post(
route_builder.build_base_route("retraining/trigger-bulk"),
status_code=status.HTTP_200_OK
)
@require_user_role(['admin', 'owner'])
async def trigger_bulk_retraining(
request: TriggerBulkRetrainingRequest,
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""
Trigger retraining for multiple products
Bulk retraining operation for multiple products at once.
"""
try:
logger.info(
"Triggering bulk retraining",
tenant_id=tenant_id,
product_count=len(request.product_ids),
reason=request.reason,
user_id=current_user.get("user_id")
)
service = RetrainingTriggerService(db)
result = await service.trigger_bulk_retraining(
tenant_id=tenant_id,
product_ids=request.product_ids,
reason=request.reason
)
return result
except Exception as e:
logger.error(
"Failed to trigger bulk retraining",
tenant_id=tenant_id,
error=str(e)
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to trigger bulk retraining: {str(e)}"
)
@router.get(
route_builder.build_base_route("retraining/recommendations"),
status_code=status.HTTP_200_OK
)
@require_user_role(['admin', 'owner', 'member'])
async def get_retraining_recommendations(
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""
Get retraining recommendations without triggering
Returns recommendations for manual review and decision-making.
"""
try:
logger.info(
"Getting retraining recommendations",
tenant_id=tenant_id,
user_id=current_user.get("user_id")
)
service = RetrainingTriggerService(db)
recommendations = await service.get_retraining_recommendations(
tenant_id=tenant_id
)
return recommendations
except Exception as e:
logger.error(
"Failed to get recommendations",
tenant_id=tenant_id,
error=str(e)
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get recommendations: {str(e)}"
)
@router.post(
route_builder.build_base_route("retraining/check-scheduled"),
status_code=status.HTTP_200_OK
)
@require_user_role(['admin', 'owner'])
async def check_scheduled_retraining(
request: ScheduledRetrainingCheckRequest,
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""
Check for models needing scheduled retraining based on age
Identifies models that haven't been updated in max_model_age_days.
"""
try:
logger.info(
"Checking scheduled retraining needs",
tenant_id=tenant_id,
max_model_age_days=request.max_model_age_days,
user_id=current_user.get("user_id")
)
service = RetrainingTriggerService(db)
result = await service.check_and_trigger_scheduled_retraining(
tenant_id=tenant_id,
max_model_age_days=request.max_model_age_days
)
return result
except Exception as e:
logger.error(
"Failed to check scheduled retraining",
tenant_id=tenant_id,
error=str(e)
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to check scheduled retraining: {str(e)}"
)

View File

@@ -0,0 +1,455 @@
"""
Scenario Simulation Operations API - PROFESSIONAL/ENTERPRISE ONLY
Business operations for "what-if" scenario testing and strategic planning
"""
import structlog
from fastapi import APIRouter, Depends, HTTPException, status, Path, Request
from typing import List, Dict, Any
from datetime import date, datetime, timedelta, timezone
import uuid
from app.schemas.forecasts import (
ScenarioSimulationRequest,
ScenarioSimulationResponse,
ScenarioComparisonRequest,
ScenarioComparisonResponse,
ScenarioType,
ScenarioImpact,
ForecastResponse,
ForecastRequest
)
from app.services.forecasting_service import EnhancedForecastingService
from shared.auth.decorators import get_current_user_dep
from shared.database.base import create_database_manager
from shared.monitoring.decorators import track_execution_time
from shared.monitoring.metrics import get_metrics_collector
from app.core.config import settings
from shared.routing import RouteBuilder
from shared.auth.access_control import require_user_role, analytics_tier_required
from shared.clients.tenant_client import TenantServiceClient
route_builder = RouteBuilder('forecasting')
logger = structlog.get_logger()
router = APIRouter(tags=["scenario-simulation"])
def get_enhanced_forecasting_service():
"""Dependency injection for EnhancedForecastingService"""
database_manager = create_database_manager(settings.DATABASE_URL, "forecasting-service")
return EnhancedForecastingService(database_manager)
@router.post(
route_builder.build_analytics_route("scenario-simulation"),
response_model=ScenarioSimulationResponse
)
@require_user_role(['admin', 'owner'])
@analytics_tier_required
@track_execution_time("scenario_simulation_duration_seconds", "forecasting-service")
async def simulate_scenario(
request: ScenarioSimulationRequest,
tenant_id: str = Path(..., description="Tenant ID"),
request_obj: Request = None,
current_user: dict = Depends(get_current_user_dep),
forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
):
"""
Run a "what-if" scenario simulation on forecasts
This endpoint allows users to test how different scenarios might impact demand:
- Weather events (heatwaves, cold snaps, rain)
- Competition (new competitors opening nearby)
- Events (festivals, concerts, sports events)
- Pricing changes
- Promotions
- Supply disruptions
**ENTERPRISE TIER ONLY - Admin+ role required**
"""
metrics = get_metrics_collector(request_obj)
start_time = datetime.now(timezone.utc)
try:
logger.info("Starting scenario simulation",
tenant_id=tenant_id,
scenario_name=request.scenario_name,
scenario_type=request.scenario_type.value,
products=len(request.inventory_product_ids))
if metrics:
metrics.increment_counter(f"scenario_simulations_total")
metrics.increment_counter(f"scenario_simulations_{request.scenario_type.value}_total")
# Generate simulation ID
simulation_id = str(uuid.uuid4())
end_date = request.start_date + timedelta(days=request.duration_days - 1)
# Step 1: Generate baseline forecasts
baseline_forecasts = []
if request.include_baseline:
logger.info("Generating baseline forecasts", tenant_id=tenant_id)
# Get tenant location (city) from tenant service
location = "default"
try:
tenant_client = TenantServiceClient(settings)
tenant_info = await tenant_client.get_tenant(tenant_id)
if tenant_info and tenant_info.get('city'):
location = tenant_info['city']
logger.info("Using tenant location for forecasts", tenant_id=tenant_id, location=location)
except Exception as e:
logger.warning("Failed to get tenant location, using default", error=str(e), tenant_id=tenant_id)
for product_id in request.inventory_product_ids:
forecast_request = ForecastRequest(
inventory_product_id=product_id,
forecast_date=request.start_date,
forecast_days=request.duration_days,
location=location
)
multi_day_result = await forecasting_service.generate_multi_day_forecast(
tenant_id=tenant_id,
request=forecast_request
)
# Convert forecast dictionaries to ForecastResponse objects
forecast_dicts = multi_day_result.get("forecasts", [])
for forecast_dict in forecast_dicts:
if isinstance(forecast_dict, dict):
baseline_forecasts.append(ForecastResponse(**forecast_dict))
else:
baseline_forecasts.append(forecast_dict)
# Step 2: Apply scenario adjustments to generate scenario forecasts
scenario_forecasts = await _apply_scenario_adjustments(
tenant_id=tenant_id,
request=request,
baseline_forecasts=baseline_forecasts if request.include_baseline else [],
forecasting_service=forecasting_service
)
# Step 3: Calculate impacts
product_impacts = _calculate_product_impacts(
baseline_forecasts,
scenario_forecasts,
request.inventory_product_ids
)
# Step 4: Calculate totals
total_baseline_demand = sum(f.predicted_demand for f in baseline_forecasts) if baseline_forecasts else 0
total_scenario_demand = sum(f.predicted_demand for f in scenario_forecasts)
overall_impact_percent = (
((total_scenario_demand - total_baseline_demand) / total_baseline_demand * 100)
if total_baseline_demand > 0 else 0
)
# Step 5: Generate insights and recommendations
insights, recommendations, risk_level = _generate_insights(
request.scenario_type,
request,
product_impacts,
overall_impact_percent
)
# Calculate processing time
processing_time_ms = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
if metrics:
metrics.increment_counter("scenario_simulations_success_total")
metrics.observe_histogram("scenario_simulation_processing_time_ms", processing_time_ms)
logger.info("Scenario simulation completed successfully",
tenant_id=tenant_id,
simulation_id=simulation_id,
overall_impact=f"{overall_impact_percent:.2f}%",
processing_time_ms=processing_time_ms)
return ScenarioSimulationResponse(
id=simulation_id,
tenant_id=tenant_id,
scenario_name=request.scenario_name,
scenario_type=request.scenario_type,
start_date=request.start_date,
end_date=end_date,
duration_days=request.duration_days,
baseline_forecasts=baseline_forecasts if request.include_baseline else None,
scenario_forecasts=scenario_forecasts,
total_baseline_demand=total_baseline_demand,
total_scenario_demand=total_scenario_demand,
overall_impact_percent=overall_impact_percent,
product_impacts=product_impacts,
insights=insights,
recommendations=recommendations,
risk_level=risk_level,
created_at=datetime.now(timezone.utc),
processing_time_ms=processing_time_ms
)
except ValueError as e:
if metrics:
metrics.increment_counter("scenario_simulation_validation_errors_total")
logger.error("Scenario simulation validation error", error=str(e), tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e:
if metrics:
metrics.increment_counter("scenario_simulations_errors_total")
logger.error("Scenario simulation failed", error=str(e), tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Scenario simulation failed"
)
async def _apply_scenario_adjustments(
tenant_id: str,
request: ScenarioSimulationRequest,
baseline_forecasts: List[ForecastResponse],
forecasting_service: EnhancedForecastingService
) -> List[ForecastResponse]:
"""
Apply scenario-specific adjustments to forecasts
"""
scenario_forecasts = []
# If no baseline, generate fresh forecasts
if not baseline_forecasts:
for product_id in request.inventory_product_ids:
forecast_request = ForecastRequest(
inventory_product_id=product_id,
forecast_date=request.start_date,
forecast_days=request.duration_days,
location="default"
)
multi_day_result = await forecasting_service.generate_multi_day_forecast(
tenant_id=tenant_id,
request=forecast_request
)
baseline_forecasts = multi_day_result.get("forecasts", [])
# Apply multipliers based on scenario type
for forecast in baseline_forecasts:
adjusted_forecast = forecast.copy()
multiplier = _get_scenario_multiplier(request)
# Adjust predicted demand
adjusted_forecast.predicted_demand *= multiplier
adjusted_forecast.confidence_lower *= multiplier
adjusted_forecast.confidence_upper *= multiplier
scenario_forecasts.append(adjusted_forecast)
return scenario_forecasts
def _get_scenario_multiplier(request: ScenarioSimulationRequest) -> float:
"""
Calculate demand multiplier based on scenario type and parameters
"""
if request.scenario_type == ScenarioType.WEATHER:
if request.weather_params:
# Heatwave increases demand for cold items, decreases for hot items
if request.weather_params.temperature_change and request.weather_params.temperature_change > 10:
return 1.25 # 25% increase during heatwave
elif request.weather_params.temperature_change and request.weather_params.temperature_change < -10:
return 0.85 # 15% decrease during cold snap
elif request.weather_params.precipitation_change and request.weather_params.precipitation_change > 10:
return 0.90 # 10% decrease during heavy rain
return 1.0
elif request.scenario_type == ScenarioType.COMPETITION:
if request.competition_params:
# New competition reduces demand based on market share loss
return 1.0 - request.competition_params.estimated_market_share_loss
return 0.85 # Default 15% reduction
elif request.scenario_type == ScenarioType.EVENT:
if request.event_params:
# Events increase demand based on attendance and proximity
if request.event_params.distance_km < 1.0:
return 1.5 # 50% increase for very close events
elif request.event_params.distance_km < 5.0:
return 1.2 # 20% increase for nearby events
return 1.15 # Default 15% increase
elif request.scenario_type == ScenarioType.PRICING:
if request.pricing_params:
# Price elasticity: typically -0.5 to -2.0
# 10% price increase = 5-20% demand decrease
elasticity = -1.0 # Average elasticity
return 1.0 + (request.pricing_params.price_change_percent / 100) * elasticity
return 1.0
elif request.scenario_type == ScenarioType.PROMOTION:
if request.promotion_params:
# Promotions increase traffic and conversion
traffic_boost = 1.0 + request.promotion_params.expected_traffic_increase
discount_boost = 1.0 + (request.promotion_params.discount_percent / 100) * 0.5
return traffic_boost * discount_boost
return 1.3 # Default 30% increase
elif request.scenario_type == ScenarioType.SUPPLY_DISRUPTION:
return 0.6 # 40% reduction due to limited supply
elif request.scenario_type == ScenarioType.CUSTOM:
if request.custom_multipliers and 'demand' in request.custom_multipliers:
return request.custom_multipliers['demand']
return 1.0
return 1.0
def _calculate_product_impacts(
baseline_forecasts: List[ForecastResponse],
scenario_forecasts: List[ForecastResponse],
product_ids: List[str]
) -> List[ScenarioImpact]:
"""
Calculate per-product impact of the scenario
"""
impacts = []
for product_id in product_ids:
baseline_total = sum(
f.predicted_demand for f in baseline_forecasts
if f.inventory_product_id == product_id
)
scenario_total = sum(
f.predicted_demand for f in scenario_forecasts
if f.inventory_product_id == product_id
)
if baseline_total > 0:
change_percent = ((scenario_total - baseline_total) / baseline_total) * 100
else:
change_percent = 0
# Get confidence ranges
scenario_product_forecasts = [
f for f in scenario_forecasts if f.inventory_product_id == product_id
]
avg_lower = sum(f.confidence_lower for f in scenario_product_forecasts) / len(scenario_product_forecasts) if scenario_product_forecasts else 0
avg_upper = sum(f.confidence_upper for f in scenario_product_forecasts) / len(scenario_product_forecasts) if scenario_product_forecasts else 0
impacts.append(ScenarioImpact(
inventory_product_id=product_id,
baseline_demand=baseline_total,
simulated_demand=scenario_total,
demand_change_percent=change_percent,
confidence_range=(avg_lower, avg_upper),
impact_factors={"primary_driver": "scenario_adjustment"}
))
return impacts
def _generate_insights(
scenario_type: ScenarioType,
request: ScenarioSimulationRequest,
impacts: List[ScenarioImpact],
overall_impact: float
) -> tuple[List[str], List[str], str]:
"""
Generate AI-powered insights and recommendations
"""
insights = []
recommendations = []
risk_level = "low"
# Determine risk level
if abs(overall_impact) > 30:
risk_level = "high"
elif abs(overall_impact) > 15:
risk_level = "medium"
# Generate scenario-specific insights
if scenario_type == ScenarioType.WEATHER:
if request.weather_params:
if request.weather_params.temperature_change and request.weather_params.temperature_change > 10:
insights.append(f"Heatwave of +{request.weather_params.temperature_change}°C expected to increase demand by {overall_impact:.1f}%")
recommendations.append("Increase inventory of cold beverages and refrigerated items")
recommendations.append("Extend operating hours to capture increased evening traffic")
elif request.weather_params.temperature_change and request.weather_params.temperature_change < -10:
insights.append(f"Cold snap of {request.weather_params.temperature_change}°C expected to decrease demand by {abs(overall_impact):.1f}%")
recommendations.append("Increase production of warm comfort foods")
recommendations.append("Reduce inventory of cold items")
elif scenario_type == ScenarioType.COMPETITION:
insights.append(f"New competitor expected to reduce demand by {abs(overall_impact):.1f}%")
recommendations.append("Consider launching loyalty program to retain customers")
recommendations.append("Differentiate with unique product offerings")
recommendations.append("Focus on customer service excellence")
elif scenario_type == ScenarioType.EVENT:
insights.append(f"Local event expected to increase demand by {overall_impact:.1f}%")
recommendations.append("Increase staffing for the event period")
recommendations.append("Stock additional inventory of popular items")
recommendations.append("Consider event-specific promotions")
elif scenario_type == ScenarioType.PRICING:
if overall_impact < 0:
insights.append(f"Price increase expected to reduce demand by {abs(overall_impact):.1f}%")
recommendations.append("Consider smaller price increases")
recommendations.append("Communicate value proposition to customers")
else:
insights.append(f"Price decrease expected to increase demand by {overall_impact:.1f}%")
recommendations.append("Ensure adequate inventory to meet increased demand")
elif scenario_type == ScenarioType.PROMOTION:
insights.append(f"Promotion expected to increase demand by {overall_impact:.1f}%")
recommendations.append("Stock additional inventory before promotion starts")
recommendations.append("Increase staffing during promotion period")
recommendations.append("Prepare marketing materials and signage")
# Add product-specific insights
high_impact_products = [
impact for impact in impacts
if abs(impact.demand_change_percent) > 20
]
if high_impact_products:
insights.append(f"{len(high_impact_products)} products show significant impact (>20% change)")
# Add general recommendation
if risk_level == "high":
recommendations.append("⚠️ High-impact scenario - review and adjust operational plans immediately")
elif risk_level == "medium":
recommendations.append("Monitor situation closely and prepare contingency plans")
return insights, recommendations, risk_level
@router.post(
route_builder.build_analytics_route("scenario-comparison"),
response_model=ScenarioComparisonResponse
)
@require_user_role(['viewer', 'member', 'admin', 'owner'])
@analytics_tier_required
async def compare_scenarios(
request: ScenarioComparisonRequest,
tenant_id: str = Path(..., description="Tenant ID")
):
"""
Compare multiple scenario simulations
**PROFESSIONAL/ENTERPRISE ONLY**
**STATUS**: Not yet implemented - requires scenario persistence layer
**Future implementation would**:
1. Retrieve saved scenarios by ID from database
2. Use ScenarioPlanner.compare_scenarios() to analyze them
3. Return comparison matrix with best/worst case analysis
**Prerequisites**:
- Scenario storage/retrieval database layer
- Scenario CRUD endpoints
- UI for scenario management
"""
# NOTE: HTTP 501 Not Implemented is the correct response for unimplemented optional features
# The ML logic exists in scenario_planner.py but requires a persistence layer
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="Scenario comparison requires scenario persistence layer (future feature)"
)

View File

@@ -0,0 +1,346 @@
# ================================================================
# services/forecasting/app/api/validation.py
# ================================================================
"""
Validation API - Forecast validation endpoints
"""
from fastapi import APIRouter, Depends, HTTPException, Path, Query, status
from typing import Dict, Any, List, Optional
from uuid import UUID
from datetime import datetime, timedelta, timezone
import structlog
from pydantic import BaseModel, Field
from app.services.validation_service import ValidationService
from shared.auth.decorators import get_current_user_dep
from shared.auth.access_control import require_user_role
from shared.routing import RouteBuilder
from app.core.database import get_db
from sqlalchemy.ext.asyncio import AsyncSession
route_builder = RouteBuilder('forecasting')
router = APIRouter(tags=["validation"])
logger = structlog.get_logger()
# ================================================================
# Request/Response Schemas
# ================================================================
class ValidationRequest(BaseModel):
"""Request model for validation"""
start_date: datetime = Field(..., description="Start date for validation period")
end_date: datetime = Field(..., description="End date for validation period")
orchestration_run_id: Optional[UUID] = Field(None, description="Optional orchestration run ID")
triggered_by: str = Field(default="manual", description="Trigger source")
class ValidationResponse(BaseModel):
"""Response model for validation results"""
validation_run_id: str
status: str
forecasts_evaluated: int
forecasts_with_actuals: int
forecasts_without_actuals: int
metrics_created: int
overall_metrics: Optional[Dict[str, float]] = None
total_predicted_demand: Optional[float] = None
total_actual_demand: Optional[float] = None
duration_seconds: Optional[float] = None
message: Optional[str] = None
class ValidationRunResponse(BaseModel):
"""Response model for validation run details"""
id: str
tenant_id: str
orchestration_run_id: Optional[str]
validation_start_date: str
validation_end_date: str
started_at: str
completed_at: Optional[str]
duration_seconds: Optional[float]
status: str
total_forecasts_evaluated: int
forecasts_with_actuals: int
forecasts_without_actuals: int
overall_mae: Optional[float]
overall_mape: Optional[float]
overall_rmse: Optional[float]
overall_r2_score: Optional[float]
overall_accuracy_percentage: Optional[float]
total_predicted_demand: float
total_actual_demand: float
metrics_by_product: Optional[Dict[str, Any]]
metrics_by_location: Optional[Dict[str, Any]]
metrics_records_created: int
error_message: Optional[str]
triggered_by: str
execution_mode: str
class AccuracyTrendResponse(BaseModel):
"""Response model for accuracy trends"""
period_days: int
total_runs: int
average_mape: Optional[float]
average_accuracy: Optional[float]
trends: List[Dict[str, Any]]
# ================================================================
# Endpoints
# ================================================================
@router.post(
route_builder.build_base_route("validation/validate-date-range"),
response_model=ValidationResponse,
status_code=status.HTTP_200_OK
)
@require_user_role(['admin', 'owner', 'member'])
async def validate_date_range(
validation_request: ValidationRequest,
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""
Validate forecasts against actual sales for a date range
This endpoint:
- Fetches forecasts for the specified date range
- Retrieves corresponding actual sales data
- Calculates accuracy metrics (MAE, MAPE, RMSE, R², accuracy %)
- Stores performance metrics in the database
- Returns validation summary
"""
try:
logger.info(
"Starting date range validation",
tenant_id=tenant_id,
start_date=validation_request.start_date.isoformat(),
end_date=validation_request.end_date.isoformat(),
user_id=current_user.get("user_id")
)
validation_service = ValidationService(db)
result = await validation_service.validate_date_range(
tenant_id=tenant_id,
start_date=validation_request.start_date,
end_date=validation_request.end_date,
orchestration_run_id=validation_request.orchestration_run_id,
triggered_by=validation_request.triggered_by
)
logger.info(
"Date range validation completed",
tenant_id=tenant_id,
validation_run_id=result.get("validation_run_id"),
forecasts_evaluated=result.get("forecasts_evaluated")
)
return ValidationResponse(**result)
except Exception as e:
logger.error(
"Failed to validate date range",
tenant_id=tenant_id,
error=str(e),
error_type=type(e).__name__
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to validate forecasts: {str(e)}"
)
@router.post(
route_builder.build_base_route("validation/validate-yesterday"),
response_model=ValidationResponse,
status_code=status.HTTP_200_OK
)
@require_user_role(['admin', 'owner', 'member'])
async def validate_yesterday(
tenant_id: UUID = Path(..., description="Tenant ID"),
orchestration_run_id: Optional[UUID] = Query(None, description="Optional orchestration run ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""
Validate yesterday's forecasts against actual sales
Convenience endpoint for validating the most recent day's forecasts.
This is typically called by the orchestrator as part of the daily workflow.
"""
try:
logger.info(
"Starting yesterday validation",
tenant_id=tenant_id,
user_id=current_user.get("user_id")
)
validation_service = ValidationService(db)
result = await validation_service.validate_yesterday(
tenant_id=tenant_id,
orchestration_run_id=orchestration_run_id,
triggered_by="manual"
)
logger.info(
"Yesterday validation completed",
tenant_id=tenant_id,
validation_run_id=result.get("validation_run_id"),
forecasts_evaluated=result.get("forecasts_evaluated")
)
return ValidationResponse(**result)
except Exception as e:
logger.error(
"Failed to validate yesterday",
tenant_id=tenant_id,
error=str(e),
error_type=type(e).__name__
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to validate yesterday's forecasts: {str(e)}"
)
@router.get(
route_builder.build_base_route("validation/runs/{validation_run_id}"),
response_model=ValidationRunResponse,
status_code=status.HTTP_200_OK
)
@require_user_role(['admin', 'owner', 'member'])
async def get_validation_run(
validation_run_id: UUID = Path(..., description="Validation run ID"),
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""
Get details of a specific validation run
Returns complete information about a validation execution including:
- Summary statistics
- Overall accuracy metrics
- Breakdown by product and location
- Execution metadata
"""
try:
validation_service = ValidationService(db)
validation_run = await validation_service.get_validation_run(validation_run_id)
if not validation_run:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Validation run {validation_run_id} not found"
)
if validation_run.tenant_id != tenant_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to this validation run"
)
return ValidationRunResponse(**validation_run.to_dict())
except HTTPException:
raise
except Exception as e:
logger.error(
"Failed to get validation run",
validation_run_id=validation_run_id,
error=str(e)
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get validation run: {str(e)}"
)
@router.get(
route_builder.build_base_route("validation/runs"),
response_model=List[ValidationRunResponse],
status_code=status.HTTP_200_OK
)
@require_user_role(['admin', 'owner', 'member'])
async def get_validation_runs(
tenant_id: UUID = Path(..., description="Tenant ID"),
limit: int = Query(50, ge=1, le=100, description="Number of records to return"),
skip: int = Query(0, ge=0, description="Number of records to skip"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""
Get validation runs for a tenant
Returns a list of validation executions with pagination support.
"""
try:
validation_service = ValidationService(db)
runs = await validation_service.get_validation_runs_by_tenant(
tenant_id=tenant_id,
limit=limit,
skip=skip
)
return [ValidationRunResponse(**run.to_dict()) for run in runs]
except Exception as e:
logger.error(
"Failed to get validation runs",
tenant_id=tenant_id,
error=str(e)
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get validation runs: {str(e)}"
)
@router.get(
route_builder.build_base_route("validation/trends"),
response_model=AccuracyTrendResponse,
status_code=status.HTTP_200_OK
)
@require_user_role(['admin', 'owner', 'member'])
async def get_accuracy_trends(
tenant_id: UUID = Path(..., description="Tenant ID"),
days: int = Query(30, ge=1, le=365, description="Number of days to analyze"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""
Get accuracy trends over time
Returns validation accuracy metrics over the specified time period.
Useful for monitoring model performance degradation and improvement.
"""
try:
validation_service = ValidationService(db)
trends = await validation_service.get_accuracy_trends(
tenant_id=tenant_id,
days=days
)
return AccuracyTrendResponse(**trends)
except Exception as e:
logger.error(
"Failed to get accuracy trends",
tenant_id=tenant_id,
error=str(e)
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get accuracy trends: {str(e)}"
)

View File

@@ -0,0 +1,174 @@
# ================================================================
# services/forecasting/app/api/webhooks.py
# ================================================================
"""
Webhooks API - Receive events from other services
"""
from fastapi import APIRouter, HTTPException, status, Header
from typing import Dict, Any, Optional
from uuid import UUID
from datetime import date
import structlog
from pydantic import BaseModel, Field
from app.jobs.sales_data_listener import (
handle_sales_import_completion,
handle_pos_sync_completion
)
from shared.routing import RouteBuilder
route_builder = RouteBuilder('forecasting')
router = APIRouter(tags=["webhooks"])
logger = structlog.get_logger()
# ================================================================
# Request Schemas
# ================================================================
class SalesImportWebhook(BaseModel):
"""Webhook payload for sales data import completion"""
tenant_id: UUID = Field(..., description="Tenant ID")
import_job_id: str = Field(..., description="Import job ID")
start_date: date = Field(..., description="Start date of imported data")
end_date: date = Field(..., description="End date of imported data")
records_count: int = Field(..., ge=0, description="Number of records imported")
import_source: str = Field(default="import", description="Source of import")
class POSSyncWebhook(BaseModel):
"""Webhook payload for POS sync completion"""
tenant_id: UUID = Field(..., description="Tenant ID")
sync_log_id: str = Field(..., description="POS sync log ID")
sync_date: date = Field(..., description="Date of synced data")
records_synced: int = Field(..., ge=0, description="Number of records synced")
# ================================================================
# Endpoints
# ================================================================
@router.post(
"/webhooks/sales-import-completed",
status_code=status.HTTP_202_ACCEPTED
)
async def sales_import_completed_webhook(
payload: SalesImportWebhook,
x_webhook_signature: Optional[str] = Header(None, description="Webhook signature for verification")
):
"""
Webhook endpoint for sales data import completion
Called by the sales service when a data import completes.
Triggers validation backfill for the imported date range.
Note: In production, this should verify the webhook signature
to ensure the request comes from a trusted source.
"""
try:
logger.info(
"Received sales import completion webhook",
tenant_id=payload.tenant_id,
import_job_id=payload.import_job_id,
date_range=f"{payload.start_date} to {payload.end_date}"
)
# In production, verify webhook signature here
# if not verify_webhook_signature(x_webhook_signature, payload):
# raise HTTPException(status_code=401, detail="Invalid webhook signature")
# Handle the import completion asynchronously
result = await handle_sales_import_completion(
tenant_id=payload.tenant_id,
import_job_id=payload.import_job_id,
start_date=payload.start_date,
end_date=payload.end_date,
records_count=payload.records_count,
import_source=payload.import_source
)
return {
"status": "accepted",
"message": "Sales import completion event received and processing",
"result": result
}
except Exception as e:
logger.error(
"Failed to process sales import webhook",
payload=payload.dict(),
error=str(e)
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to process webhook: {str(e)}"
)
@router.post(
"/webhooks/pos-sync-completed",
status_code=status.HTTP_202_ACCEPTED
)
async def pos_sync_completed_webhook(
payload: POSSyncWebhook,
x_webhook_signature: Optional[str] = Header(None, description="Webhook signature for verification")
):
"""
Webhook endpoint for POS sync completion
Called by the POS service when data synchronization completes.
Triggers validation for the synced date.
"""
try:
logger.info(
"Received POS sync completion webhook",
tenant_id=payload.tenant_id,
sync_log_id=payload.sync_log_id,
sync_date=payload.sync_date.isoformat()
)
# In production, verify webhook signature here
# if not verify_webhook_signature(x_webhook_signature, payload):
# raise HTTPException(status_code=401, detail="Invalid webhook signature")
# Handle the sync completion
result = await handle_pos_sync_completion(
tenant_id=payload.tenant_id,
sync_log_id=payload.sync_log_id,
sync_date=payload.sync_date,
records_synced=payload.records_synced
)
return {
"status": "accepted",
"message": "POS sync completion event received and processing",
"result": result
}
except Exception as e:
logger.error(
"Failed to process POS sync webhook",
payload=payload.dict(),
error=str(e)
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to process webhook: {str(e)}"
)
@router.get(
"/webhooks/health",
status_code=status.HTTP_200_OK
)
async def webhook_health_check():
"""Health check endpoint for webhook receiver"""
return {
"status": "healthy",
"service": "forecasting-webhooks",
"endpoints": [
"/webhooks/sales-import-completed",
"/webhooks/pos-sync-completed"
]
}

View File

@@ -0,0 +1,253 @@
"""
AI Insights Service HTTP Client
Posts insights from forecasting service to AI Insights Service
"""
import httpx
from typing import Dict, List, Any, Optional
from uuid import UUID
import structlog
from datetime import datetime
logger = structlog.get_logger()
class AIInsightsClient:
"""
HTTP client for AI Insights Service.
Allows forecasting service to post detected patterns and insights.
"""
def __init__(self, base_url: str, timeout: int = 30):
"""
Initialize AI Insights client.
Args:
base_url: Base URL of AI Insights Service (e.g., http://ai-insights-service:8000)
timeout: Request timeout in seconds
"""
self.base_url = base_url.rstrip('/')
self.timeout = timeout
self.client = httpx.AsyncClient(timeout=self.timeout)
async def close(self):
"""Close the HTTP client."""
await self.client.aclose()
async def create_insight(
self,
tenant_id: UUID,
insight_data: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""
Create a new insight in AI Insights Service.
Args:
tenant_id: Tenant UUID
insight_data: Insight data dictionary
Returns:
Created insight dict or None if failed
"""
url = f"{self.base_url}/api/v1/ai-insights/tenants/{tenant_id}/insights"
try:
# Ensure tenant_id is in the data
insight_data['tenant_id'] = str(tenant_id)
response = await self.client.post(url, json=insight_data)
if response.status_code == 201:
logger.info(
"Insight created successfully",
tenant_id=str(tenant_id),
insight_title=insight_data.get('title')
)
return response.json()
else:
logger.error(
"Failed to create insight",
status_code=response.status_code,
response=response.text,
insight_title=insight_data.get('title')
)
return None
except Exception as e:
logger.error(
"Error creating insight",
error=str(e),
tenant_id=str(tenant_id)
)
return None
async def create_insights_bulk(
self,
tenant_id: UUID,
insights: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""
Create multiple insights in bulk.
Args:
tenant_id: Tenant UUID
insights: List of insight data dictionaries
Returns:
Dictionary with success/failure counts
"""
results = {
'total': len(insights),
'successful': 0,
'failed': 0,
'created_insights': []
}
for insight_data in insights:
result = await self.create_insight(tenant_id, insight_data)
if result:
results['successful'] += 1
results['created_insights'].append(result)
else:
results['failed'] += 1
logger.info(
"Bulk insight creation complete",
total=results['total'],
successful=results['successful'],
failed=results['failed']
)
return results
async def get_insights(
self,
tenant_id: UUID,
filters: Optional[Dict[str, Any]] = None
) -> Optional[Dict[str, Any]]:
"""
Get insights for a tenant.
Args:
tenant_id: Tenant UUID
filters: Optional filters (category, priority, etc.)
Returns:
Paginated insights response or None if failed
"""
url = f"{self.base_url}/api/v1/ai-insights/tenants/{tenant_id}/insights"
try:
response = await self.client.get(url, params=filters or {})
if response.status_code == 200:
return response.json()
else:
logger.error(
"Failed to get insights",
status_code=response.status_code
)
return None
except Exception as e:
logger.error("Error getting insights", error=str(e))
return None
async def get_orchestration_ready_insights(
self,
tenant_id: UUID,
target_date: datetime,
min_confidence: int = 70
) -> Optional[Dict[str, List[Dict[str, Any]]]]:
"""
Get insights ready for orchestration workflow.
Args:
tenant_id: Tenant UUID
target_date: Target date for orchestration
min_confidence: Minimum confidence threshold
Returns:
Categorized insights or None if failed
"""
url = f"{self.base_url}/api/v1/ai-insights/tenants/{tenant_id}/insights/orchestration-ready"
params = {
'target_date': target_date.isoformat(),
'min_confidence': min_confidence
}
try:
response = await self.client.get(url, params=params)
if response.status_code == 200:
return response.json()
else:
logger.error(
"Failed to get orchestration insights",
status_code=response.status_code
)
return None
except Exception as e:
logger.error("Error getting orchestration insights", error=str(e))
return None
async def record_feedback(
self,
tenant_id: UUID,
insight_id: UUID,
feedback_data: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""
Record feedback for an applied insight.
Args:
tenant_id: Tenant UUID
insight_id: Insight UUID
feedback_data: Feedback data
Returns:
Feedback response or None if failed
"""
url = f"{self.base_url}/api/v1/ai-insights/tenants/{tenant_id}/insights/{insight_id}/feedback"
try:
feedback_data['insight_id'] = str(insight_id)
response = await self.client.post(url, json=feedback_data)
if response.status_code in [200, 201]:
logger.info(
"Feedback recorded",
insight_id=str(insight_id),
success=feedback_data.get('success')
)
return response.json()
else:
logger.error(
"Failed to record feedback",
status_code=response.status_code
)
return None
except Exception as e:
logger.error("Error recording feedback", error=str(e))
return None
async def health_check(self) -> bool:
"""
Check if AI Insights Service is healthy.
Returns:
True if healthy, False otherwise
"""
url = f"{self.base_url}/health"
try:
response = await self.client.get(url)
return response.status_code == 200
except Exception as e:
logger.error("AI Insights Service health check failed", error=str(e))
return False

View File

@@ -0,0 +1,187 @@
"""
Forecast event consumer for the forecasting service
Handles events that should trigger cache invalidation for aggregated forecasts
"""
import logging
from typing import Dict, Any, Optional
import json
import redis.asyncio as redis
logger = logging.getLogger(__name__)
class ForecastEventConsumer:
"""
Consumer for forecast events that may trigger cache invalidation
"""
def __init__(self, redis_client: redis.Redis):
self.redis_client = redis_client
async def handle_forecast_updated(self, event_data: Dict[str, Any]):
"""
Handle forecast updated event
Invalidate parent tenant's aggregated forecast cache if this tenant is a child
"""
try:
logger.info(f"Handling forecast updated event: {event_data}")
tenant_id = event_data.get('tenant_id')
forecast_date = event_data.get('forecast_date')
product_id = event_data.get('product_id')
updated_at = event_data.get('updated_at', None)
if not tenant_id:
logger.error("Missing tenant_id in forecast event")
return
# Check if this tenant is a child tenant (has parent)
# In a real implementation, this would call the tenant service to check hierarchy
parent_tenant_id = await self._get_parent_tenant_id(tenant_id)
if parent_tenant_id:
# Invalidate parent's aggregated forecast cache
await self._invalidate_parent_aggregated_cache(
parent_tenant_id=parent_tenant_id,
child_tenant_id=tenant_id,
forecast_date=forecast_date,
product_id=product_id
)
logger.info(f"Forecast updated event processed for tenant {tenant_id}")
except Exception as e:
logger.error(f"Error handling forecast updated event: {e}", exc_info=True)
raise
async def handle_forecast_created(self, event_data: Dict[str, Any]):
"""
Handle forecast created event
Similar to update, may affect parent tenant's aggregated forecasts
"""
await self.handle_forecast_updated(event_data)
async def handle_forecast_deleted(self, event_data: Dict[str, Any]):
"""
Handle forecast deleted event
Similar to update, may affect parent tenant's aggregated forecasts
"""
try:
logger.info(f"Handling forecast deleted event: {event_data}")
tenant_id = event_data.get('tenant_id')
forecast_date = event_data.get('forecast_date')
product_id = event_data.get('product_id')
if not tenant_id:
logger.error("Missing tenant_id in forecast delete event")
return
# Check if this tenant is a child tenant
parent_tenant_id = await self._get_parent_tenant_id(tenant_id)
if parent_tenant_id:
# Invalidate parent's aggregated forecast cache
await self._invalidate_parent_aggregated_cache(
parent_tenant_id=parent_tenant_id,
child_tenant_id=tenant_id,
forecast_date=forecast_date,
product_id=product_id
)
logger.info(f"Forecast deleted event processed for tenant {tenant_id}")
except Exception as e:
logger.error(f"Error handling forecast deleted event: {e}", exc_info=True)
raise
async def _get_parent_tenant_id(self, tenant_id: str) -> Optional[str]:
"""
Get parent tenant ID for a child tenant using the tenant service
"""
try:
from shared.clients.tenant_client import TenantServiceClient
from shared.config.base import get_settings
# Create tenant client
config = get_settings()
tenant_client = TenantServiceClient(config)
# Get parent tenant information
parent_tenant = await tenant_client.get_parent_tenant(tenant_id)
if parent_tenant:
parent_tenant_id = parent_tenant.get('id')
logger.info(f"Found parent tenant {parent_tenant_id} for child tenant {tenant_id}")
return parent_tenant_id
else:
logger.debug(f"No parent tenant found for tenant {tenant_id} (tenant may be standalone or parent)")
return None
except Exception as e:
logger.error(f"Error getting parent tenant ID for {tenant_id}: {e}")
return None
async def _invalidate_parent_aggregated_cache(
self,
parent_tenant_id: str,
child_tenant_id: str,
forecast_date: Optional[str] = None,
product_id: Optional[str] = None
):
"""
Invalidate parent tenant's aggregated forecast cache
"""
try:
# Pattern to match all aggregated forecast cache keys for this parent
# Format: agg_forecast:{parent_tenant_id}:{start_date}:{end_date}:{product_id}
pattern = f"agg_forecast:{parent_tenant_id}:*:*:*"
# Find all matching keys and delete them
keys_to_delete = []
async for key in self.redis_client.scan_iter(match=pattern):
if isinstance(key, bytes):
key = key.decode('utf-8')
keys_to_delete.append(key)
if keys_to_delete:
await self.redis_client.delete(*keys_to_delete)
logger.info(f"Invalidated {len(keys_to_delete)} aggregated forecast cache entries for parent tenant {parent_tenant_id}")
else:
logger.info(f"No aggregated forecast cache entries found to invalidate for parent tenant {parent_tenant_id}")
except Exception as e:
logger.error(f"Error invalidating parent aggregated cache: {e}", exc_info=True)
raise
async def handle_tenant_hierarchy_changed(self, event_data: Dict[str, Any]):
"""
Handle tenant hierarchy change event
This could be when a tenant becomes a child of another, or when the hierarchy changes
"""
try:
logger.info(f"Handling tenant hierarchy change event: {event_data}")
tenant_id = event_data.get('tenant_id')
parent_tenant_id = event_data.get('parent_tenant_id')
action = event_data.get('action') # 'added', 'removed', 'changed'
# Invalidate any cached aggregated forecasts that might be affected
if parent_tenant_id:
# If this child tenant changed, invalidate parent's cache
await self._invalidate_parent_aggregated_cache(
parent_tenant_id=parent_tenant_id,
child_tenant_id=tenant_id
)
# If this was a former parent tenant that's no longer a parent,
# its aggregated cache might need to be invalidated differently
if action == 'removed' and event_data.get('was_parent'):
# Invalidate its own aggregated cache since it's no longer a parent
# This would be handled by tenant service events
pass
except Exception as e:
logger.error(f"Error handling tenant hierarchy change event: {e}", exc_info=True)
raise

View File

@@ -0,0 +1,81 @@
# ================================================================
# FORECASTING SERVICE CONFIGURATION
# services/forecasting/app/core/config.py
# ================================================================
"""
Forecasting service configuration
Demand prediction and forecasting
"""
from shared.config.base import BaseServiceSettings
import os
class ForecastingSettings(BaseServiceSettings):
"""Forecasting service specific settings"""
# Service Identity
APP_NAME: str = "Forecasting Service"
SERVICE_NAME: str = "forecasting-service"
DESCRIPTION: str = "Demand prediction and forecasting service"
# Database configuration (secure approach - build from components)
@property
def DATABASE_URL(self) -> str:
"""Build database URL from secure components"""
# Try complete URL first (for backward compatibility)
complete_url = os.getenv("FORECASTING_DATABASE_URL")
if complete_url:
return complete_url
# Build from components (secure approach)
user = os.getenv("FORECASTING_DB_USER", "forecasting_user")
password = os.getenv("FORECASTING_DB_PASSWORD", "forecasting_pass123")
host = os.getenv("FORECASTING_DB_HOST", "localhost")
port = os.getenv("FORECASTING_DB_PORT", "5432")
name = os.getenv("FORECASTING_DB_NAME", "forecasting_db")
return f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{name}"
# Redis Database (dedicated for prediction cache)
REDIS_DB: int = 2
# Prediction Configuration
MAX_FORECAST_DAYS: int = int(os.getenv("MAX_FORECAST_DAYS", "30"))
MIN_HISTORICAL_DAYS: int = int(os.getenv("MIN_HISTORICAL_DAYS", "60"))
PREDICTION_CONFIDENCE_THRESHOLD: float = float(os.getenv("PREDICTION_CONFIDENCE_THRESHOLD", "0.8"))
# Caching Configuration
PREDICTION_CACHE_TTL_HOURS: int = int(os.getenv("PREDICTION_CACHE_TTL_HOURS", "6"))
FORECAST_BATCH_SIZE: int = int(os.getenv("FORECAST_BATCH_SIZE", "100"))
# MinIO Configuration
MINIO_ENDPOINT: str = os.getenv("MINIO_ENDPOINT", "minio.bakery-ia.svc.cluster.local:9000")
MINIO_ACCESS_KEY: str = os.getenv("FORECASTING_MINIO_ACCESS_KEY", "forecasting-service")
MINIO_SECRET_KEY: str = os.getenv("FORECASTING_MINIO_SECRET_KEY", "forecasting-secret-key")
MINIO_USE_SSL: bool = os.getenv("MINIO_USE_SSL", "true").lower() == "true"
MINIO_MODEL_BUCKET: str = os.getenv("MINIO_MODEL_BUCKET", "training-models")
MINIO_CONSOLE_PORT: str = os.getenv("MINIO_CONSOLE_PORT", "9001")
MINIO_API_PORT: str = os.getenv("MINIO_API_PORT", "9000")
MINIO_REGION: str = os.getenv("MINIO_REGION", "us-east-1")
MINIO_MODEL_LIFECYCLE_DAYS: int = int(os.getenv("MINIO_MODEL_LIFECYCLE_DAYS", "90"))
MINIO_CACHE_TTL_SECONDS: int = int(os.getenv("MINIO_CACHE_TTL_SECONDS", "3600"))
# Real-time Forecasting
REALTIME_FORECASTING_ENABLED: bool = os.getenv("REALTIME_FORECASTING_ENABLED", "true").lower() == "true"
FORECAST_UPDATE_INTERVAL_HOURS: int = int(os.getenv("FORECAST_UPDATE_INTERVAL_HOURS", "6"))
# Business Rules for Spanish Bakeries
BUSINESS_HOUR_START: int = 7 # 7 AM
BUSINESS_HOUR_END: int = 20 # 8 PM
WEEKEND_ADJUSTMENT_FACTOR: float = float(os.getenv("WEEKEND_ADJUSTMENT_FACTOR", "0.8"))
HOLIDAY_ADJUSTMENT_FACTOR: float = float(os.getenv("HOLIDAY_ADJUSTMENT_FACTOR", "0.5"))
# Weather Impact Modeling
WEATHER_IMPACT_ENABLED: bool = os.getenv("WEATHER_IMPACT_ENABLED", "true").lower() == "true"
TEMPERATURE_THRESHOLD_COLD: float = float(os.getenv("TEMPERATURE_THRESHOLD_COLD", "10.0"))
TEMPERATURE_THRESHOLD_HOT: float = float(os.getenv("TEMPERATURE_THRESHOLD_HOT", "30.0"))
RAIN_IMPACT_FACTOR: float = float(os.getenv("RAIN_IMPACT_FACTOR", "0.7"))
settings = ForecastingSettings()

View File

@@ -0,0 +1,121 @@
# ================================================================
# services/forecasting/app/core/database.py
# ================================================================
"""
Database configuration for forecasting service
"""
import structlog
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy import text
from typing import AsyncGenerator
from app.core.config import settings
from shared.database.base import Base, DatabaseManager
logger = structlog.get_logger()
# Create async engine
async_engine = create_async_engine(
settings.DATABASE_URL,
echo=settings.DEBUG,
pool_size=10,
max_overflow=20,
pool_pre_ping=True,
pool_recycle=3600
)
# Create async session factory
AsyncSessionLocal = async_sessionmaker(
bind=async_engine,
class_=AsyncSession,
expire_on_commit=False
)
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""Get database session"""
async with AsyncSessionLocal() as session:
try:
yield session
except Exception as e:
await session.rollback()
logger.error("Database session error", error=str(e))
raise
finally:
await session.close()
async def init_database():
"""Initialize database tables"""
try:
async with async_engine.begin() as conn:
# Import all models to ensure they are registered
from app.models.forecast import ForecastBatch, Forecast
from app.models.prediction import PredictionBatch, Prediction
# Create all tables
await conn.run_sync(Base.metadata.create_all)
logger.info("Forecasting database initialized successfully")
except Exception as e:
logger.error("Failed to initialize forecasting database", error=str(e))
raise
async def get_db_health() -> bool:
"""Check database health"""
try:
async with async_engine.begin() as conn:
await conn.execute(text("SELECT 1"))
return True
except Exception as e:
logger.error("Database health check failed", error=str(e))
return False
async def get_connection_pool_stats() -> dict:
"""
Get current connection pool statistics for monitoring.
Returns:
Dictionary with pool statistics including usage and capacity
"""
try:
pool = async_engine.pool
# Get pool stats
stats = {
"pool_size": pool.size(),
"checked_in_connections": pool.checkedin(),
"checked_out_connections": pool.checkedout(),
"overflow_connections": pool.overflow(),
"total_connections": pool.size() + pool.overflow(),
"max_capacity": 10 + 20, # pool_size + max_overflow
"usage_percentage": round(((pool.size() + pool.overflow()) / 30) * 100, 2)
}
# Add health status
if stats["usage_percentage"] > 90:
stats["status"] = "critical"
stats["message"] = "Connection pool near capacity"
elif stats["usage_percentage"] > 80:
stats["status"] = "warning"
stats["message"] = "Connection pool usage high"
else:
stats["status"] = "healthy"
stats["message"] = "Connection pool healthy"
return stats
except Exception as e:
logger.error("Failed to get connection pool stats", error=str(e))
return {
"status": "error",
"message": f"Failed to get pool stats: {str(e)}"
}
# Database manager instance for service_base compatibility
database_manager = DatabaseManager(
database_url=settings.DATABASE_URL,
service_name="forecasting-service",
pool_size=10,
max_overflow=20,
echo=settings.DEBUG
)

View File

@@ -0,0 +1,29 @@
"""
Forecasting Service Jobs Package
Scheduled and background jobs for the forecasting service
"""
from .daily_validation import daily_validation_job, validate_date_range_job
from .sales_data_listener import (
handle_sales_import_completion,
handle_pos_sync_completion,
process_pending_validations
)
from .auto_backfill_job import (
auto_backfill_all_tenants,
process_all_pending_validations,
daily_validation_maintenance_job,
run_validation_maintenance_for_tenant
)
__all__ = [
"daily_validation_job",
"validate_date_range_job",
"handle_sales_import_completion",
"handle_pos_sync_completion",
"process_pending_validations",
"auto_backfill_all_tenants",
"process_all_pending_validations",
"daily_validation_maintenance_job",
"run_validation_maintenance_for_tenant",
]

View File

@@ -0,0 +1,275 @@
# ================================================================
# services/forecasting/app/jobs/auto_backfill_job.py
# ================================================================
"""
Automated Backfill Job
Scheduled job to automatically detect and backfill validation gaps.
Can be run daily or weekly to ensure all historical forecasts are validated.
"""
from typing import Dict, Any, List
from datetime import datetime, timezone
import structlog
import uuid
from app.services.historical_validation_service import HistoricalValidationService
from app.core.database import database_manager
from app.jobs.sales_data_listener import process_pending_validations
logger = structlog.get_logger()
async def auto_backfill_all_tenants(
tenant_ids: List[uuid.UUID],
lookback_days: int = 90,
max_gaps_per_tenant: int = 5
) -> Dict[str, Any]:
"""
Run auto backfill for multiple tenants
Args:
tenant_ids: List of tenant IDs to process
lookback_days: How far back to check for gaps
max_gaps_per_tenant: Maximum number of gaps to process per tenant
Returns:
Summary of backfill operations across all tenants
"""
try:
logger.info(
"Starting auto backfill for all tenants",
tenant_count=len(tenant_ids),
lookback_days=lookback_days
)
results = []
total_gaps_found = 0
total_gaps_processed = 0
total_successful = 0
for tenant_id in tenant_ids:
try:
async with database_manager.get_session() as db:
service = HistoricalValidationService(db)
result = await service.auto_backfill_gaps(
tenant_id=tenant_id,
lookback_days=lookback_days,
max_gaps_to_process=max_gaps_per_tenant
)
results.append({
"tenant_id": str(tenant_id),
"status": "success",
**result
})
total_gaps_found += result.get("gaps_found", 0)
total_gaps_processed += result.get("gaps_processed", 0)
total_successful += result.get("validations_completed", 0)
except Exception as e:
logger.error(
"Failed to auto backfill for tenant",
tenant_id=tenant_id,
error=str(e)
)
results.append({
"tenant_id": str(tenant_id),
"status": "failed",
"error": str(e)
})
logger.info(
"Auto backfill completed for all tenants",
tenant_count=len(tenant_ids),
total_gaps_found=total_gaps_found,
total_gaps_processed=total_gaps_processed,
total_successful=total_successful
)
return {
"status": "completed",
"tenants_processed": len(tenant_ids),
"total_gaps_found": total_gaps_found,
"total_gaps_processed": total_gaps_processed,
"total_validations_completed": total_successful,
"results": results
}
except Exception as e:
logger.error(
"Auto backfill job failed",
error=str(e)
)
return {
"status": "failed",
"error": str(e)
}
async def process_all_pending_validations(
tenant_ids: List[uuid.UUID],
max_per_tenant: int = 10
) -> Dict[str, Any]:
"""
Process all pending validations for multiple tenants
Args:
tenant_ids: List of tenant IDs to process
max_per_tenant: Maximum pending validations to process per tenant
Returns:
Summary of processing results
"""
try:
logger.info(
"Processing pending validations for all tenants",
tenant_count=len(tenant_ids)
)
results = []
total_pending = 0
total_processed = 0
total_successful = 0
for tenant_id in tenant_ids:
try:
result = await process_pending_validations(
tenant_id=tenant_id,
max_to_process=max_per_tenant
)
results.append({
"tenant_id": str(tenant_id),
**result
})
total_pending += result.get("pending_count", 0)
total_processed += result.get("processed", 0)
total_successful += result.get("successful", 0)
except Exception as e:
logger.error(
"Failed to process pending validations for tenant",
tenant_id=tenant_id,
error=str(e)
)
results.append({
"tenant_id": str(tenant_id),
"status": "failed",
"error": str(e)
})
logger.info(
"Pending validations processed for all tenants",
tenant_count=len(tenant_ids),
total_pending=total_pending,
total_processed=total_processed,
total_successful=total_successful
)
return {
"status": "completed",
"tenants_processed": len(tenant_ids),
"total_pending": total_pending,
"total_processed": total_processed,
"total_successful": total_successful,
"results": results
}
except Exception as e:
logger.error(
"Failed to process all pending validations",
error=str(e)
)
return {
"status": "failed",
"error": str(e)
}
async def daily_validation_maintenance_job(
tenant_ids: List[uuid.UUID]
) -> Dict[str, Any]:
"""
Daily validation maintenance job
Combines gap detection/backfill and pending validation processing.
Recommended to run once daily (e.g., 6:00 AM after orchestrator completes).
Args:
tenant_ids: List of tenant IDs to process
Returns:
Summary of all maintenance operations
"""
try:
logger.info(
"Starting daily validation maintenance",
tenant_count=len(tenant_ids),
timestamp=datetime.now(timezone.utc).isoformat()
)
# Step 1: Process pending validations (retry failures)
pending_result = await process_all_pending_validations(
tenant_ids=tenant_ids,
max_per_tenant=10
)
# Step 2: Auto backfill detected gaps
backfill_result = await auto_backfill_all_tenants(
tenant_ids=tenant_ids,
lookback_days=90,
max_gaps_per_tenant=5
)
logger.info(
"Daily validation maintenance completed",
pending_validations_processed=pending_result.get("total_processed", 0),
gaps_backfilled=backfill_result.get("total_validations_completed", 0)
)
return {
"status": "completed",
"timestamp": datetime.now(timezone.utc).isoformat(),
"tenants_processed": len(tenant_ids),
"pending_validations": pending_result,
"gap_backfill": backfill_result,
"summary": {
"total_pending_processed": pending_result.get("total_processed", 0),
"total_gaps_backfilled": backfill_result.get("total_validations_completed", 0),
"total_validations": (
pending_result.get("total_processed", 0) +
backfill_result.get("total_validations_completed", 0)
)
}
}
except Exception as e:
logger.error(
"Daily validation maintenance failed",
error=str(e)
)
return {
"status": "failed",
"timestamp": datetime.now(timezone.utc).isoformat(),
"error": str(e)
}
# Convenience function for single tenant
async def run_validation_maintenance_for_tenant(
tenant_id: uuid.UUID
) -> Dict[str, Any]:
"""
Run validation maintenance for a single tenant
Args:
tenant_id: Tenant identifier
Returns:
Maintenance results
"""
return await daily_validation_maintenance_job([tenant_id])

View File

@@ -0,0 +1,147 @@
# ================================================================
# services/forecasting/app/jobs/daily_validation.py
# ================================================================
"""
Daily Validation Job
Scheduled job to validate previous day's forecasts against actual sales.
This job is called by the orchestrator as part of the daily workflow.
"""
from typing import Dict, Any, Optional
from datetime import datetime, timedelta, timezone
import structlog
import uuid
from app.services.validation_service import ValidationService
from app.core.database import database_manager
logger = structlog.get_logger()
async def daily_validation_job(
tenant_id: uuid.UUID,
orchestration_run_id: Optional[uuid.UUID] = None
) -> Dict[str, Any]:
"""
Validate yesterday's forecasts against actual sales
This function is designed to be called by the orchestrator as part of
the daily workflow (Step 5: validate_previous_forecasts).
Args:
tenant_id: Tenant identifier
orchestration_run_id: Optional orchestration run ID for tracking
Returns:
Dictionary with validation results
"""
async with database_manager.get_session() as db:
try:
logger.info(
"Starting daily validation job",
tenant_id=tenant_id,
orchestration_run_id=orchestration_run_id
)
validation_service = ValidationService(db)
# Validate yesterday's forecasts
result = await validation_service.validate_yesterday(
tenant_id=tenant_id,
orchestration_run_id=orchestration_run_id,
triggered_by="orchestrator"
)
logger.info(
"Daily validation job completed",
tenant_id=tenant_id,
validation_run_id=result.get("validation_run_id"),
forecasts_evaluated=result.get("forecasts_evaluated"),
forecasts_with_actuals=result.get("forecasts_with_actuals"),
overall_mape=result.get("overall_metrics", {}).get("mape")
)
return result
except Exception as e:
logger.error(
"Daily validation job failed",
tenant_id=tenant_id,
orchestration_run_id=orchestration_run_id,
error=str(e),
error_type=type(e).__name__
)
return {
"status": "failed",
"error": str(e),
"tenant_id": str(tenant_id),
"orchestration_run_id": str(orchestration_run_id) if orchestration_run_id else None
}
async def validate_date_range_job(
tenant_id: uuid.UUID,
start_date: datetime,
end_date: datetime,
orchestration_run_id: Optional[uuid.UUID] = None
) -> Dict[str, Any]:
"""
Validate forecasts for a specific date range
Useful for backfilling validation metrics when historical data is uploaded.
Args:
tenant_id: Tenant identifier
start_date: Start of validation period
end_date: End of validation period
orchestration_run_id: Optional orchestration run ID for tracking
Returns:
Dictionary with validation results
"""
async with database_manager.get_session() as db:
try:
logger.info(
"Starting date range validation job",
tenant_id=tenant_id,
start_date=start_date.isoformat(),
end_date=end_date.isoformat(),
orchestration_run_id=orchestration_run_id
)
validation_service = ValidationService(db)
result = await validation_service.validate_date_range(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
orchestration_run_id=orchestration_run_id,
triggered_by="scheduled"
)
logger.info(
"Date range validation job completed",
tenant_id=tenant_id,
validation_run_id=result.get("validation_run_id"),
forecasts_evaluated=result.get("forecasts_evaluated"),
forecasts_with_actuals=result.get("forecasts_with_actuals")
)
return result
except Exception as e:
logger.error(
"Date range validation job failed",
tenant_id=tenant_id,
start_date=start_date.isoformat(),
end_date=end_date.isoformat(),
error=str(e),
error_type=type(e).__name__
)
return {
"status": "failed",
"error": str(e),
"tenant_id": str(tenant_id),
"orchestration_run_id": str(orchestration_run_id) if orchestration_run_id else None
}

View File

@@ -0,0 +1,276 @@
# ================================================================
# services/forecasting/app/jobs/sales_data_listener.py
# ================================================================
"""
Sales Data Listener
Listens for sales data import completions and triggers validation backfill.
Can be called via webhook, message queue, or direct API call from sales service.
"""
from typing import Dict, Any, Optional
from datetime import datetime, date
import structlog
import uuid
from app.services.historical_validation_service import HistoricalValidationService
from app.core.database import database_manager
logger = structlog.get_logger()
async def handle_sales_import_completion(
tenant_id: uuid.UUID,
import_job_id: str,
start_date: date,
end_date: date,
records_count: int,
import_source: str = "import"
) -> Dict[str, Any]:
"""
Handle sales data import completion event
This function is called when the sales service completes a data import.
It registers the update and triggers validation for the imported date range.
Args:
tenant_id: Tenant identifier
import_job_id: Sales import job ID
start_date: Start date of imported data
end_date: End date of imported data
records_count: Number of records imported
import_source: Source of import (csv, xlsx, api, pos_sync)
Returns:
Dictionary with registration and validation results
"""
async with database_manager.get_session() as db:
try:
logger.info(
"Handling sales import completion",
tenant_id=tenant_id,
import_job_id=import_job_id,
date_range=f"{start_date} to {end_date}",
records_count=records_count
)
service = HistoricalValidationService(db)
# Register the sales data update and trigger validation
result = await service.register_sales_data_update(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
records_affected=records_count,
update_source=import_source,
import_job_id=import_job_id,
auto_trigger_validation=True
)
logger.info(
"Sales import completion handled",
tenant_id=tenant_id,
import_job_id=import_job_id,
update_id=result.get("update_id"),
validation_triggered=result.get("validation_triggered")
)
return {
"status": "success",
"tenant_id": str(tenant_id),
"import_job_id": import_job_id,
**result
}
except Exception as e:
logger.error(
"Failed to handle sales import completion",
tenant_id=tenant_id,
import_job_id=import_job_id,
error=str(e),
error_type=type(e).__name__
)
return {
"status": "failed",
"error": str(e),
"tenant_id": str(tenant_id),
"import_job_id": import_job_id
}
async def handle_pos_sync_completion(
tenant_id: uuid.UUID,
sync_log_id: str,
sync_date: date,
records_synced: int
) -> Dict[str, Any]:
"""
Handle POS sync completion event
Called when POS data is synchronized to the sales service.
Args:
tenant_id: Tenant identifier
sync_log_id: POS sync log ID
sync_date: Date of synced data
records_synced: Number of records synced
Returns:
Dictionary with registration and validation results
"""
async with database_manager.get_session() as db:
try:
logger.info(
"Handling POS sync completion",
tenant_id=tenant_id,
sync_log_id=sync_log_id,
sync_date=sync_date.isoformat(),
records_synced=records_synced
)
service = HistoricalValidationService(db)
# For POS syncs, we typically validate just the sync date
result = await service.register_sales_data_update(
tenant_id=tenant_id,
start_date=sync_date,
end_date=sync_date,
records_affected=records_synced,
update_source="pos_sync",
import_job_id=sync_log_id,
auto_trigger_validation=True
)
logger.info(
"POS sync completion handled",
tenant_id=tenant_id,
sync_log_id=sync_log_id,
update_id=result.get("update_id")
)
return {
"status": "success",
"tenant_id": str(tenant_id),
"sync_log_id": sync_log_id,
**result
}
except Exception as e:
logger.error(
"Failed to handle POS sync completion",
tenant_id=tenant_id,
sync_log_id=sync_log_id,
error=str(e)
)
return {
"status": "failed",
"error": str(e),
"tenant_id": str(tenant_id),
"sync_log_id": sync_log_id
}
async def process_pending_validations(
tenant_id: Optional[uuid.UUID] = None,
max_to_process: int = 10
) -> Dict[str, Any]:
"""
Process pending validation requests
Can be run as a scheduled job to handle any pending validations
that failed to trigger automatically.
Args:
tenant_id: Optional tenant ID to filter (process all tenants if None)
max_to_process: Maximum number of pending validations to process
Returns:
Summary of processing results
"""
async with database_manager.get_session() as db:
try:
logger.info(
"Processing pending validations",
tenant_id=tenant_id,
max_to_process=max_to_process
)
service = HistoricalValidationService(db)
if tenant_id:
# Process specific tenant
pending = await service.get_pending_validations(
tenant_id=tenant_id,
limit=max_to_process
)
else:
# Would need to implement get_all_pending_validations for all tenants
# For now, require tenant_id
logger.warning("Processing all tenants not implemented, tenant_id required")
return {
"status": "skipped",
"message": "tenant_id required"
}
if not pending:
logger.info("No pending validations found")
return {
"status": "success",
"pending_count": 0,
"processed": 0
}
results = []
for update_record in pending:
try:
result = await service.backfill_validation(
tenant_id=update_record.tenant_id,
start_date=update_record.update_date_start,
end_date=update_record.update_date_end,
triggered_by="pending_processor",
sales_data_update_id=update_record.id
)
results.append({
"update_id": str(update_record.id),
"status": "success",
"validation_run_id": result.get("validation_run_id")
})
except Exception as e:
logger.error(
"Failed to process pending validation",
update_id=update_record.id,
error=str(e)
)
results.append({
"update_id": str(update_record.id),
"status": "failed",
"error": str(e)
})
successful = sum(1 for r in results if r["status"] == "success")
logger.info(
"Pending validations processed",
pending_count=len(pending),
processed=len(results),
successful=successful
)
return {
"status": "success",
"pending_count": len(pending),
"processed": len(results),
"successful": successful,
"failed": len(results) - successful,
"results": results
}
except Exception as e:
logger.error(
"Failed to process pending validations",
error=str(e)
)
return {
"status": "failed",
"error": str(e)
}

View File

@@ -0,0 +1,208 @@
# ================================================================
# services/forecasting/app/main.py
# ================================================================
"""
Forecasting Service Main Application
Demand prediction and forecasting service for bakery operations
"""
from fastapi import FastAPI
from sqlalchemy import text
from app.core.config import settings
from app.core.database import database_manager
from app.services.forecasting_alert_service import ForecastingAlertService
from shared.service_base import StandardFastAPIService
# Import API routers
from app.api import forecasts, forecasting_operations, analytics, scenario_operations, audit, ml_insights, validation, historical_validation, webhooks, performance_monitoring, retraining, enterprise_forecasting, internal_demo, forecast_feedback
class ForecastingService(StandardFastAPIService):
"""Forecasting Service with standardized setup"""
expected_migration_version = "00003"
async def on_startup(self, app):
"""Custom startup logic including migration verification"""
await self.verify_migrations()
await super().on_startup(app)
async def verify_migrations(self):
"""Verify database schema matches the latest migrations."""
try:
async with self.database_manager.get_session() as session:
result = await session.execute(text("SELECT version_num FROM alembic_version"))
version = result.scalar()
if version != self.expected_migration_version:
self.logger.error(f"Migration version mismatch: expected {self.expected_migration_version}, got {version}")
raise RuntimeError(f"Migration version mismatch: expected {self.expected_migration_version}, got {version}")
self.logger.info(f"Migration verification successful: {version}")
except Exception as e:
self.logger.error(f"Migration verification failed: {e}")
raise
def __init__(self):
# Define expected database tables for health checks
forecasting_expected_tables = [
'forecasts', 'prediction_batches', 'model_performance_metrics', 'prediction_cache', 'validation_runs', 'sales_data_updates'
]
self.alert_service = None
self.rabbitmq_client = None
self.event_publisher = None
# Create custom checks for alert service
async def alert_service_check():
"""Custom health check for forecasting alert service"""
return await self.alert_service.health_check() if self.alert_service else False
# Define custom metrics for forecasting service
forecasting_custom_metrics = {
"forecasts_generated_total": {
"type": "counter",
"description": "Total forecasts generated"
},
"predictions_served_total": {
"type": "counter",
"description": "Total predictions served"
},
"prediction_errors_total": {
"type": "counter",
"description": "Total prediction errors"
},
"forecast_processing_time_seconds": {
"type": "histogram",
"description": "Time to process forecast request"
},
"prediction_processing_time_seconds": {
"type": "histogram",
"description": "Time to process prediction request"
},
"model_cache_hits_total": {
"type": "counter",
"description": "Total model cache hits"
},
"model_cache_misses_total": {
"type": "counter",
"description": "Total model cache misses"
}
}
super().__init__(
service_name="forecasting-service",
app_name="Bakery Forecasting Service",
description="AI-powered demand prediction and forecasting service for bakery operations",
version="1.0.0",
log_level=settings.LOG_LEVEL,
cors_origins=settings.CORS_ORIGINS_LIST,
api_prefix="", # Empty because RouteBuilder already includes /api/v1
database_manager=database_manager,
expected_tables=forecasting_expected_tables,
custom_health_checks={"alert_service": alert_service_check},
enable_messaging=True,
custom_metrics=forecasting_custom_metrics
)
async def _setup_messaging(self):
"""Setup messaging for forecasting service using unified messaging"""
from shared.messaging import UnifiedEventPublisher, RabbitMQClient
try:
self.rabbitmq_client = RabbitMQClient(settings.RABBITMQ_URL, service_name="forecasting-service")
await self.rabbitmq_client.connect()
# Create unified event publisher
self.event_publisher = UnifiedEventPublisher(self.rabbitmq_client, "forecasting-service")
self.logger.info("Forecasting service unified messaging setup completed")
except Exception as e:
self.logger.error("Failed to setup forecasting unified messaging", error=str(e))
raise
async def _cleanup_messaging(self):
"""Cleanup messaging for forecasting service"""
try:
if self.rabbitmq_client:
await self.rabbitmq_client.disconnect()
self.logger.info("Forecasting service messaging cleanup completed")
except Exception as e:
self.logger.error("Error during forecasting messaging cleanup", error=str(e))
async def on_startup(self, app: FastAPI):
"""Custom startup logic for forecasting service"""
await super().on_startup(app)
# Initialize forecasting alert service with EventPublisher
if self.event_publisher:
self.alert_service = ForecastingAlertService(self.event_publisher)
await self.alert_service.start()
self.logger.info("Forecasting alert service initialized")
else:
self.logger.error("Event publisher not initialized, alert service unavailable")
# Store the event publisher in app state for internal API access
app.state.event_publisher = self.event_publisher
async def on_shutdown(self, app: FastAPI):
"""Custom shutdown logic for forecasting service"""
# Cleanup alert service
if self.alert_service:
await self.alert_service.stop()
self.logger.info("Alert service cleanup completed")
def get_service_features(self):
"""Return forecasting-specific features"""
return [
"demand_prediction",
"ai_forecasting",
"model_performance_tracking",
"prediction_caching",
"alert_notifications",
"messaging_integration"
]
def setup_custom_endpoints(self):
"""Setup custom endpoints for forecasting service"""
@self.app.get("/alert-metrics")
async def get_alert_metrics():
"""Alert service metrics endpoint"""
if self.alert_service:
return self.alert_service.get_metrics()
return {"error": "Alert service not initialized"}
# Create service instance
service = ForecastingService()
# Create FastAPI app with standardized setup
app = service.create_app(
docs_url="/docs",
redoc_url="/redoc"
)
# Setup standard endpoints
service.setup_standard_endpoints()
# Setup custom endpoints
service.setup_custom_endpoints()
# Include API routers
# IMPORTANT: Register audit router FIRST to avoid route matching conflicts
service.add_router(audit.router)
service.add_router(forecasts.router)
service.add_router(forecasting_operations.router)
service.add_router(analytics.router)
service.add_router(scenario_operations.router)
service.add_router(internal_demo.router, tags=["internal-demo"])
service.add_router(ml_insights.router) # ML insights endpoint
service.add_router(ml_insights.internal_router) # Internal ML insights endpoint
service.add_router(validation.router) # Validation endpoint
service.add_router(historical_validation.router) # Historical validation endpoint
service.add_router(webhooks.router) # Webhooks endpoint
service.add_router(performance_monitoring.router) # Performance monitoring endpoint
service.add_router(retraining.router) # Retraining endpoint
service.add_router(enterprise_forecasting.router) # Enterprise forecasting endpoint
service.add_router(forecast_feedback.router) # Forecast feedback endpoint
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@@ -0,0 +1,11 @@
"""
ML Components for Forecasting
Machine learning prediction and forecasting components
"""
from .predictor import BakeryPredictor, BakeryForecaster
__all__ = [
"BakeryPredictor",
"BakeryForecaster"
]

View File

@@ -0,0 +1,393 @@
"""
Business Rules Insights Orchestrator
Coordinates business rules optimization and insight posting
"""
import pandas as pd
from typing import Dict, List, Any, Optional
import structlog
from datetime import datetime
from uuid import UUID
import sys
import os
# Add shared clients to path
sys.path.append(os.path.join(os.path.dirname(__file__), '../../../..'))
from shared.clients.ai_insights_client import AIInsightsClient
from shared.messaging import UnifiedEventPublisher
from app.ml.dynamic_rules_engine import DynamicRulesEngine
logger = structlog.get_logger()
class BusinessRulesInsightsOrchestrator:
"""
Orchestrates business rules analysis and insight generation workflow.
Workflow:
1. Analyze dynamic business rule performance
2. Generate insights for rule optimization
3. Post insights to AI Insights Service
4. Publish recommendation events to RabbitMQ
5. Provide rule optimization for forecasting
6. Track rule effectiveness and improvements
"""
def __init__(
self,
ai_insights_base_url: str = "http://ai-insights-service:8000",
event_publisher: Optional[UnifiedEventPublisher] = None
):
self.rules_engine = DynamicRulesEngine()
self.ai_insights_client = AIInsightsClient(ai_insights_base_url)
self.event_publisher = event_publisher
async def analyze_and_post_business_rules_insights(
self,
tenant_id: str,
inventory_product_id: str,
sales_data: pd.DataFrame,
min_samples: int = 10
) -> Dict[str, Any]:
"""
Complete workflow: Analyze business rules and post insights.
Args:
tenant_id: Tenant identifier
inventory_product_id: Product identifier
sales_data: Historical sales data
min_samples: Minimum samples for rule analysis
Returns:
Workflow results with analysis and posted insights
"""
logger.info(
"Starting business rules analysis workflow",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
samples=len(sales_data)
)
# Step 1: Learn and analyze rules
rules_results = await self.rules_engine.learn_all_rules(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
sales_data=sales_data,
external_data=None,
min_samples=min_samples
)
logger.info(
"Business rules analysis complete",
insights_generated=len(rules_results.get('insights', [])),
rules_learned=len(rules_results.get('rules', {}))
)
# Step 2: Enrich insights with tenant_id and product context
enriched_insights = self._enrich_insights(
rules_results.get('insights', []),
tenant_id,
inventory_product_id
)
# Step 3: Post insights to AI Insights Service
if enriched_insights:
post_results = await self.ai_insights_client.create_insights_bulk(
tenant_id=UUID(tenant_id),
insights=enriched_insights
)
logger.info(
"Business rules insights posted to AI Insights Service",
inventory_product_id=inventory_product_id,
total=post_results['total'],
successful=post_results['successful'],
failed=post_results['failed']
)
else:
post_results = {'total': 0, 'successful': 0, 'failed': 0}
logger.info("No insights to post for product", inventory_product_id=inventory_product_id)
# Step 4: Publish insight events to RabbitMQ
created_insights = post_results.get('created_insights', [])
if created_insights:
product_context = {'inventory_product_id': inventory_product_id}
await self._publish_insight_events(
tenant_id=tenant_id,
insights=created_insights,
product_context=product_context
)
# Step 5: Return comprehensive results
return {
'tenant_id': tenant_id,
'inventory_product_id': inventory_product_id,
'learned_at': rules_results['learned_at'],
'rules': rules_results.get('rules', {}),
'insights_generated': len(enriched_insights),
'insights_posted': post_results['successful'],
'insights_failed': post_results['failed'],
'created_insights': post_results.get('created_insights', [])
}
def _enrich_insights(
self,
insights: List[Dict[str, Any]],
tenant_id: str,
inventory_product_id: str
) -> List[Dict[str, Any]]:
"""
Enrich insights with required fields for AI Insights Service.
Args:
insights: Raw insights from rules engine
tenant_id: Tenant identifier
inventory_product_id: Product identifier
Returns:
Enriched insights ready for posting
"""
enriched = []
for insight in insights:
# Add required tenant_id
enriched_insight = insight.copy()
enriched_insight['tenant_id'] = tenant_id
# Add product context to metrics
if 'metrics_json' not in enriched_insight:
enriched_insight['metrics_json'] = {}
enriched_insight['metrics_json']['inventory_product_id'] = inventory_product_id
# Add source metadata
enriched_insight['source_service'] = 'forecasting'
enriched_insight['source_model'] = 'dynamic_rules_engine'
enriched_insight['detected_at'] = datetime.utcnow().isoformat()
enriched.append(enriched_insight)
return enriched
async def analyze_all_business_rules(
self,
tenant_id: str,
products_data: Dict[str, pd.DataFrame],
min_samples: int = 10
) -> Dict[str, Any]:
"""
Analyze all products for business rules optimization and generate comparative insights.
Args:
tenant_id: Tenant identifier
products_data: Dict of {inventory_product_id: sales_data DataFrame}
min_samples: Minimum samples for rule analysis
Returns:
Comprehensive analysis with rule optimization insights
"""
logger.info(
"Analyzing business rules for all products",
tenant_id=tenant_id,
products=len(products_data)
)
all_results = []
total_insights_posted = 0
# Analyze each product
for inventory_product_id, sales_data in products_data.items():
try:
results = await self.analyze_and_post_business_rules_insights(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
sales_data=sales_data,
min_samples=min_samples
)
all_results.append(results)
total_insights_posted += results['insights_posted']
except Exception as e:
logger.error(
"Error analyzing business rules for product",
inventory_product_id=inventory_product_id,
error=str(e)
)
# Generate summary insight
if total_insights_posted > 0:
summary_insight = self._generate_portfolio_summary_insight(
tenant_id, all_results
)
if summary_insight:
enriched_summary = self._enrich_insights(
[summary_insight], tenant_id, 'all_products'
)
post_results = await self.ai_insights_client.create_insights_bulk(
tenant_id=UUID(tenant_id),
insights=enriched_summary
)
total_insights_posted += post_results['successful']
logger.info(
"All business rules analysis complete",
tenant_id=tenant_id,
products_analyzed=len(all_results),
total_insights_posted=total_insights_posted
)
return {
'tenant_id': tenant_id,
'analyzed_at': datetime.utcnow().isoformat(),
'products_analyzed': len(all_results),
'product_results': all_results,
'total_insights_posted': total_insights_posted
}
def _generate_portfolio_summary_insight(
self,
tenant_id: str,
all_results: List[Dict[str, Any]]
) -> Optional[Dict[str, Any]]:
"""
Generate portfolio-level business rules summary insight.
Args:
tenant_id: Tenant identifier
all_results: All product analysis results
Returns:
Summary insight or None
"""
if not all_results:
return None
# Calculate summary statistics
total_products = len(all_results)
total_rules = sum(len(r.get('rules', {})) for r in all_results)
# Count products with significant rule improvements
significant_improvements = sum(1 for r in all_results
if any('improvement' in str(v).lower() for v in r.get('rules', {}).values()))
return {
'type': 'recommendation',
'priority': 'high' if significant_improvements > total_products * 0.3 else 'medium',
'category': 'forecasting',
'title': f'Business Rule Optimization: {total_products} Products Analyzed',
'description': f'Learned {total_rules} dynamic rules across {total_products} products. Identified {significant_improvements} products with significant rule improvements.',
'impact_type': 'operational_efficiency',
'impact_value': total_rules,
'impact_unit': 'rules',
'confidence': 80,
'metrics_json': {
'total_products': total_products,
'total_rules': total_rules,
'significant_improvements': significant_improvements,
'rules_per_product': round(total_rules / total_products, 2)
},
'actionable': True,
'recommendation_actions': [
{
'label': 'Review Learned Rules',
'action': 'review_business_rules',
'params': {'tenant_id': tenant_id}
},
{
'label': 'Implement Optimized Rules',
'action': 'implement_business_rules',
'params': {'tenant_id': tenant_id}
}
],
'source_service': 'forecasting',
'source_model': 'dynamic_rules_engine'
}
async def get_learned_rules(
self,
inventory_product_id: str
) -> Optional[Dict[str, Any]]:
"""
Get cached learned rules for a product.
Args:
inventory_product_id: Product identifier
Returns:
Learned rules or None if not analyzed
"""
return self.rules_engine.get_all_rules(inventory_product_id)
async def _publish_insight_events(self, tenant_id, insights, product_context=None):
"""
Publish insight events to RabbitMQ for alert processing.
Args:
tenant_id: Tenant identifier
insights: List of created insights
product_context: Additional context about the product
"""
if not self.event_publisher:
logger.warning("No event publisher available for business rules insights")
return
for insight in insights:
# Determine severity based on confidence and priority
confidence = insight.get('confidence', 0)
priority = insight.get('priority', 'medium')
# Map priority to severity, with confidence as tiebreaker
if priority == 'critical' or (priority == 'high' and confidence >= 70):
severity = 'high'
elif priority == 'high' or (priority == 'medium' and confidence >= 80):
severity = 'medium'
else:
severity = 'low'
# Prepare the event data
event_data = {
'insight_id': insight.get('id'),
'type': insight.get('type'),
'title': insight.get('title'),
'description': insight.get('description'),
'category': insight.get('category'),
'priority': insight.get('priority'),
'confidence': confidence,
'recommendation': insight.get('recommendation_actions', []),
'impact_type': insight.get('impact_type'),
'impact_value': insight.get('impact_value'),
'inventory_product_id': product_context.get('inventory_product_id') if product_context else None,
'timestamp': insight.get('detected_at', datetime.utcnow().isoformat()),
'source_service': 'forecasting',
'source_model': 'dynamic_rules_engine'
}
try:
await self.event_publisher.publish_recommendation(
event_type='ai_business_rule',
tenant_id=tenant_id,
severity=severity,
data=event_data
)
logger.info(
"Published business rules insight event",
tenant_id=tenant_id,
insight_id=insight.get('id'),
severity=severity
)
except Exception as e:
logger.error(
"Failed to publish business rules insight event",
tenant_id=tenant_id,
insight_id=insight.get('id'),
error=str(e)
)
async def close(self):
"""Close HTTP client connections."""
await self.ai_insights_client.close()

View File

@@ -0,0 +1,235 @@
"""
Calendar-based Feature Engineering for Forecasting Service
Generates calendar features for future date predictions
"""
import pandas as pd
import structlog
from typing import Dict, List, Any, Optional
from datetime import datetime, date, time, timedelta
from app.services.data_client import data_client
logger = structlog.get_logger()
class ForecastCalendarFeatures:
"""
Generates calendar-based features for future predictions
Optimized for forecasting service (future dates only)
"""
def __init__(self):
self.calendar_cache = {} # Cache calendar data per tenant
async def get_calendar_for_tenant(
self,
tenant_id: str
) -> Optional[Dict[str, Any]]:
"""Get cached calendar for tenant"""
if tenant_id in self.calendar_cache:
return self.calendar_cache[tenant_id]
calendar = await data_client.fetch_tenant_calendar(tenant_id)
if calendar:
self.calendar_cache[tenant_id] = calendar
return calendar
def _is_date_in_holiday_period(
self,
check_date: date,
holiday_periods: List[Dict[str, Any]]
) -> tuple[bool, Optional[str]]:
"""Check if date is within any holiday period"""
for period in holiday_periods:
start = datetime.strptime(period["start_date"], "%Y-%m-%d").date()
end = datetime.strptime(period["end_date"], "%Y-%m-%d").date()
if start <= check_date <= end:
return True, period["name"]
return False, None
def _is_school_hours_active(
self,
check_datetime: datetime,
school_hours: Dict[str, Any]
) -> bool:
"""Check if datetime falls during school operating hours"""
# Only weekdays
if check_datetime.weekday() >= 5:
return False
check_time = check_datetime.time()
# Morning session
morning_start = datetime.strptime(
school_hours["morning_start"], "%H:%M"
).time()
morning_end = datetime.strptime(
school_hours["morning_end"], "%H:%M"
).time()
if morning_start <= check_time <= morning_end:
return True
# Afternoon session if exists
if school_hours.get("has_afternoon_session", False):
afternoon_start = datetime.strptime(
school_hours["afternoon_start"], "%H:%M"
).time()
afternoon_end = datetime.strptime(
school_hours["afternoon_end"], "%H:%M"
).time()
if afternoon_start <= check_time <= afternoon_end:
return True
return False
def _calculate_school_proximity_intensity(
self,
check_datetime: datetime,
school_hours: Dict[str, Any]
) -> float:
"""
Calculate school proximity impact intensity
Returns 0.0-1.0 based on drop-off/pick-up times
"""
# Only weekdays
if check_datetime.weekday() >= 5:
return 0.0
check_time = check_datetime.time()
morning_start = datetime.strptime(
school_hours["morning_start"], "%H:%M"
).time()
morning_end = datetime.strptime(
school_hours["morning_end"], "%H:%M"
).time()
# Morning drop-off peak (30 min before to 15 min after start)
drop_off_start = (
datetime.combine(date.today(), morning_start) - timedelta(minutes=30)
).time()
drop_off_end = (
datetime.combine(date.today(), morning_start) + timedelta(minutes=15)
).time()
if drop_off_start <= check_time <= drop_off_end:
return 1.0 # Peak
# Morning pick-up peak (15 min before to 30 min after end)
pickup_start = (
datetime.combine(date.today(), morning_end) - timedelta(minutes=15)
).time()
pickup_end = (
datetime.combine(date.today(), morning_end) + timedelta(minutes=30)
).time()
if pickup_start <= check_time <= pickup_end:
return 1.0 # Peak
# During school hours (moderate)
if morning_start <= check_time <= morning_end:
return 0.3
return 0.0
async def add_calendar_features(
self,
df: pd.DataFrame,
tenant_id: str,
date_column: str = "ds"
) -> pd.DataFrame:
"""
Add calendar features to forecast dataframe
Args:
df: Forecast dataframe with future dates
tenant_id: Tenant ID to fetch calendar
date_column: Name of date column (default 'ds' for Prophet)
Returns:
DataFrame with calendar features added
"""
try:
logger.info(
"Adding calendar features to forecast",
tenant_id=tenant_id,
rows=len(df)
)
# Get calendar
calendar = await self.get_calendar_for_tenant(tenant_id)
if not calendar:
logger.info(
"No calendar available, using zero features",
tenant_id=tenant_id
)
df["is_school_holiday"] = 0
df["school_hours_active"] = 0
df["school_proximity_intensity"] = 0.0
return df
holiday_periods = calendar.get("holiday_periods", [])
school_hours = calendar.get("school_hours", {})
# Initialize feature lists
school_holidays = []
hours_active = []
proximity_intensity = []
# Process each row
for idx, row in df.iterrows():
row_date = pd.to_datetime(row[date_column])
# Check holiday
is_holiday, _ = self._is_date_in_holiday_period(
row_date.date(),
holiday_periods
)
school_holidays.append(1 if is_holiday else 0)
# Check school hours and proximity (if datetime has time component)
if hasattr(row_date, 'hour'):
hours_active.append(
1 if self._is_school_hours_active(row_date, school_hours) else 0
)
proximity_intensity.append(
self._calculate_school_proximity_intensity(row_date, school_hours)
)
else:
hours_active.append(0)
proximity_intensity.append(0.0)
# Add features
df["is_school_holiday"] = school_holidays
df["school_hours_active"] = hours_active
df["school_proximity_intensity"] = proximity_intensity
logger.info(
"Calendar features added to forecast",
tenant_id=tenant_id,
holidays_in_forecast=sum(school_holidays)
)
return df
except Exception as e:
logger.error(
"Error adding calendar features to forecast",
tenant_id=tenant_id,
error=str(e)
)
# Return with zero features on error
df["is_school_holiday"] = 0
df["school_hours_active"] = 0
df["school_proximity_intensity"] = 0.0
return df
# Global instance
forecast_calendar_features = ForecastCalendarFeatures()

View File

@@ -0,0 +1,403 @@
"""
Demand Insights Orchestrator
Coordinates demand forecasting analysis and insight posting
"""
import pandas as pd
from typing import Dict, List, Any, Optional
import structlog
from datetime import datetime
from uuid import UUID
import sys
import os
# Add shared clients to path
sys.path.append(os.path.join(os.path.dirname(__file__), '../../../..'))
from shared.clients.ai_insights_client import AIInsightsClient
from shared.messaging import UnifiedEventPublisher
from app.ml.predictor import BakeryForecaster
logger = structlog.get_logger()
class DemandInsightsOrchestrator:
"""
Orchestrates demand forecasting analysis and insight generation workflow.
Workflow:
1. Analyze historical demand patterns from sales data
2. Generate insights for demand optimization
3. Post insights to AI Insights Service
4. Publish recommendation events to RabbitMQ
5. Provide demand pattern analysis for forecasting
6. Track demand forecasting performance
"""
def __init__(
self,
ai_insights_base_url: str = "http://ai-insights-service:8000",
event_publisher: Optional[UnifiedEventPublisher] = None
):
self.forecaster = BakeryForecaster()
self.ai_insights_client = AIInsightsClient(ai_insights_base_url)
self.event_publisher = event_publisher
async def analyze_and_post_demand_insights(
self,
tenant_id: str,
inventory_product_id: str,
sales_data: pd.DataFrame,
forecast_horizon_days: int = 30,
min_history_days: int = 90
) -> Dict[str, Any]:
"""
Complete workflow: Analyze demand and post insights.
Args:
tenant_id: Tenant identifier
inventory_product_id: Product identifier
sales_data: Historical sales data
forecast_horizon_days: Days to forecast ahead
min_history_days: Minimum days of history required
Returns:
Workflow results with analysis and posted insights
"""
logger.info(
"Starting demand forecasting analysis workflow",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
history_days=len(sales_data)
)
# Step 1: Analyze demand patterns
analysis_results = await self.forecaster.analyze_demand_patterns(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
sales_data=sales_data,
forecast_horizon_days=forecast_horizon_days,
min_history_days=min_history_days
)
logger.info(
"Demand analysis complete",
inventory_product_id=inventory_product_id,
insights_generated=len(analysis_results.get('insights', []))
)
# Step 2: Enrich insights with tenant_id and product context
enriched_insights = self._enrich_insights(
analysis_results.get('insights', []),
tenant_id,
inventory_product_id
)
# Step 3: Post insights to AI Insights Service
if enriched_insights:
post_results = await self.ai_insights_client.create_insights_bulk(
tenant_id=UUID(tenant_id),
insights=enriched_insights
)
logger.info(
"Demand insights posted to AI Insights Service",
inventory_product_id=inventory_product_id,
total=post_results['total'],
successful=post_results['successful'],
failed=post_results['failed']
)
else:
post_results = {'total': 0, 'successful': 0, 'failed': 0}
logger.info("No insights to post for product", inventory_product_id=inventory_product_id)
# Step 4: Publish insight events to RabbitMQ
created_insights = post_results.get('created_insights', [])
if created_insights:
product_context = {'inventory_product_id': inventory_product_id}
await self._publish_insight_events(
tenant_id=tenant_id,
insights=created_insights,
product_context=product_context
)
# Step 5: Return comprehensive results
return {
'tenant_id': tenant_id,
'inventory_product_id': inventory_product_id,
'analyzed_at': analysis_results['analyzed_at'],
'history_days': analysis_results['history_days'],
'demand_patterns': analysis_results.get('patterns', {}),
'trend_analysis': analysis_results.get('trend_analysis', {}),
'seasonal_factors': analysis_results.get('seasonal_factors', {}),
'insights_generated': len(enriched_insights),
'insights_posted': post_results['successful'],
'insights_failed': post_results['failed'],
'created_insights': post_results.get('created_insights', [])
}
def _enrich_insights(
self,
insights: List[Dict[str, Any]],
tenant_id: str,
inventory_product_id: str
) -> List[Dict[str, Any]]:
"""
Enrich insights with required fields for AI Insights Service.
Args:
insights: Raw insights from forecaster
tenant_id: Tenant identifier
inventory_product_id: Product identifier
Returns:
Enriched insights ready for posting
"""
enriched = []
for insight in insights:
# Add required tenant_id
enriched_insight = insight.copy()
enriched_insight['tenant_id'] = tenant_id
# Add product context to metrics
if 'metrics_json' not in enriched_insight:
enriched_insight['metrics_json'] = {}
enriched_insight['metrics_json']['inventory_product_id'] = inventory_product_id
# Add source metadata
enriched_insight['source_service'] = 'forecasting'
enriched_insight['source_model'] = 'demand_analyzer'
enriched_insight['detected_at'] = datetime.utcnow().isoformat()
enriched.append(enriched_insight)
return enriched
async def analyze_all_products(
self,
tenant_id: str,
products_data: Dict[str, pd.DataFrame],
forecast_horizon_days: int = 30,
min_history_days: int = 90
) -> Dict[str, Any]:
"""
Analyze all products for a tenant and generate comparative insights.
Args:
tenant_id: Tenant identifier
products_data: Dict of {inventory_product_id: sales_data DataFrame}
forecast_horizon_days: Days to forecast
min_history_days: Minimum history required
Returns:
Comprehensive analysis with product comparison
"""
logger.info(
"Analyzing all products for tenant",
tenant_id=tenant_id,
products=len(products_data)
)
all_results = []
total_insights_posted = 0
# Analyze each product
for inventory_product_id, sales_data in products_data.items():
try:
results = await self.analyze_and_post_demand_insights(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
sales_data=sales_data,
forecast_horizon_days=forecast_horizon_days,
min_history_days=min_history_days
)
all_results.append(results)
total_insights_posted += results['insights_posted']
except Exception as e:
logger.error(
"Error analyzing product",
inventory_product_id=inventory_product_id,
error=str(e)
)
# Generate summary insight
if total_insights_posted > 0:
summary_insight = self._generate_portfolio_summary_insight(
tenant_id, all_results
)
if summary_insight:
enriched_summary = self._enrich_insights(
[summary_insight], tenant_id, 'all_products'
)
post_results = await self.ai_insights_client.create_insights_bulk(
tenant_id=UUID(tenant_id),
insights=enriched_summary
)
total_insights_posted += post_results['successful']
logger.info(
"All products analysis complete",
tenant_id=tenant_id,
products_analyzed=len(all_results),
total_insights_posted=total_insights_posted
)
return {
'tenant_id': tenant_id,
'analyzed_at': datetime.utcnow().isoformat(),
'products_analyzed': len(all_results),
'product_results': all_results,
'total_insights_posted': total_insights_posted
}
def _generate_portfolio_summary_insight(
self,
tenant_id: str,
all_results: List[Dict[str, Any]]
) -> Optional[Dict[str, Any]]:
"""
Generate portfolio-level summary insight.
Args:
tenant_id: Tenant identifier
all_results: All product analysis results
Returns:
Summary insight or None
"""
if not all_results:
return None
# Calculate summary statistics
total_products = len(all_results)
high_demand_products = sum(1 for r in all_results if r.get('trend_analysis', {}).get('is_increasing', False))
avg_seasonal_factor = sum(
r.get('seasonal_factors', {}).get('peak_ratio', 1.0)
for r in all_results
if r.get('seasonal_factors', {}).get('peak_ratio')
) / max(1, len(all_results))
return {
'type': 'recommendation',
'priority': 'medium' if high_demand_products > total_products * 0.5 else 'low',
'category': 'forecasting',
'title': f'Demand Pattern Summary: {total_products} Products Analyzed',
'description': f'Detected {high_demand_products} products with increasing demand trends. Average seasonal peak ratio: {avg_seasonal_factor:.2f}x.',
'impact_type': 'demand_optimization',
'impact_value': high_demand_products,
'impact_unit': 'products',
'confidence': 75,
'metrics_json': {
'total_products': total_products,
'high_demand_products': high_demand_products,
'avg_seasonal_factor': round(avg_seasonal_factor, 2),
'trend_strength': 'strong' if high_demand_products > total_products * 0.7 else 'moderate'
},
'actionable': True,
'recommendation_actions': [
{
'label': 'Review Production Schedule',
'action': 'review_production_schedule',
'params': {'tenant_id': tenant_id}
},
{
'label': 'Adjust Inventory Levels',
'action': 'adjust_inventory_levels',
'params': {'tenant_id': tenant_id}
}
],
'source_service': 'forecasting',
'source_model': 'demand_analyzer'
}
async def get_demand_patterns(
self,
inventory_product_id: str
) -> Optional[Dict[str, Any]]:
"""
Get cached demand patterns for a product.
Args:
inventory_product_id: Product identifier
Returns:
Demand patterns or None if not analyzed
"""
return self.forecaster.get_cached_demand_patterns(inventory_product_id)
async def _publish_insight_events(self, tenant_id, insights, product_context=None):
"""
Publish insight events to RabbitMQ for alert processing.
Args:
tenant_id: Tenant identifier
insights: List of created insights
product_context: Additional context about the product
"""
if not self.event_publisher:
logger.warning("No event publisher available for demand insights")
return
for insight in insights:
# Determine severity based on confidence and priority
confidence = insight.get('confidence', 0)
priority = insight.get('priority', 'medium')
# Map priority to severity, with confidence as tiebreaker
if priority == 'critical' or (priority == 'high' and confidence >= 70):
severity = 'high'
elif priority == 'high' or (priority == 'medium' and confidence >= 80):
severity = 'medium'
else:
severity = 'low'
# Prepare the event data
event_data = {
'insight_id': insight.get('id'),
'type': insight.get('type'),
'title': insight.get('title'),
'description': insight.get('description'),
'category': insight.get('category'),
'priority': insight.get('priority'),
'confidence': confidence,
'recommendation': insight.get('recommendation_actions', []),
'impact_type': insight.get('impact_type'),
'impact_value': insight.get('impact_value'),
'inventory_product_id': product_context.get('inventory_product_id') if product_context else None,
'timestamp': insight.get('detected_at', datetime.utcnow().isoformat()),
'source_service': 'forecasting',
'source_model': 'demand_analyzer'
}
try:
await self.event_publisher.publish_recommendation(
event_type='ai_demand_forecast',
tenant_id=tenant_id,
severity=severity,
data=event_data
)
logger.info(
"Published demand insight event",
tenant_id=tenant_id,
insight_id=insight.get('id'),
severity=severity
)
except Exception as e:
logger.error(
"Failed to publish demand insight event",
tenant_id=tenant_id,
insight_id=insight.get('id'),
error=str(e)
)
async def close(self):
"""Close HTTP client connections."""
await self.ai_insights_client.close()

View File

@@ -0,0 +1,758 @@
"""
Dynamic Business Rules Engine
Learns optimal adjustment factors from historical data instead of using hardcoded values
Replaces hardcoded weather multipliers, holiday adjustments, event impacts with learned values
"""
import pandas as pd
import numpy as np
from typing import Dict, List, Any, Optional, Tuple
import structlog
from datetime import datetime, timedelta
from scipy import stats
from sklearn.linear_model import Ridge
from collections import defaultdict
logger = structlog.get_logger()
class DynamicRulesEngine:
"""
Learns business rules from historical data instead of using hardcoded values.
Current hardcoded values to replace:
- Weather: rain = -15%, snow = -25%, extreme_heat = -10%
- Holidays: +50% (all holidays treated the same)
- Events: +30% (all events treated the same)
- Weekend: Manual assumptions
Dynamic approach:
- Learn actual weather impact per weather condition per product
- Learn holiday multipliers per holiday type
- Learn event impact by event type
- Learn day-of-week patterns per product
- Generate insights when learned values differ from hardcoded assumptions
"""
def __init__(self):
self.weather_rules = {}
self.holiday_rules = {}
self.event_rules = {}
self.dow_rules = {}
self.month_rules = {}
async def learn_all_rules(
self,
tenant_id: str,
inventory_product_id: str,
sales_data: pd.DataFrame,
external_data: Optional[pd.DataFrame] = None,
min_samples: int = 10
) -> Dict[str, Any]:
"""
Learn all business rules from historical data.
Args:
tenant_id: Tenant identifier
inventory_product_id: Product identifier
sales_data: Historical sales data with 'date', 'quantity' columns
external_data: Optional weather/events/holidays data
min_samples: Minimum samples required to learn a rule
Returns:
Dictionary of learned rules and insights
"""
logger.info(
"Learning dynamic business rules from historical data",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
data_points=len(sales_data)
)
results = {
'tenant_id': tenant_id,
'inventory_product_id': inventory_product_id,
'learned_at': datetime.utcnow().isoformat(),
'rules': {},
'insights': []
}
# Ensure date column is datetime
if 'date' not in sales_data.columns:
sales_data = sales_data.copy()
sales_data['date'] = sales_data['ds']
sales_data['date'] = pd.to_datetime(sales_data['date'])
# Learn weather impact rules
if external_data is not None and 'weather_condition' in external_data.columns:
weather_rules, weather_insights = await self._learn_weather_rules(
sales_data, external_data, min_samples
)
results['rules']['weather'] = weather_rules
results['insights'].extend(weather_insights)
self.weather_rules[inventory_product_id] = weather_rules
# Learn holiday rules
if external_data is not None and 'is_holiday' in external_data.columns:
holiday_rules, holiday_insights = await self._learn_holiday_rules(
sales_data, external_data, min_samples
)
results['rules']['holidays'] = holiday_rules
results['insights'].extend(holiday_insights)
self.holiday_rules[inventory_product_id] = holiday_rules
# Learn event rules
if external_data is not None and 'event_type' in external_data.columns:
event_rules, event_insights = await self._learn_event_rules(
sales_data, external_data, min_samples
)
results['rules']['events'] = event_rules
results['insights'].extend(event_insights)
self.event_rules[inventory_product_id] = event_rules
# Learn day-of-week patterns (always available)
dow_rules, dow_insights = await self._learn_day_of_week_rules(
sales_data, min_samples
)
results['rules']['day_of_week'] = dow_rules
results['insights'].extend(dow_insights)
self.dow_rules[inventory_product_id] = dow_rules
# Learn monthly seasonality
month_rules, month_insights = await self._learn_month_rules(
sales_data, min_samples
)
results['rules']['months'] = month_rules
results['insights'].extend(month_insights)
self.month_rules[inventory_product_id] = month_rules
logger.info(
"Dynamic rules learning complete",
total_insights=len(results['insights']),
rules_learned=len(results['rules'])
)
return results
async def _learn_weather_rules(
self,
sales_data: pd.DataFrame,
external_data: pd.DataFrame,
min_samples: int
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
"""
Learn actual weather impact from historical data.
Hardcoded assumptions:
- rain: -15%
- snow: -25%
- extreme_heat: -10%
Learn actual impact for this product.
"""
logger.info("Learning weather impact rules")
# Merge sales with weather data
merged = sales_data.merge(
external_data[['date', 'weather_condition', 'temperature', 'precipitation']],
on='date',
how='left'
)
# Baseline: average sales on clear days
clear_days = merged[
(merged['weather_condition'].isin(['clear', 'sunny', 'partly_cloudy'])) |
(merged['weather_condition'].isna())
]
baseline_avg = clear_days['quantity'].mean()
weather_rules = {
'baseline_avg': float(baseline_avg),
'conditions': {}
}
insights = []
# Hardcoded values for comparison
hardcoded_impacts = {
'rain': -0.15,
'snow': -0.25,
'extreme_heat': -0.10
}
# Learn impact for each weather condition
for condition in ['rain', 'rainy', 'snow', 'snowy', 'extreme_heat', 'hot', 'storm', 'fog']:
condition_days = merged[merged['weather_condition'].str.contains(condition, case=False, na=False)]
if len(condition_days) >= min_samples:
condition_avg = condition_days['quantity'].mean()
learned_impact = (condition_avg - baseline_avg) / baseline_avg
# Statistical significance test
t_stat, p_value = stats.ttest_ind(
condition_days['quantity'].values,
clear_days['quantity'].values,
equal_var=False
)
weather_rules['conditions'][condition] = {
'learned_multiplier': float(1 + learned_impact),
'learned_impact_pct': float(learned_impact * 100),
'sample_size': int(len(condition_days)),
'avg_quantity': float(condition_avg),
'p_value': float(p_value),
'significant': bool(p_value < 0.05)
}
# Compare with hardcoded value if exists
if condition in hardcoded_impacts and p_value < 0.05:
hardcoded_impact = hardcoded_impacts[condition]
difference = abs(learned_impact - hardcoded_impact)
if difference > 0.05: # More than 5% difference
insight = {
'type': 'optimization',
'priority': 'high' if difference > 0.15 else 'medium',
'category': 'forecasting',
'title': f'Weather Rule Mismatch: {condition.title()}',
'description': f'Learned {condition} impact is {learned_impact*100:.1f}% vs hardcoded {hardcoded_impact*100:.1f}%. Updating rule could improve forecast accuracy by {difference*100:.1f}%.',
'impact_type': 'forecast_improvement',
'impact_value': difference * 100,
'impact_unit': 'percentage_points',
'confidence': self._calculate_confidence(len(condition_days), p_value),
'metrics_json': {
'weather_condition': condition,
'learned_impact_pct': round(learned_impact * 100, 2),
'hardcoded_impact_pct': round(hardcoded_impact * 100, 2),
'difference_pct': round(difference * 100, 2),
'baseline_avg': round(baseline_avg, 2),
'condition_avg': round(condition_avg, 2),
'sample_size': len(condition_days),
'p_value': round(p_value, 4)
},
'actionable': True,
'recommendation_actions': [
{
'label': 'Update Weather Rule',
'action': 'update_weather_multiplier',
'params': {
'condition': condition,
'new_multiplier': round(1 + learned_impact, 3)
}
}
],
'source_service': 'forecasting',
'source_model': 'dynamic_rules_engine'
}
insights.append(insight)
logger.info(
"Weather rule discrepancy detected",
condition=condition,
learned=f"{learned_impact*100:.1f}%",
hardcoded=f"{hardcoded_impact*100:.1f}%"
)
return weather_rules, insights
async def _learn_holiday_rules(
self,
sales_data: pd.DataFrame,
external_data: pd.DataFrame,
min_samples: int
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
"""
Learn holiday impact by holiday type instead of uniform +50%.
Hardcoded: All holidays = +50%
Learn: Christmas vs Easter vs National holidays have different impacts
"""
logger.info("Learning holiday impact rules")
# Merge sales with holiday data
merged = sales_data.merge(
external_data[['date', 'is_holiday', 'holiday_name', 'holiday_type']],
on='date',
how='left'
)
# Baseline: non-holiday average
non_holidays = merged[merged['is_holiday'] == False]
baseline_avg = non_holidays['quantity'].mean()
holiday_rules = {
'baseline_avg': float(baseline_avg),
'hardcoded_multiplier': 1.5, # Current +50%
'holiday_types': {}
}
insights = []
# Learn impact per holiday type
if 'holiday_type' in merged.columns:
for holiday_type in merged[merged['is_holiday'] == True]['holiday_type'].unique():
if pd.isna(holiday_type):
continue
holiday_days = merged[merged['holiday_type'] == holiday_type]
if len(holiday_days) >= min_samples:
holiday_avg = holiday_days['quantity'].mean()
learned_multiplier = holiday_avg / baseline_avg
learned_impact = (learned_multiplier - 1) * 100
# Statistical test
t_stat, p_value = stats.ttest_ind(
holiday_days['quantity'].values,
non_holidays['quantity'].values,
equal_var=False
)
holiday_rules['holiday_types'][holiday_type] = {
'learned_multiplier': float(learned_multiplier),
'learned_impact_pct': float(learned_impact),
'sample_size': int(len(holiday_days)),
'avg_quantity': float(holiday_avg),
'p_value': float(p_value),
'significant': bool(p_value < 0.05)
}
# Compare with hardcoded +50%
hardcoded_multiplier = 1.5
difference = abs(learned_multiplier - hardcoded_multiplier)
if difference > 0.1 and p_value < 0.05: # More than 10% difference
insight = {
'type': 'recommendation',
'priority': 'high' if difference > 0.3 else 'medium',
'category': 'forecasting',
'title': f'Holiday Rule Optimization: {holiday_type}',
'description': f'{holiday_type} shows {learned_impact:.1f}% impact vs hardcoded +50%. Using learned multiplier {learned_multiplier:.2f}x could improve forecast accuracy.',
'impact_type': 'forecast_improvement',
'impact_value': difference * 100,
'impact_unit': 'percentage_points',
'confidence': self._calculate_confidence(len(holiday_days), p_value),
'metrics_json': {
'holiday_type': holiday_type,
'learned_multiplier': round(learned_multiplier, 3),
'hardcoded_multiplier': 1.5,
'learned_impact_pct': round(learned_impact, 2),
'hardcoded_impact_pct': 50.0,
'baseline_avg': round(baseline_avg, 2),
'holiday_avg': round(holiday_avg, 2),
'sample_size': len(holiday_days),
'p_value': round(p_value, 4)
},
'actionable': True,
'recommendation_actions': [
{
'label': 'Update Holiday Rule',
'action': 'update_holiday_multiplier',
'params': {
'holiday_type': holiday_type,
'new_multiplier': round(learned_multiplier, 3)
}
}
],
'source_service': 'forecasting',
'source_model': 'dynamic_rules_engine'
}
insights.append(insight)
logger.info(
"Holiday rule optimization identified",
holiday_type=holiday_type,
learned=f"{learned_multiplier:.2f}x",
hardcoded="1.5x"
)
# Overall holiday impact
all_holidays = merged[merged['is_holiday'] == True]
if len(all_holidays) >= min_samples:
overall_avg = all_holidays['quantity'].mean()
overall_multiplier = overall_avg / baseline_avg
holiday_rules['overall_learned_multiplier'] = float(overall_multiplier)
holiday_rules['overall_learned_impact_pct'] = float((overall_multiplier - 1) * 100)
return holiday_rules, insights
async def _learn_event_rules(
self,
sales_data: pd.DataFrame,
external_data: pd.DataFrame,
min_samples: int
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
"""
Learn event impact by event type instead of uniform +30%.
Hardcoded: All events = +30%
Learn: Sports events vs concerts vs festivals have different impacts
"""
logger.info("Learning event impact rules")
# Merge sales with event data
merged = sales_data.merge(
external_data[['date', 'event_name', 'event_type', 'event_attendance']],
on='date',
how='left'
)
# Baseline: non-event days
non_events = merged[merged['event_name'].isna()]
baseline_avg = non_events['quantity'].mean()
event_rules = {
'baseline_avg': float(baseline_avg),
'hardcoded_multiplier': 1.3, # Current +30%
'event_types': {}
}
insights = []
# Learn impact per event type
if 'event_type' in merged.columns:
for event_type in merged[merged['event_type'].notna()]['event_type'].unique():
if pd.isna(event_type):
continue
event_days = merged[merged['event_type'] == event_type]
if len(event_days) >= min_samples:
event_avg = event_days['quantity'].mean()
learned_multiplier = event_avg / baseline_avg
learned_impact = (learned_multiplier - 1) * 100
# Statistical test
t_stat, p_value = stats.ttest_ind(
event_days['quantity'].values,
non_events['quantity'].values,
equal_var=False
)
event_rules['event_types'][event_type] = {
'learned_multiplier': float(learned_multiplier),
'learned_impact_pct': float(learned_impact),
'sample_size': int(len(event_days)),
'avg_quantity': float(event_avg),
'p_value': float(p_value),
'significant': bool(p_value < 0.05)
}
# Compare with hardcoded +30%
hardcoded_multiplier = 1.3
difference = abs(learned_multiplier - hardcoded_multiplier)
if difference > 0.1 and p_value < 0.05:
insight = {
'type': 'recommendation',
'priority': 'medium',
'category': 'forecasting',
'title': f'Event Rule Optimization: {event_type}',
'description': f'{event_type} events show {learned_impact:.1f}% impact vs hardcoded +30%. Using learned multiplier could improve event forecasts.',
'impact_type': 'forecast_improvement',
'impact_value': difference * 100,
'impact_unit': 'percentage_points',
'confidence': self._calculate_confidence(len(event_days), p_value),
'metrics_json': {
'event_type': event_type,
'learned_multiplier': round(learned_multiplier, 3),
'hardcoded_multiplier': 1.3,
'learned_impact_pct': round(learned_impact, 2),
'hardcoded_impact_pct': 30.0,
'baseline_avg': round(baseline_avg, 2),
'event_avg': round(event_avg, 2),
'sample_size': len(event_days),
'p_value': round(p_value, 4)
},
'actionable': True,
'recommendation_actions': [
{
'label': 'Update Event Rule',
'action': 'update_event_multiplier',
'params': {
'event_type': event_type,
'new_multiplier': round(learned_multiplier, 3)
}
}
],
'source_service': 'forecasting',
'source_model': 'dynamic_rules_engine'
}
insights.append(insight)
return event_rules, insights
async def _learn_day_of_week_rules(
self,
sales_data: pd.DataFrame,
min_samples: int
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
"""
Learn day-of-week patterns per product.
Replace general assumptions with product-specific patterns.
"""
logger.info("Learning day-of-week patterns")
sales_data = sales_data.copy()
sales_data['day_of_week'] = sales_data['date'].dt.dayofweek
sales_data['day_name'] = sales_data['date'].dt.day_name()
# Calculate average per day of week
dow_avg = sales_data.groupby('day_of_week')['quantity'].agg(['mean', 'std', 'count'])
overall_avg = sales_data['quantity'].mean()
dow_rules = {
'overall_avg': float(overall_avg),
'days': {}
}
insights = []
day_names = ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']
for dow in range(7):
if dow not in dow_avg.index or dow_avg.loc[dow, 'count'] < min_samples:
continue
day_avg = dow_avg.loc[dow, 'mean']
day_std = dow_avg.loc[dow, 'std']
day_count = dow_avg.loc[dow, 'count']
multiplier = day_avg / overall_avg
impact_pct = (multiplier - 1) * 100
# Coefficient of variation
cv = (day_std / day_avg) if day_avg > 0 else 0
dow_rules['days'][day_names[dow]] = {
'day_of_week': int(dow),
'learned_multiplier': float(multiplier),
'impact_pct': float(impact_pct),
'avg_quantity': float(day_avg),
'std_quantity': float(day_std),
'sample_size': int(day_count),
'coefficient_of_variation': float(cv)
}
# Insight for significant deviations
if abs(impact_pct) > 20: # More than 20% difference
insight = {
'type': 'insight',
'priority': 'medium' if abs(impact_pct) > 30 else 'low',
'category': 'forecasting',
'title': f'{day_names[dow]} Pattern: {abs(impact_pct):.0f}% {"Higher" if impact_pct > 0 else "Lower"}',
'description': f'{day_names[dow]} sales average {day_avg:.1f} units ({impact_pct:+.1f}% vs weekly average {overall_avg:.1f}). Consider this pattern in production planning.',
'impact_type': 'operational_insight',
'impact_value': abs(impact_pct),
'impact_unit': 'percentage',
'confidence': self._calculate_confidence(day_count, 0.01), # Low p-value for large samples
'metrics_json': {
'day_of_week': day_names[dow],
'day_multiplier': round(multiplier, 3),
'impact_pct': round(impact_pct, 2),
'day_avg': round(day_avg, 2),
'overall_avg': round(overall_avg, 2),
'sample_size': int(day_count),
'std': round(day_std, 2)
},
'actionable': True,
'recommendation_actions': [
{
'label': 'Adjust Production Schedule',
'action': 'adjust_weekly_production',
'params': {
'day': day_names[dow],
'multiplier': round(multiplier, 3)
}
}
],
'source_service': 'forecasting',
'source_model': 'dynamic_rules_engine'
}
insights.append(insight)
return dow_rules, insights
async def _learn_month_rules(
self,
sales_data: pd.DataFrame,
min_samples: int
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
"""
Learn monthly seasonality patterns per product.
"""
logger.info("Learning monthly seasonality patterns")
sales_data = sales_data.copy()
sales_data['month'] = sales_data['date'].dt.month
sales_data['month_name'] = sales_data['date'].dt.month_name()
# Calculate average per month
month_avg = sales_data.groupby('month')['quantity'].agg(['mean', 'std', 'count'])
overall_avg = sales_data['quantity'].mean()
month_rules = {
'overall_avg': float(overall_avg),
'months': {}
}
insights = []
month_names = ['January', 'February', 'March', 'April', 'May', 'June',
'July', 'August', 'September', 'October', 'November', 'December']
for month in range(1, 13):
if month not in month_avg.index or month_avg.loc[month, 'count'] < min_samples:
continue
month_mean = month_avg.loc[month, 'mean']
month_std = month_avg.loc[month, 'std']
month_count = month_avg.loc[month, 'count']
multiplier = month_mean / overall_avg
impact_pct = (multiplier - 1) * 100
month_rules['months'][month_names[month - 1]] = {
'month': int(month),
'learned_multiplier': float(multiplier),
'impact_pct': float(impact_pct),
'avg_quantity': float(month_mean),
'std_quantity': float(month_std),
'sample_size': int(month_count)
}
# Insight for significant seasonal patterns
if abs(impact_pct) > 25: # More than 25% seasonal variation
insight = {
'type': 'insight',
'priority': 'medium',
'category': 'forecasting',
'title': f'Seasonal Pattern: {month_names[month - 1]} {abs(impact_pct):.0f}% {"Higher" if impact_pct > 0 else "Lower"}',
'description': f'{month_names[month - 1]} shows strong seasonality with {impact_pct:+.1f}% vs annual average. Plan inventory accordingly.',
'impact_type': 'operational_insight',
'impact_value': abs(impact_pct),
'impact_unit': 'percentage',
'confidence': self._calculate_confidence(month_count, 0.01),
'metrics_json': {
'month': month_names[month - 1],
'multiplier': round(multiplier, 3),
'impact_pct': round(impact_pct, 2),
'month_avg': round(month_mean, 2),
'annual_avg': round(overall_avg, 2),
'sample_size': int(month_count)
},
'actionable': True,
'recommendation_actions': [
{
'label': 'Adjust Seasonal Planning',
'action': 'adjust_seasonal_forecast',
'params': {
'month': month_names[month - 1],
'multiplier': round(multiplier, 3)
}
}
],
'source_service': 'forecasting',
'source_model': 'dynamic_rules_engine'
}
insights.append(insight)
return month_rules, insights
def _calculate_confidence(self, sample_size: int, p_value: float) -> int:
"""
Calculate confidence score (0-100) based on sample size and statistical significance.
Args:
sample_size: Number of observations
p_value: Statistical significance p-value
Returns:
Confidence score 0-100
"""
# Sample size score (0-50 points)
if sample_size >= 100:
sample_score = 50
elif sample_size >= 50:
sample_score = 40
elif sample_size >= 30:
sample_score = 30
elif sample_size >= 20:
sample_score = 20
else:
sample_score = 10
# Statistical significance score (0-50 points)
if p_value < 0.001:
sig_score = 50
elif p_value < 0.01:
sig_score = 45
elif p_value < 0.05:
sig_score = 35
elif p_value < 0.1:
sig_score = 20
else:
sig_score = 10
return min(100, sample_score + sig_score)
def get_rule(
self,
inventory_product_id: str,
rule_type: str,
key: str
) -> Optional[float]:
"""
Get learned rule multiplier for a specific condition.
Args:
inventory_product_id: Product identifier
rule_type: 'weather', 'holiday', 'event', 'day_of_week', 'month'
key: Specific condition key (e.g., 'rain', 'Christmas', 'Monday')
Returns:
Learned multiplier or None if not learned
"""
if rule_type == 'weather':
rules = self.weather_rules.get(inventory_product_id, {})
return rules.get('conditions', {}).get(key, {}).get('learned_multiplier')
elif rule_type == 'holiday':
rules = self.holiday_rules.get(inventory_product_id, {})
return rules.get('holiday_types', {}).get(key, {}).get('learned_multiplier')
elif rule_type == 'event':
rules = self.event_rules.get(inventory_product_id, {})
return rules.get('event_types', {}).get(key, {}).get('learned_multiplier')
elif rule_type == 'day_of_week':
rules = self.dow_rules.get(inventory_product_id, {})
return rules.get('days', {}).get(key, {}).get('learned_multiplier')
elif rule_type == 'month':
rules = self.month_rules.get(inventory_product_id, {})
return rules.get('months', {}).get(key, {}).get('learned_multiplier')
return None
def export_rules_for_prophet(
self,
inventory_product_id: str
) -> Dict[str, Any]:
"""
Export learned rules in format suitable for Prophet model integration.
Returns:
Dictionary with multipliers for Prophet custom seasonality/regressors
"""
return {
'weather': self.weather_rules.get(inventory_product_id, {}),
'holidays': self.holiday_rules.get(inventory_product_id, {}),
'events': self.event_rules.get(inventory_product_id, {}),
'day_of_week': self.dow_rules.get(inventory_product_id, {}),
'months': self.month_rules.get(inventory_product_id, {})
}

View File

@@ -0,0 +1,263 @@
"""
Multi-Horizon Forecasting System
Generates forecasts for multiple time horizons (7, 14, 30, 90 days)
"""
import pandas as pd
import numpy as np
from typing import Dict, List, Any, Optional, Tuple
from datetime import datetime, timedelta, date
import structlog
logger = structlog.get_logger()
class MultiHorizonForecaster:
"""
Multi-horizon forecasting with horizon-specific models.
Horizons:
- Short-term (1-7 days): High precision, detailed features
- Medium-term (8-14 days): Balanced approach
- Long-term (15-30 days): Focus on trends, seasonal patterns
- Very long-term (31-90 days): Strategic planning, major trends only
"""
HORIZONS = {
'short': (1, 7),
'medium': (8, 14),
'long': (15, 30),
'very_long': (31, 90)
}
def __init__(self, base_forecaster=None):
"""
Initialize multi-horizon forecaster.
Args:
base_forecaster: Base forecaster (e.g., BakeryForecaster) to use
"""
self.base_forecaster = base_forecaster
async def generate_multi_horizon_forecast(
self,
tenant_id: str,
inventory_product_id: str,
start_date: date,
horizons: List[str] = None,
include_confidence_intervals: bool = True
) -> Dict[str, Any]:
"""
Generate forecasts for multiple horizons.
Args:
tenant_id: Tenant identifier
inventory_product_id: Product identifier
start_date: Start date for forecasts
horizons: List of horizons to forecast ('short', 'medium', 'long', 'very_long')
include_confidence_intervals: Include confidence intervals
Returns:
Dictionary with forecasts by horizon
"""
if horizons is None:
horizons = ['short', 'medium', 'long']
logger.info(
"Generating multi-horizon forecast",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
horizons=horizons
)
results = {
'tenant_id': tenant_id,
'inventory_product_id': inventory_product_id,
'start_date': start_date.isoformat(),
'generated_at': datetime.now().isoformat(),
'horizons': {}
}
for horizon_name in horizons:
if horizon_name not in self.HORIZONS:
logger.warning(f"Unknown horizon: {horizon_name}, skipping")
continue
start_day, end_day = self.HORIZONS[horizon_name]
# Generate forecast for this horizon
horizon_forecast = await self._generate_horizon_forecast(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
start_date=start_date,
days_ahead=end_day,
horizon_name=horizon_name,
include_confidence=include_confidence_intervals
)
results['horizons'][horizon_name] = horizon_forecast
logger.info("Multi-horizon forecast complete",
horizons_generated=len(results['horizons']))
return results
async def _generate_horizon_forecast(
self,
tenant_id: str,
inventory_product_id: str,
start_date: date,
days_ahead: int,
horizon_name: str,
include_confidence: bool
) -> Dict[str, Any]:
"""
Generate forecast for a specific horizon.
Args:
tenant_id: Tenant identifier
inventory_product_id: Product identifier
start_date: Start date
days_ahead: Number of days ahead
horizon_name: Horizon name ('short', 'medium', etc.)
include_confidence: Include confidence intervals
Returns:
Forecast data for the horizon
"""
# Generate date range
dates = [start_date + timedelta(days=i) for i in range(days_ahead)]
# Use base forecaster if available
if self.base_forecaster:
# Call base forecaster for predictions
forecasts = []
for forecast_date in dates:
try:
# This would call the actual forecasting service
# For now, we'll return a structured response
forecasts.append({
'date': forecast_date.isoformat(),
'predicted_demand': 0, # Placeholder
'confidence_lower': 0 if include_confidence else None,
'confidence_upper': 0 if include_confidence else None
})
except Exception as e:
logger.error(f"Failed to generate forecast for {forecast_date}: {e}")
return {
'horizon_name': horizon_name,
'days_ahead': days_ahead,
'start_date': start_date.isoformat(),
'end_date': dates[-1].isoformat(),
'forecasts': forecasts,
'aggregates': self._calculate_horizon_aggregates(forecasts)
}
else:
logger.warning("No base forecaster available, returning placeholder")
return {
'horizon_name': horizon_name,
'days_ahead': days_ahead,
'forecasts': [],
'aggregates': {}
}
def _calculate_horizon_aggregates(self, forecasts: List[Dict]) -> Dict[str, float]:
"""
Calculate aggregate statistics for a horizon.
Args:
forecasts: List of daily forecasts
Returns:
Aggregate statistics
"""
if not forecasts:
return {}
demands = [f['predicted_demand'] for f in forecasts if f.get('predicted_demand')]
if not demands:
return {}
return {
'total_demand': sum(demands),
'avg_daily_demand': np.mean(demands),
'max_daily_demand': max(demands),
'min_daily_demand': min(demands),
'demand_volatility': np.std(demands) if len(demands) > 1 else 0
}
def get_horizon_recommendation(
self,
horizon_name: str,
forecast_data: Dict[str, Any]
) -> Dict[str, Any]:
"""
Generate recommendations based on horizon forecast.
Args:
horizon_name: Horizon name
forecast_data: Forecast data for the horizon
Returns:
Recommendations dictionary
"""
aggregates = forecast_data.get('aggregates', {})
total_demand = aggregates.get('total_demand', 0)
volatility = aggregates.get('demand_volatility', 0)
recommendations = {
'horizon': horizon_name,
'actions': []
}
if horizon_name == 'short':
# Short-term: Operational recommendations
if total_demand > 0:
recommendations['actions'].append(f"Prepare {total_demand:.0f} units for next 7 days")
if volatility > 10:
recommendations['actions'].append("High volatility expected - increase safety stock")
elif horizon_name == 'medium':
# Medium-term: Procurement planning
recommendations['actions'].append(f"Order supplies for {total_demand:.0f} units (2-week demand)")
if aggregates.get('max_daily_demand', 0) > aggregates.get('avg_daily_demand', 0) * 1.5:
recommendations['actions'].append("Peak demand day detected - plan extra capacity")
elif horizon_name == 'long':
# Long-term: Strategic planning
avg_weekly_demand = total_demand / 4 if total_demand > 0 else 0
recommendations['actions'].append(f"Monthly demand projection: {total_demand:.0f} units")
recommendations['actions'].append(f"Average weekly demand: {avg_weekly_demand:.0f} units")
elif horizon_name == 'very_long':
# Very long-term: Capacity planning
recommendations['actions'].append(f"Quarterly demand projection: {total_demand:.0f} units")
recommendations['actions'].append("Review capacity and staffing needs")
return recommendations
def get_appropriate_horizons_for_use_case(use_case: str) -> List[str]:
"""
Get appropriate forecast horizons for a use case.
Args:
use_case: Use case name (e.g., 'production_planning', 'procurement', 'strategic')
Returns:
List of horizon names
"""
use_case_horizons = {
'production_planning': ['short'],
'procurement': ['short', 'medium'],
'inventory_optimization': ['short', 'medium'],
'capacity_planning': ['medium', 'long'],
'strategic_planning': ['long', 'very_long'],
'financial_planning': ['long', 'very_long'],
'all': ['short', 'medium', 'long', 'very_long']
}
return use_case_horizons.get(use_case, ['short', 'medium'])

View File

@@ -0,0 +1,593 @@
"""
Pattern Detection Engine for Sales Data
Automatically identifies patterns and generates insights
"""
import pandas as pd
import numpy as np
from typing import Dict, List, Any, Optional, Tuple
from datetime import datetime, timedelta
import structlog
from scipy import stats
from collections import defaultdict
logger = structlog.get_logger()
class SalesPatternDetector:
"""
Detect sales patterns and generate actionable insights.
Patterns detected:
- Time-of-day patterns (hourly peaks)
- Day-of-week patterns (weekend spikes)
- Weekly seasonality patterns
- Monthly patterns
- Holiday impact patterns
- Weather correlation patterns
"""
def __init__(self, significance_threshold: float = 0.15):
"""
Initialize pattern detector.
Args:
significance_threshold: Minimum percentage difference to consider significant (default 15%)
"""
self.significance_threshold = significance_threshold
self.detected_patterns = []
async def detect_all_patterns(
self,
tenant_id: str,
inventory_product_id: str,
sales_data: pd.DataFrame,
min_confidence: int = 70
) -> List[Dict[str, Any]]:
"""
Detect all patterns in sales data and generate insights.
Args:
tenant_id: Tenant identifier
inventory_product_id: Product identifier
sales_data: Sales data with columns: date, quantity, (optional: hour, temperature, etc.)
min_confidence: Minimum confidence score for insights
Returns:
List of insight dictionaries ready for AI Insights Service
"""
logger.info(
"Starting pattern detection",
tenant_id=tenant_id,
product_id=inventory_product_id,
data_points=len(sales_data)
)
insights = []
# Ensure date column is datetime
if 'date' in sales_data.columns:
sales_data['date'] = pd.to_datetime(sales_data['date'])
# 1. Day-of-week patterns
dow_insights = await self._detect_day_of_week_patterns(
tenant_id, inventory_product_id, sales_data, min_confidence
)
insights.extend(dow_insights)
# 2. Weekend vs weekday patterns
weekend_insights = await self._detect_weekend_patterns(
tenant_id, inventory_product_id, sales_data, min_confidence
)
insights.extend(weekend_insights)
# 3. Month-end patterns
month_end_insights = await self._detect_month_end_patterns(
tenant_id, inventory_product_id, sales_data, min_confidence
)
insights.extend(month_end_insights)
# 4. Hourly patterns (if hour data available)
if 'hour' in sales_data.columns:
hourly_insights = await self._detect_hourly_patterns(
tenant_id, inventory_product_id, sales_data, min_confidence
)
insights.extend(hourly_insights)
# 5. Weather correlation (if temperature data available)
if 'temperature' in sales_data.columns:
weather_insights = await self._detect_weather_correlations(
tenant_id, inventory_product_id, sales_data, min_confidence
)
insights.extend(weather_insights)
# 6. Trend detection
trend_insights = await self._detect_trends(
tenant_id, inventory_product_id, sales_data, min_confidence
)
insights.extend(trend_insights)
logger.info(
"Pattern detection complete",
total_insights=len(insights),
product_id=inventory_product_id
)
return insights
async def _detect_day_of_week_patterns(
self,
tenant_id: str,
inventory_product_id: str,
sales_data: pd.DataFrame,
min_confidence: int
) -> List[Dict[str, Any]]:
"""Detect day-of-week patterns (e.g., Friday sales spike)."""
insights = []
if 'date' not in sales_data.columns or 'quantity' not in sales_data.columns:
return insights
# Add day of week
sales_data['day_of_week'] = sales_data['date'].dt.dayofweek
sales_data['day_name'] = sales_data['date'].dt.day_name()
# Calculate average sales per day of week
dow_avg = sales_data.groupby(['day_of_week', 'day_name'])['quantity'].agg(['mean', 'count']).reset_index()
# Only consider days with sufficient data (at least 4 observations)
dow_avg = dow_avg[dow_avg['count'] >= 4]
if len(dow_avg) < 2:
return insights
overall_avg = sales_data['quantity'].mean()
# Find days significantly above average
for _, row in dow_avg.iterrows():
day_avg = row['mean']
pct_diff = ((day_avg - overall_avg) / overall_avg) * 100
if abs(pct_diff) > self.significance_threshold * 100:
# Calculate confidence based on sample size and consistency
confidence = self._calculate_pattern_confidence(
sample_size=int(row['count']),
effect_size=abs(pct_diff) / 100,
variability=sales_data['quantity'].std()
)
if confidence >= min_confidence:
if pct_diff > 0:
insight = self._create_insight(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
insight_type='pattern',
category='sales',
priority='medium' if pct_diff > 20 else 'low',
title=f'{row["day_name"]} Sales Pattern Detected',
description=f'Sales on {row["day_name"]} are {abs(pct_diff):.1f}% {"higher" if pct_diff > 0 else "lower"} than average ({day_avg:.1f} vs {overall_avg:.1f} units).',
confidence=confidence,
metrics={
'day_of_week': row['day_name'],
'avg_sales': float(day_avg),
'overall_avg': float(overall_avg),
'difference_pct': float(pct_diff),
'sample_size': int(row['count'])
},
actionable=True,
actions=[
{'label': 'Adjust Production', 'action': 'adjust_daily_production'},
{'label': 'Review Schedule', 'action': 'review_production_schedule'}
]
)
insights.append(insight)
return insights
async def _detect_weekend_patterns(
self,
tenant_id: str,
inventory_product_id: str,
sales_data: pd.DataFrame,
min_confidence: int
) -> List[Dict[str, Any]]:
"""Detect weekend vs weekday patterns."""
insights = []
if 'date' not in sales_data.columns or 'quantity' not in sales_data.columns:
return insights
# Classify weekend vs weekday
sales_data['is_weekend'] = sales_data['date'].dt.dayofweek.isin([5, 6])
# Calculate averages
weekend_avg = sales_data[sales_data['is_weekend']]['quantity'].mean()
weekday_avg = sales_data[~sales_data['is_weekend']]['quantity'].mean()
weekend_count = sales_data[sales_data['is_weekend']]['quantity'].count()
weekday_count = sales_data[~sales_data['is_weekend']]['quantity'].count()
if weekend_count < 4 or weekday_count < 4:
return insights
pct_diff = ((weekend_avg - weekday_avg) / weekday_avg) * 100
if abs(pct_diff) > self.significance_threshold * 100:
confidence = self._calculate_pattern_confidence(
sample_size=min(weekend_count, weekday_count),
effect_size=abs(pct_diff) / 100,
variability=sales_data['quantity'].std()
)
if confidence >= min_confidence:
# Estimate revenue impact
impact_value = abs(weekend_avg - weekday_avg) * 8 * 4 # 8 weekend days per month
insight = self._create_insight(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
insight_type='recommendation',
category='forecasting',
priority='high' if abs(pct_diff) > 25 else 'medium',
title=f'Weekend Demand Pattern: {abs(pct_diff):.0f}% {"Higher" if pct_diff > 0 else "Lower"}',
description=f'Weekend sales average {weekend_avg:.1f} units vs {weekday_avg:.1f} on weekdays ({abs(pct_diff):.0f}% {"increase" if pct_diff > 0 else "decrease"}). Recommend adjusting weekend production targets.',
confidence=confidence,
impact_type='revenue_increase' if pct_diff > 0 else 'cost_savings',
impact_value=float(impact_value),
impact_unit='units/month',
metrics={
'weekend_avg': float(weekend_avg),
'weekday_avg': float(weekday_avg),
'difference_pct': float(pct_diff),
'weekend_samples': int(weekend_count),
'weekday_samples': int(weekday_count)
},
actionable=True,
actions=[
{'label': 'Increase Weekend Production', 'action': 'adjust_weekend_production'},
{'label': 'Update Forecast Multiplier', 'action': 'update_forecast_rule'}
]
)
insights.append(insight)
return insights
async def _detect_month_end_patterns(
self,
tenant_id: str,
inventory_product_id: str,
sales_data: pd.DataFrame,
min_confidence: int
) -> List[Dict[str, Any]]:
"""Detect month-end and payday patterns."""
insights = []
if 'date' not in sales_data.columns or 'quantity' not in sales_data.columns:
return insights
# Identify payday periods (15th and last 3 days of month)
sales_data['day_of_month'] = sales_data['date'].dt.day
sales_data['is_payday'] = (
(sales_data['day_of_month'] == 15) |
(sales_data['date'].dt.is_month_end) |
(sales_data['day_of_month'] >= sales_data['date'].dt.days_in_month - 2)
)
payday_avg = sales_data[sales_data['is_payday']]['quantity'].mean()
regular_avg = sales_data[~sales_data['is_payday']]['quantity'].mean()
payday_count = sales_data[sales_data['is_payday']]['quantity'].count()
if payday_count < 4:
return insights
pct_diff = ((payday_avg - regular_avg) / regular_avg) * 100
if abs(pct_diff) > self.significance_threshold * 100:
confidence = self._calculate_pattern_confidence(
sample_size=payday_count,
effect_size=abs(pct_diff) / 100,
variability=sales_data['quantity'].std()
)
if confidence >= min_confidence and pct_diff > 0:
insight = self._create_insight(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
insight_type='pattern',
category='sales',
priority='medium',
title=f'Payday Shopping Pattern Detected',
description=f'Sales increase {pct_diff:.0f}% during payday periods (15th and month-end). Average {payday_avg:.1f} vs {regular_avg:.1f} units.',
confidence=confidence,
metrics={
'payday_avg': float(payday_avg),
'regular_avg': float(regular_avg),
'difference_pct': float(pct_diff)
},
actionable=True,
actions=[
{'label': 'Increase Payday Stock', 'action': 'adjust_payday_production'}
]
)
insights.append(insight)
return insights
async def _detect_hourly_patterns(
self,
tenant_id: str,
inventory_product_id: str,
sales_data: pd.DataFrame,
min_confidence: int
) -> List[Dict[str, Any]]:
"""Detect hourly sales patterns (if POS data available)."""
insights = []
if 'hour' not in sales_data.columns or 'quantity' not in sales_data.columns:
return insights
hourly_avg = sales_data.groupby('hour')['quantity'].agg(['mean', 'count']).reset_index()
hourly_avg = hourly_avg[hourly_avg['count'] >= 3] # At least 3 observations
if len(hourly_avg) < 3:
return insights
overall_avg = sales_data['quantity'].mean()
# Find peak hours (top 3)
top_hours = hourly_avg.nlargest(3, 'mean')
for _, row in top_hours.iterrows():
hour_avg = row['mean']
pct_diff = ((hour_avg - overall_avg) / overall_avg) * 100
if pct_diff > self.significance_threshold * 100:
confidence = self._calculate_pattern_confidence(
sample_size=int(row['count']),
effect_size=pct_diff / 100,
variability=sales_data['quantity'].std()
)
if confidence >= min_confidence:
hour = int(row['hour'])
time_label = f"{hour:02d}:00-{(hour+1):02d}:00"
insight = self._create_insight(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
insight_type='pattern',
category='sales',
priority='low',
title=f'Peak Sales Hour: {time_label}',
description=f'Sales peak during {time_label} with {hour_avg:.1f} units ({pct_diff:.0f}% above average).',
confidence=confidence,
metrics={
'peak_hour': hour,
'avg_sales': float(hour_avg),
'overall_avg': float(overall_avg),
'difference_pct': float(pct_diff)
},
actionable=True,
actions=[
{'label': 'Ensure Fresh Stock', 'action': 'schedule_production'},
{'label': 'Increase Staffing', 'action': 'adjust_staffing'}
]
)
insights.append(insight)
return insights
async def _detect_weather_correlations(
self,
tenant_id: str,
inventory_product_id: str,
sales_data: pd.DataFrame,
min_confidence: int
) -> List[Dict[str, Any]]:
"""Detect weather-sales correlations."""
insights = []
if 'temperature' not in sales_data.columns or 'quantity' not in sales_data.columns:
return insights
# Remove NaN values
clean_data = sales_data[['temperature', 'quantity']].dropna()
if len(clean_data) < 30: # Need sufficient data
return insights
# Calculate correlation
correlation, p_value = stats.pearsonr(clean_data['temperature'], clean_data['quantity'])
if abs(correlation) > 0.3 and p_value < 0.05: # Moderate correlation and significant
confidence = self._calculate_correlation_confidence(correlation, p_value, len(clean_data))
if confidence >= min_confidence:
direction = 'increase' if correlation > 0 else 'decrease'
insight = self._create_insight(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
insight_type='insight',
category='forecasting',
priority='medium' if abs(correlation) > 0.5 else 'low',
title=f'Temperature Impact on Sales: {abs(correlation):.0%} Correlation',
description=f'Sales {direction} with temperature (correlation: {correlation:.2f}). {"Warmer" if correlation > 0 else "Colder"} weather associated with {"higher" if correlation > 0 else "lower"} sales.',
confidence=confidence,
metrics={
'correlation': float(correlation),
'p_value': float(p_value),
'sample_size': len(clean_data),
'direction': direction
},
actionable=False
)
insights.append(insight)
return insights
async def _detect_trends(
self,
tenant_id: str,
inventory_product_id: str,
sales_data: pd.DataFrame,
min_confidence: int
) -> List[Dict[str, Any]]:
"""Detect overall trends (growing, declining, stable)."""
insights = []
if 'date' not in sales_data.columns or 'quantity' not in sales_data.columns or len(sales_data) < 60:
return insights
# Sort by date
sales_data = sales_data.sort_values('date')
# Calculate 30-day rolling average
sales_data['rolling_30d'] = sales_data['quantity'].rolling(window=30, min_periods=15).mean()
# Compare first and last 30-day averages
first_30_avg = sales_data['rolling_30d'].iloc[:30].mean()
last_30_avg = sales_data['rolling_30d'].iloc[-30:].mean()
if pd.isna(first_30_avg) or pd.isna(last_30_avg):
return insights
pct_change = ((last_30_avg - first_30_avg) / first_30_avg) * 100
if abs(pct_change) > 10: # 10% change is significant
confidence = min(95, 70 + int(abs(pct_change))) # Higher change = higher confidence
trend_type = 'growing' if pct_change > 0 else 'declining'
insight = self._create_insight(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
insight_type='prediction',
category='forecasting',
priority='high' if abs(pct_change) > 20 else 'medium',
title=f'Sales Trend: {trend_type.title()} {abs(pct_change):.0f}%',
description=f'Sales show a {trend_type} trend over the period. Current 30-day average: {last_30_avg:.1f} vs earlier: {first_30_avg:.1f} ({pct_change:+.0f}%).',
confidence=confidence,
metrics={
'current_avg': float(last_30_avg),
'previous_avg': float(first_30_avg),
'change_pct': float(pct_change),
'trend': trend_type
},
actionable=True,
actions=[
{'label': 'Adjust Forecast Model', 'action': 'update_forecast'},
{'label': 'Review Capacity', 'action': 'review_production_capacity'}
]
)
insights.append(insight)
return insights
def _calculate_pattern_confidence(
self,
sample_size: int,
effect_size: float,
variability: float
) -> int:
"""
Calculate confidence score for detected pattern.
Args:
sample_size: Number of observations
effect_size: Size of the effect (e.g., 0.25 for 25% difference)
variability: Standard deviation of data
Returns:
Confidence score (0-100)
"""
# Base confidence from sample size
if sample_size < 4:
base = 50
elif sample_size < 10:
base = 65
elif sample_size < 30:
base = 75
elif sample_size < 100:
base = 85
else:
base = 90
# Adjust for effect size
effect_boost = min(15, effect_size * 30)
# Adjust for variability (penalize high variability)
variability_penalty = min(10, variability / 10)
confidence = base + effect_boost - variability_penalty
return int(max(0, min(100, confidence)))
def _calculate_correlation_confidence(
self,
correlation: float,
p_value: float,
sample_size: int
) -> int:
"""Calculate confidence for correlation insights."""
# Base confidence from correlation strength
base = abs(correlation) * 100
# Boost for significance
if p_value < 0.001:
significance_boost = 15
elif p_value < 0.01:
significance_boost = 10
elif p_value < 0.05:
significance_boost = 5
else:
significance_boost = 0
# Boost for sample size
if sample_size > 100:
sample_boost = 10
elif sample_size > 50:
sample_boost = 5
else:
sample_boost = 0
confidence = base + significance_boost + sample_boost
return int(max(0, min(100, confidence)))
def _create_insight(
self,
tenant_id: str,
inventory_product_id: str,
insight_type: str,
category: str,
priority: str,
title: str,
description: str,
confidence: int,
metrics: Dict[str, Any],
actionable: bool,
actions: List[Dict[str, str]] = None,
impact_type: str = None,
impact_value: float = None,
impact_unit: str = None
) -> Dict[str, Any]:
"""Create an insight dictionary for AI Insights Service."""
return {
'tenant_id': tenant_id,
'type': insight_type,
'priority': priority,
'category': category,
'title': title,
'description': description,
'impact_type': impact_type,
'impact_value': impact_value,
'impact_unit': impact_unit,
'confidence': confidence,
'metrics_json': metrics,
'actionable': actionable,
'recommendation_actions': actions or [],
'source_service': 'forecasting',
'source_data_id': f'pattern_detection_{inventory_product_id}_{datetime.utcnow().strftime("%Y%m%d")}'
}

View File

@@ -0,0 +1,854 @@
# ================================================================
# services/forecasting/app/ml/predictor.py
# ================================================================
"""
Enhanced predictor module with advanced forecasting capabilities
"""
import structlog
from typing import Dict, List, Any, Optional, Tuple
import pandas as pd
import numpy as np
from datetime import datetime, date, timedelta
import pickle
import json
from app.core.config import settings
from shared.monitoring.metrics import MetricsCollector
from shared.database.base import create_database_manager
logger = structlog.get_logger()
metrics = MetricsCollector("forecasting-service")
class BakeryPredictor:
"""
Advanced predictor for bakery demand forecasting with dependency injection
Handles Prophet models and business-specific logic
"""
def __init__(self, database_manager=None, use_dynamic_rules=True):
self.database_manager = database_manager or create_database_manager(settings.DATABASE_URL, "forecasting-service")
self.model_cache = {}
self.use_dynamic_rules = use_dynamic_rules
if use_dynamic_rules:
try:
from app.ml.dynamic_rules_engine import DynamicRulesEngine
from shared.clients.ai_insights_client import AIInsightsClient
self.rules_engine = DynamicRulesEngine()
self.ai_insights_client = AIInsightsClient(
base_url=settings.AI_INSIGHTS_SERVICE_URL or "http://ai-insights-service:8000"
)
# Also provide business_rules for consistency
self.business_rules = BakeryBusinessRules(
use_dynamic_rules=True,
ai_insights_client=self.ai_insights_client
)
except ImportError as e:
logger.warning(f"Failed to import dynamic rules engine: {e}. Falling back to basic business rules.")
self.use_dynamic_rules = False
self.business_rules = BakeryBusinessRules()
else:
self.business_rules = BakeryBusinessRules()
class BakeryForecaster:
"""
Enhanced forecaster that integrates with repository pattern
Uses enhanced features from training service for predictions
"""
def __init__(self, database_manager=None, use_enhanced_features=True):
self.database_manager = database_manager or create_database_manager(settings.DATABASE_URL, "forecasting-service")
self.predictor = BakeryPredictor(database_manager)
self.use_enhanced_features = use_enhanced_features
# Initialize business rules - this was missing! This fixes the AttributeError
self.business_rules = BakeryBusinessRules(use_dynamic_rules=True, ai_insights_client=self.predictor.ai_insights_client if hasattr(self.predictor, 'ai_insights_client') else None)
# Initialize POI feature service
from app.services.poi_feature_service import POIFeatureService
self.poi_feature_service = POIFeatureService()
# Initialize enhanced data processor from shared module
if use_enhanced_features:
try:
from shared.ml.data_processor import EnhancedBakeryDataProcessor
self.data_processor = EnhancedBakeryDataProcessor(region='MD')
logger.info("Enhanced features enabled using shared data processor")
except ImportError as e:
logger.warning(
f"Could not import EnhancedBakeryDataProcessor from shared module: {e}. "
"Falling back to basic features."
)
self.use_enhanced_features = False
self.data_processor = None
else:
self.data_processor = None
async def predict_demand(self, model, features: Dict[str, Any],
business_type: str = "individual") -> Dict[str, float]:
"""Generate demand prediction with business rules applied"""
try:
# Generate base prediction
base_prediction = await self._generate_base_prediction(model, features)
# Apply business rules
adjusted_prediction = self.business_rules.apply_rules(
base_prediction, features, business_type
)
# Add uncertainty estimation
final_prediction = self._add_uncertainty_bands(adjusted_prediction, features)
return final_prediction
except Exception as e:
logger.error("Error in demand prediction", error=str(e))
raise
async def _generate_base_prediction(self, model, features: Dict[str, Any]) -> Dict[str, float]:
"""Generate base prediction from Prophet model"""
try:
# Convert features to Prophet DataFrame
df = self._prepare_prophet_dataframe(features)
# Generate forecast
forecast = model.predict(df)
if len(forecast) > 0:
row = forecast.iloc[0]
return {
"yhat": float(row['yhat']),
"yhat_lower": float(row['yhat_lower']),
"yhat_upper": float(row['yhat_upper']),
"trend": float(row.get('trend', 0)),
"seasonal": float(row.get('seasonal', 0)),
"weekly": float(row.get('weekly', 0)),
"yearly": float(row.get('yearly', 0)),
"holidays": float(row.get('holidays', 0))
}
else:
raise ValueError("No prediction generated from model")
except Exception as e:
logger.error("Error generating base prediction", error=str(e))
raise
async def _prepare_prophet_dataframe(self, features: Dict[str, Any],
historical_data: pd.DataFrame = None) -> pd.DataFrame:
"""
Convert features to Prophet-compatible DataFrame.
Uses enhanced features when available (60+ features vs basic 10).
"""
try:
if self.use_enhanced_features and self.data_processor:
# Use enhanced data processor from training service
logger.info("Generating enhanced features for prediction")
# Create future date range
future_dates = pd.DatetimeIndex([pd.to_datetime(features['date'])])
# Prepare weather forecast DataFrame
weather_df = pd.DataFrame({
'date': [pd.to_datetime(features['date'])],
'temperature': [features.get('temperature', 15.0)],
'precipitation': [features.get('precipitation', 0.0)],
'humidity': [features.get('humidity', 60.0)],
'wind_speed': [features.get('wind_speed', 5.0)],
'pressure': [features.get('pressure', 1013.0)]
})
# Fetch POI features if tenant_id is available
poi_features = None
if 'tenant_id' in features:
poi_features = await self.poi_feature_service.get_poi_features(
features['tenant_id']
)
if poi_features:
logger.info(
f"Retrieved {len(poi_features)} POI features for prediction",
tenant_id=features['tenant_id']
)
# Use data processor to create ALL enhanced features
df = await self.data_processor.prepare_prediction_features(
future_dates=future_dates,
weather_forecast=weather_df,
traffic_forecast=None, # Will add when traffic forecasting is implemented
poi_features=poi_features, # POI features for location-based forecasting
historical_data=historical_data # For lagged features
)
logger.info(f"Generated {len(df.columns)} enhanced features for prediction")
return df
else:
# Fallback to basic features
logger.info("Using basic features for prediction")
# Create base DataFrame
df = pd.DataFrame({
'ds': [pd.to_datetime(features['date'])]
})
# Add regressor features
feature_mapping = {
'temperature': 'temperature',
'precipitation': 'precipitation',
'humidity': 'humidity',
'wind_speed': 'wind_speed',
'traffic_volume': 'traffic_volume',
'pedestrian_count': 'pedestrian_count'
}
for feature_key, df_column in feature_mapping.items():
if feature_key in features and features[feature_key] is not None:
df[df_column] = float(features[feature_key])
else:
df[df_column] = 0.0
# Add categorical features
df['day_of_week'] = int(features.get('day_of_week', 0))
df['is_weekend'] = int(features.get('is_weekend', False))
df['is_holiday'] = int(features.get('is_holiday', False))
# Business type
business_type = features.get('business_type', 'individual')
df['is_central_workshop'] = int(business_type == 'central_workshop')
return df
except Exception as e:
logger.error(f"Error preparing Prophet dataframe: {e}, falling back to basic features")
# Fallback to basic implementation on error
df = pd.DataFrame({'ds': [pd.to_datetime(features['date'])]})
df['temperature'] = features.get('temperature', 15.0)
df['precipitation'] = features.get('precipitation', 0.0)
df['is_weekend'] = int(features.get('is_weekend', False))
df['is_holiday'] = int(features.get('is_holiday', False))
return df
def _add_uncertainty_bands(self, prediction: Dict[str, float],
features: Dict[str, Any]) -> Dict[str, float]:
"""Add uncertainty estimation based on external factors"""
try:
base_demand = prediction["yhat"]
base_lower = prediction["yhat_lower"]
base_upper = prediction["yhat_upper"]
# Weather uncertainty
weather_uncertainty = self._calculate_weather_uncertainty(features)
# Holiday uncertainty
holiday_uncertainty = self._calculate_holiday_uncertainty(features)
# Weekend uncertainty
weekend_uncertainty = self._calculate_weekend_uncertainty(features)
# Total uncertainty factor
total_uncertainty = 1.0 + weather_uncertainty + holiday_uncertainty + weekend_uncertainty
# Adjust bounds
uncertainty_range = (base_upper - base_lower) * total_uncertainty
center_point = base_demand
adjusted_lower = center_point - (uncertainty_range / 2)
adjusted_upper = center_point + (uncertainty_range / 2)
return {
"demand": max(0, base_demand), # Never predict negative demand
"lower_bound": max(0, adjusted_lower),
"upper_bound": adjusted_upper,
"uncertainty_factor": total_uncertainty,
"trend": prediction.get("trend", 0),
"seasonal": prediction.get("seasonal", 0),
"holiday_effect": prediction.get("holidays", 0)
}
except Exception as e:
logger.error("Error adding uncertainty bands", error=str(e))
# Return basic prediction if uncertainty calculation fails
return {
"demand": max(0, prediction["yhat"]),
"lower_bound": max(0, prediction["yhat_lower"]),
"upper_bound": prediction["yhat_upper"],
"uncertainty_factor": 1.0
}
def _calculate_weather_uncertainty(self, features: Dict[str, Any]) -> float:
"""Calculate weather-based uncertainty"""
uncertainty = 0.0
# Temperature extremes add uncertainty
temp = features.get('temperature')
if temp is not None:
if temp < settings.TEMPERATURE_THRESHOLD_COLD or temp > settings.TEMPERATURE_THRESHOLD_HOT:
uncertainty += 0.1
# Rain adds uncertainty
precipitation = features.get('precipitation')
if precipitation is not None and precipitation > 0:
uncertainty += 0.05 * min(precipitation, 10) # Cap at 50mm
return uncertainty
def _calculate_holiday_uncertainty(self, features: Dict[str, Any]) -> float:
"""Calculate holiday-based uncertainty"""
if features.get('is_holiday', False):
return 0.2 # 20% additional uncertainty on holidays
return 0.0
def _calculate_weekend_uncertainty(self, features: Dict[str, Any]) -> float:
"""Calculate weekend-based uncertainty"""
if features.get('is_weekend', False):
return 0.1 # 10% additional uncertainty on weekends
return 0.0
async def analyze_demand_patterns(
self,
tenant_id: str,
inventory_product_id: str,
sales_data: pd.DataFrame,
forecast_horizon_days: int = 30,
min_history_days: int = 90
) -> Dict[str, Any]:
"""
Analyze demand patterns by delegating to the sales service.
NOTE: Sales data analysis is the responsibility of the sales service.
This method calls the sales service API to get demand pattern analysis.
Args:
tenant_id: Tenant identifier
inventory_product_id: Product identifier
sales_data: Historical sales DataFrame (not used - kept for backward compatibility)
forecast_horizon_days: Days to forecast ahead (not used currently)
min_history_days: Minimum history required
Returns:
Analysis results with patterns, trends, and insights from sales service
"""
try:
from shared.clients.sales_client import SalesServiceClient
from datetime import date, timedelta
logger.info(
"Requesting demand pattern analysis from sales service",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id
)
# Initialize sales client
sales_client = SalesServiceClient(config=settings, calling_service_name="forecasting")
# Calculate date range
end_date = date.today()
start_date = end_date - timedelta(days=min_history_days)
# Call sales service for pattern analysis
patterns = await sales_client.get_product_demand_patterns(
tenant_id=tenant_id,
product_id=inventory_product_id,
start_date=start_date,
end_date=end_date,
min_history_days=min_history_days
)
# Generate insights from patterns
insights = self._generate_insights_from_patterns(
patterns,
tenant_id,
inventory_product_id
)
# Add insights to the result
patterns['insights'] = insights
logger.info(
"Demand pattern analysis received from sales service",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
insights_generated=len(insights)
)
return patterns
except Exception as e:
logger.error(
"Error getting demand patterns from sales service",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
error=str(e),
exc_info=True
)
return {
'analyzed_at': datetime.utcnow().isoformat(),
'history_days': 0,
'insights': [],
'patterns': {},
'trend_analysis': {},
'seasonal_factors': {},
'statistics': {},
'error': str(e)
}
def _generate_insights_from_patterns(
self,
patterns: Dict[str, Any],
tenant_id: str,
inventory_product_id: str
) -> List[Dict[str, Any]]:
"""
Generate actionable insights from demand patterns provided by sales service.
Args:
patterns: Demand patterns from sales service
tenant_id: Tenant identifier
inventory_product_id: Product identifier
Returns:
List of insights for AI Insights Service
"""
insights = []
# Check if there was an error in pattern analysis
if 'error' in patterns:
return insights
trend = patterns.get('trend_analysis', {})
stats = patterns.get('statistics', {})
seasonal = patterns.get('seasonal_factors', {})
# Trend insight
if trend.get('is_increasing'):
insights.append({
'type': 'insight',
'priority': 'medium',
'category': 'forecasting',
'title': 'Increasing Demand Trend Detected',
'description': f"Product shows {trend.get('direction', 'increasing')} demand trend. Consider increasing inventory levels.",
'impact_type': 'demand_increase',
'impact_value': abs(trend.get('correlation', 0) * 100),
'impact_unit': 'percent',
'confidence': min(int(abs(trend.get('correlation', 0)) * 100), 95),
'metrics_json': trend,
'actionable': True,
'recommendation_actions': [
{
'label': 'Increase Safety Stock',
'action': 'increase_safety_stock',
'params': {'product_id': inventory_product_id, 'factor': 1.2}
}
]
})
elif trend.get('is_decreasing'):
insights.append({
'type': 'insight',
'priority': 'low',
'category': 'forecasting',
'title': 'Decreasing Demand Trend Detected',
'description': f"Product shows {trend.get('direction', 'decreasing')} demand trend. Consider reviewing inventory strategy.",
'impact_type': 'demand_decrease',
'impact_value': abs(trend.get('correlation', 0) * 100),
'impact_unit': 'percent',
'confidence': min(int(abs(trend.get('correlation', 0)) * 100), 95),
'metrics_json': trend,
'actionable': True,
'recommendation_actions': [
{
'label': 'Review Inventory Levels',
'action': 'review_inventory',
'params': {'product_id': inventory_product_id}
}
]
})
# Volatility insight
cv = stats.get('coefficient_of_variation', 0)
if cv > 0.5:
insights.append({
'type': 'alert',
'priority': 'medium',
'category': 'forecasting',
'title': 'High Demand Variability Detected',
'description': f'Product has high demand variability (CV: {cv:.2f}). Consider dynamic safety stock levels.',
'impact_type': 'demand_variability',
'impact_value': round(cv * 100, 1),
'impact_unit': 'percent',
'confidence': 85,
'metrics_json': stats,
'actionable': True,
'recommendation_actions': [
{
'label': 'Enable Dynamic Safety Stock',
'action': 'enable_dynamic_safety_stock',
'params': {'product_id': inventory_product_id}
}
]
})
# Seasonal pattern insight
peak_ratio = seasonal.get('peak_ratio', 1.0)
if peak_ratio > 1.5:
pattern_data = patterns.get('patterns', {})
peak_day = pattern_data.get('peak_day', 0)
low_day = pattern_data.get('low_day', 0)
insights.append({
'type': 'insight',
'priority': 'medium',
'category': 'forecasting',
'title': 'Strong Weekly Pattern Detected',
'description': f'Demand is {peak_ratio:.1f}x higher on day {peak_day} compared to day {low_day}. Adjust production schedule accordingly.',
'impact_type': 'seasonal_pattern',
'impact_value': round((peak_ratio - 1) * 100, 1),
'impact_unit': 'percent',
'confidence': 80,
'metrics_json': {**seasonal, **pattern_data},
'actionable': True,
'recommendation_actions': [
{
'label': 'Adjust Production Schedule',
'action': 'adjust_production',
'params': {'product_id': inventory_product_id, 'pattern': 'weekly'}
}
]
})
return insights
async def _get_dynamic_rules(self, tenant_id: str, inventory_product_id: str, rule_type: str) -> Dict[str, float]:
"""
Fetch learned dynamic rules from AI Insights Service.
Args:
tenant_id: Tenant UUID
inventory_product_id: Product UUID
rule_type: Type of rules (weather, temporal, holiday, etc.)
Returns:
Dictionary of learned rules with factors
"""
try:
from uuid import UUID
# Fetch latest rules insight for this product
insights = await self.ai_insights_client.get_insights(
tenant_id=UUID(tenant_id),
filters={
'category': 'forecasting',
'actionable_only': False,
'page_size': 100
}
)
if not insights or 'items' not in insights:
return {}
# Find the most recent rules insight for this product
for insight in insights['items']:
if insight.get('source_model') == 'dynamic_rules_engine':
metrics = insight.get('metrics_json', {})
if metrics.get('inventory_product_id') == inventory_product_id:
rules_data = metrics.get('rules', {})
return rules_data.get(rule_type, {})
return {}
except Exception as e:
logger.warning(f"Failed to fetch dynamic rules: {e}")
return {}
async def generate_forecast_with_repository(self, tenant_id: str, inventory_product_id: str,
forecast_date: date, model_id: str = None) -> Dict[str, Any]:
"""Generate forecast with repository integration"""
try:
# This would integrate with repositories for model loading and caching
# For now, we'll implement basic forecasting logic using the forecaster's methods
# This is a simplified approach - in production, this would use repositories
# For now, prepare minimal features for prediction
features = {
'date': forecast_date.isoformat(),
'day_of_week': forecast_date.weekday(),
'is_weekend': 1 if forecast_date.weekday() >= 5 else 0,
'is_holiday': 0, # Would come from calendar service in real implementation
# Add default weather values if needed
'temperature': 20.0,
'precipitation': 0.0,
}
# This is a placeholder - in a full implementation, we would:
# 1. Load the appropriate model from repository
# 2. Use historical data to make prediction
# 3. Apply business rules
# For now, return the structure with basic info
# For more realistic implementation, we'd use self.predict_demand method
# but that requires a model object which needs to be loaded
return {
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id,
"forecast_date": forecast_date.isoformat(),
"prediction": 10.0, # Placeholder value - in reality would be calculated
"confidence_interval": {"lower": 8.0, "upper": 12.0}, # Placeholder values
"status": "completed",
"repository_integration": True,
"forecast_method": "placeholder"
}
except Exception as e:
logger.error("Forecast generation failed", error=str(e))
raise
class BakeryBusinessRules:
"""
Business rules for Spanish bakeries
Applies domain-specific adjustments to predictions
Supports both dynamic learned rules and hardcoded fallbacks
"""
def __init__(self, use_dynamic_rules=False, ai_insights_client=None):
self.use_dynamic_rules = use_dynamic_rules
self.ai_insights_client = ai_insights_client
self.rules_cache = {}
async def apply_rules(self, prediction: Dict[str, float], features: Dict[str, Any],
business_type: str, tenant_id: str = None, inventory_product_id: str = None) -> Dict[str, float]:
"""Apply all business rules to prediction (dynamic or hardcoded)"""
adjusted_prediction = prediction.copy()
# Apply weather rules
adjusted_prediction = await self._apply_weather_rules(
adjusted_prediction, features, tenant_id, inventory_product_id
)
# Apply time-based rules
adjusted_prediction = await self._apply_time_rules(
adjusted_prediction, features, tenant_id, inventory_product_id
)
# Apply business type rules
adjusted_prediction = self._apply_business_type_rules(adjusted_prediction, business_type)
# Apply Spanish-specific rules
adjusted_prediction = self._apply_spanish_rules(adjusted_prediction, features)
return adjusted_prediction
async def _get_dynamic_rules(self, tenant_id: str, inventory_product_id: str, rule_type: str) -> Dict[str, float]:
"""
Fetch learned dynamic rules from AI Insights Service.
Args:
tenant_id: Tenant UUID
inventory_product_id: Product UUID
rule_type: Type of rules (weather, temporal, holiday, etc.)
Returns:
Dictionary of learned rules with factors
"""
# Check cache first
cache_key = f"{tenant_id}:{inventory_product_id}:{rule_type}"
if cache_key in self.rules_cache:
return self.rules_cache[cache_key]
try:
from uuid import UUID
if not self.ai_insights_client:
return {}
# Fetch latest rules insight for this product
insights = await self.ai_insights_client.get_insights(
tenant_id=UUID(tenant_id),
filters={
'category': 'forecasting',
'actionable_only': False,
'page_size': 100
}
)
if not insights or 'items' not in insights:
return {}
# Find the most recent rules insight for this product
for insight in insights['items']:
if insight.get('source_model') == 'dynamic_rules_engine':
metrics = insight.get('metrics_json', {})
if metrics.get('inventory_product_id') == inventory_product_id:
rules_data = metrics.get('rules', {})
result = rules_data.get(rule_type, {})
# Cache the result
self.rules_cache[cache_key] = result
return result
return {}
except Exception as e:
logger.warning(f"Failed to fetch dynamic rules: {e}")
return {}
async def _apply_weather_rules(self, prediction: Dict[str, float],
features: Dict[str, Any],
tenant_id: str = None,
inventory_product_id: str = None) -> Dict[str, float]:
"""Apply weather-based business rules (dynamic or hardcoded fallback)"""
if self.use_dynamic_rules and tenant_id and inventory_product_id:
try:
# Fetch dynamic weather rules
rules = await self._get_dynamic_rules(tenant_id, inventory_product_id, 'weather')
# Apply learned rain impact
precipitation = features.get('precipitation', 0)
if precipitation > 0:
rain_factor = rules.get('rain_factor', settings.RAIN_IMPACT_FACTOR)
prediction["yhat"] *= rain_factor
prediction["yhat_lower"] *= rain_factor
prediction["yhat_upper"] *= rain_factor
# Apply learned temperature impact
temperature = features.get('temperature')
if temperature is not None:
if temperature > settings.TEMPERATURE_THRESHOLD_HOT:
hot_factor = rules.get('temperature_hot_factor', 0.9)
prediction["yhat"] *= hot_factor
elif temperature < settings.TEMPERATURE_THRESHOLD_COLD:
cold_factor = rules.get('temperature_cold_factor', 1.1)
prediction["yhat"] *= cold_factor
except Exception as e:
logger.warning(f"Failed to apply dynamic weather rules, using fallback: {e}")
# Fallback to hardcoded
precipitation = features.get('precipitation', 0)
if precipitation > 0:
prediction["yhat"] *= settings.RAIN_IMPACT_FACTOR
prediction["yhat_lower"] *= settings.RAIN_IMPACT_FACTOR
prediction["yhat_upper"] *= settings.RAIN_IMPACT_FACTOR
temperature = features.get('temperature')
if temperature is not None:
if temperature > settings.TEMPERATURE_THRESHOLD_HOT:
prediction["yhat"] *= 0.9
elif temperature < settings.TEMPERATURE_THRESHOLD_COLD:
prediction["yhat"] *= 1.1
else:
# Use hardcoded rules
precipitation = features.get('precipitation', 0)
if precipitation > 0:
rain_factor = settings.RAIN_IMPACT_FACTOR
prediction["yhat"] *= rain_factor
prediction["yhat_lower"] *= rain_factor
prediction["yhat_upper"] *= rain_factor
temperature = features.get('temperature')
if temperature is not None:
if temperature > settings.TEMPERATURE_THRESHOLD_HOT:
prediction["yhat"] *= 0.9
elif temperature < settings.TEMPERATURE_THRESHOLD_COLD:
prediction["yhat"] *= 1.1
return prediction
async def _apply_time_rules(self, prediction: Dict[str, float],
features: Dict[str, Any],
tenant_id: str = None,
inventory_product_id: str = None) -> Dict[str, float]:
"""Apply time-based business rules (dynamic or hardcoded fallback)"""
if self.use_dynamic_rules and tenant_id and inventory_product_id:
try:
# Fetch dynamic temporal rules
rules = await self._get_dynamic_rules(tenant_id, inventory_product_id, 'temporal')
# Apply learned weekend adjustment
if features.get('is_weekend', False):
weekend_factor = rules.get('weekend_factor', settings.WEEKEND_ADJUSTMENT_FACTOR)
prediction["yhat"] *= weekend_factor
prediction["yhat_lower"] *= weekend_factor
prediction["yhat_upper"] *= weekend_factor
# Apply learned holiday adjustment
if features.get('is_holiday', False):
holiday_factor = rules.get('holiday_factor', settings.HOLIDAY_ADJUSTMENT_FACTOR)
prediction["yhat"] *= holiday_factor
prediction["yhat_lower"] *= holiday_factor
prediction["yhat_upper"] *= holiday_factor
except Exception as e:
logger.warning(f"Failed to apply dynamic time rules, using fallback: {e}")
# Fallback to hardcoded
if features.get('is_weekend', False):
prediction["yhat"] *= settings.WEEKEND_ADJUSTMENT_FACTOR
prediction["yhat_lower"] *= settings.WEEKEND_ADJUSTMENT_FACTOR
prediction["yhat_upper"] *= settings.WEEKEND_ADJUSTMENT_FACTOR
if features.get('is_holiday', False):
prediction["yhat"] *= settings.HOLIDAY_ADJUSTMENT_FACTOR
prediction["yhat_lower"] *= settings.HOLIDAY_ADJUSTMENT_FACTOR
prediction["yhat_upper"] *= settings.HOLIDAY_ADJUSTMENT_FACTOR
else:
# Use hardcoded rules
if features.get('is_weekend', False):
weekend_factor = settings.WEEKEND_ADJUSTMENT_FACTOR
prediction["yhat"] *= weekend_factor
prediction["yhat_lower"] *= weekend_factor
prediction["yhat_upper"] *= weekend_factor
if features.get('is_holiday', False):
holiday_factor = settings.HOLIDAY_ADJUSTMENT_FACTOR
prediction["yhat"] *= holiday_factor
prediction["yhat_lower"] *= holiday_factor
prediction["yhat_upper"] *= holiday_factor
return prediction
def _apply_business_type_rules(self, prediction: Dict[str, float],
business_type: str) -> Dict[str, float]:
"""Apply business type specific rules"""
if business_type == "central_workshop":
# Central workshops have more stable demand
uncertainty_reduction = 0.8
center = prediction["yhat"]
lower = prediction["yhat_lower"]
upper = prediction["yhat_upper"]
# Reduce uncertainty band
new_range = (upper - lower) * uncertainty_reduction
prediction["yhat_lower"] = center - (new_range / 2)
prediction["yhat_upper"] = center + (new_range / 2)
return prediction
def _apply_spanish_rules(self, prediction: Dict[str, float],
features: Dict[str, Any]) -> Dict[str, float]:
"""Apply Spanish bakery specific rules"""
# Spanish siesta time considerations
date_str = features.get('date')
if date_str:
try:
current_date = pd.to_datetime(date_str)
day_of_week = current_date.weekday()
# Reduced activity during typical siesta hours (14:00-17:00)
# This affects afternoon sales planning
if day_of_week < 5: # Weekdays
prediction["yhat"] *= 0.95 # Slight reduction for siesta effect
except Exception as e:
logger.warning(f"Error processing date in spanish rules: {e}")
else:
logger.warning("Date not provided in features, skipping Spanish rules")
return prediction

View File

@@ -0,0 +1,312 @@
"""
Rules Orchestrator
Coordinates dynamic rules learning, insight posting, and integration with forecasting service
"""
import pandas as pd
from typing import Dict, List, Any, Optional
import structlog
from datetime import datetime
from uuid import UUID
from app.ml.dynamic_rules_engine import DynamicRulesEngine
from app.clients.ai_insights_client import AIInsightsClient
from shared.messaging import UnifiedEventPublisher
logger = structlog.get_logger()
class RulesOrchestrator:
"""
Orchestrates dynamic rules learning and insight generation workflow.
Workflow:
1. Learn dynamic rules from historical data
2. Generate insights comparing learned vs hardcoded rules
3. Post insights to AI Insights Service
4. Provide learned rules for forecasting integration
5. Track rule updates and performance
"""
def __init__(
self,
ai_insights_base_url: str = "http://ai-insights-service:8000",
event_publisher: Optional[UnifiedEventPublisher] = None
):
self.rules_engine = DynamicRulesEngine()
self.ai_insights_client = AIInsightsClient(ai_insights_base_url)
self.event_publisher = event_publisher
async def learn_and_post_rules(
self,
tenant_id: str,
inventory_product_id: str,
sales_data: pd.DataFrame,
external_data: Optional[pd.DataFrame] = None,
min_samples: int = 10
) -> Dict[str, Any]:
"""
Complete workflow: Learn rules and post insights.
Args:
tenant_id: Tenant identifier
inventory_product_id: Product identifier
sales_data: Historical sales data
external_data: Optional weather/events/holidays data
min_samples: Minimum samples for rule learning
Returns:
Workflow results with learned rules and posted insights
"""
logger.info(
"Starting dynamic rules learning workflow",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id
)
# Step 1: Learn all rules from data
rules_results = await self.rules_engine.learn_all_rules(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
sales_data=sales_data,
external_data=external_data,
min_samples=min_samples
)
logger.info(
"Rules learning complete",
insights_generated=len(rules_results['insights']),
rules_learned=len(rules_results['rules'])
)
# Step 2: Enrich insights with tenant_id and product context
enriched_insights = self._enrich_insights(
rules_results['insights'],
tenant_id,
inventory_product_id
)
# Step 3: Post insights to AI Insights Service
if enriched_insights:
post_results = await self.ai_insights_client.create_insights_bulk(
tenant_id=UUID(tenant_id),
insights=enriched_insights
)
logger.info(
"Insights posted to AI Insights Service",
total=post_results['total'],
successful=post_results['successful'],
failed=post_results['failed']
)
else:
post_results = {'total': 0, 'successful': 0, 'failed': 0}
logger.info("No insights to post")
# Step 4: Publish insight events to RabbitMQ
created_insights = post_results.get('created_insights', [])
if created_insights:
product_context = {'inventory_product_id': inventory_product_id}
await self._publish_insight_events(
tenant_id=tenant_id,
insights=created_insights,
product_context=product_context
)
# Step 5: Return comprehensive results
return {
'tenant_id': tenant_id,
'inventory_product_id': inventory_product_id,
'learned_at': rules_results['learned_at'],
'rules': rules_results['rules'],
'insights_generated': len(enriched_insights),
'insights_posted': post_results['successful'],
'insights_failed': post_results['failed'],
'created_insights': post_results.get('created_insights', [])
}
def _enrich_insights(
self,
insights: List[Dict[str, Any]],
tenant_id: str,
inventory_product_id: str
) -> List[Dict[str, Any]]:
"""
Enrich insights with required fields for AI Insights Service.
Args:
insights: Raw insights from rules engine
tenant_id: Tenant identifier
inventory_product_id: Product identifier
Returns:
Enriched insights ready for posting
"""
enriched = []
for insight in insights:
# Add required tenant_id and product context
enriched_insight = insight.copy()
enriched_insight['tenant_id'] = tenant_id
# Add product context to metrics
if 'metrics_json' not in enriched_insight:
enriched_insight['metrics_json'] = {}
enriched_insight['metrics_json']['inventory_product_id'] = inventory_product_id
# Add source metadata
enriched_insight['source_service'] = 'forecasting'
enriched_insight['source_model'] = 'dynamic_rules_engine'
enriched_insight['detected_at'] = datetime.utcnow().isoformat()
enriched.append(enriched_insight)
return enriched
async def get_learned_rules_for_forecasting(
self,
inventory_product_id: str
) -> Dict[str, Any]:
"""
Get learned rules in format ready for forecasting integration.
Args:
inventory_product_id: Product identifier
Returns:
Dictionary with learned multipliers for all rule types
"""
return self.rules_engine.export_rules_for_prophet(inventory_product_id)
def get_rule_multiplier(
self,
inventory_product_id: str,
rule_type: str,
key: str,
default: float = 1.0
) -> float:
"""
Get learned rule multiplier with fallback to default.
Args:
inventory_product_id: Product identifier
rule_type: 'weather', 'holiday', 'event', 'day_of_week', 'month'
key: Condition key
default: Default multiplier if rule not learned
Returns:
Learned multiplier or default
"""
learned = self.rules_engine.get_rule(inventory_product_id, rule_type, key)
return learned if learned is not None else default
async def update_rules_periodically(
self,
tenant_id: str,
inventory_product_id: str,
sales_data: pd.DataFrame,
external_data: Optional[pd.DataFrame] = None
) -> Dict[str, Any]:
"""
Update learned rules with new data (for periodic refresh).
Args:
tenant_id: Tenant identifier
inventory_product_id: Product identifier
sales_data: Updated historical sales data
external_data: Updated external data
Returns:
Update results
"""
logger.info(
"Updating learned rules with new data",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
new_data_points=len(sales_data)
)
# Re-learn rules with updated data
results = await self.learn_and_post_rules(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
sales_data=sales_data,
external_data=external_data
)
logger.info(
"Rules update complete",
insights_posted=results['insights_posted']
)
return results
async def _publish_insight_events(self, tenant_id, insights, product_context=None):
"""
Publish insight events to RabbitMQ for alert processing.
Args:
tenant_id: Tenant identifier
insights: List of created insights
product_context: Additional context about the product
"""
if not self.event_publisher:
logger.warning("No event publisher available for business rules insights")
return
for insight in insights:
# Determine severity based on confidence and priority
confidence = insight.get('confidence', 0)
priority = insight.get('priority', 'medium')
# Map priority to severity, with confidence as tiebreaker
if priority == 'critical' or (priority == 'high' and confidence >= 70):
severity = 'high'
elif priority == 'high' or (priority == 'medium' and confidence >= 80):
severity = 'medium'
else:
severity = 'low'
# Prepare the event data
event_data = {
'insight_id': insight.get('id'),
'type': insight.get('type'),
'title': insight.get('title'),
'description': insight.get('description'),
'category': insight.get('category'),
'priority': insight.get('priority'),
'confidence': confidence,
'recommendation': insight.get('recommendation_actions', []),
'impact_type': insight.get('impact_type'),
'impact_value': insight.get('impact_value'),
'inventory_product_id': product_context.get('inventory_product_id') if product_context else None,
'timestamp': insight.get('detected_at', datetime.utcnow().isoformat()),
'source_service': 'forecasting',
'source_model': 'dynamic_rules_engine'
}
try:
await self.event_publisher.publish_recommendation(
event_type='ai_business_rule',
tenant_id=tenant_id,
severity=severity,
data=event_data
)
logger.info(
"Published business rules insight event",
tenant_id=tenant_id,
insight_id=insight.get('id'),
severity=severity
)
except Exception as e:
logger.error(
"Failed to publish business rules insight event",
tenant_id=tenant_id,
insight_id=insight.get('id'),
error=str(e)
)
async def close(self):
"""Close HTTP client connections."""
await self.ai_insights_client.close()

View File

@@ -0,0 +1,385 @@
"""
Scenario Planning System
What-if analysis for demand forecasting
"""
import pandas as pd
import numpy as np
from typing import Dict, List, Any, Optional
from datetime import datetime, date, timedelta
import structlog
from enum import Enum
logger = structlog.get_logger()
class ScenarioType(str, Enum):
"""Types of scenarios"""
BASELINE = "baseline"
OPTIMISTIC = "optimistic"
PESSIMISTIC = "pessimistic"
CUSTOM = "custom"
PROMOTION = "promotion"
EVENT = "event"
WEATHER = "weather"
PRICE_CHANGE = "price_change"
class ScenarioPlanner:
"""
Scenario planning for demand forecasting.
Allows testing "what-if" scenarios:
- What if we run a promotion?
- What if there's a local festival?
- What if weather is unusually bad?
- What if we change prices?
"""
def __init__(self, base_forecaster=None):
"""
Initialize scenario planner.
Args:
base_forecaster: Base forecaster to use for baseline predictions
"""
self.base_forecaster = base_forecaster
async def create_scenario(
self,
tenant_id: str,
inventory_product_id: str,
scenario_name: str,
scenario_type: ScenarioType,
start_date: date,
end_date: date,
adjustments: Dict[str, Any]
) -> Dict[str, Any]:
"""
Create a forecast scenario with adjustments.
Args:
tenant_id: Tenant identifier
inventory_product_id: Product identifier
scenario_name: Name for the scenario
scenario_type: Type of scenario
start_date: Scenario start date
end_date: Scenario end date
adjustments: Dictionary of adjustments to apply
Returns:
Scenario forecast results
"""
logger.info(
"Creating forecast scenario",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
scenario_name=scenario_name,
scenario_type=scenario_type
)
# Generate baseline forecast first
baseline_forecast = await self._generate_baseline_forecast(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
start_date=start_date,
end_date=end_date
)
# Apply scenario adjustments
scenario_forecast = self._apply_scenario_adjustments(
baseline_forecast=baseline_forecast,
adjustments=adjustments,
scenario_type=scenario_type
)
# Calculate impact
impact_analysis = self._calculate_scenario_impact(
baseline_forecast=baseline_forecast,
scenario_forecast=scenario_forecast
)
return {
'scenario_id': f"scenario_{tenant_id}_{inventory_product_id}_{datetime.now().strftime('%Y%m%d%H%M%S')}",
'scenario_name': scenario_name,
'scenario_type': scenario_type,
'tenant_id': tenant_id,
'inventory_product_id': inventory_product_id,
'date_range': {
'start': start_date.isoformat(),
'end': end_date.isoformat()
},
'baseline_forecast': baseline_forecast,
'scenario_forecast': scenario_forecast,
'impact_analysis': impact_analysis,
'adjustments_applied': adjustments,
'created_at': datetime.now().isoformat()
}
async def compare_scenarios(
self,
scenarios: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""
Compare multiple scenarios side-by-side.
Args:
scenarios: List of scenario results from create_scenario()
Returns:
Comparison analysis
"""
if len(scenarios) < 2:
return {'error': 'Need at least 2 scenarios to compare'}
comparison = {
'scenarios_compared': len(scenarios),
'scenario_names': [s['scenario_name'] for s in scenarios],
'comparison_metrics': {}
}
# Extract total demand for each scenario
for scenario in scenarios:
scenario_name = scenario['scenario_name']
scenario_forecast = scenario['scenario_forecast']
total_demand = sum(f['predicted_demand'] for f in scenario_forecast)
comparison['comparison_metrics'][scenario_name] = {
'total_demand': total_demand,
'avg_daily_demand': total_demand / len(scenario_forecast) if scenario_forecast else 0,
'peak_demand': max(f['predicted_demand'] for f in scenario_forecast) if scenario_forecast else 0
}
# Determine best and worst scenarios
total_demands = {
name: metrics['total_demand']
for name, metrics in comparison['comparison_metrics'].items()
}
comparison['best_scenario'] = max(total_demands, key=total_demands.get)
comparison['worst_scenario'] = min(total_demands, key=total_demands.get)
comparison['demand_range'] = {
'min': min(total_demands.values()),
'max': max(total_demands.values()),
'spread': max(total_demands.values()) - min(total_demands.values())
}
return comparison
async def _generate_baseline_forecast(
self,
tenant_id: str,
inventory_product_id: str,
start_date: date,
end_date: date
) -> List[Dict[str, Any]]:
"""
Generate baseline forecast without adjustments.
Args:
tenant_id: Tenant identifier
inventory_product_id: Product identifier
start_date: Start date
end_date: End date
Returns:
List of daily forecasts
"""
# Generate date range
dates = []
current_date = start_date
while current_date <= end_date:
dates.append(current_date)
current_date += timedelta(days=1)
# Placeholder forecast (in real implementation, call forecasting service)
baseline = []
for forecast_date in dates:
baseline.append({
'date': forecast_date.isoformat(),
'predicted_demand': 100, # Placeholder
'confidence_lower': 80,
'confidence_upper': 120
})
return baseline
def _apply_scenario_adjustments(
self,
baseline_forecast: List[Dict[str, Any]],
adjustments: Dict[str, Any],
scenario_type: ScenarioType
) -> List[Dict[str, Any]]:
"""
Apply adjustments to baseline forecast.
Args:
baseline_forecast: Baseline forecast data
adjustments: Adjustments to apply
scenario_type: Type of scenario
Returns:
Adjusted forecast
"""
scenario_forecast = []
for day_forecast in baseline_forecast:
adjusted_forecast = day_forecast.copy()
# Apply different adjustment types
if 'demand_multiplier' in adjustments:
# Multiply demand by factor
multiplier = adjustments['demand_multiplier']
adjusted_forecast['predicted_demand'] *= multiplier
adjusted_forecast['confidence_lower'] *= multiplier
adjusted_forecast['confidence_upper'] *= multiplier
if 'demand_offset' in adjustments:
# Add/subtract fixed amount
offset = adjustments['demand_offset']
adjusted_forecast['predicted_demand'] += offset
adjusted_forecast['confidence_lower'] += offset
adjusted_forecast['confidence_upper'] += offset
if 'event_impact' in adjustments:
# Apply event-specific impact
event_multiplier = adjustments['event_impact']
adjusted_forecast['predicted_demand'] *= event_multiplier
if 'weather_impact' in adjustments:
# Apply weather adjustments
weather_factor = adjustments['weather_impact']
adjusted_forecast['predicted_demand'] *= weather_factor
if 'price_elasticity' in adjustments and 'price_change_percent' in adjustments:
# Apply price elasticity
elasticity = adjustments['price_elasticity']
price_change = adjustments['price_change_percent']
demand_change = -elasticity * price_change # Negative correlation
adjusted_forecast['predicted_demand'] *= (1 + demand_change)
# Ensure non-negative demand
adjusted_forecast['predicted_demand'] = max(0, adjusted_forecast['predicted_demand'])
adjusted_forecast['confidence_lower'] = max(0, adjusted_forecast['confidence_lower'])
scenario_forecast.append(adjusted_forecast)
return scenario_forecast
def _calculate_scenario_impact(
self,
baseline_forecast: List[Dict[str, Any]],
scenario_forecast: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""
Calculate impact of scenario vs baseline.
Args:
baseline_forecast: Baseline forecast
scenario_forecast: Scenario forecast
Returns:
Impact analysis
"""
baseline_total = sum(f['predicted_demand'] for f in baseline_forecast)
scenario_total = sum(f['predicted_demand'] for f in scenario_forecast)
difference = scenario_total - baseline_total
percent_change = (difference / baseline_total * 100) if baseline_total > 0 else 0
return {
'baseline_total_demand': baseline_total,
'scenario_total_demand': scenario_total,
'absolute_difference': difference,
'percent_change': percent_change,
'impact_category': self._categorize_impact(percent_change),
'days_analyzed': len(baseline_forecast)
}
def _categorize_impact(self, percent_change: float) -> str:
"""Categorize impact magnitude"""
if abs(percent_change) < 5:
return "minimal"
elif abs(percent_change) < 15:
return "moderate"
elif abs(percent_change) < 30:
return "significant"
else:
return "major"
def generate_predefined_scenarios(
self,
base_scenario: Dict[str, Any]
) -> List[Dict[str, Any]]:
"""
Generate common predefined scenarios for comparison.
Args:
base_scenario: Base scenario parameters
Returns:
List of scenario configurations
"""
scenarios = []
# Baseline scenario
scenarios.append({
'scenario_name': 'Baseline',
'scenario_type': ScenarioType.BASELINE,
'adjustments': {}
})
# Optimistic scenario
scenarios.append({
'scenario_name': 'Optimistic',
'scenario_type': ScenarioType.OPTIMISTIC,
'adjustments': {
'demand_multiplier': 1.2, # 20% increase
'description': '+20% demand increase'
}
})
# Pessimistic scenario
scenarios.append({
'scenario_name': 'Pessimistic',
'scenario_type': ScenarioType.PESSIMISTIC,
'adjustments': {
'demand_multiplier': 0.8, # 20% decrease
'description': '-20% demand decrease'
}
})
# Promotion scenario
scenarios.append({
'scenario_name': 'Promotion Campaign',
'scenario_type': ScenarioType.PROMOTION,
'adjustments': {
'demand_multiplier': 1.5, # 50% increase
'description': '50% promotion boost'
}
})
# Bad weather scenario
scenarios.append({
'scenario_name': 'Bad Weather',
'scenario_type': ScenarioType.WEATHER,
'adjustments': {
'weather_impact': 0.7, # 30% decrease
'description': 'Bad weather reduces foot traffic'
}
})
# Price increase scenario
scenarios.append({
'scenario_name': 'Price Increase 10%',
'scenario_type': ScenarioType.PRICE_CHANGE,
'adjustments': {
'price_elasticity': 1.2, # Elastic demand
'price_change_percent': 0.10, # 10% price increase
'description': '10% price increase with elastic demand'
}
})
return scenarios

View File

@@ -0,0 +1,29 @@
"""
Forecasting Service Models Package
Import all models to ensure they are registered with SQLAlchemy Base.
"""
# Import AuditLog model for this service
from shared.security import create_audit_log_model
from shared.database.base import Base
# Create audit log model for this service
AuditLog = create_audit_log_model(Base)
# Import all models to register them with the Base metadata
from .forecasts import Forecast, PredictionBatch
from .predictions import ModelPerformanceMetric, PredictionCache
from .validation_run import ValidationRun
from .sales_data_update import SalesDataUpdate
# List all models for easier access
__all__ = [
"Forecast",
"PredictionBatch",
"ModelPerformanceMetric",
"PredictionCache",
"ValidationRun",
"SalesDataUpdate",
"AuditLog",
]

View File

@@ -0,0 +1,101 @@
# ================================================================
# services/forecasting/app/models/forecasts.py
# ================================================================
"""
Forecast models for the forecasting service
"""
from sqlalchemy import Column, String, Integer, Float, DateTime, Boolean, Text, JSON, UniqueConstraint, Index
from sqlalchemy.dialects.postgresql import UUID
from datetime import datetime, timezone
import uuid
from shared.database.base import Base
class Forecast(Base):
"""Forecast model for storing prediction results"""
__tablename__ = "forecasts"
__table_args__ = (
# Unique constraint to prevent duplicate forecasts
# Ensures only one forecast per (tenant, product, date, location) combination
UniqueConstraint(
'tenant_id', 'inventory_product_id', 'forecast_date', 'location',
name='uq_forecast_tenant_product_date_location'
),
# Composite index for common query patterns
Index('ix_forecasts_tenant_product_date', 'tenant_id', 'inventory_product_id', 'forecast_date'),
)
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
inventory_product_id = Column(UUID(as_uuid=True), nullable=False, index=True) # Reference to inventory service
product_name = Column(String(255), nullable=True, index=True) # Product name (optional - use inventory_product_id as reference)
location = Column(String(255), nullable=False, index=True)
# Forecast period
forecast_date = Column(DateTime(timezone=True), nullable=False, index=True)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
# Prediction results
predicted_demand = Column(Float, nullable=False)
confidence_lower = Column(Float, nullable=False)
confidence_upper = Column(Float, nullable=False)
confidence_level = Column(Float, default=0.8)
# Model information
model_id = Column(String(255), nullable=False)
model_version = Column(String(50), nullable=False)
algorithm = Column(String(50), default="prophet")
# Business context
business_type = Column(String(50), default="individual") # individual or central_workshop
day_of_week = Column(Integer, nullable=False)
is_holiday = Column(Boolean, default=False)
is_weekend = Column(Boolean, default=False)
# External factors
weather_temperature = Column(Float)
weather_precipitation = Column(Float)
weather_description = Column(String(100))
traffic_volume = Column(Integer)
# Metadata
processing_time_ms = Column(Integer)
features_used = Column(JSON)
def __repr__(self):
return f"<Forecast(id={self.id}, inventory_product_id={self.inventory_product_id}, date={self.forecast_date})>"
class PredictionBatch(Base):
"""Batch prediction requests"""
__tablename__ = "prediction_batches"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
# Batch information
batch_name = Column(String(255), nullable=False)
requested_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
completed_at = Column(DateTime(timezone=True))
# Status
status = Column(String(50), default="pending") # pending, processing, completed, failed
total_products = Column(Integer, default=0)
completed_products = Column(Integer, default=0)
failed_products = Column(Integer, default=0)
# Configuration
forecast_days = Column(Integer, default=7)
business_type = Column(String(50), default="individual")
# Results
error_message = Column(Text)
processing_time_ms = Column(Integer)
cancelled_by = Column(String, nullable=True)
def __repr__(self):
return f"<PredictionBatch(id={self.id}, status={self.status})>"

View File

@@ -0,0 +1,67 @@
# ================================================================
# services/forecasting/app/models/predictions.py
# ================================================================
"""
Additional prediction models for the forecasting service
"""
from sqlalchemy import Column, String, Integer, Float, DateTime, Boolean, Text, JSON
from sqlalchemy.dialects.postgresql import UUID
from datetime import datetime, timezone
import uuid
from shared.database.base import Base
class ModelPerformanceMetric(Base):
"""Track model performance over time"""
__tablename__ = "model_performance_metrics"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
model_id = Column(UUID(as_uuid=True), nullable=False, index=True)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
inventory_product_id = Column(UUID(as_uuid=True), nullable=False) # Reference to inventory service
# Performance metrics
mae = Column(Float) # Mean Absolute Error
mape = Column(Float) # Mean Absolute Percentage Error
rmse = Column(Float) # Root Mean Square Error
accuracy_score = Column(Float)
# Evaluation period
evaluation_date = Column(DateTime(timezone=True), nullable=False)
evaluation_period_start = Column(DateTime(timezone=True))
evaluation_period_end = Column(DateTime(timezone=True))
# Metadata
sample_size = Column(Integer)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
def __repr__(self):
return f"<ModelPerformanceMetric(model_id={self.model_id}, mae={self.mae})>"
class PredictionCache(Base):
"""Cache frequently requested predictions"""
__tablename__ = "prediction_cache"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
cache_key = Column(String(255), unique=True, nullable=False, index=True)
# Cached data
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
inventory_product_id = Column(UUID(as_uuid=True), nullable=False) # Reference to inventory service
location = Column(String(255), nullable=False)
forecast_date = Column(DateTime(timezone=True), nullable=False)
# Cached results
predicted_demand = Column(Float, nullable=False)
confidence_lower = Column(Float, nullable=False)
confidence_upper = Column(Float, nullable=False)
model_id = Column(UUID(as_uuid=True), nullable=False)
# Cache metadata
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
expires_at = Column(DateTime(timezone=True), nullable=False)
hit_count = Column(Integer, default=0)
def __repr__(self):
return f"<PredictionCache(key={self.cache_key}, inventory_product_id={self.inventory_product_id})>"

View File

@@ -0,0 +1,78 @@
# ================================================================
# services/forecasting/app/models/sales_data_update.py
# ================================================================
"""
Sales Data Update Tracking Model
Tracks when sales data is added or updated for past dates,
enabling automated historical validation backfill.
"""
from sqlalchemy import Column, String, Integer, DateTime, Boolean, Index, Date
from sqlalchemy.dialects.postgresql import UUID
from datetime import datetime, timezone
import uuid
from shared.database.base import Base
class SalesDataUpdate(Base):
"""Track sales data updates for historical validation"""
__tablename__ = "sales_data_updates"
__table_args__ = (
Index('ix_sales_updates_tenant_status', 'tenant_id', 'validation_status', 'created_at'),
Index('ix_sales_updates_date_range', 'tenant_id', 'update_date_start', 'update_date_end'),
Index('ix_sales_updates_validation_status', 'validation_status'),
)
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
# Date range of sales data that was added/updated
update_date_start = Column(Date, nullable=False, index=True)
update_date_end = Column(Date, nullable=False, index=True)
# Update metadata
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
update_source = Column(String(100), nullable=True) # import, manual, pos_sync
records_affected = Column(Integer, default=0)
# Validation tracking
validation_status = Column(String(50), default="pending") # pending, processing, completed, failed
validation_run_id = Column(UUID(as_uuid=True), nullable=True)
validated_at = Column(DateTime(timezone=True), nullable=True)
validation_error = Column(String(500), nullable=True)
# Determines if this update should trigger validation
requires_validation = Column(Boolean, default=True)
# Additional context
import_job_id = Column(String(255), nullable=True) # Link to sales import job if applicable
notes = Column(String(500), nullable=True)
def __repr__(self):
return (
f"<SalesDataUpdate(id={self.id}, tenant_id={self.tenant_id}, "
f"date_range={self.update_date_start} to {self.update_date_end}, "
f"status={self.validation_status})>"
)
def to_dict(self):
"""Convert to dictionary for API responses"""
return {
'id': str(self.id),
'tenant_id': str(self.tenant_id),
'update_date_start': self.update_date_start.isoformat() if self.update_date_start else None,
'update_date_end': self.update_date_end.isoformat() if self.update_date_end else None,
'created_at': self.created_at.isoformat() if self.created_at else None,
'update_source': self.update_source,
'records_affected': self.records_affected,
'validation_status': self.validation_status,
'validation_run_id': str(self.validation_run_id) if self.validation_run_id else None,
'validated_at': self.validated_at.isoformat() if self.validated_at else None,
'validation_error': self.validation_error,
'requires_validation': self.requires_validation,
'import_job_id': self.import_job_id,
'notes': self.notes
}

View File

@@ -0,0 +1,110 @@
# ================================================================
# services/forecasting/app/models/validation_run.py
# ================================================================
"""
Validation run models for tracking forecast validation executions
"""
from sqlalchemy import Column, String, Integer, Float, DateTime, Text, JSON, Index
from sqlalchemy.dialects.postgresql import UUID
from datetime import datetime, timezone
import uuid
from shared.database.base import Base
class ValidationRun(Base):
"""Track forecast validation execution runs"""
__tablename__ = "validation_runs"
__table_args__ = (
Index('ix_validation_runs_tenant_created', 'tenant_id', 'started_at'),
Index('ix_validation_runs_status', 'status', 'started_at'),
Index('ix_validation_runs_orchestration', 'orchestration_run_id'),
)
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
# Link to orchestration run (if triggered by orchestrator)
orchestration_run_id = Column(UUID(as_uuid=True), nullable=True)
# Validation period
validation_start_date = Column(DateTime(timezone=True), nullable=False)
validation_end_date = Column(DateTime(timezone=True), nullable=False)
# Execution metadata
started_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
completed_at = Column(DateTime(timezone=True), nullable=True)
duration_seconds = Column(Float, nullable=True)
# Status and results
status = Column(String(50), default="pending") # pending, running, completed, failed
# Validation statistics
total_forecasts_evaluated = Column(Integer, default=0)
forecasts_with_actuals = Column(Integer, default=0)
forecasts_without_actuals = Column(Integer, default=0)
# Accuracy metrics summary (across all validated forecasts)
overall_mae = Column(Float, nullable=True)
overall_mape = Column(Float, nullable=True)
overall_rmse = Column(Float, nullable=True)
overall_r2_score = Column(Float, nullable=True)
overall_accuracy_percentage = Column(Float, nullable=True)
# Additional statistics
total_predicted_demand = Column(Float, default=0.0)
total_actual_demand = Column(Float, default=0.0)
# Breakdown by product/location (JSON)
metrics_by_product = Column(JSON, nullable=True) # {product_id: {mae, mape, ...}}
metrics_by_location = Column(JSON, nullable=True) # {location: {mae, mape, ...}}
# Performance metrics created count
metrics_records_created = Column(Integer, default=0)
# Error tracking
error_message = Column(Text, nullable=True)
error_details = Column(JSON, nullable=True)
# Execution context
triggered_by = Column(String(100), default="manual") # manual, orchestrator, scheduled
execution_mode = Column(String(50), default="batch") # batch, single_day, real_time
def __repr__(self):
return (
f"<ValidationRun(id={self.id}, tenant_id={self.tenant_id}, "
f"status={self.status}, forecasts_evaluated={self.total_forecasts_evaluated})>"
)
def to_dict(self):
"""Convert to dictionary for API responses"""
return {
'id': str(self.id),
'tenant_id': str(self.tenant_id),
'orchestration_run_id': str(self.orchestration_run_id) if self.orchestration_run_id else None,
'validation_start_date': self.validation_start_date.isoformat() if self.validation_start_date else None,
'validation_end_date': self.validation_end_date.isoformat() if self.validation_end_date else None,
'started_at': self.started_at.isoformat() if self.started_at else None,
'completed_at': self.completed_at.isoformat() if self.completed_at else None,
'duration_seconds': self.duration_seconds,
'status': self.status,
'total_forecasts_evaluated': self.total_forecasts_evaluated,
'forecasts_with_actuals': self.forecasts_with_actuals,
'forecasts_without_actuals': self.forecasts_without_actuals,
'overall_mae': self.overall_mae,
'overall_mape': self.overall_mape,
'overall_rmse': self.overall_rmse,
'overall_r2_score': self.overall_r2_score,
'overall_accuracy_percentage': self.overall_accuracy_percentage,
'total_predicted_demand': self.total_predicted_demand,
'total_actual_demand': self.total_actual_demand,
'metrics_by_product': self.metrics_by_product,
'metrics_by_location': self.metrics_by_location,
'metrics_records_created': self.metrics_records_created,
'error_message': self.error_message,
'error_details': self.error_details,
'triggered_by': self.triggered_by,
'execution_mode': self.execution_mode,
}

View File

@@ -0,0 +1,18 @@
"""
Forecasting Service Repositories
Repository implementations for forecasting service
"""
from .base import ForecastingBaseRepository
from .forecast_repository import ForecastRepository
from .prediction_batch_repository import PredictionBatchRepository
from .performance_metric_repository import PerformanceMetricRepository
from .prediction_cache_repository import PredictionCacheRepository
__all__ = [
"ForecastingBaseRepository",
"ForecastRepository",
"PredictionBatchRepository",
"PerformanceMetricRepository",
"PredictionCacheRepository"
]

View File

@@ -0,0 +1,253 @@
"""
Base Repository for Forecasting Service
Service-specific repository base class with forecasting utilities
"""
from typing import Optional, List, Dict, Any, Type
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
from datetime import datetime, date, timedelta, timezone
import structlog
from shared.database.repository import BaseRepository
from shared.database.exceptions import DatabaseError
logger = structlog.get_logger()
class ForecastingBaseRepository(BaseRepository):
"""Base repository for forecasting service with common forecasting operations"""
def __init__(self, model: Type, session: AsyncSession, cache_ttl: Optional[int] = 600):
# Forecasting data benefits from medium cache time (10 minutes)
super().__init__(model, session, cache_ttl)
async def get_by_tenant_id(self, tenant_id: str, skip: int = 0, limit: int = 100) -> List:
"""Get records by tenant ID"""
if hasattr(self.model, 'tenant_id'):
return await self.get_multi(
skip=skip,
limit=limit,
filters={"tenant_id": tenant_id},
order_by="created_at",
order_desc=True
)
return await self.get_multi(skip=skip, limit=limit)
async def get_by_inventory_product_id(
self,
tenant_id: str,
inventory_product_id: str,
skip: int = 0,
limit: int = 100
) -> List:
"""Get records by tenant and inventory product"""
if hasattr(self.model, 'inventory_product_id'):
return await self.get_multi(
skip=skip,
limit=limit,
filters={
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id
},
order_by="created_at",
order_desc=True
)
return await self.get_by_tenant_id(tenant_id, skip, limit)
async def get_by_date_range(
self,
tenant_id: str,
start_date: datetime,
end_date: datetime,
skip: int = 0,
limit: int = 100
) -> List:
"""Get records within date range for a tenant"""
if not hasattr(self.model, 'forecast_date') and not hasattr(self.model, 'created_at'):
logger.warning(f"Model {self.model.__name__} has no date field for filtering")
return []
try:
table_name = self.model.__tablename__
date_field = "forecast_date" if hasattr(self.model, 'forecast_date') else "created_at"
query_text = f"""
SELECT * FROM {table_name}
WHERE tenant_id = :tenant_id
AND {date_field} >= :start_date
AND {date_field} <= :end_date
ORDER BY {date_field} DESC
LIMIT :limit OFFSET :skip
"""
result = await self.session.execute(text(query_text), {
"tenant_id": tenant_id,
"start_date": start_date,
"end_date": end_date,
"limit": limit,
"skip": skip
})
# Convert rows to model objects
records = []
for row in result.fetchall():
record_dict = dict(row._mapping)
record = self.model(**record_dict)
records.append(record)
return records
except Exception as e:
logger.error("Failed to get records by date range",
model=self.model.__name__,
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Date range query failed: {str(e)}")
async def get_recent_records(
self,
tenant_id: str,
hours: int = 24,
skip: int = 0,
limit: int = 100
) -> List:
"""Get recent records for a tenant"""
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours)
return await self.get_by_date_range(
tenant_id, cutoff_time, datetime.now(timezone.utc), skip, limit
)
async def cleanup_old_records(self, days_old: int = 90) -> int:
"""Clean up old forecasting records"""
try:
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_old)
table_name = self.model.__tablename__
# Use created_at or forecast_date for cleanup
date_field = "forecast_date" if hasattr(self.model, 'forecast_date') else "created_at"
query_text = f"""
DELETE FROM {table_name}
WHERE {date_field} < :cutoff_date
"""
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
deleted_count = result.rowcount
logger.info(f"Cleaned up old {self.model.__name__} records",
deleted_count=deleted_count,
days_old=days_old)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup old records",
model=self.model.__name__,
error=str(e))
raise DatabaseError(f"Cleanup failed: {str(e)}")
async def get_statistics_by_tenant(self, tenant_id: str) -> Dict[str, Any]:
"""Get statistics for a tenant"""
try:
table_name = self.model.__tablename__
# Get basic counts
total_records = await self.count(filters={"tenant_id": tenant_id})
# Get recent activity (records in last 7 days)
seven_days_ago = datetime.now(timezone.utc) - timedelta(days=7)
recent_records = len(await self.get_by_date_range(
tenant_id, seven_days_ago, datetime.now(timezone.utc), limit=1000
))
# Get records by product if applicable
product_stats = {}
if hasattr(self.model, 'inventory_product_id'):
product_query = text(f"""
SELECT inventory_product_id, COUNT(*) as count
FROM {table_name}
WHERE tenant_id = :tenant_id
GROUP BY inventory_product_id
ORDER BY count DESC
""")
result = await self.session.execute(product_query, {"tenant_id": tenant_id})
product_stats = {row.inventory_product_id: row.count for row in result.fetchall()}
return {
"total_records": total_records,
"recent_records_7d": recent_records,
"records_by_product": product_stats
}
except Exception as e:
logger.error("Failed to get tenant statistics",
model=self.model.__name__,
tenant_id=tenant_id,
error=str(e))
return {
"total_records": 0,
"recent_records_7d": 0,
"records_by_product": {}
}
def _validate_forecast_data(self, data: Dict[str, Any], required_fields: List[str]) -> Dict[str, Any]:
"""Validate forecasting-related data"""
errors = []
for field in required_fields:
if field not in data or data[field] is None:
errors.append(f"Missing required field: {field}")
# Validate tenant_id format if present
if "tenant_id" in data and data["tenant_id"]:
tenant_id = data["tenant_id"]
if not isinstance(tenant_id, str) or len(tenant_id) < 1:
errors.append("Invalid tenant_id format")
# Validate inventory_product_id if present
if "inventory_product_id" in data and data["inventory_product_id"]:
inventory_product_id = data["inventory_product_id"]
if not isinstance(inventory_product_id, str) or len(inventory_product_id) < 1:
errors.append("Invalid inventory_product_id format")
# Validate dates if present - accept datetime objects, date objects, and date strings
date_fields = ["forecast_date", "created_at", "evaluation_date", "expires_at"]
for field in date_fields:
if field in data and data[field]:
field_value = data[field]
field_type = type(field_value).__name__
if isinstance(field_value, (datetime, date)):
logger.debug(f"Date field {field} is valid {field_type}", field_value=str(field_value))
continue # Already a datetime or date, valid
elif isinstance(field_value, str):
# Try to parse the string date
try:
from dateutil.parser import parse
parse(field_value) # Just validate, don't convert yet
logger.debug(f"Date field {field} is valid string", field_value=field_value)
except (ValueError, TypeError) as e:
logger.error(f"Date parsing failed for {field}", field_value=field_value, error=str(e))
errors.append(f"Invalid {field} format - must be datetime or valid date string")
else:
logger.error(f"Date field {field} has invalid type {field_type}", field_value=str(field_value))
errors.append(f"Invalid {field} format - must be datetime or valid date string")
# Validate numeric fields
numeric_fields = [
"predicted_demand", "confidence_lower", "confidence_upper",
"mae", "mape", "rmse", "accuracy_score"
]
for field in numeric_fields:
if field in data and data[field] is not None:
try:
float(data[field])
except (ValueError, TypeError):
errors.append(f"Invalid {field} format - must be numeric")
return {
"is_valid": len(errors) == 0,
"errors": errors
}

View File

@@ -0,0 +1,565 @@
"""
Forecast Repository
Repository for forecast operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, text, desc, func
from sqlalchemy.exc import IntegrityError
from datetime import datetime, timedelta, date, timezone
import structlog
from .base import ForecastingBaseRepository
from app.models.forecasts import Forecast
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class ForecastRepository(ForecastingBaseRepository):
"""Repository for forecast operations"""
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 600):
# Forecasts are relatively stable, medium cache time (10 minutes)
super().__init__(Forecast, session, cache_ttl)
async def create_forecast(self, forecast_data: Dict[str, Any]) -> Forecast:
"""
Create a new forecast with validation.
Handles duplicate forecast race condition gracefully:
If a forecast already exists for the same (tenant, product, date, location),
it will be updated instead of creating a duplicate.
"""
try:
# Validate forecast data
validation_result = self._validate_forecast_data(
forecast_data,
["tenant_id", "inventory_product_id", "location", "forecast_date",
"predicted_demand", "confidence_lower", "confidence_upper", "model_id"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid forecast data: {validation_result['errors']}")
# Set default values
if "confidence_level" not in forecast_data:
forecast_data["confidence_level"] = 0.8
if "algorithm" not in forecast_data:
forecast_data["algorithm"] = "prophet"
if "business_type" not in forecast_data:
forecast_data["business_type"] = "individual"
# Try to create forecast
try:
forecast = await self.create(forecast_data)
logger.info("Forecast created successfully",
forecast_id=forecast.id,
tenant_id=forecast.tenant_id,
inventory_product_id=forecast.inventory_product_id,
forecast_date=forecast.forecast_date.isoformat())
return forecast
except IntegrityError as ie:
# Handle unique constraint violation (duplicate forecast)
error_msg = str(ie).lower()
if "unique constraint" in error_msg or "duplicate" in error_msg or "uq_forecast_tenant_product_date_location" in error_msg:
logger.warning("Forecast already exists (race condition), updating instead",
tenant_id=forecast_data.get("tenant_id"),
inventory_product_id=forecast_data.get("inventory_product_id"),
forecast_date=str(forecast_data.get("forecast_date")))
# Rollback the failed insert
await self.session.rollback()
# Fetch the existing forecast
existing_forecast = await self.get_existing_forecast(
tenant_id=forecast_data["tenant_id"],
inventory_product_id=forecast_data["inventory_product_id"],
forecast_date=forecast_data["forecast_date"],
location=forecast_data["location"]
)
if existing_forecast:
# Update existing forecast with new prediction data
update_data = {
"predicted_demand": forecast_data["predicted_demand"],
"confidence_lower": forecast_data["confidence_lower"],
"confidence_upper": forecast_data["confidence_upper"],
"confidence_level": forecast_data.get("confidence_level", 0.8),
"model_id": forecast_data["model_id"],
"model_version": forecast_data.get("model_version"),
"algorithm": forecast_data.get("algorithm", "prophet"),
"processing_time_ms": forecast_data.get("processing_time_ms"),
"features_used": forecast_data.get("features_used"),
"weather_temperature": forecast_data.get("weather_temperature"),
"weather_precipitation": forecast_data.get("weather_precipitation"),
"weather_description": forecast_data.get("weather_description"),
}
updated_forecast = await self.update(str(existing_forecast.id), update_data)
logger.info("Existing forecast updated after duplicate detection",
forecast_id=updated_forecast.id,
tenant_id=updated_forecast.tenant_id,
inventory_product_id=updated_forecast.inventory_product_id)
return updated_forecast
else:
# This shouldn't happen, but log it
logger.error("Duplicate forecast detected but not found in database")
raise DatabaseError("Duplicate forecast detected but not found")
else:
# Different integrity error, re-raise
raise
except ValidationError:
raise
except IntegrityError as ie:
# Re-raise integrity errors that weren't handled above
logger.error("Database integrity error creating forecast",
tenant_id=forecast_data.get("tenant_id"),
error=str(ie))
raise DatabaseError(f"Database integrity error: {str(ie)}")
except Exception as e:
logger.error("Failed to create forecast",
tenant_id=forecast_data.get("tenant_id"),
inventory_product_id=forecast_data.get("inventory_product_id"),
error=str(e))
raise DatabaseError(f"Failed to create forecast: {str(e)}")
async def get_existing_forecast(
self,
tenant_id: str,
inventory_product_id: str,
forecast_date: datetime,
location: str
) -> Optional[Forecast]:
"""Get an existing forecast by unique key (tenant, product, date, location)"""
try:
query = select(Forecast).where(
and_(
Forecast.tenant_id == tenant_id,
Forecast.inventory_product_id == inventory_product_id,
Forecast.forecast_date == forecast_date,
Forecast.location == location
)
)
result = await self.session.execute(query)
return result.scalar_one_or_none()
except Exception as e:
logger.error("Failed to get existing forecast", error=str(e))
return None
async def get_forecasts_by_date_range(
self,
tenant_id: str,
start_date: date,
end_date: date,
inventory_product_id: str = None,
location: str = None
) -> List[Forecast]:
"""Get forecasts within a date range"""
try:
filters = {"tenant_id": tenant_id}
if inventory_product_id:
filters["inventory_product_id"] = inventory_product_id
if location:
filters["location"] = location
# Convert dates to datetime for comparison
start_datetime = datetime.combine(start_date, datetime.min.time())
end_datetime = datetime.combine(end_date, datetime.max.time())
return await self.get_by_date_range(
tenant_id, start_datetime, end_datetime
)
except Exception as e:
logger.error("Failed to get forecasts by date range",
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
error=str(e))
raise DatabaseError(f"Failed to get forecasts: {str(e)}")
async def get_latest_forecast_for_product(
self,
tenant_id: str,
inventory_product_id: str,
location: str = None
) -> Optional[Forecast]:
"""Get the most recent forecast for a product"""
try:
filters = {
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id
}
if location:
filters["location"] = location
forecasts = await self.get_multi(
filters=filters,
limit=1,
order_by="forecast_date",
order_desc=True
)
return forecasts[0] if forecasts else None
except Exception as e:
logger.error("Failed to get latest forecast for product",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
error=str(e))
raise DatabaseError(f"Failed to get latest forecast: {str(e)}")
async def get_forecasts_for_date(
self,
tenant_id: str,
forecast_date: date,
inventory_product_id: str = None
) -> List[Forecast]:
"""Get all forecasts for a specific date"""
try:
# Convert date to datetime range
start_datetime = datetime.combine(forecast_date, datetime.min.time())
end_datetime = datetime.combine(forecast_date, datetime.max.time())
return await self.get_by_date_range(
tenant_id, start_datetime, end_datetime
)
except Exception as e:
logger.error("Failed to get forecasts for date",
tenant_id=tenant_id,
forecast_date=forecast_date,
error=str(e))
raise DatabaseError(f"Failed to get forecasts for date: {str(e)}")
async def get_forecast_accuracy_metrics(
self,
tenant_id: str,
inventory_product_id: str = None,
days_back: int = 30
) -> Dict[str, Any]:
"""Get forecast accuracy metrics"""
try:
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_back)
# Build base query conditions
conditions = ["tenant_id = :tenant_id", "forecast_date >= :cutoff_date"]
params = {
"tenant_id": tenant_id,
"cutoff_date": cutoff_date
}
if inventory_product_id:
conditions.append("inventory_product_id = :inventory_product_id")
params["inventory_product_id"] = inventory_product_id
query_text = f"""
SELECT
COUNT(*) as total_forecasts,
AVG(predicted_demand) as avg_predicted_demand,
MIN(predicted_demand) as min_predicted_demand,
MAX(predicted_demand) as max_predicted_demand,
AVG(confidence_upper - confidence_lower) as avg_confidence_interval,
AVG(processing_time_ms) as avg_processing_time_ms,
COUNT(DISTINCT inventory_product_id) as unique_products,
COUNT(DISTINCT model_id) as unique_models
FROM forecasts
WHERE {' AND '.join(conditions)}
"""
result = await self.session.execute(text(query_text), params)
row = result.fetchone()
if row and row.total_forecasts > 0:
return {
"total_forecasts": int(row.total_forecasts),
"avg_predicted_demand": float(row.avg_predicted_demand or 0),
"min_predicted_demand": float(row.min_predicted_demand or 0),
"max_predicted_demand": float(row.max_predicted_demand or 0),
"avg_confidence_interval": float(row.avg_confidence_interval or 0),
"avg_processing_time_ms": float(row.avg_processing_time_ms or 0),
"unique_products": int(row.unique_products or 0),
"unique_models": int(row.unique_models or 0),
"period_days": days_back
}
return {
"total_forecasts": 0,
"avg_predicted_demand": 0.0,
"min_predicted_demand": 0.0,
"max_predicted_demand": 0.0,
"avg_confidence_interval": 0.0,
"avg_processing_time_ms": 0.0,
"unique_products": 0,
"unique_models": 0,
"period_days": days_back
}
except Exception as e:
logger.error("Failed to get forecast accuracy metrics",
tenant_id=tenant_id,
error=str(e))
return {
"total_forecasts": 0,
"avg_predicted_demand": 0.0,
"min_predicted_demand": 0.0,
"max_predicted_demand": 0.0,
"avg_confidence_interval": 0.0,
"avg_processing_time_ms": 0.0,
"unique_products": 0,
"unique_models": 0,
"period_days": days_back
}
async def get_demand_trends(
self,
tenant_id: str,
inventory_product_id: str,
days_back: int = 30
) -> Dict[str, Any]:
"""Get demand trends for a product"""
try:
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_back)
query_text = """
SELECT
DATE(forecast_date) as date,
AVG(predicted_demand) as avg_demand,
MIN(predicted_demand) as min_demand,
MAX(predicted_demand) as max_demand,
COUNT(*) as forecast_count
FROM forecasts
WHERE tenant_id = :tenant_id
AND inventory_product_id = :inventory_product_id
AND forecast_date >= :cutoff_date
GROUP BY DATE(forecast_date)
ORDER BY date DESC
"""
result = await self.session.execute(text(query_text), {
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id,
"cutoff_date": cutoff_date
})
trends = []
for row in result.fetchall():
trends.append({
"date": row.date.isoformat() if row.date else None,
"avg_demand": float(row.avg_demand),
"min_demand": float(row.min_demand),
"max_demand": float(row.max_demand),
"forecast_count": int(row.forecast_count)
})
# Calculate overall trend direction
if len(trends) >= 2:
recent_avg = sum(t["avg_demand"] for t in trends[:7]) / min(7, len(trends))
older_avg = sum(t["avg_demand"] for t in trends[-7:]) / min(7, len(trends[-7:]))
trend_direction = "increasing" if recent_avg > older_avg else "decreasing"
else:
trend_direction = "stable"
return {
"inventory_product_id": inventory_product_id,
"period_days": days_back,
"trends": trends,
"trend_direction": trend_direction,
"total_data_points": len(trends)
}
except Exception as e:
logger.error("Failed to get demand trends",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
error=str(e))
return {
"inventory_product_id": inventory_product_id,
"period_days": days_back,
"trends": [],
"trend_direction": "unknown",
"total_data_points": 0
}
async def get_model_usage_statistics(self, tenant_id: str) -> Dict[str, Any]:
"""Get statistics about model usage"""
try:
# Get model usage counts
model_query = text("""
SELECT
model_id,
algorithm,
COUNT(*) as usage_count,
AVG(predicted_demand) as avg_prediction,
MAX(forecast_date) as last_used,
COUNT(DISTINCT inventory_product_id) as products_covered
FROM forecasts
WHERE tenant_id = :tenant_id
GROUP BY model_id, algorithm
ORDER BY usage_count DESC
""")
result = await self.session.execute(model_query, {"tenant_id": tenant_id})
model_stats = []
for row in result.fetchall():
model_stats.append({
"model_id": row.model_id,
"algorithm": row.algorithm,
"usage_count": int(row.usage_count),
"avg_prediction": float(row.avg_prediction),
"last_used": row.last_used.isoformat() if row.last_used else None,
"products_covered": int(row.products_covered)
})
# Get algorithm distribution
algorithm_query = text("""
SELECT algorithm, COUNT(*) as count
FROM forecasts
WHERE tenant_id = :tenant_id
GROUP BY algorithm
""")
algorithm_result = await self.session.execute(algorithm_query, {"tenant_id": tenant_id})
algorithm_distribution = {row.algorithm: row.count for row in algorithm_result.fetchall()}
return {
"model_statistics": model_stats,
"algorithm_distribution": algorithm_distribution,
"total_unique_models": len(model_stats)
}
except Exception as e:
logger.error("Failed to get model usage statistics",
tenant_id=tenant_id,
error=str(e))
return {
"model_statistics": [],
"algorithm_distribution": {},
"total_unique_models": 0
}
async def cleanup_old_forecasts(self, days_old: int = 90) -> int:
"""Clean up old forecasts"""
return await self.cleanup_old_records(days_old=days_old)
async def get_forecast_summary(self, tenant_id: str) -> Dict[str, Any]:
"""Get comprehensive forecast summary for a tenant"""
try:
# Get basic statistics
basic_stats = await self.get_statistics_by_tenant(tenant_id)
# Get accuracy metrics
accuracy_metrics = await self.get_forecast_accuracy_metrics(tenant_id)
# Get model usage
model_usage = await self.get_model_usage_statistics(tenant_id)
# Get recent activity
recent_forecasts = await self.get_recent_records(tenant_id, hours=24)
return {
"tenant_id": tenant_id,
"basic_statistics": basic_stats,
"accuracy_metrics": accuracy_metrics,
"model_usage": model_usage,
"recent_activity": {
"forecasts_last_24h": len(recent_forecasts),
"latest_forecast": recent_forecasts[0].forecast_date.isoformat() if recent_forecasts else None
}
}
except Exception as e:
logger.error("Failed to get forecast summary",
tenant_id=tenant_id,
error=str(e))
return {"error": f"Failed to get forecast summary: {str(e)}"}
async def get_forecasts_by_date(
self,
tenant_id: str,
forecast_date: date,
inventory_product_id: str = None
) -> List[Forecast]:
"""
Get all forecasts for a specific date.
Used for forecast validation against actual sales.
Args:
tenant_id: Tenant UUID
forecast_date: Date to get forecasts for
inventory_product_id: Optional product filter
Returns:
List of forecasts for the date
"""
try:
query = select(Forecast).where(
and_(
Forecast.tenant_id == tenant_id,
func.date(Forecast.forecast_date) == forecast_date
)
)
if inventory_product_id:
query = query.where(Forecast.inventory_product_id == inventory_product_id)
result = await self.session.execute(query)
forecasts = result.scalars().all()
logger.info("Retrieved forecasts by date",
tenant_id=tenant_id,
forecast_date=forecast_date.isoformat(),
count=len(forecasts))
return list(forecasts)
except Exception as e:
logger.error("Failed to get forecasts by date",
tenant_id=tenant_id,
forecast_date=forecast_date.isoformat(),
error=str(e))
raise DatabaseError(f"Failed to get forecasts: {str(e)}")
async def bulk_create_forecasts(self, forecasts_data: List[Dict[str, Any]]) -> List[Forecast]:
"""Bulk create multiple forecasts"""
try:
created_forecasts = []
for forecast_data in forecasts_data:
# Validate each forecast
validation_result = self._validate_forecast_data(
forecast_data,
["tenant_id", "inventory_product_id", "location", "forecast_date",
"predicted_demand", "confidence_lower", "confidence_upper", "model_id"]
)
if not validation_result["is_valid"]:
logger.warning("Skipping invalid forecast data",
errors=validation_result["errors"],
data=forecast_data)
continue
forecast = await self.create(forecast_data)
created_forecasts.append(forecast)
logger.info("Bulk created forecasts",
requested_count=len(forecasts_data),
created_count=len(created_forecasts))
return created_forecasts
except Exception as e:
logger.error("Failed to bulk create forecasts",
requested_count=len(forecasts_data),
error=str(e))
raise DatabaseError(f"Bulk forecast creation failed: {str(e)}")

View File

@@ -0,0 +1,214 @@
# services/forecasting/app/repositories/forecasting_alert_repository.py
"""
Forecasting Alert Repository
Data access layer for forecasting-specific alert detection and analysis
"""
from typing import List, Dict, Any
from uuid import UUID
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
import structlog
logger = structlog.get_logger()
class ForecastingAlertRepository:
"""Repository for forecasting alert data access"""
def __init__(self, session: AsyncSession):
self.session = session
async def get_weekend_demand_surges(self, tenant_id: UUID) -> List[Dict[str, Any]]:
"""
Get predicted weekend demand surges
Returns forecasts showing significant growth over previous weeks
"""
try:
query = text("""
WITH weekend_forecast AS (
SELECT
f.tenant_id,
f.inventory_product_id,
f.product_name,
f.predicted_demand,
f.forecast_date,
LAG(f.predicted_demand, 7) OVER (
PARTITION BY f.tenant_id, f.inventory_product_id
ORDER BY f.forecast_date
) as prev_week_demand,
AVG(f.predicted_demand) OVER (
PARTITION BY f.tenant_id, f.inventory_product_id
ORDER BY f.forecast_date
ROWS BETWEEN 6 PRECEDING AND CURRENT ROW
) as avg_weekly_demand
FROM forecasts f
WHERE f.forecast_date >= CURRENT_DATE + INTERVAL '1 day'
AND f.forecast_date <= CURRENT_DATE + INTERVAL '3 days'
AND EXTRACT(DOW FROM f.forecast_date) IN (6, 0)
AND f.tenant_id = :tenant_id
),
surge_analysis AS (
SELECT *,
CASE
WHEN prev_week_demand > 0 THEN
(predicted_demand - prev_week_demand) / prev_week_demand * 100
ELSE 0
END as growth_percentage,
CASE
WHEN avg_weekly_demand > 0 THEN
(predicted_demand - avg_weekly_demand) / avg_weekly_demand * 100
ELSE 0
END as avg_growth_percentage
FROM weekend_forecast
)
SELECT * FROM surge_analysis
WHERE growth_percentage > 50 OR avg_growth_percentage > 50
ORDER BY growth_percentage DESC
""")
result = await self.session.execute(query, {"tenant_id": tenant_id})
return [dict(row._mapping) for row in result.fetchall()]
except Exception as e:
logger.error("Failed to get weekend demand surges", error=str(e), tenant_id=str(tenant_id))
raise
async def get_weather_impact_forecasts(self, tenant_id: UUID) -> List[Dict[str, Any]]:
"""
Get weather impact on demand forecasts
Returns forecasts with rain or significant demand changes
"""
try:
query = text("""
WITH weather_impact AS (
SELECT
f.tenant_id,
f.inventory_product_id,
f.product_name,
f.predicted_demand,
f.forecast_date,
f.weather_precipitation,
f.weather_temperature,
f.traffic_volume,
AVG(f.predicted_demand) OVER (
PARTITION BY f.tenant_id, f.inventory_product_id
ORDER BY f.forecast_date
ROWS BETWEEN 6 PRECEDING AND CURRENT ROW
) as avg_demand
FROM forecasts f
WHERE f.forecast_date >= CURRENT_DATE + INTERVAL '1 day'
AND f.forecast_date <= CURRENT_DATE + INTERVAL '2 days'
AND f.tenant_id = :tenant_id
),
rain_impact AS (
SELECT *,
CASE
WHEN weather_precipitation > 2.0 THEN true
ELSE false
END as rain_forecast,
CASE
WHEN traffic_volume < 80 THEN true
ELSE false
END as low_traffic_expected,
(predicted_demand - avg_demand) / avg_demand * 100 as demand_change
FROM weather_impact
)
SELECT * FROM rain_impact
WHERE rain_forecast = true OR demand_change < -15
ORDER BY demand_change ASC
""")
result = await self.session.execute(query, {"tenant_id": tenant_id})
return [dict(row._mapping) for row in result.fetchall()]
except Exception as e:
logger.error("Failed to get weather impact forecasts", error=str(e), tenant_id=str(tenant_id))
raise
async def get_holiday_demand_spikes(self, tenant_id: UUID) -> List[Dict[str, Any]]:
"""
Get historical holiday demand spike analysis
Returns products with significant holiday demand increases
"""
try:
query = text("""
WITH holiday_demand AS (
SELECT
f.tenant_id,
f.inventory_product_id,
f.product_name,
AVG(f.predicted_demand) as avg_holiday_demand,
AVG(CASE WHEN f.is_holiday = false THEN f.predicted_demand END) as avg_normal_demand,
COUNT(*) as forecast_count
FROM forecasts f
WHERE f.created_at > CURRENT_DATE - INTERVAL '365 days'
AND f.tenant_id = :tenant_id
GROUP BY f.tenant_id, f.inventory_product_id, f.product_name
HAVING COUNT(*) >= 10
),
demand_spike_analysis AS (
SELECT *,
CASE
WHEN avg_normal_demand > 0 THEN
(avg_holiday_demand - avg_normal_demand) / avg_normal_demand * 100
ELSE 0
END as spike_percentage
FROM holiday_demand
)
SELECT * FROM demand_spike_analysis
WHERE spike_percentage > 25
ORDER BY spike_percentage DESC
""")
result = await self.session.execute(query, {"tenant_id": tenant_id})
return [dict(row._mapping) for row in result.fetchall()]
except Exception as e:
logger.error("Failed to get holiday demand spikes", error=str(e), tenant_id=str(tenant_id))
raise
async def get_demand_pattern_analysis(self, tenant_id: UUID) -> List[Dict[str, Any]]:
"""
Get weekly demand pattern analysis for optimization
Returns products with significant demand variations
"""
try:
query = text("""
WITH weekly_patterns AS (
SELECT
f.tenant_id,
f.inventory_product_id,
f.product_name,
EXTRACT(DOW FROM f.forecast_date) as day_of_week,
AVG(f.predicted_demand) as avg_demand,
STDDEV(f.predicted_demand) as demand_variance,
COUNT(*) as data_points
FROM forecasts f
WHERE f.created_at > CURRENT_DATE - INTERVAL '60 days'
AND f.tenant_id = :tenant_id
GROUP BY f.tenant_id, f.inventory_product_id, f.product_name, EXTRACT(DOW FROM f.forecast_date)
HAVING COUNT(*) >= 5
),
pattern_analysis AS (
SELECT
tenant_id, inventory_product_id, product_name,
MAX(avg_demand) as peak_demand,
MIN(avg_demand) as min_demand,
AVG(avg_demand) as overall_avg,
MAX(avg_demand) - MIN(avg_demand) as demand_range
FROM weekly_patterns
GROUP BY tenant_id, inventory_product_id, product_name
)
SELECT * FROM pattern_analysis
WHERE demand_range > overall_avg * 0.3
AND peak_demand > overall_avg * 1.5
ORDER BY demand_range DESC
""")
result = await self.session.execute(query, {"tenant_id": tenant_id})
return [dict(row._mapping) for row in result.fetchall()]
except Exception as e:
logger.error("Failed to get demand pattern analysis", error=str(e), tenant_id=str(tenant_id))
raise

View File

@@ -0,0 +1,271 @@
"""
Performance Metric Repository
Repository for model performance metrics in forecasting service
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
from datetime import datetime, timedelta, timezone
import structlog
from .base import ForecastingBaseRepository
from app.models.predictions import ModelPerformanceMetric
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class PerformanceMetricRepository(ForecastingBaseRepository):
"""Repository for model performance metrics operations"""
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 900):
# Performance metrics are stable, longer cache time (15 minutes)
super().__init__(ModelPerformanceMetric, session, cache_ttl)
async def create_metric(self, metric_data: Dict[str, Any]) -> ModelPerformanceMetric:
"""Create a new performance metric"""
try:
# Validate metric data
validation_result = self._validate_forecast_data(
metric_data,
["model_id", "tenant_id", "inventory_product_id", "evaluation_date"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid metric data: {validation_result['errors']}")
metric = await self.create(metric_data)
logger.info("Performance metric created",
metric_id=metric.id,
model_id=metric.model_id,
tenant_id=metric.tenant_id,
inventory_product_id=metric.inventory_product_id)
return metric
except ValidationError:
raise
except Exception as e:
logger.error("Failed to create performance metric",
model_id=metric_data.get("model_id"),
error=str(e))
raise DatabaseError(f"Failed to create metric: {str(e)}")
async def get_metrics_by_model(
self,
model_id: str,
skip: int = 0,
limit: int = 100
) -> List[ModelPerformanceMetric]:
"""Get all metrics for a model"""
try:
return await self.get_multi(
filters={"model_id": model_id},
skip=skip,
limit=limit,
order_by="evaluation_date",
order_desc=True
)
except Exception as e:
logger.error("Failed to get metrics by model",
model_id=model_id,
error=str(e))
raise DatabaseError(f"Failed to get metrics: {str(e)}")
async def get_latest_metric_for_model(self, model_id: str) -> Optional[ModelPerformanceMetric]:
"""Get the latest performance metric for a model"""
try:
metrics = await self.get_multi(
filters={"model_id": model_id},
limit=1,
order_by="evaluation_date",
order_desc=True
)
return metrics[0] if metrics else None
except Exception as e:
logger.error("Failed to get latest metric for model",
model_id=model_id,
error=str(e))
raise DatabaseError(f"Failed to get latest metric: {str(e)}")
async def get_performance_trends(
self,
tenant_id: str,
inventory_product_id: str = None,
days: int = 30
) -> Dict[str, Any]:
"""Get performance trends over time"""
try:
start_date = datetime.now(timezone.utc) - timedelta(days=days)
conditions = [
"tenant_id = :tenant_id",
"evaluation_date >= :start_date"
]
params = {
"tenant_id": tenant_id,
"start_date": start_date
}
if inventory_product_id:
conditions.append("inventory_product_id = :inventory_product_id")
params["inventory_product_id"] = inventory_product_id
query_text = f"""
SELECT
DATE(evaluation_date) as date,
inventory_product_id,
AVG(mae) as avg_mae,
AVG(mape) as avg_mape,
AVG(rmse) as avg_rmse,
AVG(accuracy_score) as avg_accuracy,
COUNT(*) as measurement_count
FROM model_performance_metrics
WHERE {' AND '.join(conditions)}
GROUP BY DATE(evaluation_date), inventory_product_id
ORDER BY date DESC, inventory_product_id
"""
result = await self.session.execute(text(query_text), params)
trends = []
for row in result.fetchall():
trends.append({
"date": row.date.isoformat() if row.date else None,
"inventory_product_id": row.inventory_product_id,
"metrics": {
"avg_mae": float(row.avg_mae) if row.avg_mae else None,
"avg_mape": float(row.avg_mape) if row.avg_mape else None,
"avg_rmse": float(row.avg_rmse) if row.avg_rmse else None,
"avg_accuracy": float(row.avg_accuracy) if row.avg_accuracy else None
},
"measurement_count": int(row.measurement_count)
})
return {
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id,
"period_days": days,
"trends": trends,
"total_measurements": len(trends)
}
except Exception as e:
logger.error("Failed to get performance trends",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
error=str(e))
return {
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id,
"period_days": days,
"trends": [],
"total_measurements": 0
}
async def cleanup_old_metrics(self, days_old: int = 180) -> int:
"""Clean up old performance metrics"""
return await self.cleanup_old_records(days_old=days_old)
async def bulk_create_metrics(self, metrics: List[ModelPerformanceMetric]) -> int:
"""
Bulk insert performance metrics for validation
Args:
metrics: List of ModelPerformanceMetric objects to insert
Returns:
Number of metrics created
"""
try:
if not metrics:
return 0
self.session.add_all(metrics)
await self.session.flush()
logger.info(
"Bulk created performance metrics",
count=len(metrics)
)
return len(metrics)
except Exception as e:
logger.error(
"Failed to bulk create performance metrics",
count=len(metrics),
error=str(e)
)
raise DatabaseError(f"Failed to bulk create metrics: {str(e)}")
async def get_metrics_by_date_range(
self,
tenant_id: str,
start_date: datetime,
end_date: datetime,
inventory_product_id: Optional[str] = None
) -> List[ModelPerformanceMetric]:
"""
Get performance metrics for a date range
Args:
tenant_id: Tenant identifier
start_date: Start of date range
end_date: End of date range
inventory_product_id: Optional product filter
Returns:
List of performance metrics
"""
try:
filters = {
"tenant_id": tenant_id
}
if inventory_product_id:
filters["inventory_product_id"] = inventory_product_id
# Build custom query for date range
query_text = """
SELECT *
FROM model_performance_metrics
WHERE tenant_id = :tenant_id
AND evaluation_date >= :start_date
AND evaluation_date <= :end_date
"""
params = {
"tenant_id": tenant_id,
"start_date": start_date,
"end_date": end_date
}
if inventory_product_id:
query_text += " AND inventory_product_id = :inventory_product_id"
params["inventory_product_id"] = inventory_product_id
query_text += " ORDER BY evaluation_date DESC"
result = await self.session.execute(text(query_text), params)
rows = result.fetchall()
# Convert rows to ModelPerformanceMetric objects
metrics = []
for row in rows:
metric = ModelPerformanceMetric()
for column in row._mapping.keys():
setattr(metric, column, row._mapping[column])
metrics.append(metric)
return metrics
except Exception as e:
logger.error(
"Failed to get metrics by date range",
tenant_id=tenant_id,
error=str(e)
)
raise DatabaseError(f"Failed to get metrics: {str(e)}")

View File

@@ -0,0 +1,388 @@
"""
Prediction Batch Repository
Repository for prediction batch operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
from datetime import datetime, timedelta, timezone
import structlog
from .base import ForecastingBaseRepository
from app.models.forecasts import PredictionBatch
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class PredictionBatchRepository(ForecastingBaseRepository):
"""Repository for prediction batch operations"""
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 300):
# Batch operations change frequently, shorter cache time (5 minutes)
super().__init__(PredictionBatch, session, cache_ttl)
async def create_batch(self, batch_data: Dict[str, Any]) -> PredictionBatch:
"""Create a new prediction batch"""
try:
# Validate batch data
validation_result = self._validate_forecast_data(
batch_data,
["tenant_id", "batch_name"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid batch data: {validation_result['errors']}")
# Set default values
if "status" not in batch_data:
batch_data["status"] = "pending"
if "forecast_days" not in batch_data:
batch_data["forecast_days"] = 7
if "business_type" not in batch_data:
batch_data["business_type"] = "individual"
batch = await self.create(batch_data)
logger.info("Prediction batch created",
batch_id=batch.id,
tenant_id=batch.tenant_id,
batch_name=batch.batch_name)
return batch
except ValidationError:
raise
except Exception as e:
logger.error("Failed to create prediction batch",
tenant_id=batch_data.get("tenant_id"),
error=str(e))
raise DatabaseError(f"Failed to create batch: {str(e)}")
async def update_batch_progress(
self,
batch_id: str,
completed_products: int = None,
failed_products: int = None,
total_products: int = None,
status: str = None
) -> Optional[PredictionBatch]:
"""Update batch progress"""
try:
update_data = {}
if completed_products is not None:
update_data["completed_products"] = completed_products
if failed_products is not None:
update_data["failed_products"] = failed_products
if total_products is not None:
update_data["total_products"] = total_products
if status:
update_data["status"] = status
if status in ["completed", "failed"]:
update_data["completed_at"] = datetime.now(timezone.utc)
if not update_data:
return await self.get_by_id(batch_id)
updated_batch = await self.update(batch_id, update_data)
logger.debug("Batch progress updated",
batch_id=batch_id,
status=status,
completed=completed_products)
return updated_batch
except Exception as e:
logger.error("Failed to update batch progress",
batch_id=batch_id,
error=str(e))
raise DatabaseError(f"Failed to update batch: {str(e)}")
async def complete_batch(
self,
batch_id: str,
processing_time_ms: int = None
) -> Optional[PredictionBatch]:
"""Mark batch as completed"""
try:
update_data = {
"status": "completed",
"completed_at": datetime.now(timezone.utc)
}
if processing_time_ms:
update_data["processing_time_ms"] = processing_time_ms
updated_batch = await self.update(batch_id, update_data)
logger.info("Batch completed",
batch_id=batch_id,
processing_time_ms=processing_time_ms)
return updated_batch
except Exception as e:
logger.error("Failed to complete batch",
batch_id=batch_id,
error=str(e))
raise DatabaseError(f"Failed to complete batch: {str(e)}")
async def fail_batch(
self,
batch_id: str,
error_message: str,
processing_time_ms: int = None
) -> Optional[PredictionBatch]:
"""Mark batch as failed"""
try:
update_data = {
"status": "failed",
"completed_at": datetime.now(timezone.utc),
"error_message": error_message
}
if processing_time_ms:
update_data["processing_time_ms"] = processing_time_ms
updated_batch = await self.update(batch_id, update_data)
logger.error("Batch failed",
batch_id=batch_id,
error_message=error_message)
return updated_batch
except Exception as e:
logger.error("Failed to mark batch as failed",
batch_id=batch_id,
error=str(e))
raise DatabaseError(f"Failed to fail batch: {str(e)}")
async def cancel_batch(
self,
batch_id: str,
cancelled_by: str = None
) -> Optional[PredictionBatch]:
"""Cancel a batch"""
try:
batch = await self.get_by_id(batch_id)
if not batch:
return None
if batch.status in ["completed", "failed"]:
logger.warning("Cannot cancel finished batch",
batch_id=batch_id,
status=batch.status)
return batch
update_data = {
"status": "cancelled",
"completed_at": datetime.now(timezone.utc),
"cancelled_by": cancelled_by,
"error_message": f"Cancelled by {cancelled_by}" if cancelled_by else "Cancelled"
}
updated_batch = await self.update(batch_id, update_data)
logger.info("Batch cancelled",
batch_id=batch_id,
cancelled_by=cancelled_by)
return updated_batch
except Exception as e:
logger.error("Failed to cancel batch",
batch_id=batch_id,
error=str(e))
raise DatabaseError(f"Failed to cancel batch: {str(e)}")
async def get_active_batches(self, tenant_id: str = None) -> List[PredictionBatch]:
"""Get currently active (pending/processing) batches"""
try:
filters = {"status": "processing"}
if tenant_id:
# Need to handle multiple status values with raw query
query_text = """
SELECT * FROM prediction_batches
WHERE status IN ('pending', 'processing')
AND tenant_id = :tenant_id
ORDER BY requested_at DESC
"""
params = {"tenant_id": tenant_id}
else:
query_text = """
SELECT * FROM prediction_batches
WHERE status IN ('pending', 'processing')
ORDER BY requested_at DESC
"""
params = {}
result = await self.session.execute(text(query_text), params)
batches = []
for row in result.fetchall():
record_dict = dict(row._mapping)
batch = self.model(**record_dict)
batches.append(batch)
return batches
except Exception as e:
logger.error("Failed to get active batches",
tenant_id=tenant_id,
error=str(e))
return []
async def get_batch_statistics(self, tenant_id: str = None) -> Dict[str, Any]:
"""Get batch processing statistics"""
try:
base_filter = "WHERE 1=1"
params = {}
if tenant_id:
base_filter = "WHERE tenant_id = :tenant_id"
params["tenant_id"] = tenant_id
# Get counts by status
status_query = text(f"""
SELECT
status,
COUNT(*) as count,
AVG(CASE WHEN processing_time_ms IS NOT NULL THEN processing_time_ms END) as avg_processing_time_ms
FROM prediction_batches
{base_filter}
GROUP BY status
""")
result = await self.session.execute(status_query, params)
status_stats = {}
total_batches = 0
avg_processing_times = {}
for row in result.fetchall():
status_stats[row.status] = row.count
total_batches += row.count
if row.avg_processing_time_ms:
avg_processing_times[row.status] = float(row.avg_processing_time_ms)
# Get recent activity (batches in last 7 days)
seven_days_ago = datetime.now(timezone.utc) - timedelta(days=7)
recent_query = text(f"""
SELECT COUNT(*) as count
FROM prediction_batches
{base_filter}
AND requested_at >= :seven_days_ago
""")
recent_result = await self.session.execute(recent_query, {
**params,
"seven_days_ago": seven_days_ago
})
recent_batches = recent_result.scalar() or 0
# Calculate success rate
completed = status_stats.get("completed", 0)
failed = status_stats.get("failed", 0)
cancelled = status_stats.get("cancelled", 0)
finished_batches = completed + failed + cancelled
success_rate = (completed / finished_batches * 100) if finished_batches > 0 else 0
return {
"total_batches": total_batches,
"batches_by_status": status_stats,
"success_rate": round(success_rate, 2),
"recent_batches_7d": recent_batches,
"avg_processing_times_ms": avg_processing_times
}
except Exception as e:
logger.error("Failed to get batch statistics",
tenant_id=tenant_id,
error=str(e))
return {
"total_batches": 0,
"batches_by_status": {},
"success_rate": 0.0,
"recent_batches_7d": 0,
"avg_processing_times_ms": {}
}
async def cleanup_old_batches(self, days_old: int = 30) -> int:
"""Clean up old completed/failed batches"""
try:
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_old)
query_text = """
DELETE FROM prediction_batches
WHERE status IN ('completed', 'failed', 'cancelled')
AND completed_at < :cutoff_date
"""
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
deleted_count = result.rowcount
logger.info("Cleaned up old prediction batches",
deleted_count=deleted_count,
days_old=days_old)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup old batches",
error=str(e))
raise DatabaseError(f"Batch cleanup failed: {str(e)}")
async def get_batch_details(self, batch_id: str) -> Dict[str, Any]:
"""Get detailed batch information"""
try:
batch = await self.get_by_id(batch_id)
if not batch:
return {"error": "Batch not found"}
# Calculate completion percentage
completion_percentage = 0
if batch.total_products > 0:
completion_percentage = (batch.completed_products / batch.total_products) * 100
# Calculate elapsed time
elapsed_time_ms = 0
if batch.completed_at:
elapsed_time_ms = int((batch.completed_at - batch.requested_at).total_seconds() * 1000)
elif batch.status in ["pending", "processing"]:
elapsed_time_ms = int((datetime.now(timezone.utc) - batch.requested_at).total_seconds() * 1000)
return {
"batch_id": str(batch.id),
"tenant_id": str(batch.tenant_id),
"batch_name": batch.batch_name,
"status": batch.status,
"progress": {
"total_products": batch.total_products,
"completed_products": batch.completed_products,
"failed_products": batch.failed_products,
"completion_percentage": round(completion_percentage, 2)
},
"timing": {
"requested_at": batch.requested_at.isoformat(),
"completed_at": batch.completed_at.isoformat() if batch.completed_at else None,
"elapsed_time_ms": elapsed_time_ms,
"processing_time_ms": batch.processing_time_ms
},
"configuration": {
"forecast_days": batch.forecast_days,
"business_type": batch.business_type
},
"error_message": batch.error_message,
"cancelled_by": batch.cancelled_by
}
except Exception as e:
logger.error("Failed to get batch details",
batch_id=batch_id,
error=str(e))
return {"error": f"Failed to get batch details: {str(e)}"}

View File

@@ -0,0 +1,302 @@
"""
Prediction Cache Repository
Repository for prediction cache operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
from datetime import datetime, timedelta, timezone
import structlog
import hashlib
from .base import ForecastingBaseRepository
from app.models.predictions import PredictionCache
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class PredictionCacheRepository(ForecastingBaseRepository):
"""Repository for prediction cache operations"""
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 60):
# Cache entries change very frequently, short cache time (1 minute)
super().__init__(PredictionCache, session, cache_ttl)
def _generate_cache_key(
self,
tenant_id: str,
inventory_product_id: str,
location: str,
forecast_date: datetime
) -> str:
"""Generate cache key for prediction"""
key_data = f"{tenant_id}:{inventory_product_id}:{location}:{forecast_date.isoformat()}"
return hashlib.md5(key_data.encode()).hexdigest()
async def cache_prediction(
self,
tenant_id: str,
inventory_product_id: str,
location: str,
forecast_date: datetime,
predicted_demand: float,
confidence_lower: float,
confidence_upper: float,
model_id: str,
expires_in_hours: int = 24
) -> PredictionCache:
"""Cache a prediction result"""
try:
cache_key = self._generate_cache_key(tenant_id, inventory_product_id, location, forecast_date)
expires_at = datetime.now(timezone.utc) + timedelta(hours=expires_in_hours)
cache_data = {
"cache_key": cache_key,
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id,
"location": location,
"forecast_date": forecast_date,
"predicted_demand": predicted_demand,
"confidence_lower": confidence_lower,
"confidence_upper": confidence_upper,
"model_id": model_id,
"expires_at": expires_at,
"hit_count": 0
}
# Try to update existing cache entry first
existing_cache = await self.get_by_field("cache_key", cache_key)
if existing_cache:
cache_entry = await self.update(existing_cache.id, cache_data)
logger.debug("Updated cache entry", cache_key=cache_key)
else:
cache_entry = await self.create(cache_data)
logger.debug("Created cache entry", cache_key=cache_key)
return cache_entry
except Exception as e:
logger.error("Failed to cache prediction",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
error=str(e))
raise DatabaseError(f"Failed to cache prediction: {str(e)}")
async def get_cached_prediction(
self,
tenant_id: str,
inventory_product_id: str,
location: str,
forecast_date: datetime
) -> Optional[PredictionCache]:
"""Get cached prediction if valid"""
try:
cache_key = self._generate_cache_key(tenant_id, inventory_product_id, location, forecast_date)
cache_entry = await self.get_by_field("cache_key", cache_key)
if not cache_entry:
logger.debug("Cache miss", cache_key=cache_key)
return None
# Check if cache entry has expired
if cache_entry.expires_at < datetime.now(timezone.utc):
logger.debug("Cache expired", cache_key=cache_key)
await self.delete(cache_entry.id)
return None
# Increment hit count
await self.update(cache_entry.id, {"hit_count": cache_entry.hit_count + 1})
logger.debug("Cache hit",
cache_key=cache_key,
hit_count=cache_entry.hit_count + 1)
return cache_entry
except Exception as e:
logger.error("Failed to get cached prediction",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
error=str(e))
return None
async def invalidate_cache(
self,
tenant_id: str,
inventory_product_id: str = None,
location: str = None
) -> int:
"""Invalidate cache entries"""
try:
conditions = ["tenant_id = :tenant_id"]
params = {"tenant_id": tenant_id}
if inventory_product_id:
conditions.append("inventory_product_id = :inventory_product_id")
params["inventory_product_id"] = inventory_product_id
if location:
conditions.append("location = :location")
params["location"] = location
query_text = f"""
DELETE FROM prediction_cache
WHERE {' AND '.join(conditions)}
"""
result = await self.session.execute(text(query_text), params)
invalidated_count = result.rowcount
logger.info("Cache invalidated",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
location=location,
invalidated_count=invalidated_count)
return invalidated_count
except Exception as e:
logger.error("Failed to invalidate cache",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Cache invalidation failed: {str(e)}")
async def cleanup_expired_cache(self) -> int:
"""Clean up expired cache entries"""
try:
query_text = """
DELETE FROM prediction_cache
WHERE expires_at < :now
"""
result = await self.session.execute(text(query_text), {"now": datetime.now(timezone.utc)})
deleted_count = result.rowcount
logger.info("Cleaned up expired cache entries",
deleted_count=deleted_count)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup expired cache",
error=str(e))
raise DatabaseError(f"Cache cleanup failed: {str(e)}")
async def get_cache_statistics(self, tenant_id: str = None) -> Dict[str, Any]:
"""Get cache performance statistics"""
try:
base_filter = "WHERE 1=1"
params = {}
if tenant_id:
base_filter = "WHERE tenant_id = :tenant_id"
params["tenant_id"] = tenant_id
# Get cache statistics
stats_query = text(f"""
SELECT
COUNT(*) as total_entries,
COUNT(CASE WHEN expires_at > :now THEN 1 END) as active_entries,
COUNT(CASE WHEN expires_at <= :now THEN 1 END) as expired_entries,
SUM(hit_count) as total_hits,
AVG(hit_count) as avg_hits_per_entry,
MAX(hit_count) as max_hits,
COUNT(DISTINCT inventory_product_id) as unique_products
FROM prediction_cache
{base_filter}
""")
params["now"] = datetime.now(timezone.utc)
result = await self.session.execute(stats_query, params)
row = result.fetchone()
if row:
return {
"total_entries": int(row.total_entries or 0),
"active_entries": int(row.active_entries or 0),
"expired_entries": int(row.expired_entries or 0),
"total_hits": int(row.total_hits or 0),
"avg_hits_per_entry": float(row.avg_hits_per_entry or 0),
"max_hits": int(row.max_hits or 0),
"unique_products": int(row.unique_products or 0),
"cache_hit_ratio": round((row.total_hits / max(row.total_entries, 1)), 2)
}
return {
"total_entries": 0,
"active_entries": 0,
"expired_entries": 0,
"total_hits": 0,
"avg_hits_per_entry": 0.0,
"max_hits": 0,
"unique_products": 0,
"cache_hit_ratio": 0.0
}
except Exception as e:
logger.error("Failed to get cache statistics",
tenant_id=tenant_id,
error=str(e))
return {
"total_entries": 0,
"active_entries": 0,
"expired_entries": 0,
"total_hits": 0,
"avg_hits_per_entry": 0.0,
"max_hits": 0,
"unique_products": 0,
"cache_hit_ratio": 0.0
}
async def get_most_accessed_predictions(
self,
tenant_id: str = None,
limit: int = 10
) -> List[Dict[str, Any]]:
"""Get most frequently accessed cached predictions"""
try:
base_filter = "WHERE hit_count > 0"
params = {"limit": limit}
if tenant_id:
base_filter = "WHERE tenant_id = :tenant_id AND hit_count > 0"
params["tenant_id"] = tenant_id
query_text = f"""
SELECT
inventory_product_id,
location,
hit_count,
predicted_demand,
created_at,
expires_at
FROM prediction_cache
{base_filter}
ORDER BY hit_count DESC
LIMIT :limit
"""
result = await self.session.execute(text(query_text), params)
popular_predictions = []
for row in result.fetchall():
popular_predictions.append({
"inventory_product_id": row.inventory_product_id,
"location": row.location,
"hit_count": int(row.hit_count),
"predicted_demand": float(row.predicted_demand),
"created_at": row.created_at.isoformat() if row.created_at else None,
"expires_at": row.expires_at.isoformat() if row.expires_at else None
})
return popular_predictions
except Exception as e:
logger.error("Failed to get most accessed predictions",
tenant_id=tenant_id,
error=str(e))
return []

View File

@@ -0,0 +1,302 @@
# ================================================================
# services/forecasting/app/schemas/forecasts.py
# ================================================================
"""
Forecast schemas for request/response validation
"""
from pydantic import BaseModel, Field, validator
from datetime import datetime, date
from typing import Optional, List, Dict, Any
from enum import Enum
from uuid import UUID
class BusinessType(str, Enum):
INDIVIDUAL = "individual"
CENTRAL_WORKSHOP = "central_workshop"
class ForecastRequest(BaseModel):
"""Request schema for generating forecasts"""
inventory_product_id: str = Field(..., description="Inventory product UUID reference")
# product_name: str = Field(..., description="Product name") # DEPRECATED - use inventory_product_id
forecast_date: date = Field(..., description="Starting date for forecast")
forecast_days: int = Field(1, ge=1, le=30, description="Number of days to forecast")
location: str = Field(..., description="Location identifier")
# Optional parameters - internally handled
confidence_level: float = Field(0.8, ge=0.5, le=0.95, description="Confidence level")
@validator('inventory_product_id')
def validate_inventory_product_id(cls, v):
"""Validate that inventory_product_id is a valid UUID"""
try:
UUID(v)
except (ValueError, AttributeError):
raise ValueError(f"inventory_product_id must be a valid UUID, got: {v}")
return v
@validator('forecast_date')
def validate_forecast_date(cls, v):
if v < date.today():
raise ValueError("Forecast date cannot be in the past")
return v
class BatchForecastRequest(BaseModel):
"""Request schema for batch forecasting"""
tenant_id: Optional[str] = None # Optional, can be from path parameter
batch_name: str = Field(..., description="Batch name for tracking")
inventory_product_ids: List[str] = Field(..., description="List of inventory product IDs")
forecast_days: int = Field(7, ge=1, le=30, description="Number of days to forecast")
@validator('tenant_id')
def validate_tenant_id(cls, v):
"""Validate that tenant_id is a valid UUID if provided"""
if v is not None:
try:
UUID(v)
except (ValueError, AttributeError):
raise ValueError(f"tenant_id must be a valid UUID, got: {v}")
return v
@validator('inventory_product_ids')
def validate_inventory_product_ids(cls, v):
"""Validate that all inventory_product_ids are valid UUIDs"""
for product_id in v:
try:
UUID(product_id)
except (ValueError, AttributeError):
raise ValueError(f"All inventory_product_ids must be valid UUIDs, got invalid: {product_id}")
return v
class ForecastResponse(BaseModel):
"""Response schema for forecast results"""
id: str
tenant_id: str
inventory_product_id: str # Reference to inventory service
# product_name: str # Can be fetched from inventory service if needed for display
location: str
forecast_date: datetime
# Predictions
predicted_demand: float
confidence_lower: float
confidence_upper: float
confidence_level: float
# Model info
model_id: str
model_version: str
algorithm: str
# Context
business_type: str
is_holiday: bool
is_weekend: bool
day_of_week: int
# External factors
weather_temperature: Optional[float]
weather_precipitation: Optional[float]
weather_description: Optional[str]
traffic_volume: Optional[int]
# Metadata
created_at: datetime
processing_time_ms: Optional[int]
features_used: Optional[Dict[str, Any]]
class BatchForecastResponse(BaseModel):
"""Response schema for batch forecast requests"""
id: str
tenant_id: str
batch_name: str
status: str
total_products: int
completed_products: int
failed_products: int
# Timing
requested_at: datetime
completed_at: Optional[datetime]
processing_time_ms: Optional[int]
# Results
forecasts: Optional[List[ForecastResponse]]
error_message: Optional[str]
class MultiDayForecastResponse(BaseModel):
"""Response schema for multi-day forecast results"""
tenant_id: str = Field(..., description="Tenant ID")
inventory_product_id: str = Field(..., description="Inventory product ID")
forecast_start_date: date = Field(..., description="Start date of forecast period")
forecast_days: int = Field(..., description="Number of forecasted days")
forecasts: List[ForecastResponse] = Field(..., description="Daily forecasts")
total_predicted_demand: float = Field(..., description="Total demand across all days")
average_confidence_level: float = Field(..., description="Average confidence across all days")
processing_time_ms: int = Field(..., description="Total processing time")
# ================================================================
# SCENARIO SIMULATION SCHEMAS - PROFESSIONAL/ENTERPRISE ONLY
# ================================================================
class ScenarioType(str, Enum):
"""Types of scenarios available for simulation"""
WEATHER = "weather" # Weather impact (heatwave, cold snap, rain, etc.)
COMPETITION = "competition" # New competitor opening nearby
EVENT = "event" # Local event (festival, sports, concert, etc.)
PRICING = "pricing" # Price changes
PROMOTION = "promotion" # Promotional campaigns
HOLIDAY = "holiday" # Holiday periods
SUPPLY_DISRUPTION = "supply_disruption" # Supply chain issues
CUSTOM = "custom" # Custom user-defined scenario
class WeatherScenario(BaseModel):
"""Weather scenario parameters"""
temperature_change: Optional[float] = Field(None, ge=-30, le=30, description="Temperature change in °C")
precipitation_change: Optional[float] = Field(None, ge=0, le=100, description="Precipitation change in mm")
weather_type: Optional[str] = Field(None, description="Weather type (heatwave, cold_snap, rainy, etc.)")
class CompetitionScenario(BaseModel):
"""Competition scenario parameters"""
new_competitors: int = Field(1, ge=1, le=10, description="Number of new competitors")
distance_km: float = Field(0.5, ge=0.1, le=10, description="Distance from location in km")
estimated_market_share_loss: float = Field(0.1, ge=0, le=0.5, description="Estimated market share loss (0-50%)")
class EventScenario(BaseModel):
"""Event scenario parameters"""
event_type: str = Field(..., description="Type of event (festival, sports, concert, etc.)")
expected_attendance: int = Field(..., ge=0, description="Expected attendance")
distance_km: float = Field(0.5, ge=0, le=50, description="Distance from location in km")
duration_days: int = Field(1, ge=1, le=30, description="Duration in days")
class PricingScenario(BaseModel):
"""Pricing scenario parameters"""
price_change_percent: float = Field(..., ge=-50, le=100, description="Price change percentage")
affected_products: Optional[List[str]] = Field(None, description="List of affected product IDs")
class PromotionScenario(BaseModel):
"""Promotion scenario parameters"""
discount_percent: float = Field(..., ge=0, le=75, description="Discount percentage")
promotion_type: str = Field(..., description="Type of promotion (bogo, discount, bundle, etc.)")
expected_traffic_increase: float = Field(0.2, ge=0, le=2, description="Expected traffic increase (0-200%)")
class ScenarioSimulationRequest(BaseModel):
"""Request schema for scenario simulation - PROFESSIONAL/ENTERPRISE ONLY"""
scenario_name: str = Field(..., min_length=3, max_length=200, description="Name for this scenario")
scenario_type: ScenarioType = Field(..., description="Type of scenario to simulate")
inventory_product_ids: List[str] = Field(..., min_items=1, description="Products to simulate")
start_date: date = Field(..., description="Simulation start date")
duration_days: int = Field(7, ge=1, le=30, description="Simulation duration in days")
# Scenario-specific parameters (one should be provided based on scenario_type)
weather_params: Optional[WeatherScenario] = None
competition_params: Optional[CompetitionScenario] = None
event_params: Optional[EventScenario] = None
pricing_params: Optional[PricingScenario] = None
promotion_params: Optional[PromotionScenario] = None
# Custom scenario parameters
custom_multipliers: Optional[Dict[str, float]] = Field(
None,
description="Custom multipliers for baseline forecast (e.g., {'demand': 1.2, 'traffic': 0.8})"
)
# Comparison settings
include_baseline: bool = Field(True, description="Include baseline forecast for comparison")
@validator('start_date')
def validate_start_date(cls, v):
if v < date.today():
raise ValueError("Simulation start date cannot be in the past")
return v
class ScenarioImpact(BaseModel):
"""Impact of scenario on a specific product"""
inventory_product_id: str
baseline_demand: float
simulated_demand: float
demand_change_percent: float
confidence_range: tuple[float, float]
impact_factors: Dict[str, Any] # Breakdown of what drove the change
class ScenarioSimulationResponse(BaseModel):
"""Response schema for scenario simulation"""
id: str = Field(..., description="Simulation ID")
tenant_id: str
scenario_name: str
scenario_type: ScenarioType
# Simulation parameters
start_date: date
end_date: date
duration_days: int
# Results
baseline_forecasts: Optional[List[ForecastResponse]] = Field(
None,
description="Baseline forecasts (if requested)"
)
scenario_forecasts: List[ForecastResponse] = Field(..., description="Forecasts with scenario applied")
# Impact summary
total_baseline_demand: float
total_scenario_demand: float
overall_impact_percent: float
product_impacts: List[ScenarioImpact]
# Insights and recommendations
insights: List[str] = Field(..., description="AI-generated insights about the scenario")
recommendations: List[str] = Field(..., description="Actionable recommendations")
risk_level: str = Field(..., description="Risk level: low, medium, high")
# Metadata
created_at: datetime
processing_time_ms: int
class Config:
json_schema_extra = {
"example": {
"id": "scenario_123",
"tenant_id": "tenant_456",
"scenario_name": "Summer Heatwave Impact",
"scenario_type": "weather",
"overall_impact_percent": 15.5,
"insights": [
"Cold beverages expected to increase by 45%",
"Bread products may decrease by 8% due to reduced appetite",
"Ice cream demand projected to surge by 120%"
],
"recommendations": [
"Increase cold beverage inventory by 40%",
"Reduce bread production by 10%",
"Stock additional ice cream varieties"
],
"risk_level": "medium"
}
}
class ScenarioComparisonRequest(BaseModel):
"""Request to compare multiple scenarios"""
scenario_ids: List[str] = Field(..., min_items=2, max_items=5, description="Scenario IDs to compare")
class ScenarioComparisonResponse(BaseModel):
"""Response comparing multiple scenarios"""
scenarios: List[ScenarioSimulationResponse]
comparison_matrix: Dict[str, Dict[str, Any]]
best_case_scenario_id: str
worst_case_scenario_id: str
recommended_action: str

View File

@@ -0,0 +1,17 @@
"""
Forecasting Service Layer
Business logic services for demand forecasting and prediction
"""
from .forecasting_service import ForecastingService, EnhancedForecastingService
from .prediction_service import PredictionService
from .model_client import ModelClient
from .data_client import DataClient
__all__ = [
"ForecastingService",
"EnhancedForecastingService",
"PredictionService",
"ModelClient",
"DataClient"
]

View File

@@ -0,0 +1,132 @@
# services/training/app/services/data_client.py
"""
Training Service Data Client
Migrated to use shared service clients - much simpler now!
"""
import structlog
from typing import Dict, Any, List, Optional
from datetime import datetime
# Import the shared clients
from shared.clients import get_sales_client, get_external_client, get_service_clients
from app.core.config import settings
logger = structlog.get_logger()
class DataClient:
"""
Data client for training service
Now uses the shared data service client under the hood
"""
def __init__(self):
# Get the new specialized clients
self.sales_client = get_sales_client(settings, "forecasting")
self.external_client = get_external_client(settings, "forecasting")
# Or alternatively, get all clients at once:
# self.clients = get_service_clients(settings, "forecasting")
# Then use: self.clients.sales.get_sales_data(...) and self.clients.external.get_weather_forecast(...)
async def fetch_weather_forecast(
self,
tenant_id: str,
days: int = 7,
latitude: Optional[float] = None,
longitude: Optional[float] = None
) -> List[Dict[str, Any]]:
"""
Fetch weather forecast data
Uses new v2.0 optimized endpoint via shared external client
"""
try:
weather_data = await self.external_client.get_weather_forecast(
tenant_id=tenant_id,
days=days,
latitude=latitude,
longitude=longitude
)
if weather_data:
logger.info(f"Fetched {len(weather_data)} weather records",
tenant_id=tenant_id)
return weather_data
else:
logger.warning("No weather data returned", tenant_id=tenant_id)
return []
except Exception as e:
logger.error(f"Error fetching weather data: {e}", tenant_id=tenant_id)
return []
async def fetch_tenant_calendar(
self,
tenant_id: str
) -> Optional[Dict[str, Any]]:
"""
Fetch tenant's assigned school calendar
Returns None if no calendar assigned
"""
try:
location_context = await self.external_client.get_tenant_location_context(
tenant_id=tenant_id
)
if location_context and location_context.get("calendar"):
logger.info(
"Fetched calendar for tenant",
tenant_id=tenant_id,
calendar_name=location_context["calendar"].get("calendar_name")
)
return location_context["calendar"]
else:
logger.info("No calendar assigned to tenant", tenant_id=tenant_id)
return None
except Exception as e:
logger.error(f"Error fetching calendar: {e}", tenant_id=tenant_id)
return None
async def check_school_holiday(
self,
calendar_id: str,
check_date: str,
tenant_id: str
) -> bool:
"""
Check if a date is a school holiday
Args:
calendar_id: School calendar UUID
check_date: Date in ISO format (YYYY-MM-DD)
tenant_id: Tenant ID for auth
Returns:
True if school holiday, False otherwise
"""
try:
result = await self.external_client.check_is_school_holiday(
calendar_id=calendar_id,
check_date=check_date,
tenant_id=tenant_id
)
if result:
is_holiday = result.get("is_holiday", False)
if is_holiday:
logger.debug(
"School holiday detected",
date=check_date,
holiday_name=result.get("holiday_name")
)
return is_holiday
return False
except Exception as e:
logger.error(f"Error checking school holiday: {e}", date=check_date)
return False
# Global instance - same as before, but much simpler implementation
data_client = DataClient()

View File

@@ -0,0 +1,260 @@
"""
Enterprise forecasting service for aggregated demand across parent-child tenants
"""
import logging
from typing import Dict, Any, List, Optional
from datetime import date, datetime
import json
import redis.asyncio as redis
from shared.clients.forecast_client import ForecastServiceClient
from shared.clients.tenant_client import TenantServiceClient
logger = logging.getLogger(__name__)
class EnterpriseForecastingService:
"""
Service for aggregating forecasts across parent and child tenants
"""
def __init__(
self,
forecast_client: ForecastServiceClient,
tenant_client: TenantServiceClient,
redis_client: redis.Redis
):
self.forecast_client = forecast_client
self.tenant_client = tenant_client
self.redis_client = redis_client
self.cache_ttl_seconds = 3600 # 1 hour TTL
async def get_aggregated_forecast(
self,
parent_tenant_id: str,
start_date: date,
end_date: date,
product_id: Optional[str] = None
) -> Dict[str, Any]:
"""
Get aggregated forecast across parent and all child tenants
Args:
parent_tenant_id: Parent tenant ID
start_date: Start date for forecast aggregation
end_date: End date for forecast aggregation
product_id: Optional product ID to filter by
Returns:
Dict with aggregated forecast data by date and product
"""
# Create cache key
cache_key = f"agg_forecast:{parent_tenant_id}:{start_date}:{end_date}:{product_id or 'all'}"
# Try to get from cache first
try:
cached_result = await self.redis_client.get(cache_key)
if cached_result:
logger.info(f"Cache hit for aggregated forecast: {cache_key}")
return json.loads(cached_result)
except Exception as e:
logger.warning(f"Cache read failed: {e}")
logger.info(f"Computing aggregated forecast for parent {parent_tenant_id} from {start_date} to {end_date}")
# Get child tenant IDs
child_tenants = await self.tenant_client.get_child_tenants(parent_tenant_id)
child_tenant_ids = [child['id'] for child in child_tenants]
# Include parent tenant in the list for complete aggregation
all_tenant_ids = [parent_tenant_id] + child_tenant_ids
# Fetch forecasts for all tenants (parent + children)
all_forecasts = {}
tenant_contributions = {} # Track which tenant contributed to each forecast
for tenant_id in all_tenant_ids:
try:
tenant_forecasts = await self.forecast_client.get_forecasts(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
product_id=product_id
)
for forecast_date_str, products in tenant_forecasts.items():
if forecast_date_str not in all_forecasts:
all_forecasts[forecast_date_str] = {}
tenant_contributions[forecast_date_str] = {}
for product_id_key, forecast_data in products.items():
if product_id_key not in all_forecasts[forecast_date_str]:
all_forecasts[forecast_date_str][product_id_key] = {
'predicted_demand': 0,
'confidence_lower': 0,
'confidence_upper': 0,
'tenant_contributions': []
}
# Aggregate the forecast values
all_forecasts[forecast_date_str][product_id_key]['predicted_demand'] += forecast_data.get('predicted_demand', 0)
# For confidence intervals, we'll use a simple approach
# In a real implementation, this would require proper statistical combination
all_forecasts[forecast_date_str][product_id_key]['confidence_lower'] += forecast_data.get('confidence_lower', 0)
all_forecasts[forecast_date_str][product_id_key]['confidence_upper'] += forecast_data.get('confidence_upper', 0)
# Track contribution by tenant
all_forecasts[forecast_date_str][product_id_key]['tenant_contributions'].append({
'tenant_id': tenant_id,
'demand': forecast_data.get('predicted_demand', 0),
'confidence_lower': forecast_data.get('confidence_lower', 0),
'confidence_upper': forecast_data.get('confidence_upper', 0)
})
except Exception as e:
logger.error(f"Failed to fetch forecasts for tenant {tenant_id}: {e}")
# Continue with other tenants even if one fails
# Prepare result
result = {
"parent_tenant_id": parent_tenant_id,
"aggregated_forecasts": all_forecasts,
"tenant_contributions": tenant_contributions,
"child_tenant_count": len(child_tenant_ids),
"forecast_dates": list(all_forecasts.keys()),
"computed_at": datetime.utcnow().isoformat()
}
# Cache the result
try:
await self.redis_client.setex(
cache_key,
self.cache_ttl_seconds,
json.dumps(result, default=str) # Handle date serialization
)
logger.info(f"Forecast cached for {cache_key}")
except Exception as e:
logger.warning(f"Cache write failed: {e}")
return result
async def get_network_performance_metrics(
self,
parent_tenant_id: str,
start_date: date,
end_date: date
) -> Dict[str, Any]:
"""
Get aggregated performance metrics across the tenant network
Args:
parent_tenant_id: Parent tenant ID
start_date: Start date for metrics
end_date: End date for metrics
Returns:
Dict with aggregated performance metrics
"""
child_tenants = await self.tenant_client.get_child_tenants(parent_tenant_id)
child_tenant_ids = [child['id'] for child in child_tenants]
# Include parent tenant in the list for complete aggregation
all_tenant_ids = [parent_tenant_id] + child_tenant_ids
total_sales = 0
total_forecasted = 0
total_accuracy = 0
tenant_count = 0
performance_data = {}
for tenant_id in all_tenant_ids:
try:
# Fetch sales and forecast data for the period
sales_data = await self._fetch_sales_data(tenant_id, start_date, end_date)
forecast_data = await self.get_aggregated_forecast(tenant_id, start_date, end_date)
tenant_performance = {
'tenant_id': tenant_id,
'sales': sales_data.get('total_sales', 0),
'forecasted': sum(
sum(day.get('predicted_demand', 0) for product in day.values())
if isinstance(day, dict) else day
for day in forecast_data.get('aggregated_forecasts', {}).values()
),
}
# Calculate accuracy if both sales and forecast data exist
if tenant_performance['sales'] > 0 and tenant_performance['forecasted'] > 0:
accuracy = 1 - abs(tenant_performance['forecasted'] - tenant_performance['sales']) / tenant_performance['sales']
tenant_performance['accuracy'] = max(0, min(1, accuracy)) # Clamp between 0 and 1
else:
tenant_performance['accuracy'] = 0
performance_data[tenant_id] = tenant_performance
total_sales += tenant_performance['sales']
total_forecasted += tenant_performance['forecasted']
total_accuracy += tenant_performance['accuracy']
tenant_count += 1
except Exception as e:
logger.error(f"Failed to fetch performance data for tenant {tenant_id}: {e}")
network_performance = {
"parent_tenant_id": parent_tenant_id,
"total_sales": total_sales,
"total_forecasted": total_forecasted,
"average_accuracy": total_accuracy / tenant_count if tenant_count > 0 else 0,
"tenant_count": tenant_count,
"child_tenant_count": len(child_tenant_ids),
"tenant_performances": performance_data,
"computed_at": datetime.utcnow().isoformat()
}
return network_performance
async def _fetch_sales_data(self, tenant_id: str, start_date: date, end_date: date) -> Dict[str, Any]:
"""
Helper method to fetch sales data from the sales service
"""
try:
from shared.clients.sales_client import SalesServiceClient
from shared.config.base import get_settings
# Create sales client
config = get_settings()
sales_client = SalesServiceClient(config, calling_service_name="forecasting")
# Fetch sales data for the date range
sales_data = await sales_client.get_sales_data(
tenant_id=tenant_id,
start_date=start_date.isoformat(),
end_date=end_date.isoformat(),
aggregation="daily"
)
# Calculate total sales from the retrieved data
total_sales = 0
if sales_data:
for sale in sales_data:
# Sum up quantity_sold or total_amount depending on what's available
total_sales += sale.get('quantity_sold', 0)
return {
'total_sales': total_sales,
'date_range': f"{start_date} to {end_date}",
'tenant_id': tenant_id,
'record_count': len(sales_data) if sales_data else 0
}
except Exception as e:
logger.error(f"Failed to fetch sales data for tenant {tenant_id}: {e}")
# Return empty result on error
return {
'total_sales': 0,
'date_range': f"{start_date} to {end_date}",
'tenant_id': tenant_id,
'error': str(e)
}

View File

@@ -0,0 +1,495 @@
# services/forecasting/app/services/forecast_cache.py
"""
Forecast Cache Service - Redis-based caching for forecast results
Provides service-level caching for forecast predictions to eliminate redundant
computations when multiple services (Orders, Production) request the same
forecast data within a short time window.
Cache Strategy:
- Key: forecast:{tenant_id}:{product_id}:{forecast_date}
- TTL: Until midnight of day after forecast_date
- Invalidation: On model retraining for specific products
- Metadata: Includes 'cached' flag for observability
"""
import json
from datetime import datetime, date, timedelta
from typing import Optional, Dict, Any, List
from uuid import UUID
import structlog
from shared.redis_utils import get_redis_client
logger = structlog.get_logger()
class ForecastCacheService:
"""Service-level caching for forecast predictions"""
def __init__(self):
"""Initialize forecast cache service"""
pass
async def _get_redis(self):
"""Get shared Redis client"""
return await get_redis_client()
async def is_available(self) -> bool:
"""Check if Redis cache is available"""
try:
client = await self._get_redis()
await client.ping()
return True
except Exception:
return False
# ================================================================
# FORECAST CACHING
# ================================================================
def _get_forecast_key(
self,
tenant_id: UUID,
product_id: UUID,
forecast_date: date
) -> str:
"""Generate cache key for forecast"""
return f"forecast:{tenant_id}:{product_id}:{forecast_date.isoformat()}"
def _get_batch_forecast_key(
self,
tenant_id: UUID,
product_ids: List[UUID],
forecast_date: date
) -> str:
"""Generate cache key for batch forecast"""
# Sort product IDs for consistent key generation
sorted_ids = sorted(str(pid) for pid in product_ids)
products_hash = hash(tuple(sorted_ids))
return f"forecast:batch:{tenant_id}:{products_hash}:{forecast_date.isoformat()}"
def _calculate_ttl(self, forecast_date: date) -> int:
"""
Calculate TTL for forecast cache entry
Forecasts expire at midnight of the day after forecast_date.
This ensures forecasts remain cached throughout the forecasted day
but don't become stale.
Args:
forecast_date: Date of the forecast
Returns:
TTL in seconds
"""
# Expire at midnight after forecast_date
expiry_datetime = datetime.combine(
forecast_date + timedelta(days=1),
datetime.min.time()
)
now = datetime.now()
ttl_seconds = int((expiry_datetime - now).total_seconds())
# Minimum TTL of 1 hour, maximum of 48 hours
return max(3600, min(ttl_seconds, 172800))
async def get_cached_forecast(
self,
tenant_id: UUID,
product_id: UUID,
forecast_date: date
) -> Optional[Dict[str, Any]]:
"""
Retrieve cached forecast if available
Args:
tenant_id: Tenant identifier
product_id: Product identifier
forecast_date: Date of forecast
Returns:
Cached forecast data or None if not found
"""
if not await self.is_available():
return None
try:
key = self._get_forecast_key(tenant_id, product_id, forecast_date)
client = await self._get_redis()
cached_data = await client.get(key)
if cached_data:
forecast_data = json.loads(cached_data)
# Add cache hit metadata
forecast_data['cached'] = True
forecast_data['cache_hit_at'] = datetime.now().isoformat()
logger.info("Forecast cache HIT",
tenant_id=str(tenant_id),
product_id=str(product_id),
forecast_date=str(forecast_date))
return forecast_data
logger.debug("Forecast cache MISS",
tenant_id=str(tenant_id),
product_id=str(product_id),
forecast_date=str(forecast_date))
return None
except Exception as e:
logger.error("Error retrieving cached forecast",
error=str(e),
tenant_id=str(tenant_id))
return None
async def cache_forecast(
self,
tenant_id: UUID,
product_id: UUID,
forecast_date: date,
forecast_data: Dict[str, Any]
) -> bool:
"""
Cache forecast prediction result
Args:
tenant_id: Tenant identifier
product_id: Product identifier
forecast_date: Date of forecast
forecast_data: Forecast prediction data to cache
Returns:
True if cached successfully, False otherwise
"""
if not await self.is_available():
logger.warning("Redis not available, skipping forecast cache")
return False
try:
key = self._get_forecast_key(tenant_id, product_id, forecast_date)
ttl = self._calculate_ttl(forecast_date)
# Add caching metadata
cache_entry = {
**forecast_data,
'cached_at': datetime.now().isoformat(),
'cache_key': key,
'ttl_seconds': ttl
}
# Serialize and cache
client = await self._get_redis()
await client.setex(
key,
ttl,
json.dumps(cache_entry, default=str)
)
logger.info("Forecast cached successfully",
tenant_id=str(tenant_id),
product_id=str(product_id),
forecast_date=str(forecast_date),
ttl_hours=round(ttl / 3600, 2))
return True
except Exception as e:
logger.error("Error caching forecast",
error=str(e),
tenant_id=str(tenant_id))
return False
async def get_cached_batch_forecast(
self,
tenant_id: UUID,
product_ids: List[UUID],
forecast_date: date
) -> Optional[Dict[str, Any]]:
"""
Retrieve cached batch forecast
Args:
tenant_id: Tenant identifier
product_ids: List of product identifiers
forecast_date: Date of forecast
Returns:
Cached batch forecast data or None
"""
if not await self.is_available():
return None
try:
key = self._get_batch_forecast_key(tenant_id, product_ids, forecast_date)
client = await self._get_redis()
cached_data = await client.get(key)
if cached_data:
forecast_data = json.loads(cached_data)
forecast_data['cached'] = True
forecast_data['cache_hit_at'] = datetime.now().isoformat()
logger.info("Batch forecast cache HIT",
tenant_id=str(tenant_id),
products_count=len(product_ids),
forecast_date=str(forecast_date))
return forecast_data
return None
except Exception as e:
logger.error("Error retrieving cached batch forecast", error=str(e))
return None
async def cache_batch_forecast(
self,
tenant_id: UUID,
product_ids: List[UUID],
forecast_date: date,
forecast_data: Dict[str, Any]
) -> bool:
"""Cache batch forecast result"""
if not await self.is_available():
return False
try:
key = self._get_batch_forecast_key(tenant_id, product_ids, forecast_date)
ttl = self._calculate_ttl(forecast_date)
cache_entry = {
**forecast_data,
'cached_at': datetime.now().isoformat(),
'cache_key': key,
'ttl_seconds': ttl
}
client = await self._get_redis()
await client.setex(key, ttl, json.dumps(cache_entry, default=str))
logger.info("Batch forecast cached successfully",
tenant_id=str(tenant_id),
products_count=len(product_ids),
ttl_hours=round(ttl / 3600, 2))
return True
except Exception as e:
logger.error("Error caching batch forecast", error=str(e))
return False
# ================================================================
# CACHE INVALIDATION
# ================================================================
async def invalidate_product_forecasts(
self,
tenant_id: UUID,
product_id: UUID
) -> int:
"""
Invalidate all forecast cache entries for a product
Called when model is retrained for specific product.
Args:
tenant_id: Tenant identifier
product_id: Product identifier
Returns:
Number of cache entries invalidated
"""
if not await self.is_available():
return 0
try:
# Find all keys matching this product
pattern = f"forecast:{tenant_id}:{product_id}:*"
client = await self._get_redis()
keys = await client.keys(pattern)
if keys:
deleted = await client.delete(*keys)
logger.info("Invalidated product forecast cache",
tenant_id=str(tenant_id),
product_id=str(product_id),
keys_deleted=deleted)
return deleted
return 0
except Exception as e:
logger.error("Error invalidating product forecasts",
error=str(e),
tenant_id=str(tenant_id))
return 0
async def invalidate_tenant_forecasts(
self,
tenant_id: UUID,
forecast_date: Optional[date] = None
) -> int:
"""
Invalidate forecast cache for tenant
Args:
tenant_id: Tenant identifier
forecast_date: Optional specific date to invalidate
Returns:
Number of cache entries invalidated
"""
if not await self.is_available():
return 0
try:
if forecast_date:
pattern = f"forecast:{tenant_id}:*:{forecast_date.isoformat()}"
else:
pattern = f"forecast:{tenant_id}:*"
client = await self._get_redis()
keys = await client.keys(pattern)
if keys:
deleted = await client.delete(*keys)
logger.info("Invalidated tenant forecast cache",
tenant_id=str(tenant_id),
forecast_date=str(forecast_date) if forecast_date else "all",
keys_deleted=deleted)
return deleted
return 0
except Exception as e:
logger.error("Error invalidating tenant forecasts", error=str(e))
return 0
async def invalidate_all_forecasts(self) -> int:
"""
Invalidate all forecast cache entries (use with caution)
Returns:
Number of cache entries invalidated
"""
if not await self.is_available():
return 0
try:
pattern = "forecast:*"
client = await self._get_redis()
keys = await client.keys(pattern)
if keys:
deleted = await client.delete(*keys)
logger.warning("Invalidated ALL forecast cache", keys_deleted=deleted)
return deleted
return 0
except Exception as e:
logger.error("Error invalidating all forecasts", error=str(e))
return 0
# ================================================================
# CACHE STATISTICS & MONITORING
# ================================================================
async def get_cache_stats(self) -> Dict[str, Any]:
"""
Get cache statistics for monitoring
Returns:
Dictionary with cache metrics
"""
if not await self.is_available():
return {"available": False}
try:
client = await self._get_redis()
info = await client.info()
# Get forecast-specific stats
forecast_keys = await client.keys("forecast:*")
batch_keys = await client.keys("forecast:batch:*")
return {
"available": True,
"total_forecast_keys": len(forecast_keys),
"batch_forecast_keys": len(batch_keys),
"single_forecast_keys": len(forecast_keys) - len(batch_keys),
"used_memory": info.get("used_memory_human"),
"connected_clients": info.get("connected_clients"),
"keyspace_hits": info.get("keyspace_hits", 0),
"keyspace_misses": info.get("keyspace_misses", 0),
"hit_rate_percent": self._calculate_hit_rate(
info.get("keyspace_hits", 0),
info.get("keyspace_misses", 0)
),
"total_commands_processed": info.get("total_commands_processed", 0)
}
except Exception as e:
logger.error("Error getting cache stats", error=str(e))
return {"available": False, "error": str(e)}
def _calculate_hit_rate(self, hits: int, misses: int) -> float:
"""Calculate cache hit rate percentage"""
total = hits + misses
return round((hits / total * 100), 2) if total > 0 else 0.0
async def get_cached_forecast_info(
self,
tenant_id: UUID,
product_id: UUID,
forecast_date: date
) -> Optional[Dict[str, Any]]:
"""
Get metadata about cached forecast without retrieving full data
Args:
tenant_id: Tenant identifier
product_id: Product identifier
forecast_date: Date of forecast
Returns:
Cache metadata or None
"""
if not await self.is_available():
return None
try:
key = self._get_forecast_key(tenant_id, product_id, forecast_date)
client = await self._get_redis()
ttl = await client.ttl(key)
if ttl > 0:
return {
"cached": True,
"cache_key": key,
"ttl_seconds": ttl,
"ttl_hours": round(ttl / 3600, 2),
"expires_at": (datetime.now() + timedelta(seconds=ttl)).isoformat()
}
return None
except Exception as e:
logger.error("Error getting forecast cache info", error=str(e))
return None
# Global cache service instance
_cache_service = None
def get_forecast_cache_service() -> ForecastCacheService:
"""
Get the global forecast cache service instance
Returns:
ForecastCacheService instance
"""
global _cache_service
if _cache_service is None:
_cache_service = ForecastCacheService()
return _cache_service

View File

@@ -0,0 +1,533 @@
# services/forecasting/app/services/forecast_feedback_service.py
"""
Forecast Feedback Service
Business logic for collecting and analyzing forecast feedback
"""
from typing import List, Dict, Any, Optional
from datetime import datetime, timedelta, date
import uuid
import structlog
from dataclasses import dataclass
logger = structlog.get_logger()
@dataclass
class ForecastFeedback:
"""Data class for forecast feedback"""
feedback_id: uuid.UUID
forecast_id: uuid.UUID
tenant_id: str
feedback_type: str
confidence: str
actual_value: Optional[float]
notes: Optional[str]
feedback_data: Dict[str, Any]
created_at: datetime
created_by: Optional[str]
@dataclass
class ForecastAccuracyMetrics:
"""Data class for forecast accuracy metrics"""
forecast_id: str
total_feedback_count: int
accuracy_score: float
feedback_distribution: Dict[str, int]
average_confidence: float
last_feedback_date: Optional[datetime]
@dataclass
class ForecasterPerformanceMetrics:
"""Data class for forecaster performance metrics"""
overall_accuracy: float
total_forecasts_with_feedback: int
accuracy_by_product: Dict[str, float]
accuracy_trend: str
improvement_suggestions: List[str]
class ForecastFeedbackService:
"""
Service for managing forecast feedback and accuracy tracking
"""
def __init__(self, database_manager):
self.database_manager = database_manager
async def forecast_exists(self, tenant_id: str, forecast_id: str) -> bool:
"""
Check if a forecast exists
"""
try:
async with self.database_manager.get_session() as session:
from app.models.forecasts import Forecast
result = await session.execute(
"""
SELECT 1 FROM forecasts
WHERE tenant_id = :tenant_id AND id = :forecast_id
""",
{"tenant_id": tenant_id, "forecast_id": forecast_id}
)
return result.scalar() is not None
except Exception as e:
logger.error("Failed to check forecast existence", error=str(e))
raise Exception(f"Failed to check forecast existence: {str(e)}")
async def submit_feedback(
self,
tenant_id: str,
forecast_id: str,
feedback_type: str,
confidence: str,
actual_value: Optional[float] = None,
notes: Optional[str] = None,
feedback_data: Optional[Dict[str, Any]] = None
) -> ForecastFeedback:
"""
Submit feedback on forecast accuracy
"""
try:
async with self.database_manager.get_session() as session:
# Create feedback record
feedback_id = uuid.uuid4()
created_at = datetime.now()
# In a real implementation, this would insert into a forecast_feedback table
# For demo purposes, we'll simulate the database operation
feedback = ForecastFeedback(
feedback_id=feedback_id,
forecast_id=uuid.UUID(forecast_id),
tenant_id=tenant_id,
feedback_type=feedback_type,
confidence=confidence,
actual_value=actual_value,
notes=notes,
feedback_data=feedback_data or {},
created_at=created_at,
created_by="system" # In real implementation, this would be the user ID
)
# Simulate database insert
logger.info("Feedback submitted",
feedback_id=str(feedback_id),
forecast_id=forecast_id,
feedback_type=feedback_type)
return feedback
except Exception as e:
logger.error("Failed to submit feedback", error=str(e))
raise Exception(f"Failed to submit feedback: {str(e)}")
async def get_feedback_for_forecast(
self,
tenant_id: str,
forecast_id: str,
limit: int = 50,
offset: int = 0
) -> List[ForecastFeedback]:
"""
Get all feedback for a specific forecast
"""
try:
# In a real implementation, this would query the forecast_feedback table
# For demo purposes, we'll return simulated data
# Simulate some feedback data
simulated_feedback = []
for i in range(min(limit, 3)): # Return up to 3 simulated feedback items
feedback = ForecastFeedback(
feedback_id=uuid.uuid4(),
forecast_id=uuid.UUID(forecast_id),
tenant_id=tenant_id,
feedback_type=["too_high", "too_low", "accurate"][i % 3],
confidence=["medium", "high", "low"][i % 3],
actual_value=150.0 + i * 20 if i < 2 else None,
notes=f"Feedback sample {i+1}" if i == 0 else None,
feedback_data={"sample": i+1, "demo": True},
created_at=datetime.now() - timedelta(days=i),
created_by="demo_user"
)
simulated_feedback.append(feedback)
return simulated_feedback
except Exception as e:
logger.error("Failed to get feedback for forecast", error=str(e))
raise Exception(f"Failed to get feedback: {str(e)}")
async def calculate_accuracy_metrics(
self,
tenant_id: str,
forecast_id: str
) -> ForecastAccuracyMetrics:
"""
Calculate accuracy metrics for a forecast
"""
try:
# Get feedback for this forecast
feedback_list = await self.get_feedback_for_forecast(tenant_id, forecast_id)
if not feedback_list:
return None
# Calculate metrics
total_feedback = len(feedback_list)
# Count feedback distribution
feedback_distribution = {
"too_high": 0,
"too_low": 0,
"accurate": 0,
"uncertain": 0
}
confidence_scores = {
"low": 1,
"medium": 2,
"high": 3
}
total_confidence = 0
for feedback in feedback_list:
feedback_distribution[feedback.feedback_type] += 1
total_confidence += confidence_scores.get(feedback.confidence, 1)
# Calculate accuracy score (simplified)
accurate_count = feedback_distribution["accurate"]
accuracy_score = (accurate_count / total_feedback) * 100
# Adjust for confidence
avg_confidence = total_confidence / total_feedback
adjusted_accuracy = accuracy_score * (avg_confidence / 3) # Normalize confidence to 0-1 range
return ForecastAccuracyMetrics(
forecast_id=forecast_id,
total_feedback_count=total_feedback,
accuracy_score=round(adjusted_accuracy, 1),
feedback_distribution=feedback_distribution,
average_confidence=round(avg_confidence, 1),
last_feedback_date=max(f.created_at for f in feedback_list)
)
except Exception as e:
logger.error("Failed to calculate accuracy metrics", error=str(e))
raise Exception(f"Failed to calculate metrics: {str(e)}")
async def calculate_performance_summary(
self,
tenant_id: str,
start_date: Optional[date] = None,
end_date: Optional[date] = None,
product_id: Optional[str] = None
) -> ForecasterPerformanceMetrics:
"""
Calculate overall forecaster performance summary
"""
try:
# In a real implementation, this would aggregate data across multiple forecasts
# For demo purposes, we'll return simulated metrics
# Simulate performance data
accuracy_by_product = {
"baguette": 85.5,
"croissant": 78.2,
"pain_au_chocolat": 92.1
}
if product_id and product_id in accuracy_by_product:
# Return metrics for specific product
product_accuracy = accuracy_by_product[product_id]
accuracy_by_product = {product_id: product_accuracy}
# Calculate overall accuracy
overall_accuracy = sum(accuracy_by_product.values()) / len(accuracy_by_product)
# Determine trend (simulated)
trend_data = [82.3, 84.1, 85.5, 86.8, 88.2] # Last 5 periods
if trend_data[-1] > trend_data[0]:
trend = "improving"
elif trend_data[-1] < trend_data[0]:
trend = "declining"
else:
trend = "stable"
# Generate improvement suggestions
suggestions = []
for product, accuracy in accuracy_by_product.items():
if accuracy < 80:
suggestions.append(f"Improve {product} forecast accuracy (current: {accuracy}%)")
elif accuracy < 90:
suggestions.append(f"Consider fine-tuning {product} forecast model (current: {accuracy}%)")
if not suggestions:
suggestions.append("Overall forecast accuracy is excellent - maintain current approach")
return ForecasterPerformanceMetrics(
overall_accuracy=round(overall_accuracy, 1),
total_forecasts_with_feedback=42,
accuracy_by_product=accuracy_by_product,
accuracy_trend=trend,
improvement_suggestions=suggestions
)
except Exception as e:
logger.error("Failed to calculate performance summary", error=str(e))
raise Exception(f"Failed to calculate summary: {str(e)}")
async def get_feedback_trends(
self,
tenant_id: str,
days: int = 30,
product_id: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
Get feedback trends over time
"""
try:
# Simulate trend data
trends = []
end_date = datetime.now()
# Generate daily trend data
for i in range(days):
date = end_date - timedelta(days=i)
# Simulate varying accuracy with weekly pattern
base_accuracy = 85.0
weekly_variation = 3.0 * (i % 7 / 6 - 0.5) # Weekly pattern
daily_noise = (i % 3 - 1) * 1.5 # Daily noise
accuracy = max(70, min(95, base_accuracy + weekly_variation + daily_noise))
trends.append({
'date': date.strftime('%Y-%m-%d'),
'accuracy_score': round(accuracy, 1),
'feedback_count': max(1, int(5 + i % 10)),
'confidence_score': round(2.5 + (i % 5 - 2) * 0.2, 1)
})
# Sort by date (oldest first)
trends.sort(key=lambda x: x['date'])
return trends
except Exception as e:
logger.error("Failed to get feedback trends", error=str(e))
raise Exception(f"Failed to get trends: {str(e)}")
async def trigger_retraining_from_feedback(
self,
tenant_id: str,
forecast_id: str
) -> Dict[str, Any]:
"""
Trigger model retraining based on feedback
"""
try:
# In a real implementation, this would:
# 1. Collect recent feedback data
# 2. Prepare training dataset
# 3. Submit retraining job to ML service
# 4. Return job ID
# For demo purposes, simulate a retraining job
job_id = str(uuid.uuid4())
logger.info("Retraining job triggered",
job_id=job_id,
tenant_id=tenant_id,
forecast_id=forecast_id)
return {
'job_id': job_id,
'forecasts_included': 15,
'feedback_samples_used': 42,
'status': 'queued',
'estimated_completion': (datetime.now() + timedelta(minutes=30)).isoformat()
}
except Exception as e:
logger.error("Failed to trigger retraining", error=str(e))
raise Exception(f"Failed to trigger retraining: {str(e)}")
async def get_improvement_suggestions(
self,
tenant_id: str,
forecast_id: str
) -> List[Dict[str, Any]]:
"""
Get AI-generated improvement suggestions
"""
try:
# Get accuracy metrics for this forecast
metrics = await self.calculate_accuracy_metrics(tenant_id, forecast_id)
if not metrics:
return [
{
'suggestion': 'Insufficient feedback data to generate suggestions',
'type': 'data',
'priority': 'low',
'confidence': 0.7
}
]
# Generate suggestions based on metrics
suggestions = []
# Analyze feedback distribution
feedback_dist = metrics.feedback_distribution
total_feedback = metrics.total_feedback_count
if feedback_dist['too_high'] > total_feedback * 0.4:
suggestions.append({
'suggestion': 'Forecasts are consistently too high - consider adjusting demand estimation parameters',
'type': 'bias',
'priority': 'high',
'confidence': 0.9,
'details': {
'too_high_percentage': feedback_dist['too_high'] / total_feedback * 100,
'recommended_action': 'Reduce demand estimation by 10-15%'
}
})
if feedback_dist['too_low'] > total_feedback * 0.4:
suggestions.append({
'suggestion': 'Forecasts are consistently too low - consider increasing demand estimation parameters',
'type': 'bias',
'priority': 'high',
'confidence': 0.9,
'details': {
'too_low_percentage': feedback_dist['too_low'] / total_feedback * 100,
'recommended_action': 'Increase demand estimation by 10-15%'
}
})
if metrics.accuracy_score < 70:
suggestions.append({
'suggestion': 'Low overall accuracy - consider comprehensive model review and retraining',
'type': 'model',
'priority': 'critical',
'confidence': 0.85,
'details': {
'current_accuracy': metrics.accuracy_score,
'recommended_action': 'Full model retraining with expanded feature set'
}
})
elif metrics.accuracy_score < 85:
suggestions.append({
'suggestion': 'Moderate accuracy - consider feature engineering improvements',
'type': 'features',
'priority': 'medium',
'confidence': 0.8,
'details': {
'current_accuracy': metrics.accuracy_score,
'recommended_action': 'Add weather data, promotions, and seasonal features'
}
})
if metrics.average_confidence < 2.0: # Average of medium (2) and high (3)
suggestions.append({
'suggestion': 'Low confidence in feedback - consider improving feedback collection process',
'type': 'process',
'priority': 'medium',
'confidence': 0.75,
'details': {
'average_confidence': metrics.average_confidence,
'recommended_action': 'Provide clearer guidance to users on feedback submission'
}
})
if not suggestions:
suggestions.append({
'suggestion': 'Forecast accuracy is good - consider expanding to additional products',
'type': 'expansion',
'priority': 'low',
'confidence': 0.85,
'details': {
'current_accuracy': metrics.accuracy_score,
'recommended_action': 'Extend forecasting to new product categories'
}
})
return suggestions
except Exception as e:
logger.error("Failed to generate improvement suggestions", error=str(e))
raise Exception(f"Failed to generate suggestions: {str(e)}")
# Helper class for feedback analysis
class FeedbackAnalyzer:
"""
Helper class for analyzing feedback patterns
"""
@staticmethod
def detect_feedback_patterns(feedback_list: List[ForecastFeedback]) -> Dict[str, Any]:
"""
Detect patterns in feedback data
"""
if not feedback_list:
return {'patterns': [], 'anomalies': []}
patterns = []
anomalies = []
# Simple pattern detection (in real implementation, this would be more sophisticated)
feedback_types = [f.feedback_type for f in feedback_list]
if len(set(feedback_types)) == 1:
patterns.append({
'type': 'consistent_feedback',
'pattern': f'All feedback is "{feedback_types[0]}"',
'confidence': 0.9
})
return {'patterns': patterns, 'anomalies': anomalies}
# Helper class for accuracy calculation
class AccuracyCalculator:
"""
Helper class for calculating forecast accuracy metrics
"""
@staticmethod
def calculate_mape(actual: float, predicted: float) -> float:
"""
Calculate Mean Absolute Percentage Error
"""
if actual == 0:
return 0.0
return abs((actual - predicted) / actual) * 100
@staticmethod
def calculate_rmse(actual: float, predicted: float) -> float:
"""
Calculate Root Mean Squared Error
"""
return (actual - predicted) ** 2
@staticmethod
def feedback_to_accuracy_score(feedback_type: str) -> float:
"""
Convert feedback type to accuracy score
"""
feedback_scores = {
'accurate': 100,
'too_high': 50,
'too_low': 50,
'uncertain': 75
}
return feedback_scores.get(feedback_type, 75)

View File

@@ -0,0 +1,338 @@
"""
Forecasting Alert Service - Simplified
Emits minimal events using EventPublisher.
All enrichment handled by alert_processor.
"""
import json
from typing import List, Dict, Any, Optional
from uuid import UUID
from datetime import datetime, timedelta
import structlog
from shared.messaging import UnifiedEventPublisher
logger = structlog.get_logger()
class ForecastingAlertService:
"""Simplified forecasting alert service using EventPublisher"""
def __init__(self, event_publisher: UnifiedEventPublisher):
self.publisher = event_publisher
async def start(self):
"""Start the forecasting alert service"""
logger.info("ForecastingAlertService started")
# Add any initialization logic here if needed
async def stop(self):
"""Stop the forecasting alert service"""
logger.info("ForecastingAlertService stopped")
# Add any cleanup logic here if needed
async def health_check(self):
"""Health check for the forecasting alert service"""
try:
# Check if the event publisher is available and operational
if hasattr(self, 'publisher') and self.publisher:
# Basic check if publisher is available
return True
return False
except Exception as e:
logger.error("ForecastingAlertService health check failed", error=str(e))
return False
async def emit_demand_surge_weekend(
self,
tenant_id: UUID,
product_name: str,
inventory_product_id: str,
predicted_demand: float,
growth_percentage: float,
forecast_date: str,
weather_favorable: bool = False
):
"""Emit weekend demand surge alert"""
# Determine severity based on growth magnitude
if growth_percentage > 100:
severity = 'high'
elif growth_percentage > 75:
severity = 'medium'
else:
severity = 'low'
metadata = {
"product_name": product_name,
"inventory_product_id": str(inventory_product_id),
"predicted_demand": float(predicted_demand),
"growth_percentage": float(growth_percentage),
"forecast_date": forecast_date,
"weather_favorable": weather_favorable
}
await self.publisher.publish_alert(
event_type="forecasting.demand_surge_weekend",
tenant_id=tenant_id,
severity=severity,
data=metadata
)
logger.info(
"demand_surge_weekend_emitted",
tenant_id=str(tenant_id),
product_name=product_name,
growth_percentage=growth_percentage
)
async def emit_weather_impact_alert(
self,
tenant_id: UUID,
forecast_date: str,
precipitation: float,
expected_demand_change: float,
traffic_volume: int,
weather_type: str = "general",
product_name: Optional[str] = None
):
"""Emit weather impact alert"""
# Determine severity based on impact
if expected_demand_change < -20:
severity = 'high'
elif expected_demand_change < -10:
severity = 'medium'
else:
severity = 'low'
metadata = {
"forecast_date": forecast_date,
"precipitation_mm": float(precipitation),
"expected_demand_change": float(expected_demand_change),
"traffic_volume": traffic_volume,
"weather_type": weather_type
}
if product_name:
metadata["product_name"] = product_name
# Add triggers information
triggers = ['weather_conditions', 'demand_forecast']
if precipitation > 0:
triggers.append('rain_forecast')
if expected_demand_change < -15:
triggers.append('outdoor_events_cancelled')
metadata["triggers"] = triggers
await self.publisher.publish_alert(
event_type="forecasting.weather_impact_alert",
tenant_id=tenant_id,
severity=severity,
data=metadata
)
logger.info(
"weather_impact_alert_emitted",
tenant_id=str(tenant_id),
weather_type=weather_type,
expected_demand_change=expected_demand_change
)
async def emit_holiday_preparation(
self,
tenant_id: UUID,
holiday_name: str,
days_until_holiday: int,
product_name: str,
spike_percentage: float,
avg_holiday_demand: float,
avg_normal_demand: float,
holiday_date: str
):
"""Emit holiday preparation alert"""
# Determine severity based on spike magnitude and preparation time
if spike_percentage > 75 and days_until_holiday <= 3:
severity = 'high'
elif spike_percentage > 50 or days_until_holiday <= 3:
severity = 'medium'
else:
severity = 'low'
metadata = {
"holiday_name": holiday_name,
"days_until_holiday": days_until_holiday,
"product_name": product_name,
"spike_percentage": float(spike_percentage),
"avg_holiday_demand": float(avg_holiday_demand),
"avg_normal_demand": float(avg_normal_demand),
"holiday_date": holiday_date
}
# Add triggers information
triggers = [f'spanish_holiday_in_{days_until_holiday}_days']
if spike_percentage > 25:
triggers.append('historical_demand_spike')
metadata["triggers"] = triggers
await self.publisher.publish_alert(
event_type="forecasting.holiday_preparation",
tenant_id=tenant_id,
severity=severity,
data=metadata
)
logger.info(
"holiday_preparation_emitted",
tenant_id=str(tenant_id),
holiday_name=holiday_name,
spike_percentage=spike_percentage
)
async def emit_demand_optimization_recommendation(
self,
tenant_id: UUID,
product_name: str,
optimization_potential: float,
peak_demand: float,
min_demand: float,
demand_range: float
):
"""Emit demand pattern optimization recommendation"""
metadata = {
"product_name": product_name,
"optimization_potential": float(optimization_potential),
"peak_demand": float(peak_demand),
"min_demand": float(min_demand),
"demand_range": float(demand_range)
}
await self.publisher.publish_recommendation(
event_type="forecasting.demand_pattern_optimization",
tenant_id=tenant_id,
data=metadata
)
logger.info(
"demand_pattern_optimization_emitted",
tenant_id=str(tenant_id),
product_name=product_name,
optimization_potential=optimization_potential
)
async def emit_demand_spike_detected(
self,
tenant_id: UUID,
product_name: str,
spike_percentage: float
):
"""Emit demand spike detected event"""
# Determine severity based on spike magnitude
if spike_percentage > 50:
severity = 'high'
elif spike_percentage > 20:
severity = 'medium'
else:
severity = 'low'
metadata = {
"product_name": product_name,
"spike_percentage": float(spike_percentage),
"detection_source": "database"
}
await self.publisher.publish_alert(
event_type="forecasting.demand_spike_detected",
tenant_id=tenant_id,
severity=severity,
data=metadata
)
logger.info(
"demand_spike_detected_emitted",
tenant_id=str(tenant_id),
product_name=product_name,
spike_percentage=spike_percentage
)
async def emit_severe_weather_impact(
self,
tenant_id: UUID,
weather_type: str,
severity_level: str,
duration_hours: int
):
"""Emit severe weather impact event"""
# Determine alert severity based on weather severity
if severity_level == 'critical' or duration_hours > 24:
alert_severity = 'urgent'
elif severity_level == 'high' or duration_hours > 12:
alert_severity = 'high'
else:
alert_severity = 'medium'
metadata = {
"weather_type": weather_type,
"severity_level": severity_level,
"duration_hours": duration_hours
}
await self.publisher.publish_alert(
event_type="forecasting.severe_weather_impact",
tenant_id=tenant_id,
severity=alert_severity,
data=metadata
)
logger.info(
"severe_weather_impact_emitted",
tenant_id=str(tenant_id),
weather_type=weather_type,
severity_level=severity_level
)
async def emit_unexpected_demand_spike(
self,
tenant_id: UUID,
product_name: str,
spike_percentage: float,
current_sales: float,
forecasted_sales: float
):
"""Emit unexpected sales spike event"""
# Determine severity based on spike magnitude
if spike_percentage > 75:
severity = 'high'
elif spike_percentage > 40:
severity = 'medium'
else:
severity = 'low'
metadata = {
"product_name": product_name,
"spike_percentage": float(spike_percentage),
"current_sales": float(current_sales),
"forecasted_sales": float(forecasted_sales)
}
await self.publisher.publish_alert(
event_type="forecasting.unexpected_demand_spike",
tenant_id=tenant_id,
severity=severity,
data=metadata
)
logger.info(
"unexpected_demand_spike_emitted",
tenant_id=str(tenant_id),
product_name=product_name,
spike_percentage=spike_percentage
)

View File

@@ -0,0 +1,246 @@
"""
Forecasting Recommendation Service - Simplified
Emits minimal events using EventPublisher.
All enrichment handled by alert_processor.
"""
from datetime import datetime, timezone
from typing import Optional, Dict, Any, List
from uuid import UUID
import structlog
from shared.messaging import UnifiedEventPublisher
logger = structlog.get_logger()
class ForecastingRecommendationService:
"""
Service for emitting forecasting recommendations using EventPublisher.
"""
def __init__(self, event_publisher: UnifiedEventPublisher):
self.publisher = event_publisher
async def emit_demand_surge_recommendation(
self,
tenant_id: UUID,
product_sku: str,
product_name: str,
predicted_demand: float,
normal_demand: float,
surge_percentage: float,
surge_date: str,
confidence_score: float,
reasoning: str,
) -> None:
"""
Emit RECOMMENDATION for predicted demand surge.
"""
metadata = {
"product_sku": product_sku,
"product_name": product_name,
"predicted_demand": float(predicted_demand),
"normal_demand": float(normal_demand),
"surge_percentage": float(surge_percentage),
"surge_date": surge_date,
"confidence_score": float(confidence_score),
"reasoning": reasoning,
"estimated_impact": {
"additional_revenue_eur": predicted_demand * 5, # Rough estimate
"stockout_risk": "high" if surge_percentage > 50 else "medium",
},
}
await self.publisher.publish_recommendation(
event_type="demand.demand_surge_predicted",
tenant_id=tenant_id,
data=metadata
)
logger.info(
"demand_surge_recommendation_emitted",
tenant_id=str(tenant_id),
product_name=product_name,
surge_percentage=surge_percentage
)
async def emit_weather_impact_recommendation(
self,
tenant_id: UUID,
weather_event: str, # 'rain', 'snow', 'heatwave', etc.
forecast_date: str,
affected_products: List[Dict[str, Any]],
impact_description: str,
confidence_score: float,
) -> None:
"""
Emit RECOMMENDATION for weather impact on demand.
"""
metadata = {
"weather_event": weather_event,
"forecast_date": forecast_date,
"affected_products": affected_products,
"impact_description": impact_description,
"confidence_score": float(confidence_score),
}
await self.publisher.publish_recommendation(
event_type="demand.weather_impact_forecast",
tenant_id=tenant_id,
data=metadata
)
logger.info(
"weather_impact_recommendation_emitted",
tenant_id=str(tenant_id),
weather_event=weather_event
)
async def emit_holiday_preparation_recommendation(
self,
tenant_id: UUID,
holiday_name: str,
holiday_date: str,
days_until_holiday: int,
recommended_products: List[Dict[str, Any]],
preparation_tips: List[str],
) -> None:
"""
Emit RECOMMENDATION for holiday preparation.
"""
metadata = {
"holiday_name": holiday_name,
"holiday_date": holiday_date,
"days_until_holiday": days_until_holiday,
"recommended_products": recommended_products,
"preparation_tips": preparation_tips,
"confidence_score": 0.9, # High confidence for known holidays
}
await self.publisher.publish_recommendation(
event_type="demand.holiday_preparation",
tenant_id=tenant_id,
data=metadata
)
logger.info(
"holiday_preparation_recommendation_emitted",
tenant_id=str(tenant_id),
holiday=holiday_name
)
async def emit_seasonal_trend_recommendation(
self,
tenant_id: UUID,
season: str, # 'spring', 'summer', 'fall', 'winter'
trend_type: str, # 'increasing', 'decreasing', 'stable'
affected_categories: List[str],
trend_description: str,
suggested_actions: List[str],
) -> None:
"""
Emit RECOMMENDATION for seasonal trend insight.
"""
metadata = {
"season": season,
"trend_type": trend_type,
"affected_categories": affected_categories,
"trend_description": trend_description,
"suggested_actions": suggested_actions,
"confidence_score": 0.85,
}
await self.publisher.publish_recommendation(
event_type="demand.seasonal_trend_insight",
tenant_id=tenant_id,
data=metadata
)
logger.info(
"seasonal_trend_recommendation_emitted",
tenant_id=str(tenant_id),
season=season
)
async def emit_inventory_optimization_recommendation(
self,
tenant_id: UUID,
ingredient_id: str,
ingredient_name: str,
current_stock: float,
optimal_stock: float,
unit: str,
reason: str,
estimated_savings_eur: Optional[float] = None,
) -> None:
"""
Emit RECOMMENDATION for inventory optimization.
"""
difference = abs(current_stock - optimal_stock)
action = "reduce" if current_stock > optimal_stock else "increase"
estimated_impact = {}
if estimated_savings_eur:
estimated_impact["financial_savings_eur"] = estimated_savings_eur
metadata = {
"ingredient_id": ingredient_id,
"ingredient_name": ingredient_name,
"current_stock": float(current_stock),
"optimal_stock": float(optimal_stock),
"difference": float(difference),
"action": action,
"unit": unit,
"reason": reason,
"estimated_impact": estimated_impact if estimated_impact else None,
"confidence_score": 0.75,
}
await self.publisher.publish_recommendation(
event_type="inventory.inventory_optimization_opportunity",
tenant_id=tenant_id,
data=metadata
)
logger.info(
"inventory_optimization_recommendation_emitted",
tenant_id=str(tenant_id),
ingredient_name=ingredient_name
)
async def emit_cost_reduction_recommendation(
self,
tenant_id: UUID,
opportunity_type: str, # 'supplier_switch', 'bulk_purchase', 'seasonal_buying'
title: str,
description: str,
estimated_savings_eur: float,
suggested_actions: List[str],
details: Dict[str, Any],
) -> None:
"""
Emit RECOMMENDATION for cost reduction opportunity.
"""
metadata = {
"opportunity_type": opportunity_type,
"title": title,
"description": description,
"estimated_savings_eur": float(estimated_savings_eur),
"suggested_actions": suggested_actions,
"details": details,
"confidence_score": 0.8,
}
await self.publisher.publish_recommendation(
event_type="supply_chain.cost_reduction_suggestion",
tenant_id=tenant_id,
data=metadata
)
logger.info(
"cost_reduction_recommendation_emitted",
tenant_id=str(tenant_id),
opportunity_type=opportunity_type
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,480 @@
# ================================================================
# services/forecasting/app/services/historical_validation_service.py
# ================================================================
"""
Historical Validation Service
Handles validation backfill when historical sales data is uploaded late.
Detects gaps in validation coverage and automatically triggers validation
for periods where forecasts exist but haven't been validated yet.
"""
from typing import Dict, Any, List, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, func, Date, or_
from datetime import datetime, timedelta, timezone, date
import structlog
import uuid
from app.models.forecasts import Forecast
from app.models.validation_run import ValidationRun
from app.models.sales_data_update import SalesDataUpdate
from app.services.validation_service import ValidationService
from shared.database.exceptions import DatabaseError
logger = structlog.get_logger()
class HistoricalValidationService:
"""Service for backfilling historical validation when sales data arrives late"""
def __init__(self, db_session: AsyncSession):
self.db = db_session
self.validation_service = ValidationService(db_session)
async def detect_validation_gaps(
self,
tenant_id: uuid.UUID,
lookback_days: int = 90
) -> List[Dict[str, Any]]:
"""
Detect date ranges where forecasts exist but haven't been validated
Args:
tenant_id: Tenant identifier
lookback_days: How far back to check (default 90 days)
Returns:
List of gap periods with date ranges
"""
try:
end_date = datetime.now(timezone.utc)
start_date = end_date - timedelta(days=lookback_days)
logger.info(
"Detecting validation gaps",
tenant_id=tenant_id,
start_date=start_date.isoformat(),
end_date=end_date.isoformat()
)
# Get all dates with forecasts
forecast_query = select(
func.cast(Forecast.forecast_date, Date).label('forecast_date')
).where(
and_(
Forecast.tenant_id == tenant_id,
Forecast.forecast_date >= start_date,
Forecast.forecast_date <= end_date
)
).group_by(
func.cast(Forecast.forecast_date, Date)
).order_by(
func.cast(Forecast.forecast_date, Date)
)
forecast_result = await self.db.execute(forecast_query)
forecast_dates = {row.forecast_date for row in forecast_result.fetchall()}
if not forecast_dates:
logger.info("No forecasts found in lookback period", tenant_id=tenant_id)
return []
# Get all dates that have been validated
validation_query = select(
func.cast(ValidationRun.validation_start_date, Date).label('validated_date')
).where(
and_(
ValidationRun.tenant_id == tenant_id,
ValidationRun.status == "completed",
ValidationRun.validation_start_date >= start_date,
ValidationRun.validation_end_date <= end_date
)
).group_by(
func.cast(ValidationRun.validation_start_date, Date)
)
validation_result = await self.db.execute(validation_query)
validated_dates = {row.validated_date for row in validation_result.fetchall()}
# Find gaps (dates with forecasts but no validation)
gap_dates = sorted(forecast_dates - validated_dates)
if not gap_dates:
logger.info("No validation gaps found", tenant_id=tenant_id)
return []
# Group consecutive dates into ranges
gaps = []
current_gap_start = gap_dates[0]
current_gap_end = gap_dates[0]
for i in range(1, len(gap_dates)):
if (gap_dates[i] - current_gap_end).days == 1:
# Consecutive date, extend current gap
current_gap_end = gap_dates[i]
else:
# Gap in dates, save current gap and start new one
gaps.append({
"start_date": current_gap_start,
"end_date": current_gap_end,
"days_count": (current_gap_end - current_gap_start).days + 1
})
current_gap_start = gap_dates[i]
current_gap_end = gap_dates[i]
# Don't forget the last gap
gaps.append({
"start_date": current_gap_start,
"end_date": current_gap_end,
"days_count": (current_gap_end - current_gap_start).days + 1
})
logger.info(
"Validation gaps detected",
tenant_id=tenant_id,
gaps_count=len(gaps),
total_days=len(gap_dates)
)
return gaps
except Exception as e:
logger.error(
"Failed to detect validation gaps",
tenant_id=tenant_id,
error=str(e)
)
raise DatabaseError(f"Failed to detect validation gaps: {str(e)}")
async def backfill_validation(
self,
tenant_id: uuid.UUID,
start_date: date,
end_date: date,
triggered_by: str = "manual",
sales_data_update_id: Optional[uuid.UUID] = None
) -> Dict[str, Any]:
"""
Backfill validation for a historical date range
Args:
tenant_id: Tenant identifier
start_date: Start date for backfill
end_date: End date for backfill
triggered_by: How this backfill was triggered
sales_data_update_id: Optional link to sales data update record
Returns:
Backfill results with validation summary
"""
try:
logger.info(
"Starting validation backfill",
tenant_id=tenant_id,
start_date=start_date.isoformat(),
end_date=end_date.isoformat(),
triggered_by=triggered_by
)
# Convert dates to datetime
start_datetime = datetime.combine(start_date, datetime.min.time()).replace(tzinfo=timezone.utc)
end_datetime = datetime.combine(end_date, datetime.max.time()).replace(tzinfo=timezone.utc)
# Run validation for the date range
validation_result = await self.validation_service.validate_date_range(
tenant_id=tenant_id,
start_date=start_datetime,
end_date=end_datetime,
orchestration_run_id=None,
triggered_by=triggered_by
)
# Update sales data update record if provided
if sales_data_update_id:
await self._update_sales_data_record(
sales_data_update_id=sales_data_update_id,
validation_run_id=uuid.UUID(validation_result["validation_run_id"]),
status="completed" if validation_result["status"] == "completed" else "failed"
)
logger.info(
"Validation backfill completed",
tenant_id=tenant_id,
validation_run_id=validation_result.get("validation_run_id"),
forecasts_evaluated=validation_result.get("forecasts_evaluated")
)
return {
**validation_result,
"backfill_date_range": {
"start": start_date.isoformat(),
"end": end_date.isoformat()
}
}
except Exception as e:
logger.error(
"Validation backfill failed",
tenant_id=tenant_id,
start_date=start_date.isoformat(),
end_date=end_date.isoformat(),
error=str(e)
)
if sales_data_update_id:
await self._update_sales_data_record(
sales_data_update_id=sales_data_update_id,
validation_run_id=None,
status="failed",
error_message=str(e)
)
raise DatabaseError(f"Validation backfill failed: {str(e)}")
async def auto_backfill_gaps(
self,
tenant_id: uuid.UUID,
lookback_days: int = 90,
max_gaps_to_process: int = 10
) -> Dict[str, Any]:
"""
Automatically detect and backfill validation gaps
Args:
tenant_id: Tenant identifier
lookback_days: How far back to check
max_gaps_to_process: Maximum number of gaps to process in one run
Returns:
Summary of backfill operations
"""
try:
logger.info(
"Starting auto backfill",
tenant_id=tenant_id,
lookback_days=lookback_days
)
# Detect gaps
gaps = await self.detect_validation_gaps(tenant_id, lookback_days)
if not gaps:
return {
"gaps_found": 0,
"gaps_processed": 0,
"validations_completed": 0,
"message": "No validation gaps found"
}
# Limit number of gaps to process
gaps_to_process = gaps[:max_gaps_to_process]
results = []
for gap in gaps_to_process:
try:
result = await self.backfill_validation(
tenant_id=tenant_id,
start_date=gap["start_date"],
end_date=gap["end_date"],
triggered_by="auto_backfill"
)
results.append({
"gap": gap,
"result": result,
"status": "success"
})
except Exception as e:
logger.error(
"Failed to backfill gap",
gap=gap,
error=str(e)
)
results.append({
"gap": gap,
"error": str(e),
"status": "failed"
})
successful = sum(1 for r in results if r["status"] == "success")
logger.info(
"Auto backfill completed",
tenant_id=tenant_id,
gaps_found=len(gaps),
gaps_processed=len(results),
successful=successful
)
return {
"gaps_found": len(gaps),
"gaps_processed": len(results),
"validations_completed": successful,
"validations_failed": len(results) - successful,
"results": results
}
except Exception as e:
logger.error(
"Auto backfill failed",
tenant_id=tenant_id,
error=str(e)
)
raise DatabaseError(f"Auto backfill failed: {str(e)}")
async def register_sales_data_update(
self,
tenant_id: uuid.UUID,
start_date: date,
end_date: date,
records_affected: int,
update_source: str = "import",
import_job_id: Optional[str] = None,
auto_trigger_validation: bool = True
) -> Dict[str, Any]:
"""
Register a sales data update and optionally trigger validation
Args:
tenant_id: Tenant identifier
start_date: Start date of updated data
end_date: End date of updated data
records_affected: Number of sales records affected
update_source: Source of update (import, manual, pos_sync)
import_job_id: Optional import job ID
auto_trigger_validation: Whether to automatically trigger validation
Returns:
Update record and validation result if triggered
"""
try:
# Create sales data update record
update_record = SalesDataUpdate(
tenant_id=tenant_id,
update_date_start=start_date,
update_date_end=end_date,
records_affected=records_affected,
update_source=update_source,
import_job_id=import_job_id,
requires_validation=auto_trigger_validation,
validation_status="pending" if auto_trigger_validation else "not_required"
)
self.db.add(update_record)
await self.db.flush()
logger.info(
"Registered sales data update",
tenant_id=tenant_id,
update_id=update_record.id,
date_range=f"{start_date} to {end_date}",
records_affected=records_affected
)
result = {
"update_id": str(update_record.id),
"update_record": update_record.to_dict(),
"validation_triggered": False
}
# Trigger validation if requested
if auto_trigger_validation:
try:
validation_result = await self.backfill_validation(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
triggered_by="sales_data_update",
sales_data_update_id=update_record.id
)
result["validation_triggered"] = True
result["validation_result"] = validation_result
logger.info(
"Validation triggered for sales data update",
update_id=update_record.id,
validation_run_id=validation_result.get("validation_run_id")
)
except Exception as e:
logger.error(
"Failed to trigger validation for sales data update",
update_id=update_record.id,
error=str(e)
)
update_record.validation_status = "failed"
update_record.validation_error = str(e)[:500]
await self.db.commit()
return result
except Exception as e:
logger.error(
"Failed to register sales data update",
tenant_id=tenant_id,
error=str(e)
)
await self.db.rollback()
raise DatabaseError(f"Failed to register sales data update: {str(e)}")
async def _update_sales_data_record(
self,
sales_data_update_id: uuid.UUID,
validation_run_id: Optional[uuid.UUID],
status: str,
error_message: Optional[str] = None
):
"""Update sales data update record with validation results"""
try:
query = select(SalesDataUpdate).where(SalesDataUpdate.id == sales_data_update_id)
result = await self.db.execute(query)
update_record = result.scalar_one_or_none()
if update_record:
update_record.validation_status = status
update_record.validation_run_id = validation_run_id
update_record.validated_at = datetime.now(timezone.utc)
if error_message:
update_record.validation_error = error_message[:500]
await self.db.commit()
except Exception as e:
logger.error(
"Failed to update sales data record",
sales_data_update_id=sales_data_update_id,
error=str(e)
)
async def get_pending_validations(
self,
tenant_id: uuid.UUID,
limit: int = 50
) -> List[SalesDataUpdate]:
"""Get pending sales data updates that need validation"""
try:
query = (
select(SalesDataUpdate)
.where(
and_(
SalesDataUpdate.tenant_id == tenant_id,
SalesDataUpdate.validation_status == "pending",
SalesDataUpdate.requires_validation == True
)
)
.order_by(SalesDataUpdate.created_at)
.limit(limit)
)
result = await self.db.execute(query)
return result.scalars().all()
except Exception as e:
logger.error(
"Failed to get pending validations",
tenant_id=tenant_id,
error=str(e)
)
raise DatabaseError(f"Failed to get pending validations: {str(e)}")

View File

@@ -0,0 +1,240 @@
# services/forecasting/app/services/model_client.py
"""
Forecast Service Model Client
Demonstrates calling training service to get models
"""
import structlog
from typing import Dict, Any, List, Optional
# Import shared clients - no more code duplication!
from shared.clients import get_service_clients, get_training_client, get_sales_client
from shared.database.base import create_database_manager
from app.core.config import settings
logger = structlog.get_logger()
class ModelClient:
"""
Client for managing models in forecasting service with dependency injection
Shows how to call multiple services cleanly
"""
def __init__(self, database_manager=None):
self.database_manager = database_manager or create_database_manager(
settings.DATABASE_URL, "forecasting-service"
)
# Option 1: Get all clients at once
self.clients = get_service_clients(settings, "forecasting")
# Option 2: Get specific clients
# self.training_client = get_training_client(settings, "forecasting")
# self.sales_client = get_sales_client(settings, "forecasting")
async def get_available_models(
self,
tenant_id: str,
model_type: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
Get available trained models from training service
"""
try:
models = await self.clients.training.list_models(
tenant_id=tenant_id,
status="deployed", # Only get deployed models
model_type=model_type
)
if models:
logger.info(f"Found {len(models)} available models",
tenant_id=tenant_id, model_type=model_type)
return models
else:
logger.warning("No available models found", tenant_id=tenant_id)
return []
except Exception as e:
logger.error(f"Error fetching available models: {e}", tenant_id=tenant_id)
return []
async def get_best_model_for_forecasting(
self,
tenant_id: str,
inventory_product_id: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""
Get the best model for forecasting based on performance metrics
"""
try:
# Get latest model
latest_model = await self.clients.training.get_active_model_for_product(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id
)
if not latest_model:
logger.warning("No trained models found", tenant_id=tenant_id)
return None
# ✅ FIX 1: Use "model_id" instead of "id"
model_id = latest_model.get("model_id")
if not model_id:
logger.error("Model response missing model_id field", tenant_id=tenant_id)
return None
# ✅ FIX 2: Handle metrics endpoint failure gracefully
try:
# Get model metrics to validate quality
metrics = await self.clients.training.get_model_metrics(
tenant_id=tenant_id,
model_id=model_id
)
# If metrics call succeeded, check accuracy threshold
if metrics and metrics.get("accuracy", 0) > 0.7: # 70% accuracy threshold
logger.info(f"Selected model {model_id} with accuracy {metrics.get('accuracy')}",
tenant_id=tenant_id)
return latest_model
elif metrics:
logger.warning(f"Model accuracy too low: {metrics.get('accuracy', 'unknown')}",
tenant_id=tenant_id)
# Still return the model even if accuracy is low - better than no prediction
logger.info("Returning model despite low accuracy - no alternative available",
tenant_id=tenant_id)
return latest_model
else:
logger.warning("No metrics returned from training service", tenant_id=tenant_id)
# Return model anyway - metrics service might be temporarily down
return latest_model
except Exception as metrics_error:
# ✅ FIX 3: If metrics endpoint fails, still return the model
logger.warning(f"Failed to get model metrics: {metrics_error}", tenant_id=tenant_id)
logger.info("Proceeding with model despite metrics failure", tenant_id=tenant_id)
return latest_model
except Exception as e:
logger.error(f"Error selecting best model: {e}", tenant_id=tenant_id)
return None
async def get_any_model_for_tenant(
self,
tenant_id: str
) -> Optional[Dict[str, Any]]:
"""
Get any available model for a tenant, used as fallback when specific product models aren't found
"""
try:
# First try to get any active models for this tenant
models = await self.get_available_models(tenant_id)
if models:
# Return the most recently trained model
sorted_models = sorted(models, key=lambda x: x.get('created_at', ''), reverse=True)
best_model = sorted_models[0]
logger.info("Found fallback model for tenant",
tenant_id=tenant_id,
model_id=best_model.get('id', 'unknown'),
inventory_product_id=best_model.get('inventory_product_id', 'unknown'))
return best_model
logger.warning("No fallback models available for tenant", tenant_id=tenant_id)
return None
except Exception as e:
logger.error("Error getting fallback model for tenant",
tenant_id=tenant_id,
error=str(e))
return None
async def validate_model_data_compatibility(
self,
tenant_id: str,
model_id: str,
forecast_start_date: str,
forecast_end_date: str
) -> Dict[str, Any]:
"""
Validate that we have sufficient data for the model to make forecasts
Demonstrates calling both training and data services
"""
try:
# Get model details from training service
model = await self.clients.training.get_model(
tenant_id=tenant_id,
model_id=model_id
)
if not model:
return {"is_valid": False, "error": "Model not found"}
# Get data statistics from data service
data_stats = await self.clients.data.get_data_statistics(
tenant_id=tenant_id,
start_date=forecast_start_date,
end_date=forecast_end_date
)
if not data_stats:
return {"is_valid": False, "error": "Could not retrieve data statistics"}
# Check if we have minimum required data points
min_required = model.get("metadata", {}).get("min_data_points", 30)
available_points = data_stats.get("total_records", 0)
is_valid = available_points >= min_required
result = {
"is_valid": is_valid,
"model_id": model_id,
"required_points": min_required,
"available_points": available_points,
"data_coverage": data_stats.get("coverage_percentage", 0)
}
if not is_valid:
result["error"] = f"Insufficient data: need {min_required}, have {available_points}"
logger.info("Model data compatibility check completed",
tenant_id=tenant_id, model_id=model_id, is_valid=is_valid)
return result
except Exception as e:
logger.error(f"Error validating model compatibility: {e}",
tenant_id=tenant_id, model_id=model_id)
return {"is_valid": False, "error": str(e)}
async def trigger_model_retraining(
self,
tenant_id: str,
include_weather: bool = True,
include_traffic: bool = False
) -> Optional[Dict[str, Any]]:
"""
Trigger a new training job if current model is outdated
"""
try:
# Create training job through training service
job = await self.clients.training.create_training_job(
tenant_id=tenant_id,
include_weather=include_weather,
include_traffic=include_traffic,
min_data_points=50 # Higher threshold for forecasting
)
if job:
logger.info(f"Training job created: {job['job_id']}", tenant_id=tenant_id)
return job
else:
logger.error("Failed to create training job", tenant_id=tenant_id)
return None
except Exception as e:
logger.error(f"Error triggering model retraining: {e}", tenant_id=tenant_id)
return None
# Global instance
model_client = ModelClient()

View File

@@ -0,0 +1,435 @@
# ================================================================
# services/forecasting/app/services/performance_monitoring_service.py
# ================================================================
"""
Performance Monitoring Service
Monitors forecast accuracy over time and triggers actions when
performance degrades below acceptable thresholds.
"""
from typing import Dict, Any, List, Optional, Tuple
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, func, desc
from datetime import datetime, timedelta, timezone
import structlog
import uuid
from app.models.validation_run import ValidationRun
from app.models.predictions import ModelPerformanceMetric
from app.models.forecasts import Forecast
from shared.database.exceptions import DatabaseError
logger = structlog.get_logger()
class PerformanceMonitoringService:
"""Service for monitoring forecast performance and triggering improvements"""
# Configurable thresholds
MAPE_WARNING_THRESHOLD = 20.0 # Warning if MAPE > 20%
MAPE_CRITICAL_THRESHOLD = 30.0 # Critical if MAPE > 30%
MAPE_TREND_THRESHOLD = 5.0 # Alert if MAPE increases by > 5% over period
MIN_SAMPLES_FOR_ALERT = 5 # Minimum validations before alerting
TREND_LOOKBACK_DAYS = 30 # Days to analyze for trends
def __init__(self, db_session: AsyncSession):
self.db = db_session
async def get_accuracy_summary(
self,
tenant_id: uuid.UUID,
days: int = 30
) -> Dict[str, Any]:
"""
Get accuracy summary for recent period
Args:
tenant_id: Tenant identifier
days: Number of days to analyze
Returns:
Summary with overall metrics and trends
"""
try:
start_date = datetime.now(timezone.utc) - timedelta(days=days)
# Get recent validation runs
query = (
select(ValidationRun)
.where(
and_(
ValidationRun.tenant_id == tenant_id,
ValidationRun.status == "completed",
ValidationRun.started_at >= start_date,
ValidationRun.forecasts_with_actuals > 0 # Only runs with actual data
)
)
.order_by(desc(ValidationRun.started_at))
)
result = await self.db.execute(query)
runs = result.scalars().all()
if not runs:
return {
"status": "no_data",
"message": f"No validation runs found in last {days} days",
"period_days": days
}
# Calculate summary statistics
total_forecasts = sum(r.total_forecasts_evaluated for r in runs)
total_with_actuals = sum(r.forecasts_with_actuals for r in runs)
mape_values = [r.overall_mape for r in runs if r.overall_mape is not None]
mae_values = [r.overall_mae for r in runs if r.overall_mae is not None]
rmse_values = [r.overall_rmse for r in runs if r.overall_rmse is not None]
avg_mape = sum(mape_values) / len(mape_values) if mape_values else None
avg_mae = sum(mae_values) / len(mae_values) if mae_values else None
avg_rmse = sum(rmse_values) / len(rmse_values) if rmse_values else None
# Determine health status
health_status = "healthy"
if avg_mape and avg_mape > self.MAPE_CRITICAL_THRESHOLD:
health_status = "critical"
elif avg_mape and avg_mape > self.MAPE_WARNING_THRESHOLD:
health_status = "warning"
return {
"status": "ok",
"period_days": days,
"validation_runs": len(runs),
"total_forecasts_evaluated": total_forecasts,
"total_forecasts_with_actuals": total_with_actuals,
"coverage_percentage": round(
(total_with_actuals / total_forecasts * 100) if total_forecasts > 0 else 0, 2
),
"average_metrics": {
"mape": round(avg_mape, 2) if avg_mape else None,
"mae": round(avg_mae, 2) if avg_mae else None,
"rmse": round(avg_rmse, 2) if avg_rmse else None,
"accuracy_percentage": round(100 - avg_mape, 2) if avg_mape else None
},
"health_status": health_status,
"thresholds": {
"warning": self.MAPE_WARNING_THRESHOLD,
"critical": self.MAPE_CRITICAL_THRESHOLD
}
}
except Exception as e:
logger.error(
"Failed to get accuracy summary",
tenant_id=tenant_id,
error=str(e)
)
raise DatabaseError(f"Failed to get accuracy summary: {str(e)}")
async def detect_performance_degradation(
self,
tenant_id: uuid.UUID,
lookback_days: int = 30
) -> Dict[str, Any]:
"""
Detect if forecast performance is degrading over time
Args:
tenant_id: Tenant identifier
lookback_days: Days to analyze for trends
Returns:
Degradation analysis with recommendations
"""
try:
logger.info(
"Detecting performance degradation",
tenant_id=tenant_id,
lookback_days=lookback_days
)
start_date = datetime.now(timezone.utc) - timedelta(days=lookback_days)
# Get validation runs ordered by time
query = (
select(ValidationRun)
.where(
and_(
ValidationRun.tenant_id == tenant_id,
ValidationRun.status == "completed",
ValidationRun.started_at >= start_date,
ValidationRun.forecasts_with_actuals > 0
)
)
.order_by(ValidationRun.started_at)
)
result = await self.db.execute(query)
runs = list(result.scalars().all())
if len(runs) < self.MIN_SAMPLES_FOR_ALERT:
return {
"status": "insufficient_data",
"message": f"Need at least {self.MIN_SAMPLES_FOR_ALERT} validation runs",
"runs_found": len(runs)
}
# Split into first half and second half
midpoint = len(runs) // 2
first_half = runs[:midpoint]
second_half = runs[midpoint:]
# Calculate average MAPE for each half
first_half_mape = sum(
r.overall_mape for r in first_half if r.overall_mape
) / len([r for r in first_half if r.overall_mape])
second_half_mape = sum(
r.overall_mape for r in second_half if r.overall_mape
) / len([r for r in second_half if r.overall_mape])
mape_change = second_half_mape - first_half_mape
mape_change_percentage = (mape_change / first_half_mape * 100) if first_half_mape > 0 else 0
# Determine if degradation is significant
is_degrading = mape_change > self.MAPE_TREND_THRESHOLD
severity = "none"
if is_degrading:
if mape_change > self.MAPE_TREND_THRESHOLD * 2:
severity = "high"
elif mape_change > self.MAPE_TREND_THRESHOLD:
severity = "medium"
# Get products with worst performance
poor_products = await self._identify_poor_performers(tenant_id, lookback_days)
result = {
"status": "analyzed",
"period_days": lookback_days,
"samples_analyzed": len(runs),
"is_degrading": is_degrading,
"severity": severity,
"metrics": {
"first_period_mape": round(first_half_mape, 2),
"second_period_mape": round(second_half_mape, 2),
"mape_change": round(mape_change, 2),
"mape_change_percentage": round(mape_change_percentage, 2)
},
"poor_performers": poor_products,
"recommendations": []
}
# Add recommendations
if is_degrading:
result["recommendations"].append({
"action": "retrain_models",
"priority": "high" if severity == "high" else "medium",
"reason": f"MAPE increased by {abs(mape_change):.1f}% over {lookback_days} days"
})
if poor_products:
result["recommendations"].append({
"action": "retrain_poor_performers",
"priority": "high",
"reason": f"{len(poor_products)} products with MAPE > {self.MAPE_CRITICAL_THRESHOLD}%",
"products": poor_products[:10] # Top 10 worst
})
logger.info(
"Performance degradation analysis complete",
tenant_id=tenant_id,
is_degrading=is_degrading,
severity=severity,
poor_performers=len(poor_products)
)
return result
except Exception as e:
logger.error(
"Failed to detect performance degradation",
tenant_id=tenant_id,
error=str(e)
)
raise DatabaseError(f"Failed to detect degradation: {str(e)}")
async def _identify_poor_performers(
self,
tenant_id: uuid.UUID,
lookback_days: int
) -> List[Dict[str, Any]]:
"""Identify products/locations with poor accuracy"""
try:
start_date = datetime.now(timezone.utc) - timedelta(days=lookback_days)
# Get recent performance metrics grouped by product
query = select(
ModelPerformanceMetric.inventory_product_id,
func.avg(ModelPerformanceMetric.mape).label('avg_mape'),
func.avg(ModelPerformanceMetric.mae).label('avg_mae'),
func.count(ModelPerformanceMetric.id).label('sample_count')
).where(
and_(
ModelPerformanceMetric.tenant_id == tenant_id,
ModelPerformanceMetric.created_at >= start_date,
ModelPerformanceMetric.mape.isnot(None)
)
).group_by(
ModelPerformanceMetric.inventory_product_id
).having(
func.avg(ModelPerformanceMetric.mape) > self.MAPE_CRITICAL_THRESHOLD
).order_by(
desc(func.avg(ModelPerformanceMetric.mape))
).limit(20)
result = await self.db.execute(query)
poor_performers = []
for row in result.fetchall():
poor_performers.append({
"inventory_product_id": str(row.inventory_product_id),
"avg_mape": round(row.avg_mape, 2),
"avg_mae": round(row.avg_mae, 2),
"sample_count": row.sample_count,
"requires_retraining": True
})
return poor_performers
except Exception as e:
logger.error("Failed to identify poor performers", error=str(e))
return []
async def check_model_age(
self,
tenant_id: uuid.UUID,
max_age_days: int = 30
) -> Dict[str, Any]:
"""
Check if models are outdated and need retraining
Args:
tenant_id: Tenant identifier
max_age_days: Maximum acceptable model age
Returns:
Analysis of model ages
"""
try:
# Get distinct models used in recent forecasts
cutoff_date = datetime.now(timezone.utc) - timedelta(days=7)
query = select(
Forecast.model_id,
Forecast.model_version,
Forecast.inventory_product_id,
func.max(Forecast.created_at).label('last_used'),
func.count(Forecast.id).label('forecast_count')
).where(
and_(
Forecast.tenant_id == tenant_id,
Forecast.created_at >= cutoff_date
)
).group_by(
Forecast.model_id,
Forecast.model_version,
Forecast.inventory_product_id
)
result = await self.db.execute(query)
models_info = []
outdated_count = 0
for row in result.fetchall():
# Check age against training service (would need to query training service)
# For now, assume models older than max_age_days need retraining
models_info.append({
"model_id": row.model_id,
"model_version": row.model_version,
"inventory_product_id": str(row.inventory_product_id),
"last_used": row.last_used.isoformat(),
"forecast_count": row.forecast_count
})
return {
"status": "analyzed",
"models_in_use": len(models_info),
"outdated_models": outdated_count,
"max_age_days": max_age_days,
"models": models_info[:20] # Top 20
}
except Exception as e:
logger.error("Failed to check model age", error=str(e))
raise DatabaseError(f"Failed to check model age: {str(e)}")
async def generate_performance_report(
self,
tenant_id: uuid.UUID,
days: int = 30
) -> Dict[str, Any]:
"""
Generate comprehensive performance report
Args:
tenant_id: Tenant identifier
days: Analysis period
Returns:
Complete performance report with recommendations
"""
try:
logger.info(
"Generating performance report",
tenant_id=tenant_id,
days=days
)
# Get all analyses
summary = await self.get_accuracy_summary(tenant_id, days)
degradation = await self.detect_performance_degradation(tenant_id, days)
model_age = await self.check_model_age(tenant_id)
# Compile recommendations
all_recommendations = []
if summary.get("health_status") == "critical":
all_recommendations.append({
"priority": "critical",
"action": "immediate_review",
"reason": f"Overall MAPE is {summary['average_metrics']['mape']}%",
"details": "Forecast accuracy is critically low"
})
all_recommendations.extend(degradation.get("recommendations", []))
report = {
"generated_at": datetime.now(timezone.utc).isoformat(),
"tenant_id": str(tenant_id),
"analysis_period_days": days,
"summary": summary,
"degradation_analysis": degradation,
"model_age_analysis": model_age,
"recommendations": all_recommendations,
"requires_action": len(all_recommendations) > 0
}
logger.info(
"Performance report generated",
tenant_id=tenant_id,
health_status=summary.get("health_status"),
recommendations_count=len(all_recommendations)
)
return report
except Exception as e:
logger.error(
"Failed to generate performance report",
tenant_id=tenant_id,
error=str(e)
)
raise DatabaseError(f"Failed to generate report: {str(e)}")

View File

@@ -0,0 +1,99 @@
"""
POI Feature Service for Forecasting
Fetches POI features for use in demand forecasting predictions.
Ensures feature consistency between training and prediction.
"""
from typing import Dict, Any, Optional
import structlog
from shared.clients.external_client import ExternalServiceClient
logger = structlog.get_logger()
class POIFeatureService:
"""
POI feature service for forecasting.
Fetches POI context from External service to ensure
prediction uses the same features as training.
"""
def __init__(self, external_client: ExternalServiceClient = None):
"""
Initialize POI feature service.
Args:
external_client: External service client instance (optional)
"""
if external_client is None:
from app.core.config import settings
self.external_client = ExternalServiceClient(settings, "forecasting-service")
else:
self.external_client = external_client
async def get_poi_features(
self,
tenant_id: str
) -> Dict[str, Any]:
"""
Get POI features for tenant.
Args:
tenant_id: Tenant UUID
Returns:
Dictionary with POI features or empty dict if not available
"""
try:
result = await self.external_client.get_poi_context(tenant_id)
if result:
poi_context = result.get("poi_context", {})
ml_features = poi_context.get("ml_features", {})
logger.info(
"POI features retrieved for forecasting",
tenant_id=tenant_id,
feature_count=len(ml_features)
)
return ml_features
else:
logger.warning(
"No POI context found for tenant",
tenant_id=tenant_id
)
return {}
except Exception as e:
logger.error(
"Failed to fetch POI features for forecasting",
tenant_id=tenant_id,
error=str(e),
exc_info=True
)
return {}
async def check_poi_service_health(self) -> bool:
"""
Check if POI service is accessible through the external client.
Returns:
True if service is healthy, False otherwise
"""
try:
# Test the external service health by attempting to get POI context for a dummy tenant
# This will go through the proper authentication and routing
dummy_context = await self.external_client.get_poi_context("test-tenant")
# If we can successfully make a request (even if it returns None for missing tenant),
# it means the service is accessible
return True
except Exception as e:
logger.error(
"POI service health check failed",
error=str(e)
)
return False

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,486 @@
# ================================================================
# services/forecasting/app/services/retraining_trigger_service.py
# ================================================================
"""
Retraining Trigger Service
Automatically triggers model retraining based on performance metrics,
accuracy degradation, or data availability.
"""
from typing import Dict, Any, List, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from datetime import datetime, timezone
import structlog
import uuid
from app.services.performance_monitoring_service import PerformanceMonitoringService
from shared.clients.training_client import TrainingServiceClient
from shared.config.base import BaseServiceSettings
from shared.database.exceptions import DatabaseError
logger = structlog.get_logger()
class RetrainingTriggerService:
"""Service for triggering automatic model retraining"""
def __init__(self, db_session: AsyncSession):
self.db = db_session
self.performance_service = PerformanceMonitoringService(db_session)
# Initialize training client
config = BaseServiceSettings()
self.training_client = TrainingServiceClient(config, calling_service_name="forecasting")
async def evaluate_and_trigger_retraining(
self,
tenant_id: uuid.UUID,
auto_trigger: bool = True
) -> Dict[str, Any]:
"""
Evaluate performance and trigger retraining if needed
Args:
tenant_id: Tenant identifier
auto_trigger: Whether to automatically trigger retraining
Returns:
Evaluation results and retraining actions taken
"""
try:
logger.info(
"Evaluating retraining needs",
tenant_id=tenant_id,
auto_trigger=auto_trigger
)
# Generate performance report
report = await self.performance_service.generate_performance_report(
tenant_id=tenant_id,
days=30
)
if not report.get("requires_action"):
logger.info(
"No retraining required",
tenant_id=tenant_id,
health_status=report["summary"].get("health_status")
)
return {
"status": "no_action_needed",
"tenant_id": str(tenant_id),
"health_status": report["summary"].get("health_status"),
"report": report
}
# Extract products that need retraining
products_to_retrain = []
recommendations = report.get("recommendations", [])
for rec in recommendations:
if rec.get("action") == "retrain_poor_performers":
products_to_retrain.extend(rec.get("products", []))
if not products_to_retrain and auto_trigger:
# If degradation detected but no specific products, consider retraining all
degradation = report.get("degradation_analysis", {})
if degradation.get("is_degrading") and degradation.get("severity") in ["high", "medium"]:
logger.info(
"General degradation detected, considering full retraining",
tenant_id=tenant_id,
severity=degradation.get("severity")
)
retraining_results = []
if auto_trigger and products_to_retrain:
# Trigger retraining for poor performers
for product in products_to_retrain:
try:
result = await self._trigger_product_retraining(
tenant_id=tenant_id,
inventory_product_id=uuid.UUID(product["inventory_product_id"]),
reason=f"MAPE {product['avg_mape']}% exceeds threshold",
priority="high"
)
retraining_results.append(result)
except Exception as e:
logger.error(
"Failed to trigger retraining for product",
product_id=product["inventory_product_id"],
error=str(e)
)
retraining_results.append({
"product_id": product["inventory_product_id"],
"status": "failed",
"error": str(e)
})
logger.info(
"Retraining evaluation complete",
tenant_id=tenant_id,
products_evaluated=len(products_to_retrain),
retraining_triggered=len(retraining_results)
)
return {
"status": "evaluated",
"tenant_id": str(tenant_id),
"requires_action": report.get("requires_action"),
"products_needing_retraining": len(products_to_retrain),
"retraining_triggered": len(retraining_results),
"auto_trigger_enabled": auto_trigger,
"retraining_results": retraining_results,
"performance_report": report
}
except Exception as e:
logger.error(
"Failed to evaluate and trigger retraining",
tenant_id=tenant_id,
error=str(e)
)
raise DatabaseError(f"Failed to evaluate retraining: {str(e)}")
async def _trigger_product_retraining(
self,
tenant_id: uuid.UUID,
inventory_product_id: uuid.UUID,
reason: str,
priority: str = "normal"
) -> Dict[str, Any]:
"""
Trigger retraining for a specific product
Args:
tenant_id: Tenant identifier
inventory_product_id: Product to retrain
reason: Reason for retraining
priority: Priority level (low, normal, high)
Returns:
Retraining trigger result
"""
try:
logger.info(
"Triggering product retraining",
tenant_id=tenant_id,
product_id=inventory_product_id,
reason=reason,
priority=priority
)
# Call training service to trigger retraining
result = await self.training_client.trigger_retrain(
tenant_id=str(tenant_id),
inventory_product_id=str(inventory_product_id),
reason=reason,
priority=priority
)
if result:
logger.info(
"Retraining triggered successfully",
tenant_id=tenant_id,
product_id=inventory_product_id,
training_job_id=result.get("training_job_id")
)
return {
"status": "triggered",
"product_id": str(inventory_product_id),
"training_job_id": result.get("training_job_id"),
"reason": reason,
"priority": priority,
"triggered_at": datetime.now(timezone.utc).isoformat()
}
else:
logger.warning(
"Retraining trigger returned no result",
tenant_id=tenant_id,
product_id=inventory_product_id
)
return {
"status": "no_response",
"product_id": str(inventory_product_id),
"reason": reason
}
except Exception as e:
logger.error(
"Failed to trigger product retraining",
tenant_id=tenant_id,
product_id=inventory_product_id,
error=str(e)
)
return {
"status": "failed",
"product_id": str(inventory_product_id),
"error": str(e),
"reason": reason
}
async def trigger_bulk_retraining(
self,
tenant_id: uuid.UUID,
product_ids: List[uuid.UUID],
reason: str = "Bulk retraining requested"
) -> Dict[str, Any]:
"""
Trigger retraining for multiple products
Args:
tenant_id: Tenant identifier
product_ids: List of products to retrain
reason: Reason for bulk retraining
Returns:
Bulk retraining results
"""
try:
logger.info(
"Triggering bulk retraining",
tenant_id=tenant_id,
product_count=len(product_ids)
)
results = []
for product_id in product_ids:
result = await self._trigger_product_retraining(
tenant_id=tenant_id,
inventory_product_id=product_id,
reason=reason,
priority="normal"
)
results.append(result)
successful = sum(1 for r in results if r["status"] == "triggered")
logger.info(
"Bulk retraining completed",
tenant_id=tenant_id,
total=len(product_ids),
successful=successful,
failed=len(product_ids) - successful
)
return {
"status": "completed",
"tenant_id": str(tenant_id),
"total_products": len(product_ids),
"successful": successful,
"failed": len(product_ids) - successful,
"results": results
}
except Exception as e:
logger.error(
"Bulk retraining failed",
tenant_id=tenant_id,
error=str(e)
)
raise DatabaseError(f"Bulk retraining failed: {str(e)}")
async def check_and_trigger_scheduled_retraining(
self,
tenant_id: uuid.UUID,
max_model_age_days: int = 30
) -> Dict[str, Any]:
"""
Check model ages and trigger retraining for outdated models
Args:
tenant_id: Tenant identifier
max_model_age_days: Maximum acceptable model age
Returns:
Scheduled retraining results
"""
try:
logger.info(
"Checking for scheduled retraining needs",
tenant_id=tenant_id,
max_model_age_days=max_model_age_days
)
# Get model age analysis
model_age_analysis = await self.performance_service.check_model_age(
tenant_id=tenant_id,
max_age_days=max_model_age_days
)
outdated_count = model_age_analysis.get("outdated_models", 0)
if outdated_count == 0:
logger.info(
"No outdated models found",
tenant_id=tenant_id
)
return {
"status": "no_action_needed",
"tenant_id": str(tenant_id),
"outdated_models": 0
}
# Trigger retraining for outdated models
try:
from shared.clients.training_client import TrainingServiceClient
from shared.config.base import get_settings
from shared.messaging import get_rabbitmq_client
config = get_settings()
training_client = TrainingServiceClient(config, "forecasting")
# Get list of models that need retraining
outdated_models = await training_client.get_outdated_models(
tenant_id=str(tenant_id),
max_age_days=max_model_age_days,
min_accuracy=0.85, # Configurable threshold
min_new_data_points=1000 # Configurable threshold
)
if not outdated_models:
logger.info("No specific models returned for retraining", tenant_id=tenant_id)
return {
"status": "no_models_found",
"tenant_id": str(tenant_id),
"outdated_models": outdated_count
}
# Publish retraining events to RabbitMQ for each model
rabbitmq_client = get_rabbitmq_client()
triggered_models = []
if rabbitmq_client:
for model in outdated_models:
try:
import uuid as uuid_module
from datetime import datetime
retraining_event = {
"event_id": str(uuid_module.uuid4()),
"event_type": "training.retrain.requested",
"timestamp": datetime.utcnow().isoformat(),
"tenant_id": str(tenant_id),
"data": {
"model_id": model.get('id'),
"product_id": model.get('product_id'),
"model_type": model.get('model_type'),
"current_accuracy": model.get('accuracy'),
"model_age_days": model.get('age_days'),
"new_data_points": model.get('new_data_points', 0),
"trigger_reason": model.get('trigger_reason', 'scheduled_check'),
"priority": model.get('priority', 'normal'),
"requested_by": "system_scheduled_check"
}
}
await rabbitmq_client.publish_event(
exchange_name="training.events",
routing_key="training.retrain.requested",
event_data=retraining_event
)
triggered_models.append({
'model_id': model.get('id'),
'product_id': model.get('product_id'),
'event_id': retraining_event['event_id']
})
logger.info(
"Published retraining request",
model_id=model.get('id'),
product_id=model.get('product_id'),
event_id=retraining_event['event_id'],
trigger_reason=model.get('trigger_reason')
)
except Exception as publish_error:
logger.error(
"Failed to publish retraining event",
model_id=model.get('id'),
error=str(publish_error)
)
# Continue with other models even if one fails
else:
logger.warning(
"RabbitMQ client not available, cannot trigger retraining",
tenant_id=tenant_id
)
return {
"status": "retraining_triggered",
"tenant_id": str(tenant_id),
"outdated_models": outdated_count,
"triggered_count": len(triggered_models),
"triggered_models": triggered_models,
"message": f"Triggered retraining for {len(triggered_models)} models"
}
except Exception as trigger_error:
logger.error(
"Failed to trigger retraining",
tenant_id=tenant_id,
error=str(trigger_error),
exc_info=True
)
# Return analysis result even if triggering failed
return {
"status": "trigger_failed",
"tenant_id": str(tenant_id),
"outdated_models": outdated_count,
"error": str(trigger_error),
"message": "Analysis complete but failed to trigger retraining"
}
except Exception as e:
logger.error(
"Scheduled retraining check failed",
tenant_id=tenant_id,
error=str(e)
)
raise DatabaseError(f"Scheduled retraining check failed: {str(e)}")
async def get_retraining_recommendations(
self,
tenant_id: uuid.UUID
) -> Dict[str, Any]:
"""
Get retraining recommendations without triggering
Args:
tenant_id: Tenant identifier
Returns:
Recommendations for manual review
"""
try:
# Evaluate without auto-triggering
result = await self.evaluate_and_trigger_retraining(
tenant_id=tenant_id,
auto_trigger=False
)
# Extract just the recommendations
report = result.get("performance_report", {})
recommendations = report.get("recommendations", [])
return {
"tenant_id": str(tenant_id),
"generated_at": datetime.now(timezone.utc).isoformat(),
"requires_action": result.get("requires_action", False),
"recommendations": recommendations,
"summary": report.get("summary", {}),
"degradation_detected": report.get("degradation_analysis", {}).get("is_degrading", False)
}
except Exception as e:
logger.error(
"Failed to get retraining recommendations",
tenant_id=tenant_id,
error=str(e)
)
raise DatabaseError(f"Failed to get recommendations: {str(e)}")

View File

@@ -0,0 +1,97 @@
# ================================================================
# services/forecasting/app/services/sales_client.py
# ================================================================
"""
Sales Client for Forecasting Service
Wrapper around shared sales client with forecasting-specific methods
"""
from typing import List, Dict, Any
from datetime import datetime
import structlog
import uuid
from shared.clients.sales_client import SalesServiceClient
from shared.config.base import BaseServiceSettings
logger = structlog.get_logger()
class SalesClient:
"""Client for fetching sales data from sales service"""
def __init__(self):
"""Initialize sales client"""
# Load configuration
config = BaseServiceSettings()
self.client = SalesServiceClient(config, calling_service_name="forecasting")
async def get_sales_by_date_range(
self,
tenant_id: uuid.UUID,
start_date: datetime,
end_date: datetime,
product_id: uuid.UUID = None
) -> List[Dict[str, Any]]:
"""
Get sales data for a date range
Args:
tenant_id: Tenant identifier
start_date: Start of date range
end_date: End of date range
product_id: Optional product filter
Returns:
List of sales records
"""
try:
# Convert datetime to ISO format strings
start_date_str = start_date.isoformat() if start_date else None
end_date_str = end_date.isoformat() if end_date else None
product_id_str = str(product_id) if product_id else None
# Use the paginated method to get all sales data
sales_data = await self.client.get_all_sales_data(
tenant_id=str(tenant_id),
start_date=start_date_str,
end_date=end_date_str,
product_id=product_id_str,
aggregation="none", # Get raw data without aggregation
page_size=1000,
max_pages=100
)
if not sales_data:
logger.info(
"No sales data found for date range",
tenant_id=tenant_id,
start_date=start_date.isoformat(),
end_date=end_date.isoformat()
)
return []
logger.info(
"Retrieved sales data",
tenant_id=tenant_id,
records_count=len(sales_data),
start_date=start_date.isoformat(),
end_date=end_date.isoformat()
)
return sales_data
except Exception as e:
logger.error(
"Failed to fetch sales data",
tenant_id=tenant_id,
error=str(e),
error_type=type(e).__name__
)
# Return empty list instead of raising to allow validation to continue
return []
async def close(self):
"""Close the client connection"""
if hasattr(self.client, 'close'):
await self.client.close()

View File

@@ -0,0 +1,240 @@
# services/forecasting/app/services/tenant_deletion_service.py
"""
Tenant Data Deletion Service for Forecasting Service
Handles deletion of all forecasting-related data for a tenant
"""
from typing import Dict
from sqlalchemy import select, func, delete
from sqlalchemy.ext.asyncio import AsyncSession
import structlog
from shared.services.tenant_deletion import (
BaseTenantDataDeletionService,
TenantDataDeletionResult
)
from app.models import (
Forecast,
PredictionBatch,
ModelPerformanceMetric,
PredictionCache,
AuditLog
)
logger = structlog.get_logger(__name__)
class ForecastingTenantDeletionService(BaseTenantDataDeletionService):
"""Service for deleting all forecasting-related data for a tenant"""
def __init__(self, db: AsyncSession):
self.db = db
self.service_name = "forecasting"
async def get_tenant_data_preview(self, tenant_id: str) -> Dict[str, int]:
"""
Get counts of what would be deleted for a tenant (dry-run)
Args:
tenant_id: The tenant ID to preview deletion for
Returns:
Dictionary with entity names and their counts
"""
logger.info("forecasting.tenant_deletion.preview", tenant_id=tenant_id)
preview = {}
try:
# Count forecasts
forecast_count = await self.db.scalar(
select(func.count(Forecast.id)).where(
Forecast.tenant_id == tenant_id
)
)
preview["forecasts"] = forecast_count or 0
# Count prediction batches
batch_count = await self.db.scalar(
select(func.count(PredictionBatch.id)).where(
PredictionBatch.tenant_id == tenant_id
)
)
preview["prediction_batches"] = batch_count or 0
# Count model performance metrics
metric_count = await self.db.scalar(
select(func.count(ModelPerformanceMetric.id)).where(
ModelPerformanceMetric.tenant_id == tenant_id
)
)
preview["model_performance_metrics"] = metric_count or 0
# Count prediction cache entries
cache_count = await self.db.scalar(
select(func.count(PredictionCache.id)).where(
PredictionCache.tenant_id == tenant_id
)
)
preview["prediction_cache"] = cache_count or 0
# Count audit logs
audit_count = await self.db.scalar(
select(func.count(AuditLog.id)).where(
AuditLog.tenant_id == tenant_id
)
)
preview["audit_logs"] = audit_count or 0
logger.info(
"forecasting.tenant_deletion.preview_complete",
tenant_id=tenant_id,
preview=preview
)
except Exception as e:
logger.error(
"forecasting.tenant_deletion.preview_error",
tenant_id=tenant_id,
error=str(e),
exc_info=True
)
raise
return preview
async def delete_tenant_data(self, tenant_id: str) -> TenantDataDeletionResult:
"""
Permanently delete all forecasting data for a tenant
Deletion order:
1. PredictionCache (independent)
2. ModelPerformanceMetric (independent)
3. PredictionBatch (independent)
4. Forecast (independent)
5. AuditLog (independent)
Note: All tables are independent with no foreign key relationships
Args:
tenant_id: The tenant ID to delete data for
Returns:
TenantDataDeletionResult with deletion counts and any errors
"""
logger.info("forecasting.tenant_deletion.started", tenant_id=tenant_id)
result = TenantDataDeletionResult(tenant_id=tenant_id, service_name=self.service_name)
try:
# Step 1: Delete prediction cache
logger.info("forecasting.tenant_deletion.deleting_cache", tenant_id=tenant_id)
cache_result = await self.db.execute(
delete(PredictionCache).where(
PredictionCache.tenant_id == tenant_id
)
)
result.deleted_counts["prediction_cache"] = cache_result.rowcount
logger.info(
"forecasting.tenant_deletion.cache_deleted",
tenant_id=tenant_id,
count=cache_result.rowcount
)
# Step 2: Delete model performance metrics
logger.info("forecasting.tenant_deletion.deleting_metrics", tenant_id=tenant_id)
metrics_result = await self.db.execute(
delete(ModelPerformanceMetric).where(
ModelPerformanceMetric.tenant_id == tenant_id
)
)
result.deleted_counts["model_performance_metrics"] = metrics_result.rowcount
logger.info(
"forecasting.tenant_deletion.metrics_deleted",
tenant_id=tenant_id,
count=metrics_result.rowcount
)
# Step 3: Delete prediction batches
logger.info("forecasting.tenant_deletion.deleting_batches", tenant_id=tenant_id)
batches_result = await self.db.execute(
delete(PredictionBatch).where(
PredictionBatch.tenant_id == tenant_id
)
)
result.deleted_counts["prediction_batches"] = batches_result.rowcount
logger.info(
"forecasting.tenant_deletion.batches_deleted",
tenant_id=tenant_id,
count=batches_result.rowcount
)
# Step 4: Delete forecasts
logger.info("forecasting.tenant_deletion.deleting_forecasts", tenant_id=tenant_id)
forecasts_result = await self.db.execute(
delete(Forecast).where(
Forecast.tenant_id == tenant_id
)
)
result.deleted_counts["forecasts"] = forecasts_result.rowcount
logger.info(
"forecasting.tenant_deletion.forecasts_deleted",
tenant_id=tenant_id,
count=forecasts_result.rowcount
)
# Step 5: Delete audit logs
logger.info("forecasting.tenant_deletion.deleting_audit_logs", tenant_id=tenant_id)
audit_result = await self.db.execute(
delete(AuditLog).where(
AuditLog.tenant_id == tenant_id
)
)
result.deleted_counts["audit_logs"] = audit_result.rowcount
logger.info(
"forecasting.tenant_deletion.audit_logs_deleted",
tenant_id=tenant_id,
count=audit_result.rowcount
)
# Commit the transaction
await self.db.commit()
# Calculate total deleted
total_deleted = sum(result.deleted_counts.values())
logger.info(
"forecasting.tenant_deletion.completed",
tenant_id=tenant_id,
total_deleted=total_deleted,
breakdown=result.deleted_counts
)
result.success = True
except Exception as e:
await self.db.rollback()
error_msg = f"Failed to delete forecasting data for tenant {tenant_id}: {str(e)}"
logger.error(
"forecasting.tenant_deletion.failed",
tenant_id=tenant_id,
error=str(e),
exc_info=True
)
result.errors.append(error_msg)
result.success = False
return result
def get_forecasting_tenant_deletion_service(
db: AsyncSession
) -> ForecastingTenantDeletionService:
"""
Factory function to create ForecastingTenantDeletionService instance
Args:
db: AsyncSession database session
Returns:
ForecastingTenantDeletionService instance
"""
return ForecastingTenantDeletionService(db)

View File

@@ -0,0 +1,586 @@
# ================================================================
# services/forecasting/app/services/validation_service.py
# ================================================================
"""
Forecast Validation Service
Compares historical forecasts with actual sales data to:
1. Calculate accuracy metrics (MAE, MAPE, RMSE, R², accuracy percentage)
2. Store performance metrics in the database
3. Track validation runs for audit purposes
4. Enable continuous model improvement
"""
from typing import Dict, Any, List, Optional, Tuple
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, func, Date
from datetime import datetime, timedelta, timezone
import structlog
import math
import uuid
from app.models.forecasts import Forecast
from app.models.predictions import ModelPerformanceMetric
from app.models.validation_run import ValidationRun
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class ValidationService:
"""Service for validating forecasts against actual sales data"""
def __init__(self, db_session: AsyncSession):
self.db = db_session
async def validate_date_range(
self,
tenant_id: uuid.UUID,
start_date: datetime,
end_date: datetime,
orchestration_run_id: Optional[uuid.UUID] = None,
triggered_by: str = "manual"
) -> Dict[str, Any]:
"""
Validate forecasts against actual sales for a date range
Args:
tenant_id: Tenant identifier
start_date: Start of validation period
end_date: End of validation period
orchestration_run_id: Optional link to orchestration run
triggered_by: How this validation was triggered (manual, orchestrator, scheduled)
Returns:
Dictionary with validation results and metrics
"""
validation_run = None
try:
# Create validation run record
validation_run = ValidationRun(
tenant_id=tenant_id,
orchestration_run_id=orchestration_run_id,
validation_start_date=start_date,
validation_end_date=end_date,
status="running",
triggered_by=triggered_by,
execution_mode="batch"
)
self.db.add(validation_run)
await self.db.flush()
logger.info(
"Starting forecast validation",
validation_run_id=validation_run.id,
tenant_id=tenant_id,
start_date=start_date.isoformat(),
end_date=end_date.isoformat()
)
# Fetch forecasts with matching sales data
forecasts_with_sales = await self._fetch_forecasts_with_sales(
tenant_id, start_date, end_date
)
if not forecasts_with_sales:
logger.warning(
"No forecasts with matching sales data found",
tenant_id=tenant_id,
start_date=start_date.isoformat(),
end_date=end_date.isoformat()
)
validation_run.status = "completed"
validation_run.completed_at = datetime.now(timezone.utc)
validation_run.duration_seconds = (
validation_run.completed_at - validation_run.started_at
).total_seconds()
await self.db.commit()
return {
"validation_run_id": str(validation_run.id),
"status": "completed",
"message": "No forecasts with matching sales data found",
"forecasts_evaluated": 0,
"metrics_created": 0
}
# Calculate metrics and create performance records
metrics_results = await self._calculate_and_store_metrics(
forecasts_with_sales, validation_run.id
)
# Update validation run with results
validation_run.total_forecasts_evaluated = metrics_results["total_evaluated"]
validation_run.forecasts_with_actuals = metrics_results["with_actuals"]
validation_run.forecasts_without_actuals = metrics_results["without_actuals"]
validation_run.overall_mae = metrics_results["overall_mae"]
validation_run.overall_mape = metrics_results["overall_mape"]
validation_run.overall_rmse = metrics_results["overall_rmse"]
validation_run.overall_r2_score = metrics_results["overall_r2_score"]
validation_run.overall_accuracy_percentage = metrics_results["overall_accuracy_percentage"]
validation_run.total_predicted_demand = metrics_results["total_predicted"]
validation_run.total_actual_demand = metrics_results["total_actual"]
validation_run.metrics_by_product = metrics_results["metrics_by_product"]
validation_run.metrics_by_location = metrics_results["metrics_by_location"]
validation_run.metrics_records_created = metrics_results["metrics_created"]
validation_run.status = "completed"
validation_run.completed_at = datetime.now(timezone.utc)
validation_run.duration_seconds = (
validation_run.completed_at - validation_run.started_at
).total_seconds()
await self.db.commit()
logger.info(
"Forecast validation completed successfully",
validation_run_id=validation_run.id,
forecasts_evaluated=validation_run.total_forecasts_evaluated,
metrics_created=validation_run.metrics_records_created,
overall_mape=validation_run.overall_mape,
duration_seconds=validation_run.duration_seconds
)
# Extract poor accuracy products (MAPE > 30%)
poor_accuracy_products = []
if validation_run.metrics_by_product:
for product_id, product_metrics in validation_run.metrics_by_product.items():
if product_metrics.get("mape", 0) > 30:
poor_accuracy_products.append({
"product_id": product_id,
"mape": product_metrics.get("mape"),
"mae": product_metrics.get("mae"),
"accuracy_percentage": product_metrics.get("accuracy_percentage")
})
return {
"validation_run_id": str(validation_run.id),
"status": "completed",
"forecasts_evaluated": validation_run.total_forecasts_evaluated,
"forecasts_with_actuals": validation_run.forecasts_with_actuals,
"forecasts_without_actuals": validation_run.forecasts_without_actuals,
"metrics_created": validation_run.metrics_records_created,
"overall_metrics": {
"mae": validation_run.overall_mae,
"mape": validation_run.overall_mape,
"rmse": validation_run.overall_rmse,
"r2_score": validation_run.overall_r2_score,
"accuracy_percentage": validation_run.overall_accuracy_percentage
},
"total_predicted_demand": validation_run.total_predicted_demand,
"total_actual_demand": validation_run.total_actual_demand,
"duration_seconds": validation_run.duration_seconds,
"poor_accuracy_products": poor_accuracy_products,
"metrics_by_product": validation_run.metrics_by_product,
"metrics_by_location": validation_run.metrics_by_location
}
except Exception as e:
logger.error(
"Forecast validation failed",
tenant_id=tenant_id,
error=str(e),
error_type=type(e).__name__
)
if validation_run:
validation_run.status = "failed"
validation_run.error_message = str(e)
validation_run.error_details = {"error_type": type(e).__name__}
validation_run.completed_at = datetime.now(timezone.utc)
validation_run.duration_seconds = (
validation_run.completed_at - validation_run.started_at
).total_seconds()
await self.db.commit()
raise DatabaseError(f"Forecast validation failed: {str(e)}")
async def validate_yesterday(
self,
tenant_id: uuid.UUID,
orchestration_run_id: Optional[uuid.UUID] = None,
triggered_by: str = "orchestrator"
) -> Dict[str, Any]:
"""
Convenience method to validate yesterday's forecasts
Args:
tenant_id: Tenant identifier
orchestration_run_id: Optional link to orchestration run
triggered_by: How this validation was triggered
Returns:
Dictionary with validation results
"""
yesterday = datetime.now(timezone.utc) - timedelta(days=1)
start_date = yesterday.replace(hour=0, minute=0, second=0, microsecond=0)
end_date = yesterday.replace(hour=23, minute=59, second=59, microsecond=999999)
return await self.validate_date_range(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
orchestration_run_id=orchestration_run_id,
triggered_by=triggered_by
)
async def _fetch_forecasts_with_sales(
self,
tenant_id: uuid.UUID,
start_date: datetime,
end_date: datetime
) -> List[Dict[str, Any]]:
"""
Fetch forecasts with their corresponding actual sales data
Returns list of dictionaries containing forecast and sales data
"""
try:
# Import here to avoid circular dependency
from app.services.sales_client import SalesClient
# Query to get all forecasts in the date range
query = select(Forecast).where(
and_(
Forecast.tenant_id == tenant_id,
func.cast(Forecast.forecast_date, Date) >= start_date.date(),
func.cast(Forecast.forecast_date, Date) <= end_date.date()
)
).order_by(Forecast.forecast_date, Forecast.inventory_product_id)
result = await self.db.execute(query)
forecasts = result.scalars().all()
if not forecasts:
return []
# Fetch actual sales data from sales service
sales_client = SalesClient()
sales_data = await sales_client.get_sales_by_date_range(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date
)
# Create lookup dict: (product_id, date) -> sales quantity
sales_lookup = {}
for sale in sales_data:
sale_date = sale['date']
if isinstance(sale_date, str):
sale_date = datetime.fromisoformat(sale_date.replace('Z', '+00:00'))
key = (str(sale['inventory_product_id']), sale_date.date())
# Sum quantities if multiple sales records for same product/date
if key in sales_lookup:
sales_lookup[key]['quantity_sold'] += sale['quantity_sold']
else:
sales_lookup[key] = sale
# Match forecasts with sales data
forecasts_with_sales = []
for forecast in forecasts:
forecast_date = forecast.forecast_date.date() if hasattr(forecast.forecast_date, 'date') else forecast.forecast_date
key = (str(forecast.inventory_product_id), forecast_date)
sales_record = sales_lookup.get(key)
forecasts_with_sales.append({
"forecast_id": forecast.id,
"tenant_id": forecast.tenant_id,
"inventory_product_id": forecast.inventory_product_id,
"product_name": forecast.product_name,
"location": forecast.location,
"forecast_date": forecast.forecast_date,
"predicted_demand": forecast.predicted_demand,
"confidence_lower": forecast.confidence_lower,
"confidence_upper": forecast.confidence_upper,
"model_id": forecast.model_id,
"model_version": forecast.model_version,
"algorithm": forecast.algorithm,
"created_at": forecast.created_at,
"actual_sales": sales_record['quantity_sold'] if sales_record else None,
"has_actual_data": sales_record is not None
})
return forecasts_with_sales
except Exception as e:
logger.error(
"Failed to fetch forecasts with sales data",
tenant_id=tenant_id,
error=str(e)
)
raise DatabaseError(f"Failed to fetch forecast and sales data: {str(e)}")
async def _calculate_and_store_metrics(
self,
forecasts_with_sales: List[Dict[str, Any]],
validation_run_id: uuid.UUID
) -> Dict[str, Any]:
"""
Calculate accuracy metrics and store them in the database
Returns summary of metrics calculated
"""
# Separate forecasts with and without actual data
forecasts_with_actuals = [f for f in forecasts_with_sales if f["has_actual_data"]]
forecasts_without_actuals = [f for f in forecasts_with_sales if not f["has_actual_data"]]
if not forecasts_with_actuals:
return {
"total_evaluated": len(forecasts_with_sales),
"with_actuals": 0,
"without_actuals": len(forecasts_without_actuals),
"metrics_created": 0,
"overall_mae": None,
"overall_mape": None,
"overall_rmse": None,
"overall_r2_score": None,
"overall_accuracy_percentage": None,
"total_predicted": 0.0,
"total_actual": 0.0,
"metrics_by_product": {},
"metrics_by_location": {}
}
# Calculate individual metrics and prepare for bulk insert
performance_metrics = []
errors = []
for forecast in forecasts_with_actuals:
predicted = forecast["predicted_demand"]
actual = forecast["actual_sales"]
# Calculate error
error = abs(predicted - actual)
errors.append(error)
# Calculate percentage error (avoid division by zero)
percentage_error = (error / actual * 100) if actual > 0 else 0.0
# Calculate individual metrics
mae = error
mape = percentage_error
rmse = error # Will be squared and averaged later
# Calculate accuracy percentage (100% - MAPE, capped at 0)
accuracy_percentage = max(0.0, 100.0 - mape)
# Create performance metric record
metric = ModelPerformanceMetric(
model_id=uuid.UUID(forecast["model_id"]) if isinstance(forecast["model_id"], str) else forecast["model_id"],
tenant_id=forecast["tenant_id"],
inventory_product_id=forecast["inventory_product_id"],
mae=mae,
mape=mape,
rmse=rmse,
accuracy_score=accuracy_percentage / 100.0, # Store as 0-1 scale
evaluation_date=forecast["forecast_date"],
evaluation_period_start=forecast["forecast_date"],
evaluation_period_end=forecast["forecast_date"],
sample_size=1
)
performance_metrics.append(metric)
# Bulk insert all performance metrics
if performance_metrics:
self.db.add_all(performance_metrics)
await self.db.flush()
# Calculate overall metrics
overall_metrics = self._calculate_overall_metrics(forecasts_with_actuals)
# Calculate metrics by product
metrics_by_product = self._calculate_metrics_by_dimension(
forecasts_with_actuals, "inventory_product_id"
)
# Calculate metrics by location
metrics_by_location = self._calculate_metrics_by_dimension(
forecasts_with_actuals, "location"
)
return {
"total_evaluated": len(forecasts_with_sales),
"with_actuals": len(forecasts_with_actuals),
"without_actuals": len(forecasts_without_actuals),
"metrics_created": len(performance_metrics),
"overall_mae": overall_metrics["mae"],
"overall_mape": overall_metrics["mape"],
"overall_rmse": overall_metrics["rmse"],
"overall_r2_score": overall_metrics["r2_score"],
"overall_accuracy_percentage": overall_metrics["accuracy_percentage"],
"total_predicted": overall_metrics["total_predicted"],
"total_actual": overall_metrics["total_actual"],
"metrics_by_product": metrics_by_product,
"metrics_by_location": metrics_by_location
}
def _calculate_overall_metrics(self, forecasts: List[Dict[str, Any]]) -> Dict[str, float]:
"""Calculate aggregated metrics across all forecasts"""
if not forecasts:
return {
"mae": None, "mape": None, "rmse": None,
"r2_score": None, "accuracy_percentage": None,
"total_predicted": 0.0, "total_actual": 0.0
}
predicted_values = [f["predicted_demand"] for f in forecasts]
actual_values = [f["actual_sales"] for f in forecasts]
# MAE: Mean Absolute Error
mae = sum(abs(p - a) for p, a in zip(predicted_values, actual_values)) / len(forecasts)
# MAPE: Mean Absolute Percentage Error (handle division by zero)
mape_values = [
abs(p - a) / a * 100 if a > 0 else 0.0
for p, a in zip(predicted_values, actual_values)
]
mape = sum(mape_values) / len(mape_values)
# RMSE: Root Mean Square Error
squared_errors = [(p - a) ** 2 for p, a in zip(predicted_values, actual_values)]
rmse = math.sqrt(sum(squared_errors) / len(squared_errors))
# R² Score (coefficient of determination)
mean_actual = sum(actual_values) / len(actual_values)
ss_total = sum((a - mean_actual) ** 2 for a in actual_values)
ss_residual = sum((a - p) ** 2 for a, p in zip(actual_values, predicted_values))
r2_score = 1 - (ss_residual / ss_total) if ss_total > 0 else 0.0
# Accuracy percentage (100% - MAPE)
accuracy_percentage = max(0.0, 100.0 - mape)
return {
"mae": round(mae, 2),
"mape": round(mape, 2),
"rmse": round(rmse, 2),
"r2_score": round(r2_score, 4),
"accuracy_percentage": round(accuracy_percentage, 2),
"total_predicted": round(sum(predicted_values), 2),
"total_actual": round(sum(actual_values), 2)
}
def _calculate_metrics_by_dimension(
self,
forecasts: List[Dict[str, Any]],
dimension_key: str
) -> Dict[str, Dict[str, float]]:
"""Calculate metrics grouped by a dimension (product_id or location)"""
dimension_groups = {}
# Group forecasts by dimension
for forecast in forecasts:
key = str(forecast[dimension_key])
if key not in dimension_groups:
dimension_groups[key] = []
dimension_groups[key].append(forecast)
# Calculate metrics for each group
metrics_by_dimension = {}
for key, group_forecasts in dimension_groups.items():
metrics = self._calculate_overall_metrics(group_forecasts)
metrics_by_dimension[key] = {
"count": len(group_forecasts),
"mae": metrics["mae"],
"mape": metrics["mape"],
"rmse": metrics["rmse"],
"accuracy_percentage": metrics["accuracy_percentage"],
"total_predicted": metrics["total_predicted"],
"total_actual": metrics["total_actual"]
}
return metrics_by_dimension
async def get_validation_run(self, validation_run_id: uuid.UUID) -> Optional[ValidationRun]:
"""Get a validation run by ID"""
try:
query = select(ValidationRun).where(ValidationRun.id == validation_run_id)
result = await self.db.execute(query)
return result.scalar_one_or_none()
except Exception as e:
logger.error("Failed to get validation run", validation_run_id=validation_run_id, error=str(e))
raise DatabaseError(f"Failed to get validation run: {str(e)}")
async def get_validation_runs_by_tenant(
self,
tenant_id: uuid.UUID,
limit: int = 50,
skip: int = 0
) -> List[ValidationRun]:
"""Get validation runs for a tenant"""
try:
query = (
select(ValidationRun)
.where(ValidationRun.tenant_id == tenant_id)
.order_by(ValidationRun.created_at.desc())
.limit(limit)
.offset(skip)
)
result = await self.db.execute(query)
return result.scalars().all()
except Exception as e:
logger.error("Failed to get validation runs", tenant_id=tenant_id, error=str(e))
raise DatabaseError(f"Failed to get validation runs: {str(e)}")
async def get_accuracy_trends(
self,
tenant_id: uuid.UUID,
days: int = 30
) -> Dict[str, Any]:
"""Get accuracy trends over time"""
try:
start_date = datetime.now(timezone.utc) - timedelta(days=days)
query = (
select(ValidationRun)
.where(
and_(
ValidationRun.tenant_id == tenant_id,
ValidationRun.status == "completed",
ValidationRun.created_at >= start_date
)
)
.order_by(ValidationRun.created_at)
)
result = await self.db.execute(query)
runs = result.scalars().all()
if not runs:
return {
"period_days": days,
"total_runs": 0,
"trends": []
}
trends = [
{
"date": run.validation_start_date.isoformat(),
"mae": run.overall_mae,
"mape": run.overall_mape,
"rmse": run.overall_rmse,
"accuracy_percentage": run.overall_accuracy_percentage,
"forecasts_evaluated": run.total_forecasts_evaluated,
"forecasts_with_actuals": run.forecasts_with_actuals
}
for run in runs
]
# Calculate averages
valid_runs = [r for r in runs if r.overall_mape is not None]
avg_mape = sum(r.overall_mape for r in valid_runs) / len(valid_runs) if valid_runs else None
avg_accuracy = sum(r.overall_accuracy_percentage for r in valid_runs) / len(valid_runs) if valid_runs else None
return {
"period_days": days,
"total_runs": len(runs),
"average_mape": round(avg_mape, 2) if avg_mape else None,
"average_accuracy": round(avg_accuracy, 2) if avg_accuracy else None,
"trends": trends
}
except Exception as e:
logger.error("Failed to get accuracy trends", tenant_id=tenant_id, error=str(e))
raise DatabaseError(f"Failed to get accuracy trends: {str(e)}")

View File

@@ -0,0 +1,3 @@
"""
Utility modules for forecasting service
"""

View File

@@ -0,0 +1,258 @@
"""
Distributed Locking Mechanisms for Forecasting Service
Prevents concurrent forecast generation for the same product/date
"""
import asyncio
import time
from typing import Optional
import logging
from contextlib import asynccontextmanager
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
from datetime import datetime, timezone, timedelta
logger = logging.getLogger(__name__)
class LockAcquisitionError(Exception):
"""Raised when lock cannot be acquired"""
pass
class DatabaseLock:
"""
Database-based distributed lock using PostgreSQL advisory locks.
Works across multiple service instances.
"""
def __init__(self, lock_name: str, timeout: float = 30.0):
"""
Initialize database lock.
Args:
lock_name: Unique identifier for the lock
timeout: Maximum seconds to wait for lock acquisition
"""
self.lock_name = lock_name
self.timeout = timeout
self.lock_id = self._hash_lock_name(lock_name)
def _hash_lock_name(self, name: str) -> int:
"""Convert lock name to integer ID for PostgreSQL advisory lock"""
# Use hash and modulo to get a positive 32-bit integer
return abs(hash(name)) % (2**31)
@asynccontextmanager
async def acquire(self, session: AsyncSession):
"""
Acquire distributed lock as async context manager.
Args:
session: Database session for lock operations
Raises:
LockAcquisitionError: If lock cannot be acquired within timeout
"""
acquired = False
start_time = time.time()
try:
# Try to acquire lock with timeout
while time.time() - start_time < self.timeout:
# Try non-blocking lock acquisition
result = await session.execute(
text("SELECT pg_try_advisory_lock(:lock_id)"),
{"lock_id": self.lock_id}
)
acquired = result.scalar()
if acquired:
logger.info(f"Acquired lock: {self.lock_name} (id={self.lock_id})")
break
# Wait a bit before retrying
await asyncio.sleep(0.1)
if not acquired:
raise LockAcquisitionError(
f"Could not acquire lock '{self.lock_name}' within {self.timeout}s"
)
yield
finally:
if acquired:
# Release lock
await session.execute(
text("SELECT pg_advisory_unlock(:lock_id)"),
{"lock_id": self.lock_id}
)
logger.info(f"Released lock: {self.lock_name} (id={self.lock_id})")
class SimpleDatabaseLock:
"""
Simple table-based distributed lock.
Alternative to advisory locks, uses a dedicated locks table.
"""
def __init__(self, lock_name: str, timeout: float = 30.0, ttl: float = 300.0):
"""
Initialize simple database lock.
Args:
lock_name: Unique identifier for the lock
timeout: Maximum seconds to wait for lock acquisition
ttl: Time-to-live for stale lock cleanup (seconds)
"""
self.lock_name = lock_name
self.timeout = timeout
self.ttl = ttl
async def _ensure_lock_table(self, session: AsyncSession):
"""Ensure locks table exists"""
create_table_sql = """
CREATE TABLE IF NOT EXISTS distributed_locks (
lock_name VARCHAR(255) PRIMARY KEY,
acquired_at TIMESTAMP WITH TIME ZONE NOT NULL,
acquired_by VARCHAR(255),
expires_at TIMESTAMP WITH TIME ZONE NOT NULL
)
"""
await session.execute(text(create_table_sql))
await session.commit()
async def _cleanup_stale_locks(self, session: AsyncSession):
"""Remove expired locks"""
cleanup_sql = """
DELETE FROM distributed_locks
WHERE expires_at < :now
"""
await session.execute(
text(cleanup_sql),
{"now": datetime.now(timezone.utc)}
)
await session.commit()
@asynccontextmanager
async def acquire(self, session: AsyncSession, owner: str = "forecasting-service"):
"""
Acquire simple database lock.
Args:
session: Database session
owner: Identifier for lock owner
Raises:
LockAcquisitionError: If lock cannot be acquired
"""
await self._ensure_lock_table(session)
await self._cleanup_stale_locks(session)
acquired = False
start_time = time.time()
try:
# Try to acquire lock
while time.time() - start_time < self.timeout:
now = datetime.now(timezone.utc)
expires_at = now + timedelta(seconds=self.ttl)
try:
# Try to insert lock record
insert_sql = """
INSERT INTO distributed_locks (lock_name, acquired_at, acquired_by, expires_at)
VALUES (:lock_name, :acquired_at, :acquired_by, :expires_at)
ON CONFLICT (lock_name) DO NOTHING
RETURNING lock_name
"""
result = await session.execute(
text(insert_sql),
{
"lock_name": self.lock_name,
"acquired_at": now,
"acquired_by": owner,
"expires_at": expires_at
}
)
await session.commit()
if result.rowcount > 0:
acquired = True
logger.info(f"Acquired simple lock: {self.lock_name}")
break
except Exception as e:
logger.debug(f"Lock acquisition attempt failed: {e}")
await session.rollback()
# Wait before retrying
await asyncio.sleep(0.5)
if not acquired:
raise LockAcquisitionError(
f"Could not acquire lock '{self.lock_name}' within {self.timeout}s"
)
yield
finally:
if acquired:
# Release lock
delete_sql = """
DELETE FROM distributed_locks
WHERE lock_name = :lock_name
"""
await session.execute(
text(delete_sql),
{"lock_name": self.lock_name}
)
await session.commit()
logger.info(f"Released simple lock: {self.lock_name}")
def get_forecast_lock(
tenant_id: str,
product_id: str,
forecast_date: str,
use_advisory: bool = True
) -> DatabaseLock:
"""
Get distributed lock for generating a forecast for a specific product and date.
Args:
tenant_id: Tenant identifier
product_id: Product identifier
forecast_date: Forecast date (ISO format)
use_advisory: Use PostgreSQL advisory locks (True) or table-based (False)
Returns:
Lock instance
"""
lock_name = f"forecast:{tenant_id}:{product_id}:{forecast_date}"
if use_advisory:
return DatabaseLock(lock_name, timeout=30.0)
else:
return SimpleDatabaseLock(lock_name, timeout=30.0, ttl=300.0)
def get_batch_forecast_lock(tenant_id: str, use_advisory: bool = True) -> DatabaseLock:
"""
Get distributed lock for batch forecast generation for a tenant.
Args:
tenant_id: Tenant identifier
use_advisory: Use PostgreSQL advisory locks (True) or table-based (False)
Returns:
Lock instance
"""
lock_name = f"forecast_batch:{tenant_id}"
if use_advisory:
return DatabaseLock(lock_name, timeout=60.0)
else:
return SimpleDatabaseLock(lock_name, timeout=60.0, ttl=600.0)